Skip to content

Commit

Permalink
Dropout prob extremal patch (#1804)
Browse files Browse the repository at this point in the history
Fixes #1799

1. Updates rand_like by changing output==1 to 0 via `where`;
2. Patches codegen float output.
  • Loading branch information
jjsjann123 committed Jul 7, 2022
1 parent 282c429 commit 037a75a
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 24 deletions.
43 changes: 27 additions & 16 deletions test/test_jit_cuda_fuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ def is_pre_volta():

TEST_BF16 = RUN_NVFUSER and torch.cuda.is_bf16_supported()

TEST_LARGE_TENSOR = RUN_NVFUSER
if RUN_NVFUSER:
torch.ones(1).cuda() # initialize cuda context
TEST_LARGE_TENSOR = torch.cuda.get_device_properties(0).total_memory >= 12e9

class CudaFuserTestOptions():
def __init__(self):
self.old_cpu_fuse = torch._C._jit_can_fuse_on_cpu()
Expand Down Expand Up @@ -183,23 +188,27 @@ def tearDown(self):
self.cuda_fuser_options.restore()
super(TestCudaFuser, self).tearDown()

def _run_helper(self, jit_op, op, *args, check_stride=False, num_fusion=1):
torch.cuda.manual_seed_all(123)
jit_o = jit_op(*args)
torch.cuda.manual_seed_all(123)
def _run_helper(self, jit_op, op, *args, check_stride=False, num_fusion=1, check_runs=1):
seed = 123
torch.cuda.manual_seed_all(seed)
jit_o = jit_op(*args)
torch.cuda.manual_seed_all(123)
o = op(*args)

if type(jit_o) is torch.Tensor:
jit_o = [jit_o, ]
o = [o, ]
for i in range(check_runs):
torch.cuda.manual_seed_all(seed + i)
jit_o = jit_op(*args)
torch.cuda.manual_seed_all(seed + i)
o = op(*args)

if type(jit_o) is torch.Tensor:
jit_o = [jit_o, ]
o = [o, ]

for oo, jit_oo in zip(o, jit_o):
self.assertEqual(oo.dtype, jit_oo.dtype)
self.assertEqual(oo, jit_oo)
if check_stride:
self.assertEqual(oo.stride(), jit_oo.stride())

for oo, jit_oo in zip(o, jit_o):
self.assertEqual(oo.dtype, jit_oo.dtype)
self.assertEqual(oo, jit_oo)
if check_stride:
self.assertEqual(oo.stride(), jit_oo.stride())
self.assertGraphContainsExactly(jit_op.graph_for(*args), FUSION_GUARD, num_fusion, consider_subgraphs=True)

def _run_training_helper(self, jit_op, op, grads, *args):
Expand Down Expand Up @@ -2562,13 +2571,14 @@ def t(x: torch.Tensor, p: float, train: bool):

self._run_helper(t_jit, t, x, 0.15, False)

@unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_dropout_train_nograd_fusion(self):
dtype = torch.float
device = "cuda"
x = torch.randn([10, 4, 8], dtype=dtype, device=device)
x = torch.randn([64, 128, 1024], dtype=dtype, device=device)

def t(x: torch.Tensor, p: float, train: bool):
o = torch.nn.functional.dropout(x, p, training=train)
Expand All @@ -2577,7 +2587,8 @@ def t(x: torch.Tensor, p: float, train: bool):

t_jit = torch.jit.script(t)

self._run_helper(t_jit, t, x, 0.0, True)
self._run_helper(t_jit, t, x, 0.0, True, check_runs=20)
self._run_helper(t_jit, t, x, 1.0, True, check_runs=20)

@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
Expand Down
25 changes: 24 additions & 1 deletion torch/csrc/jit/codegen/cuda/arith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,6 @@ TensorView* unaryOp(
}

NVFUSER_DEFINE_UNARY_OP(set, Set)
NVFUSER_DEFINE_UNARY_OP(randlike, RandLike)
NVFUSER_DEFINE_UNARY_OP(ceil, Ceil)
NVFUSER_DEFINE_UNARY_OP(floor, Floor)
NVFUSER_DEFINE_UNARY_OP(frac, Frac)
Expand All @@ -469,6 +468,30 @@ NVFUSER_DEFINE_UNARY_OP(silu, Silu)
NVFUSER_DEFINE_UNARY_OP(trunc, Trunc)
#undef NVFUSER_DEFINE_UNARY_OP

Val* randlike(Val* v) {
TORCH_CHECK(
isFloatingPointType(v->dtype()),
"input must have floating point type, but got ",
v->dtype());
auto rand_vals = unaryOp(UnaryOpType::RandLike, v);
return where(
eq(rand_vals, IrBuilder::create<Double>(1.0)),
IrBuilder::create<Double>(0.0),
rand_vals);
}

TensorView* randlike(TensorView* v) {
TORCH_CHECK(
isFloatingPointType(v->dtype()),
"input must have floating point type, but got ",
v->dtype());
auto rand_vals = unaryOp(UnaryOpType::RandLike, v);
return where(
eq(rand_vals, IrBuilder::create<Double>(1.0)),
IrBuilder::create<Double>(0.0),
rand_vals);
}

Val* bitwise_not(Val* v) {
TORCH_CHECK(
isIntegralType(v->dtype()) || v->dtype() == DataType::Bool,
Expand Down
19 changes: 12 additions & 7 deletions torch/csrc/jit/codegen/cuda/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,15 @@ class CudaKernelGenerator : private OptOutConstDispatch {
}

private:
explicit CudaKernelGenerator(const kir::Kernel* kernel) : kernel_(kernel) {}
explicit CudaKernelGenerator(const kir::Kernel* kernel) : kernel_(kernel) {
initStringStreamFormat(code_);
}

void initStringStreamFormat(std::stringstream& ss) {
const int digits = std::numeric_limits<Double::ScalarType>::max_digits10;
ss.imbue(std::locale("C"));
ss << std::scientific << std::setprecision(digits);
}

// Generates the kernel function declaration
void genDeclaration(const std::string& kernel_name) {
Expand Down Expand Up @@ -358,6 +366,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {

std::string gen(const Statement* stmt) {
std::stringstream tmp_code;
initStringStreamFormat(tmp_code);
std::swap(tmp_code, code_);
OptOutConstDispatch::handle(stmt);
std::swap(tmp_code, code_);
Expand Down Expand Up @@ -419,9 +428,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {
} else if (std::isnan(val)) {
code_ << "NAN";
} else {
const int digits =
std::numeric_limits<Double::ScalarType>::max_digits10;
code_ << std::setprecision(digits) << val;
code_ << val;
}
} else {
code_ << varName(d);
Expand Down Expand Up @@ -454,9 +461,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {
if (def != nullptr && !has_alloc) {
code_ << "(" << gen(def) << ")";
} else if (c->isConst()) {
const int digits = std::numeric_limits<double>::max_digits10;
code_ << "std::complex<double>" << std::setprecision(digits)
<< *c->value();
code_ << "std::complex<double>" << *c->value();
} else {
code_ << varName(c);
}
Expand Down

0 comments on commit 037a75a

Please sign in to comment.