spateo.alignment.methods.backend_ot

Multi-lib backend for POT

The goal is to write backend-agnostic code. Whether you’re using Numpy, PyTorch, Jax, Cupy, or Tensorflow, POT code should work nonetheless. To achieve that, POT provides backend classes which implements functions in their respective backend imitating Numpy API. As a convention, we use nx instead of np to refer to the backend.

Examples

>>> from ot.utils import list_to_array
>>> from ot.backend import get_backend
>>> def f(a, b):  # the function does not know which backend to use
...     a, b = list_to_array(a, b)  # if a list in given, make it an array
...     nx = get_backend(a, b)  # infer the backend from the arguments
...     c = nx.dot(a, b)  # now use the backend to do any calculation
...     return c

Warning

Tensorflow only works with the Numpy API. To activate it, please run the following:

from tensorflow.python.ops.numpy_ops import np_config
np_config.enable_numpy_behavior()

Performance

Sinkhorn Knopp - Averaged on 100 runs
Bitsize32 bits
DeviceCPUGPU
Sample sizeNumpyPytorchTensorflowCupyJaxPytorchTensorflow
500.00080.00220.01510.00950.01930.00510.0293
1000.00050.00130.00970.00570.01150.00290.0173
5000.00090.00160.01100.00580.01150.00290.0166
10000.00210.00210.01450.00560.01180.00290.0168
20000.00690.00430.02780.00590.01180.00300.0165
50000.07070.03140.13950.00740.01250.00350.0198
 
Bitsize64 bits
DeviceCPUGPU
Sample sizeNumpyPytorchTensorflowCupyJaxPytorchTensorflow
500.00080.00200.01540.00930.01910.00510.0328
1000.00050.00130.00940.00560.01140.00290.0169
5000.00130.00170.01200.00590.01160.00290.0168
10000.00340.00270.01770.00580.01180.00290.0167
20000.01460.00750.04360.00590.01200.00290.0165
50000.14670.05680.24680.00770.01460.00450.0204

Attributes

Classes

Backend

Backend abstract class.

NumpyBackend

NumPy implementation of the backend

JaxBackend

JAX implementation of the backend

TorchBackend

PyTorch implementation of the backend

CupyBackend

CuPy implementation of the backend

TensorflowBackend

Backend abstract class.

Functions

_register_backend_implementation(backend_impl)

_get_backend_instance(backend_impl)

_check_args_backend(backend_impl, args)

get_backend_list()

Returns instances of all available backends.

get_available_backend_implementations()

Returns the list of available backend implementations.

get_backend(*args)

Returns the proper backend for a list of input arrays

to_numpy(*args)

Returns numpy arrays from any compatible backend

Module Contents

spateo.alignment.methods.backend_ot.DISABLE_TORCH_KEY = 'POT_BACKEND_DISABLE_PYTORCH'[source]
spateo.alignment.methods.backend_ot.DISABLE_JAX_KEY = 'POT_BACKEND_DISABLE_JAX'[source]
spateo.alignment.methods.backend_ot.DISABLE_CUPY_KEY = 'POT_BACKEND_DISABLE_CUPY'[source]
spateo.alignment.methods.backend_ot.DISABLE_TF_KEY = 'POT_BACKEND_DISABLE_TENSORFLOW'[source]
spateo.alignment.methods.backend_ot.torch_type[source]
spateo.alignment.methods.backend_ot.jax_type[source]
spateo.alignment.methods.backend_ot.cp_type[source]
spateo.alignment.methods.backend_ot.tf_type[source]
spateo.alignment.methods.backend_ot.str_type_error = 'All array should be from the same type/backend. Current types are : {}'[source]
spateo.alignment.methods.backend_ot._BACKEND_IMPLEMENTATIONS = [][source]
spateo.alignment.methods.backend_ot._BACKENDS[source]
spateo.alignment.methods.backend_ot._register_backend_implementation(backend_impl)[source]
spateo.alignment.methods.backend_ot._get_backend_instance(backend_impl)[source]
spateo.alignment.methods.backend_ot._check_args_backend(backend_impl, args)[source]
spateo.alignment.methods.backend_ot.get_backend_list()[source]

Returns instances of all available backends.

Note that the function forces all detected implementations to be instantiated even if specific backend was not use before. Be careful as instantiation of the backend might lead to side effects, like GPU memory pre-allocation. See the documentation for more details. If you only need to know which implementations are available, use :py:func:`ot.backend.get_available_backend_implementations, which does not force instance of the backend object to be created.

spateo.alignment.methods.backend_ot.get_available_backend_implementations()[source]

Returns the list of available backend implementations.

spateo.alignment.methods.backend_ot.get_backend(*args)[source]

Returns the proper backend for a list of input arrays

Accepts None entries in the arguments, and ignores them

Also raises TypeError if all arrays are not from the same backend

spateo.alignment.methods.backend_ot.to_numpy(*args)[source]

Returns numpy arrays from any compatible backend

class spateo.alignment.methods.backend_ot.Backend[source]

Backend abstract class. Implementations: JaxBackend, NumpyBackend, TorchBackend, CupyBackend, TensorflowBackend

  • The __name__ class attribute refers to the name of the backend.

  • The __type__ class attribute refers to the data structure used by the backend.

__name__ = None[source]
__type__ = None[source]
__type_list__ = None[source]
rng_ = None[source]
__str__()[source]
to_numpy(*arrays)[source]

Returns the numpy version of tensors

abstract _to_numpy(a)[source]

Returns the numpy version of a tensor

from_numpy(*arrays, type_as=None)[source]

Creates tensors cloning a numpy array, with the given precision (defaulting to input’s precision) and the given device (in case of GPUs)

abstract _from_numpy(a, type_as=None)[source]

Creates a tensor cloning a numpy array, with the given precision (defaulting to input’s precision) and the given device (in case of GPUs)

abstract set_gradients(val, inputs, grads)[source]

Define the gradients for the value val wrt the inputs

detach(*arrays)[source]

Detach the tensors from the computation graph

See: https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html

abstract _detach(a)[source]

Detach the tensor from the computation graph

abstract zeros(shape, type_as=None)[source]

Creates a tensor full of zeros.

This function follows the api from numpy.zeros

See: https://numpy.org/doc/stable/reference/generated/numpy.zeros.html

abstract ones(shape, type_as=None)[source]

Creates a tensor full of ones.

This function follows the api from numpy.ones

See: https://numpy.org/doc/stable/reference/generated/numpy.ones.html

abstract arange(stop, start=0, step=1, type_as=None)[source]

Returns evenly spaced values within a given interval.

This function follows the api from numpy.arange

See: https://numpy.org/doc/stable/reference/generated/numpy.arange.html

abstract full(shape, fill_value, type_as=None)[source]

Creates a tensor with given shape, filled with given value.

This function follows the api from numpy.full

See: https://numpy.org/doc/stable/reference/generated/numpy.full.html

abstract eye(N, M=None, type_as=None)[source]

Creates the identity matrix of given size.

This function follows the api from numpy.eye

See: https://numpy.org/doc/stable/reference/generated/numpy.eye.html

abstract sum(a, axis=None, keepdims=False)[source]

Sums tensor elements over given dimensions.

This function follows the api from numpy.sum

See: https://numpy.org/doc/stable/reference/generated/numpy.sum.html

abstract cumsum(a, axis=None)[source]

Returns the cumulative sum of tensor elements over given dimensions.

This function follows the api from numpy.cumsum

See: https://numpy.org/doc/stable/reference/generated/numpy.cumsum.html

abstract max(a, axis=None, keepdims=False)[source]

Returns the maximum of an array or maximum along given dimensions.

This function follows the api from numpy.amax

See: https://numpy.org/doc/stable/reference/generated/numpy.amax.html

abstract min(a, axis=None, keepdims=False)[source]

Returns the maximum of an array or maximum along given dimensions.

This function follows the api from numpy.amin

See: https://numpy.org/doc/stable/reference/generated/numpy.amin.html

abstract maximum(a, b)[source]

Returns element-wise maximum of array elements.

This function follows the api from numpy.maximum

See: https://numpy.org/doc/stable/reference/generated/numpy.maximum.html

abstract minimum(a, b)[source]

Returns element-wise minimum of array elements.

This function follows the api from numpy.minimum

See: https://numpy.org/doc/stable/reference/generated/numpy.minimum.html

abstract sign(a)[source]

Returns an element-wise indication of the sign of a number.

This function follows the api from numpy.sign

See: https://numpy.org/doc/stable/reference/generated/numpy.sign.html

abstract dot(a, b)[source]

Returns the dot product of two tensors.

This function follows the api from numpy.dot

See: https://numpy.org/doc/stable/reference/generated/numpy.dot.html

abstract abs(a)[source]

Computes the absolute value element-wise.

This function follows the api from numpy.absolute

See: https://numpy.org/doc/stable/reference/generated/numpy.absolute.html

abstract exp(a)[source]

Computes the exponential value element-wise.

This function follows the api from numpy.exp

See: https://numpy.org/doc/stable/reference/generated/numpy.exp.html

abstract log(a)[source]

Computes the natural logarithm, element-wise.

This function follows the api from numpy.log

See: https://numpy.org/doc/stable/reference/generated/numpy.log.html

abstract sqrt(a)[source]

Returns the non-ngeative square root of a tensor, element-wise.

This function follows the api from numpy.sqrt

See: https://numpy.org/doc/stable/reference/generated/numpy.sqrt.html

abstract power(a, exponents)[source]

First tensor elements raised to powers from second tensor, element-wise.

This function follows the api from numpy.power

See: https://numpy.org/doc/stable/reference/generated/numpy.power.html

abstract norm(a, axis=None, keepdims=False)[source]

Computes the matrix frobenius norm.

This function follows the api from numpy.linalg.norm

See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.norm.html

abstract any(a)[source]

Tests whether any tensor element along given dimensions evaluates to True.

This function follows the api from numpy.any

See: https://numpy.org/doc/stable/reference/generated/numpy.any.html

abstract isnan(a)[source]

Tests element-wise for NaN and returns result as a boolean tensor.

This function follows the api from numpy.isnan

See: https://numpy.org/doc/stable/reference/generated/numpy.isnan.html

abstract isinf(a)[source]

Tests element-wise for positive or negative infinity and returns result as a boolean tensor.

This function follows the api from numpy.isinf

See: https://numpy.org/doc/stable/reference/generated/numpy.isinf.html

abstract einsum(subscripts, *operands)[source]

Evaluates the Einstein summation convention on the operands.

This function follows the api from numpy.einsum

See: https://numpy.org/doc/stable/reference/generated/numpy.einsum.html

abstract sort(a, axis=-1)[source]

Returns a sorted copy of a tensor.

This function follows the api from numpy.sort

See: https://numpy.org/doc/stable/reference/generated/numpy.sort.html

abstract argsort(a, axis=None)[source]

Returns the indices that would sort a tensor.

This function follows the api from numpy.argsort

See: https://numpy.org/doc/stable/reference/generated/numpy.argsort.html

abstract searchsorted(a, v, side='left')[source]

Finds indices where elements should be inserted to maintain order in given tensor.

This function follows the api from numpy.searchsorted

See: https://numpy.org/doc/stable/reference/generated/numpy.searchsorted.html

abstract flip(a, axis=None)[source]

Reverses the order of elements in a tensor along given dimensions.

This function follows the api from numpy.flip

See: https://numpy.org/doc/stable/reference/generated/numpy.flip.html

abstract clip(a, a_min, a_max)[source]

Limits the values in a tensor.

This function follows the api from numpy.clip

See: https://numpy.org/doc/stable/reference/generated/numpy.clip.html

abstract repeat(a, repeats, axis=None)[source]

Repeats elements of a tensor.

This function follows the api from numpy.repeat

See: https://numpy.org/doc/stable/reference/generated/numpy.repeat.html

abstract take_along_axis(arr, indices, axis)[source]

Gathers elements of a tensor along given dimensions.

This function follows the api from numpy.take_along_axis

See: https://numpy.org/doc/stable/reference/generated/numpy.take_along_axis.html

abstract concatenate(arrays, axis=0)[source]

Joins a sequence of tensors along an existing dimension.

This function follows the api from numpy.concatenate

See: https://numpy.org/doc/stable/reference/generated/numpy.concatenate.html

abstract zero_pad(a, pad_width, value=0)[source]

Pads a tensor with a given value (0 by default).

This function follows the api from numpy.pad

See: https://numpy.org/doc/stable/reference/generated/numpy.pad.html

abstract argmax(a, axis=None)[source]

Returns the indices of the maximum values of a tensor along given dimensions.

This function follows the api from numpy.argmax

See: https://numpy.org/doc/stable/reference/generated/numpy.argmax.html

abstract argmin(a, axis=None)[source]

Returns the indices of the minimum values of a tensor along given dimensions.

This function follows the api from numpy.argmin

See: https://numpy.org/doc/stable/reference/generated/numpy.argmin.html

abstract mean(a, axis=None)[source]

Computes the arithmetic mean of a tensor along given dimensions.

This function follows the api from numpy.mean

See: https://numpy.org/doc/stable/reference/generated/numpy.mean.html

abstract median(a, axis=None)[source]

Computes the median of a tensor along given dimensions.

This function follows the api from numpy.median

See: https://numpy.org/doc/stable/reference/generated/numpy.median.html

abstract std(a, axis=None)[source]

Computes the standard deviation of a tensor along given dimensions.

This function follows the api from numpy.std

See: https://numpy.org/doc/stable/reference/generated/numpy.std.html

abstract linspace(start, stop, num, type_as=None)[source]

Returns a specified number of evenly spaced values over a given interval.

This function follows the api from numpy.linspace

See: https://numpy.org/doc/stable/reference/generated/numpy.linspace.html

abstract meshgrid(a, b)[source]

Returns coordinate matrices from coordinate vectors (Numpy convention).

This function follows the api from numpy.meshgrid

See: https://numpy.org/doc/stable/reference/generated/numpy.meshgrid.html

abstract diag(a, k=0)[source]

Extracts or constructs a diagonal tensor.

This function follows the api from numpy.diag

See: https://numpy.org/doc/stable/reference/generated/numpy.diag.html

abstract unique(a, return_inverse=False)[source]

Finds unique elements of given tensor.

This function follows the api from numpy.unique

See: https://numpy.org/doc/stable/reference/generated/numpy.unique.html

abstract logsumexp(a, axis=None)[source]

Computes the log of the sum of exponentials of input elements.

This function follows the api from scipy.special.logsumexp

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.logsumexp.html

abstract stack(arrays, axis=0)[source]

Joins a sequence of tensors along a new dimension.

This function follows the api from numpy.stack

See: https://numpy.org/doc/stable/reference/generated/numpy.stack.html

abstract outer(a, b)[source]

Computes the outer product between two vectors.

This function follows the api from numpy.outer

See: https://numpy.org/doc/stable/reference/generated/numpy.outer.html

abstract reshape(a, shape)[source]

Gives a new shape to a tensor without changing its data.

This function follows the api from numpy.reshape

See: https://numpy.org/doc/stable/reference/generated/numpy.reshape.html

abstract seed(seed=None)[source]

Sets the seed for the random generator.

This function follows the api from numpy.random.seed

See: https://numpy.org/doc/stable/reference/random/generated/numpy.random.seed.html

abstract rand(*size, type_as=None)[source]

Generate uniform random numbers.

This function follows the api from numpy.random.rand

See: https://numpy.org/doc/stable/reference/random/generated/numpy.random.rand.html

abstract randn(*size, type_as=None)[source]

Generate normal Gaussian random numbers.

This function follows the api from numpy.random.rand

See: https://numpy.org/doc/stable/reference/random/generated/numpy.random.rand.html

abstract coo_matrix(data, rows, cols, shape=None, type_as=None)[source]

Creates a sparse tensor in COOrdinate format.

This function follows the api from scipy.sparse.coo_matrix

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.html

abstract issparse(a)[source]

Checks whether or not the input tensor is a sparse tensor.

This function follows the api from scipy.sparse.issparse

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.issparse.html

abstract tocsr(a)[source]

Converts this matrix to Compressed Sparse Row format.

This function follows the api from scipy.sparse.coo_matrix.tocsr

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.tocsr.html

abstract eliminate_zeros(a, threshold=0.0)[source]

Removes entries smaller than the given threshold from the sparse tensor.

This function follows the api from scipy.sparse.csr_matrix.eliminate_zeros

See: https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.sparse.csr_matrix.eliminate_zeros.html

abstract todense(a)[source]

Converts a sparse tensor to a dense tensor.

This function follows the api from scipy.sparse.csr_matrix.toarray

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.toarray.html

abstract where(condition, x, y)[source]

Returns elements chosen from x or y depending on condition.

This function follows the api from numpy.where

See: https://numpy.org/doc/stable/reference/generated/numpy.where.html

abstract copy(a)[source]

Returns a copy of the given tensor.

This function follows the api from numpy.copy

See: https://numpy.org/doc/stable/reference/generated/numpy.copy.html

abstract allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False)[source]

Returns True if two arrays are element-wise equal within a tolerance.

This function follows the api from numpy.allclose

See: https://numpy.org/doc/stable/reference/generated/numpy.allclose.html

abstract dtype_device(a)[source]

Returns the dtype and the device of the given tensor.

abstract assert_same_dtype_device(a, b)[source]

Checks whether or not the two given inputs have the same dtype as well as the same device

abstract squeeze(a, axis=None)[source]

Remove axes of length one from a.

This function follows the api from numpy.squeeze.

See: https://numpy.org/doc/stable/reference/generated/numpy.squeeze.html

abstract bitsize(type_as)[source]

Gives the number of bits used by the data type of the given tensor.

abstract device_type(type_as)[source]

Returns CPU or GPU depending on the device where the given tensor is located.

abstract _bench(callable, *args, n_runs=1, warmup_runs=1)[source]

Executes a benchmark of the given callable with the given arguments.

abstract solve(a, b)[source]

Solves a linear matrix equation, or system of linear scalar equations.

This function follows the api from numpy.linalg.solve.

See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.solve.html

abstract trace(a)[source]

Returns the sum along diagonals of the array.

This function follows the api from numpy.trace.

See: https://numpy.org/doc/stable/reference/generated/numpy.trace.html

abstract inv(a)[source]

Computes the inverse of a matrix.

This function follows the api from scipy.linalg.inv.

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.inv.html

abstract sqrtm(a)[source]

Computes the matrix square root. Requires input to be definite positive.

This function follows the api from scipy.linalg.sqrtm.

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.sqrtm.html

abstract eigh(a)[source]

Computes the eigenvalues and eigenvectors of a symmetric tensor.

This function follows the api from scipy.linalg.eigh.

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.eigh.html

abstract kl_div(p, q, mass=False, eps=1e-16)[source]

Computes the (Generalized) Kullback-Leibler divergence.

This function follows the api from scipy.stats.entropy.

Parameter eps is used to avoid numerical errors and is added in the log.

\[KL(p,q) = \langle \mathbf{p}, log(\mathbf{p} / \mathbf{q} + eps \rangle + \mathbb{1}_{mass=True} \langle \mathbf{q} - \mathbf{p}, \mathbf{1} \rangle\]

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.entropy.html

abstract isfinite(a)[source]

Tests element-wise for finiteness (not infinity and not Not a Number).

This function follows the api from numpy.isfinite.

See: https://numpy.org/doc/stable/reference/generated/numpy.isfinite.html

abstract array_equal(a, b)[source]

True if two arrays have the same shape and elements, False otherwise.

This function follows the api from numpy.array_equal.

See: https://numpy.org/doc/stable/reference/generated/numpy.array_equal.html

abstract is_floating_point(a)[source]

Returns whether or not the input consists of floats

abstract tile(a, reps)[source]

Construct an array by repeating a the number of times given by reps

See: https://numpy.org/doc/stable/reference/generated/numpy.tile.html

abstract floor(a)[source]

Return the floor of the input element-wise

See: https://numpy.org/doc/stable/reference/generated/numpy.floor.html

abstract prod(a, axis=None)[source]

Return the product of all elements.

See: https://numpy.org/doc/stable/reference/generated/numpy.prod.html

abstract sort2(a, axis=None)[source]

Return the sorted array and the indices to sort the array

See: https://pytorch.org/docs/stable/generated/torch.sort.html

abstract qr(a)[source]

Return the QR factorization

See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.qr.html

abstract atan2(a, b)[source]

Element wise arctangent

See: https://numpy.org/doc/stable/reference/generated/numpy.arctan2.html

abstract transpose(a, axes=None)[source]

Returns a tensor that is a transposed version of a. The given dimensions dim0 and dim1 are swapped.

See: https://numpy.org/doc/stable/reference/generated/numpy.transpose.html

abstract matmul(a, b)[source]

Matrix product of two arrays.

See: https://numpy.org/doc/stable/reference/generated/numpy.matmul.html#numpy.matmul

abstract nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None)[source]

Replace NaN with zero and infinity with large finite numbers or with the numbers defined by the user.

See: https://numpy.org/doc/stable/reference/generated/numpy.nan_to_num.html#numpy.nan_to_num

class spateo.alignment.methods.backend_ot.NumpyBackend[source]

Bases: Backend

NumPy implementation of the backend

  • __name__ is “numpy”

  • __type__ is np.ndarray

__name__ = 'numpy'[source]
__type__[source]
__type_list__[source]
rng_[source]
_to_numpy(a)[source]

Returns the numpy version of a tensor

_from_numpy(a, type_as=None)[source]

Creates a tensor cloning a numpy array, with the given precision (defaulting to input’s precision) and the given device (in case of GPUs)

set_gradients(val, inputs, grads)[source]

Define the gradients for the value val wrt the inputs

_detach(a)[source]

Detach the tensor from the computation graph

zeros(shape, type_as=None)[source]

Creates a tensor full of zeros.

This function follows the api from numpy.zeros

See: https://numpy.org/doc/stable/reference/generated/numpy.zeros.html

ones(shape, type_as=None)[source]

Creates a tensor full of ones.

This function follows the api from numpy.ones

See: https://numpy.org/doc/stable/reference/generated/numpy.ones.html

arange(stop, start=0, step=1, type_as=None)[source]

Returns evenly spaced values within a given interval.

This function follows the api from numpy.arange

See: https://numpy.org/doc/stable/reference/generated/numpy.arange.html

full(shape, fill_value, type_as=None)[source]

Creates a tensor with given shape, filled with given value.

This function follows the api from numpy.full

See: https://numpy.org/doc/stable/reference/generated/numpy.full.html

eye(N, M=None, type_as=None)[source]

Creates the identity matrix of given size.

This function follows the api from numpy.eye

See: https://numpy.org/doc/stable/reference/generated/numpy.eye.html

sum(a, axis=None, keepdims=False)[source]

Sums tensor elements over given dimensions.

This function follows the api from numpy.sum

See: https://numpy.org/doc/stable/reference/generated/numpy.sum.html

cumsum(a, axis=None)[source]

Returns the cumulative sum of tensor elements over given dimensions.

This function follows the api from numpy.cumsum

See: https://numpy.org/doc/stable/reference/generated/numpy.cumsum.html

max(a, axis=None, keepdims=False)[source]

Returns the maximum of an array or maximum along given dimensions.

This function follows the api from numpy.amax

See: https://numpy.org/doc/stable/reference/generated/numpy.amax.html

min(a, axis=None, keepdims=False)[source]

Returns the maximum of an array or maximum along given dimensions.

This function follows the api from numpy.amin

See: https://numpy.org/doc/stable/reference/generated/numpy.amin.html

maximum(a, b)[source]

Returns element-wise maximum of array elements.

This function follows the api from numpy.maximum

See: https://numpy.org/doc/stable/reference/generated/numpy.maximum.html

minimum(a, b)[source]

Returns element-wise minimum of array elements.

This function follows the api from numpy.minimum

See: https://numpy.org/doc/stable/reference/generated/numpy.minimum.html

sign(a)[source]

Returns an element-wise indication of the sign of a number.

This function follows the api from numpy.sign

See: https://numpy.org/doc/stable/reference/generated/numpy.sign.html

dot(a, b)[source]

Returns the dot product of two tensors.

This function follows the api from numpy.dot

See: https://numpy.org/doc/stable/reference/generated/numpy.dot.html

abs(a)[source]

Computes the absolute value element-wise.

This function follows the api from numpy.absolute

See: https://numpy.org/doc/stable/reference/generated/numpy.absolute.html

exp(a)[source]

Computes the exponential value element-wise.

This function follows the api from numpy.exp

See: https://numpy.org/doc/stable/reference/generated/numpy.exp.html

log(a)[source]

Computes the natural logarithm, element-wise.

This function follows the api from numpy.log

See: https://numpy.org/doc/stable/reference/generated/numpy.log.html

sqrt(a)[source]

Returns the non-ngeative square root of a tensor, element-wise.

This function follows the api from numpy.sqrt

See: https://numpy.org/doc/stable/reference/generated/numpy.sqrt.html

power(a, exponents)[source]

First tensor elements raised to powers from second tensor, element-wise.

This function follows the api from numpy.power

See: https://numpy.org/doc/stable/reference/generated/numpy.power.html

norm(a, axis=None, keepdims=False)[source]

Computes the matrix frobenius norm.

This function follows the api from numpy.linalg.norm

See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.norm.html

any(a)[source]

Tests whether any tensor element along given dimensions evaluates to True.

This function follows the api from numpy.any

See: https://numpy.org/doc/stable/reference/generated/numpy.any.html

isnan(a)[source]

Tests element-wise for NaN and returns result as a boolean tensor.

This function follows the api from numpy.isnan

See: https://numpy.org/doc/stable/reference/generated/numpy.isnan.html

isinf(a)[source]

Tests element-wise for positive or negative infinity and returns result as a boolean tensor.

This function follows the api from numpy.isinf

See: https://numpy.org/doc/stable/reference/generated/numpy.isinf.html

einsum(subscripts, *operands)[source]

Evaluates the Einstein summation convention on the operands.

This function follows the api from numpy.einsum

See: https://numpy.org/doc/stable/reference/generated/numpy.einsum.html

sort(a, axis=-1)[source]

Returns a sorted copy of a tensor.

This function follows the api from numpy.sort

See: https://numpy.org/doc/stable/reference/generated/numpy.sort.html

argsort(a, axis=-1)[source]

Returns the indices that would sort a tensor.

This function follows the api from numpy.argsort

See: https://numpy.org/doc/stable/reference/generated/numpy.argsort.html

searchsorted(a, v, side='left')[source]

Finds indices where elements should be inserted to maintain order in given tensor.

This function follows the api from numpy.searchsorted

See: https://numpy.org/doc/stable/reference/generated/numpy.searchsorted.html

flip(a, axis=None)[source]

Reverses the order of elements in a tensor along given dimensions.

This function follows the api from numpy.flip

See: https://numpy.org/doc/stable/reference/generated/numpy.flip.html

outer(a, b)[source]

Computes the outer product between two vectors.

This function follows the api from numpy.outer

See: https://numpy.org/doc/stable/reference/generated/numpy.outer.html

clip(a, a_min, a_max)[source]

Limits the values in a tensor.

This function follows the api from numpy.clip

See: https://numpy.org/doc/stable/reference/generated/numpy.clip.html

repeat(a, repeats, axis=None)[source]

Repeats elements of a tensor.

This function follows the api from numpy.repeat

See: https://numpy.org/doc/stable/reference/generated/numpy.repeat.html

take_along_axis(arr, indices, axis)[source]

Gathers elements of a tensor along given dimensions.

This function follows the api from numpy.take_along_axis

See: https://numpy.org/doc/stable/reference/generated/numpy.take_along_axis.html

concatenate(arrays, axis=0)[source]

Joins a sequence of tensors along an existing dimension.

This function follows the api from numpy.concatenate

See: https://numpy.org/doc/stable/reference/generated/numpy.concatenate.html

zero_pad(a, pad_width, value=0)[source]

Pads a tensor with a given value (0 by default).

This function follows the api from numpy.pad

See: https://numpy.org/doc/stable/reference/generated/numpy.pad.html

argmax(a, axis=None)[source]

Returns the indices of the maximum values of a tensor along given dimensions.

This function follows the api from numpy.argmax

See: https://numpy.org/doc/stable/reference/generated/numpy.argmax.html

argmin(a, axis=None)[source]

Returns the indices of the minimum values of a tensor along given dimensions.

This function follows the api from numpy.argmin

See: https://numpy.org/doc/stable/reference/generated/numpy.argmin.html

mean(a, axis=None)[source]

Computes the arithmetic mean of a tensor along given dimensions.

This function follows the api from numpy.mean

See: https://numpy.org/doc/stable/reference/generated/numpy.mean.html

median(a, axis=None)[source]

Computes the median of a tensor along given dimensions.

This function follows the api from numpy.median

See: https://numpy.org/doc/stable/reference/generated/numpy.median.html

std(a, axis=None)[source]

Computes the standard deviation of a tensor along given dimensions.

This function follows the api from numpy.std

See: https://numpy.org/doc/stable/reference/generated/numpy.std.html

linspace(start, stop, num, type_as=None)[source]

Returns a specified number of evenly spaced values over a given interval.

This function follows the api from numpy.linspace

See: https://numpy.org/doc/stable/reference/generated/numpy.linspace.html

meshgrid(a, b)[source]

Returns coordinate matrices from coordinate vectors (Numpy convention).

This function follows the api from numpy.meshgrid

See: https://numpy.org/doc/stable/reference/generated/numpy.meshgrid.html

diag(a, k=0)[source]

Extracts or constructs a diagonal tensor.

This function follows the api from numpy.diag

See: https://numpy.org/doc/stable/reference/generated/numpy.diag.html

unique(a, return_inverse=False)[source]

Finds unique elements of given tensor.

This function follows the api from numpy.unique

See: https://numpy.org/doc/stable/reference/generated/numpy.unique.html

logsumexp(a, axis=None)[source]

Computes the log of the sum of exponentials of input elements.

This function follows the api from scipy.special.logsumexp

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.logsumexp.html

stack(arrays, axis=0)[source]

Joins a sequence of tensors along a new dimension.

This function follows the api from numpy.stack

See: https://numpy.org/doc/stable/reference/generated/numpy.stack.html

reshape(a, shape)[source]

Gives a new shape to a tensor without changing its data.

This function follows the api from numpy.reshape

See: https://numpy.org/doc/stable/reference/generated/numpy.reshape.html

seed(seed=None)[source]

Sets the seed for the random generator.

This function follows the api from numpy.random.seed

See: https://numpy.org/doc/stable/reference/random/generated/numpy.random.seed.html

rand(*size, type_as=None)[source]

Generate uniform random numbers.

This function follows the api from numpy.random.rand

See: https://numpy.org/doc/stable/reference/random/generated/numpy.random.rand.html

randn(*size, type_as=None)[source]

Generate normal Gaussian random numbers.

This function follows the api from numpy.random.rand

See: https://numpy.org/doc/stable/reference/random/generated/numpy.random.rand.html

coo_matrix(data, rows, cols, shape=None, type_as=None)[source]

Creates a sparse tensor in COOrdinate format.

This function follows the api from scipy.sparse.coo_matrix

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.html

issparse(a)[source]

Checks whether or not the input tensor is a sparse tensor.

This function follows the api from scipy.sparse.issparse

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.issparse.html

tocsr(a)[source]

Converts this matrix to Compressed Sparse Row format.

This function follows the api from scipy.sparse.coo_matrix.tocsr

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.tocsr.html

eliminate_zeros(a, threshold=0.0)[source]

Removes entries smaller than the given threshold from the sparse tensor.

This function follows the api from scipy.sparse.csr_matrix.eliminate_zeros

See: https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.sparse.csr_matrix.eliminate_zeros.html

todense(a)[source]

Converts a sparse tensor to a dense tensor.

This function follows the api from scipy.sparse.csr_matrix.toarray

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.toarray.html

where(condition, x=None, y=None)[source]

Returns elements chosen from x or y depending on condition.

This function follows the api from numpy.where

See: https://numpy.org/doc/stable/reference/generated/numpy.where.html

copy(a)[source]

Returns a copy of the given tensor.

This function follows the api from numpy.copy

See: https://numpy.org/doc/stable/reference/generated/numpy.copy.html

allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False)[source]

Returns True if two arrays are element-wise equal within a tolerance.

This function follows the api from numpy.allclose

See: https://numpy.org/doc/stable/reference/generated/numpy.allclose.html

dtype_device(a)[source]

Returns the dtype and the device of the given tensor.

assert_same_dtype_device(a, b)[source]

Checks whether or not the two given inputs have the same dtype as well as the same device

squeeze(a, axis=None)[source]

Remove axes of length one from a.

This function follows the api from numpy.squeeze.

See: https://numpy.org/doc/stable/reference/generated/numpy.squeeze.html

bitsize(type_as)[source]

Gives the number of bits used by the data type of the given tensor.

device_type(type_as)[source]

Returns CPU or GPU depending on the device where the given tensor is located.

_bench(callable, *args, n_runs=1, warmup_runs=1)[source]

Executes a benchmark of the given callable with the given arguments.

solve(a, b)[source]

Solves a linear matrix equation, or system of linear scalar equations.

This function follows the api from numpy.linalg.solve.

See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.solve.html

trace(a)[source]

Returns the sum along diagonals of the array.

This function follows the api from numpy.trace.

See: https://numpy.org/doc/stable/reference/generated/numpy.trace.html

inv(a)[source]

Computes the inverse of a matrix.

This function follows the api from scipy.linalg.inv.

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.inv.html

sqrtm(a)[source]

Computes the matrix square root. Requires input to be definite positive.

This function follows the api from scipy.linalg.sqrtm.

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.sqrtm.html

eigh(a)[source]

Computes the eigenvalues and eigenvectors of a symmetric tensor.

This function follows the api from scipy.linalg.eigh.

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.eigh.html

kl_div(p, q, mass=False, eps=1e-16)[source]

Computes the (Generalized) Kullback-Leibler divergence.

This function follows the api from scipy.stats.entropy.

Parameter eps is used to avoid numerical errors and is added in the log.

\[KL(p,q) = \langle \mathbf{p}, log(\mathbf{p} / \mathbf{q} + eps \rangle + \mathbb{1}_{mass=True} \langle \mathbf{q} - \mathbf{p}, \mathbf{1} \rangle\]

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.entropy.html

isfinite(a)[source]

Tests element-wise for finiteness (not infinity and not Not a Number).

This function follows the api from numpy.isfinite.

See: https://numpy.org/doc/stable/reference/generated/numpy.isfinite.html

array_equal(a, b)[source]

True if two arrays have the same shape and elements, False otherwise.

This function follows the api from numpy.array_equal.

See: https://numpy.org/doc/stable/reference/generated/numpy.array_equal.html

is_floating_point(a)[source]

Returns whether or not the input consists of floats

tile(a, reps)[source]

Construct an array by repeating a the number of times given by reps

See: https://numpy.org/doc/stable/reference/generated/numpy.tile.html

floor(a)[source]

Return the floor of the input element-wise

See: https://numpy.org/doc/stable/reference/generated/numpy.floor.html

prod(a, axis=0)[source]

Return the product of all elements.

See: https://numpy.org/doc/stable/reference/generated/numpy.prod.html

sort2(a, axis=-1)[source]

Return the sorted array and the indices to sort the array

See: https://pytorch.org/docs/stable/generated/torch.sort.html

qr(a)[source]

Return the QR factorization

See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.qr.html

atan2(a, b)[source]

Element wise arctangent

See: https://numpy.org/doc/stable/reference/generated/numpy.arctan2.html

transpose(a, axes=None)[source]

Returns a tensor that is a transposed version of a. The given dimensions dim0 and dim1 are swapped.

See: https://numpy.org/doc/stable/reference/generated/numpy.transpose.html

matmul(a, b)[source]

Matrix product of two arrays.

See: https://numpy.org/doc/stable/reference/generated/numpy.matmul.html#numpy.matmul

nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None)[source]

Replace NaN with zero and infinity with large finite numbers or with the numbers defined by the user.

See: https://numpy.org/doc/stable/reference/generated/numpy.nan_to_num.html#numpy.nan_to_num

class spateo.alignment.methods.backend_ot.JaxBackend[source]

Bases: Backend

JAX implementation of the backend

  • __name__ is “jax”

  • __type__ is jax.numpy.ndarray

__name__ = 'jax'[source]
__type__[source]
__type_list__ = None[source]
rng_ = None[source]
jax_new_version[source]
_to_numpy(a)[source]

Returns the numpy version of a tensor

_get_device(a)[source]
_change_device(a, type_as)[source]
_from_numpy(a, type_as=None)[source]

Creates a tensor cloning a numpy array, with the given precision (defaulting to input’s precision) and the given device (in case of GPUs)

set_gradients(val, inputs, grads)[source]

Define the gradients for the value val wrt the inputs

_detach(a)[source]

Detach the tensor from the computation graph

zeros(shape, type_as=None)[source]

Creates a tensor full of zeros.

This function follows the api from numpy.zeros

See: https://numpy.org/doc/stable/reference/generated/numpy.zeros.html

ones(shape, type_as=None)[source]

Creates a tensor full of ones.

This function follows the api from numpy.ones

See: https://numpy.org/doc/stable/reference/generated/numpy.ones.html

arange(stop, start=0, step=1, type_as=None)[source]

Returns evenly spaced values within a given interval.

This function follows the api from numpy.arange

See: https://numpy.org/doc/stable/reference/generated/numpy.arange.html

full(shape, fill_value, type_as=None)[source]

Creates a tensor with given shape, filled with given value.

This function follows the api from numpy.full

See: https://numpy.org/doc/stable/reference/generated/numpy.full.html

eye(N, M=None, type_as=None)[source]

Creates the identity matrix of given size.

This function follows the api from numpy.eye

See: https://numpy.org/doc/stable/reference/generated/numpy.eye.html

sum(a, axis=None, keepdims=False)[source]

Sums tensor elements over given dimensions.

This function follows the api from numpy.sum

See: https://numpy.org/doc/stable/reference/generated/numpy.sum.html

cumsum(a, axis=None)[source]

Returns the cumulative sum of tensor elements over given dimensions.

This function follows the api from numpy.cumsum

See: https://numpy.org/doc/stable/reference/generated/numpy.cumsum.html

max(a, axis=None, keepdims=False)[source]

Returns the maximum of an array or maximum along given dimensions.

This function follows the api from numpy.amax

See: https://numpy.org/doc/stable/reference/generated/numpy.amax.html

min(a, axis=None, keepdims=False)[source]

Returns the maximum of an array or maximum along given dimensions.

This function follows the api from numpy.amin

See: https://numpy.org/doc/stable/reference/generated/numpy.amin.html

maximum(a, b)[source]

Returns element-wise maximum of array elements.

This function follows the api from numpy.maximum

See: https://numpy.org/doc/stable/reference/generated/numpy.maximum.html

minimum(a, b)[source]

Returns element-wise minimum of array elements.

This function follows the api from numpy.minimum

See: https://numpy.org/doc/stable/reference/generated/numpy.minimum.html

sign(a)[source]

Returns an element-wise indication of the sign of a number.

This function follows the api from numpy.sign

See: https://numpy.org/doc/stable/reference/generated/numpy.sign.html

dot(a, b)[source]

Returns the dot product of two tensors.

This function follows the api from numpy.dot

See: https://numpy.org/doc/stable/reference/generated/numpy.dot.html

abs(a)[source]

Computes the absolute value element-wise.

This function follows the api from numpy.absolute

See: https://numpy.org/doc/stable/reference/generated/numpy.absolute.html

exp(a)[source]

Computes the exponential value element-wise.

This function follows the api from numpy.exp

See: https://numpy.org/doc/stable/reference/generated/numpy.exp.html

log(a)[source]

Computes the natural logarithm, element-wise.

This function follows the api from numpy.log

See: https://numpy.org/doc/stable/reference/generated/numpy.log.html

sqrt(a)[source]

Returns the non-ngeative square root of a tensor, element-wise.

This function follows the api from numpy.sqrt

See: https://numpy.org/doc/stable/reference/generated/numpy.sqrt.html

power(a, exponents)[source]

First tensor elements raised to powers from second tensor, element-wise.

This function follows the api from numpy.power

See: https://numpy.org/doc/stable/reference/generated/numpy.power.html

norm(a, axis=None, keepdims=False)[source]

Computes the matrix frobenius norm.

This function follows the api from numpy.linalg.norm

See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.norm.html

any(a)[source]

Tests whether any tensor element along given dimensions evaluates to True.

This function follows the api from numpy.any

See: https://numpy.org/doc/stable/reference/generated/numpy.any.html

isnan(a)[source]

Tests element-wise for NaN and returns result as a boolean tensor.

This function follows the api from numpy.isnan

See: https://numpy.org/doc/stable/reference/generated/numpy.isnan.html

isinf(a)[source]

Tests element-wise for positive or negative infinity and returns result as a boolean tensor.

This function follows the api from numpy.isinf

See: https://numpy.org/doc/stable/reference/generated/numpy.isinf.html

einsum(subscripts, *operands)[source]

Evaluates the Einstein summation convention on the operands.

This function follows the api from numpy.einsum

See: https://numpy.org/doc/stable/reference/generated/numpy.einsum.html

sort(a, axis=-1)[source]

Returns a sorted copy of a tensor.

This function follows the api from numpy.sort

See: https://numpy.org/doc/stable/reference/generated/numpy.sort.html

argsort(a, axis=-1)[source]

Returns the indices that would sort a tensor.

This function follows the api from numpy.argsort

See: https://numpy.org/doc/stable/reference/generated/numpy.argsort.html

searchsorted(a, v, side='left')[source]

Finds indices where elements should be inserted to maintain order in given tensor.

This function follows the api from numpy.searchsorted

See: https://numpy.org/doc/stable/reference/generated/numpy.searchsorted.html

flip(a, axis=None)[source]

Reverses the order of elements in a tensor along given dimensions.

This function follows the api from numpy.flip

See: https://numpy.org/doc/stable/reference/generated/numpy.flip.html

outer(a, b)[source]

Computes the outer product between two vectors.

This function follows the api from numpy.outer

See: https://numpy.org/doc/stable/reference/generated/numpy.outer.html

clip(a, a_min, a_max)[source]

Limits the values in a tensor.

This function follows the api from numpy.clip

See: https://numpy.org/doc/stable/reference/generated/numpy.clip.html

repeat(a, repeats, axis=None)[source]

Repeats elements of a tensor.

This function follows the api from numpy.repeat

See: https://numpy.org/doc/stable/reference/generated/numpy.repeat.html

take_along_axis(arr, indices, axis)[source]

Gathers elements of a tensor along given dimensions.

This function follows the api from numpy.take_along_axis

See: https://numpy.org/doc/stable/reference/generated/numpy.take_along_axis.html

concatenate(arrays, axis=0)[source]

Joins a sequence of tensors along an existing dimension.

This function follows the api from numpy.concatenate

See: https://numpy.org/doc/stable/reference/generated/numpy.concatenate.html

zero_pad(a, pad_width, value=0)[source]

Pads a tensor with a given value (0 by default).

This function follows the api from numpy.pad

See: https://numpy.org/doc/stable/reference/generated/numpy.pad.html

argmax(a, axis=None)[source]

Returns the indices of the maximum values of a tensor along given dimensions.

This function follows the api from numpy.argmax

See: https://numpy.org/doc/stable/reference/generated/numpy.argmax.html

argmin(a, axis=None)[source]

Returns the indices of the minimum values of a tensor along given dimensions.

This function follows the api from numpy.argmin

See: https://numpy.org/doc/stable/reference/generated/numpy.argmin.html

mean(a, axis=None)[source]

Computes the arithmetic mean of a tensor along given dimensions.

This function follows the api from numpy.mean

See: https://numpy.org/doc/stable/reference/generated/numpy.mean.html

median(a, axis=None)[source]

Computes the median of a tensor along given dimensions.

This function follows the api from numpy.median

See: https://numpy.org/doc/stable/reference/generated/numpy.median.html

std(a, axis=None)[source]

Computes the standard deviation of a tensor along given dimensions.

This function follows the api from numpy.std

See: https://numpy.org/doc/stable/reference/generated/numpy.std.html

linspace(start, stop, num, type_as=None)[source]

Returns a specified number of evenly spaced values over a given interval.

This function follows the api from numpy.linspace

See: https://numpy.org/doc/stable/reference/generated/numpy.linspace.html

meshgrid(a, b)[source]

Returns coordinate matrices from coordinate vectors (Numpy convention).

This function follows the api from numpy.meshgrid

See: https://numpy.org/doc/stable/reference/generated/numpy.meshgrid.html

diag(a, k=0)[source]

Extracts or constructs a diagonal tensor.

This function follows the api from numpy.diag

See: https://numpy.org/doc/stable/reference/generated/numpy.diag.html

unique(a, return_inverse=False)[source]

Finds unique elements of given tensor.

This function follows the api from numpy.unique

See: https://numpy.org/doc/stable/reference/generated/numpy.unique.html

logsumexp(a, axis=None)[source]

Computes the log of the sum of exponentials of input elements.

This function follows the api from scipy.special.logsumexp

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.logsumexp.html

stack(arrays, axis=0)[source]

Joins a sequence of tensors along a new dimension.

This function follows the api from numpy.stack

See: https://numpy.org/doc/stable/reference/generated/numpy.stack.html

reshape(a, shape)[source]

Gives a new shape to a tensor without changing its data.

This function follows the api from numpy.reshape

See: https://numpy.org/doc/stable/reference/generated/numpy.reshape.html

seed(seed=None)[source]

Sets the seed for the random generator.

This function follows the api from numpy.random.seed

See: https://numpy.org/doc/stable/reference/random/generated/numpy.random.seed.html

rand(*size, type_as=None)[source]

Generate uniform random numbers.

This function follows the api from numpy.random.rand

See: https://numpy.org/doc/stable/reference/random/generated/numpy.random.rand.html

randn(*size, type_as=None)[source]

Generate normal Gaussian random numbers.

This function follows the api from numpy.random.rand

See: https://numpy.org/doc/stable/reference/random/generated/numpy.random.rand.html

coo_matrix(data, rows, cols, shape=None, type_as=None)[source]

Creates a sparse tensor in COOrdinate format.

This function follows the api from scipy.sparse.coo_matrix

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.html

issparse(a)[source]

Checks whether or not the input tensor is a sparse tensor.

This function follows the api from scipy.sparse.issparse

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.issparse.html

tocsr(a)[source]

Converts this matrix to Compressed Sparse Row format.

This function follows the api from scipy.sparse.coo_matrix.tocsr

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.tocsr.html

eliminate_zeros(a, threshold=0.0)[source]

Removes entries smaller than the given threshold from the sparse tensor.

This function follows the api from scipy.sparse.csr_matrix.eliminate_zeros

See: https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.sparse.csr_matrix.eliminate_zeros.html

todense(a)[source]

Converts a sparse tensor to a dense tensor.

This function follows the api from scipy.sparse.csr_matrix.toarray

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.toarray.html

where(condition, x=None, y=None)[source]

Returns elements chosen from x or y depending on condition.

This function follows the api from numpy.where

See: https://numpy.org/doc/stable/reference/generated/numpy.where.html

copy(a)[source]

Returns a copy of the given tensor.

This function follows the api from numpy.copy

See: https://numpy.org/doc/stable/reference/generated/numpy.copy.html

allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False)[source]

Returns True if two arrays are element-wise equal within a tolerance.

This function follows the api from numpy.allclose

See: https://numpy.org/doc/stable/reference/generated/numpy.allclose.html

dtype_device(a)[source]

Returns the dtype and the device of the given tensor.

assert_same_dtype_device(a, b)[source]

Checks whether or not the two given inputs have the same dtype as well as the same device

squeeze(a, axis=None)[source]

Remove axes of length one from a.

This function follows the api from numpy.squeeze.

See: https://numpy.org/doc/stable/reference/generated/numpy.squeeze.html

bitsize(type_as)[source]

Gives the number of bits used by the data type of the given tensor.

device_type(type_as)[source]

Returns CPU or GPU depending on the device where the given tensor is located.

_bench(callable, *args, n_runs=1, warmup_runs=1)[source]

Executes a benchmark of the given callable with the given arguments.

solve(a, b)[source]

Solves a linear matrix equation, or system of linear scalar equations.

This function follows the api from numpy.linalg.solve.

See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.solve.html

trace(a)[source]

Returns the sum along diagonals of the array.

This function follows the api from numpy.trace.

See: https://numpy.org/doc/stable/reference/generated/numpy.trace.html

inv(a)[source]

Computes the inverse of a matrix.

This function follows the api from scipy.linalg.inv.

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.inv.html

sqrtm(a)[source]

Computes the matrix square root. Requires input to be definite positive.

This function follows the api from scipy.linalg.sqrtm.

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.sqrtm.html

eigh(a)[source]

Computes the eigenvalues and eigenvectors of a symmetric tensor.

This function follows the api from scipy.linalg.eigh.

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.eigh.html

kl_div(p, q, mass=False, eps=1e-16)[source]

Computes the (Generalized) Kullback-Leibler divergence.

This function follows the api from scipy.stats.entropy.

Parameter eps is used to avoid numerical errors and is added in the log.

\[KL(p,q) = \langle \mathbf{p}, log(\mathbf{p} / \mathbf{q} + eps \rangle + \mathbb{1}_{mass=True} \langle \mathbf{q} - \mathbf{p}, \mathbf{1} \rangle\]

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.entropy.html

isfinite(a)[source]

Tests element-wise for finiteness (not infinity and not Not a Number).

This function follows the api from numpy.isfinite.

See: https://numpy.org/doc/stable/reference/generated/numpy.isfinite.html

array_equal(a, b)[source]

True if two arrays have the same shape and elements, False otherwise.

This function follows the api from numpy.array_equal.

See: https://numpy.org/doc/stable/reference/generated/numpy.array_equal.html

is_floating_point(a)[source]

Returns whether or not the input consists of floats

tile(a, reps)[source]

Construct an array by repeating a the number of times given by reps

See: https://numpy.org/doc/stable/reference/generated/numpy.tile.html

floor(a)[source]

Return the floor of the input element-wise

See: https://numpy.org/doc/stable/reference/generated/numpy.floor.html

prod(a, axis=0)[source]

Return the product of all elements.

See: https://numpy.org/doc/stable/reference/generated/numpy.prod.html

sort2(a, axis=-1)[source]

Return the sorted array and the indices to sort the array

See: https://pytorch.org/docs/stable/generated/torch.sort.html

qr(a)[source]

Return the QR factorization

See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.qr.html

atan2(a, b)[source]

Element wise arctangent

See: https://numpy.org/doc/stable/reference/generated/numpy.arctan2.html

transpose(a, axes=None)[source]

Returns a tensor that is a transposed version of a. The given dimensions dim0 and dim1 are swapped.

See: https://numpy.org/doc/stable/reference/generated/numpy.transpose.html

matmul(a, b)[source]

Matrix product of two arrays.

See: https://numpy.org/doc/stable/reference/generated/numpy.matmul.html#numpy.matmul

nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None)[source]

Replace NaN with zero and infinity with large finite numbers or with the numbers defined by the user.

See: https://numpy.org/doc/stable/reference/generated/numpy.nan_to_num.html#numpy.nan_to_num

class spateo.alignment.methods.backend_ot.TorchBackend[source]

Bases: Backend

PyTorch implementation of the backend

  • __name__ is “torch”

  • __type__ is torch.Tensor

__name__ = 'torch'[source]
__type__[source]
__type_list__ = None[source]
rng_ = None[source]
ValFunction[source]
_to_numpy(a)[source]

Returns the numpy version of a tensor

_from_numpy(a, type_as=None)[source]

Creates a tensor cloning a numpy array, with the given precision (defaulting to input’s precision) and the given device (in case of GPUs)

set_gradients(val, inputs, grads)[source]

Define the gradients for the value val wrt the inputs

_detach(a)[source]

Detach the tensor from the computation graph

zeros(shape, type_as=None)[source]

Creates a tensor full of zeros.

This function follows the api from numpy.zeros

See: https://numpy.org/doc/stable/reference/generated/numpy.zeros.html

ones(shape, type_as=None)[source]

Creates a tensor full of ones.

This function follows the api from numpy.ones

See: https://numpy.org/doc/stable/reference/generated/numpy.ones.html

arange(stop, start=0, step=1, type_as=None)[source]

Returns evenly spaced values within a given interval.

This function follows the api from numpy.arange

See: https://numpy.org/doc/stable/reference/generated/numpy.arange.html

full(shape, fill_value, type_as=None)[source]

Creates a tensor with given shape, filled with given value.

This function follows the api from numpy.full

See: https://numpy.org/doc/stable/reference/generated/numpy.full.html

eye(N, M=None, type_as=None)[source]

Creates the identity matrix of given size.

This function follows the api from numpy.eye

See: https://numpy.org/doc/stable/reference/generated/numpy.eye.html

sum(a, axis=None, keepdims=False)[source]

Sums tensor elements over given dimensions.

This function follows the api from numpy.sum

See: https://numpy.org/doc/stable/reference/generated/numpy.sum.html

cumsum(a, axis=None)[source]

Returns the cumulative sum of tensor elements over given dimensions.

This function follows the api from numpy.cumsum

See: https://numpy.org/doc/stable/reference/generated/numpy.cumsum.html

max(a, axis=None, keepdims=False)[source]

Returns the maximum of an array or maximum along given dimensions.

This function follows the api from numpy.amax

See: https://numpy.org/doc/stable/reference/generated/numpy.amax.html

min(a, axis=None, keepdims=False)[source]

Returns the maximum of an array or maximum along given dimensions.

This function follows the api from numpy.amin

See: https://numpy.org/doc/stable/reference/generated/numpy.amin.html

maximum(a, b)[source]

Returns element-wise maximum of array elements.

This function follows the api from numpy.maximum

See: https://numpy.org/doc/stable/reference/generated/numpy.maximum.html

minimum(a, b)[source]

Returns element-wise minimum of array elements.

This function follows the api from numpy.minimum

See: https://numpy.org/doc/stable/reference/generated/numpy.minimum.html

sign(a)[source]

Returns an element-wise indication of the sign of a number.

This function follows the api from numpy.sign

See: https://numpy.org/doc/stable/reference/generated/numpy.sign.html

dot(a, b)[source]

Returns the dot product of two tensors.

This function follows the api from numpy.dot

See: https://numpy.org/doc/stable/reference/generated/numpy.dot.html

abs(a)[source]

Computes the absolute value element-wise.

This function follows the api from numpy.absolute

See: https://numpy.org/doc/stable/reference/generated/numpy.absolute.html

exp(a)[source]

Computes the exponential value element-wise.

This function follows the api from numpy.exp

See: https://numpy.org/doc/stable/reference/generated/numpy.exp.html

log(a)[source]

Computes the natural logarithm, element-wise.

This function follows the api from numpy.log

See: https://numpy.org/doc/stable/reference/generated/numpy.log.html

sqrt(a)[source]

Returns the non-ngeative square root of a tensor, element-wise.

This function follows the api from numpy.sqrt

See: https://numpy.org/doc/stable/reference/generated/numpy.sqrt.html

power(a, exponents)[source]

First tensor elements raised to powers from second tensor, element-wise.

This function follows the api from numpy.power

See: https://numpy.org/doc/stable/reference/generated/numpy.power.html

norm(a, axis=None, keepdims=False)[source]

Computes the matrix frobenius norm.

This function follows the api from numpy.linalg.norm

See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.norm.html

any(a)[source]

Tests whether any tensor element along given dimensions evaluates to True.

This function follows the api from numpy.any

See: https://numpy.org/doc/stable/reference/generated/numpy.any.html

isnan(a)[source]

Tests element-wise for NaN and returns result as a boolean tensor.

This function follows the api from numpy.isnan

See: https://numpy.org/doc/stable/reference/generated/numpy.isnan.html

isinf(a)[source]

Tests element-wise for positive or negative infinity and returns result as a boolean tensor.

This function follows the api from numpy.isinf

See: https://numpy.org/doc/stable/reference/generated/numpy.isinf.html

einsum(subscripts, *operands)[source]

Evaluates the Einstein summation convention on the operands.

This function follows the api from numpy.einsum

See: https://numpy.org/doc/stable/reference/generated/numpy.einsum.html

sort(a, axis=-1)[source]

Returns a sorted copy of a tensor.

This function follows the api from numpy.sort

See: https://numpy.org/doc/stable/reference/generated/numpy.sort.html

argsort(a, axis=-1)[source]

Returns the indices that would sort a tensor.

This function follows the api from numpy.argsort

See: https://numpy.org/doc/stable/reference/generated/numpy.argsort.html

searchsorted(a, v, side='left')[source]

Finds indices where elements should be inserted to maintain order in given tensor.

This function follows the api from numpy.searchsorted

See: https://numpy.org/doc/stable/reference/generated/numpy.searchsorted.html

flip(a, axis=None)[source]

Reverses the order of elements in a tensor along given dimensions.

This function follows the api from numpy.flip

See: https://numpy.org/doc/stable/reference/generated/numpy.flip.html

outer(a, b)[source]

Computes the outer product between two vectors.

This function follows the api from numpy.outer

See: https://numpy.org/doc/stable/reference/generated/numpy.outer.html

clip(a, a_min, a_max)[source]

Limits the values in a tensor.

This function follows the api from numpy.clip

See: https://numpy.org/doc/stable/reference/generated/numpy.clip.html

repeat(a, repeats, axis=None)[source]

Repeats elements of a tensor.

This function follows the api from numpy.repeat

See: https://numpy.org/doc/stable/reference/generated/numpy.repeat.html

take_along_axis(arr, indices, axis)[source]

Gathers elements of a tensor along given dimensions.

This function follows the api from numpy.take_along_axis

See: https://numpy.org/doc/stable/reference/generated/numpy.take_along_axis.html

concatenate(arrays, axis=0)[source]

Joins a sequence of tensors along an existing dimension.

This function follows the api from numpy.concatenate

See: https://numpy.org/doc/stable/reference/generated/numpy.concatenate.html

zero_pad(a, pad_width, value=0)[source]

Pads a tensor with a given value (0 by default).

This function follows the api from numpy.pad

See: https://numpy.org/doc/stable/reference/generated/numpy.pad.html

argmax(a, axis=None)[source]

Returns the indices of the maximum values of a tensor along given dimensions.

This function follows the api from numpy.argmax

See: https://numpy.org/doc/stable/reference/generated/numpy.argmax.html

argmin(a, axis=None)[source]

Returns the indices of the minimum values of a tensor along given dimensions.

This function follows the api from numpy.argmin

See: https://numpy.org/doc/stable/reference/generated/numpy.argmin.html

mean(a, axis=None)[source]

Computes the arithmetic mean of a tensor along given dimensions.

This function follows the api from numpy.mean

See: https://numpy.org/doc/stable/reference/generated/numpy.mean.html

median(a, axis=None)[source]

Computes the median of a tensor along given dimensions.

This function follows the api from numpy.median

See: https://numpy.org/doc/stable/reference/generated/numpy.median.html

std(a, axis=None)[source]

Computes the standard deviation of a tensor along given dimensions.

This function follows the api from numpy.std

See: https://numpy.org/doc/stable/reference/generated/numpy.std.html

linspace(start, stop, num, type_as=None)[source]

Returns a specified number of evenly spaced values over a given interval.

This function follows the api from numpy.linspace

See: https://numpy.org/doc/stable/reference/generated/numpy.linspace.html

meshgrid(a, b)[source]

Returns coordinate matrices from coordinate vectors (Numpy convention).

This function follows the api from numpy.meshgrid

See: https://numpy.org/doc/stable/reference/generated/numpy.meshgrid.html

diag(a, k=0)[source]

Extracts or constructs a diagonal tensor.

This function follows the api from numpy.diag

See: https://numpy.org/doc/stable/reference/generated/numpy.diag.html

unique(a, return_inverse=False)[source]

Finds unique elements of given tensor.

This function follows the api from numpy.unique

See: https://numpy.org/doc/stable/reference/generated/numpy.unique.html

logsumexp(a, axis=None)[source]

Computes the log of the sum of exponentials of input elements.

This function follows the api from scipy.special.logsumexp

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.logsumexp.html

stack(arrays, axis=0)[source]

Joins a sequence of tensors along a new dimension.

This function follows the api from numpy.stack

See: https://numpy.org/doc/stable/reference/generated/numpy.stack.html

reshape(a, shape)[source]

Gives a new shape to a tensor without changing its data.

This function follows the api from numpy.reshape

See: https://numpy.org/doc/stable/reference/generated/numpy.reshape.html

seed(seed=None)[source]

Sets the seed for the random generator.

This function follows the api from numpy.random.seed

See: https://numpy.org/doc/stable/reference/random/generated/numpy.random.seed.html

rand(*size, type_as=None)[source]

Generate uniform random numbers.

This function follows the api from numpy.random.rand

See: https://numpy.org/doc/stable/reference/random/generated/numpy.random.rand.html

randn(*size, type_as=None)[source]

Generate normal Gaussian random numbers.

This function follows the api from numpy.random.rand

See: https://numpy.org/doc/stable/reference/random/generated/numpy.random.rand.html

coo_matrix(data, rows, cols, shape=None, type_as=None)[source]

Creates a sparse tensor in COOrdinate format.

This function follows the api from scipy.sparse.coo_matrix

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.html

issparse(a)[source]

Checks whether or not the input tensor is a sparse tensor.

This function follows the api from scipy.sparse.issparse

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.issparse.html

tocsr(a)[source]

Converts this matrix to Compressed Sparse Row format.

This function follows the api from scipy.sparse.coo_matrix.tocsr

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.tocsr.html

eliminate_zeros(a, threshold=0.0)[source]

Removes entries smaller than the given threshold from the sparse tensor.

This function follows the api from scipy.sparse.csr_matrix.eliminate_zeros

See: https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.sparse.csr_matrix.eliminate_zeros.html

todense(a)[source]

Converts a sparse tensor to a dense tensor.

This function follows the api from scipy.sparse.csr_matrix.toarray

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.toarray.html

where(condition, x=None, y=None)[source]

Returns elements chosen from x or y depending on condition.

This function follows the api from numpy.where

See: https://numpy.org/doc/stable/reference/generated/numpy.where.html

copy(a)[source]

Returns a copy of the given tensor.

This function follows the api from numpy.copy

See: https://numpy.org/doc/stable/reference/generated/numpy.copy.html

allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False)[source]

Returns True if two arrays are element-wise equal within a tolerance.

This function follows the api from numpy.allclose

See: https://numpy.org/doc/stable/reference/generated/numpy.allclose.html

dtype_device(a)[source]

Returns the dtype and the device of the given tensor.

assert_same_dtype_device(a, b)[source]

Checks whether or not the two given inputs have the same dtype as well as the same device

squeeze(a, axis=None)[source]

Remove axes of length one from a.

This function follows the api from numpy.squeeze.

See: https://numpy.org/doc/stable/reference/generated/numpy.squeeze.html

bitsize(type_as)[source]

Gives the number of bits used by the data type of the given tensor.

device_type(type_as)[source]

Returns CPU or GPU depending on the device where the given tensor is located.

_bench(callable, *args, n_runs=1, warmup_runs=1)[source]

Executes a benchmark of the given callable with the given arguments.

solve(a, b)[source]

Solves a linear matrix equation, or system of linear scalar equations.

This function follows the api from numpy.linalg.solve.

See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.solve.html

trace(a)[source]

Returns the sum along diagonals of the array.

This function follows the api from numpy.trace.

See: https://numpy.org/doc/stable/reference/generated/numpy.trace.html

inv(a)[source]

Computes the inverse of a matrix.

This function follows the api from scipy.linalg.inv.

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.inv.html

sqrtm(a)[source]

Computes the matrix square root. Requires input to be definite positive.

This function follows the api from scipy.linalg.sqrtm.

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.sqrtm.html

eigh(a)[source]

Computes the eigenvalues and eigenvectors of a symmetric tensor.

This function follows the api from scipy.linalg.eigh.

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.eigh.html

kl_div(p, q, mass=False, eps=1e-16)[source]

Computes the (Generalized) Kullback-Leibler divergence.

This function follows the api from scipy.stats.entropy.

Parameter eps is used to avoid numerical errors and is added in the log.

\[KL(p,q) = \langle \mathbf{p}, log(\mathbf{p} / \mathbf{q} + eps \rangle + \mathbb{1}_{mass=True} \langle \mathbf{q} - \mathbf{p}, \mathbf{1} \rangle\]

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.entropy.html

isfinite(a)[source]

Tests element-wise for finiteness (not infinity and not Not a Number).

This function follows the api from numpy.isfinite.

See: https://numpy.org/doc/stable/reference/generated/numpy.isfinite.html

array_equal(a, b)[source]

True if two arrays have the same shape and elements, False otherwise.

This function follows the api from numpy.array_equal.

See: https://numpy.org/doc/stable/reference/generated/numpy.array_equal.html

is_floating_point(a)[source]

Returns whether or not the input consists of floats

tile(a, reps)[source]

Construct an array by repeating a the number of times given by reps

See: https://numpy.org/doc/stable/reference/generated/numpy.tile.html

floor(a)[source]

Return the floor of the input element-wise

See: https://numpy.org/doc/stable/reference/generated/numpy.floor.html

prod(a, axis=0)[source]

Return the product of all elements.

See: https://numpy.org/doc/stable/reference/generated/numpy.prod.html

sort2(a, axis=-1)[source]

Return the sorted array and the indices to sort the array

See: https://pytorch.org/docs/stable/generated/torch.sort.html

qr(a)[source]

Return the QR factorization

See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.qr.html

atan2(a, b)[source]

Element wise arctangent

See: https://numpy.org/doc/stable/reference/generated/numpy.arctan2.html

transpose(a, axes=None)[source]

Returns a tensor that is a transposed version of a. The given dimensions dim0 and dim1 are swapped.

See: https://numpy.org/doc/stable/reference/generated/numpy.transpose.html

matmul(a, b)[source]

Matrix product of two arrays.

See: https://numpy.org/doc/stable/reference/generated/numpy.matmul.html#numpy.matmul

nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None)[source]

Replace NaN with zero and infinity with large finite numbers or with the numbers defined by the user.

See: https://numpy.org/doc/stable/reference/generated/numpy.nan_to_num.html#numpy.nan_to_num

class spateo.alignment.methods.backend_ot.CupyBackend[source]

Bases: Backend

CuPy implementation of the backend

  • __name__ is “cupy”

  • __type__ is cp.ndarray

__name__ = 'cupy'[source]
__type__[source]
__type_list__ = None[source]
rng_ = None[source]
_to_numpy(a)[source]

Returns the numpy version of a tensor

_from_numpy(a, type_as=None)[source]

Creates a tensor cloning a numpy array, with the given precision (defaulting to input’s precision) and the given device (in case of GPUs)

set_gradients(val, inputs, grads)[source]

Define the gradients for the value val wrt the inputs

_detach(a)[source]

Detach the tensor from the computation graph

zeros(shape, type_as=None)[source]

Creates a tensor full of zeros.

This function follows the api from numpy.zeros

See: https://numpy.org/doc/stable/reference/generated/numpy.zeros.html

ones(shape, type_as=None)[source]

Creates a tensor full of ones.

This function follows the api from numpy.ones

See: https://numpy.org/doc/stable/reference/generated/numpy.ones.html

arange(stop, start=0, step=1, type_as=None)[source]

Returns evenly spaced values within a given interval.

This function follows the api from numpy.arange

See: https://numpy.org/doc/stable/reference/generated/numpy.arange.html

full(shape, fill_value, type_as=None)[source]

Creates a tensor with given shape, filled with given value.

This function follows the api from numpy.full

See: https://numpy.org/doc/stable/reference/generated/numpy.full.html

eye(N, M=None, type_as=None)[source]

Creates the identity matrix of given size.

This function follows the api from numpy.eye

See: https://numpy.org/doc/stable/reference/generated/numpy.eye.html

sum(a, axis=None, keepdims=False)[source]

Sums tensor elements over given dimensions.

This function follows the api from numpy.sum

See: https://numpy.org/doc/stable/reference/generated/numpy.sum.html

cumsum(a, axis=None)[source]

Returns the cumulative sum of tensor elements over given dimensions.

This function follows the api from numpy.cumsum

See: https://numpy.org/doc/stable/reference/generated/numpy.cumsum.html

max(a, axis=None, keepdims=False)[source]

Returns the maximum of an array or maximum along given dimensions.

This function follows the api from numpy.amax

See: https://numpy.org/doc/stable/reference/generated/numpy.amax.html

min(a, axis=None, keepdims=False)[source]

Returns the maximum of an array or maximum along given dimensions.

This function follows the api from numpy.amin

See: https://numpy.org/doc/stable/reference/generated/numpy.amin.html

maximum(a, b)[source]

Returns element-wise maximum of array elements.

This function follows the api from numpy.maximum

See: https://numpy.org/doc/stable/reference/generated/numpy.maximum.html

minimum(a, b)[source]

Returns element-wise minimum of array elements.

This function follows the api from numpy.minimum

See: https://numpy.org/doc/stable/reference/generated/numpy.minimum.html

sign(a)[source]

Returns an element-wise indication of the sign of a number.

This function follows the api from numpy.sign

See: https://numpy.org/doc/stable/reference/generated/numpy.sign.html

abs(a)[source]

Computes the absolute value element-wise.

This function follows the api from numpy.absolute

See: https://numpy.org/doc/stable/reference/generated/numpy.absolute.html

exp(a)[source]

Computes the exponential value element-wise.

This function follows the api from numpy.exp

See: https://numpy.org/doc/stable/reference/generated/numpy.exp.html

log(a)[source]

Computes the natural logarithm, element-wise.

This function follows the api from numpy.log

See: https://numpy.org/doc/stable/reference/generated/numpy.log.html

sqrt(a)[source]

Returns the non-ngeative square root of a tensor, element-wise.

This function follows the api from numpy.sqrt

See: https://numpy.org/doc/stable/reference/generated/numpy.sqrt.html

power(a, exponents)[source]

First tensor elements raised to powers from second tensor, element-wise.

This function follows the api from numpy.power

See: https://numpy.org/doc/stable/reference/generated/numpy.power.html

dot(a, b)[source]

Returns the dot product of two tensors.

This function follows the api from numpy.dot

See: https://numpy.org/doc/stable/reference/generated/numpy.dot.html

norm(a, axis=None, keepdims=False)[source]

Computes the matrix frobenius norm.

This function follows the api from numpy.linalg.norm

See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.norm.html

any(a)[source]

Tests whether any tensor element along given dimensions evaluates to True.

This function follows the api from numpy.any

See: https://numpy.org/doc/stable/reference/generated/numpy.any.html

isnan(a)[source]

Tests element-wise for NaN and returns result as a boolean tensor.

This function follows the api from numpy.isnan

See: https://numpy.org/doc/stable/reference/generated/numpy.isnan.html

isinf(a)[source]

Tests element-wise for positive or negative infinity and returns result as a boolean tensor.

This function follows the api from numpy.isinf

See: https://numpy.org/doc/stable/reference/generated/numpy.isinf.html

einsum(subscripts, *operands)[source]

Evaluates the Einstein summation convention on the operands.

This function follows the api from numpy.einsum

See: https://numpy.org/doc/stable/reference/generated/numpy.einsum.html

sort(a, axis=-1)[source]

Returns a sorted copy of a tensor.

This function follows the api from numpy.sort

See: https://numpy.org/doc/stable/reference/generated/numpy.sort.html

argsort(a, axis=-1)[source]

Returns the indices that would sort a tensor.

This function follows the api from numpy.argsort

See: https://numpy.org/doc/stable/reference/generated/numpy.argsort.html

searchsorted(a, v, side='left')[source]

Finds indices where elements should be inserted to maintain order in given tensor.

This function follows the api from numpy.searchsorted

See: https://numpy.org/doc/stable/reference/generated/numpy.searchsorted.html

flip(a, axis=None)[source]

Reverses the order of elements in a tensor along given dimensions.

This function follows the api from numpy.flip

See: https://numpy.org/doc/stable/reference/generated/numpy.flip.html

outer(a, b)[source]

Computes the outer product between two vectors.

This function follows the api from numpy.outer

See: https://numpy.org/doc/stable/reference/generated/numpy.outer.html

clip(a, a_min, a_max)[source]

Limits the values in a tensor.

This function follows the api from numpy.clip

See: https://numpy.org/doc/stable/reference/generated/numpy.clip.html

repeat(a, repeats, axis=None)[source]

Repeats elements of a tensor.

This function follows the api from numpy.repeat

See: https://numpy.org/doc/stable/reference/generated/numpy.repeat.html

take_along_axis(arr, indices, axis)[source]

Gathers elements of a tensor along given dimensions.

This function follows the api from numpy.take_along_axis

See: https://numpy.org/doc/stable/reference/generated/numpy.take_along_axis.html

concatenate(arrays, axis=0)[source]

Joins a sequence of tensors along an existing dimension.

This function follows the api from numpy.concatenate

See: https://numpy.org/doc/stable/reference/generated/numpy.concatenate.html

zero_pad(a, pad_width, value=0)[source]

Pads a tensor with a given value (0 by default).

This function follows the api from numpy.pad

See: https://numpy.org/doc/stable/reference/generated/numpy.pad.html

argmax(a, axis=None)[source]

Returns the indices of the maximum values of a tensor along given dimensions.

This function follows the api from numpy.argmax

See: https://numpy.org/doc/stable/reference/generated/numpy.argmax.html

argmin(a, axis=None)[source]

Returns the indices of the minimum values of a tensor along given dimensions.

This function follows the api from numpy.argmin

See: https://numpy.org/doc/stable/reference/generated/numpy.argmin.html

mean(a, axis=None)[source]

Computes the arithmetic mean of a tensor along given dimensions.

This function follows the api from numpy.mean

See: https://numpy.org/doc/stable/reference/generated/numpy.mean.html

median(a, axis=None)[source]

Computes the median of a tensor along given dimensions.

This function follows the api from numpy.median

See: https://numpy.org/doc/stable/reference/generated/numpy.median.html

std(a, axis=None)[source]

Computes the standard deviation of a tensor along given dimensions.

This function follows the api from numpy.std

See: https://numpy.org/doc/stable/reference/generated/numpy.std.html

linspace(start, stop, num, type_as=None)[source]

Returns a specified number of evenly spaced values over a given interval.

This function follows the api from numpy.linspace

See: https://numpy.org/doc/stable/reference/generated/numpy.linspace.html

meshgrid(a, b)[source]

Returns coordinate matrices from coordinate vectors (Numpy convention).

This function follows the api from numpy.meshgrid

See: https://numpy.org/doc/stable/reference/generated/numpy.meshgrid.html

diag(a, k=0)[source]

Extracts or constructs a diagonal tensor.

This function follows the api from numpy.diag

See: https://numpy.org/doc/stable/reference/generated/numpy.diag.html

unique(a, return_inverse=False)[source]

Finds unique elements of given tensor.

This function follows the api from numpy.unique

See: https://numpy.org/doc/stable/reference/generated/numpy.unique.html

logsumexp(a, axis=None)[source]

Computes the log of the sum of exponentials of input elements.

This function follows the api from scipy.special.logsumexp

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.logsumexp.html

stack(arrays, axis=0)[source]

Joins a sequence of tensors along a new dimension.

This function follows the api from numpy.stack

See: https://numpy.org/doc/stable/reference/generated/numpy.stack.html

reshape(a, shape)[source]

Gives a new shape to a tensor without changing its data.

This function follows the api from numpy.reshape

See: https://numpy.org/doc/stable/reference/generated/numpy.reshape.html

seed(seed=None)[source]

Sets the seed for the random generator.

This function follows the api from numpy.random.seed

See: https://numpy.org/doc/stable/reference/random/generated/numpy.random.seed.html

rand(*size, type_as=None)[source]

Generate uniform random numbers.

This function follows the api from numpy.random.rand

See: https://numpy.org/doc/stable/reference/random/generated/numpy.random.rand.html

randn(*size, type_as=None)[source]

Generate normal Gaussian random numbers.

This function follows the api from numpy.random.rand

See: https://numpy.org/doc/stable/reference/random/generated/numpy.random.rand.html

coo_matrix(data, rows, cols, shape=None, type_as=None)[source]

Creates a sparse tensor in COOrdinate format.

This function follows the api from scipy.sparse.coo_matrix

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.html

issparse(a)[source]

Checks whether or not the input tensor is a sparse tensor.

This function follows the api from scipy.sparse.issparse

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.issparse.html

tocsr(a)[source]

Converts this matrix to Compressed Sparse Row format.

This function follows the api from scipy.sparse.coo_matrix.tocsr

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.tocsr.html

eliminate_zeros(a, threshold=0.0)[source]

Removes entries smaller than the given threshold from the sparse tensor.

This function follows the api from scipy.sparse.csr_matrix.eliminate_zeros

See: https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.sparse.csr_matrix.eliminate_zeros.html

todense(a)[source]

Converts a sparse tensor to a dense tensor.

This function follows the api from scipy.sparse.csr_matrix.toarray

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.toarray.html

where(condition, x=None, y=None)[source]

Returns elements chosen from x or y depending on condition.

This function follows the api from numpy.where

See: https://numpy.org/doc/stable/reference/generated/numpy.where.html

copy(a)[source]

Returns a copy of the given tensor.

This function follows the api from numpy.copy

See: https://numpy.org/doc/stable/reference/generated/numpy.copy.html

allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False)[source]

Returns True if two arrays are element-wise equal within a tolerance.

This function follows the api from numpy.allclose

See: https://numpy.org/doc/stable/reference/generated/numpy.allclose.html

dtype_device(a)[source]

Returns the dtype and the device of the given tensor.

assert_same_dtype_device(a, b)[source]

Checks whether or not the two given inputs have the same dtype as well as the same device

squeeze(a, axis=None)[source]

Remove axes of length one from a.

This function follows the api from numpy.squeeze.

See: https://numpy.org/doc/stable/reference/generated/numpy.squeeze.html

bitsize(type_as)[source]

Gives the number of bits used by the data type of the given tensor.

device_type(type_as)[source]

Returns CPU or GPU depending on the device where the given tensor is located.

_bench(callable, *args, n_runs=1, warmup_runs=1)[source]

Executes a benchmark of the given callable with the given arguments.

solve(a, b)[source]

Solves a linear matrix equation, or system of linear scalar equations.

This function follows the api from numpy.linalg.solve.

See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.solve.html

trace(a)[source]

Returns the sum along diagonals of the array.

This function follows the api from numpy.trace.

See: https://numpy.org/doc/stable/reference/generated/numpy.trace.html

inv(a)[source]

Computes the inverse of a matrix.

This function follows the api from scipy.linalg.inv.

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.inv.html

sqrtm(a)[source]

Computes the matrix square root. Requires input to be definite positive.

This function follows the api from scipy.linalg.sqrtm.

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.sqrtm.html

eigh(a)[source]

Computes the eigenvalues and eigenvectors of a symmetric tensor.

This function follows the api from scipy.linalg.eigh.

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.eigh.html

kl_div(p, q, mass=False, eps=1e-16)[source]

Computes the (Generalized) Kullback-Leibler divergence.

This function follows the api from scipy.stats.entropy.

Parameter eps is used to avoid numerical errors and is added in the log.

\[KL(p,q) = \langle \mathbf{p}, log(\mathbf{p} / \mathbf{q} + eps \rangle + \mathbb{1}_{mass=True} \langle \mathbf{q} - \mathbf{p}, \mathbf{1} \rangle\]

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.entropy.html

isfinite(a)[source]

Tests element-wise for finiteness (not infinity and not Not a Number).

This function follows the api from numpy.isfinite.

See: https://numpy.org/doc/stable/reference/generated/numpy.isfinite.html

array_equal(a, b)[source]

True if two arrays have the same shape and elements, False otherwise.

This function follows the api from numpy.array_equal.

See: https://numpy.org/doc/stable/reference/generated/numpy.array_equal.html

is_floating_point(a)[source]

Returns whether or not the input consists of floats

tile(a, reps)[source]

Construct an array by repeating a the number of times given by reps

See: https://numpy.org/doc/stable/reference/generated/numpy.tile.html

floor(a)[source]

Return the floor of the input element-wise

See: https://numpy.org/doc/stable/reference/generated/numpy.floor.html

prod(a, axis=0)[source]

Return the product of all elements.

See: https://numpy.org/doc/stable/reference/generated/numpy.prod.html

sort2(a, axis=-1)[source]

Return the sorted array and the indices to sort the array

See: https://pytorch.org/docs/stable/generated/torch.sort.html

qr(a)[source]

Return the QR factorization

See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.qr.html

atan2(a, b)[source]

Element wise arctangent

See: https://numpy.org/doc/stable/reference/generated/numpy.arctan2.html

transpose(a, axes=None)[source]

Returns a tensor that is a transposed version of a. The given dimensions dim0 and dim1 are swapped.

See: https://numpy.org/doc/stable/reference/generated/numpy.transpose.html

matmul(a, b)[source]

Matrix product of two arrays.

See: https://numpy.org/doc/stable/reference/generated/numpy.matmul.html#numpy.matmul

nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None)[source]

Replace NaN with zero and infinity with large finite numbers or with the numbers defined by the user.

See: https://numpy.org/doc/stable/reference/generated/numpy.nan_to_num.html#numpy.nan_to_num

class spateo.alignment.methods.backend_ot.TensorflowBackend[source]

Bases: Backend

Backend abstract class. Implementations: JaxBackend, NumpyBackend, TorchBackend, CupyBackend, TensorflowBackend

  • The __name__ class attribute refers to the name of the backend.

  • The __type__ class attribute refers to the data structure used by the backend.

__name__ = 'tf'[source]
__type__[source]
__type_list__ = None[source]
rng_ = None[source]
_to_numpy(a)[source]

Returns the numpy version of a tensor

_from_numpy(a, type_as=None)[source]

Creates a tensor cloning a numpy array, with the given precision (defaulting to input’s precision) and the given device (in case of GPUs)

set_gradients(val, inputs, grads)[source]

Define the gradients for the value val wrt the inputs

_detach(a)[source]

Detach the tensor from the computation graph

zeros(shape, type_as=None)[source]

Creates a tensor full of zeros.

This function follows the api from numpy.zeros

See: https://numpy.org/doc/stable/reference/generated/numpy.zeros.html

ones(shape, type_as=None)[source]

Creates a tensor full of ones.

This function follows the api from numpy.ones

See: https://numpy.org/doc/stable/reference/generated/numpy.ones.html

arange(stop, start=0, step=1, type_as=None)[source]

Returns evenly spaced values within a given interval.

This function follows the api from numpy.arange

See: https://numpy.org/doc/stable/reference/generated/numpy.arange.html

full(shape, fill_value, type_as=None)[source]

Creates a tensor with given shape, filled with given value.

This function follows the api from numpy.full

See: https://numpy.org/doc/stable/reference/generated/numpy.full.html

eye(N, M=None, type_as=None)[source]

Creates the identity matrix of given size.

This function follows the api from numpy.eye

See: https://numpy.org/doc/stable/reference/generated/numpy.eye.html

sum(a, axis=None, keepdims=False)[source]

Sums tensor elements over given dimensions.

This function follows the api from numpy.sum

See: https://numpy.org/doc/stable/reference/generated/numpy.sum.html

cumsum(a, axis=None)[source]

Returns the cumulative sum of tensor elements over given dimensions.

This function follows the api from numpy.cumsum

See: https://numpy.org/doc/stable/reference/generated/numpy.cumsum.html

max(a, axis=None, keepdims=False)[source]

Returns the maximum of an array or maximum along given dimensions.

This function follows the api from numpy.amax

See: https://numpy.org/doc/stable/reference/generated/numpy.amax.html

min(a, axis=None, keepdims=False)[source]

Returns the maximum of an array or maximum along given dimensions.

This function follows the api from numpy.amin

See: https://numpy.org/doc/stable/reference/generated/numpy.amin.html

maximum(a, b)[source]

Returns element-wise maximum of array elements.

This function follows the api from numpy.maximum

See: https://numpy.org/doc/stable/reference/generated/numpy.maximum.html

minimum(a, b)[source]

Returns element-wise minimum of array elements.

This function follows the api from numpy.minimum

See: https://numpy.org/doc/stable/reference/generated/numpy.minimum.html

sign(a)[source]

Returns an element-wise indication of the sign of a number.

This function follows the api from numpy.sign

See: https://numpy.org/doc/stable/reference/generated/numpy.sign.html

dot(a, b)[source]

Returns the dot product of two tensors.

This function follows the api from numpy.dot

See: https://numpy.org/doc/stable/reference/generated/numpy.dot.html

abs(a)[source]

Computes the absolute value element-wise.

This function follows the api from numpy.absolute

See: https://numpy.org/doc/stable/reference/generated/numpy.absolute.html

exp(a)[source]

Computes the exponential value element-wise.

This function follows the api from numpy.exp

See: https://numpy.org/doc/stable/reference/generated/numpy.exp.html

log(a)[source]

Computes the natural logarithm, element-wise.

This function follows the api from numpy.log

See: https://numpy.org/doc/stable/reference/generated/numpy.log.html

sqrt(a)[source]

Returns the non-ngeative square root of a tensor, element-wise.

This function follows the api from numpy.sqrt

See: https://numpy.org/doc/stable/reference/generated/numpy.sqrt.html

power(a, exponents)[source]

First tensor elements raised to powers from second tensor, element-wise.

This function follows the api from numpy.power

See: https://numpy.org/doc/stable/reference/generated/numpy.power.html

norm(a, axis=None, keepdims=False)[source]

Computes the matrix frobenius norm.

This function follows the api from numpy.linalg.norm

See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.norm.html

any(a)[source]

Tests whether any tensor element along given dimensions evaluates to True.

This function follows the api from numpy.any

See: https://numpy.org/doc/stable/reference/generated/numpy.any.html

isnan(a)[source]

Tests element-wise for NaN and returns result as a boolean tensor.

This function follows the api from numpy.isnan

See: https://numpy.org/doc/stable/reference/generated/numpy.isnan.html

isinf(a)[source]

Tests element-wise for positive or negative infinity and returns result as a boolean tensor.

This function follows the api from numpy.isinf

See: https://numpy.org/doc/stable/reference/generated/numpy.isinf.html

einsum(subscripts, *operands)[source]

Evaluates the Einstein summation convention on the operands.

This function follows the api from numpy.einsum

See: https://numpy.org/doc/stable/reference/generated/numpy.einsum.html

sort(a, axis=-1)[source]

Returns a sorted copy of a tensor.

This function follows the api from numpy.sort

See: https://numpy.org/doc/stable/reference/generated/numpy.sort.html

argsort(a, axis=-1)[source]

Returns the indices that would sort a tensor.

This function follows the api from numpy.argsort

See: https://numpy.org/doc/stable/reference/generated/numpy.argsort.html

searchsorted(a, v, side='left')[source]

Finds indices where elements should be inserted to maintain order in given tensor.

This function follows the api from numpy.searchsorted

See: https://numpy.org/doc/stable/reference/generated/numpy.searchsorted.html

flip(a, axis=None)[source]

Reverses the order of elements in a tensor along given dimensions.

This function follows the api from numpy.flip

See: https://numpy.org/doc/stable/reference/generated/numpy.flip.html

outer(a, b)[source]

Computes the outer product between two vectors.

This function follows the api from numpy.outer

See: https://numpy.org/doc/stable/reference/generated/numpy.outer.html

clip(a, a_min, a_max)[source]

Limits the values in a tensor.

This function follows the api from numpy.clip

See: https://numpy.org/doc/stable/reference/generated/numpy.clip.html

repeat(a, repeats, axis=None)[source]

Repeats elements of a tensor.

This function follows the api from numpy.repeat

See: https://numpy.org/doc/stable/reference/generated/numpy.repeat.html

take_along_axis(arr, indices, axis)[source]

Gathers elements of a tensor along given dimensions.

This function follows the api from numpy.take_along_axis

See: https://numpy.org/doc/stable/reference/generated/numpy.take_along_axis.html

concatenate(arrays, axis=0)[source]

Joins a sequence of tensors along an existing dimension.

This function follows the api from numpy.concatenate

See: https://numpy.org/doc/stable/reference/generated/numpy.concatenate.html

zero_pad(a, pad_width, value=0)[source]

Pads a tensor with a given value (0 by default).

This function follows the api from numpy.pad

See: https://numpy.org/doc/stable/reference/generated/numpy.pad.html

argmax(a, axis=None)[source]

Returns the indices of the maximum values of a tensor along given dimensions.

This function follows the api from numpy.argmax

See: https://numpy.org/doc/stable/reference/generated/numpy.argmax.html

argmin(a, axis=None)[source]

Returns the indices of the minimum values of a tensor along given dimensions.

This function follows the api from numpy.argmin

See: https://numpy.org/doc/stable/reference/generated/numpy.argmin.html

mean(a, axis=None)[source]

Computes the arithmetic mean of a tensor along given dimensions.

This function follows the api from numpy.mean

See: https://numpy.org/doc/stable/reference/generated/numpy.mean.html

median(a, axis=None)[source]

Computes the median of a tensor along given dimensions.

This function follows the api from numpy.median

See: https://numpy.org/doc/stable/reference/generated/numpy.median.html

std(a, axis=None)[source]

Computes the standard deviation of a tensor along given dimensions.

This function follows the api from numpy.std

See: https://numpy.org/doc/stable/reference/generated/numpy.std.html

linspace(start, stop, num, type_as=None)[source]

Returns a specified number of evenly spaced values over a given interval.

This function follows the api from numpy.linspace

See: https://numpy.org/doc/stable/reference/generated/numpy.linspace.html

meshgrid(a, b)[source]

Returns coordinate matrices from coordinate vectors (Numpy convention).

This function follows the api from numpy.meshgrid

See: https://numpy.org/doc/stable/reference/generated/numpy.meshgrid.html

diag(a, k=0)[source]

Extracts or constructs a diagonal tensor.

This function follows the api from numpy.diag

See: https://numpy.org/doc/stable/reference/generated/numpy.diag.html

unique(a, return_inverse=False)[source]

Finds unique elements of given tensor.

This function follows the api from numpy.unique

See: https://numpy.org/doc/stable/reference/generated/numpy.unique.html

logsumexp(a, axis=None)[source]

Computes the log of the sum of exponentials of input elements.

This function follows the api from scipy.special.logsumexp

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.logsumexp.html

stack(arrays, axis=0)[source]

Joins a sequence of tensors along a new dimension.

This function follows the api from numpy.stack

See: https://numpy.org/doc/stable/reference/generated/numpy.stack.html

reshape(a, shape)[source]

Gives a new shape to a tensor without changing its data.

This function follows the api from numpy.reshape

See: https://numpy.org/doc/stable/reference/generated/numpy.reshape.html

seed(seed=None)[source]

Sets the seed for the random generator.

This function follows the api from numpy.random.seed

See: https://numpy.org/doc/stable/reference/random/generated/numpy.random.seed.html

rand(*size, type_as=None)[source]

Generate uniform random numbers.

This function follows the api from numpy.random.rand

See: https://numpy.org/doc/stable/reference/random/generated/numpy.random.rand.html

randn(*size, type_as=None)[source]

Generate normal Gaussian random numbers.

This function follows the api from numpy.random.rand

See: https://numpy.org/doc/stable/reference/random/generated/numpy.random.rand.html

_convert_to_index_for_coo(tensor)[source]
coo_matrix(data, rows, cols, shape=None, type_as=None)[source]

Creates a sparse tensor in COOrdinate format.

This function follows the api from scipy.sparse.coo_matrix

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.html

issparse(a)[source]

Checks whether or not the input tensor is a sparse tensor.

This function follows the api from scipy.sparse.issparse

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.issparse.html

tocsr(a)[source]

Converts this matrix to Compressed Sparse Row format.

This function follows the api from scipy.sparse.coo_matrix.tocsr

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.tocsr.html

eliminate_zeros(a, threshold=0.0)[source]

Removes entries smaller than the given threshold from the sparse tensor.

This function follows the api from scipy.sparse.csr_matrix.eliminate_zeros

See: https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.sparse.csr_matrix.eliminate_zeros.html

todense(a)[source]

Converts a sparse tensor to a dense tensor.

This function follows the api from scipy.sparse.csr_matrix.toarray

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.toarray.html

where(condition, x=None, y=None)[source]

Returns elements chosen from x or y depending on condition.

This function follows the api from numpy.where

See: https://numpy.org/doc/stable/reference/generated/numpy.where.html

copy(a)[source]

Returns a copy of the given tensor.

This function follows the api from numpy.copy

See: https://numpy.org/doc/stable/reference/generated/numpy.copy.html

allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False)[source]

Returns True if two arrays are element-wise equal within a tolerance.

This function follows the api from numpy.allclose

See: https://numpy.org/doc/stable/reference/generated/numpy.allclose.html

dtype_device(a)[source]

Returns the dtype and the device of the given tensor.

assert_same_dtype_device(a, b)[source]

Checks whether or not the two given inputs have the same dtype as well as the same device

squeeze(a, axis=None)[source]

Remove axes of length one from a.

This function follows the api from numpy.squeeze.

See: https://numpy.org/doc/stable/reference/generated/numpy.squeeze.html

bitsize(type_as)[source]

Gives the number of bits used by the data type of the given tensor.

device_type(type_as)[source]

Returns CPU or GPU depending on the device where the given tensor is located.

_bench(callable, *args, n_runs=1, warmup_runs=1)[source]

Executes a benchmark of the given callable with the given arguments.

solve(a, b)[source]

Solves a linear matrix equation, or system of linear scalar equations.

This function follows the api from numpy.linalg.solve.

See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.solve.html

trace(a)[source]

Returns the sum along diagonals of the array.

This function follows the api from numpy.trace.

See: https://numpy.org/doc/stable/reference/generated/numpy.trace.html

inv(a)[source]

Computes the inverse of a matrix.

This function follows the api from scipy.linalg.inv.

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.inv.html

sqrtm(a)[source]

Computes the matrix square root. Requires input to be definite positive.

This function follows the api from scipy.linalg.sqrtm.

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.sqrtm.html

eigh(a)[source]

Computes the eigenvalues and eigenvectors of a symmetric tensor.

This function follows the api from scipy.linalg.eigh.

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.eigh.html

kl_div(p, q, mass=False, eps=1e-16)[source]

Computes the (Generalized) Kullback-Leibler divergence.

This function follows the api from scipy.stats.entropy.

Parameter eps is used to avoid numerical errors and is added in the log.

\[KL(p,q) = \langle \mathbf{p}, log(\mathbf{p} / \mathbf{q} + eps \rangle + \mathbb{1}_{mass=True} \langle \mathbf{q} - \mathbf{p}, \mathbf{1} \rangle\]

See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.entropy.html

isfinite(a)[source]

Tests element-wise for finiteness (not infinity and not Not a Number).

This function follows the api from numpy.isfinite.

See: https://numpy.org/doc/stable/reference/generated/numpy.isfinite.html

array_equal(a, b)[source]

True if two arrays have the same shape and elements, False otherwise.

This function follows the api from numpy.array_equal.

See: https://numpy.org/doc/stable/reference/generated/numpy.array_equal.html

is_floating_point(a)[source]

Returns whether or not the input consists of floats

tile(a, reps)[source]

Construct an array by repeating a the number of times given by reps

See: https://numpy.org/doc/stable/reference/generated/numpy.tile.html

floor(a)[source]

Return the floor of the input element-wise

See: https://numpy.org/doc/stable/reference/generated/numpy.floor.html

prod(a, axis=0)[source]

Return the product of all elements.

See: https://numpy.org/doc/stable/reference/generated/numpy.prod.html

sort2(a, axis=-1)[source]

Return the sorted array and the indices to sort the array

See: https://pytorch.org/docs/stable/generated/torch.sort.html

qr(a)[source]

Return the QR factorization

See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.qr.html

atan2(a, b)[source]

Element wise arctangent

See: https://numpy.org/doc/stable/reference/generated/numpy.arctan2.html

transpose(a, axes=None)[source]

Returns a tensor that is a transposed version of a. The given dimensions dim0 and dim1 are swapped.

See: https://numpy.org/doc/stable/reference/generated/numpy.transpose.html

matmul(a, b)[source]

Matrix product of two arrays.

See: https://numpy.org/doc/stable/reference/generated/numpy.matmul.html#numpy.matmul

nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None)[source]

Replace NaN with zero and infinity with large finite numbers or with the numbers defined by the user.

See: https://numpy.org/doc/stable/reference/generated/numpy.nan_to_num.html#numpy.nan_to_num