@@ -1973,6 +1973,57 @@ def t(x):
1973
1973
x = x .to ("cuda:1" )
1974
1974
jit_o = t_jit (x )
1975
1975
1976
+ @unittest .skipIf (not RUN_CUDA , "requires CUDA" )
1977
+ @unittest .skipIf (GRAPH_EXECUTOR != ProfilingMode .PROFILING ,
1978
+ "Requires fusion optimization pass to be effective" )
1979
+ def test_graph_rng (self ):
1980
+ self .assertTrue (torch ._C ._jit_nvfuser_enabled ())
1981
+ size = 10000
1982
+ a = torch .randn ((size ,), device = "cuda" , dtype = torch .float )
1983
+
1984
+ def t (x ):
1985
+ o = x + 1.0
1986
+ o = torch .nn .functional .dropout (o , p = 0.1 )
1987
+ o = o + 1.0
1988
+ o = torch .nn .functional .dropout (o , p = 0.1 )
1989
+ return o
1990
+
1991
+ t_jit = torch .jit .script (t )
1992
+
1993
+ for _ in range (3 ):
1994
+ t_jit (a )
1995
+
1996
+ self .assertGraphContainsExactly (t_jit .graph_for (a ), FUSION_GUARD , 1 )
1997
+
1998
+ # Control (jitted, ungraphed)
1999
+ torch .cuda .manual_seed (5 )
2000
+ eager_out = a .clone ()
2001
+ for _ in range (3 ):
2002
+ eager_out = t_jit (eager_out )
2003
+
2004
+ graph_in = a .clone ()
2005
+ g = torch .cuda ._Graph ()
2006
+ s = torch .cuda .Stream ()
2007
+ s .wait_stream (torch .cuda .current_stream ())
2008
+ with torch .cuda .stream (s ):
2009
+ torch .cuda .manual_seed (5 )
2010
+ g .capture_begin ()
2011
+ graph_out = t_jit (graph_in )
2012
+ g .capture_end ()
2013
+ torch .cuda .current_stream ().wait_stream (s )
2014
+ # g is now a jitted, graphed version of t.
2015
+
2016
+ # Runs a (jitted, graphed) -> (jitted, ungraphed) -> (jitted, graphed) sequence.
2017
+ # The ops in the overall sequence should be the same as Control.
2018
+ g .replay ()
2019
+ # graph_out is now filled with g's result. Use it as ungraphed input.
2020
+ out = t_jit (graph_out )
2021
+ graph_in .copy_ (out )
2022
+ g .replay ()
2023
+
2024
+ # If replay() updated RNG state correctly, graph_out should now equal eager_out
2025
+ self .assertEqual (graph_out , eager_out )
2026
+
1976
2027
class TestPassManagerCudaFuser (JitTestCase ):
1977
2028
1978
2029
@unittest .skipIf (not RUN_CUDA , "requires CUDA" )
0 commit comments