From 04ed4ce7f3771fbeb548beac7c21d2fe92cc42e9 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Thu, 15 Sep 2022 13:44:29 +0200 Subject: [PATCH 1/2] Enable torchdynamo tests --- .../Dockerfile | 18 ++++++++++++++++++ src/transformers/trainer.py | 3 +-- tests/trainer/test_trainer.py | 11 ++++++++++- 3 files changed, 29 insertions(+), 3 deletions(-) diff --git a/docker/transformers-pytorch-deepspeed-nightly-gpu/Dockerfile b/docker/transformers-pytorch-deepspeed-nightly-gpu/Dockerfile index 1854d9f4b38d48..573e09c22a9c05 100644 --- a/docker/transformers-pytorch-deepspeed-nightly-gpu/Dockerfile +++ b/docker/transformers-pytorch-deepspeed-nightly-gpu/Dockerfile @@ -27,6 +27,24 @@ RUN python3 -m pip uninstall -y deepspeed # RUN git clone https://github.com/microsoft/DeepSpeed && cd DeepSpeed && rm -rf build && \ # DS_BUILD_CPU_ADAM=1 DS_BUILD_FUSED_ADAM=1 DS_BUILD_AIO=1 DS_BUILD_UTILS=1 python3 -m pip install . --global-option="build_ext" --global-option="-j8" --no-cache -v --disable-pip-version-check 2>&1 +# For `torchdynamo` tests +# (see https://github.com/huggingface/transformers/pull/17765) +RUN git clone https://github.com/pytorch/functorch +RUN python3 -m pip install --no-cache-dir ./functorch[aot] +RUN cd functorch && python3 setup.py develop + +RUN git clone https://github.com/pytorch/torchdynamo +RUN python3 -m pip install -r ./torchdynamo/requirements.txt +RUN cd torchdynamo && python3 setup.py develop + +# install TensorRT +RUN python3 -m pip install --no-cache-dir -U nvidia-pyindex +RUN python3 -m pip install --no-cache-dir -U nvidia-tensorrt==8.2.4.2 + +# install torch_tensorrt (fx path) +RUN git clone https://github.com/pytorch/TensorRT.git +RUN cd TensorRT/py && python3 setup.py install --fx-only + # When installing in editable mode, `transformers` is not recognized as a package. # this line must be added in order for python to be aware of transformers. RUN cd transformers && python3 setup.py develop diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 27e44ea0ba0bd4..6cae5a6ea0069f 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -638,14 +638,13 @@ def __init__( raise RuntimeError("Torchdynamo is not installed.") import torchdynamo from torchdynamo.optimizations import backends - from torchdynamo.optimizations.training import aot_autograd_speedup_strategy def get_ctx(): # Normal if args.torchdynamo == "eager": return torchdynamo.optimize("eager") elif args.torchdynamo == "nvfuser": - return torchdynamo.optimize(aot_autograd_speedup_strategy) + return torchdynamo.optimize("aot_nvfuser") # TensorRT if args.torchdynamo in ["fx2trt-fp16", "fx2trt"]: if not is_torch_tensorrt_fx_available(): diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index f48265ffa58168..512279c024a5c3 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1799,6 +1799,7 @@ def test_fp16_full_eval(self): @require_torchdynamo @require_torch_tensorrt_fx def test_torchdynamo_full_eval(self): + import torchdynamo # torchdynamo at the moment doesn't support DP/DDP, therefore require a single gpu n_gpus = get_gpu_count() @@ -1820,11 +1821,13 @@ def test_torchdynamo_full_eval(self): metrics = trainer.evaluate() self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss) del trainer + torchdynamo.reset() # 3. TorchDynamo nvfuser trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="nvfuser") metrics = trainer.evaluate() self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss) + torchdynamo.reset() # 4. TorchDynamo fx2trt trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="fx2trt") @@ -1832,6 +1835,7 @@ def test_torchdynamo_full_eval(self): t1 = metrics["eval_loss"] t2 = original_eval_loss self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss) + torchdynamo.reset() # 5. TorchDynamo fx2trt-fp16 trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="fx2trt-fp16") @@ -1840,11 +1844,14 @@ def test_torchdynamo_full_eval(self): t2 = original_eval_loss # fp16 has accuracy accuracy degradation self.assertLess(np.max(np.abs(t1 - t2)), 1e-3) + torchdynamo.reset() @require_torch_non_multi_gpu @require_torchdynamo def test_torchdynamo_memory(self): # torchdynamo at the moment doesn't support DP/DDP, therefore require a single gpu + import torchdynamo + class CustomTrainer(Trainer): def compute_loss(self, model, inputs, return_outputs=False): x = inputs["x"] @@ -1861,7 +1868,7 @@ def __init__(self): def forward(self, x): for _ in range(20): - x = torch.nn.functional.relu(x) + x = torch.cos(x) return x mod = MyModule() @@ -1881,6 +1888,7 @@ def forward(self, x): orig_loss = trainer.training_step(mod, {"x": a}) orig_peak_mem = torch.cuda.max_memory_allocated() + torchdynamo.reset() del trainer # 2. TorchDynamo nvfuser @@ -1899,6 +1907,7 @@ def forward(self, x): loss = trainer.training_step(mod, {"x": a}) peak_mem = torch.cuda.max_memory_allocated() + torchdynamo.reset() del trainer # Functional check From b3c7b611d0fd281e179af537a007671130b36340 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Thu, 15 Sep 2022 19:08:07 +0200 Subject: [PATCH 2/2] make style --- tests/trainer/test_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 512279c024a5c3..a8f4c11dcc4101 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1800,6 +1800,7 @@ def test_fp16_full_eval(self): @require_torch_tensorrt_fx def test_torchdynamo_full_eval(self): import torchdynamo + # torchdynamo at the moment doesn't support DP/DDP, therefore require a single gpu n_gpus = get_gpu_count()