1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
|
use mdarray::{view, tensor, DSlice, Expression};
fn matmul(a: &DSlice<f64, 2>, b: &DSlice<f64, 2>, c: &mut DSlice<f64, 2>) {
for (mut cj, bj) in c.cols_mut().zip(b.cols()) {
for (ak, bkj) in a.cols().zip(bj) {
for (cij, aik) in cj.expr_mut().zip(ak) {
*cij = aik.mul_add(*bkj, *cij);
}
}
}
}
fn main() {
let a = view![[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]];
let b = view![[0.0, 1.0], [1.0, 1.0]];
let mut c = tensor![[0.0; 2]; 3];
dbg!(std::any::type_name_of_val(&a));
dbg!(std::any::type_name_of_val(&b));
dbg!(std::any::type_name_of_val(&c));
matmul(&a, &b, &mut c);
assert_eq!(c, view![[4.0, 5.0], [5.0, 7.0], [6.0, 9.0]]);
// slice
let d = a.view(1, ..);
dbg!(std::any::type_name_of_val(&d));
// index
let e = c[4];
dbg!(std::any::type_name_of_val(&e));
// permute
let f = c.permute::<1, 0>();
dbg!(std::any::type_name_of_val(&f));
}
|