View & Manipulation Operations#

reshape#

def reshape(arg: nabla.core.tensor.Tensor, shape: tuple[int, ...]) -> nabla.core.tensor.Tensor:

Reshape tensor to given shape.


transpose#

def transpose(arg: nabla.core.tensor.Tensor, axis_1: int = -2, axis_2: int = -1) -> nabla.core.tensor.Tensor:

Transpose tensor along two axes.


permute#

def permute(input_tensor: nabla.core.tensor.Tensor, axes: tuple[int, ...]) -> nabla.core.tensor.Tensor:

Permute (reorder) the dimensions of a tensor.

Examples

>>> x = nb.ones((2, 3, 4))  # shape (2, 3, 4)
>>> y = permute(x, (2, 0, 1))  # shape (4, 2, 3)
>>> # Dimension 2 -> position 0, dimension 0 -> position 1, dimension 1 -> position 2

concatenate#

def concatenate(args: list[nabla.core.tensor.Tensor], axis: int = 0) -> nabla.core.tensor.Tensor:

Concatenate tensors along an existing axis.

Parameters#

args: List of tensors to concatenate
axis: Axis along which to concatenate tensors (default: 0)

Returns#

Concatenated tensor

stack#

def stack(tensors: list[nabla.core.tensor.Tensor], axis: int = 0) -> nabla.core.tensor.Tensor:

Stack tensors along a new axis.

Parameters#

tensors: List of tensors to stack
axis: Axis along which to stack the tensors (default: 0)

Returns#

Stacked tensor

split#

def split(arg: nabla.core.tensor.Tensor, sizes: list[int], axis: int = 0) -> list[nabla.core.tensor.Tensor]:

Split an tensor into multiple sub-tensors along a specified axis.

Parameters#

arg: Input tensor to split
sizes: List of sizes for each split along the specified axis
axis: Axis along which to split the tensor (default: 0)

Returns#

List of sub-tensors resulting from the split

squeeze#

def squeeze(arg: nabla.core.tensor.Tensor, axes: list[int] | None = None) -> nabla.core.tensor.Tensor:

Squeeze tensor by removing dimensions of size 1.


unsqueeze#

def unsqueeze(arg: nabla.core.tensor.Tensor, axes: list[int] | None = None) -> nabla.core.tensor.Tensor:

Unsqueeze tensor by adding dimensions of size 1.


pad#

def pad(arg: nabla.core.tensor.Tensor, slices: list[slice], target_shape: tuple[int, ...]) -> nabla.core.tensor.Tensor:

Place a smaller tensor into a larger zero-filled tensor at the location specified by slices.

This is the inverse operation of tensor slicing - given slices, a small tensor, and target shape, it creates a larger tensor where the small tensor is placed at the sliced location and everything else is zero.

Parameters#

arg: Input tensor (the smaller tensor to be placed)
slices: List of slice objects defining where to place the tensor
target_shape: The shape of the output tensor

Returns#

Larger tensor with input placed at sliced location, zeros elsewhere

broadcast_to#

def broadcast_to(arg: nabla.core.tensor.Tensor, shape: tuple[int, ...]) -> nabla.core.tensor.Tensor:

Broadcast tensor to target shape.


tensor_slice#

def tensor_slice(arg: nabla.core.tensor.Tensor, slices: list[slice], squeeze_axes: list[int] | None = None) -> nabla.core.tensor.Tensor:

Slice an tensor along specified dimensions.

Parameters#

arg: Input tensor to slice
slices: List of slice objects defining the slicing for each dimension
squeeze_axes: List of axes that should be squeezed (for JAX compatibility)

Returns#

Sliced tensor

shallow_copy#

def shallow_copy(arg: nabla.core.tensor.Tensor) -> nabla.core.tensor.Tensor:

Create a shallow copy of the tensor.


transpose_batch_dims#

def transpose_batch_dims(arg: nabla.core.tensor.Tensor, axis_1: int = -2, axis_2: int = -1) -> nabla.core.tensor.Tensor:

Transpose batch dimensions along two axes.

This operation swaps two axes in the batch_dims of an Tensor, similar to how regular transpose works on shape dimensions. The shape dimensions remain unchanged.

Examples

>>> import nabla as nb
>>> # Tensor with batch_dims=(2, 3, 4) and shape=(5, 6)
>>> x = nb.ones((5, 6))
>>> x.batch_dims = (2, 3, 4)  # Simulated for example
>>> y = transpose_batch_dims(x, -3, -1)  # Swap first and last batch dims
>>> # Result has batch_dims=(4, 3, 2) and shape=(5, 6)

permute_batch_dims#

def permute_batch_dims(input_tensor: nabla.core.tensor.Tensor, axes: tuple[int, ...]) -> nabla.core.tensor.Tensor:

Permute (reorder) the batch dimensions of an tensor.

This operation reorders the batch_dims of an Tensor according to the given axes, similar to how regular permute works on shape dimensions. The shape dimensions remain unchanged.

Examples

>>> import nabla as nb
>>> # Tensor with batch_dims=(2, 3, 4) and shape=(5, 6)
>>> x = nb.ones((5, 6))
>>> x.batch_dims = (2, 3, 4)  # Simulated for example
>>> y = permute_batch_dims(x, (-1, -3, -2))  # Reorder as (4, 2, 3)
>>> # Result has batch_dims=(4, 2, 3) and shape=(5, 6)

broadcast_batch_dims#

def broadcast_batch_dims(arg: nabla.core.tensor.Tensor, batch_dims: tuple[int, ...]) -> nabla.core.tensor.Tensor:

Broadcast tensor to target batch_dims.


squeeze_batch_dims#

def squeeze_batch_dims(arg: nabla.core.tensor.Tensor, axes: list[int] | None = None) -> nabla.core.tensor.Tensor:

Squeeze tensor by removing batch dimensions of size 1.

Parameters#

arg: Input tensor
axes: List of batch dimension axes to squeeze. If None, returns tensor unchanged.

Returns#

Tensor with specified batch dimensions of size 1 removed

unsqueeze_batch_dims#

def unsqueeze_batch_dims(arg: nabla.core.tensor.Tensor, axes: list[int] | None = None) -> nabla.core.tensor.Tensor:

Unsqueeze tensor by adding batch dimensions of size 1.

Parameters#

arg: Input tensor
axes: List of positions where to insert batch dimensions of size 1.
      If None, returns tensor unchanged.

Returns#

Tensor with batch dimensions of size 1 added at specified positions