Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run torchdynamo tests #19056

Merged
merged 2 commits into from
Sep 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions docker/transformers-pytorch-deepspeed-nightly-gpu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,24 @@ RUN python3 -m pip uninstall -y deepspeed
# RUN git clone https://github.com/microsoft/DeepSpeed && cd DeepSpeed && rm -rf build && \
# DS_BUILD_CPU_ADAM=1 DS_BUILD_FUSED_ADAM=1 DS_BUILD_AIO=1 DS_BUILD_UTILS=1 python3 -m pip install . --global-option="build_ext" --global-option="-j8" --no-cache -v --disable-pip-version-check 2>&1

# For `torchdynamo` tests
# (see https://github.com/huggingface/transformers/pull/17765)
RUN git clone https://github.com/pytorch/functorch
RUN python3 -m pip install --no-cache-dir ./functorch[aot]
RUN cd functorch && python3 setup.py develop

RUN git clone https://github.com/pytorch/torchdynamo
RUN python3 -m pip install -r ./torchdynamo/requirements.txt
RUN cd torchdynamo && python3 setup.py develop

# install TensorRT
RUN python3 -m pip install --no-cache-dir -U nvidia-pyindex
RUN python3 -m pip install --no-cache-dir -U nvidia-tensorrt==8.2.4.2

# install torch_tensorrt (fx path)
RUN git clone https://github.com/pytorch/TensorRT.git
RUN cd TensorRT/py && python3 setup.py install --fx-only

# When installing in editable mode, `transformers` is not recognized as a package.
# this line must be added in order for python to be aware of transformers.
RUN cd transformers && python3 setup.py develop
Expand Down
3 changes: 1 addition & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,14 +638,13 @@ def __init__(
raise RuntimeError("Torchdynamo is not installed.")
import torchdynamo
from torchdynamo.optimizations import backends
from torchdynamo.optimizations.training import aot_autograd_speedup_strategy

def get_ctx():
# Normal
if args.torchdynamo == "eager":
return torchdynamo.optimize("eager")
elif args.torchdynamo == "nvfuser":
return torchdynamo.optimize(aot_autograd_speedup_strategy)
return torchdynamo.optimize("aot_nvfuser")
# TensorRT
if args.torchdynamo in ["fx2trt-fp16", "fx2trt"]:
if not is_torch_tensorrt_fx_available():
Expand Down
12 changes: 11 additions & 1 deletion tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1799,6 +1799,8 @@ def test_fp16_full_eval(self):
@require_torchdynamo
@require_torch_tensorrt_fx
def test_torchdynamo_full_eval(self):
import torchdynamo

# torchdynamo at the moment doesn't support DP/DDP, therefore require a single gpu
n_gpus = get_gpu_count()

Expand All @@ -1820,18 +1822,21 @@ def test_torchdynamo_full_eval(self):
metrics = trainer.evaluate()
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
del trainer
torchdynamo.reset()

# 3. TorchDynamo nvfuser
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="nvfuser")
metrics = trainer.evaluate()
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
torchdynamo.reset()

# 4. TorchDynamo fx2trt
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="fx2trt")
metrics = trainer.evaluate()
t1 = metrics["eval_loss"]
t2 = original_eval_loss
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
torchdynamo.reset()

# 5. TorchDynamo fx2trt-fp16
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="fx2trt-fp16")
Expand All @@ -1840,11 +1845,14 @@ def test_torchdynamo_full_eval(self):
t2 = original_eval_loss
# fp16 has accuracy accuracy degradation
self.assertLess(np.max(np.abs(t1 - t2)), 1e-3)
torchdynamo.reset()

@require_torch_non_multi_gpu
@require_torchdynamo
def test_torchdynamo_memory(self):
# torchdynamo at the moment doesn't support DP/DDP, therefore require a single gpu
import torchdynamo

class CustomTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
x = inputs["x"]
Expand All @@ -1861,7 +1869,7 @@ def __init__(self):

def forward(self, x):
for _ in range(20):
x = torch.nn.functional.relu(x)
x = torch.cos(x)
return x

mod = MyModule()
Expand All @@ -1881,6 +1889,7 @@ def forward(self, x):

orig_loss = trainer.training_step(mod, {"x": a})
orig_peak_mem = torch.cuda.max_memory_allocated()
torchdynamo.reset()
del trainer

# 2. TorchDynamo nvfuser
Expand All @@ -1899,6 +1908,7 @@ def forward(self, x):

loss = trainer.training_step(mod, {"x": a})
peak_mem = torch.cuda.max_memory_allocated()
torchdynamo.reset()
del trainer

# Functional check
Expand Down