diff --git a/test/test_jit_cuda_fuser.py b/test/test_jit_cuda_fuser.py index 084a4bd0cbd2..1ca6c213843e 100644 --- a/test/test_jit_cuda_fuser.py +++ b/test/test_jit_cuda_fuser.py @@ -1973,6 +1973,57 @@ def t(x): x = x.to("cuda:1") jit_o = t_jit(x) + @unittest.skipIf(not RUN_CUDA, "requires CUDA") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") + def test_graph_rng(self): + self.assertTrue(torch._C._jit_nvfuser_enabled()) + size = 10000 + a = torch.randn((size,), device="cuda", dtype=torch.float) + + def t(x): + o = x + 1.0 + o = torch.nn.functional.dropout(o, p=0.1) + o = o + 1.0 + o = torch.nn.functional.dropout(o, p=0.1) + return o + + t_jit = torch.jit.script(t) + + for _ in range(3): + t_jit(a) + + self.assertGraphContainsExactly(t_jit.graph_for(a), FUSION_GUARD, 1) + + # Control (jitted, ungraphed) + torch.cuda.manual_seed(5) + eager_out = a.clone() + for _ in range(3): + eager_out = t_jit(eager_out) + + graph_in = a.clone() + g = torch.cuda._Graph() + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + torch.cuda.manual_seed(5) + g.capture_begin() + graph_out = t_jit(graph_in) + g.capture_end() + torch.cuda.current_stream().wait_stream(s) + # g is now a jitted, graphed version of t. + + # Runs a (jitted, graphed) -> (jitted, ungraphed) -> (jitted, graphed) sequence. + # The ops in the overall sequence should be the same as Control. + g.replay() + # graph_out is now filled with g's result. Use it as ungraphed input. + out = t_jit(graph_out) + graph_in.copy_(out) + g.replay() + + # If replay() updated RNG state correctly, graph_out should now equal eager_out + self.assertEqual(graph_out, eager_out) + class TestPassManagerCudaFuser(JitTestCase): @unittest.skipIf(not RUN_CUDA, "requires CUDA") diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index 40974ecbf797..dcc0e2e55d8e 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -93,7 +93,7 @@ class CudaKernelGenerator : private kir::IrVisitor { // Kernels generating random numbers take extra (seed, offset) arguments if (kernel_summary.is_stochastic) { - code_ << ", unsigned long long seed, unsigned long long offset"; + code_ << ", at::PhiloxCudaState philox_args"; } code_ << ") "; @@ -106,7 +106,11 @@ class CudaKernelGenerator : private kir::IrVisitor { // Random number generator (optional) if (kernel_summary.is_stochastic) { indent() << "const int idx = blockIdx.x*blockDim.x + threadIdx.x;\n"; - indent() << "Philox rnd(seed, idx, offset);\n"; + indent() << "auto offset = philox_args.captured_ ?\n"; + indent() + << " static_cast(*(philox_args.offset_.ptr) + philox_args.offset_intragraph_) :\n"; + indent() << " philox_args.offset_.val;\n"; + indent() << "Philox rnd(philox_args.seed_, idx, offset);\n"; } // Do we have any dynamic shared memory buffers? diff --git a/torch/csrc/jit/codegen/cuda/executor_kernel_arg.cpp b/torch/csrc/jit/codegen/cuda/executor_kernel_arg.cpp index b2dc41100751..b0ad6749c396 100644 --- a/torch/csrc/jit/codegen/cuda/executor_kernel_arg.cpp +++ b/torch/csrc/jit/codegen/cuda/executor_kernel_arg.cpp @@ -78,8 +78,8 @@ void KernelArgumentHolder::push(const IValue& val) { " Tried to create argument to send to a fused kernel, but got a non-scalar type."); } -void KernelArgumentHolder::push(const uint64_t& val) { - arguments_.push_back(std::make_unique(val)); +void KernelArgumentHolder::push(const at::PhiloxCudaState& val) { + arguments_.push_back(std::make_unique(val)); } // Create buffer, flatten arguments into it, align by 8 Bytes, return pointers @@ -115,17 +115,16 @@ void KernelArgumentHolder::push(const std::vector& tensors) { } void KernelArgumentHolder::appendPhiloxRNGSeed(uint64_t rand_offset) { - std::pair philox_engine_inputs; + at::PhiloxCudaState philox_engine_inputs; auto gen = at::cuda::detail::getDefaultCUDAGenerator(); { // See Note [Acquire lock when using random generators] std::lock_guard lock(gen.mutex()); philox_engine_inputs = - at::check_generator(gen)->philox_engine_inputs( + at::check_generator(gen)->philox_cuda_state( rand_offset); } - push(philox_engine_inputs.first); - push(philox_engine_inputs.second); + push(philox_engine_inputs); } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/executor_kernel_arg.h b/torch/csrc/jit/codegen/cuda/executor_kernel_arg.h index fbecd9b7ec0b..7c43f950bb50 100644 --- a/torch/csrc/jit/codegen/cuda/executor_kernel_arg.h +++ b/torch/csrc/jit/codegen/cuda/executor_kernel_arg.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -53,10 +54,9 @@ struct ArgAbstract { virtual void* arg() = 0; }; -// Explicitly for philox seed, not a supported type by any other mechanism -struct ULongArg : public ArgAbstract { - uint64_t val_; - explicit ULongArg(uint64_t _val) : val_(_val){}; +struct PhiloxCudaStateArg : public ArgAbstract { + at::PhiloxCudaState val_; + PhiloxCudaStateArg(at::PhiloxCudaState _val) : val_(_val){}; void* arg() { return &val_; } @@ -155,7 +155,7 @@ class KernelArgumentHolder { // Push a scalar or integer to the arguments void push(const IValue& val); - void push(const uint64_t& val); + void push(const at::PhiloxCudaState& val); // Create buffer, flatten arguments into it, align by 8 Bytes, return pointers // in the buffer diff --git a/torch/csrc/jit/codegen/cuda/executor_utils.cpp b/torch/csrc/jit/codegen/cuda/executor_utils.cpp index 843de0d11199..c6455266f7fe 100644 --- a/torch/csrc/jit/codegen/cuda/executor_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/executor_utils.cpp @@ -12,6 +12,7 @@ #include #include +#include #include #include #include @@ -43,6 +44,7 @@ std::string kernelPreamble() { ss << nvfuser_resources::grid_reduction_cu; ss << nvfuser_resources::broadcast_cu; ss << nvfuser_resources::welford_cu; + ss << nvfuser_resources::PhiloxCudaStateRaw_cu; return ss.str(); } diff --git a/torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu b/torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu index 4a3964de6192..bbea2656ef9a 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu @@ -1,4 +1,3 @@ - class Philox { public: __device__ Philox( diff --git a/torch/csrc/jit/codegen/cuda/runtime/tensor.cu b/torch/csrc/jit/codegen/cuda/runtime/tensor.cu index e19e77e4f62f..06c352aa8669 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/tensor.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/tensor.cu @@ -2,7 +2,9 @@ typedef unsigned char uint8_t; typedef signed char int8_t; typedef short int int16_t; +typedef unsigned int uint32_t; typedef long long int int64_t; +typedef unsigned long long int uint64_t; template struct Tensor {