Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 10 additions & 40 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,42 +417,6 @@ def _rsub(self, node: fx.Node) -> relax.Var:

return self.block_builder.emit(relax.op.subtract(rhs, lhs))

########## Linear Algebra ##########

def _linalg_vector_norm(self, node: fx.Node) -> relax.Var:

args = self.retrieve_args(node)

data = args[0]
# Default ord=2 if not supplied
ord_val = args[1] if len(args) > 1 else 2.0
dim = args[2] if len(args) > 2 else None
keepdim = args[3] if len(args) > 3 else False

# If ord_val is a Python float/int, wrap it in a Relax const
# so that it matches data's dtype.
dtype = data.struct_info.dtype
ord_expr = (
ord_val if isinstance(ord_val, relax.Expr) else relax.const(float(ord_val), dtype)
)
# Reciprocal
reci_expr = (
relax.op.divide(relax.const(1.0, dtype), ord_expr)
if isinstance(ord_val, relax.Expr)
else relax.const(1.0 / float(ord_val), dtype)
)

# abs(data)
abs_data = self.block_builder.emit(relax.op.abs(data))
# abs_data^ord
abs_data_pow = self.block_builder.emit(relax.op.power(abs_data, ord_expr))
# sum over dim
reduced = self.block_builder.emit(relax.op.sum(abs_data_pow, dim, keepdims=keepdim))
# (sum(...))^(1/ord)
norm_val = self.block_builder.emit(relax.op.power(reduced, reci_expr))

return norm_val

########## Neural Network ##########

def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var:
Expand Down Expand Up @@ -980,16 +944,22 @@ def _norm(self, node: fx.Node) -> relax.Var:
elif order == "fro":
return self.block_builder.emit(
relax.op.sqrt(
relax.op.sum(relax.op.multiply(data, data), axis=axis, keepdims=keepdims),
relax.op.sum(relax.op.multiply(data, data), axis=axis, keepdims=keepdims)
)
)
else:
reci_order = relax.const(1 / order, dtype=dtype)
order = relax.const(order, dtype=dtype)
ord_expr = (
order if isinstance(order, relax.Expr) else relax.const(float(order), dtype=dtype)
)
reci_order = (
relax.op.divide(relax.const(1.0, dtype), ord_expr)
if isinstance(order, relax.Expr)
else relax.const(1.0 / order, dtype=dtype)
)
return self.block_builder.emit(
relax.op.power(
relax.op.sum(
relax.op.power(relax.op.abs(data), order), axis=axis, keepdims=keepdims
relax.op.power(relax.op.abs(data), ord_expr), axis=axis, keepdims=keepdims
),
reci_order,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def create_convert_map(
"__xor__.Tensor": self._binary_op(relax.op.bitwise_xor, operator.xor),
"__xor__.Scalar": self._binary_op(relax.op.bitwise_xor, operator.xor),
# linear algebra
"linalg_vector_norm.default": self._linalg_vector_norm,
"linalg_vector_norm.default": self._norm,
# neural network
"_native_batch_norm_legit_functional.default": self._batch_norm_legit_functional,
"_native_batch_norm_legit_no_training.default": self._batch_norm_legit_no_training,
Expand Down
112 changes: 112 additions & 0 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -4379,6 +4379,118 @@ def main(
verify_model(Narrow(), example_args, {}, Expected)


def test_norm():
class Norm(Module):
def __init__(self, p, dim=None, keepdim=False):
super().__init__()
self.p = p
self.dim = dim
self.keepdim = keepdim

def forward(self, x):
return torch.norm(x, p=self.p, dim=self.dim, keepdim=self.keepdim)

@tvm.script.ir_module
class Expected1:
@R.function
def main(
inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"),
) -> R.Tuple(R.Tensor((), dtype="float32")):
with R.dataflow():
lv: R.Tensor((), dtype="float32") = R.max(R.abs(inp_0), axis=None, keepdims=False)
gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
R.output(gv)
return gv

@tvm.script.ir_module
class Expected2:
@R.function
def main(
inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"),
) -> R.Tuple(R.Tensor((), dtype="float32")):
with R.dataflow():
lv: R.Tensor((), dtype="float32") = R.min(R.abs(inp_0), axis=None, keepdims=False)
gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
R.output(gv)
return gv

@tvm.script.ir_module
class Expected3:
@R.function
def main(
inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"),
) -> R.Tuple(R.Tensor((), dtype="float32")):
with R.dataflow():
lv: R.Tensor((1, 3, 5, 3), dtype="float32") = R.abs(inp_0)
lv1: R.Tensor((1, 3, 5, 3), dtype="float32") = R.power(lv, R.const(2, "float32"))
lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None, keepdims=False)
lv3: R.Tensor((), dtype="float32") = R.power(lv2, R.const(0.5, "float32"))
gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv3,)
R.output(gv)
return gv

@tvm.script.ir_module
class Expected4:
@R.function
def main(
inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"),
) -> R.Tuple(R.Tensor((), dtype="float32")):
with R.dataflow():
lv: R.Tensor((1, 3, 5, 3), dtype="float32") = R.abs(inp_0)
lv1: R.Tensor((1, 3, 5, 3), dtype="float32") = R.power(lv, R.const(1.0, "float32"))
lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None, keepdims=False)
lv3: R.Tensor((), dtype="float32") = R.power(lv2, R.const(1.0, "float32"))
gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv3,)
R.output(gv)
return gv

@tvm.script.ir_module
class Expected5:
@R.function
def main(
inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"),
) -> R.Tuple(R.Tensor((1, 1, 1, 1), dtype="float32")):
with R.dataflow():
lv: R.Tensor((1, 3, 5, 3), dtype="float32") = R.abs(inp_0)
lv1: R.Tensor((1, 3, 5, 3), dtype="float32") = R.power(lv, R.const(-4.0, "float32"))
lv2: R.Tensor((1, 1, 1, 1), dtype="float32") = R.sum(lv1, axis=None, keepdims=True)
lv3: R.Tensor((1, 1, 1, 1), dtype="float32") = R.power(
lv2, R.const(-0.25, "float32")
)
gv: R.Tuple(R.Tensor((1, 1, 1, 1), dtype="float32")) = (lv3,)
R.output(gv)
return gv

@tvm.script.ir_module
class Expected6:
@R.function
def main(
inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"),
) -> R.Tuple(R.Tensor((1, 1, 1, 1), dtype="float32")):
with R.dataflow():
lv: R.Tensor((1, 3, 5, 3), dtype="float32") = R.abs(inp_0)
lv1: R.Tensor((1, 3, 5, 3), dtype="float32") = R.power(lv, R.const(0.5, "float32"))
lv2: R.Tensor((1, 1, 1, 1), dtype="float32") = R.sum(lv1, axis=None, keepdims=True)
lv3: R.Tensor((1, 1, 1, 1), dtype="float32") = R.power(lv2, R.const(2.0, "float32"))
gv: R.Tuple(R.Tensor((1, 1, 1, 1), dtype="float32")) = (lv3,)
R.output(gv)
return gv

norms = [
((float("inf"), None, False), Expected1),
((float("-inf"), None, False), Expected2),
((float(2), None, False), Expected3),
((float(1.0), None, False), Expected4),
((float(-4), None, True), Expected5),
((float(0.5), None, True), Expected6),
]

example_args = (torch.randn(1, 3, 5, 3, dtype=torch.float32),)

for (p, dim, keepdim), expected in norms:
verify_model(Norm(p, dim=dim, keepdim=keepdim), example_args, {}, expected)


def test_eye():
class Eye1(Module):
def forward(self, input):
Expand Down
19 changes: 8 additions & 11 deletions tests/python/relax/test_frontend_from_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4938,19 +4938,16 @@ def main(
return gv

norms = [
(float("inf"), None, False),
(float("-inf"), None, False),
(float(2), None, False),
(float(1.0), None, False),
(float(-4), None, True),
(float(0.5), None, True),
("fro", None, False),
((float("inf"), None, False), Expected1),
((float("-inf"), None, False), Expected2),
((float(2), None, False), Expected3),
((float(1.0), None, False), Expected4),
((float(-4), None, True), Expected5),
((float(0.5), None, True), Expected6),
(("fro", None, False), Expected7),
]

for norm, expected in zip(
norms, [Expected1, Expected2, Expected3, Expected4, Expected5, Expected6, Expected7]
):
p, dim, keepdim = norm
for (p, dim, keepdim), expected in norms:
verify_model(Norm(p, dim=dim, keepdim=keepdim), input_info, {}, expected)


Expand Down