Skip to content

[Relax][PyTorch] CrossEntropyLoss #17863

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
May 8, 2025
Merged
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
bcb6998
tests for add'l modules
hugolatendresse Apr 19, 2025
751121b
no sort
hugolatendresse Apr 19, 2025
175ac34
cross entropy test passes
hugolatendresse Apr 19, 2025
03d04be
cleanup
hugolatendresse Apr 19, 2025
cf77abc
fix expand
hugolatendresse Apr 19, 2025
aad6ae1
merge main
hugolatendresse Apr 21, 2025
f503d6a
remove new e2e tests
hugolatendresse Apr 21, 2025
d583016
remove new e2e tests
hugolatendresse Apr 21, 2025
0e4ca8d
convert e2e test to unit test
hugolatendresse Apr 21, 2025
320ad6b
unit test
hugolatendresse Apr 21, 2025
3f8247c
restore tests
hugolatendresse Apr 21, 2025
b504820
Merge branch 'main' of https://github.com/apache/tvm into cross_entropy
hugolatendresse Apr 21, 2025
e7d6b97
move
hugolatendresse Apr 28, 2025
9b2d998
add new tests
hugolatendresse Apr 28, 2025
8897b28
add new tests from 17862
hugolatendresse Apr 28, 2025
2a96153
whitespace
hugolatendresse Apr 28, 2025
d600982
merge main
hugolatendresse May 4, 2025
a60a862
print statemetns
hugolatendresse May 4, 2025
ea1ad55
Merge branch 'main' of https://github.com/apache/tvm into move_e2e_ni…
hugolatendresse May 4, 2025
fb85449
all tests pass
hugolatendresse May 4, 2025
22ede51
cleanup - all tests still pass
hugolatendresse May 4, 2025
8ca3434
Merge branch 'move_e2e_nightly' into cross_entropy3
hugolatendresse May 4, 2025
396d447
cleanup. All nightly tests pass
hugolatendresse May 4, 2025
f92ce3e
Merge pull request #20 from hugolatendresse/cross_entropy3
hugolatendresse May 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions python/tvm/dlight/gpu/general_reduction.py
Original file line number Diff line number Diff line change
@@ -61,6 +61,23 @@ def apply( # pylint: disable=too-many-locals
# Align the number of block iters of the last block.
num_last_block_iter = len(block_infos[-1].dom_kind())
if num_last_block_iter < len(dom_kind):
# If the last block is a scalar value, there is nothing left to
# tile/parallelise, and `iters` is an empty tuple.
# Add a unit thread loop so the final write happens inside a valid
# GPU thread environment.
if num_last_block_iter == 0:
# Put every block (both the running reductions and the final
# scalar write) inside a trivial GPU thread. The very first block
# gets a `blockIdx.x` wrapper so that kernels still have a unique
# block scope.
for i, info in enumerate(block_infos):
loop_rv = sch.add_unit_loop(info.block_rv)
if i == 0:
sch.bind(loop_rv, "blockIdx.x")
else:
sch.bind(loop_rv, "threadIdx.x")

return sch

def f_layout_mapping(*iters):
analyzer = arith.Analyzer()
19 changes: 19 additions & 0 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
@@ -776,6 +776,25 @@ def _conv3d(self, node: fx.Node) -> relax.Var:
groups=groups,
)

def _cross_entropy_loss(
self,
preds: relax.Expr,
targets: relax.Expr,
weights: Optional[relax.Expr],
reduction: str,
ignore_index: int,
) -> relax.Expr:
log_probs = relax.op.nn.log_softmax(preds)
return self.block_builder.emit(
relax.op.nn.nll_loss(
log_probs,
targets,
weights,
reduction,
ignore_index,
)
)

def _einsum(self, node: fx.Node) -> relax.Var:
import torch # type: ignore

11 changes: 10 additions & 1 deletion python/tvm/relax/frontend/torch/exported_program_translator.py
Original file line number Diff line number Diff line change
@@ -66,7 +66,7 @@ def _reciprocal(self, node: fx.Node) -> relax.Var:

########## Neural Network ##########

def _batch_norm(self, node: fx.Node, training) -> relax.Var:
def _batch_norm(self, node: fx.Node, training: bool) -> relax.Var:
import numpy as np

x = self.env[node.args[0]]
@@ -113,6 +113,14 @@ def _batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var:
training = False
return self._batch_norm(node, training)

def _cross_entropy_default(self, node: fx.Node) -> relax.Expr:
preds = self.env[node.args[0]]
targets = self.env[node.args[1]]
weight = self.env.get(node.args[2], None) if len(node.args) > 2 else None
reduction = node.kwargs.get("reduction", "mean")
ignore_index = node.kwargs.get("ignore_index", -100)
return self._cross_entropy_loss(preds, targets, weight, reduction, ignore_index)

def _group_norm(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
num_groups = node.args[1]
@@ -399,6 +407,7 @@ def create_convert_map(
"conv1d.default": self._conv1d,
"conv2d.default": self._conv2d,
"conv3d.default": self._conv3d,
"cross_entropy_loss.default": self._cross_entropy_default,
"einsum.default": self._einsum,
"embedding.default": lambda node: self._embedding_impl(
self.env[node.args[1]], self.env[node.args[0]]
17 changes: 7 additions & 10 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
@@ -308,12 +308,7 @@ def _cross_entropy(self, node: fx.Node) -> relax.Expr:
weights = self.env.get(node.kwargs["weight"], None)
reduction = node.kwargs["reduction"]
ignore_index = node.kwargs["ignore_index"]

return self.block_builder.emit(
relax.op.nn.nll_loss(
relax.op.nn.log_softmax(preds), targets, weights, reduction, ignore_index
)
)
return self._cross_entropy_loss(preds, targets, weights, reduction, ignore_index)

def _cross_entropy_module(self, node: fx.Node) -> relax.Expr:
preds = self.env[node.args[0]]
@@ -330,10 +325,12 @@ def _cross_entropy_module(self, node: fx.Node) -> relax.Expr:
reduction = module.reduction
ignore_index = module.ignore_index

return self.block_builder.emit(
relax.op.nn.nll_loss(
relax.op.nn.log_softmax(preds), targets, weights, reduction, ignore_index
)
return self._cross_entropy_loss(
preds,
targets,
weights,
reduction,
ignore_index,
)

def _embedding_module(self, node: fx.Node) -> relax.Var:
Loading