summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Cargo.lock2
-rw-r--r--Cargo.toml2
-rw-r--r--src/main.rs10
3 files changed, 8 insertions, 6 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 747f12a..f24c693 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -5,7 +5,7 @@ version = 3
[[package]]
name = "mdarray"
version = "0.6.1"
-source = "git+https://github.com/fre-hu/mdarray.git?rev=f7aaba4fe618edb5b55f36e8a9425f165f38163c#f7aaba4fe618edb5b55f36e8a9425f165f38163c"
+source = "git+https://github.com/fre-hu/mdarray.git?rev=06cb5e371326e674c0e9eebe0db5c4e7e200b800#06cb5e371326e674c0e9eebe0db5c4e7e200b800"
[[package]]
name = "mdarray-test"
diff --git a/Cargo.toml b/Cargo.toml
index b0e8e33..4a27b26 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -4,4 +4,4 @@ version = "0.1.0"
edition = "2021"
[dependencies]
-mdarray = { git = "https://github.com/fre-hu/mdarray.git", rev = "f7aaba4fe618edb5b55f36e8a9425f165f38163c" }
+mdarray = { git = "https://github.com/fre-hu/mdarray.git", rev = "06cb5e371326e674c0e9eebe0db5c4e7e200b800" }
diff --git a/src/main.rs b/src/main.rs
index a8eec44..021e799 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,4 +1,4 @@
-use mdarray::{view, array, tensor, Slice, Dim, Const, Dyn, expr::Expression};
+use mdarray::{view, array, tensor, Slice, Dim, Const, Dyn, Rank, expr::Expression, Dense};
// Indexing convention: C_ij <- A_ik * B_kj
fn matmul<D0: Dim, D1: Dim, D2: Dim>(
@@ -19,13 +19,14 @@ fn main() {
let b = array![[0.0, 1.0], [1.0, 1.0]];
let b = b.reshape((Const::<2>, Dyn(!0)));
- let mut c = tensor![[0.0; 2]; 3];
+ // .into_dyn() replaces .reshape(DynRank::from_dims(&[2, 3])).into():
+ let mut c: mdarray::Tensor<f64> = tensor![[0.0; 2]; 3].into_dyn();
dbg!(std::any::type_name_of_val(&a));
dbg!(std::any::type_name_of_val(&b));
dbg!(std::any::type_name_of_val(&c));
- matmul(&a.reshape((Dyn(3), Const::<2>)), &b, &mut c);
+ matmul(&a.reshape((Dyn(3), Const::<2>)), &b, &mut c.remap_mut());
assert_eq!(c, view![[4.0, 5.0], [5.0, 7.0], [6.0, 9.0]]);
@@ -38,6 +39,7 @@ fn main() {
dbg!(std::any::type_name_of_val(&e));
// permute
- let f = c.permute::<1, 0>();
+ let f = c.remap::<Rank<2>, Dense>();
+ let f = f.permute::<1, 0>();
dbg!(std::any::type_name_of_val(&f));
}