.. _torch-hf-bert-finetune: PyTorch Neuron for Trainium Hugging Face BERT MRPC task finetuning using Hugging Face Trainer API ================================================================================================= In this tutorial, we show how to run a Hugging Face script that uses Hugging Face Trainer API to do fine-tuning on Trainium. The example follows the `text-classification example `__ which fine-tunes BERT-base model for sequence classification on the GLUE benchmark. .. contents:: Table of Contents :local: :depth: 2 .. include:: ../note-performance.txt Setup and compilation --------------------- Before running the tutorial please follow the installation instructions at: :ref:`Install PyTorch Neuron on Trn1 ` Please set the storage of instance to *512GB* or more if you also want to run through the BERT pretraining and GPT pretraining tutorials. For all the commands below, make sure you are in the virtual environment that you have created above before you run the commands: .. code:: shell source ~/aws_neuron_venv_pytorch/bin/activate First we install a recent version of HF transformers, scikit-learn and evaluate packages in our environment as well as download the source matching the installed version. In this example, we use the text classification example from HF transformers source: .. code:: bash export HF_VER=4.27.4 pip install -U transformers==$HF_VER datasets evaluate scikit-learn cd ~/ git clone https://github.com/huggingface/transformers --branch v$HF_VER cd ~/transformers/examples/pytorch/text-classification Single-worker training ---------------------- We will run MRPC task fine-tuning following the example in README.md located in the path ``~/transformers/examples/pytorch/text-classification``. In this part of the tutorial we will use the Hugging Face model hub's pretrained ``bert-large-uncased`` model. .. note:: If you are using older versions of transformers <4.27.0 or PyTorch Neuron <1.13.0, please see section :ref:`workarounds_for_older_versions` for necessary workarounds. We use full BF16 casting using XLA_USE_BF16=1 and compiler flag ``--model-type=transformer`` to enable best performance. First, paste the following script into your terminal to create a “run.sh” file and change it to executable: .. code:: bash tee run.sh > /dev/null < /dev/null < /dev/null <= 4.21.0 # https://github.com/aws-neuron/aws-neuron-sdk/issues/593 import transformers if os.environ.get("XLA_USE_BF16") or os.environ.get("XLA_DOWNCAST_BF16"): transformers.modeling_utils.get_parameter_dtype = lambda x: torch.bfloat16 .. _known_issues: Known issues and limitations ---------------------------- The following are currently known issues: - Long compilation times: this can be alleviated with ``neuron_parallel_compile`` tool to extract graphs from a short trial run and compile them in parallel ahead of the actual run, as shown above. - When precompiling using batch size of 16 on trn1.2xlarge, you will see ``ERROR ||PARALLEL_COMPILE||: parallel compilation with neuronx-cc exited with error.Received error code: -9``. To workaround this error, please set NEURON_PARALLEL_COMPILE_MAX_RETRIES=1 in the environment. - With release 2.6 and transformers==4.25.1, using ``neuron_parallel_compile`` tool to run ``run_glue.py`` script with both train and evaluation options (``--do_train`` and ``--do_eval``), you will encounter harmless error ``ValueError: Target is multiclass but average='binary'`` - Reduced accuracy for RoBerta-Large is seen with Neuron PyTorch 1.12 (release 2.6) in FP32 mode with compiler BF16 autocast. The workaround is to set NEURON_CC_FLAGS="--auto-cast none" or set NEURON_RT_STOCHASTIC_ROUNDING_EN=1. - When using DDP in PT 1.13, compilation of one graph will fail with "Killed" error message for ``bert-large-uncased``. For ``bert-base-cased``, the final MRPC evaluation accuracy is 31% which is lower than expected. These issues are being investigated and will be fixed in an upcoming release. For now, DDP is disabled with the workaround shown above in :ref:`multi_worker_training`. - When using DDP in PT 1.13 with neuron_parallel_compile precompilation, you will hit an error ``Rank 1 has 393 params, while rank 0 has inconsistent 0 params.``. To workaround this error, add the follow code snippet at the top of ``run_glue.py`` to skip the problematic shape verification code during precompilation: .. code:: python import os if os.environ.get("NEURON_EXTRACT_GRAPHS_ONLY", None): import torch.distributed as dist _verify_param_shape_across_processes = lambda process_group, tensors, logger=None: True - Variable input sizes: When fine-tune models such as dslim/bert-base-NER using the `token-classification example `__, you may encounter timeouts (lots of "socket.h:524 CCOM WARN Timeout waiting for RX" messages) and execution hang. This occurs because NER dataset has different sample sizes, which causes many recompilations and compiled graph (NEFF) reloads. Furthermore, different data parallel workers can execute different compiled graph. This multiple-program multiple-data behavior is currently unsupported. To workaround this issue, please pad to maximum length using the Trainer API option ``--pad_to_max_length``. - When running HuggingFace GPT fine-tuning with transformers version >= 4.21.0 and using XLA_USE_BF16=1 or XLA_DOWNCAST_BF16=1, you might see NaNs in the loss immediately at the first step. This issue occurs due to large negative constants used to implement attention masking (https://github.com/huggingface/transformers/pull/17306). To workaround this issue, please use transformers version <= 4.20.0. - When using Trainer API option --bf16, you will see "RuntimeError: No CUDA GPUs are available". To workaround this error, please add "import torch; torch.cuda.is_bf16_supported = lambda: True" to the Python script (i.e. run_glue.py). (Trainer API option --fp16 is not yet supported). The following are resolved issues: - Using ``neuron_parallel_compile`` tool to run ``run_glue.py`` script with both train and evaluation options (``--do_train`` and ``--do_eval``), you will encounter INVALID_ARGUMENT error. To avoid this, only enable train for parallel compile (``--do_train``). This will cause compilations during evaluation step. The INVALID_ARGUMENT error is fixed in release 2.6 together with latest transformers package version 4.25.1. - When running HuggingFace BERT (any size) fine-tuning tutorial or pretraining tutorial with transformers version >= 4.21.0 and < 4.25.1 and using XLA_USE_BF16=1 or XLA_DOWNCAST_BF16=1, you will see NaNs in the loss immediately at the first step. More details on the issue can be found at `pytorch/xla#4152 `_. The workaround is to use transformers version < 4.21.0 or >= 4.25.1, or add ``transformers.modeling_utils.get_parameter_dtype = lambda x: torch.bfloat16`` to your Python script (i.e. run_glue.py). - Some recompilation is seen at the epoch boundary even after ``neuron_parallel_compile`` is used. This can be fixed by using the same number of epochs both during precompilation and the actual run. - When running multi-worker training, you may see the process getting killed at the time of model saving on trn1.2xlarge. This happens because the transformers ``trainer.save_model`` api uses ``xm.save`` for saving models. This api is known to cause high host memory usage in multi-worker setting `see Saving and Loading XLA Tensors in `__ . Coupled with a compilation at the same time results in a host OOM. To avoid this issue, we can: Precompile all the graphs in multi-worker training. This can be done by running the multi-worker training first with ``neuron_parallel_compile