Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dropout prob extremal patch #1804

Merged
merged 15 commits into from
Jul 7, 2022
43 changes: 27 additions & 16 deletions test/test_jit_cuda_fuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ def is_pre_volta():

TEST_BF16 = RUN_NVFUSER and torch.cuda.is_bf16_supported()

TEST_LARGE_TENSOR = RUN_NVFUSER
if RUN_NVFUSER:
torch.ones(1).cuda() # initialize cuda context
TEST_LARGE_TENSOR = torch.cuda.get_device_properties(0).total_memory >= 12e9

class CudaFuserTestOptions():
def __init__(self):
self.old_cpu_fuse = torch._C._jit_can_fuse_on_cpu()
Expand Down Expand Up @@ -183,23 +188,27 @@ def tearDown(self):
self.cuda_fuser_options.restore()
super(TestCudaFuser, self).tearDown()

def _run_helper(self, jit_op, op, *args, check_stride=False, num_fusion=1):
torch.cuda.manual_seed_all(123)
jit_o = jit_op(*args)
torch.cuda.manual_seed_all(123)
def _run_helper(self, jit_op, op, *args, check_stride=False, num_fusion=1, check_runs=1):
seed = 123
torch.cuda.manual_seed_all(seed)
jit_o = jit_op(*args)
torch.cuda.manual_seed_all(123)
o = op(*args)

if type(jit_o) is torch.Tensor:
jit_o = [jit_o, ]
o = [o, ]
for i in range(check_runs):
torch.cuda.manual_seed_all(seed + i)
jit_o = jit_op(*args)
torch.cuda.manual_seed_all(seed + i)
o = op(*args)

if type(jit_o) is torch.Tensor:
jit_o = [jit_o, ]
o = [o, ]

for oo, jit_oo in zip(o, jit_o):
self.assertEqual(oo.dtype, jit_oo.dtype)
self.assertEqual(oo, jit_oo)
if check_stride:
self.assertEqual(oo.stride(), jit_oo.stride())

for oo, jit_oo in zip(o, jit_o):
self.assertEqual(oo.dtype, jit_oo.dtype)
self.assertEqual(oo, jit_oo)
if check_stride:
self.assertEqual(oo.stride(), jit_oo.stride())
self.assertGraphContainsExactly(jit_op.graph_for(*args), FUSION_GUARD, num_fusion, consider_subgraphs=True)

def _run_training_helper(self, jit_op, op, grads, *args):
Expand Down Expand Up @@ -2562,13 +2571,14 @@ def t(x: torch.Tensor, p: float, train: bool):

self._run_helper(t_jit, t, x, 0.15, False)

@unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_dropout_train_nograd_fusion(self):
dtype = torch.float
device = "cuda"
x = torch.randn([10, 4, 8], dtype=dtype, device=device)
x = torch.randn([64, 128, 1024], dtype=dtype, device=device)

def t(x: torch.Tensor, p: float, train: bool):
o = torch.nn.functional.dropout(x, p, training=train)
Expand All @@ -2577,7 +2587,8 @@ def t(x: torch.Tensor, p: float, train: bool):

t_jit = torch.jit.script(t)

self._run_helper(t_jit, t, x, 0.0, True)
self._run_helper(t_jit, t, x, 0.0, True, check_runs=20)
self._run_helper(t_jit, t, x, 1.0, True, check_runs=20)

@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
Expand Down
25 changes: 24 additions & 1 deletion torch/csrc/jit/codegen/cuda/arith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,6 @@ TensorView* unaryOp(
}

NVFUSER_DEFINE_UNARY_OP(set, Set)
NVFUSER_DEFINE_UNARY_OP(randlike, RandLike)
NVFUSER_DEFINE_UNARY_OP(ceil, Ceil)
NVFUSER_DEFINE_UNARY_OP(floor, Floor)
NVFUSER_DEFINE_UNARY_OP(frac, Frac)
Expand All @@ -469,6 +468,30 @@ NVFUSER_DEFINE_UNARY_OP(silu, Silu)
NVFUSER_DEFINE_UNARY_OP(trunc, Trunc)
#undef NVFUSER_DEFINE_UNARY_OP

Val* randlike(Val* v) {
TORCH_CHECK(
isFloatingPointType(v->dtype()),
"input must have floating point type, but got ",
v->dtype());
auto rand_vals = unaryOp(UnaryOpType::RandLike, v);
return where(
eq(rand_vals, IrBuilder::create<Double>(1.0)),
IrBuilder::create<Double>(0.0),
rand_vals);
}

TensorView* randlike(TensorView* v) {
TORCH_CHECK(
isFloatingPointType(v->dtype()),
"input must have floating point type, but got ",
v->dtype());
auto rand_vals = unaryOp(UnaryOpType::RandLike, v);
return where(
eq(rand_vals, IrBuilder::create<Double>(1.0)),
IrBuilder::create<Double>(0.0),
rand_vals);
}

Val* bitwise_not(Val* v) {
TORCH_CHECK(
isIntegralType(v->dtype()) || v->dtype() == DataType::Bool,
Expand Down
19 changes: 12 additions & 7 deletions torch/csrc/jit/codegen/cuda/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,15 @@ class CudaKernelGenerator : private OptOutConstDispatch {
}

private:
explicit CudaKernelGenerator(const kir::Kernel* kernel) : kernel_(kernel) {}
explicit CudaKernelGenerator(const kir::Kernel* kernel) : kernel_(kernel) {
initStringStreamFormat(code_);
}

void initStringStreamFormat(std::stringstream& ss) {
const int digits = std::numeric_limits<Double::ScalarType>::max_digits10;
ss.imbue(std::locale("C"));
ss << std::scientific << std::setprecision(digits);
}

// Generates the kernel function declaration
void genDeclaration(const std::string& kernel_name) {
Expand Down Expand Up @@ -358,6 +366,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {

std::string gen(const Statement* stmt) {
std::stringstream tmp_code;
initStringStreamFormat(tmp_code);
std::swap(tmp_code, code_);
OptOutConstDispatch::handle(stmt);
std::swap(tmp_code, code_);
Expand Down Expand Up @@ -419,9 +428,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {
} else if (std::isnan(val)) {
code_ << "NAN";
} else {
const int digits =
std::numeric_limits<Double::ScalarType>::max_digits10;
code_ << std::setprecision(digits) << val;
code_ << val;
}
} else {
code_ << varName(d);
Expand Down Expand Up @@ -454,9 +461,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {
if (def != nullptr && !has_alloc) {
code_ << "(" << gen(def) << ")";
} else if (c->isConst()) {
const int digits = std::numeric_limits<double>::max_digits10;
code_ << "std::complex<double>" << std::setprecision(digits)
<< *c->value();
code_ << "std::complex<double>" << *c->value();
} else {
code_ << varName(c);
}
Expand Down