Indexing Operations

Contents

Indexing Operations#

gather#

def gather(input_tensor: nabla.core.tensor.Tensor, indices: nabla.core.tensor.Tensor, axis: int = -1) -> nabla.core.tensor.Tensor:

Selects elements from an input tensor using indices along a specified axis.

This function is analogous to numpy.take_along_axis. It selects elements from input_tensor at the positions specified by indices.

Parameters

  • input_tensor : Tensor – The source tensor from which to gather values.

  • indices : Tensor – The tensor of indices to gather. Must be an integer-typed tensor.

  • axis : int, optional, default: -1 – The axis along which to gather. A negative value counts from the last dimension. Defaults to -1.

Returns

Tensor – A new tensor containing the elements of input_tensor at the given indices.

Examples

>>> import nabla as nb
>>> x = nb.tensor([[10, 20, 30], [40, 50, 60]])
>>> indices = nb.tensor([[0, 2], [1, 0]])
>>> # Gather along axis 1
>>> nb.gather(x, indices, axis=1)
Tensor([[10, 30],
       [50, 40]], dtype=int32)
>>> # Gather along axis 0
>>> indices = nb.tensor([[0, 1, 0]])
>>> nb.gather(x, indices, axis=0)
Tensor([[10, 50, 30]], dtype=int32)

scatter#

def scatter(target_shape: tuple, indices: nabla.core.tensor.Tensor, values: nabla.core.tensor.Tensor, axis: int = -1) -> nabla.core.tensor.Tensor:

Updates an tensor of zeros with given values at specified indices.

This function creates an tensor of shape target_shape filled with zeros and then places the values at the locations specified by indices along the given axis. This operation is the inverse of gather.

Parameters

  • target_shape : tuple – The shape of the output tensor.

  • indices : Tensor – An integer tensor specifying the indices to update.

  • values : Tensor – The tensor of values to scatter into the new tensor.

  • axis : int, optional, default: -1 – The axis along which to scatter. A negative value counts from the last dimension. Defaults to -1.

Returns

Tensor – A new tensor of shape target_shape with values scattered at the specified indices.

Examples

>>> import nabla as nb
>>> target_shape = (3, 4)
>>> indices = nb.tensor([0, 2, 1])
>>> values = nb.tensor([10, 20, 30])
>>> # Scatter values into a 1D target
>>> nb.scatter((4,), nb.tensor([0, 3, 1]), nb.tensor([1, 2, 3]))
Tensor([1, 3, 0, 2], dtype=int32)
>>> # Scatter rows into a 2D target along axis 0
>>> values_2d = nb.tensor([[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3]])
>>> nb.scatter(target_shape, indices, values_2d, axis=0)
Tensor([[1, 1, 1, 1],
       [3, 3, 3, 3],
       [2, 2, 2, 2]], dtype=int32)