Skip to content

Commit

Permalink
DYNAMO RNG seed update optimization (pytorch#7884)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored and yitongh committed Dec 11, 2024
1 parent a529a16 commit a39d28a
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
6 changes: 6 additions & 0 deletions test/dynamo/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,18 @@ def random_op(self, a):
return torch.randn(5, 5, device=a.device) + a

def test_random_op_different_result_each_run(self):
xm.wait_device_ops()
met.clear_all()
dynamo_random_op = torch.compile(
self.random_op, backend="openxla", fullgraph=True)
t = torch.randn(5, 5).to(xm.xla_device())
dynamo_res_1 = dynamo_random_op(t)
dynamo_res_2 = dynamo_random_op(t)
dynamo_res_3 = dynamo_random_op(t)
# retriving/updating rng seed in the breidge should not cause transferToServer
self.assertNotIn("TransferFromDeviceTime", met.metric_names())
# updating rng seed will result in transferToServer
self.assertIn("TransferToDeviceTime", met.metric_names())
self.assertFalse(torch.allclose(dynamo_res_1, dynamo_res_2))
self.assertFalse(torch.allclose(dynamo_res_2, dynamo_res_3))

Expand Down
4 changes: 3 additions & 1 deletion torch_xla/core/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,10 @@ def __call__(self, args):
# update random seed here to avoid random operations always return
# the same result. The seed update logic is the same as `mark_step` in
# https://github.com/pytorch/pytorch/blob/6af6b8f728426fb7551630e28148c0017fa501bc/torch/csrc/lazy/core/lazy_graph_executor.cpp#L144C18-L144C51
# Note: don't do `inp.item()` here since it will trigger a transferFromDevice
xm.set_rng_state(
(1012031 + inp.item() * 7012063) % 18446744073709551615, str_device)
(1012031 + torch_xla._XLAC._xla_get_rng_seed() * 7012063) %
18446744073709551615, str_device)
elif arg_idx is None:
assert traced_xla_value is not None, "Traced Tensor cannot be None."
inp = traced_xla_value
Expand Down

0 comments on commit a39d28a

Please sign in to comment.