Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
d892f7a
nvfuser diffs
Jan 7, 2021
823ad0e
Copying in Philox utility definitions
Jan 7, 2021
dbbd9a8
give PhiloxCudaState and unpack their own files for nvfuser codegen
Jan 8, 2021
6fd3407
try with cuh
Jan 8, 2021
7db88ae
forgot executor_utils.cpp
Jan 8, 2021
3e58cd0
CMakeLists.txt
Jan 8, 2021
82cf408
cargo culting other resource files
Jan 8, 2021
2dbbf60
presumably in-kernel usages should look at local namespace?
Jan 8, 2021
778a866
CMAKE_CURRENT_SOURCE_DIR
Jan 8, 2021
9bae448
fix path
Jan 8, 2021
a297eb1
namespaces in raw headers
Jan 8, 2021
76f5809
final change
Jan 9, 2021
8477146
Struct for seed and offset. Can't put the logic in getters in Philox…
Jan 11, 2021
fd0d07e
Comment why no pragma once in raw headers
Jan 11, 2021
31f1c78
Remove CUDAGeneratorImpl.h from codegen.cpp and remove ULongArg
Jan 11, 2021
8584a92
forgot to remove push ULongArg
Jan 12, 2021
efb56b5
Resolving conflict
Jan 27, 2021
48d61d8
should have accepted TORCH_CUDA_CPP_API for conflicting diff
Jan 27, 2021
ed813dc
Merge remote-tracking branch 'csarofeen/20_12_3_devel' into graphable…
Jan 28, 2021
18ad987
Removing eager mode changes
Feb 2, 2021
170c26e
Merge remote-tracking branch 'csarofeen/20_12_3_devel' into graphable…
Feb 2, 2021
5c3e4e4
Return to original in-kernel api
Mar 3, 2021
8699704
Merge remote-tracking branch 'csarofeen/20_12_3_devel' into graphable…
Mar 17, 2021
daee448
Test passes, but i don't see fusions in profile
Mar 18, 2021
e7345d7
Warmup calls and remove for loop in script??
Mar 18, 2021
280b694
Manually unpack philox_args to avoid std::stuff in at::cuda::philox::…
Mar 18, 2021
fac66c6
Full test passes!
Mar 19, 2021
1333182
Clean up test
Mar 19, 2021
fe4f221
remove nvtx
Mar 19, 2021
e881d65
Merge remote-tracking branch 'csarofeen/20_12_3_devel' into HEAD
jjsjann123 Mar 25, 2021
a934e19
clang-format
jjsjann123 Mar 25, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions test/test_jit_cuda_fuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1973,6 +1973,57 @@ def t(x):
x = x.to("cuda:1")
jit_o = t_jit(x)

@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_graph_rng(self):
self.assertTrue(torch._C._jit_nvfuser_enabled())
size = 10000
a = torch.randn((size,), device="cuda", dtype=torch.float)

def t(x):
o = x + 1.0
o = torch.nn.functional.dropout(o, p=0.1)
o = o + 1.0
o = torch.nn.functional.dropout(o, p=0.1)
return o

t_jit = torch.jit.script(t)

for _ in range(3):
t_jit(a)

self.assertGraphContainsExactly(t_jit.graph_for(a), FUSION_GUARD, 1)

# Control (jitted, ungraphed)
torch.cuda.manual_seed(5)
eager_out = a.clone()
for _ in range(3):
eager_out = t_jit(eager_out)

graph_in = a.clone()
g = torch.cuda._Graph()
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
torch.cuda.manual_seed(5)
g.capture_begin()
graph_out = t_jit(graph_in)
g.capture_end()
torch.cuda.current_stream().wait_stream(s)
# g is now a jitted, graphed version of t.

# Runs a (jitted, graphed) -> (jitted, ungraphed) -> (jitted, graphed) sequence.
# The ops in the overall sequence should be the same as Control.
g.replay()
# graph_out is now filled with g's result. Use it as ungraphed input.
out = t_jit(graph_out)
graph_in.copy_(out)
g.replay()

# If replay() updated RNG state correctly, graph_out should now equal eager_out
self.assertEqual(graph_out, eager_out)

class TestPassManagerCudaFuser(JitTestCase):

@unittest.skipIf(not RUN_CUDA, "requires CUDA")
Expand Down
8 changes: 6 additions & 2 deletions torch/csrc/jit/codegen/cuda/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class CudaKernelGenerator : private kir::IrVisitor {

// Kernels generating random numbers take extra (seed, offset) arguments
if (kernel_summary.is_stochastic) {
code_ << ", unsigned long long seed, unsigned long long offset";
code_ << ", at::PhiloxCudaState philox_args";
}

code_ << ") ";
Expand All @@ -106,7 +106,11 @@ class CudaKernelGenerator : private kir::IrVisitor {
// Random number generator (optional)
if (kernel_summary.is_stochastic) {
indent() << "const int idx = blockIdx.x*blockDim.x + threadIdx.x;\n";
indent() << "Philox rnd(seed, idx, offset);\n";
indent() << "auto offset = philox_args.captured_ ?\n";
indent()
<< " static_cast<uint64_t>(*(philox_args.offset_.ptr) + philox_args.offset_intragraph_) :\n";
indent() << " philox_args.offset_.val;\n";
indent() << "Philox rnd(philox_args.seed_, idx, offset);\n";
}

// Do we have any dynamic shared memory buffers?
Expand Down
11 changes: 5 additions & 6 deletions torch/csrc/jit/codegen/cuda/executor_kernel_arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ void KernelArgumentHolder::push(const IValue& val) {
" Tried to create argument to send to a fused kernel, but got a non-scalar type.");
}

void KernelArgumentHolder::push(const uint64_t& val) {
arguments_.push_back(std::make_unique<ULongArg>(val));
void KernelArgumentHolder::push(const at::PhiloxCudaState& val) {
arguments_.push_back(std::make_unique<PhiloxCudaStateArg>(val));
}

// Create buffer, flatten arguments into it, align by 8 Bytes, return pointers
Expand Down Expand Up @@ -115,17 +115,16 @@ void KernelArgumentHolder::push(const std::vector<at::Tensor>& tensors) {
}

void KernelArgumentHolder::appendPhiloxRNGSeed(uint64_t rand_offset) {
std::pair<uint64_t, uint64_t> philox_engine_inputs;
at::PhiloxCudaState philox_engine_inputs;
auto gen = at::cuda::detail::getDefaultCUDAGenerator();
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen.mutex());
philox_engine_inputs =
at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(
at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_cuda_state(
rand_offset);
}
push(philox_engine_inputs.first);
push(philox_engine_inputs.second);
push(philox_engine_inputs);
}

} // namespace cuda
Expand Down
10 changes: 5 additions & 5 deletions torch/csrc/jit/codegen/cuda/executor_kernel_arg.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include <ATen/CUDAGeneratorImpl.h>
#include <ATen/core/ivalue.h>
#include <c10/util/Exception.h>
#include <torch/csrc/jit/ir/ir.h>
Expand Down Expand Up @@ -53,10 +54,9 @@ struct ArgAbstract {
virtual void* arg() = 0;
};

// Explicitly for philox seed, not a supported type by any other mechanism
struct ULongArg : public ArgAbstract {
uint64_t val_;
explicit ULongArg(uint64_t _val) : val_(_val){};
struct PhiloxCudaStateArg : public ArgAbstract {
at::PhiloxCudaState val_;
PhiloxCudaStateArg(at::PhiloxCudaState _val) : val_(_val){};
void* arg() {
return &val_;
}
Expand Down Expand Up @@ -155,7 +155,7 @@ class KernelArgumentHolder {
// Push a scalar or integer to the arguments
void push(const IValue& val);

void push(const uint64_t& val);
void push(const at::PhiloxCudaState& val);

// Create buffer, flatten arguments into it, align by 8 Bytes, return pointers
// in the buffer
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/codegen/cuda/executor_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <torch/csrc/jit/codegen/fuser/cuda/fused_kernel.h>
#include <torch/csrc/jit/resource_guard.h>

#include <nvfuser_resources/PhiloxCudaStateRaw.h>
#include <nvfuser_resources/block_reduction.h>
#include <nvfuser_resources/broadcast.h>
#include <nvfuser_resources/fp16_support.h>
Expand Down Expand Up @@ -43,6 +44,7 @@ std::string kernelPreamble() {
ss << nvfuser_resources::grid_reduction_cu;
ss << nvfuser_resources::broadcast_cu;
ss << nvfuser_resources::welford_cu;
ss << nvfuser_resources::PhiloxCudaStateRaw_cu;

return ss.str();
}
Expand Down
1 change: 0 additions & 1 deletion torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why was this line removed?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no reason, it was random whitespace at the top of the file.

class Philox {
public:
__device__ Philox(
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/codegen/cuda/runtime/tensor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
typedef unsigned char uint8_t;
typedef signed char int8_t;
typedef short int int16_t;
typedef unsigned int uint32_t;
typedef long long int int64_t;
typedef unsigned long long int uint64_t;

template <typename T, int N>
struct Tensor {
Expand Down