summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChristoph Groth <christoph.groth@cea.fr>2024-10-24 12:52:52 +0200
committerChristoph Groth <christoph.groth@cea.fr>2025-01-09 13:58:07 +0100
commit91882899908ed62ff37636d1670e8e46df48a9d4 (patch)
tree67f2f8bec65c6754951223afee1c3f4f3f712c91
parent52600bbfe3e48d66d1dce00bcde3e2012807c9e0 (diff)
Adapt to new API and the switch to row major
-rw-r--r--Cargo.lock3
-rw-r--r--Cargo.toml2
-rw-r--r--src/main.rs12
3 files changed, 8 insertions, 9 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 84ba7d8..d0b78ee 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -5,8 +5,7 @@ version = 3
[[package]]
name = "mdarray"
version = "0.6.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4e0d82a5ed5640d5075b3fdfe2c0921fc473bc0977c5707c248b76bd43b56dcd"
+source = "git+https://github.com/fre-hu/mdarray.git?rev=b93aaa58b8b0139f4058247df04f9cf765e0717d#b93aaa58b8b0139f4058247df04f9cf765e0717d"
[[package]]
name = "mdarray-test"
diff --git a/Cargo.toml b/Cargo.toml
index 79415b8..3f3a2d9 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -4,4 +4,4 @@ version = "0.1.0"
edition = "2021"
[dependencies]
-mdarray = "0.6.1"
+mdarray = { git = "https://github.com/fre-hu/mdarray.git", rev = "b93aaa58b8b0139f4058247df04f9cf765e0717d" }
diff --git a/src/main.rs b/src/main.rs
index c0ee447..31d810b 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,6 +1,6 @@
-use mdarray::{expr, grid, DSpan, Expression};
+use mdarray::{view, tensor, DSlice, Expression};
-fn matmul(a: &DSpan<f64, 2>, b: &DSpan<f64, 2>, c: &mut DSpan<f64, 2>) {
+fn matmul(a: &DSlice<f64, 2>, b: &DSlice<f64, 2>, c: &mut DSlice<f64, 2>) {
for (mut cj, bj) in c.cols_mut().zip(b.cols()) {
for (ak, bkj) in a.cols().zip(bj) {
for (cij, aik) in cj.expr_mut().zip(ak) {
@@ -11,10 +11,10 @@ fn matmul(a: &DSpan<f64, 2>, b: &DSpan<f64, 2>, c: &mut DSpan<f64, 2>) {
}
fn main() {
- let a = expr![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
- let b = expr![[0.0, 1.0], [1.0, 1.0]];
+ let a = view![[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]];
+ let b = view![[0.0, 1.0], [1.0, 1.0]];
- let mut c = grid![[0.0; 3]; 2];
+ let mut c = tensor![[0.0; 2]; 3];
dbg!(std::any::type_name_of_val(&a));
dbg!(std::any::type_name_of_val(&b));
@@ -22,7 +22,7 @@ fn main() {
matmul(&a, &b, &mut c);
- assert_eq!(c, expr![[4.0, 5.0, 6.0], [5.0, 7.0, 9.0]]);
+ assert_eq!(c, view![[4.0, 5.0], [5.0, 7.0], [6.0, 9.0]]);
// slice
let d = a.view(1, ..);