-
Notifications
You must be signed in to change notification settings - Fork 370
Open
Labels
Description
In NumPy, the left-hand side of a matrix multiplication can have as many axes as desired, as long as it has more than 2 axes and the last axis's dimension matches that of the 0th axis of the right-hand side, e.g.:
import numpy as np
x = np.random.random((3, 2, 5, 9, 12))
y = np.random.random((12, 13))
(x @ y).shape
# (3, 2, 5, 9, 13)In ndarray, you can't do this directly:
// Doesn't compile:
use ndarray::prelude::*;
use ndarray_rand::RandomExt;
use ndarray_rand::rand_distr::Uniform;
fn main() {
let x: Array<f64, Ix3> = Array::random(
(12, 4, 3),
Uniform::new(0., 1.).unwrap()
);
let y: Array<f64, Ix2> = Array::random(
(3, 2),
Uniform::new(0., 1.).unwrap()
);
let x_y = x.dot(&y);
println!("{}", x_y);
}Compiler Output
$ cargo run
Compiling playground v0.1.0 (/home/connor/RustroverProjects/playground)
error[E0275]: overflow evaluating the requirement `&ArrayBase<_, _, _>: Not`
--> src/main.rs:15:17
|
15 | let x_y = x.dot(&y);
| ^^^
|
= help: consider increasing the recursion limit by adding a `#![recursion_limit = "256"]` attribute to your crate (`playground`)
= note: required for `&ArrayBase<_, _, _>` to implement `Not`
= note: 127 redundant requirements hidden
= note: required for `&ArrayBase<OwnedRepr<f64>, Dim<[usize; 3]>, f64>` to implement `Not`
For more information about this error, try `rustc --explain E0275`.
error: could not compile `playground` (bin "playground") due to 1 previous error
Emulating the behavior in the previous NumPy example requires a non-trivial amount of work, e.g. for x with 3 axes:
use ndarray::prelude::*;
use ndarray_rand::RandomExt;
use ndarray_rand::rand_distr::Uniform;
fn main() {
let x: Array<f64, Ix3> = Array::random((12, 4, 3), Uniform::new(0., 1.).unwrap());
let y = Array::random((3, 2), Uniform::new(0., 1.).unwrap());
let (a, b, c) = (x.len_of(Axis(0)), x.len_of(Axis(1)), x.len_of(Axis(2)));
let d = y.len_of(Axis(1));
let x_y: Array3<f64> = x
.to_shape((a * b, c)).unwrap()
.dot(&y)
.to_shape((a, b, d)).unwrap()
.to_owned();
println!("{:?}", x_y);
}Therefore, I think it would be nice to have dot() be implemented for axis numbers greater than Ix2
Reactions are currently unavailable