diff options
| author | Christoph Groth <christoph.groth@cea.fr> | 2024-10-24 13:32:20 +0200 |
|---|---|---|
| committer | Christoph Groth <christoph.groth@cea.fr> | 2025-01-09 13:58:20 +0100 |
| commit | 0cd1f377cbf1a555d6de211b5ef02f9c65f1db2e (patch) | |
| tree | 4edf58fbdcc520d12a8dec07d2dddef3d9502fe3 /src/main.rs | |
| parent | 22a5aa82f4f2b14313238de770923d0ece8d906a (diff) | |
Make matmul generic over dimensions
Diffstat (limited to 'src/main.rs')
| -rw-r--r-- | src/main.rs | 7 |
1 files changed, 5 insertions, 2 deletions
diff --git a/src/main.rs b/src/main.rs index edd202d..1627396 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,10 @@ -use mdarray::{view, tensor, DSlice, Expression}; +use mdarray::{view, tensor, Slice, Expression, Dim}; // Indexing convention: C_ij <- A_ik * B_kj -fn matmul(a: &DSlice<f64, 2>, b: &DSlice<f64, 2>, c: &mut DSlice<f64, 2>) { +fn matmul<D0: Dim, D1: Dim, D2: Dim>( + a: &Slice<f64, (D0, D1)>, + b: &Slice<f64, (D1, D2)>, + c: &mut Slice<f64, (D0, D2)>) { for (mut ci, ai) in c.rows_mut().zip(a.rows()) { for (aik, bk) in ai.zip(b.rows()) { for (cij, bkj) in ci.expr_mut().zip(bk) { |
