Batched Linear Operations#

Some of SciPy’s linear algebra functions support N-dimensional array input. These operations have not been mathematically generalized to higher-order tensors; rather, the indicated operation is performed on a batch (or “stack”) of input scalars, vectors, and/or matrices.

Consider the linalg.det function, which maps a matrix to a scalar.

import numpy as np
from scipy import linalg
A = np.eye(3)
linalg.det(A)
np.float64(1.0)

Sometimes we need the determinant of a batch of matrices of the same dimensionality.

batch = [i*np.eye(3) for i in range(1, 4)]
batch
[array([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]]),
 array([[2., 0., 0.],
        [0., 2., 0.],
        [0., 0., 2.]]),
 array([[3., 0., 0.],
        [0., 3., 0.],
        [0., 0., 3.]])]

We could perform the operation for each element of the batch in a loop or list comprehension:

[linalg.det(A) for A in batch]
[np.float64(1.0), np.float64(8.0), np.float64(27.0)]

However, just as we might use NumPy broadcasting and vectorization rules to create the batch of matrices in the first place:

i = np.arange(1, 4).reshape(-1, 1, 1)
batch = i * np.eye(3)
batch
array([[[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]],

       [[2., 0., 0.],
        [0., 2., 0.],
        [0., 0., 2.]],

       [[3., 0., 0.],
        [0., 3., 0.],
        [0., 0., 3.]]])

we might also wish to perform the determinant operation on all of the matrices in one function call.

linalg.det(batch)
array([ 1.,  8., 27.])

In SciPy, we prefer the term “batch” instead of “stack” because the idea is generalized to N-dimensional batches. Suppose the input is a 2 x 4 batch of 3 x 3 matrices.

batch_shape = (2, 4)
i = np.arange(np.prod(batch_shape)).reshape(*batch_shape, 1, 1)
input = i * np.eye(3)

In this case, we say that the batch shape is (2, 4), and the core shape of the input is (3, 3). The net shape of the input is the sum (concatenation) of the batch shape and core shape.

input.shape
(2, 4, 3, 3)

Since each 3 x 3 matrix is converted to a zero-dimensional scalar, we say that the core shape of the outuput is (). The shape of the output is the sum of the batch shape and core shape, so the result is a 2 x 4 array.

output = linalg.det(input)
output
array([[  0.,   1.,   8.,  27.],
       [ 64., 125., 216., 343.]])
output.shape
(2, 4)

Not all linear algebra functions map to scalars. For instance, the scipy.linalg.expm function maps from a matrix to a matrix with the same shape.

A = np.eye(3)
linalg.expm(A)
array([[2.71828183, 0.        , 0.        ],
       [0.        , 2.71828183, 0.        ],
       [0.        , 0.        , 2.71828183]])

In this case, the core shape of the output is (3, 3), so with a batch shape of (2, 4), we expect an output of shape (2, 4, 3, 3).

output = linalg.expm(input)
output.shape
(2, 4, 3, 3)

Generalization of these rules to functions with multiple inputs and outputs is straightforward. For instance, the scipy.linalg.eig function produces two outputs by default, a vector and a matrix.

evals, evecs = linalg.eig(A)
evals.shape, evecs.shape
((3,), (3, 3))

In this case, the core shape of the output vector is (3,) and the core shape of the output matrix is (3, 3). The shape of each output is the batch shape plus the core shape as before.

evals, evecs = linalg.eig(input)
evals.shape, evecs.shape
((2, 4, 3), (2, 4, 3, 3))

When there is more than one input, there is no complication if the input shapes are identical.

evals, evecs = linalg.eig(input, b=input)
evals.shape, evecs.shape
((2, 4, 3), (2, 4, 3, 3))

The rules when the shapes are not identical follow logically. Each input can have its own batch shape as long as the shapes are broadcastable according to NumPy’s broadcasting rules. The net batch shape is the broadcasted shape of the individual batch shapes, and the shape of each output is the net batch shape plus its core shape.

rng = np.random.default_rng(2859239482)

# Define input core shapes
m = 3
core_shape_a = (m, m)
core_shape_b = (m, m)

# Define broadcastable batch shapes
batch_shape_a = (2, 4)
batch_shape_b = (5, 1, 4)

# Define output core shapes
core_shape_evals = (m,)
core_shape_evecs = (m, m)

# Predict shapes of outputs: broadcast batch shapes,
# and append output core shapes
net_batch_shape = np.broadcast_shapes(batch_shape_a, batch_shape_b)
output_shape_evals = net_batch_shape + core_shape_evals
output_shape_evecs = net_batch_shape + core_shape_evecs
output_shape_evals, output_shape_evecs
((5, 2, 4, 3), (5, 2, 4, 3, 3))
# Check predictions
input_a = rng.random(batch_shape_a + core_shape_a)
input_b = rng.random(batch_shape_b + core_shape_b)
evals, evecs = linalg.eig(input_a, b=input_b)
evals.shape, evecs.shape
((5, 2, 4, 3), (5, 2, 4, 3, 3))