summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorChristoph Groth <christoph.groth@cea.fr>2024-10-25 17:46:41 +0200
committerChristoph Groth <christoph.groth@cea.fr>2025-01-09 13:58:20 +0100
commite39d3089b5b1a8583c6750bedaf639de73cf0316 (patch)
treeb173dff9bd87a575bad8ee4f3bcf93c98ba9d477 /src
parent0cd1f377cbf1a555d6de211b5ef02f9c65f1db2e (diff)
Optimize inner loop by using constant dimension
Diffstat (limited to 'src')
-rw-r--r--src/main.rs7
1 files changed, 4 insertions, 3 deletions
diff --git a/src/main.rs b/src/main.rs
index 1627396..70bb8a8 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,4 +1,4 @@
-use mdarray::{view, tensor, Slice, Expression, Dim};
+use mdarray::{view, array, tensor, Slice, Expression, Dim, Const, Dyn};
// Indexing convention: C_ij <- A_ik * B_kj
fn matmul<D0: Dim, D1: Dim, D2: Dim>(
@@ -16,7 +16,8 @@ fn matmul<D0: Dim, D1: Dim, D2: Dim>(
fn main() {
let a = view![[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]];
- let b = view![[0.0, 1.0], [1.0, 1.0]];
+ let b = array![[0.0, 1.0], [1.0, 1.0]];
+ let b = b.reshape((Const::<2>, Dyn(2)));
let mut c = tensor![[0.0; 2]; 3];
@@ -24,7 +25,7 @@ fn main() {
dbg!(std::any::type_name_of_val(&b));
dbg!(std::any::type_name_of_val(&c));
- matmul(&a, &b, &mut c);
+ matmul(&a.reshape((Dyn(3), Const::<2>)), &b, &mut c);
assert_eq!(c, view![[4.0, 5.0], [5.0, 7.0], [6.0, 9.0]]);