TensorView API Reference#

This topic provides the API reference for the TensorView utility. It provides zero-copy tensor view operations for NKI tensors.

When to Use#

Use TensorView when you need to:

  • Reshape without copying: Change tensor layout for different computation phases

  • Slice with strides: Extract non-contiguous elements efficiently

  • Permute dimensions: Transpose or reorder dimensions for matmul compatibility

  • Broadcast dimensions: Expand size-1 dimensions without data duplication

  • Chain operations: Combine multiple view transformations fluently

TensorView is essential for kernels that need to interpret the same data in multiple layouts (e.g., attention kernels that reshape between [B, S, H] and [B, num_heads, S, head_dim]).

API Reference#

Source code: aws-neuron/nki-library

TensorView#

class nkilib.core.utils.tensor_view.TensorView(base_tensor)#

A view wrapper around NKI tensors supporting various operations without copying data.

Parameters:

base_tensor (nl.ndarray) – The underlying NKI tensor.

shape: tuple[int, ...]#

Current shape of the view.

strides: tuple[int, ...]#

Stride of each dimension in elements.

get_view()#

Generates the actual NKI tensor view using array pattern.

Returns:

NKI tensor with the view pattern applied.

Return type:

nl.ndarray

slice(dim, start, end, step=1)#

Creates a sliced view along a dimension.

Parameters:
  • dim (int) – Dimension to slice.

  • start (int) – Start index (inclusive).

  • end (int) – End index (exclusive).

  • step (int) – Step size. Default 1.

Returns:

New TensorView with sliced dimension.

Return type:

TensorView

permute(dims)#

Creates a permuted view by reordering dimensions.

Parameters:

dims (tuple[int, ...]) – New order of dimensions.

Returns:

New TensorView with permuted dimensions.

Return type:

TensorView

Note: For SBUF tensors, partition dimension (dim 0) must remain at position 0.

broadcast(dim, size)#

Expands a size-1 dimension to a larger size without copying.

Parameters:
  • dim (int) – Dimension to broadcast (must have size 1).

  • size (int) – New size for the dimension.

Returns:

New TensorView with broadcasted dimension.

Return type:

TensorView

reshape_dim(dim, shape)#

Reshapes a single dimension into multiple dimensions.

Parameters:
  • dim (int) – Dimension to reshape.

  • shape (tuple[int, ...]) – New sizes (can contain one -1 for inference).

Returns:

New TensorView with reshaped dimension.

Return type:

TensorView

flatten_dims(start_dim, end_dim)#

Flattens a range of contiguous dimensions into one.

Parameters:
  • start_dim (int) – First dimension to flatten (inclusive).

  • end_dim (int) – Last dimension to flatten (inclusive).

Returns:

New TensorView with flattened dimensions.

Return type:

TensorView

expand_dim(dim)#

Inserts a new dimension of size 1.

Parameters:

dim (int) – Position to insert the new dimension.

Returns:

New TensorView with added dimension.

Return type:

TensorView

squeeze_dim(dim)#

Removes a dimension of size 1.

Parameters:

dim (int) – Dimension to remove (must have size 1).

Returns:

New TensorView with removed dimension.

Return type:

TensorView

select(dim, index)#

Selects a single element along a dimension, reducing dimensionality.

Parameters:
  • dim (int) – Dimension to select from.

  • index (int | nl.ndarray) – Index to select (int for static, nl.ndarray for dynamic).

Returns:

New TensorView with one fewer dimension.

Return type:

TensorView

rearrange(src_pattern, dst_pattern, fixed_sizes=None)#

Rearranges dimensions using einops-style patterns.

Parameters:
  • src_pattern (tuple[str | tuple[str, ...], ...]) – Source dimension pattern with named dimensions.

  • dst_pattern (tuple[str | tuple[str, ...], ...]) – Destination dimension pattern.

  • fixed_sizes (dict[str, int], optional) – Dictionary mapping dimension names to sizes.

Returns:

New TensorView with rearranged dimensions.

Return type:

TensorView

reshape(new_shape)#

Reshapes the tensor to new dimensions.

Parameters:

new_shape (tuple[int, ...]) – New dimension shape.

Returns:

New TensorView with reshaped dimensions.

Return type:

TensorView

Note

General reshape is not yet implemented and will raise an error. Use reshape_dim for single-dimension reshaping.

has_dynamic_access()#

Checks if the tensor view uses dynamic indexing (via a prior select with an nl.ndarray index).

Returns:

True if the view has dynamic access, False otherwise.

Return type:

bool

Examples#

Reshape and Permute#

import nki.language as nl
from nkilib.core.utils.tensor_view import TensorView

@nki.jit
def kernel_reshape_permute(data_sb):
    view = TensorView(data_sb)  # Shape: (128, 24, 64)

    reshaped = view.reshape_dim(1, (4, 6))  # (128, 4, 6, 64)
    transposed = reshaped.permute((0, 2, 1, 3))  # (128, 6, 4, 64)

    result = transposed.get_view()

Slicing with Step#

from nkilib.core.utils.tensor_view import TensorView

@nki.jit
def kernel_strided_slice(data_sb):
    view = TensorView(data_sb)  # Shape: (128, 256)

    # Take every other element: indices 0, 2, 4, ...
    strided = view.slice(dim=1, start=0, end=256, step=2)  # (128, 128)

    result = strided.get_view()

Broadcasting#

from nkilib.core.utils.tensor_view import TensorView

@nki.jit
def kernel_broadcast(scale_sb, data_sb):
    # scale_sb shape: (128, 1, 64)
    # data_sb shape: (128, 32, 64)

    scale_view = TensorView(scale_sb)

    # Broadcast dim 1 from size 1 to 32
    broadcasted = scale_view.broadcast(dim=1, size=32)  # (128, 32, 64)

    # Now can multiply element-wise
    result = data_sb * broadcasted.get_view()

Einops-Style Rearrange#

from nkilib.core.utils.tensor_view import TensorView

@nki.jit
def kernel_rearrange(data_sb):
    view = TensorView(data_sb)  # Shape: (128, 512, 64)

    # Reshape and transpose: (p, h*w, c) -> (p, c, h, w)
    # where h=32 (must specify one dimension for -1 inference)
    rearranged = view.rearrange(
        src_pattern=('p', ('h', 'w'), 'c'),
        dst_pattern=('p', 'c', 'h', 'w'),
        fixed_sizes={'h': 32}
    )  # (128, 64, 32, 16)

    result = rearranged.get_view()

Chained Operations#

from nkilib.core.utils.tensor_view import TensorView

@nki.jit
def attention_reshape(qkv_sb, num_heads, head_dim):
    # qkv_sb shape: (128, seq_len, 3 * num_heads * head_dim)
    view = TensorView(qkv_sb)

    # Chain: reshape -> slice Q -> reshape to heads
    q_view = (view
        .reshape_dim(2, (3, num_heads, head_dim))  # (128, S, 3, H, D)
        .select(dim=2, index=0)                     # (128, S, H, D) - select Q
        .permute((0, 2, 1, 3)))                     # (128, H, S, D)

    q = q_view.get_view()

See Also#

  • stream_shuffle_broadcast - Hardware broadcast for partition dimension

  • SbufManager - Memory allocation with scope management