Skip to content

Commit

Permalink
Merge remote-tracking branch 'csarofeen/devel' into HEAD
Browse files Browse the repository at this point in the history
  • Loading branch information
jjsjann123 committed Aug 6, 2022
2 parents 5b51849 + 1617373 commit dfe02f3
Show file tree
Hide file tree
Showing 111 changed files with 6,153 additions and 4,273 deletions.
36 changes: 18 additions & 18 deletions benchmarks/cpp/nvfuser/bert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ static void MagicScheduler_DivMaxSoftDropFwd(
std::vector<at::Tensor> cg_outputs;

auto norm_params = getPersistentHeuristics(&fusion, at_inputs);
TORCH_CHECK(norm_params.has_value(), "Norm scheduler can't be used!");
schedulePersistentKernel(&fusion, norm_params.value());
TORCH_CHECK(norm_params != nullptr, "Norm scheduler can't be used!");
schedulePersistentKernel(&fusion, *norm_params);

FusionExecutor fe;
fe.compileFusion(&fusion);
Expand All @@ -143,7 +143,7 @@ static void MagicScheduler_DivMaxSoftDropFwd(
cudaDeviceSynchronize();
for (auto _ : benchmark_state) {
CudaKernelTimer timer;
cg_outputs = fe.runFusion({t0, t1}, norm_params.value().lparams);
cg_outputs = fe.runFusion({t0, t1}, norm_params->lparams);
benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0);
}
// Sync everything up before we're finished, don't want to run ahead on the
Expand Down Expand Up @@ -193,8 +193,8 @@ static void MagicScheduler_DivMaxSoftDropBwd(
std::vector<at::Tensor> cg_outputs;

auto norm_params = getPersistentHeuristics(&fusion, at_inputs);
TORCH_CHECK(norm_params.has_value(), "Norm scheduler can't be used!");
schedulePersistentKernel(&fusion, norm_params.value());
TORCH_CHECK(norm_params != nullptr, "Norm scheduler can't be used!");
schedulePersistentKernel(&fusion, *norm_params);

FusionExecutor fe;
fe.compileFusion(&fusion);
Expand All @@ -203,7 +203,7 @@ static void MagicScheduler_DivMaxSoftDropBwd(
cudaDeviceSynchronize();
for (auto _ : benchmark_state) {
CudaKernelTimer timer;
cg_outputs = fe.runFusion({t0, t1, t2, t3}, norm_params.value().lparams);
cg_outputs = fe.runFusion({t0, t1, t2, t3}, norm_params->lparams);
benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0);
}
// Sync everything up before we're finished, don't want to run ahead on the
Expand Down Expand Up @@ -308,8 +308,8 @@ static void MagicScheduler_BiasDropoutAddLayernormFwd(
std::vector<at::Tensor> cg_outputs;

auto norm_params = getPersistentHeuristics(&fusion, at_inputs);
TORCH_CHECK(norm_params.has_value(), "Norm scheduler can't be used!");
schedulePersistentKernel(&fusion, norm_params.value());
TORCH_CHECK(norm_params != nullptr, "Norm scheduler can't be used!");
schedulePersistentKernel(&fusion, *norm_params);

FusionExecutor fe;
fe.compileFusion(&fusion);
Expand All @@ -319,7 +319,7 @@ static void MagicScheduler_BiasDropoutAddLayernormFwd(
cudaDeviceSynchronize();
for (auto _ : benchmark_state) {
CudaKernelTimer timer;
cg_outputs = fe.runFusion(at_inputs, norm_params.value().lparams);
cg_outputs = fe.runFusion(at_inputs, norm_params->lparams);
benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0);
}
// Sync everything up before we're finished, don't want to run ahead on the
Expand Down Expand Up @@ -423,8 +423,8 @@ static void MagicScheduler_BiasDropoutAddLayernormBwd1(
std::vector<at::Tensor> cg_outputs;

auto norm_params = getReductionHeuristics(&fusion, at_inputs);
TORCH_CHECK(norm_params.has_value(), "Norm scheduler can't be used!");
scheduleReduction(&fusion, norm_params.value());
TORCH_CHECK(norm_params != nullptr, "Norm scheduler can't be used!");
scheduleReduction(&fusion, *norm_params);

FusionExecutor fe;
fe.compileFusion(&fusion);
Expand All @@ -434,7 +434,7 @@ static void MagicScheduler_BiasDropoutAddLayernormBwd1(
cudaDeviceSynchronize();
for (auto _ : benchmark_state) {
clearL2Cache();
cg_outputs = fe.runFusion(at_inputs, norm_params.value().lparams);
cg_outputs = fe.runFusion(at_inputs, norm_params->lparams);
benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0);
}
// Sync everything up before we're finished, don't want to run ahead on the
Expand Down Expand Up @@ -534,8 +534,8 @@ static void MagicScheduler_BiasDropoutAddLayernormBwd2(
std::vector<at::Tensor> cg_outputs;

auto norm_params = getPersistentHeuristics(&fusion, at_inputs);
TORCH_CHECK(norm_params.has_value(), "Norm scheduler can't be used!");
schedulePersistentKernel(&fusion, norm_params.value());
TORCH_CHECK(norm_params != nullptr, "Norm scheduler can't be used!");
schedulePersistentKernel(&fusion, *norm_params);

FusionExecutor fe;
fe.compileFusion(&fusion);
Expand All @@ -545,7 +545,7 @@ static void MagicScheduler_BiasDropoutAddLayernormBwd2(
cudaDeviceSynchronize();
for (auto _ : benchmark_state) {
CudaKernelTimer timer;
cg_outputs = fe.runFusion(at_inputs, norm_params.value().lparams);
cg_outputs = fe.runFusion(at_inputs, norm_params->lparams);
benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0);
}
// Sync everything up before we're finished, don't want to run ahead on the
Expand Down Expand Up @@ -625,8 +625,8 @@ static void MagicScheduler_BiasDropoutAddLayernormBwd3(
std::vector<at::Tensor> cg_outputs;

auto norm_params = getReductionHeuristics(&fusion, at_inputs);
TORCH_CHECK(norm_params.has_value(), "Norm scheduler can't be used!");
scheduleReduction(&fusion, norm_params.value());
TORCH_CHECK(norm_params != nullptr, "Norm scheduler can't be used!");
scheduleReduction(&fusion, *norm_params);

FusionExecutor fe;
fe.compileFusion(&fusion);
Expand All @@ -636,7 +636,7 @@ static void MagicScheduler_BiasDropoutAddLayernormBwd3(
cudaDeviceSynchronize();
for (auto _ : benchmark_state) {
CudaKernelTimer timer;
cg_outputs = fe.runFusion(at_inputs, norm_params.value().lparams);
cg_outputs = fe.runFusion(at_inputs, norm_params->lparams);
benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0);
}
// Sync everything up before we're finished, don't want to run ahead on the
Expand Down
3 changes: 1 addition & 2 deletions benchmarks/cpp/nvfuser/broadcast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ static void NvFuserScheduler_Broadcast(

auto compile_log = fusion_executor_cache->getMostRecentExecutorInfo();
auto executor_instance = compile_log.fusion_executor;
TORCH_INTERNAL_ASSERT(compile_log.pointwise_params.has_value());
auto params = toString(compile_log.pointwise_params.value());
auto params = toString(compile_log.params);
auto lparams = toString(compile_log.fusion_executor->lastLaunchParams());

benchmark_state.SetLabel(params + lparams);
Expand Down
3 changes: 1 addition & 2 deletions benchmarks/cpp/nvfuser/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ static void NvFuserScheduler_Reduction(

auto compile_log = fusion_executor_cache->getMostRecentExecutorInfo();
auto executor_instance = compile_log.fusion_executor;
TORCH_INTERNAL_ASSERT(compile_log.reduction_params.has_value());
auto rparams = toString(compile_log.reduction_params.value());
auto rparams = toString(compile_log.params);
auto lparams = toString(compile_log.fusion_executor->lastLaunchParams());

benchmark_state.SetLabel(rparams + lparams);
Expand Down
6 changes: 2 additions & 4 deletions benchmarks/cpp/nvfuser/scale_bias_relu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,7 @@ static void NvFuserScheduler_SBR(

auto compile_log = fusion_executor_cache->getMostRecentExecutorInfo();
auto executor_instance = compile_log.fusion_executor;
TORCH_INTERNAL_ASSERT(compile_log.pointwise_params.has_value());
auto params = toString(compile_log.pointwise_params.value());
auto params = toString(compile_log.params);
auto lparams = toString(compile_log.fusion_executor->lastLaunchParams());

benchmark_state.SetLabel(params + lparams);
Expand Down Expand Up @@ -238,8 +237,7 @@ static void NvFuserScheduler_SBR_Norm(

auto compile_log = fusion_executor_cache->getMostRecentExecutorInfo();
auto executor_instance = compile_log.fusion_executor;
TORCH_INTERNAL_ASSERT(compile_log.pointwise_params.has_value());
auto params = toString(compile_log.pointwise_params.value());
auto params = toString(compile_log.params);
auto lparams = toString(compile_log.fusion_executor->lastLaunchParams());

benchmark_state.SetLabel(params + lparams);
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/cpp/nvfuser/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ static void setupTranspose(
return (is_transpose) ? transpose(tv, axes.first, axes.second) : tv;
};

auto input1 = makeContigTensor(num_dims);
auto input2 = makeContigTensor(num_dims);
auto input1 = makeContigTensor(num_dims, dtype);
auto input2 = makeContigTensor(num_dims, dtype);
fusion->addInput(input1);
fusion->addInput(input2);

Expand Down
30 changes: 18 additions & 12 deletions benchmarks/cpp/nvfuser/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,20 @@ std::string toString(PointwiseParams params) {
return ss.str();
}

std::string toString(const std::shared_ptr<HeuristicParams>& params) {
auto rparams = std::dynamic_pointer_cast<ReductionParams>(params);
if (rparams) {
return toString(*rparams);
}
auto pparams = std::dynamic_pointer_cast<PointwiseParams>(params);
if (pparams) {
return toString(*pparams);
}
TORCH_INTERNAL_ASSERT(
false,
"Unknown heuristic parameter type. Did you just added a new heuristic parameter type but forget to update here?");
}

std::string toString(LaunchParams lparams) {
std::stringstream ss;
lparams.toString();
Expand Down Expand Up @@ -123,9 +137,7 @@ TensorView* makeContigTensor(size_t ndims, DataType dtype) {
.build();
}

TensorView* makeConcreteTensor(
std::vector<int64_t> shape,
DataType dtype) {
TensorView* makeConcreteTensor(std::vector<int64_t> shape, DataType dtype) {
return TensorViewBuilder().shape(shape).dtype(dtype).build();
}

Expand Down Expand Up @@ -157,15 +169,9 @@ void runBenchmarkIterations(
auto compile_log = fusion_executor_cache->getMostRecentExecutorInfo();
auto executor_instance = compile_log.fusion_executor;

if (compile_log.reduction_params.has_value()) {
auto rparams = toString(compile_log.reduction_params.value());
auto lparams = toString(compile_log.fusion_executor->lastLaunchParams());
benchmark_state.SetLabel(rparams + lparams);
} else if (compile_log.pointwise_params.has_value()){
auto pparams = toString(compile_log.pointwise_params.value());
auto lparams = toString(compile_log.fusion_executor->lastLaunchParams());
benchmark_state.SetLabel(pparams + lparams);
}
auto params = toString(compile_log.params);
auto lparams = toString(compile_log.fusion_executor->lastLaunchParams());
benchmark_state.SetLabel(params + lparams);

executor_instance->setMeasureKernelTimeFlag(true);

Expand Down
1 change: 1 addition & 0 deletions benchmarks/cpp/nvfuser/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ TensorView* makeContigConcreteTensor(

std::string toString(ReductionParams rparams);
std::string toString(PointwiseParams params);
std::string toString(const std::shared_ptr<HeuristicParams>& params);
std::string toString(LaunchParams lparams);

// Run benchmark iterations with provided inputs. If not segmented, report
Expand Down
2 changes: 2 additions & 0 deletions build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -717,8 +717,10 @@ libtorch_cuda_core_sources = [
"torch/csrc/jit/codegen/cuda/register_interface.cpp",
"torch/csrc/jit/codegen/cuda/root_domain_map.cpp",
"torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp",
"torch/csrc/jit/codegen/cuda/scheduler/pointwise_utils.cpp",
"torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp",
"torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp",
"torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp",
"torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp",
"torch/csrc/jit/codegen/cuda/scheduler/registry.cpp",
"torch/csrc/jit/codegen/cuda/scheduler/utils.cpp",
Expand Down
23 changes: 23 additions & 0 deletions test/test_jit_cuda_fuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,7 @@ def test_unary_ops(self):
torch.isreal,
torch.nn.functional.softplus,
torch.nn.functional.gelu,
torch.nn.functional.leaky_relu,
torch.relu,
torch.sigmoid,
torch.bitwise_not,
Expand Down Expand Up @@ -4938,6 +4939,28 @@ def t2(x0, x1, x2):
t2_jit = torch.jit.script(t2)
self._run_helper(t2_jit, t2, x0, x1, x2, check_stride=True)

@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_type_inference(self):
device = "cuda"
x0 = torch.randn(10, 128, device=device)
x1 = torch.rand_like(x0)
x2 = torch.rand_like(x0)

def t(x0, x1, x2, flag : bool = True):
x3 = 2.0 * x0
x4 = 2.0 * x1
x5 = 2.0 * x2
if flag:
return torch.stack([x3, x4, x5], dim=-1)
# second code path doesn't run through profiling
# hence would utilize type inference with profiling information
return x0 + x1 + x2

t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x0, x1, x2, check_stride=True)


class TestEnableDisableCudaFuser(JitTestCase):
def setUp(self):
Expand Down
7 changes: 1 addition & 6 deletions torch/csrc/jit/codegen/cuda/arith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,7 @@ NVFUSER_DEFINE_UNARY_OP(relu, Relu)
NVFUSER_DEFINE_UNARY_OP(round, Round)
NVFUSER_DEFINE_UNARY_OP(silu, Silu)
NVFUSER_DEFINE_UNARY_OP(trunc, Trunc)
NVFUSER_DEFINE_UNARY_OP(print, Print)
#undef NVFUSER_DEFINE_UNARY_OP

Val* randlike(Val* v) {
Expand Down Expand Up @@ -1430,12 +1431,6 @@ WelfordResult::WelfordResult(
TORCH_INTERNAL_ASSERT(avg->definition()->sameAs(n->definition()));
}

WelfordResult WelfordResult::rFactor(const std::vector<int>& axes) {
auto o_tv = avg->definition()->as<WelfordOp>()->out()->as<TensorView>();
auto rf_tvs = o_tv->rFactor(axes, std::vector<TensorView*>{avg, var_sum, n});
return WelfordResult{rf_tvs.at(0), rf_tvs.at(1), rf_tvs.at(2)};
}

// COMPOUND OPERATIONS

// add_alpha
Expand Down
5 changes: 3 additions & 2 deletions torch/csrc/jit/codegen/cuda/arith.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,6 @@ class TORCH_CUDA_CU_API WelfordResult {
TensorView* in_avg,
TensorView* in_var_sum,
TensorView* in_n);

WelfordResult rFactor(const std::vector<int>& axes);
};

//! Welford operator on specified axes. This is currently the only scan op with
Expand Down Expand Up @@ -253,6 +251,9 @@ TORCH_CUDA_CU_API TensorView* isposinf(TensorView*);
// isreal
TORCH_CUDA_CU_API Val* isreal(Val*);
TORCH_CUDA_CU_API TensorView* isreal(TensorView*);
// print
TORCH_CUDA_CU_API Val* print(Val*);
TORCH_CUDA_CU_API TensorView* print(TensorView*);

// Broadcasts inp based on bool vector. Size of broadcast bool vector should be
// the number of dims desired in the broadcasted tensor. This vector should be
Expand Down
Loading

0 comments on commit dfe02f3

Please sign in to comment.