This document is relevant for: Trn1, Trn2, Trn3

nki.collectives.all_to_all_v#

nki.collectives.all_to_all_v(srcs, dsts, replica_group, metadata_tensor, recv_counts_known=False, has_rdispls=False)[source]#

Perform a variable-length all-to-all on the given replica group and input/output tensors.

Unlike all_to_all which splits and concatenates along a collective_dim, all_to_all_v treats tensors as flat buffers of elements. Counts and displacements in the metadata tensor are in elements (row-major order), not slices along a particular dimension.

Parameters:
  • srcs – List of input tensors to redistribute (must be exactly one)

  • dsts – List of output tensors to store results (must be exactly one)

  • replica_group – ReplicaGroup defining rank groups for the collective

  • metadata_tensor – Metadata tensor of shape (2-4, world_size), dtype uint32. Row 0: send counts, Row 1: send displacements, Row 2 (optional): recv counts, Row 3 (optional): recv displacements.

  • recv_counts_known – If True, metadata includes receive counts (row 2)

  • has_rdispls – If True, metadata includes receive displacements (row 3)

This document is relevant for: Trn1, Trn2, Trn3