fortrangoingonforty/armfortas / 697a46b

Browse files

Wire array bounds checks end to end

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
697a46b9a8a02ba57710b7b99cc18199b18ca026
Parents
74a593b
Tree
838b7e4

7 changed files

StatusFile+-
M runtime/src/array.rs 14 1
M src/ir/lower.rs 18 11
M src/opt/bce.rs 462 45
M src/opt/pipeline.rs 4 0
A test_programs/bounds_check_loop.f90 14 0
A tests/bounds_checks.rs 82 0
A tests/fixtures/bounds_check_oob.f90 7 0
runtime/src/array.rsmodified
@@ -7,6 +7,20 @@
7
 use crate::descriptor::*;
7
 use crate::descriptor::*;
8
 use std::ptr;
8
 use std::ptr;
9
 
9
 
10
+// ---- BOUNDS CHECKS ----
11
+
12
+/// Abort if an array subscript is outside the legal closed interval.
13
+#[no_mangle]
14
+pub extern "C" fn afs_check_bounds(index: i64, lower: i64, upper: i64) {
15
+    if index < lower || index > upper {
16
+        eprintln!(
17
+            "Bounds check failed: index {} outside [{}, {}]",
18
+            index, lower, upper
19
+        );
20
+        std::process::exit(1);
21
+    }
22
+}
23
+
10
 // ---- ALLOCATE ----
24
 // ---- ALLOCATE ----
11
 
25
 
12
 /// Allocate an array described by the given dimensions.
26
 /// Allocate an array described by the given dimensions.
@@ -829,4 +843,3 @@ pub extern "C" fn afs_dot_product_int(
829
     }
843
     }
830
     dot
844
     dot
831
 }
845
 }
832
-
src/ir/lower.rsmodified
@@ -4942,6 +4942,15 @@ fn lower_array_element(
4942
     b.load(elem_ptr)
4942
     b.load(elem_ptr)
4943
 }
4943
 }
4944
 
4944
 
4945
+fn emit_bounds_check(
4946
+    b: &mut FuncBuilder,
4947
+    index: ValueId,
4948
+    lower: ValueId,
4949
+    upper: ValueId,
4950
+) {
4951
+    b.runtime_call(RuntimeFunc::CheckBounds, vec![index, lower, upper], IrType::Void);
4952
+}
4953
+
4945
 /// Compute the column-major flat ELEMENT offset (i64) for an array
4954
 /// Compute the column-major flat ELEMENT offset (i64) for an array
4946
 /// subscript expression, returning a value suitable for `b.gep` (which
4955
 /// subscript expression, returning a value suitable for `b.gep` (which
4947
 /// scales by the GEP result element size).
4956
 /// scales by the GEP result element size).
@@ -4991,6 +5000,7 @@ fn compute_flat_elem_offset(
4991
             let p_up = b.gep(info.addr, vec![off_up], IrType::Int(IntWidth::I8));
5000
             let p_up = b.gep(info.addr, vec![off_up], IrType::Int(IntWidth::I8));
4992
             let lo = b.load_typed(p_lo, IrType::Int(IntWidth::I64));
5001
             let lo = b.load_typed(p_lo, IrType::Int(IntWidth::I64));
4993
             let up = b.load_typed(p_up, IrType::Int(IntWidth::I64));
5002
             let up = b.load_typed(p_up, IrType::Int(IntWidth::I64));
5003
+            emit_bounds_check(b, sub, lo, up);
4994
 
5004
 
4995
             let adjusted = b.isub(sub, lo);
5005
             let adjusted = b.isub(sub, lo);
4996
 
5006
 
@@ -5023,20 +5033,23 @@ fn compute_flat_elem_offset(
5023
             crate::ast::expr::SectionSubscript::Element(e) => lower_expr(b, locals, e, st),
5033
             crate::ast::expr::SectionSubscript::Element(e) => lower_expr(b, locals, e, st),
5024
             _ => b.const_i32(0),
5034
             _ => b.const_i32(0),
5025
         };
5035
         };
5036
+        let subscript64 = widen_idx_to_i64(b, subscript);
5026
 
5037
 
5027
         let (lower, extent) = if dim_idx < info.dims.len() {
5038
         let (lower, extent) = if dim_idx < info.dims.len() {
5028
             info.dims[dim_idx]
5039
             info.dims[dim_idx]
5029
         } else {
5040
         } else {
5030
             (1, 1)
5041
             (1, 1)
5031
         };
5042
         };
5032
-
5043
+        let upper = lower + extent - 1;
5033
-        let lower_val = b.const_i32(lower as i32);
5044
+        let lower_val = b.const_i64(lower);
5034
-        let adjusted = b.isub(subscript, lower_val);
5045
+        let upper_val = b.const_i64(upper);
5046
+        emit_bounds_check(b, subscript64, lower_val, upper_val);
5047
+        let adjusted = b.isub(subscript64, lower_val);
5035
 
5048
 
5036
         let dim_offset = if stride == 1 {
5049
         let dim_offset = if stride == 1 {
5037
             adjusted
5050
             adjusted
5038
         } else {
5051
         } else {
5039
-            let stride_val = b.const_i32(stride as i32);
5052
+            let stride_val = b.const_i64(stride);
5040
             b.imul(adjusted, stride_val)
5053
             b.imul(adjusted, stride_val)
5041
         };
5054
         };
5042
 
5055
 
@@ -5048,13 +5061,7 @@ fn compute_flat_elem_offset(
5048
         stride *= extent;
5061
         stride *= extent;
5049
     }
5062
     }
5050
 
5063
 
5051
-    let idx = flat_offset.unwrap_or_else(|| b.const_i32(0));
5064
+    flat_offset.unwrap_or_else(|| b.const_i64(0))
5052
-    // Widen index to i64 for pointer arithmetic. ARM64 GEP lowering
5053
-    // needs an i64 subscript so the codegen `mul` lands as
5054
-    // `mul x, x, x` instead of `mul x, w, x` (which the assembler
5055
-    // rejects). Applies to BOTH stack and allocatable arrays —
5056
-    // gating on `info.allocatable` was the audit's CRITICAL-1.
5057
-    widen_idx_to_i64(b, idx)
5058
 }
5065
 }
5059
 
5066
 
5060
 /// Widen an i32 (or smaller) index value to i64 for pointer
5067
 /// Widen an i32 (or smaller) index value to i64 for pointer
src/opt/bce.rsmodified
@@ -9,14 +9,13 @@
9
 //!   and hi <= upper
9
 //!   and hi <= upper
10
 //! - Constant index within [lower, upper]
10
 //! - Constant index within [lower, upper]
11
 //!
11
 //!
12
-//! Note: Bounds check INSERTION (the lowerer adding CheckBounds calls
12
+//! Bounds checks are inserted during lowering for scalar array
13
-//! at array access sites) is deferred. This pass provides the
13
+//! element accesses. This pass removes the checks when the index is
14
-//! elimination framework for when insertion lands.
14
+//! provably safe.
15
 
15
 
16
-use std::collections::HashSet;
17
 use crate::ir::inst::*;
16
 use crate::ir::inst::*;
18
-use crate::ir::walk::find_natural_loops;
17
+use crate::ir::walk::{find_natural_loops, predecessors};
19
-use super::loop_utils::{resolve_const_int, loop_defined_values};
18
+use super::loop_utils::{find_preheader, resolve_const_int};
20
 use super::pass::Pass;
19
 use super::pass::Pass;
21
 
20
 
22
 pub struct Bce;
21
 pub struct Bce;
@@ -35,6 +34,7 @@ impl Pass for Bce {
35
 
34
 
36
 fn bce_function(func: &mut Function) -> bool {
35
 fn bce_function(func: &mut Function) -> bool {
37
     let loops = find_natural_loops(func);
36
     let loops = find_natural_loops(func);
37
+    let preds = predecessors(func);
38
     let mut to_remove: Vec<(BlockId, usize)> = Vec::new();
38
     let mut to_remove: Vec<(BlockId, usize)> = Vec::new();
39
 
39
 
40
     for block in &func.blocks {
40
     for block in &func.blocks {
@@ -45,7 +45,7 @@ fn bce_function(func: &mut Function) -> bool {
45
                     let lower = args[1];
45
                     let lower = args[1];
46
                     let upper = args[2];
46
                     let upper = args[2];
47
 
47
 
48
-                    if is_provably_safe(func, &loops, index, lower, upper) {
48
+                    if is_provably_safe(func, &loops, &preds, index, lower, upper) {
49
                         to_remove.push((block.id, inst_idx));
49
                         to_remove.push((block.id, inst_idx));
50
                     }
50
                     }
51
                 }
51
                 }
@@ -68,60 +68,214 @@ fn bce_function(func: &mut Function) -> bool {
68
 fn is_provably_safe(
68
 fn is_provably_safe(
69
     func: &Function,
69
     func: &Function,
70
     loops: &[crate::ir::walk::NaturalLoop],
70
     loops: &[crate::ir::walk::NaturalLoop],
71
+    preds: &std::collections::HashMap<BlockId, Vec<BlockId>>,
71
     index: ValueId,
72
     index: ValueId,
72
     lower: ValueId,
73
     lower: ValueId,
73
     upper: ValueId,
74
     upper: ValueId,
74
 ) -> bool {
75
 ) -> bool {
76
+    let index = strip_int_casts(func, index);
77
+
75
     // Case 1: constant index, constant bounds.
78
     // Case 1: constant index, constant bounds.
76
     if let (Some(idx), Some(lo), Some(hi)) = (
79
     if let (Some(idx), Some(lo), Some(hi)) = (
77
-        resolve_const_int(func, index),
80
+        resolve_int_scalar(func, index),
78
-        resolve_const_int(func, lower),
81
+        resolve_int_scalar(func, lower),
79
-        resolve_const_int(func, upper),
82
+        resolve_int_scalar(func, upper),
80
     ) {
83
     ) {
81
         return idx >= lo && idx <= hi;
84
         return idx >= lo && idx <= hi;
82
     }
85
     }
83
 
86
 
84
-    // Case 2: index is a loop IV, bounds are constants matching loop bounds.
87
+    // Case 2: index is a canonical loop IV, and the loop's closed range
85
-    // Check if the index is a block param of a loop header, and the loop's
88
+    // stays within the checked bounds.
86
-    // init/bound encompass [lower, upper].
89
+    let Some(lo_const) = resolve_int_scalar(func, lower) else { return false; };
90
+    let Some(hi_const) = resolve_int_scalar(func, upper) else { return false; };
87
     for lp in loops {
91
     for lp in loops {
88
-        let hdr = func.block(lp.header);
92
+        let Some((range_lo, range_hi)) = loop_index_range(func, lp, preds, index) else {
89
-        if hdr.params.len() != 1 { continue; }
93
+            continue;
90
-        let iv = hdr.params[0].id;
94
+        };
91
-        if iv != index { continue; }
95
+        if range_lo >= lo_const && range_hi <= hi_const {
92
-
96
+            return true;
93
-        // The IV is in-bounds if the loop's init >= lower and bound <= upper.
97
+        }
94
-        // Find init (from preheader's branch arg) and bound (from cmp block).
98
+    }
95
-        // For now, conservative: only eliminate if both lower and upper are
99
+
96
-        // constants and the loop is a standard counted loop.
100
+    false
97
-        let loop_defs = loop_defined_values(func, lp);
101
+}
98
-
102
+
99
-        // Find bound from comparison.
103
+fn loop_index_range(
100
-        for &bid in &lp.body {
104
+    func: &Function,
101
-            let block = func.block(bid);
105
+    lp: &crate::ir::walk::NaturalLoop,
102
-            for inst in &block.insts {
106
+    preds: &std::collections::HashMap<BlockId, Vec<BlockId>>,
103
-                if let InstKind::ICmp(CmpOp::Le, a, b) = &inst.kind {
107
+    index: ValueId,
104
-                    if *a == iv {
108
+) -> Option<(i64, i64)> {
105
-                        // IV <= b → loop upper = b
109
+    let header = func.block(lp.header);
106
-                        if let (Some(lo_const), Some(hi_const), Some(bound_const)) = (
110
+    let param_idx = header.params.iter().position(|param| param.id == index)?;
107
-                            resolve_const_int(func, lower),
111
+    let init = loop_init_const(func, lp, preds, param_idx)?;
108
-                            resolve_const_int(func, upper),
112
+    let (bound, dir) = loop_bound_const(func, lp.header, index)?;
109
-                            resolve_const_int(func, *b),
113
+    let step = loop_step_const(func, lp, param_idx, index)?;
110
-                        ) {
114
+
111
-                            // If loop runs from some init to bound_const,
115
+    match (dir, step.signum()) {
112
-                            // and lo_const <= init and bound_const <= hi_const,
116
+        (LoopDir::Ascending, 1) | (LoopDir::Descending, -1) => {
113
-                            // the access is safe.
117
+            Some((init.min(bound), init.max(bound)))
114
-                            if bound_const <= hi_const && lo_const <= 1 {
118
+        }
115
-                                return true;
119
+        _ => None,
116
-                            }
120
+    }
117
-                        }
121
+}
118
-                    }
122
+
123
+fn loop_init_const(
124
+    func: &Function,
125
+    lp: &crate::ir::walk::NaturalLoop,
126
+    preds: &std::collections::HashMap<BlockId, Vec<BlockId>>,
127
+    param_idx: usize,
128
+) -> Option<i64> {
129
+    let preheader = find_preheader(func, lp, preds)?;
130
+    match &func.block(preheader).terminator {
131
+        Some(Terminator::Branch(dest, args))
132
+            if *dest == lp.header && param_idx < args.len() =>
133
+        {
134
+            resolve_int_scalar(func, args[param_idx])
135
+        }
136
+        _ => None,
137
+    }
138
+}
139
+
140
+#[derive(Clone, Copy, Debug, PartialEq, Eq)]
141
+enum LoopDir {
142
+    Ascending,
143
+    Descending,
144
+}
145
+
146
+fn loop_bound_const(
147
+    func: &Function,
148
+    header: BlockId,
149
+    index: ValueId,
150
+) -> Option<(i64, LoopDir)> {
151
+    for inst in &func.block(header).insts {
152
+        let InstKind::ICmp(op, lhs, rhs) = &inst.kind else { continue };
153
+        match op {
154
+            CmpOp::Le => {
155
+                if *lhs == index {
156
+                    return resolve_int_scalar(func, *rhs)
157
+                        .map(|bound| (bound, LoopDir::Ascending));
158
+                }
159
+                if *rhs == index {
160
+                    return resolve_int_scalar(func, *lhs)
161
+                        .map(|bound| (bound, LoopDir::Descending));
119
                 }
162
                 }
120
             }
163
             }
164
+            CmpOp::Ge => {
165
+                if *lhs == index {
166
+                    return resolve_int_scalar(func, *rhs)
167
+                        .map(|bound| (bound, LoopDir::Descending));
168
+                }
169
+                if *rhs == index {
170
+                    return resolve_int_scalar(func, *lhs)
171
+                        .map(|bound| (bound, LoopDir::Ascending));
172
+                }
173
+            }
174
+            _ => {}
121
         }
175
         }
122
     }
176
     }
177
+    None
178
+}
123
 
179
 
124
-    false
180
+fn loop_step_const(
181
+    func: &Function,
182
+    lp: &crate::ir::walk::NaturalLoop,
183
+    param_idx: usize,
184
+    index: ValueId,
185
+) -> Option<i64> {
186
+    let mut step: Option<i64> = None;
187
+    for &latch in &lp.latches {
188
+        let next =
189
+            edge_arg_value(func.block(latch).terminator.as_ref()?, lp.header, param_idx)?;
190
+        let latch_step = update_step_const(func, index, next)?;
191
+        if latch_step == 0 {
192
+            return None;
193
+        }
194
+        match step {
195
+            Some(prev) if prev != latch_step => return None,
196
+            Some(_) => {}
197
+            None => step = Some(latch_step),
198
+        }
199
+    }
200
+    step
201
+}
202
+
203
+fn edge_arg_value(term: &Terminator, target: BlockId, param_idx: usize) -> Option<ValueId> {
204
+    match term {
205
+        Terminator::Branch(dest, args) if *dest == target => args.get(param_idx).copied(),
206
+        Terminator::CondBranch {
207
+            true_dest,
208
+            true_args,
209
+            false_dest,
210
+            false_args,
211
+            ..
212
+        } => {
213
+            if *true_dest == target {
214
+                true_args.get(param_idx).copied()
215
+            } else if *false_dest == target {
216
+                false_args.get(param_idx).copied()
217
+            } else {
218
+                None
219
+            }
220
+        }
221
+        _ => None,
222
+    }
223
+}
224
+
225
+fn update_step_const(func: &Function, index: ValueId, next: ValueId) -> Option<i64> {
226
+    let next = strip_int_casts(func, next);
227
+    if next == index {
228
+        return Some(0);
229
+    }
230
+
231
+    let kind = find_inst_kind(func, next)?;
232
+    match kind {
233
+        InstKind::IAdd(lhs, rhs) => {
234
+            if *lhs == index {
235
+                resolve_int_scalar(func, *rhs)
236
+            } else if *rhs == index {
237
+                resolve_int_scalar(func, *lhs)
238
+            } else {
239
+                None
240
+            }
241
+        }
242
+        InstKind::ISub(lhs, rhs) if *lhs == index => {
243
+            resolve_int_scalar(func, *rhs).map(|step| -step)
244
+        }
245
+        _ => None,
246
+    }
247
+}
248
+
249
+fn resolve_int_scalar(func: &Function, value: ValueId) -> Option<i64> {
250
+    let kind = find_inst_kind(func, value)?;
251
+    match kind {
252
+        InstKind::ConstInt(v, _) => Some(*v),
253
+        InstKind::IntExtend(src, _, _) | InstKind::IntTrunc(src, _) => {
254
+            resolve_int_scalar(func, *src)
255
+        }
256
+        _ => resolve_const_int(func, value),
257
+    }
258
+}
259
+
260
+fn strip_int_casts(func: &Function, mut value: ValueId) -> ValueId {
261
+    loop {
262
+        let Some(kind) = find_inst_kind(func, value) else { return value };
263
+        match kind {
264
+            InstKind::IntExtend(src, _, _) | InstKind::IntTrunc(src, _) => value = *src,
265
+            _ => return value,
266
+        }
267
+    }
268
+}
269
+
270
+fn find_inst_kind<'a>(func: &'a Function, value: ValueId) -> Option<&'a InstKind> {
271
+    for block in &func.blocks {
272
+        for inst in &block.insts {
273
+            if inst.id == value {
274
+                return Some(&inst.kind);
275
+            }
276
+        }
277
+    }
278
+    None
125
 }
279
 }
126
 
280
 
127
 #[cfg(test)]
281
 #[cfg(test)]
@@ -187,4 +341,267 @@ mod tests {
187
             .any(|i| matches!(i.kind, InstKind::RuntimeCall(RuntimeFunc::CheckBounds, _)));
341
             .any(|i| matches!(i.kind, InstKind::RuntimeCall(RuntimeFunc::CheckBounds, _)));
188
         assert!(!has_check, "CheckBounds should be removed");
342
         assert!(!has_check, "CheckBounds should be removed");
189
     }
343
     }
344
+
345
+    #[test]
346
+    fn bce_removes_canonical_loop_iv_check() {
347
+        let mut m = Module::new("test".into());
348
+        let mut f = Function::new("test".into(), vec![], IrType::Void);
349
+        let span = crate::lexer::Span {
350
+            file_id: 0,
351
+            start: crate::lexer::Position { line: 0, col: 0 },
352
+            end: crate::lexer::Position { line: 0, col: 0 },
353
+        };
354
+
355
+        let header = f.create_block("header");
356
+        let body = f.create_block("body");
357
+        let exit = f.create_block("exit");
358
+
359
+        let init = f.next_value_id();
360
+        f.register_type(init, IrType::Int(IntWidth::I32));
361
+        f.block_mut(f.entry).insts.push(Inst {
362
+            id: init,
363
+            ty: IrType::Int(IntWidth::I32),
364
+            span,
365
+            kind: InstKind::ConstInt(1, IntWidth::I32),
366
+        });
367
+        let zero = f.next_value_id();
368
+        f.register_type(zero, IrType::Int(IntWidth::I32));
369
+        f.block_mut(f.entry).insts.push(Inst {
370
+            id: zero,
371
+            ty: IrType::Int(IntWidth::I32),
372
+            span,
373
+            kind: InstKind::ConstInt(0, IntWidth::I32),
374
+        });
375
+        f.block_mut(f.entry).terminator = Some(Terminator::Branch(header, vec![init, zero]));
376
+
377
+        let iv = f.next_value_id();
378
+        f.register_type(iv, IrType::Int(IntWidth::I32));
379
+        let sum = f.next_value_id();
380
+        f.register_type(sum, IrType::Int(IntWidth::I32));
381
+        f.block_mut(header).params.push(BlockParam { id: iv, ty: IrType::Int(IntWidth::I32) });
382
+        f.block_mut(header).params.push(BlockParam { id: sum, ty: IrType::Int(IntWidth::I32) });
383
+
384
+        let bound = f.next_value_id();
385
+        f.register_type(bound, IrType::Int(IntWidth::I32));
386
+        f.block_mut(header).insts.push(Inst {
387
+            id: bound,
388
+            ty: IrType::Int(IntWidth::I32),
389
+            span,
390
+            kind: InstKind::ConstInt(4, IntWidth::I32),
391
+        });
392
+        let cond = f.next_value_id();
393
+        f.register_type(cond, IrType::Bool);
394
+        f.block_mut(header).insts.push(Inst {
395
+            id: cond,
396
+            ty: IrType::Bool,
397
+            span,
398
+            kind: InstKind::ICmp(CmpOp::Le, iv, bound),
399
+        });
400
+        f.block_mut(header).terminator = Some(Terminator::CondBranch {
401
+            cond,
402
+            true_dest: body,
403
+            true_args: vec![],
404
+            false_dest: exit,
405
+            false_args: vec![],
406
+        });
407
+
408
+        let iv64 = f.next_value_id();
409
+        f.register_type(iv64, IrType::Int(IntWidth::I64));
410
+        f.block_mut(body).insts.push(Inst {
411
+            id: iv64,
412
+            ty: IrType::Int(IntWidth::I64),
413
+            span,
414
+            kind: InstKind::IntExtend(iv, IntWidth::I64, true),
415
+        });
416
+        let lo = f.next_value_id();
417
+        f.register_type(lo, IrType::Int(IntWidth::I64));
418
+        f.block_mut(body).insts.push(Inst {
419
+            id: lo,
420
+            ty: IrType::Int(IntWidth::I64),
421
+            span,
422
+            kind: InstKind::ConstInt(1, IntWidth::I64),
423
+        });
424
+        let hi = f.next_value_id();
425
+        f.register_type(hi, IrType::Int(IntWidth::I64));
426
+        f.block_mut(body).insts.push(Inst {
427
+            id: hi,
428
+            ty: IrType::Int(IntWidth::I64),
429
+            span,
430
+            kind: InstKind::ConstInt(4, IntWidth::I64),
431
+        });
432
+        let check = f.next_value_id();
433
+        f.register_type(check, IrType::Void);
434
+        f.block_mut(body).insts.push(Inst {
435
+            id: check,
436
+            ty: IrType::Void,
437
+            span,
438
+            kind: InstKind::RuntimeCall(RuntimeFunc::CheckBounds, vec![iv64, lo, hi]),
439
+        });
440
+        let step = f.next_value_id();
441
+        f.register_type(step, IrType::Int(IntWidth::I32));
442
+        f.block_mut(body).insts.push(Inst {
443
+            id: step,
444
+            ty: IrType::Int(IntWidth::I32),
445
+            span,
446
+            kind: InstKind::ConstInt(1, IntWidth::I32),
447
+        });
448
+        let next_iv = f.next_value_id();
449
+        f.register_type(next_iv, IrType::Int(IntWidth::I32));
450
+        f.block_mut(body).insts.push(Inst {
451
+            id: next_iv,
452
+            ty: IrType::Int(IntWidth::I32),
453
+            span,
454
+            kind: InstKind::IAdd(iv, step),
455
+        });
456
+        let next_sum = f.next_value_id();
457
+        f.register_type(next_sum, IrType::Int(IntWidth::I32));
458
+        f.block_mut(body).insts.push(Inst {
459
+            id: next_sum,
460
+            ty: IrType::Int(IntWidth::I32),
461
+            span,
462
+            kind: InstKind::IAdd(sum, step),
463
+        });
464
+        f.block_mut(body).terminator =
465
+            Some(Terminator::Branch(header, vec![next_iv, next_sum]));
466
+        f.block_mut(exit).terminator = Some(Terminator::Return(None));
467
+        m.add_function(f);
468
+
469
+        let pass = Bce;
470
+        assert!(pass.run(&mut m), "canonical counted-loop bounds check should be removed");
471
+        let has_check = m.functions[0].block(body).insts.iter().any(|inst| {
472
+            matches!(inst.kind, InstKind::RuntimeCall(RuntimeFunc::CheckBounds, _))
473
+        });
474
+        assert!(!has_check, "loop body should no longer contain CheckBounds");
475
+    }
476
+
477
+    #[test]
478
+    fn bce_keeps_loop_check_when_bounds_are_tighter_than_trip_range() {
479
+        let mut m = Module::new("test".into());
480
+        let mut f = Function::new("test".into(), vec![], IrType::Void);
481
+        let span = crate::lexer::Span {
482
+            file_id: 0,
483
+            start: crate::lexer::Position { line: 0, col: 0 },
484
+            end: crate::lexer::Position { line: 0, col: 0 },
485
+        };
486
+
487
+        let header = f.create_block("header");
488
+        let body = f.create_block("body");
489
+        let exit = f.create_block("exit");
490
+
491
+        let init = f.next_value_id();
492
+        f.register_type(init, IrType::Int(IntWidth::I32));
493
+        f.block_mut(f.entry).insts.push(Inst {
494
+            id: init,
495
+            ty: IrType::Int(IntWidth::I32),
496
+            span,
497
+            kind: InstKind::ConstInt(1, IntWidth::I32),
498
+        });
499
+        let zero = f.next_value_id();
500
+        f.register_type(zero, IrType::Int(IntWidth::I32));
501
+        f.block_mut(f.entry).insts.push(Inst {
502
+            id: zero,
503
+            ty: IrType::Int(IntWidth::I32),
504
+            span,
505
+            kind: InstKind::ConstInt(0, IntWidth::I32),
506
+        });
507
+        f.block_mut(f.entry).terminator = Some(Terminator::Branch(header, vec![init, zero]));
508
+
509
+        let iv = f.next_value_id();
510
+        f.register_type(iv, IrType::Int(IntWidth::I32));
511
+        let sum = f.next_value_id();
512
+        f.register_type(sum, IrType::Int(IntWidth::I32));
513
+        f.block_mut(header).params.push(BlockParam { id: iv, ty: IrType::Int(IntWidth::I32) });
514
+        f.block_mut(header).params.push(BlockParam { id: sum, ty: IrType::Int(IntWidth::I32) });
515
+
516
+        let bound = f.next_value_id();
517
+        f.register_type(bound, IrType::Int(IntWidth::I32));
518
+        f.block_mut(header).insts.push(Inst {
519
+            id: bound,
520
+            ty: IrType::Int(IntWidth::I32),
521
+            span,
522
+            kind: InstKind::ConstInt(4, IntWidth::I32),
523
+        });
524
+        let cond = f.next_value_id();
525
+        f.register_type(cond, IrType::Bool);
526
+        f.block_mut(header).insts.push(Inst {
527
+            id: cond,
528
+            ty: IrType::Bool,
529
+            span,
530
+            kind: InstKind::ICmp(CmpOp::Le, iv, bound),
531
+        });
532
+        f.block_mut(header).terminator = Some(Terminator::CondBranch {
533
+            cond,
534
+            true_dest: body,
535
+            true_args: vec![],
536
+            false_dest: exit,
537
+            false_args: vec![],
538
+        });
539
+
540
+        let iv64 = f.next_value_id();
541
+        f.register_type(iv64, IrType::Int(IntWidth::I64));
542
+        f.block_mut(body).insts.push(Inst {
543
+            id: iv64,
544
+            ty: IrType::Int(IntWidth::I64),
545
+            span,
546
+            kind: InstKind::IntExtend(iv, IntWidth::I64, true),
547
+        });
548
+        let lo = f.next_value_id();
549
+        f.register_type(lo, IrType::Int(IntWidth::I64));
550
+        f.block_mut(body).insts.push(Inst {
551
+            id: lo,
552
+            ty: IrType::Int(IntWidth::I64),
553
+            span,
554
+            kind: InstKind::ConstInt(1, IntWidth::I64),
555
+        });
556
+        let hi = f.next_value_id();
557
+        f.register_type(hi, IrType::Int(IntWidth::I64));
558
+        f.block_mut(body).insts.push(Inst {
559
+            id: hi,
560
+            ty: IrType::Int(IntWidth::I64),
561
+            span,
562
+            kind: InstKind::ConstInt(3, IntWidth::I64),
563
+        });
564
+        let check = f.next_value_id();
565
+        f.register_type(check, IrType::Void);
566
+        f.block_mut(body).insts.push(Inst {
567
+            id: check,
568
+            ty: IrType::Void,
569
+            span,
570
+            kind: InstKind::RuntimeCall(RuntimeFunc::CheckBounds, vec![iv64, lo, hi]),
571
+        });
572
+        let step = f.next_value_id();
573
+        f.register_type(step, IrType::Int(IntWidth::I32));
574
+        f.block_mut(body).insts.push(Inst {
575
+            id: step,
576
+            ty: IrType::Int(IntWidth::I32),
577
+            span,
578
+            kind: InstKind::ConstInt(1, IntWidth::I32),
579
+        });
580
+        let next_iv = f.next_value_id();
581
+        f.register_type(next_iv, IrType::Int(IntWidth::I32));
582
+        f.block_mut(body).insts.push(Inst {
583
+            id: next_iv,
584
+            ty: IrType::Int(IntWidth::I32),
585
+            span,
586
+            kind: InstKind::IAdd(iv, step),
587
+        });
588
+        let next_sum = f.next_value_id();
589
+        f.register_type(next_sum, IrType::Int(IntWidth::I32));
590
+        f.block_mut(body).insts.push(Inst {
591
+            id: next_sum,
592
+            ty: IrType::Int(IntWidth::I32),
593
+            span,
594
+            kind: InstKind::IAdd(sum, step),
595
+        });
596
+        f.block_mut(body).terminator =
597
+            Some(Terminator::Branch(header, vec![next_iv, next_sum]));
598
+        f.block_mut(exit).terminator = Some(Terminator::Return(None));
599
+        m.add_function(f);
600
+
601
+        let pass = Bce;
602
+        assert!(
603
+            !pass.run(&mut m),
604
+            "loop trip range 1..4 exceeds checked upper bound 3, so CheckBounds must remain"
605
+        );
606
+    }
190
 }
607
 }
src/opt/pipeline.rsmodified
@@ -26,6 +26,7 @@ use super::dead_func::DeadFuncElim;
26
 use super::sroa::Sroa;
26
 use super::sroa::Sroa;
27
 use super::gvn::Gvn;
27
 use super::gvn::Gvn;
28
 use super::global_lsf::GlobalLsf;
28
 use super::global_lsf::GlobalLsf;
29
+use super::bce::Bce;
29
 use super::fission::LoopFission;
30
 use super::fission::LoopFission;
30
 use super::fusion::LoopFusion;
31
 use super::fusion::LoopFusion;
31
 
32
 
@@ -141,6 +142,7 @@ pub fn build_pipeline(level: OptLevel) -> PassManager {
141
             pm.add(Box::new(Inline::for_level(OptLevel::O2)));
142
             pm.add(Box::new(Inline::for_level(OptLevel::O2)));
142
             pm.add(Box::new(SimplifyCfg));
143
             pm.add(Box::new(SimplifyCfg));
143
             pm.add(Box::new(DeadFuncElim));
144
             pm.add(Box::new(DeadFuncElim));
145
+            pm.add(Box::new(Bce));
144
             pm.add(Box::new(StrengthReduce));
146
             pm.add(Box::new(StrengthReduce));
145
             pm.add(Box::new(LocalLsf));
147
             pm.add(Box::new(LocalLsf));
146
             pm.add(Box::new(GlobalLsf));
148
             pm.add(Box::new(GlobalLsf));
@@ -168,6 +170,7 @@ pub fn build_pipeline(level: OptLevel) -> PassManager {
168
             pm.add(Box::new(Inline::for_level(OptLevel::Os)));
170
             pm.add(Box::new(Inline::for_level(OptLevel::Os)));
169
             pm.add(Box::new(SimplifyCfg));
171
             pm.add(Box::new(SimplifyCfg));
170
             pm.add(Box::new(DeadFuncElim));
172
             pm.add(Box::new(DeadFuncElim));
173
+            pm.add(Box::new(Bce));
171
             pm.add(Box::new(StrengthReduce));
174
             pm.add(Box::new(StrengthReduce));
172
             pm.add(Box::new(LocalLsf));
175
             pm.add(Box::new(LocalLsf));
173
             pm.add(Box::new(GlobalLsf));
176
             pm.add(Box::new(GlobalLsf));
@@ -192,6 +195,7 @@ pub fn build_pipeline(level: OptLevel) -> PassManager {
192
             pm.add(Box::new(Inline::for_level(OptLevel::O3)));
195
             pm.add(Box::new(Inline::for_level(OptLevel::O3)));
193
             pm.add(Box::new(SimplifyCfg));
196
             pm.add(Box::new(SimplifyCfg));
194
             pm.add(Box::new(DeadFuncElim));
197
             pm.add(Box::new(DeadFuncElim));
198
+            pm.add(Box::new(Bce));
195
             pm.add(Box::new(StrengthReduce));
199
             pm.add(Box::new(StrengthReduce));
196
             pm.add(Box::new(LocalLsf));
200
             pm.add(Box::new(LocalLsf));
197
             pm.add(Box::new(GlobalLsf));
201
             pm.add(Box::new(GlobalLsf));
test_programs/bounds_check_loop.f90added
@@ -0,0 +1,14 @@
1
+program bounds_check_loop
2
+  implicit none
3
+  integer :: i, a(4), s
4
+
5
+  a = [1, 2, 3, 4]
6
+  s = 0
7
+  do i = 1, 4
8
+    s = s + a(i)
9
+  end do
10
+
11
+  print *, s
12
+end program bounds_check_loop
13
+! CHECK: 10
14
+! IR_CHECK: rt_call @__afs_check_bounds
tests/bounds_checks.rsadded
@@ -0,0 +1,82 @@
1
+use std::collections::BTreeSet;
2
+use std::path::PathBuf;
3
+
4
+use armfortas::driver::OptLevel;
5
+use armfortas::testing::{capture_from_path, CaptureRequest, CapturedStage, RunCapture, Stage};
6
+
7
+fn fixture(path: &str) -> PathBuf {
8
+    let path = PathBuf::from(path);
9
+    assert!(path.exists(), "missing test fixture {}", path.display());
10
+    path
11
+}
12
+
13
+fn capture_text(request: CaptureRequest, stage: Stage) -> String {
14
+    let result = capture_from_path(&request).expect("capture should succeed");
15
+    match result.get(stage) {
16
+        Some(CapturedStage::Text(text)) => text.clone(),
17
+        Some(CapturedStage::Run(_)) => panic!("expected text stage for {}", stage.as_str()),
18
+        None => panic!("missing requested stage {}", stage.as_str()),
19
+    }
20
+}
21
+
22
+fn capture_run(request: CaptureRequest) -> RunCapture {
23
+    let result = capture_from_path(&request).expect("capture should succeed");
24
+    match result.get(Stage::Run) {
25
+        Some(CapturedStage::Run(run)) => run.clone(),
26
+        Some(CapturedStage::Text(_)) => panic!("expected run stage"),
27
+        None => panic!("missing requested stage {}", Stage::Run.as_str()),
28
+    }
29
+}
30
+
31
+#[test]
32
+fn lowering_inserts_bounds_checks_at_o0() {
33
+    let ir = capture_text(
34
+        CaptureRequest {
35
+            input: fixture("test_programs/bounds_check_loop.f90"),
36
+            requested: BTreeSet::from([Stage::Ir]),
37
+            opt_level: OptLevel::O0,
38
+        },
39
+        Stage::Ir,
40
+    );
41
+
42
+    assert!(
43
+        ir.contains("rt_call @__afs_check_bounds"),
44
+        "lowered IR should contain runtime bounds checks before optimization"
45
+    );
46
+}
47
+
48
+#[test]
49
+fn bce_removes_canonical_loop_bounds_checks_at_o2() {
50
+    let opt_ir = capture_text(
51
+        CaptureRequest {
52
+            input: fixture("test_programs/bounds_check_loop.f90"),
53
+            requested: BTreeSet::from([Stage::OptIr]),
54
+            opt_level: OptLevel::O2,
55
+        },
56
+        Stage::OptIr,
57
+    );
58
+
59
+    assert!(
60
+        !opt_ir.contains("rt_call @__afs_check_bounds"),
61
+        "O2 optimized IR should eliminate provably-safe loop bounds checks"
62
+    );
63
+}
64
+
65
+#[test]
66
+fn runtime_bounds_checks_trap_out_of_range_accesses() {
67
+    let run = capture_run(CaptureRequest {
68
+        input: fixture("tests/fixtures/bounds_check_oob.f90"),
69
+        requested: BTreeSet::from([Stage::Run]),
70
+        opt_level: OptLevel::O0,
71
+    });
72
+
73
+    assert_ne!(
74
+        run.exit_code, 0,
75
+        "out-of-range access should fail at runtime"
76
+    );
77
+    assert!(
78
+        run.stderr.contains("Bounds check failed"),
79
+        "runtime trap should explain the out-of-range access, stderr was:\n{}",
80
+        run.stderr
81
+    );
82
+}
tests/fixtures/bounds_check_oob.f90added
@@ -0,0 +1,7 @@
1
+program bounds_check_oob
2
+  implicit none
3
+  integer :: a(4)
4
+
5
+  a = [1, 2, 3, 4]
6
+  print *, a(5)
7
+end program bounds_check_oob