diff options
| author | Christoph Groth <christoph.groth@cea.fr> | 2024-10-24 12:52:52 +0200 |
|---|---|---|
| committer | Christoph Groth <christoph.groth@cea.fr> | 2025-01-09 13:58:07 +0100 |
| commit | 91882899908ed62ff37636d1670e8e46df48a9d4 (patch) | |
| tree | 67f2f8bec65c6754951223afee1c3f4f3f712c91 | |
| parent | 52600bbfe3e48d66d1dce00bcde3e2012807c9e0 (diff) | |
Adapt to new API and the switch to row major
| -rw-r--r-- | Cargo.lock | 3 | ||||
| -rw-r--r-- | Cargo.toml | 2 | ||||
| -rw-r--r-- | src/main.rs | 12 |
3 files changed, 8 insertions, 9 deletions
@@ -5,8 +5,7 @@ version = 3 [[package]] name = "mdarray" version = "0.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e0d82a5ed5640d5075b3fdfe2c0921fc473bc0977c5707c248b76bd43b56dcd" +source = "git+https://github.com/fre-hu/mdarray.git?rev=b93aaa58b8b0139f4058247df04f9cf765e0717d#b93aaa58b8b0139f4058247df04f9cf765e0717d" [[package]] name = "mdarray-test" @@ -4,4 +4,4 @@ version = "0.1.0" edition = "2021" [dependencies] -mdarray = "0.6.1" +mdarray = { git = "https://github.com/fre-hu/mdarray.git", rev = "b93aaa58b8b0139f4058247df04f9cf765e0717d" } diff --git a/src/main.rs b/src/main.rs index c0ee447..31d810b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,6 @@ -use mdarray::{expr, grid, DSpan, Expression}; +use mdarray::{view, tensor, DSlice, Expression}; -fn matmul(a: &DSpan<f64, 2>, b: &DSpan<f64, 2>, c: &mut DSpan<f64, 2>) { +fn matmul(a: &DSlice<f64, 2>, b: &DSlice<f64, 2>, c: &mut DSlice<f64, 2>) { for (mut cj, bj) in c.cols_mut().zip(b.cols()) { for (ak, bkj) in a.cols().zip(bj) { for (cij, aik) in cj.expr_mut().zip(ak) { @@ -11,10 +11,10 @@ fn matmul(a: &DSpan<f64, 2>, b: &DSpan<f64, 2>, c: &mut DSpan<f64, 2>) { } fn main() { - let a = expr![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; - let b = expr![[0.0, 1.0], [1.0, 1.0]]; + let a = view![[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]]; + let b = view![[0.0, 1.0], [1.0, 1.0]]; - let mut c = grid![[0.0; 3]; 2]; + let mut c = tensor![[0.0; 2]; 3]; dbg!(std::any::type_name_of_val(&a)); dbg!(std::any::type_name_of_val(&b)); @@ -22,7 +22,7 @@ fn main() { matmul(&a, &b, &mut c); - assert_eq!(c, expr![[4.0, 5.0, 6.0], [5.0, 7.0, 9.0]]); + assert_eq!(c, view![[4.0, 5.0], [5.0, 7.0], [6.0, 9.0]]); // slice let d = a.view(1, ..); |
