From e39d3089b5b1a8583c6750bedaf639de73cf0316 Mon Sep 17 00:00:00 2001 From: Christoph Groth Date: Fri, 25 Oct 2024 17:46:41 +0200 Subject: Optimize inner loop by using constant dimension --- src/main.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/main.rs b/src/main.rs index 1627396..70bb8a8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,4 @@ -use mdarray::{view, tensor, Slice, Expression, Dim}; +use mdarray::{view, array, tensor, Slice, Expression, Dim, Const, Dyn}; // Indexing convention: C_ij <- A_ik * B_kj fn matmul( @@ -16,7 +16,8 @@ fn matmul( fn main() { let a = view![[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]]; - let b = view![[0.0, 1.0], [1.0, 1.0]]; + let b = array![[0.0, 1.0], [1.0, 1.0]]; + let b = b.reshape((Const::<2>, Dyn(2))); let mut c = tensor![[0.0; 2]; 3]; @@ -24,7 +25,7 @@ fn main() { dbg!(std::any::type_name_of_val(&b)); dbg!(std::any::type_name_of_val(&c)); - matmul(&a, &b, &mut c); + matmul(&a.reshape((Dyn(3), Const::<2>)), &b, &mut c); assert_eq!(c, view![[4.0, 5.0], [5.0, 7.0], [6.0, 9.0]]); -- cgit v1.2.3-74-g4815