summaryrefslogtreecommitdiff
path: root/src/main.rs
blob: 655c90627a06c83f458810bd946983a15f0e30f8 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
use std::mem::MaybeUninit;

use mdarray::{view, tensor, Slice, Dim, Const, expr::Expression};
use mdarray as md;

// Indexing convention: C_ij <- A_ik * B_kj
fn matmul<D0: Dim, D1: Dim, D2: Dim>(
    a: &Slice<f64, (D0, D1)>,
    b: &Slice<f64, (D1, D2)>,
    c: &mut Slice<f64, (D0, D2)>) {
    for (mut ci, ai) in c.rows_mut().zip(a.rows()) {
        for (aik, bk) in ai.zip(b.rows()) {
            for (cij, bkj) in ci.expr_mut().zip(bk) {
                *cij = aik.mul_add(*bkj, *cij);
            }
        }
    }
}

fn main() {
    let a = view![[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]];
    let b = md::Array::<f64, (Const::<2>, Const::<2>)>::from_fn(
        |i| if i[0] == 0 && i[1] == 0 { 0.0 } else { 1.0 }
    );
    let b = b.reshape((Const::<2>, !0));

    // .into_dyn() replaces .reshape(DynRank::from_dims(&[2, 3])).into():
    let mut c: mdarray::Tensor<f64> = tensor![[0.0; 2]; 3].into_dyn();

    matmul(&a.reshape((3, Const::<2>)), &b, &mut c.remap_mut());

    assert_eq!(c, view![[4.0, 5.0], [5.0, 7.0], [6.0, 9.0]]);

    // slice
    let _ = a.view(1, ..);

    // index
    let _ = c[4];

    // permute & transpose
    assert_eq!(c.permute([1, 0]), c.transpose());

    // Arrays with rank > 6:
    let d: Vec<f64> = (0..128).map(|x| x as f64).collect();
    let d: md::DTensor<f64, 1> = d.into();
    let d = d.into_dyn();
    let d = d.reshape(&[2; 7]);
    let d = d.permute(&[6, 5, 4, 3, 2, 1, 0]);

    // Working with uninitialized memory.
    let mut e = tensor![[MaybeUninit::<f64>::uninit(); 3]; 3];
    for (i, val) in e.iter_mut().enumerate() {
        val.write(i as f64);
    }
    let _ = unsafe { e.assume_init() };

    // Indexing
    for i in 0..2 {
        let index = [0, i, 0, 0, i, 0, 0];
        println!("{} {} {} {}", a[i], a[[i, 0]], c[&index[..2]], d[index]);
    }
}