fortrangoingonforty/armfortas / b9cfaba

Browse files

Use canonical column-major strides in allocatable assignment dest

Authored by espadonne
Committed by mfwolffe
SHA
b9cfaba3d23e4b9661920b385bce4efe3be7bd7a
Parents
5b3eccb
Tree
889072e

1 changed file

StatusFile+-
M runtime/src/array.rs 135 19
runtime/src/array.rsmodified
@@ -1434,12 +1434,22 @@ pub extern "C" fn afs_assign_allocatable(
1434
             dest.flags &= !DESC_ALLOCATED;
1434
             dest.flags &= !DESC_ALLOCATED;
1435
         }
1435
         }
1436
 
1436
 
1437
-        // Allocate with source's shape.
1437
+        // Allocate with source's shape, but compute canonical
1438
+        // column-major strides (1, ext_0, ext_0*ext_1, ...) — the
1439
+        // dest is freshly contiguous, so per-dim memory step must
1440
+        // match Fortran's column-major convention used by
1441
+        // afs_create_section / load_rank1_array_desc_elem. Setting
1442
+        // stride=1 across the board collapsed dim_1+ accesses onto
1443
+        // the dim_0 axis (e.g. allocatable A = transpose(reshape(...))
1444
+        // produced descriptor with stride=(1,1) and any subsequent
1445
+        // assumed-shape pass read overlapping bytes per "column").
1438
         dest.rank = source.rank;
1446
         dest.rank = source.rank;
1439
         dest.elem_size = source.elem_size;
1447
         dest.elem_size = source.elem_size;
1448
+        let mut running_stride: i64 = 1;
1440
         for i in 0..source.rank as usize {
1449
         for i in 0..source.rank as usize {
1441
             dest.dims[i] = source.dims[i];
1450
             dest.dims[i] = source.dims[i];
1442
-            dest.dims[i].stride = 1; // dest is always contiguous
1451
+            dest.dims[i].stride = running_stride;
1452
+            running_stride = running_stride.saturating_mul(source.dims[i].extent().max(1));
1443
         }
1453
         }
1444
 
1454
 
1445
         let bytes = dest.total_bytes();
1455
         let bytes = dest.total_bytes();
@@ -1454,11 +1464,68 @@ pub extern "C" fn afs_assign_allocatable(
1454
         dest.flags = DESC_ALLOCATED | DESC_CONTIGUOUS;
1464
         dest.flags = DESC_ALLOCATED | DESC_CONTIGUOUS;
1455
     }
1465
     }
1456
 
1466
 
1457
-    // Copy data. Use ptr::copy (not copy_nonoverlapping) to handle self-assignment.
1467
+    // Copy data. Source may be non-contiguous (e.g. result of
1468
+    // transpose() returns a descriptor with reversed dim strides
1469
+    // pointing at the original buffer). A flat ptr::copy of
1470
+    // total_bytes from source.base_addr would drag adjacent bytes
1471
+    // forward without honoring per-dim strides — the same class of
1472
+    // bug as the original afs_copy_array_data flat copy. Detect
1473
+    // non-contiguous and walk every multi-index column-major.
1474
+    //
1475
+    // We only treat the source as non-contiguous when at least one
1476
+    // dim's stride is *strictly greater* than its canonical
1477
+    // column-major step. Strides smaller than canonical (e.g.
1478
+    // afs_matmul's 2x2 result emitted with stride=(1,1) instead of
1479
+    // (1,2)) describe an internally inconsistent descriptor whose
1480
+    // base_addr still points at a flat contiguous buffer; walking
1481
+    // those would re-read the same byte offset twice and drop the
1482
+    // last element. The conservative choice is the flat copy that
1483
+    // mirrors total_bytes — which the previous unconditional ptr::copy
1484
+    // did silently for both kinds of source.
1458
     let bytes = source.total_bytes();
1485
     let bytes = source.total_bytes();
1459
     if bytes > 0 && !source.base_addr.is_null() && !dest.base_addr.is_null() {
1486
     if bytes > 0 && !source.base_addr.is_null() && !dest.base_addr.is_null() {
1460
-        unsafe {
1487
+        let elem_size = source.elem_size;
1461
-            ptr::copy(source.base_addr, dest.base_addr, bytes as usize);
1488
+        let mut canonical: i64 = 1;
1489
+        let mut strided = false;
1490
+        for i in 0..source.rank as usize {
1491
+            if source.dims[i].stride > canonical {
1492
+                strided = true;
1493
+                break;
1494
+            }
1495
+            canonical = canonical.saturating_mul(source.dims[i].extent().max(1));
1496
+        }
1497
+        if !strided {
1498
+            unsafe {
1499
+                ptr::copy(source.base_addr, dest.base_addr, bytes as usize);
1500
+            }
1501
+        } else {
1502
+            let rank = source.rank as usize;
1503
+            let extents: Vec<i64> = (0..rank).map(|i| source.dims[i].extent()).collect();
1504
+            let strides: Vec<i64> = (0..rank).map(|i| source.dims[i].stride).collect();
1505
+            let mut idx = vec![0i64; rank];
1506
+            let total = source.total_elements();
1507
+            for k in 0..total {
1508
+                let mut src_off: i64 = 0;
1509
+                for d in 0..rank {
1510
+                    src_off += idx[d] * strides[d];
1511
+                }
1512
+                src_off *= elem_size;
1513
+                let dst_off = k * elem_size;
1514
+                unsafe {
1515
+                    ptr::copy_nonoverlapping(
1516
+                        source.base_addr.offset(src_off as isize),
1517
+                        dest.base_addr.offset(dst_off as isize),
1518
+                        elem_size as usize,
1519
+                    );
1520
+                }
1521
+                for d in 0..rank {
1522
+                    idx[d] += 1;
1523
+                    if idx[d] < extents[d] {
1524
+                        break;
1525
+                    }
1526
+                    idx[d] = 0;
1527
+                }
1528
+            }
1462
         }
1529
         }
1463
     }
1530
     }
1464
     dest.set_scalar_type_tag(source.scalar_type_tag());
1531
     dest.set_scalar_type_tag(source.scalar_type_tag());
@@ -1523,9 +1590,15 @@ pub extern "C" fn afs_assign_allocatable_convert(
1523
         }
1590
         }
1524
         dest_ref.rank = source_ref.rank;
1591
         dest_ref.rank = source_ref.rank;
1525
         dest_ref.elem_size = dest_elem_size;
1592
         dest_ref.elem_size = dest_elem_size;
1593
+        // Canonical column-major strides — see matching note in
1594
+        // afs_assign_allocatable. dest is freshly contiguous; the
1595
+        // per-dim memory step must be (1, ext_0, ext_0*ext_1, ...).
1596
+        let mut running_stride: i64 = 1;
1526
         for i in 0..source_ref.rank as usize {
1597
         for i in 0..source_ref.rank as usize {
1527
             dest_ref.dims[i] = source_ref.dims[i];
1598
             dest_ref.dims[i] = source_ref.dims[i];
1528
-            dest_ref.dims[i].stride = 1;
1599
+            dest_ref.dims[i].stride = running_stride;
1600
+            running_stride =
1601
+                running_stride.saturating_mul(source_ref.dims[i].extent().max(1));
1529
         }
1602
         }
1530
         let bytes = dest_ref.total_bytes();
1603
         let bytes = dest_ref.total_bytes();
1531
         if bytes > 0 {
1604
         if bytes > 0 {
@@ -1548,29 +1621,72 @@ pub extern "C" fn afs_assign_allocatable_convert(
1548
 
1621
 
1549
     let src_p = source_ref.base_addr;
1622
     let src_p = source_ref.base_addr;
1550
     let dst_p = dest_ref.base_addr;
1623
     let dst_p = dest_ref.base_addr;
1551
-    for i in 0..n {
1624
+    let src_elem_size: i64 = match src_kind_tag {
1625
+        0 => 1,
1626
+        1 => 2,
1627
+        2 | 4 => 4,
1628
+        3 | 5 => 8,
1629
+        _ => return,
1630
+    };
1631
+    // Source may be non-contiguous (e.g. transpose result, section).
1632
+    // Walk each multi-index column-major and apply per-dim strides.
1633
+    // Mirror the same-class detection used in afs_assign_allocatable
1634
+    // and afs_copy_array_data: only apply per-dim strides when at
1635
+    // least one stride is *strictly greater* than its canonical
1636
+    // column-major step. A stride below canonical describes a
1637
+    // malformed descriptor (e.g. a 2x2 matmul result with stride=(1,1)
1638
+    // instead of (1,2)) whose underlying buffer is still flat
1639
+    // contiguous; walking those re-reads the same offset twice.
1640
+    let rank = source_ref.rank as usize;
1641
+    let extents: Vec<i64> = (0..rank).map(|i| source_ref.dims[i].extent()).collect();
1642
+    let raw_strides: Vec<i64> = (0..rank).map(|i| source_ref.dims[i].stride).collect();
1643
+    let mut canonical_step: i64 = 1;
1644
+    let mut canonical: Vec<i64> = Vec::with_capacity(rank);
1645
+    let mut strided = false;
1646
+    for d in 0..rank {
1647
+        canonical.push(canonical_step);
1648
+        if raw_strides[d] > canonical_step {
1649
+            strided = true;
1650
+        }
1651
+        canonical_step = canonical_step.saturating_mul(extents[d].max(1));
1652
+    }
1653
+    let strides: &[i64] = if strided { &raw_strides } else { &canonical };
1654
+    let mut idx = vec![0i64; rank];
1655
+    for k in 0..n {
1656
+        let mut src_off_elems: i64 = 0;
1657
+        for d in 0..rank {
1658
+            src_off_elems += idx[d] * strides[d];
1659
+        }
1660
+        let src_byte_off = src_off_elems * src_elem_size;
1552
         let src_val_f64: f64 = unsafe {
1661
         let src_val_f64: f64 = unsafe {
1553
             match src_kind_tag {
1662
             match src_kind_tag {
1554
-                0 => *(src_p.add(i) as *const i8) as f64,
1663
+                0 => *(src_p.offset(src_byte_off as isize) as *const i8) as f64,
1555
-                1 => *(src_p.add(2 * i) as *const i16) as f64,
1664
+                1 => *(src_p.offset(src_byte_off as isize) as *const i16) as f64,
1556
-                2 => *(src_p.add(4 * i) as *const i32) as f64,
1665
+                2 => *(src_p.offset(src_byte_off as isize) as *const i32) as f64,
1557
-                3 => *(src_p.add(8 * i) as *const i64) as f64,
1666
+                3 => *(src_p.offset(src_byte_off as isize) as *const i64) as f64,
1558
-                4 => *(src_p.add(4 * i) as *const f32) as f64,
1667
+                4 => *(src_p.offset(src_byte_off as isize) as *const f32) as f64,
1559
-                5 => *(src_p.add(8 * i) as *const f64),
1668
+                5 => *(src_p.offset(src_byte_off as isize) as *const f64),
1560
                 _ => return,
1669
                 _ => return,
1561
             }
1670
             }
1562
         };
1671
         };
1563
         unsafe {
1672
         unsafe {
1564
             match dest_kind_tag {
1673
             match dest_kind_tag {
1565
-                0 => *(dst_p.add(i) as *mut i8) = src_val_f64 as i8,
1674
+                0 => *(dst_p.add(k) as *mut i8) = src_val_f64 as i8,
1566
-                1 => *(dst_p.add(2 * i) as *mut i16) = src_val_f64 as i16,
1675
+                1 => *(dst_p.add(2 * k) as *mut i16) = src_val_f64 as i16,
1567
-                2 => *(dst_p.add(4 * i) as *mut i32) = src_val_f64 as i32,
1676
+                2 => *(dst_p.add(4 * k) as *mut i32) = src_val_f64 as i32,
1568
-                3 => *(dst_p.add(8 * i) as *mut i64) = src_val_f64 as i64,
1677
+                3 => *(dst_p.add(8 * k) as *mut i64) = src_val_f64 as i64,
1569
-                4 => *(dst_p.add(4 * i) as *mut f32) = src_val_f64 as f32,
1678
+                4 => *(dst_p.add(4 * k) as *mut f32) = src_val_f64 as f32,
1570
-                5 => *(dst_p.add(8 * i) as *mut f64) = src_val_f64,
1679
+                5 => *(dst_p.add(8 * k) as *mut f64) = src_val_f64,
1571
                 _ => return,
1680
                 _ => return,
1572
             }
1681
             }
1573
         }
1682
         }
1683
+        for d in 0..rank {
1684
+            idx[d] += 1;
1685
+            if idx[d] < extents[d] {
1686
+                break;
1687
+            }
1688
+            idx[d] = 0;
1689
+        }
1574
     }
1690
     }
1575
 }
1691
 }
1576
 
1692