Skip to content

Commit

Permalink
Separate kernel compilation API from kernel execution API (#1914)
Browse files Browse the repository at this point in the history
1. Mostly mechanical changes to refactor some of KernelArgumentHolder in our stack instead of direct use of at::Tensor/IValue:
Note: we are still holding a ref counted at::Tensor within kernel arg holder for tensor entries, simply because we want to forward it in case of aliased output. This is quite unsatisfying. But to properly strip framework Tensor from codegen stack, we need quite some refactor to abstract away the ownership of memory and allocator. That's for some future PRs.
2. Separate compilation from execution of kernels, currently using FusionExecutorCache::compileFusion and FusionExecutorCache::runFusionWithInputs. Note that the compilation API is still experimental. We currently kick off compilation into a separate thread. This part would need to be exposed & integrated into our python API.

TODO for follow up PRs:
- trivial forwarding input to outputs
- infer outputs should switch from meta tensor to fake tensor in order to preserve device
- segmented fusion should/could be compiled in parallel, since we can infer outputs without a compiled kernel.
- inputs_id_lookup should be refactored into KernelArgumentHolder, since we currently use args for passing inputs around.
- index mode currently is per fusion. which is not neccesary and could be refactored into per segmented fuion instead.
- bind kernel inputs should also try to bind cpu scalar with int type, since the runtime value can also be used in shape inference. Generally speaking, cpu scalar dtype should also be checked during validation.
- high water mark could be refactored into using occupancy API after compilation, so we are not unnecessarily recompile when we don't have to.
  • Loading branch information
jjsjann123 authored Aug 26, 2022
1 parent b34e3b9 commit b247dcf
Show file tree
Hide file tree
Showing 22 changed files with 1,251 additions and 542 deletions.
14 changes: 10 additions & 4 deletions benchmarks/cpp/nvfuser/heuristic_lookup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,15 @@ static void LayerNormBackward_HeuristicLookup(

auto runtime = getLayerBackwardNormRuntime(
std::move(fusion_ptr), fec, aten_inputs, shape, norm_shape);

KernelArgumentHolder args = KernelArgumentHolder::createKernelArgumentHolder(aten_inputs);

TORCH_INTERNAL_ASSERT(
runtime->getMaybeHeuristicsFor(aten_inputs).has_value());
runtime->getMaybeHeuristicsFor(args).has_value());

for (auto _ : benchmark_state) {
// Setup (not included in the measurement)
runtime->getMaybeHeuristicsFor(aten_inputs);
runtime->getMaybeHeuristicsFor(args);
}
}

Expand Down Expand Up @@ -152,12 +155,15 @@ static void LayerNormForward_HeuristicLookup(

auto runtime = getLayerForwardNormRuntime(
std::move(fusion_ptr), fec, aten_inputs, shape, norm_shape);

KernelArgumentHolder args = KernelArgumentHolder::createKernelArgumentHolder(aten_inputs);

TORCH_INTERNAL_ASSERT(
runtime->getMaybeHeuristicsFor(aten_inputs).has_value());
runtime->getMaybeHeuristicsFor(args).has_value());

for (auto _ : benchmark_state) {
// Setup (not included in the measurement)
runtime->getMaybeHeuristicsFor(aten_inputs);
runtime->getMaybeHeuristicsFor(args);
}
}

Expand Down
9 changes: 7 additions & 2 deletions benchmarks/cpp/nvfuser/shape_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,11 @@ void LayerNormBackward_ShapeInference_Base(

auto runtime = getLayerBackwardNormRuntime(
std::move(fusion_ptr), fec, aten_inputs, shape, norm_shape);

KernelArgumentHolder args = KernelArgumentHolder::createKernelArgumentHolder(aten_inputs);

TORCH_INTERNAL_ASSERT(
runtime->getMaybeHeuristicsFor(aten_inputs).has_value());
runtime->getMaybeHeuristicsFor(args).has_value());

fec->profile(true);
fec->disableKernelLaunch();
Expand Down Expand Up @@ -172,8 +175,10 @@ void LayerNormForward_ShapeInferenceBase(
auto runtime = getLayerForwardNormRuntime(
std::move(fusion_ptr), fec, aten_inputs, shape, norm_shape);

KernelArgumentHolder args = KernelArgumentHolder::createKernelArgumentHolder(aten_inputs);

TORCH_INTERNAL_ASSERT(
runtime->getMaybeHeuristicsFor(aten_inputs).has_value());
runtime->getMaybeHeuristicsFor(args).has_value());

fec->profile(true);
fec->disableKernelLaunch();
Expand Down
58 changes: 43 additions & 15 deletions torch/csrc/jit/codegen/cuda/evaluator_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -383,19 +383,19 @@ KernelPrecomputedIntegers::KernelPrecomputedIntegers(kir::Kernel* kernel) {
initializeIntegerMachine();
}

// TODO: put this to base class
void KernelPrecomputedIntegers::bindTensorMetaData(
TensorView* tv,
const at::Tensor& at_tensor) {
std::vector<std::pair<Val*, int64_t>> ret;
const TensorArgAbstract* tensor_arg_abstract) {
const auto root_domain =
TensorDomain::noReductions(tv->domain()->getMaybeRFactorDomain());
TORCH_INTERNAL_ASSERT(
at_tensor.ndimension() == static_cast<int>(root_domain.size()),
tensor_arg_abstract->getRank() == static_cast<int>(root_domain.size()),
"Something went wrong configuring launch. Inputs do not match.");

for (const auto dim : c10::irange(root_domain.size())) {
auto extent = root_domain[dim]->extent();
auto value = at_tensor.sizes()[dim];
auto value = tensor_arg_abstract->getSize(dim);
bindValue(extent->evaluatorIndex(), value);
}
}
Expand Down Expand Up @@ -434,22 +434,37 @@ void KernelPrecomputedIntegers::initializeNamedScalars() {
}
}

// TODO: merge this one with above.
void KernelPrecomputedIntegers::bindKernelInputs(
kir::Kernel* kernel,
const at::ArrayRef<IValue>& aten_inputs) {
const KernelArgumentHolder& args) {
if (hasValidValues()) {
invalidate();
}

const auto& inputs = kernel->inputs();
TORCH_INTERNAL_ASSERT(
args.size() == inputs.size(), "kernel inputs size does not match args");

for (const auto i : c10::irange(inputs.size())) {
auto arg = args[i];
const auto input = inputs[i];
if (auto tensor_input = dynamic_cast<TensorView*>(input)) {
const auto aten_tensor = aten_inputs[i].toTensor();
bindTensorMetaData(tensor_input, aten_tensor);
if (const auto& tensor_arg_abstract =
dynamic_cast<const TensorArgAbstract*>(arg)) {
bindTensorMetaData(tensor_input, tensor_arg_abstract);
} else {
// TODO: cpu scalar of int type should be bound as scalar int as well
TORCH_CHECK(
arg->isType(ArgType::CpuScalarTensor),
"binding input to TensorView expects input arg to be of tensor type");
}
} else if (input->isScalar() && input->dtype() == DataType::Int) {
bindValue(input->evaluatorIndex(), aten_inputs[i].toInt());
TORCH_CHECK(
arg->isType(ArgType::Long),
"binding input to integer type expects input arg to be a scalar of Long type");
precomputedIntegersBaseType::bindValue(
input->evaluatorIndex(), *static_cast<const int64_t*>(arg->arg()));
}
}
}
Expand Down Expand Up @@ -489,38 +504,51 @@ FusionPrecomputedIntegers::FusionPrecomputedIntegers(Fusion* fusion)
initializeIntegerMachine();
}

// TODO: put this to base class
void FusionPrecomputedIntegers::bindTensorMetaData(
TensorView* tv,
const at::Tensor& at_tensor) {
const TensorArgAbstract* tensor_arg_abstract) {
const auto root_domain =
TensorDomain::noReductions(tv->getMaybeRFactorDomain());
TORCH_INTERNAL_ASSERT(
at_tensor.ndimension() == static_cast<int>(root_domain.size()),
tensor_arg_abstract->getRank() == static_cast<int>(root_domain.size()),
"Something went wrong configuring launch. Inputs do not match.");

for (const auto dim : c10::irange(root_domain.size())) {
auto extent = root_domain[dim]->extent();
auto value = at_tensor.sizes()[dim];
auto value = tensor_arg_abstract->getSize(dim);
precomputedIntegersBaseType::bindValue(extent->evaluatorIndex(), value);
}
}

void FusionPrecomputedIntegers::bindFusionInputs(
const at::ArrayRef<IValue>& aten_inputs) {
const KernelArgumentHolder& args) {
if (hasValidValues()) {
precomputedIntegersBaseType::invalidate();
}

const auto& inputs = fusion_->inputs();
TORCH_INTERNAL_ASSERT(
args.size() == inputs.size(), "kernel inputs size does not match args");

for (const auto i : c10::irange(inputs.size())) {
const auto input = inputs[i];
const ArgAbstract* arg = args[i];
if (auto tensor_input = dynamic_cast<TensorView*>(input)) {
const auto aten_tensor = aten_inputs[i].toTensor();
bindTensorMetaData(tensor_input, aten_tensor);
if (const auto& tensor_arg_abstract =
dynamic_cast<const TensorArgAbstract*>(arg)) {
bindTensorMetaData(tensor_input, tensor_arg_abstract);
} else {
TORCH_CHECK(
arg->isType(ArgType::CpuScalarTensor),
"binding input to TensorView expects input arg to be of tensor type");
}
} else if (input->isScalar() && input->getDataType() == DataType::Int) {
TORCH_CHECK(
arg->isType(ArgType::Long),
"binding input to integer type expects input arg to be a scalar of Long type");
precomputedIntegersBaseType::bindValue(
input->evaluatorIndex(), aten_inputs[i].toInt());
input->evaluatorIndex(), *static_cast<const int64_t*>(arg->arg()));
}
}
}
Expand Down
14 changes: 8 additions & 6 deletions torch/csrc/jit/codegen/cuda/evaluator_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,12 @@ class FusionPrecomputedIntegers
FusionPrecomputedIntegers(Fusion* fusion);

//! Bind concrete values from fusion runtime inputs
void bindFusionInputs(const at::ArrayRef<IValue>& aten_inputs);
void bindFusionInputs(const KernelArgumentHolder& args);

private:
void bindTensorMetaData(TensorView* tv, const at::Tensor& at_tensor);
void bindTensorMetaData(
TensorView* tv,
const TensorArgAbstract* tensor_arg_abstract);

private:
Fusion* fusion_ = nullptr;
Expand All @@ -302,9 +304,7 @@ class KernelPrecomputedIntegers
KernelPrecomputedIntegers(kir::Kernel* kernel);

//! Bind concrete values from fusion runtime inputs
void bindKernelInputs(
kir::Kernel* kernel,
const at::ArrayRef<IValue>& aten_inputs);
void bindKernelInputs(kir::Kernel* kernel, const KernelArgumentHolder& args);

//! Bind concrete values from launch constraints
void bindParallelExtents(
Expand All @@ -317,7 +317,9 @@ class KernelPrecomputedIntegers
void bindConcreteParallelTypeValue(ParallelType pt, int64_t value);

private:
void bindTensorMetaData(TensorView* tv, const at::Tensor& at_tensor);
void bindTensorMetaData(
TensorView* tv,
const TensorArgAbstract* tensor_arg_abstract);

//! Iterate through all the named scalars corresponding
//! to thread sizes and pre-group them by their parallel
Expand Down
Loading

0 comments on commit b247dcf

Please sign in to comment.