This document is relevant for: Inf2, Trn1, Trn2

PyTorch NeuronX Analyze API for Inference#

torch_neuronx.analyze(func, example_inputs, compiler_workdir=None)#

Checks the support of the operations in the func by checking each operator against neuronx-cc.

Parameters:
  • func (Module,callable) – The function/module that that will be run using the example_inputs arguments in order to record the computation graph.

  • example_inputs (Tensor,tuple[Tensor]) – A tuple of example inputs that will be passed to the func while tracing.

Keyword Arguments:
  • compiler_workdir (str) – Work directory used by neuronx-cc. This can be useful for debugging and/or inspecting intermediary neuronx-cc outputs

  • additional_ignored_ops (set) – A set of aten operators to not analyze. Default is an empty set.

  • max_workers (int) – The max number of workers threads to spawn. The default is 4.

  • is_hf_transformers (bool) – If the model is a huggingface transformers model, it is recommended to enable this option to prevent deadlocks. Default is False.

  • cleanup (bool) – Specifies whether to delete the compiler artifact directories generated after running analyze. Default is False.

Returns:

A JSON like Dict with the supported operators and their count, and unsupported operators with the failure mode and location of the operator in the python code.

Return type:

Dict

Examples

Fully supported model

import json

import torch
import torch.nn as nn
import torch_neuronx

class MLP(nn.Module):
   def __init__(self, input_size=28*28, output_size=10, layers=[120,84]):
      super(MLP, self).__init__()
      self.fc1 = nn.Linear(input_size, layers[0])
      self.relu = nn.ReLU()
      self.fc2 = nn.Linear(layers[0], layers[1])
   def forward(self, x):
      f1 = self.fc1(x)
      r1 = self.relu(f1)
      f2 = self.fc2(r1)
      r2 = self.relu(f2)
      f3 = self.fc3(r2)
      return torch.log_softmax(f3, dim=1)

model = MLP()
ex_input = torch.rand([32,784])

model_support = torch_neuronx.analyze(model,ex_input)
print(json.dumps(model_support,indent=4))
{
    "torch_neuronx_version": "1.13.0.1.5.0",
    "neuronx_cc_version": "2.0.0.11796a0+24a26e112",
    "support_percentage": "100.00%",
    "supported_operators": {
       "aten::linear": 3,
    "aten::relu": 2,
    "aten::log_softmax": 1
    },
    "unsupported_operators": []
 }

Unsupported Model/Operator

import json
import torch
import torch_neuronx

def fft(x):
   return torch.fft.fft(x)

model = fft
ex_input = torch.arange(4)

model_support = torch_neuronx.analyze(model,ex_input)
print(json.dumps(model_support,indent=4))
{
   "torch_neuronx_version": "1.13.0.1.5.0",
   "neuronx_cc_version": "2.0.0.11796a0+24a26e112",
   "support_percentage": "0.00%",
   "supported_operators": {},
   "unsupported_operators": [
      {
         "kind": "aten::fft_fft",
         "failureAt": "neuronx-cc",
         "call": "test.py(6): fft\n/home/ubuntu/testdir/venv/lib/python3.8/site-packages/torch_neuronx/xla_impl/analyze.py(35): forward\n/home/ubuntu/testdir/venv/lib/python3.8/site-packages/torch/nn/modules/module.py(1182): _slow_forward\n/home/ubuntu/testdir/venv/lib/python3.8/site-packages/torch/nn/modules/module.py(1194): _call_impl\n/home/ubuntu/testdir/venv/lib/python3.8/site-packages/torch/jit/_trace.py(976): trace_module\n/home/ubuntu/testdir/venv/lib/python3.8/site-packages/torch/jit/_trace.py(759): trace\n/home/ubuntu/testdir/venv/lib/python3.8/site-packages/torch_neuronx/xla_impl/analyze.py(302): analyze\ntest.py(11): <module>\n",
         "opGraph": "graph(%x : Long(4, strides=[1], requires_grad=0, device=cpu),\n      %neuron_4 : NoneType,\n      %neuron_5 : int,\n      %neuron_6 : NoneType):\n  %neuron_7 : ComplexFloat(4, strides=[1], requires_grad=0, device=cpu) = aten::fft_fft(%x, %neuron_4, %neuron_5, %neuron_6)\n  return (%neuron_7)\n"
      }
   ]
}

Note: the failureAt field can either be “neuronx-cc” or “Lowering to HLO”. If the field is “neuronx-cc”, then it indicates that the provided operator configuration failed to be compiled with neuronx-cc. This could either indicate that the operator configuration is unsupported, or there is a bug with that operator configuration.

This document is relevant for: Inf2, Trn1, Trn2