This document is relevant for: Inf2
, Trn1
, Trn1n
nki.language.sequential_range#
- nki.language.sequential_range(*args)[source]#
Create a sequence of numbers for use as sequential loop iterators in NKI.
sequential_range
should be used when there is a loop carried dependency. Note, associative reductions are not considered loop carried dependencies in this context. See affine_range for an example of such associative reduction.Notes:
Inside a NKI kernel, any use of Python
range(...)
will be replaced withsequential_range(...)
by Neuron compiler.Using
sequential_range
prevents Neuron compiler from unrolling the loops until entering compiler backend, which typically results in better compilation time compared to the fully unrolled iterator static_range.Using
sequential_range
informs Neuron compiler to respect inter-loop dependency and perform much more conservative loop-level optimizations compared toaffine_range
.Using
affine_range
instead ofsequential_range
in case of loop carried dependency incorrectly is considered unsafe and could lead to numerical errors.
1import neuronxcc.nki.language as nl 2 3####################################################################### 4# Example 1: Loop carried dependency from tiling tensor_tensor_scan 5# Both sbuf tensor input0 and input1 shapes: [128, 2048] 6# Perform a scan operation between the two inputs using a tile size of [128, 512] 7# Store the scan output to another [128, 2048] tensor 8####################################################################### 9 10# Loop iterations communicate through this init tensor 11init = nl.zeros((128, 1), dtype=input0.dtype) 12 13# This loop will only produce correct results if the iterations are performed in order 14for i_input in nl.sequential_range(input0.shape[1] // 512): 15 offset = i_input * 512 16 17 # Depends on scan result from the previous loop iteration 18 result = nisa.tensor_tensor_scan(input0[:, offset:offset+512], 19 input1[:, offset:offset+512], 20 initial=init, 21 op0=nl.multiply, op1=nl.add) 22 23 nl.store(output[0:input0.shape[0], offset:offset+512], result) 24 25 # Prepare initial result for scan in the next loop iteration 26 init[:, :] = result[:, 511]
This document is relevant for: Inf2
, Trn1
, Trn1n