Skip to content

Commit 14bd01e

Browse files
author
mcarilli
authored
[CUDA graphs] [JIT] Capture-safe RNG in nvfuser (#593)
Eager mode RNG kernels needed some minor changes to interact safely with cuda graphs. This PR extends those changes to the kernels generated by nvfuser.
1 parent 146c1a4 commit 14bd01e

File tree

7 files changed

+71
-14
lines changed

7 files changed

+71
-14
lines changed

test/test_jit_cuda_fuser.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1973,6 +1973,57 @@ def t(x):
19731973
x = x.to("cuda:1")
19741974
jit_o = t_jit(x)
19751975

1976+
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
1977+
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
1978+
"Requires fusion optimization pass to be effective")
1979+
def test_graph_rng(self):
1980+
self.assertTrue(torch._C._jit_nvfuser_enabled())
1981+
size = 10000
1982+
a = torch.randn((size,), device="cuda", dtype=torch.float)
1983+
1984+
def t(x):
1985+
o = x + 1.0
1986+
o = torch.nn.functional.dropout(o, p=0.1)
1987+
o = o + 1.0
1988+
o = torch.nn.functional.dropout(o, p=0.1)
1989+
return o
1990+
1991+
t_jit = torch.jit.script(t)
1992+
1993+
for _ in range(3):
1994+
t_jit(a)
1995+
1996+
self.assertGraphContainsExactly(t_jit.graph_for(a), FUSION_GUARD, 1)
1997+
1998+
# Control (jitted, ungraphed)
1999+
torch.cuda.manual_seed(5)
2000+
eager_out = a.clone()
2001+
for _ in range(3):
2002+
eager_out = t_jit(eager_out)
2003+
2004+
graph_in = a.clone()
2005+
g = torch.cuda._Graph()
2006+
s = torch.cuda.Stream()
2007+
s.wait_stream(torch.cuda.current_stream())
2008+
with torch.cuda.stream(s):
2009+
torch.cuda.manual_seed(5)
2010+
g.capture_begin()
2011+
graph_out = t_jit(graph_in)
2012+
g.capture_end()
2013+
torch.cuda.current_stream().wait_stream(s)
2014+
# g is now a jitted, graphed version of t.
2015+
2016+
# Runs a (jitted, graphed) -> (jitted, ungraphed) -> (jitted, graphed) sequence.
2017+
# The ops in the overall sequence should be the same as Control.
2018+
g.replay()
2019+
# graph_out is now filled with g's result. Use it as ungraphed input.
2020+
out = t_jit(graph_out)
2021+
graph_in.copy_(out)
2022+
g.replay()
2023+
2024+
# If replay() updated RNG state correctly, graph_out should now equal eager_out
2025+
self.assertEqual(graph_out, eager_out)
2026+
19762027
class TestPassManagerCudaFuser(JitTestCase):
19772028

19782029
@unittest.skipIf(not RUN_CUDA, "requires CUDA")

torch/csrc/jit/codegen/cuda/codegen.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ class CudaKernelGenerator : private kir::IrVisitor {
9393

9494
// Kernels generating random numbers take extra (seed, offset) arguments
9595
if (kernel_summary.is_stochastic) {
96-
code_ << ", unsigned long long seed, unsigned long long offset";
96+
code_ << ", at::PhiloxCudaState philox_args";
9797
}
9898

9999
code_ << ") ";
@@ -106,7 +106,11 @@ class CudaKernelGenerator : private kir::IrVisitor {
106106
// Random number generator (optional)
107107
if (kernel_summary.is_stochastic) {
108108
indent() << "const int idx = blockIdx.x*blockDim.x + threadIdx.x;\n";
109-
indent() << "Philox rnd(seed, idx, offset);\n";
109+
indent() << "auto offset = philox_args.captured_ ?\n";
110+
indent()
111+
<< " static_cast<uint64_t>(*(philox_args.offset_.ptr) + philox_args.offset_intragraph_) :\n";
112+
indent() << " philox_args.offset_.val;\n";
113+
indent() << "Philox rnd(philox_args.seed_, idx, offset);\n";
110114
}
111115

112116
// Do we have any dynamic shared memory buffers?

torch/csrc/jit/codegen/cuda/executor_kernel_arg.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ void KernelArgumentHolder::push(const IValue& val) {
7878
" Tried to create argument to send to a fused kernel, but got a non-scalar type.");
7979
}
8080

81-
void KernelArgumentHolder::push(const uint64_t& val) {
82-
arguments_.push_back(std::make_unique<ULongArg>(val));
81+
void KernelArgumentHolder::push(const at::PhiloxCudaState& val) {
82+
arguments_.push_back(std::make_unique<PhiloxCudaStateArg>(val));
8383
}
8484

8585
// Create buffer, flatten arguments into it, align by 8 Bytes, return pointers
@@ -115,17 +115,16 @@ void KernelArgumentHolder::push(const std::vector<at::Tensor>& tensors) {
115115
}
116116

117117
void KernelArgumentHolder::appendPhiloxRNGSeed(uint64_t rand_offset) {
118-
std::pair<uint64_t, uint64_t> philox_engine_inputs;
118+
at::PhiloxCudaState philox_engine_inputs;
119119
auto gen = at::cuda::detail::getDefaultCUDAGenerator();
120120
{
121121
// See Note [Acquire lock when using random generators]
122122
std::lock_guard<std::mutex> lock(gen.mutex());
123123
philox_engine_inputs =
124-
at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(
124+
at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_cuda_state(
125125
rand_offset);
126126
}
127-
push(philox_engine_inputs.first);
128-
push(philox_engine_inputs.second);
127+
push(philox_engine_inputs);
129128
}
130129

131130
} // namespace cuda

torch/csrc/jit/codegen/cuda/executor_kernel_arg.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22

3+
#include <ATen/CUDAGeneratorImpl.h>
34
#include <ATen/core/ivalue.h>
45
#include <c10/util/Exception.h>
56
#include <torch/csrc/jit/ir/ir.h>
@@ -53,10 +54,9 @@ struct ArgAbstract {
5354
virtual void* arg() = 0;
5455
};
5556

56-
// Explicitly for philox seed, not a supported type by any other mechanism
57-
struct ULongArg : public ArgAbstract {
58-
uint64_t val_;
59-
explicit ULongArg(uint64_t _val) : val_(_val){};
57+
struct PhiloxCudaStateArg : public ArgAbstract {
58+
at::PhiloxCudaState val_;
59+
PhiloxCudaStateArg(at::PhiloxCudaState _val) : val_(_val){};
6060
void* arg() {
6161
return &val_;
6262
}
@@ -155,7 +155,7 @@ class KernelArgumentHolder {
155155
// Push a scalar or integer to the arguments
156156
void push(const IValue& val);
157157

158-
void push(const uint64_t& val);
158+
void push(const at::PhiloxCudaState& val);
159159

160160
// Create buffer, flatten arguments into it, align by 8 Bytes, return pointers
161161
// in the buffer

torch/csrc/jit/codegen/cuda/executor_utils.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <torch/csrc/jit/codegen/fuser/cuda/fused_kernel.h>
1313
#include <torch/csrc/jit/resource_guard.h>
1414

15+
#include <nvfuser_resources/PhiloxCudaStateRaw.h>
1516
#include <nvfuser_resources/block_reduction.h>
1617
#include <nvfuser_resources/broadcast.h>
1718
#include <nvfuser_resources/fp16_support.h>
@@ -43,6 +44,7 @@ std::string kernelPreamble() {
4344
ss << nvfuser_resources::grid_reduction_cu;
4445
ss << nvfuser_resources::broadcast_cu;
4546
ss << nvfuser_resources::welford_cu;
47+
ss << nvfuser_resources::PhiloxCudaStateRaw_cu;
4648

4749
return ss.str();
4850
}

torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
class Philox {
32
public:
43
__device__ Philox(

torch/csrc/jit/codegen/cuda/runtime/tensor.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
typedef unsigned char uint8_t;
33
typedef signed char int8_t;
44
typedef short int int16_t;
5+
typedef unsigned int uint32_t;
56
typedef long long int int64_t;
7+
typedef unsigned long long int uint64_t;
68

79
template <typename T, int N>
810
struct Tensor {

0 commit comments

Comments
 (0)