Skip to content

Commit

Permalink
Type inference patch (#1848)
Browse files Browse the repository at this point in the history
Fixes type inference issue where update is written to a temporary optional holder, other than the updating the graph.
  • Loading branch information
jjsjann123 committed Jul 20, 2022
1 parent 102fe93 commit db89c65
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
22 changes: 22 additions & 0 deletions test/test_jit_cuda_fuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4937,6 +4937,28 @@ def t2(x0, x1, x2):
t2_jit = torch.jit.script(t2)
self._run_helper(t2_jit, t2, x0, x1, x2, check_stride=True)

@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_type_inference(self):
device = "cuda"
x0 = torch.randn(10, 128, device=device)
x1 = torch.rand_like(x0)
x2 = torch.rand_like(x0)

def t(x0, x1, x2, flag : bool = True):
x3 = 2.0 * x0
x4 = 2.0 * x1
x5 = 2.0 * x2
if flag:
return torch.stack([x3, x4, x5], dim=-1)
# second code path doesn't run through profiling
# hence would utilize type inference with profiling information
return x0 + x1 + x2

t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x0, x1, x2, check_stride=True)


class TestEnableDisableCudaFuser(JitTestCase):
def setUp(self):
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/codegen/cuda/type_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ void copyScalarTypeAndDeviceToOutput(
out != nullptr,
"Expect target node's type pointer to be non-nullptr, but get nullptr");
if (!hasTypeAndDevice(out)) {
out->scalarType() = dtype;
out->device() = device;
node->output(index)->setType(
TensorType::create(dtype, device, c10::nullopt, c10::nullopt));
}
}

Expand Down

0 comments on commit db89c65

Please sign in to comment.