Skip to content

Commit

Permalink
Support mark_dynamic (#7812)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored Aug 7, 2024
1 parent c26b19e commit 08d9595
Show file tree
Hide file tree
Showing 2 changed files with 185 additions and 94 deletions.
66 changes: 65 additions & 1 deletion test/dynamo/test_dynamo_dynamic_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,51 @@ def test_dynamic_shape_basic(self):
self.assertEqual(met.metric_data('CompileTime')[0], 1)
self.assertEqual(met.metric_data('ExecuteTime')[0], 1)

def test_dynamic_shape_basic_with_mark_dynamic(self):
torch_xla.manual_seed(100)
device = torch_xla.device()
# model setup
dummy_linear, dummy_linear_xla, input, input_xla = self._get_linear_and_input(
10, 40, 40, device)
compiled_linear_xla = torch.compile(
dummy_linear_xla, backend="openxla", dynamic=False)
xm.wait_device_ops()
met.clear_all()

# first run
res = dummy_linear(input)
torch._dynamo.mark_dynamic(input_xla, 0)
res_xla = compiled_linear_xla(input_xla)
# TPU matmul happens in bf16
torch.allclose(res, res_xla.cpu(), atol=1e-2, rtol=1e-4)
# torch.compile should be called once
self.assertEqual(met.counter_value('DynamoExtractCompiledGraph'), 1)
self.assertEqual(met.metric_data('CompileTime')[0], 1)
self.assertEqual(met.metric_data('ExecuteTime')[0], 1)

input = torch.randn(25, 10)
input_xla = input.to(device)
torch._dynamo.mark_dynamic(input_xla, 0)
met.clear_all()
res = dummy_linear(input)
res_xla = compiled_linear_xla(input_xla)
torch.allclose(res, res_xla.cpu(), atol=1e-2, rtol=1e-4)
self.assertNotIn('DynamoExtractCompiledGraph', met.counter_names())
self.assertEqual(met.metric_data('CompileTime')[0], 1)
self.assertEqual(met.metric_data('ExecuteTime')[0], 1)

# torch.compile should not recompile
input = torch.randn(26, 10)
input_xla = input.to(device)
torch._dynamo.mark_dynamic(input_xla, 0)
met.clear_all()
res = dummy_linear(input)
res_xla = compiled_linear_xla(input_xla)
torch.allclose(res, res_xla.cpu(), atol=1e-2, rtol=1e-4)
self.assertNotIn('DynamoExtractCompiledGraph', met.counter_names())
self.assertEqual(met.metric_data('CompileTime')[0], 1)
self.assertEqual(met.metric_data('ExecuteTime')[0], 1)

def test_dynamic_shape_multiple_batchs(self):
torch_xla.manual_seed(100)
device = torch_xla.device()
Expand Down Expand Up @@ -191,7 +236,7 @@ def test_dynamic_shape_mix_with_non_dynamic(self):
self.assertEqual(met.metric_data('CompileTime')[0], 1)
self.assertEqual(met.metric_data('ExecuteTime')[0], 1)

def test_dynamic_shape_symint_as_return(self):
def test_dynamic_decoder(self):
device = torch_xla.device()
config = DecoderOnlyConfig()
config.num_hidden_layers = 2
Expand All @@ -210,6 +255,25 @@ def test_dynamic_shape_symint_as_return(self):
# for other batch sizes.
self.assertEqual(met.counter_value('DynamoExtractCompiledGraph'), 2)

def test_dynamic_shape_decoder_mark_dynamic(self):
device = torch_xla.device()
config = DecoderOnlyConfig()
config.num_hidden_layers = 2
config.hidden_size = 512
seq_len = 512
decoder_model = DecoderOnlyModel(config).to(device)
compiled_decoder_model = torch.compile(
decoder_model, backend="openxla", dynamic=False)
xm.wait_device_ops()
met.clear_all()
for batch_size in [1, 2, 3, 4, 5]:
input = torch.zeros(batch_size, seq_len, dtype=torch.int64).to(device)
torch._dynamo.mark_dynamic(input, 0)
res = compiled_decoder_model(input)
# First compile will be static, starting from second one it will be dynamic so we will see
# torch.compile call xla's dynamo backend twice
self.assertEqual(met.counter_value('DynamoExtractCompiledGraph'), 2)

def test_dynamic_shape_no_retracing(self):
device = torch_xla.device()
# model setup
Expand Down
Loading

0 comments on commit 08d9595

Please sign in to comment.