From 76800fb8e66aabfa1b31300f80fcf986befbb15f Mon Sep 17 00:00:00 2001 From: Philipp Schmid <32632186+philschmid@users.noreply.github.com> Date: Tue, 6 Apr 2021 15:12:21 +0200 Subject: [PATCH] added new merged Trainer test (#11090) --- tests/sagemaker/test_multi_node_model_parallel.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/sagemaker/test_multi_node_model_parallel.py b/tests/sagemaker/test_multi_node_model_parallel.py index bca402bcba42f0..3135573653002c 100644 --- a/tests/sagemaker/test_multi_node_model_parallel.py +++ b/tests/sagemaker/test_multi_node_model_parallel.py @@ -1,4 +1,5 @@ import os +import subprocess import unittest from ast import literal_eval @@ -28,10 +29,23 @@ "instance_type": "ml.p3dn.24xlarge", "results": {"train_runtime": 700, "eval_accuracy": 0.3, "eval_loss": 1.2}, }, + { + "framework": "pytorch", + "script": "run_glue.py", + "model_name_or_path": "roberta-large", + "instance_type": "ml.p3dn.24xlarge", + "results": {"train_runtime": 700, "eval_accuracy": 0.3, "eval_loss": 1.2}, + }, ] ) class MultiNodeTest(unittest.TestCase): def setUp(self): + if self.framework == "pytorch": + subprocess.run( + f"cp ./examples/text-classification/run_glue.py {self.env.test_path}/run_glue.py".split(), + encoding="utf-8", + check=True, + ) assert hasattr(self, "env") def create_estimator(self, instance_count):