This document is relevant for: Inf1, Inf2, Trn1, Trn2, Trn3

Release Notes for Neuron Component: Neuron Kernel Interface (NKI)#

The release notes for the Neuron Kernel Interface (NKI) component. Read them for the details about the changes, improvements, and bug fixes for all release versions of the AWS Neuron SDK.

Neuron Kernel Interface (NKI) [0.4.0] (Neuron 2.30.0 Release)#

Date of Release: 05/21/2026

New Features#

  • nki.language.abs_max and nki.language.abs_min: New callable APIs for element-wise absolute maximum and absolute minimum. These are trn3 only and run on the Vector Engine. Also usable as op0 with nki.isa.tensor_scalar and nki.isa.tensor_scalar_reduce. See nki.isa.tensor_scalar and nki.isa.tensor_scalar_reduce.

  • nki.isa.activate2: New trn3 Scalar Engine API that applies an activation function to the result of a two-stage tensor-scalar preprocessing pipeline (data op0 imm0) op1 imm1, with an optional reduction, all in a single instruction. Supports six (op0, op1) combinations (scale+bias, scale-only, bias-only, etc.) and optional operand reversal for non-commutative operations. Reduces instruction count compared to chaining nisa.tensor_scalar with nisa.activation. See nki.isa.activate2.

  • New opcodes for nki.isa.tensor_scalar and tensor_scalar_reduce: square and relu are now accepted as op0 on trn3. See nki.isa.tensor_scalar and nki.isa.tensor_scalar_reduce.

  • New activation and arithmetic opcodes in nki.language: nl.prelu (parametric ReLU, used as op=nl.prelu with nki.isa.activate2) and nl.bypass (pass-through op for nki.isa.activate2). See the supported activation functions and arithmetic operator tables in nki.api.shared.

  • tile_size bytes-aware constants: New properties on nki.language.tile_size expose SBUF and PSUM capacity in both elements and bytes:

    • tile_size.sbuf_size_bytes — total SBUF capacity across all 128 partitions, in bytes

    • tile_size.sbuf_fmax — per-partition usable SBUF free dimension in FP32 elements

    • tile_size.sbuf_fmax_bytes — per-partition usable SBUF free dimension in bytes

    • tile_size.psum_bank_fmax — PSUM bank capacity in FP32 elements

    • tile_size.psum_bank_fmax_bytes — PSUM bank capacity in bytes

    See nki.language.tile_size.

  • nki.isa.dma_compute oob_mode parameter: dma_compute now accepts an oob_mode parameter (oob_mode.error or oob_mode.skip) to control handling of out-of-bounds indices in indirect gather/scatter operations with vector_offset, mirroring existing dma_copy behavior. Validation ensures oob_mode.skip is used only with indirect indexing. See nki.isa.dma_compute.

  • nc_matmul float8_e4m3fn input dtype: On trn3, nc_matmul now accepts float8_e4m3fn (OCP FP8) as an input dtype, distinct from the legacy float8_e4m3. A new validation prevents mixing legacy float8_e4m3 with OCP float8_e4m3fn operands in the same matmul. See nki.isa.nc_matmul.

  • JAX dtype support: JAX scalar dtype types (jnp.bfloat16, jnp.float16, jnp.float32, etc.) are now accepted as kernel arguments and keyword arguments and automatically converted to the equivalent NKI dtype. Unsupported JAX dtypes raise TypeError. See NKI data types.

Improvements#

  • nki.isa.sendrecv — Removed the restriction that the src and dst partition dimension must be a multiple of 16. Note: sendrecv is an intra-LNC communication API and is only supported when running on LNC2 (trn2 or later). See nki.isa.sendrecv.

  • nki.isa.dma_transpose with indirect indexing — Relaxed the src innermost dimension constraint from exactly 128 to <= 128 when src uses an indirect access pattern (vector_offset). See nki.isa.dma_transpose.

  • ``nki.simulate`` default accuracy improved: NKI_PRECISE_FP=1 is now the default for CPU simulation. Low-precision dtypes (bfloat16, float8) are now modeled accurately instead of being approximated with float32, producing simulator results closer to hardware. Set NKI_PRECISE_FP=0 to restore the previous behavior. See nki.simulate.

  • ``NKI_SIMULATOR=1`` environment variable: Setting NKI_SIMULATOR=1 now works with torch.Tensor inputs directly — no manual conversion to NumPy arrays required. See nki.simulate.

  • Improved error messages for nested NKI calls: Kernel compilation errors now show the full Python call stack instead of only the innermost frame, making it easier to locate the call site that triggered the error.

Deprecated and Removed APIs#

  • nki.isa.tensor_copy_dynamic_src / nki.isa.tensor_copy_dynamic_dst — Removed. Use nisa.tensor_copy() with .ap() and scalar_offset instead. See nki.isa.tensor_copy.

  • nki.language.tile_size.total_available_sbuf_size — Deprecated. Despite the name, this attribute returns the usable SBUF free dimension per partition, not total SBUF capacity. Use tile_size.sbuf_size_bytes for total SBUF capacity across all partitions, or tile_size.sbuf_fmax_bytes for the per-partition size. The deprecated attribute continues to work and returns the same value as before. See nki.language.tile_size.

Breaking Changes#

  • nisa.dma_transpose — Now enforces that dst.shape matches the transposed src.shape exactly, including rank. Previously, a lower-rank dst.shape was silently padded to match a higher-rank src.shape (e.g., a 3D dst against a 4D src). The compiler now raises an assertion error if the ranks differ. To migrate: either match the dst rank to the src rank (e.g., use dst.shape=(128, 1, 1, 4096) for a 4D src), or use a src and axes of the same rank as the intended dst (e.g., a 3D src with axes=(2, 1, 0) instead of a 4D src with axes=(3, 1, 2, 0)). See nki.isa.dma_transpose.

  • neuronxcc.nki.* namespace — Usage of the deprecated neuronxcc.nki.* namespace inside NKI kernels now raises a compilation error instead of a warning. To migrate, follow the NKI Beta 2 Migration Guide.

Bug Fixes#

  • ``nki.simulate`` correctness fixes: The CPU simulator was corrected in several areas for closer matching to hardware behavior:

    • nki.isa.dma_copy with oob_mode=oob_mode.skip no longer casts integer tiles to float32 in the OOB skip path — integer bit patterns are now preserved.

    • nki.isa.nc_stream_shuffle with mask=255 now preserves existing destination data instead of zeroing it.

    • nki.isa.local_gather now produces correct results, including when the destination uses a sub-view or .ap() access pattern.

    • nki.isa.nc_matmul with 3D+ operand shapes now copies all elements (previously higher dimensions were silently dropped, producing zero-filled results).

    • nki.isa.quantize_mx with float8_e4m3fn_x4 output now simulates correctly.

    • nki.isa.iota and nki.isa.affine_select now handle dynamic register offsets correctly.

    • MX x4 packed dtypes (float8_e4m3fn_x4, float8_e5m2_x4) now simulate correctly when targeting trn3.

    • nki.isa.dma_compute no longer applies a fictional additive scale mode that the hardware does not support.

    • nl.logical_and in CPU simulation now produces correct results.

    See nki.simulate.

  • Fixed NKIObject subclasses decorated with @dataclass(frozen=True) failing to instantiate with FrozenInstanceError. Frozen NKI object subclasses can now be constructed normally.

  • Fixed nki.jit cache misses when NKIObject subclasses are decorated with @dataclass. @dataclass removes __hash__ by default, which prevented the cache key from being computed. NKIObject dataclass subclasses are now handled with a consistent cache key regardless of hashability. See nki.jit.

  • Fixed bool output tensors being returned as uint8 from PyTorch Native kernels. Bool dtype is now preserved end-to-end so output tensors are returned as torch.bool instead of torch.uint8.

  • Fixed hardware race conditions in dynamic loop kernels when loop-body memory accesses overlapped with pre-loop accesses. Cross-scope memory dependencies are now tracked correctly across loop boundaries.

  • Fixed silent, undefined behavior when writing to the induction variable of an nl.dynamic_range loop. Writing to the induction variable (e.g., to try to break out of the loop early) had no effect but did not surface any error. Such writes now raise AssertionError at trace time. Additionally, the induction variable is now a VirtualRegister (previously a bare scalar), so it can be used as a scalar_offset in access patterns (e.g., nisa.dma_copy with a per-iteration dynamic offset) — resolving a 0.3.0 known issue. See nki.language.dynamic_range.

  • Fixed nki.isa.nc_transpose silently ignoring the engine argument. An explicit engine=engine.vector or engine=engine.tensor is now honored; engine=engine.unknown (the default) continues to auto-select based on the destination buffer. See nki.isa.nc_transpose.

  • Fixed nki.jit kernel cache missing when a kernel is invoked with None arguments. Previously this triggered unnecessary recompilation on every such call. See nki.jit.

  • Fixed nl.device_print failing verification on 1-D HBM tensors (e.g., after linearization). device_print now works with any tensor rank. See nki.language.device_print.

Known Issues#

Control Flow

  • Nested nl.dynamic_range loops with loop-carried values fail to compile with a “Could not find register” error. Workaround: restructure to avoid nested dynamic loops, or use nl.static_range / nl.affine_range for the outer loop when the trip count is known at compile time.

CPU Simulator

  • The CPU simulator has additional known limitations beyond those listed here. See the Simulation Limitations section of the simulator guide for the full list.

NKI Language (experimental)

The nki.language APIs are convenience wrappers around nki.isa instructions. They are experimental and have the following known limitations:

  • nki.language.divide is not supported — Division is not available as a hardware instruction. As a workaround, multiply by the reciprocal: nl.multiply(x, nl.reciprocal(y)).

  • nki.language.fmod and nki.language.mod are not supported — Modulo operations are not available as hardware instructions. These APIs work in simulation but fail when compiled for Trainium hardware.

  • nki.language.power does not support scalar exponents — nl.power(tile, scalar) is not supported. Use nl.power(tile, tile) instead, where both operands are tiles.

  • Binary operations do not support broadcasting — Operations like nl.add(a, b) require both operands to have the same shape. Broadcasting (e.g., adding a (128, 1) tile to a (128, 512) tile) is not yet supported.

  • nki.language.random_seed requires a tensor, not a scalar — Pass a [1, 1] tensor on SBUF instead of a Python integer. For example: nl.random_seed(nl.full((1, 1), 42, dtype=nl.int32, buffer=nl.sbuf)).

  • nki.language.rand and nki.language.random_seed engine behavior — On trn3, rand uses nisa.rand2 on the Vector Engine. On earlier targets, rand uses nisa.rng which may run on a different engine than random_seed, potentially causing random_seed to have no effect on rand output.

  • nki.language.matmul without transpose_x=True is not supported — Calling nl.matmul(x, y) without setting transpose_x=True will fail. As a workaround, always use nl.matmul(x, y, transpose_x=True) and pre-arrange data accordingly.

  • nki.language.copy uses lossy FP32 casting — nl.copy uses the Scalar Engine which internally casts through float32, which is lossy for integer types with values exceeding float32 precision (e.g., int32 values greater than 2^23). Additionally, cross-buffer copies (e.g., PSUM to SBUF) are not supported.

Neuron Kernel Interface (NKI) [0.3.0] (Neuron 2.29.0 Release)#

Date of Release: 04/09/2026

AWS Neuron SDK 2.29.0 introduces NKI 0.3.0, a significant update to the Neuron Kernel Interface for General Availability. NKI 0.3.0 features NKI Standard Library (nki-stdlib), which provides developer-visible code for all NKI APIs and native language objects (e.g., NkiTensor). This release provides new exposed Trainium capabilities and features in the NKI API and introduces nki.language APIs. NKI 0.3.0 includes a CPU Simulator, which executes NKI kernels entirely on CPU using NumPy — enabling developers to validate kernel logic on laptops and CI environments without Trainium hardware. NKI 0.3.0 also includes the nki.typing module for declaring expected tensor shapes, a dedicated nki.isa.exponential instruction optimized for Softmax computation, matmul accumulation control, explicit memory address placement, and variable-length all-to-all collectives via nki.collectives.all_to_all_v. NKI 0.3.0 includes several API breaking changes that improve correctness and consistency along with updated documentation.

For the full list of changes and update examples, see the NKI 0.3.0 Update Guide.

New Features#

  • NKI Standard Library (nki-stdlib): NKI 0.3.0 ships with the NKI Standard Library (nki-stdlib), which provides developer-visible code for all NKI APIs and native language objects (e.g., NkiTensor).

  • NKI CPU Simulator (Experimental): Executes NKI kernels entirely on CPU using NumPy, enabling local development, debugging, and functional correctness testing without Trainium hardware. Set the environment variable NKI_SIMULATOR=1 to run existing kernels without code changes, or wrap the kernel call with nki.simulate(kernel). See nki.simulate API Reference.

  • nki.language APIs (Experimental): Introduces nki.language APIs as convenience wrappers around nki.isa APIs, including nl.load, nl.store, nl.copy, nl.matmul, nl.transpose, nl.softmax, and other high-level operations. See nki.language API Reference.

  • nki.typing module: New module for type-annotating kernel tensor parameters. Use nt.tensor[shape] to declare expected tensor shapes.

  • nki.isa.exponential: Dedicated exponential instruction with max subtraction, faster than nisa.activation(op=nl.exp) and useful for Softmax calculation. Trn3 (NeuronCore-v4) only. See nki.isa.exponential.

  • nki.collectives.all_to_all_v: Variable-length all-to-all collective. Unlike all_to_all, uses a metadata tensor to specify per-rank send/recv counts. See nki.collectives API Reference.

  • Matmul accumulation: nc_matmul and nc_matmul_mx now have an accumulate parameter that controls whether the operation overwrites or accumulates on the destination PSUM tile. The default (accumulate=None) auto-detects, matching NKI 0.2.0 behavior. See nki.isa.nc_matmul.

  • Address placement: The address parameter was added to nki.language.ndarray for explicit memory placement. See nki.language.ndarray.

Deprecated and Removed APIs#

  • nki.isa.tensor_copy_dynamic_src / nki.isa.tensor_copy_dynamic_dst — Deprecated and scheduled for removal. Use nisa.tensor_copy() with .ap() and scalar_offset instead.

  • nki.jit(platform_target=...) — Deprecated. Set the target platform via the NEURON_PLATFORM_TARGET_OVERRIDE environment variable instead. This is a breaking change.

  • nki.jit(mode=...) — Deprecated and ignored. The NKI Compiler now auto-detects the framework from kernel arguments. This is a breaking change.

Breaking Changes#

Note

NKI 0.3.0 requires all NKI kernels in a model to be updated to NKI 0.3.0. Mixing NKI 0.3.0 and NKI 0.2.0 kernels in the same model is not supported. For models that have not yet been updated, continue using Neuron SDK 2.28.

  • nisa.dma_copy — No longer supports reading directly from PSUM. Copy the PSUM tensor to SBUF first using nisa.tensor_copy.

  • nisa.dma_copy — Enforces matching source and destination element types when using dge_mode=dge_mode.hwdge. Use .view() to reinterpret types.

  • nisa.dma_copydst_rmw_op and unique_indices parameters removed. Use nisa.dma_compute instead.

  • nisa.dma_computescales and reduce_op parameters swapped positions. scales is now optional. unique_indices parameter added. Update call sites to use the new parameter order: nisa.dma_compute(dst, srcs, reduce_op, scales=None, unique_indices=True).

  • nisa.memset — Enforces strict type matching between value and destination dtype. x4 packed types enforce value=0. Kernels that pass float values to integer-typed tensors (e.g., value=2.0 instead of value=2) will now raise an error at compile time.

  • nisa.sendrecvuse_gpsimd_dma replaced by dma_engine enum. Update existing kernels to use the new enum.

  • nisa.affine_selectoffset moved from 3rd positional argument to keyword argument with default 0.

  • nisa.register_moveimm renamed to src, now accepts VirtualRegister. Update keyword argument from imm= to src=.

  • nki.collectives.collective_permute_implicit_current_processing_rank_idnum_channels parameter removed. Remove num_channels from call sites and pass channel_ids list to collective_permute_implicit() instead.

  • Output tensors must use buffer=nl.shared_hbm. Using nl.hbm causes compilation failures.

  • Raw integer enum constants no longer accepted. Use named enum members.

  • String buffer names no longer accepted. Use buffer objects (e.g., nl.sbuf).

  • Keyword-only argument separator (*) in kernel signatures is not supported.

  • is / is not operators are not supported. Use == / !=.

  • list kernel arguments are not supported. Convert to tuples.

For before-and-after code examples, see the NKI 0.3.0 Update Guide.

Note

The previously announced removal of the neuronxcc.nki.* namespace has been postponed to a future release. Both the neuronxcc.nki.* and nki.* namespaces continue to be supported in this release.

Other Changes#

  • nki.isa.dma_engine alias repurposed as the dma_engine enum for DMA transfer engine selection.

  • nki.isa.iotaoffset now optional with default 0.

  • nki.isa.core_barrierengine default changed from unknown to gpsimd (no behavioral change).

  • nki.language.num_programsaxes default changed from None to 0.

  • nki.language.program_idaxis now defaults to 0.

  • nki.language.ndarraybuffer default changed from None to nl.sbuf.

  • nki.language.zerosbuffer default changed from None to nl.sbuf.

  • nki.language.sequential_rangestop and step now have default values (None and 1).

Bug Fixes#

  • Fixed incorrect axis handling in nisa.tensor_reduce. NKI 0.2.0 incorrectly allowed axis=1 to refer to the last free dimension even for 3D/4D tensors. NKI 0.3.0 corrects this so that axis values correspond to the actual tensor dimensions.

  • Fixed nisa.range_select silently overriding user-specified parameters. The on_false_value and reduce_cmd parameters were incorrectly ignored by the compiler — on_false_value was always set to -3.4028235e+38 and reduce_cmd was always set to reset_reduce, regardless of the values passed in. NKI 0.3.0 honors the reduce_cmd parameter and documents the FP32_MIN hardware constraint for on_false_value.

Known Issues#

Math Operations

  • nki.language.divide is not supported — Division is not available as a hardware instruction. As a workaround, multiply by the reciprocal: nl.multiply(x, nl.reciprocal(y)).

  • nki.language.fmod and nki.language.mod are not supported — Modulo operations are not available as hardware instructions. These APIs work in simulation but will fail when compiled for Trainium hardware.

  • nki.language.power does not support scalar exponents — nl.power(tile, scalar) is not supported. Use nl.power(tile, tile) instead, where both operands are tiles.

Broadcasting

  • Binary operations do not support broadcasting — Operations like nl.add(a, b) require both operands to have the same shape. Broadcasting (e.g., adding a (128, 1) tile to a (128, 512) tile) is not yet supported.

Random Number Generation

  • nki.language.random_seed requires a tensor, not a scalar — Pass a [1, 1] tensor on SBUF instead of a Python integer. For example: nl.random_seed(nl.full((1, 1), 42, dtype=nl.int32, buffer=nl.sbuf)).

  • nki.language.rand and nki.language.random_seed engine behavior — On NeuronCore-v4+ (Trn3+), rand uses nisa.rand2 on the Vector Engine. On earlier NeuronCores, rand uses nisa.rng which may run on a different engine than random_seed, potentially causing random_seed to have no effect on rand output.

Matrix Operations

  • nki.language.matmul without transpose_x=True is not supported — Calling nl.matmul(x, y) without setting transpose_x=True will fail. As a workaround, always use nl.matmul(x, y, transpose_x=True) and pre-arrange data accordingly.

Data Movement

  • nki.language.store does not support PSUM tiles directly — Storing a tile that resides in PSUM requires manually copying it to SBUF first using nisa.tensor_copy.

  • nki.language.copy uses lossy FP32 casting — nl.copy uses the Scalar Engine which internally casts through FP32, which is lossy for integer types with values exceeding FP32 precision (e.g., int32 values > 2^23). Additionally, cross-buffer copies (e.g., PSUM to SBUF) are not supported.

Control Flow

  • nki.language.dynamic_range loop variable cannot be used in index arithmetic — The induction variable of a dynamic_range loop is a scalar, not a register. It cannot be used as a scalar_offset in access patterns or in arithmetic expressions for computing tile offsets. Use nl.affine_range or nl.static_range if you need to compute offsets from the loop variable.

Multi-Core (LNC2)

  • LNC2 requires identical control flow across cores — When running with Logical NeuronCore 2 (LNC2), the NKI compiler expects each physical NeuronCore to execute identical control flow. Programs with dynamic control flow that differs across cores may deadlock or produce incorrect results. This constraint is not enforced at compile time.

Caching

  • NKI kernel caching assumes kernels are pure functions of their input arguments. If a kernel’s output depends on external state (such as global variables or closures over mutable objects), the cache may return stale results. This is undefined behavior. Always ensure kernel outputs are determined solely by kernel arguments.

Compiler

  • Address rotation cannot be disabled — Address rotation, a backend compiler optimization that rotates tensor addresses for improved memory utilization, is enabled by default and cannot be opted out of in this release.

Collectives

  • nki.collectives.all_to_all_v(): has_rdispls=True has no effect on NeuronSwitch-based architectures (e.g., Trainium3 UltraServer); the receive layout is the same as has_rdispls=False.

Neuron Kernel Interface (NKI) (0.2.0) [2.28] (Neuron 2.28.0 Release)#

Date of Release: 02/26/2026

New Features#

Improvements#

  • Updated nki.isa APIs:

  • Compiler output improvements:

    • The compiler no longer truncates diagnostic output; users now receive the full set of warnings and errors

Breaking Changes#

  • nki.isa.nc_matmul parameter psumAccumulateFlag has been removed. This parameter had no effect on compilation or execution. Simply remove it from your kernel code.

  • nki.isa.nc_matmul parameter is_moving_zero has been renamed to is_moving_onezero to match hardware semantics, consistent with the companion is_stationary_onezero parameter. Kernels that passed is_moving_zero by name should update to is_moving_onezero.

  • nki.tensor has moved to nki.meta.tensor. Users should update their imports accordingly.

Note

The previously announced removal of the neuronxcc.nki.* namespace has been postponed from Neuron 2.28 to Neuron 2.29. Both the neuronxcc.nki.* and nki.* namespaces continue to be supported in this release. We encourage customers to migrate to the nki.* namespace using the NKI 0.2.0 Migration Guide.

Bug Fixes#

  • Fixed incorrect default value for on_false_value in nki.isa.range_select. The default was 0.0 instead of negative infinity (-inf). This caused range_select to write zeros for out-of-range elements instead of the expected negative-infinity sentinel, which could produce incorrect results in downstream reductions (e.g., max-pooling or top-k). See nki.isa.range_select.

  • Fixed default value parsing for keyword-only arguments in NKI kernels. When a Python function used keyword-only arguments with default values (arguments after * in the signature), the NKI compiler did not associate the defaults with their corresponding parameter names. This caused keyword-only arguments to appear as required even when they had defaults, leading to “missing argument” errors during kernel compilation.

  • Fixed wrong default for reduce_cmd in nki.isa.activation. The default was incorrectly set to ZeroAccumulate instead of Idle, causing the accumulator to be zeroed before every activation call even when no reduction was requested.

  • Fixed missing ALU operators (rsqrt, abs, power) in nki.isa.tensor_scalar and nki.isa.tensor_tensor. Passing these operators previously raised an “unsupported operator” error. See NKI Language Guide.

  • Fixed float8_e4m3fn to float8_e4m3 conversion for kernel inputs and outputs. When a tensor with dtype float8_e4m3fn was passed to the compiler, the automatic conversion to float8_e4m3 could fail with a size-check error. The conversion now validates sizes correctly before casting. See nki.language.float8_e4m3.

  • Fixed dynamic for loop incorrectly incrementing the loop induction variable. In loops with a runtime-determined trip count (sequential_range with non-constant bounds), the compiler generated incorrect increment code, causing the loop counter to never advance and the loop to run indefinitely or produce incorrect iteration values. See nki.language.sequential_range.

  • Fixed reshape of shared_hbm and private_hbm tensors failing partition size check. Reshape only recognized plain hbm memory as exempt from partition-dimension size validation. Tensors allocated in shared_hbm or private_hbm (used for cross-kernel and kernel-private storage) incorrectly triggered a “partition size mismatch” error when reshaped. See nki.language.shared_hbm and nki.language.private_hbm.

  • Fixed bias shape checking in nki.isa.activation. The bias parameter was not validated for shape correctness. A bias tensor with a free dimension other than 1 (e.g., shape (128, 64) instead of (128, 1)) was accepted without validation, which could produce incorrect results. The compiler now raises an error if the bias free dimension is not 1.

  • Fixed incorrect line numbers in stack traces and error reporting. An off-by-one error in the line offset calculation caused all reported line numbers to be shifted by one. Additionally, error location was sometimes lost when errors propagated across file boundaries.

  • Fixed invalid keyword arguments being silently ignored instead of raising an error. When calling an NKI API with a misspelled or unsupported keyword argument, the argument was ignored without warning. The compiler now validates all keyword argument names against the function signature and raises an unexpected keyword argument error for unrecognized names.

  • Fixed nki.jit in auto-detection mode returning an uncalled kernel object instead of executing the kernel. When nki.jit was used without specifying a framework mode (e.g., @nki.jit with no mode argument), the auto-detection path constructed the appropriate framework-specific kernel object but returned it without calling it. The user received a kernel object instead of the computed result, requiring an extra manual invocation. See nki.jit.

  • Fixed stale kernel object state between trace invocations. When tracing the same kernel multiple times (e.g., with different input shapes), compiler state was not fully reset between invocations, causing name collisions and incorrect results. The trace state is now fully reset before each invocation.

  • Improved ‘removed during code migration’ error messages with clear descriptions of unimplemented features. APIs not available in this release (nki.baremetal, nki.benchmark, nki.profile, nki.simulate_kernel) previously raised a generic NotImplementedError("removed during code migration") message. Each now raises a specific message naming the unsupported API. Additionally, calling an nki.jit kernel with no arguments now raises a clear error instead. See NKI 0.2.0 Migration Guide.

  • Fixed nested nki_jit decorators not being allowed. The NKI compiler only recognized @nki.jit-decorated functions when they were plain function objects. Nested decorators (e.g., @my_wrapper @nki.jit) wrapped the function in a non-function object, causing the compiler to skip it. The compiler now correctly unwraps decorator chains to find the underlying kernel function. See nki.jit.

Known Issues#

  • nki.isa.range_select: The on_false_value and reduce_cmd parameters are incorrectly ignored by the NKI compiler. The on_false_value is always set to (-3.4028235e+38) and reduce_cmd is always set to reduce_cmd.reset_reduce, regardless of the values passed in.

Neuron Kernel Interface (NKI) (0.1.0) [2.27] (Neuron 2.27.0 Release)#

Date: 12/25/2025

Improvements#

Known Issues#

  • nki.isa.nki.isa.nc_matmul - is_moving_onezero was incorrectly named is_moving_zero in this release

  • NKI ISA semantic checks are not available with NKI 0.2.0, workaround is to reference the API docs

  • NKI Collectives are not available with NKI 0.2.0

  • nki.benchmark and nki.profile are not available with NKI 0.2.0


Neuron Kernel Interface (NKI) (Beta) [2.26] (Neuron 2.26.0 Release)#

Date: 09/18/2025

Improvements#

  • new nki.language APIs:

    • nki.language.gelu_apprx_sigmoid - Gaussian Error Linear Unit activation function with sigmoid approximation.

    • nki.language.tile_size.total_available_sbuf_size to get total available SBUF size

  • new nki.isa APIs:

    • nki.isa.select_reduce - selectively copy elements with max reduction

    • nki.isa.sequence_bounds - compute sequence bounds of segment IDs

    • nki.isa.dma_transpose

      • axes param to define 4D transpose for some supported cases

      • dge_mode to specify Descriptor Generation Engine (DGE).

    • nl.gelu_apprx_sigmoid op support on nki.isa.activation

  • fixes / improvements:

    • nki.language.store supports PSUM buffer with extra additional copy inserted.

  • docs/tutorial improvements:

    • nki.isa.dma_transpose API doc and example

    • nki.simulate_kernel example improvement

    • use nl.fp32.min in tutorial code instead of a magic number

  • better error reporting:

    • indirect indexing on transpose

    • mask expressions


Neuron Kernel Interface (NKI) (Beta) [2.24] (Neuron 2.24.0 Release)#

Date: 06/24/2025

Improvements#

  • sqrt valid data range extended for accuracy improvement with wider numerical values support.

  • nki.language.gather_flattened new API

  • nki.isa.nc_match_replace8 additional param dst_idx

  • improved docs/examples on nki.isa.nc_match_replace8, nki.isa.nc_stream_shuffle

  • improved error messages


Neuron Kernel Interface (NKI) (Beta) [2.23] (Neuron 2.23.0 Release)#

Date: 05/20/2025

Improvements#

  • nki.isa.range_select (for trn2) new instruction

  • abs, power ops supported on to nki.isa tensor instruction

  • abs op supported on nki.isa.activation instruction

  • GpSIMD engine support added to add, multiply in 32bit integer to nki.isa tensor operations

  • nki.isa.tensor_copy_predicated support for reversing predicate.

  • nki.isa.tensor_copy_dynamic_src, tensor_copy_dynamic_dst engine selection.

  • nki.isa.dma_copy additional support with dge_mode, oob_mode, and in-place add rmw_op.

  • +=, -=, /=, *= operators now work consistently across loop types, PSUM, and SBUF,

  • fixed simulation for instructions: nki.language.rand, random_seed, nki.isa.dropout

  • fixed simulation masking behavior

  • Added warning when the block dimension is used for SBUF and PSUM tensors, see: NKI Block Dimension Migration Guide


Neuron Kernel Interface (NKI) (Beta) [2.22] (Neuron 2.22.0 Release)#

Date: 04/03/2025

Improvements#

  • New modules and APIs:

    • nki.profile

    • nki.isa new APIs:

      • tensor_copy_dynamic_dst

      • tensor_copy_predicated

      • max8, nc_find_index8, nc_match_replace8

      • nc_stream_shuffle

    • nki.language new APIs: mod, fmod, reciprocal, broadcast_to, empty_like

  • Improvements:

    • nki.isa.nc_matmul now supports PE tiling feature

    • nki.isa.activation updated to support reduce operation and reduce commands

    • nki.isa.engine enum

    • engine parameter added to more nki.isa APIs that support engine selection (ie, tensor_scalar, tensor_tensor, memset)

    • Documentation for nki.kernels have been moved to the GitHub: https://aws-neuron.github.io/nki-samples. The source code can be viewed at aws-neuron/nki-samples.

      • These kernels are still shipped as part of Neuron package in neuronxcc.nki.kernels module

  • Documentation updates:


Neuron Kernel Interface (NKI) (Beta) [2.21] (Neuron 2.21.0 Release)#

Date: 12/16/2024

Improvements#

  • New modules and APIs:

    • nki.compiler module with Allocation Control and Kernel decorators, see guide for more info.

    • nki.isa: new APIs (activation_reduce, tensor_partition_reduce, scalar_tensor_tensor, tensor_scalar_reduce, tensor_copy, tensor_copy_dynamic_src, dma_copy), new activation functions(identity, silu, silu_dx), and target query APIs (nc_version, get_nc_version).

    • nki.language: new APIs (shared_identity_matrix, tan, silu, silu_dx, left_shift, right_shift, ds, spmd_dim, nc).

    • New datatype <nl_datatypes>: float8_e5m2

    • New kernels (allocated_fused_self_attn_for_SD_small_head_size, allocated_fused_rms_norm_qkv) added, kernels moved to public repository.

  • Improvements:

    • Semantic analysis checks for nki.isa APIs to validate supported ops, dtypes, and tile shapes.

    • Standardized naming conventions with keyword arguments for common optional parameters.

    • Transition from function calls to kernel decorators (jit, benchmark, baremetal, simulate_kernel).

  • Documentation updates:


Neuron Kernel Interface (NKI) (Beta) (Neuron 2.20.1 Release)#

Date: 12/03/2024

Improvements#

  • NKI support for Trainium2, including full integration with Neuron Compiler. Users can directly shard NKI kernels across multiple Neuron Cores from an SPMD launch grid. See Trainium2 Architecture Guide for the architecture specification.

  • New calling convention in NKI kernels, where kernel output tensors are explicitly returned from the kernel instead of pass-by-reference. See any NKI tutorial for code examples.


Neuron Kernel Interface (NKI) (Beta) [2.20] (Neuron 2.20.0 Release)#

Date: 09/16/2024

Improvements#

  • This release includes the beta launch of the Neuron Kernel Interface (NKI) (Beta). NKI is a programming interface enabling developers to build optimized compute kernels on top of Trainium and Inferentia. NKI empowers developers to enhance deep learning models with new capabilities, performance optimizations, and scientific innovation. It natively integrates with PyTorch and JAX, providing a Python-based programming environment with Triton-like syntax and tile-level semantics offering a familiar programming experience for developers. Additionally, to enable bare-metal access precisely programming the instructions used by the chip, this release includes a set of NKI APIs (nki.isa) that directly emit Neuron Instruction Set Architecture (ISA) instructions in NKI kernels.

This document is relevant for: Inf1, Inf2, Trn1, Trn2, Trn3