Skip to content

Commit

Permalink
[Transform] Legalize some operators (#101)
Browse files Browse the repository at this point in the history
This PRs add legalizations for the following operators:
- `relax.nn.log_softmax`
- `relax.nn.cross_entropy_without_logits`
- `relax.nn.cross_entropy_with_logits`
- `relax.exp`
  • Loading branch information
SiriusNEO authored and MasterJH5574 committed Feb 8, 2023
1 parent b4aba65 commit 57fee16
Show file tree
Hide file tree
Showing 3 changed files with 719 additions and 0 deletions.
55 changes: 55 additions & 0 deletions python/tvm/relax/transform/legalize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,38 @@ def _nn_softmax(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(topi.nn.softmax, call.args[0], call.attrs.axis)


def _nn_log_softmax(bb: BlockBuilder, call: Call):
return bb.call_te(topi.nn.log_softmax, call.args[0], call.attrs.axis)


def _nn_cross_entropy_without_logits(bb: BlockBuilder, call: Call):
def te_cross_entropy_without_logits(x, y):
if len(x.shape) > 1:
return -topi.sum(topi.log(x) * y) / x.shape[0]
return -topi.sum(topi.log(x) * y)

return bb.call_te(
te_cross_entropy_without_logits,
call.args[0],
call.args[1],
primfunc_name_hint="cross_entropy_without_logits",
)


def _nn_cross_entropy_with_logits(bb: BlockBuilder, call: Call):
def te_cross_entropy_with_logits(x, y):
if len(x.shape) > 1:
return -topi.sum(x * y) / x.shape[0]
return -topi.sum(x * y)

return bb.call_te(
te_cross_entropy_with_logits,
call.args[0],
call.args[1],
primfunc_name_hint="cross_entropy_with_logits",
)


def _nn_batch_norm(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(
topi.nn.batch_norm,
Expand Down Expand Up @@ -547,6 +579,24 @@ def _nn_dropout(bb: BlockBuilder, call: Call) -> Expr:
return call


def _nn_nll_loss(bb: BlockBuilder, call: Call) -> Expr:
if len(call.args) == 2:
# TODO(relax-team): handle optional arugment weight of NLLLoss
logging.info(
"Can not legalize it now, because don't know how to set "
"the default value of the optional argument 'weight' of NLLLoss."
)
return call
return bb.call_te(
topi.nn.nll_loss,
call.args[0],
call.args[1],
call.args[2],
reduction=call.attrs.reduction,
ignore_index=call.attrs.ignore_index,
)


##################### Image #####################


Expand Down Expand Up @@ -578,6 +628,7 @@ def _image_resize2d(bb: BlockBuilder, call: Call) -> Expr:
# Arithmetic and comparison
"relax.cos": _unary(topi.cos),
"relax.log": _unary(topi.log),
"relax.exp": _unary(topi.exp),
"relax.negative": _unary(topi.negative),
"relax.sigmoid": _unary(topi.sigmoid),
"relax.sin": _unary(topi.sin),
Expand Down Expand Up @@ -642,9 +693,13 @@ def _image_resize2d(bb: BlockBuilder, call: Call) -> Expr:
"relax.nn.gelu": _nn_gelu,
"relax.nn.silu": _nn_silu,
"relax.nn.softmax": _nn_softmax,
"relax.nn.log_softmax": _nn_log_softmax,
"relax.nn.cross_entropy_without_logits": _nn_cross_entropy_without_logits,
"relax.nn.cross_entropy_with_logits": _nn_cross_entropy_with_logits,
"relax.nn.batch_norm": _nn_batch_norm,
"relax.nn.layer_norm": _nn_layer_norm,
"relax.nn.dropout": _nn_dropout,
"relax.nn.nll_loss": _nn_nll_loss,
# Image
"relax.image.resize2d": _image_resize2d,
# Todo(relax-team): Introduce cumsum for GPT-2
Expand Down
Loading

0 comments on commit 57fee16

Please sign in to comment.