summaryrefslogtreecommitdiff
path: root/src/main.rs
blob: 87a6477ad398550ed8e2b42e0988198c3b8fb472 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
use mdarray::{view, array, tensor, Slice, Dim, Const, Dyn, expr::Expression};

// Indexing convention: C_ij <- A_ik * B_kj
fn matmul<D0: Dim, D1: Dim, D2: Dim>(
    a: &Slice<f64, (D0, D1)>,
    b: &Slice<f64, (D1, D2)>,
    c: &mut Slice<f64, (D0, D2)>) {
    for (mut ci, ai) in c.rows_mut().zip(a.rows()) {
        for (aik, bk) in ai.zip(b.rows()) {
            for (cij, bkj) in ci.expr_mut().zip(bk) {
                *cij = aik.mul_add(*bkj, *cij);
            }
        }
    }
}

fn main() {
    let a = view![[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]];
    let b = array![[0.0, 1.0], [1.0, 1.0]];
    let b = b.reshape((Const::<2>, Dyn(2)));

    let mut c = tensor![[0.0; 2]; 3];

    dbg!(std::any::type_name_of_val(&a));
    dbg!(std::any::type_name_of_val(&b));
    dbg!(std::any::type_name_of_val(&c));

    matmul(&a.reshape((Dyn(3), Const::<2>)), &b, &mut c);

    assert_eq!(c, view![[4.0, 5.0], [5.0, 7.0], [6.0, 9.0]]);

    // slice
    let d = a.view(1, ..);
    dbg!(std::any::type_name_of_val(&d));

    // index
    let e = c[4];
    dbg!(std::any::type_name_of_val(&e));

    // permute
    let f = c.permute::<1, 0>();
    dbg!(std::any::type_name_of_val(&f));
}