Skip to content

Conversation

@azazhu
Copy link
Collaborator

@azazhu azazhu commented Nov 15, 2022

This PR is frontend implementation of index_select(lookup_tv, dim, index_tv) fusion:

  • support arbitrarily rank
  • support arbitrarily dim value
  • index_tv can become consumer;
  • lookup_tv cannot become consumer.

This PR depends on index_select frontend PR #2183

OP def is here: https://pytorch.org/docs/stable/generated/torch.index_select.html

@zasdfgbnm zasdfgbnm requested a review from jjsjann123 November 15, 2022 07:38
…ent out call index_select func as it is in backend PR
@azazhu azazhu requested a review from jjsjann123 November 17, 2022 00:49
Copy link
Collaborator

@jjsjann123 jjsjann123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks about right.

There seems to be some unrelated changes under torch/csrc/jit/passes/, let's revert those.

"aten::logit(Tensor self, float? eps=None) -> Tensor",
"aten::lerp(Tensor self, Tensor end, Scalar weight) -> Tensor",
"aten::lerp(Tensor self, Tensor end, Tensor weight) -> Tensor",
"aten::index_select(Tensor self, int dim, Tensor index) -> Tensor",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope you don't want it here. This is legacy fuser not nvfuser.

case aten::log10:
case aten::frac:
case aten::lerp:
case aten::index_select:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to have index_select here to remove the guard?

dim_value.has_value(), "dim in index_select is not valid");
auto index = list_val.front();
list_val.pop_front();
Val* out = nullptr;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: Can we just put an assert(false) here with a not implemented? I suspect this will give us a segfault that's going to be confusing if index_select is somehow accidentally enabled.

auto arg3 =
fd.getFusionState(args_.at(1).index)->template as<Nvf::TensorView>();

Nvf::Val* output = nullptr;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.

@azazhu azazhu force-pushed the fw/idx_sel_frontend branch from d033d72 to c0228b7 Compare November 18, 2022 02:32
@azazhu azazhu requested a review from jjsjann123 November 22, 2022 01:19
indies_tv = torch.randint(0, lookup_size, (num_elements,), dtype=torch.float, device="cuda").to(dtype=torch.int)
sbf = torch.rand(num_elements, feat_dim, dtype=torch.float, device="cuda")

def failure(x_kj, idx_kj, sbf):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be able to support runtime scalar input. e.g. https://github.com/pytorch/pytorch/blob/391b593ca262432ccba1939f7448275cfd4f62e6/test/test_jit_cuda_fuser.py#L1106-L1110

runtime_dim shouldn't be a failure if the scalar is passed properly.

def func(x_kj, idx_kj, sbf, dim: int):
    sbf_res = torch.index_select(x_kj, dim, idx_kj) * sbf
    sbf_res = sbf_res + 17
    return sbf_res

And you would also need to run it like self._run_helper(t_jit, t, lookup_tv, indies_tv, sbf, 0)

Copy link
Collaborator

@jjsjann123 jjsjann123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, minor comments on proper test for runtime scalar, looks like the test is still not quite right, but code logic looks about right. I'm stamping it.

@azazhu azazhu force-pushed the fw/idx_sel_frontend branch from eaa1fbe to eadcb1c Compare November 24, 2022 11:37
@azazhu azazhu merged commit 18c788f into csarofeen:devel Nov 28, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants