This document is relevant for: Inf2, Trn1, Trn1n

nki.language.sequential_range#

nki.language.sequential_range(*args)[source]#

Create a sequence of numbers for use as sequential loop iterators in NKI. sequential_range should be used when there is a loop carried dependency. Note, associative reductions are not considered loop carried dependencies in this context. See affine_range for an example of such associative reduction.

Notes:

  • Inside a NKI kernel, any use of Python range(...) will be replaced with sequential_range(...) by Neuron compiler.

  • Using sequential_range prevents Neuron compiler from unrolling the loops until entering compiler backend, which typically results in better compilation time compared to the fully unrolled iterator static_range.

  • Using sequential_range informs Neuron compiler to respect inter-loop dependency and perform much more conservative loop-level optimizations compared to affine_range.

  • Using affine_range instead of sequential_range in case of loop carried dependency incorrectly is considered unsafe and could lead to numerical errors.

 1import neuronxcc.nki.language as nl
 2
 3#######################################################################
 4# Example 1: Loop carried dependency from tiling tensor_tensor_scan
 5# Both sbuf tensor input0 and input1 shapes: [128, 2048]
 6# Perform a scan operation between the two inputs using a tile size of [128, 512]
 7# Store the scan output to another [128, 2048] tensor
 8#######################################################################
 9
10# Loop iterations communicate through this init tensor
11init = nl.zeros((128, 1), dtype=input0.dtype)
12
13# This loop will only produce correct results if the iterations are performed in order
14for i_input in nl.sequential_range(input0.shape[1] // 512):
15  offset = i_input * 512
16
17  # Depends on scan result from the previous loop iteration
18  result = nisa.tensor_tensor_scan(input0[:, offset:offset+512],
19                                   input1[:, offset:offset+512],
20                                   initial=init,
21                                   op0=nl.multiply, op1=nl.add)
22
23  nl.store(output[0:input0.shape[0], offset:offset+512], result)
24
25  # Prepare initial result for scan in the next loop iteration
26  init[:, :] = result[:, 511]

This document is relevant for: Inf2, Trn1, Trn1n