Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions src/relax/op/tensor/manipulate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2047,11 +2047,44 @@ StructInfo InferStructInfoFlip(const Call& call, const BlockBuilder& ctx) {
return data_sinfo;
}

InferLayoutOutput InferLayoutFlip(
const Call& call, const ffi::Map<ffi::String, ffi::Array<ffi::String>>& desired_layouts,
const VarLayoutMap& var_layout_map) {
ICHECK(NoDesiredLayout(call, desired_layouts));

const auto* attrs = call->attrs.as<FlipAttrs>();
ICHECK(attrs != nullptr) << "Invalid Call";
const auto* tensor_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now";

LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]);
int ndim = tensor_sinfo->ndim;

if (existing_layout->layout.ndim() != existing_layout->layout.ndim_primal()) {
existing_layout = LayoutDecision(InitialLayout(ndim));
}

int axis = attrs->axis.IntValue();
if (axis < 0) {
axis += ndim;
}

const int new_axis = FindAxis(existing_layout->layout, axis);
ICHECK_GE(new_axis, 0) << "Failed to find transformed axis";

ObjectPtr<FlipAttrs> new_attrs = ffi::make_object<FlipAttrs>(*attrs);
new_attrs->axis = Integer(new_axis);

return InferLayoutOutput({existing_layout}, {existing_layout}, Attrs(new_attrs));
}

TVM_REGISTER_OP("relax.flip")
.set_attrs_type<FlipAttrs>()
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoFlip)
.set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutFlip)
.set_attr<Bool>("FPurity", Bool(true));

/* relax.gather_elements */
Expand Down
44 changes: 44 additions & 0 deletions tests/python/relax/test_transform_convert_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -5283,5 +5283,49 @@ def main(
verify(Input, Expected)


def test_conv2d_flip():
@I.ir_module
class Input:
@R.function
def main(
x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32")
) -> R.Tensor(None, "float32", ndim=4):
with R.dataflow():
gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32")
gv2: R.Tensor((2, 4, 26, 26), "float32") = R.flip(gv, axis=1)
R.output(gv2)
return gv2

@I.ir_module
class Expected:
@R.function
def main(
x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32")
) -> R.Tensor(None, dtype="float32", ndim=4):
with R.dataflow():
lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1])
lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1])
gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
lv,
lv1,
strides=[1, 1],
padding=[0, 0, 0, 0],
dilation=[1, 1],
groups=1,
data_layout="NHWC",
kernel_layout="OHWI",
out_layout="NHWC",
out_dtype="float32",
)
lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.flip(gv, axis=3)
gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims(
lv2, axes=[0, 3, 1, 2]
)
R.output(gv2)
return gv2

verify(Input, Expected)


if __name__ == "__main__":
tvm.testing.main()
Loading