From 2ec9bbe18b57d0673a7065c23ca5d40184bf55ce Mon Sep 17 00:00:00 2001 From: Christoph Groth Date: Thu, 9 Jan 2025 21:59:23 +0100 Subject: New conversions allow to use a dynamic rank array with matmul --- Cargo.lock | 2 +- Cargo.toml | 2 +- src/main.rs | 10 ++++++---- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 747f12a..f24c693 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5,7 +5,7 @@ version = 3 [[package]] name = "mdarray" version = "0.6.1" -source = "git+https://github.com/fre-hu/mdarray.git?rev=f7aaba4fe618edb5b55f36e8a9425f165f38163c#f7aaba4fe618edb5b55f36e8a9425f165f38163c" +source = "git+https://github.com/fre-hu/mdarray.git?rev=06cb5e371326e674c0e9eebe0db5c4e7e200b800#06cb5e371326e674c0e9eebe0db5c4e7e200b800" [[package]] name = "mdarray-test" diff --git a/Cargo.toml b/Cargo.toml index b0e8e33..4a27b26 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,4 +4,4 @@ version = "0.1.0" edition = "2021" [dependencies] -mdarray = { git = "https://github.com/fre-hu/mdarray.git", rev = "f7aaba4fe618edb5b55f36e8a9425f165f38163c" } +mdarray = { git = "https://github.com/fre-hu/mdarray.git", rev = "06cb5e371326e674c0e9eebe0db5c4e7e200b800" } diff --git a/src/main.rs b/src/main.rs index a8eec44..021e799 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,4 @@ -use mdarray::{view, array, tensor, Slice, Dim, Const, Dyn, expr::Expression}; +use mdarray::{view, array, tensor, Slice, Dim, Const, Dyn, Rank, expr::Expression, Dense}; // Indexing convention: C_ij <- A_ik * B_kj fn matmul( @@ -19,13 +19,14 @@ fn main() { let b = array![[0.0, 1.0], [1.0, 1.0]]; let b = b.reshape((Const::<2>, Dyn(!0))); - let mut c = tensor![[0.0; 2]; 3]; + // .into_dyn() replaces .reshape(DynRank::from_dims(&[2, 3])).into(): + let mut c: mdarray::Tensor = tensor![[0.0; 2]; 3].into_dyn(); dbg!(std::any::type_name_of_val(&a)); dbg!(std::any::type_name_of_val(&b)); dbg!(std::any::type_name_of_val(&c)); - matmul(&a.reshape((Dyn(3), Const::<2>)), &b, &mut c); + matmul(&a.reshape((Dyn(3), Const::<2>)), &b, &mut c.remap_mut()); assert_eq!(c, view![[4.0, 5.0], [5.0, 7.0], [6.0, 9.0]]); @@ -38,6 +39,7 @@ fn main() { dbg!(std::any::type_name_of_val(&e)); // permute - let f = c.permute::<1, 0>(); + let f = c.remap::, Dense>(); + let f = f.permute::<1, 0>(); dbg!(std::any::type_name_of_val(&f)); } -- cgit v1.2.3-74-g4815