Skip to content
40 changes: 39 additions & 1 deletion test/test_jit_cuda_fuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@

if 'PYTORCH_NVFUSER_ENABLE' not in os.environ:
os.environ['PYTORCH_NVFUSER_ENABLE'] = ""
os.environ['PYTORCH_NVFUSER_ENABLE'] = 'linear_decomposition,conv_decomposition,' + os.environ['PYTORCH_NVFUSER_ENABLE']
os.environ['PYTORCH_NVFUSER_ENABLE'] = 'linear_decomposition,conv_decomposition,graph_op_fusion,' + \
os.environ['PYTORCH_NVFUSER_ENABLE']
if 'PYTORCH_NVFUSER_DISABLE' not in os.environ:
os.environ['PYTORCH_NVFUSER_DISABLE'] = ""
os.environ['PYTORCH_NVFUSER_DISABLE'] = 'fallback,fma,' + os.environ['PYTORCH_NVFUSER_DISABLE']
Expand Down Expand Up @@ -4152,6 +4153,43 @@ def t(x):
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x)

@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_index_select_fusion(self):
lookup_size = 68
feat_dim = 128
num_elements = 355984
lookup_tv = torch.rand(lookup_size, feat_dim, dtype=torch.float, device="cuda")
indies_tv = torch.randint(0, lookup_size, (num_elements,), device="cuda").to(dtype=torch.int)
sbf = torch.rand(num_elements, feat_dim, dtype=torch.float, device="cuda")

def t(x_kj, idx_kj, sbf):
sbf_res = torch.index_select(x_kj, 0, idx_kj) * sbf
sbf_res = sbf_res + 17
return sbf_res
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, lookup_tv, indies_tv, sbf)

@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_index_select_runtime_dim(self):
lookup_size = 68
feat_dim = 128
num_elements = 355984
dim = torch.tensor(0, device='cuda').to(dtype=torch.int)
lookup_tv = torch.rand(lookup_size, feat_dim, dtype=torch.float, device="cuda")
indies_tv = torch.randint(0, lookup_size, (num_elements,), dtype=torch.float, device="cuda").to(dtype=torch.long)
sbf = torch.rand(num_elements, feat_dim, dtype=torch.float, device="cuda")

def t(x_kj: torch.Tensor, idx_kj: torch.Tensor, sbf: torch.Tensor, dim : int):
sbf_res = torch.index_select(x_kj, dim, idx_kj) * sbf
sbf_res = sbf_res + 17
return sbf_res
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, lookup_tv, indies_tv, sbf, dim)

@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
Expand Down
47 changes: 47 additions & 0 deletions torch/csrc/jit/codegen/cuda/parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1513,6 +1513,37 @@ class IrParser {
nullptr);
}
}
{
if (isOptionEnabled(EnableOption::GraphOp)) {
auto ptr_op = getOperatorForLiteral(
"aten::index_select(Tensor self, int dim, Tensor index) -> Tensor");
REGISTER_PARSE_RULE(
ptr_op,
{
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getPWFormatValues(
c10::nullopt,
value_map[node->inputs()[0]->unique()],
value_map[node->inputs()[2]->unique()]);
auto input = list_val.front();
list_val.pop_front();
auto dim_value = constant_as<int>(node->input(1));
TORCH_INTERNAL_ASSERT(
dim_value.has_value(), "dim in index_select is not valid");
auto index = list_val.front();
list_val.pop_front();
Val* out = index_select(
input->as<TensorView>(),
dim_value.value(),
index->as<TensorView>());
value_map.emplace(
node->output()->unique(), ValueHolder(out, format));
},
isInputNonSizeZeroTensor,
nullptr);
}
}

{
auto ptr_op = getOperatorForLiteral(
Expand Down Expand Up @@ -4301,6 +4332,22 @@ bool insertProfileIValue(ProfilingRecord* pr, Node* node, size_t offset) {
return true;
}

static auto index_select_schema =
getOperatorForLiteral(
"aten::index_select(Tensor self, int dim, Tensor index) -> Tensor")
->schema();
if (node->matches(index_select_schema)) {
switch (offset) {
// argument 1: unsqueeze dim;
case 1:
profileInt(pr, node, offset);
break;
default:
return false;
}
return true;
}

static auto batch_norm_impl_index_schema =
getOperatorForLiteral(
"aten::_batch_norm_impl_index(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor, Tensor, Tensor, Tensor, int)")
Expand Down
32 changes: 32 additions & 0 deletions torch/csrc/jit/codegen/cuda/python_frontend/fusion_record.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ enum class RecordType {
VarianceMeanOp,
ViewOp,
PermuteOp,
IndexSelectOp,
};

//! RecordFunctor is the base class record for operations recorded by
Expand Down Expand Up @@ -1266,6 +1267,37 @@ struct ReductionOpRecord : RecordFunctor {
Nvf::DataType dtype_;
};

struct IndexSelectOpRecord : RecordFunctor {
IndexSelectOpRecord(
std::vector<State> _args,
std::vector<State> _outputs,
int64_t dim)
: RecordFunctor(
std::move(_args),
std::move(_outputs),
"index_select",
RecordType::IndexSelectOp),
dim_(dim) {}
virtual ~IndexSelectOpRecord() = default;
virtual RecordFunctor* clone() final {
return new IndexSelectOpRecord(*this);
}

void operator()(FusionDefinition& fd) final {
auto arg1 =
fd.getFusionState(args_.at(0).index)->template as<Nvf::TensorView>();
auto arg3 =
fd.getFusionState(args_.at(1).index)->template as<Nvf::TensorView>();

Nvf::Val* output = Nvf::index_select(arg1, dim_, arg3);
fd.setFusionState(outputs_.at(0).index, output);
}

private:
//! Dimension to select.
int64_t dim_;
};

//! Specialized Record Functor for recording FusionDefinition input scalars.

struct ScalarRecord : RecordFunctor {
Expand Down
24 changes: 24 additions & 0 deletions torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1236,6 +1236,30 @@ void initNvFuserPythonBindings(PyObject* module) {
py::arg("original_shape"),
py::arg("dim"),
py::return_value_policy::reference);

nvf_ops.def(
"index_select",
[](nvfuser::FusionDefinition::Operators& self,
nvfuser::Tensor arg1,
int64_t dim,
nvfuser::Tensor arg3) -> nvfuser::Tensor {
FUSER_PERF_SCOPE("Operators.index_select");
nvfuser::FusionDefinition* fd = self.fusion_definition;
nvfuser::Tensor output = fd->defineTensor();
fd->defineRecord(new nvfuser::IndexSelectOpRecord(
{
fd->recordingState(arg1()),
fd->recordingState(arg3()),
},
{fd->recordingState(output())},
dim));
return output;
},
py::arg("arg1"),
py::arg("dim"),
py::arg("arg3"),
py::return_value_policy::reference);

nvf_ops.def(
"view",
[](nvfuser::FusionDefinition::Operators& self,
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/codegen/cuda/type_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,7 @@ class NaiveTypePropagator {
case prim::t_copy:
case prim::transpose_copy:
case prim::unsqueeze_copy:
case aten::index_select:
case prim::view_copy: {
auto out_type = node->input(0)->type()->cast<TensorType>();
copyScalarTypeAndDeviceToOutput(out_type, node);
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/codegen/cuda/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ auto parseEnableOptions() {
options_map[EnableOption::LinearDecomposition] = true;
} else if (token == "conv_decomposition") {
options_map[EnableOption::ConvDecomposition] = true;
} else if (token == "graph_op_fusion") {
options_map[EnableOption::GraphOp] = true;
} else {
TORCH_CHECK(
false,
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/codegen/cuda/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ enum class EnableOption {
KernelProfile, //! Enable intra-kernel performance profiling
LinearDecomposition, //! Enable linear-bias decomposition
ConvDecomposition, //! Enable conv-bias decomposition
GraphOp, //! Enable graphOps(index_select/gather/scatter)
};

TORCH_CUDA_CU_API bool isOptionEnabled(EnableOption option);
Expand Down