fortrangoingonforty/armfortas / dc5f135

Browse files

Test sum-with-unary across all four reduction shapes (incl. scalar tail)

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
dc5f13556908525b30b4cacefd99f45deba4257b
Parents
d39cf3d
Tree
e9f8507

1 changed file

StatusFile+-
M tests/vectorize_reduce_sum_unary.rs 19 13
tests/vectorize_reduce_sum_unary.rsmodified
@@ -39,22 +39,15 @@ fn o3_vectorizes_sum_with_unary_load() {
3939
         },
4040
         Stage::OptIr,
4141
     );
42
-    // Expect the unary lifted into the vector lane (vneg / vabs)
43
-    // inside the body, plus a vreduce_sum at exit.
4442
     assert!(
45
-        o3_ir.contains("vneg"),
46
-        "expected vneg in IR:\n{}",
47
-        o3_ir
48
-    );
49
-    assert!(
50
-        o3_ir.contains("vabs"),
51
-        "expected vabs in IR:\n{}",
43
+        o3_ir.contains("vneg") && o3_ir.contains("vabs"),
44
+        "expected both vneg and vabs in IR:\n{}",
5245
         o3_ir
5346
     );
5447
     assert_eq!(
5548
         o3_ir.matches("vreduce_sum").count(),
56
-        2,
57
-        "expected two vreduce_sum:\n{}",
49
+        4,
50
+        "expected four vreduce_sum:\n{}",
5851
         o3_ir
5952
     );
6053
 
@@ -68,11 +61,24 @@ fn o3_vectorizes_sum_with_unary_load() {
6861
         .map(|l| l.trim())
6962
         .filter(|l| !l.is_empty())
7063
         .collect();
71
-    assert_eq!(trimmed.len(), 2, "expected two output lines:\n{}", stdout);
64
+    assert_eq!(trimmed.len(), 4, "expected four output lines:\n{}", stdout);
65
+    // sum(-i for i=1..32) = -528.
7266
     assert_eq!(trimmed[0], "-528", "neg sum wrong: got {:?}", trimmed[0]);
67
+    // sum(|i-16| for i=1..32) = 240 + 16 = 256.
7368
     assert!(
7469
         trimmed[1].starts_with("2.56"),
75
-        "abs sum wrong: got {:?}",
70
+        "f32 abs sum wrong: got {:?}",
7671
         trimmed[1]
7772
     );
73
+    assert!(
74
+        trimmed[2].starts_with("2.56"),
75
+        "f64 abs sum wrong: got {:?}",
76
+        trimmed[2]
77
+    );
78
+    // Trip = 31; head 28 + tail 3, sum |i-16| for i=1..31 = 240.
79
+    assert!(
80
+        trimmed[3].starts_with("2.4"),
81
+        "f32 abs+tail wrong: got {:?}",
82
+        trimmed[3]
83
+    );
7884
 }