Skip to content

Commit

Permalink
resolve conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
Laurawly committed Nov 3, 2020
1 parent cf0cc3f commit 93ed8e0
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions python/tvm/relay/op/image/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def compute_resize(attrs, inputs, out_type):

reg.register_injective_schedule("image.resize")


@script
def _resize_shape_func(image_shape, size, batch_axis, height_axis, width_axis, channel_axis):
out = output_tensor((4,), "int64")
Expand All @@ -50,6 +51,7 @@ def _resize_shape_func(image_shape, size, batch_axis, height_axis, width_axis, c
out[channel_axis] = image_shape[channel_axis]
return out


@reg.register_shape_func("image.resize", False)
def resize_shape_func(attrs, inputs, _):
"""
Expand All @@ -67,9 +69,15 @@ def resize_shape_func(attrs, inputs, _):
if letter == "C":
channel_axis = i
size = get_const_tuple(attrs.size)
return [_resize_shape_func(inputs[0], convert(size), convert(batch_axis),
convert(height_axis), convert(width_axis),
convert(channel_axis))]
return [
_resize_shape_func(inputs[0],
convert(size),
convert(batch_axis),
convert(height_axis),
convert(width_axis),
convert(channel_axis)
)
]


@reg.register_compute("image.resize3d")
Expand Down Expand Up @@ -163,6 +171,7 @@ def compute_affine_grid(attrs, inputs, out_dtype):

reg.register_injective_schedule("image.affine_grid")


@script
def _affine_grid_func(data, target_shape):
out = output_tensor((4,), "int64")
Expand All @@ -172,6 +181,7 @@ def _affine_grid_func(data, target_shape):
out[3] = int64(target_shape[1])
return out


@reg.register_shape_func("image.affine_grid", False)
def affine_grid_func(attrs, inputs, _):
"""
Expand All @@ -191,6 +201,7 @@ def compute_grid_sample(attrs, inputs, out_dtype):

reg.register_injective_schedule("image.grid_sample")


@script
def _grid_sample_func(data, grid):
out = output_tensor((4,), "int64")
Expand All @@ -200,6 +211,7 @@ def _grid_sample_func(data, grid):
out[3] = int64(grid[3])
return out


@reg.register_shape_func("image.grid_sample", False)
def grid_sample_func(attrs, inputs, _):
"""
Expand Down

0 comments on commit 93ed8e0

Please sign in to comment.