summaryrefslogtreecommitdiff
path: root/src/main.rs
blob: c0ee447a023d2449d96698d59590a03d99560f06 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
use mdarray::{expr, grid, DSpan, Expression};

fn matmul(a: &DSpan<f64, 2>, b: &DSpan<f64, 2>, c: &mut DSpan<f64, 2>) {
    for (mut cj, bj) in c.cols_mut().zip(b.cols()) {
        for (ak, bkj) in a.cols().zip(bj) {
            for (cij, aik) in cj.expr_mut().zip(ak) {
                *cij = aik.mul_add(*bkj, *cij);
            }
        }
    }
}

fn main() {
    let a = expr![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
    let b = expr![[0.0, 1.0], [1.0, 1.0]];

    let mut c = grid![[0.0; 3]; 2];

    dbg!(std::any::type_name_of_val(&a));
    dbg!(std::any::type_name_of_val(&b));
    dbg!(std::any::type_name_of_val(&c));

    matmul(&a, &b, &mut c);

    assert_eq!(c, expr![[4.0, 5.0, 6.0], [5.0, 7.0, 9.0]]);

    // slice
    let d = a.view(1, ..);
    dbg!(std::any::type_name_of_val(&d));

    // index
    let e = c[4];
    dbg!(std::any::type_name_of_val(&e));

    // permute
    let f = c.permute::<1, 0>();
    dbg!(std::any::type_name_of_val(&f));
}