This document is relevant for: Trn2, Trn3

nki.collectives.all_to_all_v#

nki.collectives.all_to_all_v(srcs: List, dsts: List, replica_group: ReplicaGroup, metadata_tensor, recv_counts_known: bool = False, has_rdispls: bool = False, priority: Optional[int] = None) None[source]#

Executes an all-to-all collective where each rank can send a different number of elements, known only at execution time (rather than at compile time).

Unlike all_to_all which splits/concatenates along a collective dimension, all_to_all_v treats tensors as flat element buffers. Per-rank send/recv counts and displacements are supplied via a uint32 metadata tensor, making per-rank payload sizes dynamic.

Current restrictions:

On instances with a NeuronSwitch fabric (see Trn3 architecture), all_to_all_v requires LNC=2 and more than one participating device. Multiple ranks per device are supported, but for every replica-group rank-list, every device participating in that rank-list must have all of its ranks (4 under LNC=2) included in the same rank-list — each rank-list is a set of sequential ranks in the world (e.g. [[1, 2, 3, 4], [5, 6, 7, 8]]). To exclude a rank, keep it in the replica group and set its send_count to 0.

On other instances, all_to_all_v currently supports only inter-node replica groups: each rank-list contains same-indexed ranks from different nodes (a node refers to a different Trn EC2 instance).

Parameters:
  • srcs – Input tensor list. Currently supports exactly one tensor. Must be HBM-backed.

  • dsts – Output tensor list. Currently supports exactly one tensor. Must be HBM-backed. src and dst element counts can be different; sizes are validated against the metadata at execution time.

  • replica_group – ReplicaGroup defining which ranks participate.

  • metadata_tensor

    uint32 tensor laid out contiguously in memory. Shape depends on backing buffer, where rows is 3 when has_rdispls=False and 4 when has_rdispls=True:

    • HBM: (rows, replica_group_size).

    • SBUF: (1, rows, replica_group_size) — the whole buffer must live on a single partition, so a trivial partition dim is prepended.

    For each other rank r in the replica group, the rows are:

    • Row 0 send_counts[r]: number of elements sent to rank r. Always an input.

    • Row 1 send_displs[r]: offset in elements within src where the chunk destined for rank r begins. Always an input.

    • Row 2 recv_counts[r]: number of elements received from rank r. Controlled by recv_counts_known — see that flag.

    • Row 3 recv_displs[r]: offset in elements within dst where the chunk from rank r is written. Only present when has_rdispls=True.

  • recv_counts_known

    Controls whether row 2 is populated by the collective during execution. Row 2 is never read as input.

    • True: row 2 is left untouched, avoiding a small per-rank writeback.

    • False (default): row 2 is an output — per-rank received counts are written during execution, and can be read after the op to learn received sizes.

  • has_rdispls

    • True: row 3 is an input; recv_displs must be populated. The chunk from sender rank r is written at dst[recv_displs[r] : recv_displs[r] + recv_counts[r]].

    • False: row 3 may be omitted from metadata_tensor (pass a 3-row tensor). Incoming chunks are laid out equally-spaced at recv_displs[r] = (dst.total_elements / replica_group_size) * r, regardless of the actual recv_count per rank.

  • priority – DMA QoS priority level 0-3 where lower is higher priority (NeuronCore-v4+ only).

This document is relevant for: Trn2, Trn3