fortrangoingonforty/armfortas / 9f540f9

Browse files

Lower mixed range+vector subscript section to fresh descriptor gather

Authored by espadonne
Committed by mfwolffe
SHA
9f540f9115d03ccf351a9d47fbf3b6fe09744185
Parents
6681f99
Tree
4447a2e

1 changed file

StatusFile+-
M src/ir/lower/core.rs 384 9
src/ir/lower/core.rsmodified
@@ -21783,6 +21783,12 @@ pub(super) fn lower_section_write_nd(
2178321783
         const_stride: Option<i64>,
2178421784
         decl_lo: i64,
2178521785
         cum_stride: i64,
21786
+        // F2018 §9.5.3.3 vector subscript: when an Element subscript
21787
+        // is itself an array, the iteration counter walks 1..N (with
21788
+        // N = size of the index array) and the value used in flat-
21789
+        // offset arithmetic is `vector_idx_desc[counter-1]`. None for
21790
+        // ordinary Range / scalar Element dims.
21791
+        vector_idx: Option<(ValueId, IrType)>,
2178621792
     }
2178721793
 
2178821794
     let mut dims: Vec<DimSlice> = Vec::with_capacity(args.len());
@@ -21792,7 +21798,7 @@ pub(super) fn lower_section_write_nd(
2179221798
         let decl_hi = decl_lo + decl_ext - 1;
2179321799
 
2179421800
         let counter = b.alloca(IrType::Int(IntWidth::I32));
21795
-        let (start_val, end_val, stride_val, const_stride) = match &arg.value {
21801
+        let (start_val, end_val, stride_val, const_stride, vector_idx) = match &arg.value {
2179621802
             SectionSubscript::Range { start, end, stride } => {
2179721803
                 let start_v = match start {
2179821804
                     Some(e) => super::expr::lower_expr_ctx(b, ctx, e),
@@ -21807,12 +21813,34 @@ pub(super) fn lower_section_write_nd(
2180721813
                     None => b.const_i32(1),
2180821814
                 };
2180921815
                 let cs = stride.as_ref().and_then(eval_const_int);
21810
-                (start_v, end_v, stride_v, cs)
21816
+                (start_v, end_v, stride_v, cs, None)
2181121817
             }
2181221818
             SectionSubscript::Element(e) => {
21813
-                let v = super::expr::lower_expr_ctx(b, ctx, e);
21814
-                // Single-element dimension: start == end, stride 1.
21815
-                (v, v, b.const_i32(1), Some(1))
21819
+                if expr_returns_array(e, &ctx.locals, ctx.st) {
21820
+                    let (idx_desc, idx_elem_ty) = lower_array_expr_descriptor(
21821
+                        b,
21822
+                        &ctx.locals,
21823
+                        e,
21824
+                        ctx.st,
21825
+                        Some(ctx.type_layouts),
21826
+                        Some(ctx.internal_funcs),
21827
+                        Some(ctx.contained_host_refs),
21828
+                        Some(ctx.descriptor_params),
21829
+                    )
21830
+                    .expect("vector subscript array must produce descriptor");
21831
+                    let n = b.call(
21832
+                        FuncRef::External("afs_array_size".into()),
21833
+                        vec![idx_desc],
21834
+                        IrType::Int(IntWidth::I64),
21835
+                    );
21836
+                    let n_i32 = b.int_trunc(n, IntWidth::I32);
21837
+                    let one = b.const_i32(1);
21838
+                    (one, n_i32, one, Some(1), Some((idx_desc, idx_elem_ty)))
21839
+                } else {
21840
+                    let v = super::expr::lower_expr_ctx(b, ctx, e);
21841
+                    // Single-element dimension: start == end, stride 1.
21842
+                    (v, v, b.const_i32(1), Some(1), None)
21843
+                }
2181621844
             }
2181721845
         };
2181821846
         b.store(start_val, counter);
@@ -21824,6 +21852,7 @@ pub(super) fn lower_section_write_nd(
2182421852
             const_stride,
2182521853
             decl_lo,
2182621854
             cum_stride,
21855
+            vector_idx,
2182721856
         });
2182821857
         cum_stride *= decl_ext.max(1);
2182921858
     }
@@ -21894,12 +21923,25 @@ pub(super) fn lower_section_write_nd(
2189421923
             // Borrow `dims` immutably while iterating it; the loop
2189521924
             // body needs &mut b so we collect the per-dim values
2189621925
             // first, then emit the IR for the sum afterwards.
21897
-            let dim_data: Vec<(ValueId, i64, i64)> = dims
21926
+            // For vector-subscript dims, load the actual index from
21927
+            // the index array at position counter-1 (1-based); for
21928
+            // ordinary dims use the counter directly. The decl_lo
21929
+            // adjustment is the array's declared lower bound, which
21930
+            // applies to both kinds of indices.
21931
+            let dim_data: Vec<(ValueId, i64, i64, Option<(ValueId, IrType)>)> = dims
2189821932
                 .iter()
21899
-                .map(|d| (d.counter, d.decl_lo, d.cum_stride))
21933
+                .map(|d| (d.counter, d.decl_lo, d.cum_stride, d.vector_idx.clone()))
2190021934
                 .collect();
21901
-            for (counter, decl_lo, cum_stride_d) in dim_data {
21902
-                let cnt = b.load(counter);
21935
+            for (counter, decl_lo, cum_stride_d, vector_idx) in dim_data {
21936
+                let cnt = if let Some((idx_desc, idx_elem_ty)) = vector_idx {
21937
+                    let i_val = b.load(counter);
21938
+                    let one = b.const_i32(1);
21939
+                    let zero_based = b.isub(i_val, one);
21940
+                    let zero_based_i64 = widen_idx_to_i64(b, zero_based);
21941
+                    load_rank1_array_desc_elem(b, idx_desc, &idx_elem_ty, zero_based_i64)
21942
+                } else {
21943
+                    b.load(counter)
21944
+                };
2190321945
                 let lo_const = b.const_i32(decl_lo as i32);
2190421946
                 let zero_based = b.isub(cnt, lo_const);
2190521947
                 let zero_based64 = widen_idx_to_i64(b, zero_based);
@@ -27299,6 +27341,34 @@ pub(super) fn lower_array_expr_descriptor(
2729927341
                         };
2730027342
                         return Some((desc, info.ty.clone()));
2730127343
                     }
27344
+                    // F2018 §9.5.3.3: any subscript that is itself an
27345
+                    // array expression is a vector subscript. When mixed
27346
+                    // with ordinary ranges (e.g. `A(:, pivots)`), the
27347
+                    // afs_create_section runtime can't represent it, so
27348
+                    // route those through a per-element gather that
27349
+                    // builds a freshly allocated descriptor.
27350
+                    let any_vector_subscript = args.iter().any(|arg| {
27351
+                        if let crate::ast::expr::SectionSubscript::Element(e) = &arg.value {
27352
+                            expr_returns_array(e, locals, st)
27353
+                        } else {
27354
+                            false
27355
+                        }
27356
+                    });
27357
+                    if any_vector_subscript {
27358
+                        if let Some(desc) = lower_array_section_with_vector_subscripts(
27359
+                            b,
27360
+                            locals,
27361
+                            info,
27362
+                            args,
27363
+                            st,
27364
+                            type_layouts,
27365
+                            internal_funcs,
27366
+                            contained_host_refs,
27367
+                            descriptor_params,
27368
+                        ) {
27369
+                            return Some((desc, info.ty.clone()));
27370
+                        }
27371
+                    }
2730227372
                     if args.iter().any(|arg| {
2730327373
                         matches!(arg.value, crate::ast::expr::SectionSubscript::Range { .. })
2730427374
                     }) {
@@ -29242,6 +29312,311 @@ pub(super) fn first_arg_is_complex(
2924229312
         .unwrap_or(false)
2924329313
 }
2924429314
 
29315
+/// F2018 §9.5.3.3: array section with one or more vector subscripts mixed
29316
+/// with ordinary range subscripts. afs_create_section can't represent a
29317
+/// vector subscript, so when one is present we materialize the result
29318
+/// into a freshly-allocated descriptor by gathering element-by-element.
29319
+/// Returns Some(result_desc) on success, or None to fall through to the
29320
+/// normal `afs_create_section` path (no vector subscripts present, or
29321
+/// the index array element type isn't an integer kind we handle).
29322
+pub(super) fn lower_array_section_with_vector_subscripts(
29323
+    b: &mut FuncBuilder,
29324
+    locals: &HashMap<String, LocalInfo>,
29325
+    info: &LocalInfo,
29326
+    args: &[crate::ast::expr::Argument],
29327
+    st: &SymbolTable,
29328
+    type_layouts: Option<&crate::sema::type_layout::TypeLayoutRegistry>,
29329
+    internal_funcs: Option<&HashMap<String, u32>>,
29330
+    contained_host_refs: Option<&HashMap<String, Vec<String>>>,
29331
+    descriptor_params: Option<&HashMap<String, Vec<bool>>>,
29332
+) -> Option<ValueId> {
29333
+    use crate::ast::expr::SectionSubscript;
29334
+
29335
+    enum DimKind {
29336
+        Range {
29337
+            start: ValueId,
29338
+            stride: ValueId,
29339
+            result_dim: usize,
29340
+        },
29341
+        Vector {
29342
+            idx_desc: ValueId,
29343
+            idx_elem_ty: IrType,
29344
+            result_dim: usize,
29345
+        },
29346
+        Scalar {
29347
+            val: ValueId,
29348
+        },
29349
+    }
29350
+
29351
+    // Source descriptor — uniform handle whether the local is descriptor-
29352
+    // backed or a contiguous stack array.
29353
+    let source_desc = if local_uses_array_descriptor(info) {
29354
+        array_descriptor_addr(b, info)
29355
+    } else {
29356
+        materialize_array_descriptor_for_info(b, info)
29357
+    };
29358
+
29359
+    let elem_ty = info.ty.clone();
29360
+    let elem_bytes = ir_scalar_byte_size(&elem_ty);
29361
+    let zero32 = b.const_i32(0);
29362
+    let zero64 = b.const_i64(0);
29363
+    let one64 = b.const_i64(1);
29364
+
29365
+    let mut dim_kinds: Vec<DimKind> = Vec::with_capacity(args.len());
29366
+    let mut result_extents: Vec<ValueId> = Vec::new();
29367
+    let mut has_vector = false;
29368
+
29369
+    for (d, arg) in args.iter().enumerate() {
29370
+        let dim_lo_off = 24 + (d as i64) * 24;
29371
+        let dim_ext_off = dim_lo_off + 8;
29372
+        match &arg.value {
29373
+            SectionSubscript::Range { start, end, stride } => {
29374
+                let start_v = match start {
29375
+                    Some(e) => {
29376
+                        let raw = super::expr::lower_expr_with_optional_layouts(b, locals, e, st, type_layouts);
29377
+                        widen_idx_to_i64(b, raw)
29378
+                    }
29379
+                    None => load_array_desc_i64_field(b, source_desc, dim_lo_off),
29380
+                };
29381
+                let end_v = match end {
29382
+                    Some(e) => {
29383
+                        let raw = super::expr::lower_expr_with_optional_layouts(b, locals, e, st, type_layouts);
29384
+                        widen_idx_to_i64(b, raw)
29385
+                    }
29386
+                    None => {
29387
+                        let lo = load_array_desc_i64_field(b, source_desc, dim_lo_off);
29388
+                        let ext = load_array_desc_i64_field(b, source_desc, dim_ext_off);
29389
+                        let lo_minus_one = b.isub(lo, one64);
29390
+                        b.iadd(lo_minus_one, ext)
29391
+                    }
29392
+                };
29393
+                let stride_v = match stride {
29394
+                    Some(e) => {
29395
+                        let raw = super::expr::lower_expr_with_optional_layouts(b, locals, e, st, type_layouts);
29396
+                        widen_idx_to_i64(b, raw)
29397
+                    }
29398
+                    None => one64,
29399
+                };
29400
+                // result_extent = max((end - start)/stride + 1, 0)
29401
+                let diff = b.isub(end_v, start_v);
29402
+                let bumped = b.iadd(diff, stride_v);
29403
+                let raw_extent = b.idiv(bumped, stride_v);
29404
+                let nonneg = b.icmp(CmpOp::Ge, raw_extent, zero64);
29405
+                let extent = b.select(nonneg, raw_extent, zero64);
29406
+                let result_dim = result_extents.len();
29407
+                result_extents.push(extent);
29408
+                dim_kinds.push(DimKind::Range {
29409
+                    start: start_v,
29410
+                    stride: stride_v,
29411
+                    result_dim,
29412
+                });
29413
+            }
29414
+            SectionSubscript::Element(e) if expr_returns_array(e, locals, st) => {
29415
+                has_vector = true;
29416
+                let (idx_desc, idx_elem_ty) = lower_array_expr_descriptor(
29417
+                    b,
29418
+                    locals,
29419
+                    e,
29420
+                    st,
29421
+                    type_layouts,
29422
+                    internal_funcs,
29423
+                    contained_host_refs,
29424
+                    descriptor_params,
29425
+                )?;
29426
+                if !matches!(idx_elem_ty, IrType::Int(_)) {
29427
+                    return None;
29428
+                }
29429
+                let size = b.call(
29430
+                    FuncRef::External("afs_array_size".into()),
29431
+                    vec![idx_desc],
29432
+                    IrType::Int(IntWidth::I64),
29433
+                );
29434
+                let result_dim = result_extents.len();
29435
+                result_extents.push(size);
29436
+                dim_kinds.push(DimKind::Vector {
29437
+                    idx_desc,
29438
+                    idx_elem_ty,
29439
+                    result_dim,
29440
+                });
29441
+            }
29442
+            SectionSubscript::Element(e) => {
29443
+                let raw = super::expr::lower_expr_with_optional_layouts(b, locals, e, st, type_layouts);
29444
+                let val = widen_idx_to_i64(b, raw);
29445
+                dim_kinds.push(DimKind::Scalar { val });
29446
+            }
29447
+        }
29448
+    }
29449
+
29450
+    if !has_vector {
29451
+        return None;
29452
+    }
29453
+
29454
+    let result_rank = result_extents.len();
29455
+    if result_rank == 0 {
29456
+        return None;
29457
+    }
29458
+
29459
+    // Allocate a fresh descriptor for the result with rank=result_rank and
29460
+    // each kept dim having lower_bound 1 and extent = result_extents[k].
29461
+    // afs_allocate_array fills in the heap base and computes column-major
29462
+    // strides (1, ext0, ext0*ext1, ...).
29463
+    let result_desc = b.alloca(IrType::Array(Box::new(IrType::Int(IntWidth::I8)), 384));
29464
+    let sz384 = b.const_i64(384);
29465
+    b.call(
29466
+        FuncRef::External("memset".into()),
29467
+        vec![result_desc, zero32, sz384],
29468
+        IrType::Ptr(Box::new(IrType::Int(IntWidth::I8))),
29469
+    );
29470
+    let dim_buf_words = (3 * result_rank) as u64;
29471
+    let dim_buf = b.alloca(IrType::Array(Box::new(IrType::Int(IntWidth::I64)), dim_buf_words));
29472
+    for (k, ext) in result_extents.iter().enumerate() {
29473
+        let lo_idx = b.const_i64((k * 3) as i64);
29474
+        let lo_dst = b.gep(dim_buf, vec![lo_idx], IrType::Int(IntWidth::I64));
29475
+        b.store(one64, lo_dst);
29476
+        let ub_idx = b.const_i64((k * 3 + 1) as i64);
29477
+        let ub_dst = b.gep(dim_buf, vec![ub_idx], IrType::Int(IntWidth::I64));
29478
+        b.store(*ext, ub_dst);
29479
+        let str_idx = b.const_i64((k * 3 + 2) as i64);
29480
+        let str_dst = b.gep(dim_buf, vec![str_idx], IrType::Int(IntWidth::I64));
29481
+        b.store(one64, str_dst);
29482
+    }
29483
+    let elem_size_v = b.const_i64(elem_bytes);
29484
+    let rank_v = b.const_i32(result_rank as i32);
29485
+    let stat = b.alloca(IrType::Int(IntWidth::I32));
29486
+    b.store(zero32, stat);
29487
+    b.call(
29488
+        FuncRef::External("afs_allocate_array".into()),
29489
+        vec![result_desc, elem_size_v, rank_v, dim_buf, stat],
29490
+        IrType::Void,
29491
+    );
29492
+
29493
+    // Cache source per-dim lower bound and stride (in elements). These are
29494
+    // queried once before the loops to avoid reloading on every iteration.
29495
+    let mut src_los: Vec<ValueId> = Vec::with_capacity(args.len());
29496
+    let mut src_strides: Vec<ValueId> = Vec::with_capacity(args.len());
29497
+    for d in 0..args.len() {
29498
+        let lo_off = 24 + (d as i64) * 24;
29499
+        let stride_off = lo_off + 16;
29500
+        src_los.push(load_array_desc_i64_field(b, source_desc, lo_off));
29501
+        src_strides.push(load_array_desc_i64_field(b, source_desc, stride_off));
29502
+    }
29503
+    let src_base = b.load_typed(source_desc, IrType::Ptr(Box::new(IrType::Int(IntWidth::I8))));
29504
+    let dst_base = b.load_typed(result_desc, IrType::Ptr(Box::new(IrType::Int(IntWidth::I8))));
29505
+
29506
+    // One i64 counter per kept dim, all 0-based.
29507
+    let mut counters: Vec<ValueId> = Vec::with_capacity(result_rank);
29508
+    for _ in 0..result_rank {
29509
+        let c = b.alloca(IrType::Int(IntWidth::I64));
29510
+        b.store(zero64, c);
29511
+        counters.push(c);
29512
+    }
29513
+
29514
+    // Nested loop scaffolding: outermost = highest result dim, innermost = 0
29515
+    // (column-major iteration matching the result descriptor's stride layout).
29516
+    let mut checks: Vec<BlockId> = Vec::with_capacity(result_rank);
29517
+    let mut bodies: Vec<BlockId> = Vec::with_capacity(result_rank);
29518
+    let mut incrs: Vec<BlockId> = Vec::with_capacity(result_rank);
29519
+    let mut exits: Vec<BlockId> = Vec::with_capacity(result_rank);
29520
+    for k in 0..result_rank {
29521
+        checks.push(b.create_block(&format!("vsg_check_d{}", k)));
29522
+        bodies.push(b.create_block(&format!("vsg_body_d{}", k)));
29523
+        incrs.push(b.create_block(&format!("vsg_incr_d{}", k)));
29524
+        exits.push(b.create_block(&format!("vsg_exit_d{}", k)));
29525
+    }
29526
+    let outer = result_rank - 1;
29527
+    b.branch(checks[outer], vec![]);
29528
+
29529
+    let elem_bytes_v = b.const_i64(elem_bytes);
29530
+    for k_rev in 0..result_rank {
29531
+        let k = result_rank - 1 - k_rev;
29532
+
29533
+        b.set_block(checks[k]);
29534
+        let cur = b.load(counters[k]);
29535
+        let done = b.icmp(CmpOp::Ge, cur, result_extents[k]);
29536
+        b.cond_branch(done, exits[k], vec![], bodies[k], vec![]);
29537
+
29538
+        b.set_block(bodies[k]);
29539
+        if k == 0 {
29540
+            // Innermost body: gather one source element into the result.
29541
+            // Source byte offset = sum_d (zero_based_src_idx_d * src_stride_d * elem_bytes)
29542
+            let mut src_byte_off: Option<ValueId> = None;
29543
+            // Borrow checker: collect per-dim closures of values needed before mutating b.
29544
+            // dim_kinds has indices into counters, so we walk it directly.
29545
+            for (d, kind) in dim_kinds.iter().enumerate() {
29546
+                let zero_based = match kind {
29547
+                    DimKind::Range { start, stride, result_dim } => {
29548
+                        let cnt = b.load(counters[*result_dim]);
29549
+                        let off = b.imul(cnt, *stride);
29550
+                        let src_idx = b.iadd(*start, off);
29551
+                        b.isub(src_idx, src_los[d])
29552
+                    }
29553
+                    DimKind::Vector { idx_desc, idx_elem_ty, result_dim } => {
29554
+                        let cnt = b.load(counters[*result_dim]);
29555
+                        let raw = load_rank1_array_desc_elem(b, *idx_desc, idx_elem_ty, cnt);
29556
+                        let widened = match idx_elem_ty {
29557
+                            IrType::Int(IntWidth::I64) => raw,
29558
+                            IrType::Int(_) => b.int_extend(raw, IntWidth::I64, true),
29559
+                            _ => return None,
29560
+                        };
29561
+                        b.isub(widened, src_los[d])
29562
+                    }
29563
+                    DimKind::Scalar { val } => b.isub(*val, src_los[d]),
29564
+                };
29565
+                let stride_term = b.imul(zero_based, src_strides[d]);
29566
+                let bytes_term = b.imul(stride_term, elem_bytes_v);
29567
+                src_byte_off = Some(match src_byte_off {
29568
+                    Some(prev) => b.iadd(prev, bytes_term),
29569
+                    None => bytes_term,
29570
+                });
29571
+            }
29572
+            let src_off = src_byte_off.unwrap_or(zero64);
29573
+            let src_p = b.gep(src_base, vec![src_off], IrType::Int(IntWidth::I8));
29574
+            let elem_val = b.load_typed(src_p, elem_ty.clone());
29575
+
29576
+            // Result byte offset (column-major): for each kept dim k2,
29577
+            // contribution = counters[k2] * (prod_{j<k2} ext_j) * elem_bytes.
29578
+            let mut dst_byte_off: Option<ValueId> = None;
29579
+            let mut cum_ext: Option<ValueId> = None;
29580
+            for k2 in 0..result_rank {
29581
+                let cnt = b.load(counters[k2]);
29582
+                let stride_elems = cum_ext.unwrap_or(one64);
29583
+                let term_elements = b.imul(cnt, stride_elems);
29584
+                let term_bytes = b.imul(term_elements, elem_bytes_v);
29585
+                dst_byte_off = Some(match dst_byte_off {
29586
+                    Some(prev) => b.iadd(prev, term_bytes),
29587
+                    None => term_bytes,
29588
+                });
29589
+                cum_ext = Some(match cum_ext {
29590
+                    Some(prev) => b.imul(prev, result_extents[k2]),
29591
+                    None => result_extents[k2],
29592
+                });
29593
+            }
29594
+            let dst_off = dst_byte_off.unwrap_or(zero64);
29595
+            let dst_p = b.gep(dst_base, vec![dst_off], IrType::Int(IntWidth::I8));
29596
+            b.store(elem_val, dst_p);
29597
+            b.branch(incrs[0], vec![]);
29598
+        } else {
29599
+            // Reset next-inner counter to 0 and dive in.
29600
+            b.store(zero64, counters[k - 1]);
29601
+            b.branch(checks[k - 1], vec![]);
29602
+        }
29603
+
29604
+        b.set_block(incrs[k]);
29605
+        let cur2 = b.load(counters[k]);
29606
+        let next = b.iadd(cur2, one64);
29607
+        b.store(next, counters[k]);
29608
+        b.branch(checks[k], vec![]);
29609
+
29610
+        b.set_block(exits[k]);
29611
+        if k < result_rank - 1 {
29612
+            b.branch(incrs[k + 1], vec![]);
29613
+        }
29614
+    }
29615
+
29616
+    b.set_block(exits[outer]);
29617
+    Some(result_desc)
29618
+}
29619
+
2924529620
 /// Lower an array section expression: a(1:10:2) → create section descriptor.
2924629621
 pub(super) fn lower_array_section(
2924729622
     b: &mut FuncBuilder,