Source code for capytaine.tools.lazy_matrices
"""Lazy matrix where the rows are computed and stored on demand when the matrix-vector product is requested."""
from typing import Callable
import numpy as np
[docs]
def slices(start, stop, chunk_size):
"""Generator returning slices covering the range from start to stop with `chunk_size` elements per slice.
>>> list(slices(0, 50, 10))
[slice(0, 15, None),
slice(15, 30, None),
slice(30, 45, None),
slice(45, 50, None)]
"""
i = start
while i < stop:
batch = slice(i, min(i+chunk_size, stop))
yield batch
i = i + chunk_size
[docs]
class LazyMatrix:
def __init__(self, row_constructor, shape, *, chunk_size=10, dtype=float):
"""
A matrix (2D array) that is never fully stored in memory, but instead recomputed from a `row_constructor` method when required.
Parameters
----------
row_constructor: callable
Function returning a numpy array containing a few rows of the matrix.
We assume that row_constructor(slice(n, n+m)) returns a numpy array of shape (m, d) corresponding to the m rows of indices between n and n+m.
The dtype of the output of row_constructor should match self.dtype.
shape: 2-ple of int
The shape of the matrix.
chunk_size: int
The number of row requested to row_constructor at each call.
dtype: numpy.dtype
The type of data contained in the matrix.
"""
self.row_constructor: Callable[range, np.ndarray] = row_constructor
self.shape = shape
self.chunk_size = chunk_size
self.dtype = dtype
self.ndim = 2 # Other shapes not implemented
self._slices = list(slices(0, self.shape[0], self.chunk_size))
def __array__(self, dtype=None, copy=True):
if not copy:
raise NotImplementedError
if dtype is None:
dtype = self.dtype
rows = [self.row_constructor(sl) for sl in self._slices]
return np.concatenate(rows).astype(dtype)
def __matmul__(self, other):
if isinstance(other, np.ndarray) and other.ndim == 1 and other.shape[0] == self.shape[1]:
# Only matrix-vector product is actually implemented
# Compute `chunk_size` rows and multiply them by `other`
output_chunks = [self.row_constructor(sl) @ other for sl in self._slices]
return np.concatenate(output_chunks)
else:
return NotImplemented
# Usually fallback on building the full matrix with __array__ above.