Skip to content

Commit

Permalink
Merge pull request apache#26 from heliqi/paddle_frontend
Browse files Browse the repository at this point in the history
Paddle frontend
  • Loading branch information
jiangjiajun authored Sep 10, 2021
2 parents 986914f + 82d0307 commit 2d2217b
Show file tree
Hide file tree
Showing 2 changed files with 198 additions and 158 deletions.
191 changes: 125 additions & 66 deletions python/tvm/relay/frontend/paddlepaddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,13 +480,19 @@ def convert_elementwise_op(g, op, block):
"""Operator converter for all the elementwise operators."""

op_map = {
"elementwise_div": lambda x, y: x / y,
"elementwise_add": lambda x, y: x + y,
"elementwise_mul": lambda x, y: x * y,
"elementwise_sub": lambda x, y: x - y,
"elementwise_mod": _op.mod,
"elementwise_pow": _op.power,
"elementwise_floordiv": _op.floor_divide
"elementwise_div": "divide",
"elementwise_add": "add",
"elementwise_mul": "multiply",
"elementwise_sub": "subtract",
"elementwise_mod": "mod",
"elementwise_pow": "power",
"elementwise_floordiv": "floor_divide",
"floor_mod": "floor_mod",
"equal": "equal",
"greater_than": "greater",
"less_equal": "less_equal",
"less_than": "less",
"not_equal": "not_equal",
}
op_func = op_map[op.type]
ipt0 = g.get_node(op.input("X")[0])
Expand All @@ -499,19 +505,11 @@ def convert_elementwise_op(g, op, block):
axis = axis + len(ipt0_shape)
if axis != len(ipt0_shape) - 1:
ipt1 = _op.expand_dims(ipt1, axis=axis, num_newaxis=(len(ipt0_shape) - axis - 1))
op_func = get_relay_op(op_func)
out = op_func(ipt0, ipt1)
g.add_node(op.output("Out")[0], out)


def convert_equal(g, op, block):
"""Operator converter for equal."""

x = g.get_node(op.input("X")[0])
y = g.get_node(op.input("Y")[0])
out = _op.equal(x, y)
g.add_node(op.output("Out")[0], out)


def convert_expand(g, op, block):
"""Operator converter for expand."""

Expand Down Expand Up @@ -602,10 +600,12 @@ def convert_fill_constant_batch_size_like(g, op, block):
dtype = block.var(op.output("Out")[0]).dtype
dtype = str(dtype).strip().split(".")[1]
input_shape = shape_of(x)
batch = _op.strided_slice(input_shape, begin=[input_dim_idx], end=[input_dim_idx+1]).astype("int32")
batch = _op.strided_slice(input_shape, begin=[input_dim_idx], end=[input_dim_idx + 1]).astype(
"int32"
)
shape_before = shape[:output_dim_idx]
shape_before = _expr.const(shape_before, dtype="int32")
shape_after = shape[output_dim_idx+1:]
shape_after = shape[output_dim_idx + 1 :]
shape_after = _expr.const(shape_after, dtype="int32")

out_shape = _op.concatenate([shape_before, batch, shape_after], axis=0)
Expand Down Expand Up @@ -746,15 +746,6 @@ def convert_leaky_relu(g, op, block):
g.add_node(op.output("Out")[0], out)


def convert_less_than(g, op, block):
"""Operator converter for less_than."""

x = g.get_node(op.input("X")[0])
y = g.get_node(op.input("Y")[0])
out = _op.less(x, y)
g.add_node(op.output("Out")[0], out)


def convert_lookup_table(g, op, block):
"""Operator converter for lookup_table_v2."""

Expand Down Expand Up @@ -946,15 +937,6 @@ def convert_mul(g, op, block):
g.add_node(op.output("Out")[0], out)


def convert_not_equal(g, op, block):
"""Operator converter for not_equal."""

x = g.get_node(op.input("X")[0])
y = g.get_node(op.input("Y")[0])
out = _op.not_equal(x, y)
g.add_node(op.output("Out")[0], out)


def convert_pool2d(g, op, block):
"""Operator converter for pool2d."""

Expand Down Expand Up @@ -1051,7 +1033,7 @@ def convert_pow(g, op, block):
out = _op.power(x, factor)
g.add_node(op.output("Out")[0], out)


def convert_range(g, op, block):
"""Operator converter for range."""

Expand Down Expand Up @@ -1297,36 +1279,61 @@ def convert_shape(g, op, block):
def convert_slice(g, op, block):
"""Operator converter for slice."""

def parameter_process(starts, ends, axes):
new_axes = []
new_starts = []
new_ends = []
pop_index = 0
for i in range(max(axes) + 1):
new_axes.append(i)
if i in axes:
new_starts.append(starts[pop_index])
new_ends.append(ends[pop_index])
pop_index += 1
else:
new_starts.append(0)
new_ends.append(np.iinfo(np.int32).max)
return new_starts, new_ends, new_axes

data = g.get_node(op.input("Input")[0])
starts = op.attr("starts")
ends = op.attr("ends")
dims = len(block.var(op.input("Input")[0]).shape)
dtype = "int64"

axes = op.attr("axes")
axes = _op.const(axes)
decrease_axis = op.attr("decrease_axis")
if isinstance(starts, int):
starts = [starts]
if isinstance(ends, int):
ends = [ends]
if isinstance(axes, int):
axes = [axes]
if isinstance(decrease_axis, int):
decrease_axis = [decrease_axis]
starts, ends, axes = parameter_process(starts, ends, axes)

starts = op.input("StartsTensor")
if starts:
starts = g.get_node(starts[0])
elif op.input("StartsTensorList"):
starts = []
for start_index in op.input("StartsTensorList"):
start_index = g.get_node(start_index)
if not isinstance(start_index, _expr.Expr):
start_index = _expr.const(start_index, dtype=dtype)
else:
start_index = start_index.astype(dtype)
starts.append(start_index)
starts = _op.concatenate(starts, axis=0)
else:
starts = op.attr("starts")
starts = _expr.const(starts)
if isinstance(starts, _expr.Expr):
starts = _op.scatter(
_op.const([0] * dims, dtype=infer_type(starts).checked_type.dtype),
axes,
starts,
axis=0,
)

data_shape = shape_of(data)
ends = op.input("EndsTensor")
if ends:
ends = g.get_node(ends[0])
elif op.input("EndsTensorList"):
ends = []
data_shape = data_shape.astype(dtype)
for end_index in op.input("EndsTensorList"):
end_index = g.get_node(end_index)
if not isinstance(end_index, _expr.Expr):
end_index = _expr.const(end_index, dtype=dtype)
else:
end_index = end_index.astype(dtype)
ends.append(end_index)
ends = _op.concatenate(ends, axis=0)
else:
ends = op.attr("ends")
ends = _expr.const(ends)
if isinstance(ends, _expr.Expr):
ends = _op.scatter(data_shape, axes, ends, axis=0)

out = _op.strided_slice(data, begin=starts, end=ends)
if decrease_axis:
out = _op.squeeze(out, axis=decrease_axis)
Expand All @@ -1347,6 +1354,52 @@ def convert_softmax(g, op, block):
g.add_node(op.output("Out")[0], out)


def convert_split(g, op, block):
"""Operator converter for split."""

x = g.get_node(op.input("X")[0])
axis = op.input("AxisTensor")
if axis:
axis = g.get_node(axis[0])
axis = infer_value(axis, g.get_params()).numpy().tolist()[0]
else:
axis = op.attr("axis")

sections = op.input("SectionsTensorList")
if sections:
tmp_section = []
for i in sections:
i = g.get_node(i)
i = infer_value(i, g.get_params()).numpy().tolist()
tmp_section.extend(i)
sections = tmp_section
else:
sections = op.attr("sections")
if sections:
indices = []
split_index = 0
for i in sections[:-1]:
if i == -1:
input_shape = infer_shape(x)[axis]
i = input_shape - np.sum(sections) - 1
split_index += i
indices.append(split_index)
else:
indices = op.attr("num")

out = _op.split(x, indices, axis)
for i, out_i in enumerate(out):
g.add_node(op.output("Out")[i], out_i)


def convert_square(g, op, block):
"""Operator converter for square."""

x = g.get_node(op.input("X")[0])
out = _op.multiply(x, x)
g.add_node(op.output("Out")[0], out)


def convert_squeeze(g, op, block):
"""Operator converter for squeeze2."""

Expand Down Expand Up @@ -1375,7 +1428,7 @@ def convert_topk(g, op, block):

g.add_node(op.output("Out")[0], outs[0])
g.add_node(op.output("Indices")[0], outs[1])


def convert_stack(g, op, block):
"""Operator converter for stack."""
Expand Down Expand Up @@ -1465,7 +1518,7 @@ def convert_unsqueeze(g, op, block):
"elementwise_mod": convert_elementwise_op,
"elementwise_pow": convert_elementwise_op,
"elementwise_floordiv": convert_elementwise_op,
"equal": convert_equal,
"equal": convert_elementwise_op,
"exp": convert_unary_op,
"expand_v2": convert_expand,
"feed": convert_feed,
Expand All @@ -1474,16 +1527,19 @@ def convert_unsqueeze(g, op, block):
"fill_constant_batch_size_like": convert_fill_constant_batch_size_like,
"flatten_contiguous_range": convert_flatten,
"floor": convert_unary_op,
"floor_mod": convert_elementwise_op,
"gather": convert_gather,
"gather_nd": convert_gather_nd,
"gelu": convert_gelu,
"greater_than": convert_elementwise_op,
"hard_sigmoid": convert_hard_sigmoid,
"hard_swish": convert_hard_swish,
"isinf": convert_unary_op,
"isinf_v2": convert_unary_op,
"layer_norm": convert_layer_norm,
"leaky_relu": convert_leaky_relu,
"less_than": convert_less_than,
"less_equal": convert_elementwise_op,
"less_than": convert_elementwise_op,
"lookup_table": convert_lookup_table,
"lookup_table_v2": convert_lookup_table,
"log": convert_unary_op,
Expand All @@ -1494,7 +1550,7 @@ def convert_unsqueeze(g, op, block):
"matmul_v2": convert_matmul,
"mul": convert_mul,
"nearest_interp_v2": convert_interpolate,
"not_equal": convert_not_equal,
"not_equal": convert_elementwise_op,
"pool2d": convert_pool2d,
"pad1d": convert_padding,
"pad2d": convert_padding,
Expand All @@ -1511,12 +1567,15 @@ def convert_unsqueeze(g, op, block):
"relu": convert_unary_op,
"reshape2": convert_reshape,
"rnn": convert_rnn,
"rsqrt": convert_unary_op,
"scale": convert_scale,
"shape": convert_shape,
"sigmoid": convert_unary_op,
"sin": convert_unary_op,
"slice": convert_slice,
"softmax": convert_softmax,
"split": convert_split,
"square": convert_square,
"squeeze2": convert_squeeze,
"stack": convert_stack,
"tan": convert_unary_op,
Expand Down
Loading

0 comments on commit 2d2217b

Please sign in to comment.