Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PaddlePaddle Hackathon 4][Frontend][Paddle]Add tile/mish/stack/unstack/silu/softshrink/where op for paddle frontend #14160

Merged
merged 9 commits into from
Mar 8, 2023
108 changes: 108 additions & 0 deletions python/tvm/relay/frontend/paddlepaddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,6 +1084,21 @@ def convert_meshgrid(g, op, block):
g.add_node(op.output("Out")[i], out)


def convert_mish(g, op, block):
"""Operator converter for mish."""

x = g.get_node(op.input("X")[0])
dtype = infer_type(x).checked_type.dtype
threshold = _expr.const(op.attr("threshold"), dtype=dtype)
exp = _op.exp(x)
add = _op.add(exp, _expr.const(1.0, dtype))
log = _op.log(add)
softplus = _op.where(x > threshold, x, log)
tanh = _op.tanh(softplus)
out = _op.multiply(x, tanh)
g.add_node(op.output("Out")[0], out)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

I think we only need to implement this formula, there's no need to process situation that beta * x > threshold, it's just a strategy to make calculation stable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image
I refer to the api doc. The softplus in the mish function does not contain the beta parameter. And the attribute of mish(op.attr) only contain threshold.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I modified the implementation and only process the case of x <= threshold (beta = 1.0).



def convert_mul(g, op, block):
"""Operator converter for mul."""

Expand Down Expand Up @@ -1785,6 +1800,14 @@ def convert_shape(g, op, block):
g.add_node(op.output("Out")[0], out)


def convert_silu(g, op, block):
"""Operator converter for silu."""

x = g.get_node(op.input("X")[0])
out = _op.multiply(x, _op.sigmoid(x))
g.add_node(op.output("Out")[0], out)


def convert_size(g, op, block):
"""Operator converter for size."""

Expand Down Expand Up @@ -1950,6 +1973,19 @@ def convert_softsign(g, op, block):
g.add_node(op.output("Out")[0], out)


def convert_softshrink(g, op, block):
"""Operator converter for softshrink."""

x = g.get_node(op.input("X")[0])
dtype = infer_type(x).checked_type.dtype
threshold = _expr.const(op.attr("lambda"), dtype=dtype)
zeros = _op.zeros_like(x)
out = _op.where(x < -threshold, x + threshold, zeros) + _op.where(
x > threshold, x - threshold, zeros
)
g.add_node(op.output("Out")[0], out)


def convert_split(g, op, block):
"""Operator converter for split."""

Expand Down Expand Up @@ -1994,6 +2030,18 @@ def convert_split(g, op, block):
g.add_node(op.output("Out")[i], out_i)


def convert_stack(g, op, blcok):
"""Operator converter for stack."""

x = op.input("X")
all_inputs = []
for inp in x:
all_inputs.append(g.get_node(inp))
axis = op.attr("axis")
out = _op.stack(all_inputs, axis)
g.add_node(op.output("Y")[0], out)


def convert_square(g, op, block):
"""Operator converter for square."""

Expand Down Expand Up @@ -2025,6 +2073,37 @@ def convert_swish(g, op, block):
g.add_node(op.output("Out")[0], out)


def convert_tile(g, op, block):
"""Operator converter for tile."""

x = g.get_node(op.input("X")[0])
if op.input("RepeatTimes"):
reps = g.get_node(op.input("RepeatTimes")[0])
reps, infered = try_infer_value(reps, g.get_params())
if infered:
reps = reps.tolist()
elif op.input("repeat_times_tensor"):
reps = []
for rep_value in op.input("repeat_times_tensor"):
rep_value = g.get_node(rep_value).astype("int32")
reps.append(rep_value)
reps = _op.concatenate(reps, axis=0)
reps, infered = try_infer_value(reps, g.get_params())
if infered:
reps = reps.tolist()
else:
reps = op.attr("repeat_times")
infered = True

if not infered:
msg = 'Value {} in attribute "repeat_times" of operator Tile is not "valid."'
raise tvm.error.OpAttributeInvalid(msg.format(reps))

op_func = get_relay_op(op.type)
out = op_func(x, reps=reps)
g.add_node(op.output("Out")[0], out)


def convert_topk(g, op, block):
"""Operator converter for topk."""

Expand Down Expand Up @@ -2074,6 +2153,28 @@ def convert_unsqueeze(g, op, block):
g.add_node(op.output("Out")[0], x)


def convert_unstack(g, op, block):
"""Operator converter for unstack."""

x = g.get_node(op.input("X")[0])
axis = op.attr("axis")
indices_or_sections = len(op.output("Y"))
outs = _op.split(x, indices_or_sections=indices_or_sections, axis=axis)
for i, out in enumerate(outs):
out = _op.squeeze(out, axis=axis)
g.add_node(op.output("Y")[i], out)


def convert_where(g, op, block):
"""Operator converter for where."""

condition = g.get_node(op.input("Condition")[0])
x = g.get_node(op.input("X")[0])
y = g.get_node(op.input("Y")[0])
out = _op.where(condition, x, y)
g.add_node(op.output("Out")[0], out)


def convert_where_index(g, op, block):
"""Operator converter for where_index."""

Expand Down Expand Up @@ -2166,6 +2267,7 @@ def convert_where_index(g, op, block):
"matmul": convert_matmul,
"matmul_v2": convert_matmul,
"meshgrid": convert_meshgrid,
"mish": convert_mish,
"mul": convert_mul,
"mv": convert_mv,
"nearest_interp_v2": convert_interpolate,
Expand Down Expand Up @@ -2201,24 +2303,30 @@ def convert_where_index(g, op, block):
"shape": convert_shape,
"sigmoid": convert_unary_op,
"sign": convert_unary_op,
"silu": convert_silu,
"sin": convert_unary_op,
"sinh": convert_unary_op,
"size": convert_size,
"slice": convert_slice,
"softmax": convert_softmax,
"softplus": convert_softplus,
"softsign": convert_softsign,
"softshrink": convert_softshrink,
"split": convert_split,
"stack": convert_stack,
"strided_slice": convert_slice,
"sqrt": convert_unary_op,
"square": convert_square,
"squeeze2": convert_squeeze,
"swish": convert_swish,
"tan": convert_unary_op,
"tanh": convert_unary_op,
"tile": convert_tile,
"top_k_v2": convert_topk,
"transpose2": convert_transpose,
"unsqueeze2": convert_unsqueeze,
"unstack": convert_unstack,
"where": convert_where,
"where_index": convert_where_index,
}

Expand Down
185 changes: 185 additions & 0 deletions tests/python/frontend/paddlepaddle/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1784,5 +1784,190 @@ def where_index_1(inputs):
verify_model(where_index_1, input_data=input_data, use_vm=True)


@tvm.testing.uses_gpu
def test_forward_stack():
class Stack1(nn.Layer):
@paddle.jit.to_static
def forward(self, input0, input1, input2):
return paddle.stack([input0, input1, input2], axis=-1)

class Stack2(nn.Layer):
@paddle.jit.to_static
def forward(self, input0, input1, input2):
return paddle.stack([input0, input1, input2], axis=1)

class Stack3(nn.Layer):
@paddle.jit.to_static
def forward(self, input0, input1, input2):
return paddle.stack([input0, input1, input2], axis=2)

input_shapes = [[2, 3], [5, 10, 11], [3, 4, 5, 6]]
for input_shape in input_shapes:
input_data_0 = paddle.randn(shape=input_shape, dtype="float32")
input_data_1 = paddle.randn(shape=input_shape, dtype="float32")
input_data_2 = paddle.randn(shape=input_shape, dtype="float32")
verify_model(Stack1(), [input_data_0, input_data_1, input_data_2])
verify_model(Stack2(), [input_data_0, input_data_1, input_data_2])
verify_model(Stack3(), [input_data_0, input_data_1, input_data_2])


@tvm.testing.uses_gpu
def test_forward_unstack():
class UnStack1(nn.Layer):
@paddle.jit.to_static
def forward(self, inputs):
return paddle.unstack(inputs, axis=-1)

class UnStack2(nn.Layer):
@paddle.jit.to_static
def forward(self, inputs):
return paddle.unstack(inputs, axis=1)

class UnStack3(nn.Layer):
@paddle.jit.to_static
def forward(self, inputs):
return paddle.unstack(inputs, axis=0)

input_shapes = [[2, 3], [5, 10, 11], [3, 4, 5, 6], [1, 3, 4, 1, 1]]
for input_shape in input_shapes:
input_data = paddle.randn(shape=input_shape, dtype="float32")
verify_model(UnStack1(), input_data)
verify_model(UnStack2(), input_data)
verify_model(UnStack3(), input_data)


@tvm.testing.uses_gpu
def test_forward_silu():
class Silu(nn.Layer):
@paddle.jit.to_static
def forward(self, inputs):
return nn.functional.silu(inputs)

input_shapes = [[10], [2, 3], [5, 10, 11], [3, 4, 5, 6]]
for input_shape in input_shapes:
input_data = paddle.randn(shape=input_shape, dtype="float32")
verify_model(Silu(), input_data=input_data)


@tvm.testing.uses_gpu
def test_forward_softshrink():
@paddle.jit.to_static
def Softshrink1(input):
return nn.functional.softshrink(input, threshold=0.0)

@paddle.jit.to_static
def Softshrink2(input):
return nn.functional.softshrink(input, threshold=0.5)

@paddle.jit.to_static
def Softshrink3(input):
return nn.functional.softshrink(input, threshold=1.0)

x = paddle.to_tensor([-0.9, -0.2, 0.1, 0.8])
verify_model(Softshrink2, x)

input_shapes = [[10], [2, 3], [5, 10, 11], [3, 4, 5, 6]]
for input_shape in input_shapes:
input_data = paddle.randn(shape=input_shape, dtype="float32")
verify_model(Softshrink1, input_data=input_data)
verify_model(Softshrink2, input_data=input_data)
verify_model(Softshrink3, input_data=input_data)


@tvm.testing.uses_gpu
def test_forward_where():
@paddle.jit.to_static
def where1(x, y):
return paddle.where(x > 1, x, y)

@paddle.jit.to_static
def where2(x, y):
return paddle.where(x > y, x, y)

x = paddle.to_tensor([0.9383, 0.1983, 3.2, 1.2])
y = paddle.to_tensor([1.0, 1.0, 1.0, 1.0])
verify_model(where1, [x, y])

input_shapes = [[10], [2, 3], [5, 10, 11], [3, 4, 5, 6]]
for input_shape in input_shapes:
x = paddle.randn(shape=input_shape, dtype="float32")
y = paddle.randn(shape=input_shape, dtype="float32")
verify_model(where1, [x, y])
verify_model(where2, [x, y])


@tvm.testing.uses_gpu
def test_forward_tile():
class Tile1(nn.Layer):
@paddle.jit.to_static
def forward(self, inputs):
return paddle.tile(inputs, repeat_times=[10])

class Tile2(nn.Layer):
@paddle.jit.to_static
def forward(self, inputs):
return paddle.tile(inputs, repeat_times=[2, 3])

class Tile3(nn.Layer):
@paddle.jit.to_static
def forward(self, inputs):
return paddle.tile(inputs, repeat_times=[1, 2, 3])

class Tile4(nn.Layer):
@paddle.jit.to_static
def forward(self, inputs):
return paddle.tile(inputs, repeat_times=[2, 3, 4, 1, 5])

class Tile5(nn.Layer):
@paddle.jit.to_static
def forward(self, inputs):
reps = paddle.to_tensor([3, 2])
reps = paddle.cast(reps, "int32")
return paddle.tile(inputs, repeat_times=reps)

class Tile6(nn.Layer):
@paddle.jit.to_static
def forward(self, inputs):
rep_0 = paddle.to_tensor([3])
rep_1 = paddle.to_tensor([2])
rep_0 = paddle.cast(rep_0, "int32")
rep_1 = paddle.cast(rep_1, "int32")
return paddle.tile(inputs, repeat_times=[rep_0, rep_1])

input_shapes = [
[10],
[2, 3],
[3, 4, 5],
[5, 3, 1, 4],
[1, 3, 1, 6, 7],
]
for input_shape in input_shapes:
input_data = paddle.randn(shape=input_shape, dtype="float32")
verify_model(Tile1(), input_data=input_data)
verify_model(Tile2(), input_data=input_data)
verify_model(Tile3(), input_data=input_data)
verify_model(Tile4(), input_data=input_data)
verify_model(Tile5(), input_data=input_data)
verify_model(Tile6(), input_data=input_data)


@tvm.testing.uses_gpu
def test_forward_mish():
class Mish(nn.Layer):
@paddle.jit.to_static
def forward(self, inputs):
return nn.functional.mish(inputs)

input_shapes = [[10], [2, 3], [5, 10, 11], [3, 4, 5, 6]]
for input_shape in input_shapes:
input_data = paddle.randn(shape=input_shape, dtype="float32")
verify_model(Mish(), input_data=input_data)
input_data += 20.0
verify_model(Mish(), input_data=input_data)

input_data = paddle.to_tensor([-5.0, 0.0, 5.0, 23.1, 20.0])
verify_model(Mish(), input_data=input_data)


if __name__ == "__main__":
tvm.testing.main()