View & Shape#
reshape#
def reshape(x: 'Tensor', shape: 'tuple[int, ...]') -> 'Tensor':
Return a tensor with the same data as x reshaped to shape.
The total number of elements must be preserved. A single -1
in shape is automatically inferred.
Parameters
x– Input tensor.shape– Target shape.
Returns
Tensor with the same data and new shape.
transpose#
def transpose(x: 'Tensor', axis1: 'int', axis2: 'int') -> 'Tensor':
Swap (transpose) two dimensions of x.
Parameters
x– Input tensor.axis1– First axis. Supports negative indexing.axis2– Second axis. Supports negative indexing.
Returns
View with axis1 and axis2 swapped.
permute#
def permute(x: 'Tensor', order: 'tuple[int, ...]') -> 'Tensor':
Reorder the dimensions of x according to order.
Parameters
x– Input tensor of rankN.order– A permutation of(0, 1, ..., N-1)giving the new dimension ordering. Equivalent to NumPy’stranspose(axes=order).
Returns
Tensor with dimensions reordered as specified.
unsqueeze#
def unsqueeze(x: 'Tensor', axis: 'int' = 0) -> 'Tensor':
Insert a size-1 dimension at axis into x’s shape.
Parameters
x– Input tensor.axis– Position at which to insert the new dimension. Supports negative indexing.
Returns
Tensor with one additional dimension of size 1.
squeeze#
def squeeze(x: 'Tensor', axis: 'int' = 0) -> 'Tensor':
Remove the size-1 dimension at axis from x’s shape.
Parameters
x– Input tensor. The dimension at axis must be 1.axis– Dimension to remove. Supports negative indexing. PassNoneto squeeze all size-1 dimensions.
Returns
Tensor with the specified dimension removed.
flatten#
def flatten(x: 'Tensor', start_dim: 'int' = 0, end_dim: 'int' = -1) -> 'Tensor':
Flatten a contiguous range of dimensions into one.
Parameters
x– Input tensor.start_dim– First dimension to flatten (inclusive). Default:0.end_dim– Last dimension to flatten (inclusive). Default:-1(last dimension).
Returns
Tensor with dimensions start_dim through end_dim collapsed
into a single dimension.
broadcast_to#
def broadcast_to(x: 'Tensor', shape: 'tuple[int, ...]') -> 'Tensor':
Broadcast x to a new shape.
Leading dimensions are added as needed (NumPy-style broadcasting). Size-1 dimensions in x are expanded to match shape.
Parameters
x– Input tensor.shape– Target output shape. Must be broadcast-compatible withx.shape.
Returns
Tensor of the given shape sharing data with x where possible.
concatenate#
def concatenate(tensors: 'Sequence[Tensor]', axis: 'int' = 0) -> 'Tensor':
Concatenate a sequence of tensors along an existing axis.
Parameters
tensors– Non-empty sequence of tensors with the same shape except along axis.axis– Axis along which to concatenate. Default:0.
Returns
Tensor whose axis dimension is the sum of the inputs’.
stack#
def stack(tensors: 'list[Tensor]', axis: 'int' = 0) -> 'Tensor':
Stack a sequence of tensors along a new axis.
All tensors must have the same shape. The result has one more dimension than the inputs.
Parameters
tensors– List of tensors with identical shapes.axis– Position of the new dimension in the output. Default:0.
Returns
Tensor of shape tensors[0].shape[:axis] + (N,) + tensors[0].shape[axis:]
where N is the number of input tensors.
gather#
def gather(x: 'Tensor', indices: 'Tensor', axis: 'int' = 0) -> 'Tensor':
Gather elements from x along axis using indices.
Equivalent to x.take(indices, axis=axis) in NumPy.
Parameters
x– Source tensor.indices– Integer tensor of indices. Shape can differ from x.axis– Axis along which to index x. Default:0.
Returns
Tensor with the same dtype as x and shape
x.shape[:axis] + indices.shape + x.shape[axis+1:].
scatter#
def scatter(x: 'Tensor', indices: 'Tensor', updates: 'Tensor', axis: 'int' = 0) -> 'Tensor':
Scatter updates into x at indices along axis.
Functional (out-of-place). Equivalent to
out = x.copy(); out[indices] = updates along axis.
Parameters
x– Base tensor that receives the scattered values.indices– Integer index tensor specifying where to write.updates– Values to write into x. Must be compatible withx.shape[:axis] + indices.shape + x.shape[axis+1:].axis– Axis along which to scatter. Default:0.
Returns
New tensor equal to x except at positions specified by indices.
slice_tensor#
def slice_tensor(x: 'Tensor', start: 'Any', size: 'Any') -> 'Tensor':
Extract a rectangular slice from x.
Parameters
x– Input tensor.start– Sequence of per-dimension start indices (supports negative).size– Sequence of per-dimension slice sizes.
Returns
Tensor of shape size containing the requested slice.
slice_update#
def slice_update(x: 'Tensor', update: 'Tensor', start: 'Any', size: 'Any') -> 'Tensor':
Return x with a rectangular region replaced by update.
This is a functional (out-of-place) operation. The original x is not modified. Supports autograd.
Parameters
x– Base tensor to update.update– Values to write into x. Must have shape size.start– Sequence of per-dimension start indices.size– Sequence of per-dimension region sizes.
Returns
New tensor equal to x except at the specified slice.
moveaxis#
def moveaxis(x: 'Tensor', source: 'int', destination: 'int') -> 'Tensor':
Move axis source to position destination.
Parameters
x– Input tensor.source– Original axis position. Supports negative indexing.destination– Target axis position. Supports negative indexing.
Returns
Tensor with the axis at source moved to destination.
swap_axes#
def swap_axes(x: 'Tensor', axis1: 'int', axis2: 'int') -> 'Tensor':
Swap (transpose) two dimensions of x.
Parameters
x– Input tensor.axis1– First axis. Supports negative indexing.axis2– Second axis. Supports negative indexing.
Returns
View with axis1 and axis2 swapped.
flip#
def flip(x: 'Tensor', axis: 'int') -> 'Tensor':
Reverse the elements of x along the specified axis.
Parameters
x– Input tensor.axis– The axis along which to reverse. Supports negative indexing.
Returns
Tensor with elements reversed along axis. Shape is unchanged.
pad#
def pad(x: 'Tensor', paddings: 'list[tuple[int, int]]' = None, mode: 'str' = 'constant', value: 'float' = 0.0, **kwargs) -> 'Tensor':
Pad a tensor with a constant value (or a specific padding mode).
Parameters
x– Input tensor.paddings– List of(before, after)tuples, one per logical dimension. Also accepted via thepad_widthkeyword alias.mode– Padding mode. Currently only"constant"is supported.value– Fill value for constant padding. Default:0.0.
Returns
Padded tensor. Each dimension i grows by
paddings[i][0] + paddings[i][1] elements.
rebind#
def rebind(x: 'Tensor', shape: 'tuple[int, ...]', **kwargs) -> 'Tensor':
Rebind a tensor to a new symbolic shape without changing the data.
Used to introduce or update shape constraints known at compile time. Has no gradient — the cotangent is passed through unchanged.
Parameters
x– Input tensor.shape– New shape annotation (can include symbolic dimensions).
Returns
Tensor with updated shape metadata.
as_interleaved_complex#
def as_interleaved_complex(x: 'Tensor') -> 'Tensor':
Reinterpret a real tensor with last dim 2 as a complex tensor.
Parameters
x– Real tensor of shape(..., 2).
Returns
Complex-valued tensor of shape (...).
view_as_real_interleaved#
def view_as_real_interleaved(x: 'Tensor') -> 'Tensor':
Reinterpret a complex tensor as a real tensor with an extra trailing 2-dim.
Parameters
x– Complex tensor of shape(...).
Returns
Real tensor of shape (..., 2) where the last axis contains
[real, imag] components.
broadcast_to_physical#
def broadcast_to_physical(x: 'Tensor', shape: 'tuple[int, ...]') -> 'Tensor':
Broadcast x to shape in the physical tensor layout.
Unlike :func:broadcast_to, this operates on the physical shape
(including batch dimensions added by vmap). Used internally by
transforms and physical gradient rules.
Parameters
x– Input tensor.shape– Target physical shape.
Returns
Tensor broadcast to the given physical shape.
squeeze_physical#
def squeeze_physical(x: 'Tensor', axis: 'int' = 0) -> 'Tensor':
Remove the size-1 dimension at axis in the physical tensor layout.
Counterpart to :func:unsqueeze_physical. Used internally by transforms.
unsqueeze_physical#
def unsqueeze_physical(x: 'Tensor', axis: 'int' = 0) -> 'Tensor':
Insert a size-1 dimension at axis in the physical tensor layout.
Unlike :func:unsqueeze, this operates on the physical shape (which
includes batch dimensions added by vmap). Used internally by
transforms that manipulate the physical layout directly.
incr_batch_dims#
def incr_batch_dims(x: 'Tensor') -> 'Tensor':
Increment batch_dims counter (first physical dim becomes batch dim).
decr_batch_dims#
def decr_batch_dims(x: 'Tensor') -> 'Tensor':
Decrement batch_dims counter (first batch dim becomes logical dim).
move_axis_to_batch_dims#
def move_axis_to_batch_dims(x: 'Tensor', axis: 'int') -> 'Tensor':
Move a logical axis into the batch dimensions (3 ops: calc + moveaxis_physical + incr).
move_axis_from_batch_dims#
def move_axis_from_batch_dims(x: 'Tensor', batch_axis: 'int' = 0, logical_destination: 'int' = 0) -> 'Tensor':
Move a batch dimension to logical axis (3 ops: calc + moveaxis_physical + decr).