Skip to content

Commit

Permalink
[Unity] Relax op: set (#13990)
Browse files Browse the repository at this point in the history
This PR is about the high-level tensor computation operators in Relax.

This PR includes the set operators.

Co-authored-by: Prakalp Srivastava <prakalp@octoml.ai>
  • Loading branch information
2 people authored and tqchen committed Mar 13, 2023
1 parent 54810c3 commit 1653505
Show file tree
Hide file tree
Showing 9 changed files with 1,244 additions and 0 deletions.
62 changes: 62 additions & 0 deletions include/tvm/relax/attrs/set.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relax/attrs/set.h
* \brief Attributes for set operators.
*/
#ifndef TVM_RELAX_ATTRS_SET_H_
#define TVM_RELAX_ATTRS_SET_H_

#include <tvm/relax/expr.h>

namespace tvm {
namespace relax {

/*! \brief Attributes used in unique operator */
struct UniqueAttrs : public tvm::AttrsNode<UniqueAttrs> {
bool sorted;
bool return_index;
bool return_inverse;
bool return_counts;
Optional<Integer> axis;

TVM_DECLARE_ATTRS(UniqueAttrs, "relax.attrs.UniqueAttrs") {
TVM_ATTR_FIELD(sorted).describe(
"Whether to sort the unique elements in ascending order before returning as output.");
TVM_ATTR_FIELD(return_index)
.describe(
"Whether to return an additional tensor with indices for where elements in the unique "
"tensor come from the original input.");
TVM_ATTR_FIELD(return_inverse)
.describe(
"Whether to return an additional tensor with indices for where elements in the "
"original input ended up in the returned unique list.");
TVM_ATTR_FIELD(return_counts)
.describe("Whether to return an additional tensor with counts of each unique elements");
TVM_ATTR_FIELD(axis).describe(
"The dimension to apply unique. If it is NullOpt, the unique values of the flattened input "
"is are returned.");
}
}; // struct UniqueAttrs

} // namespace relax
} // namespace tvm

#endif // TVM_RELAX_ATTRS_SET_H_
1 change: 1 addition & 0 deletions python/tvm/relax/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,6 @@
from .index import *
from .manipulate import *
from .op_attrs import *
from .set import *
from . import builtin
from . import memory
5 changes: 5 additions & 0 deletions python/tvm/relax/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,8 @@ class TakeAttrs(Attrs):
@tvm._ffi.register_object("relax.attrs.StridedSliceAttrs")
class StridedSliceAttrs(Attrs):
"""Attributes used in strided_slice operator"""


@tvm._ffi.register_object("relax.attrs.UniqueAttrs")
class UniqueAttrs(Attrs):
"""Attributes used for the unique operator"""
101 changes: 101 additions & 0 deletions python/tvm/relax/op/set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=import-outside-toplevel, redefined-builtin, unused-argument
"""Set operators."""
from typing import Optional

import numpy as np # type: ignore
import tvm

from . import _ffi_api
from ..expr import Expr


def unique(
x: Expr,
sorted: bool = True,
return_index: bool = False,
return_inverse: bool = False,
return_counts: bool = False,
axis: Optional[int] = None,
) -> Expr:
"""Find the unique elements in a given tensor.
In addition, it optionally returns
- the indices of the input tensor that give the unique values;
- the indices of the unique tensor that reconstruct the input tensor;
- the number of times each unique value comes up in the input tensor.
Parameters
----------
x : relax.Expr
The input tensor.
sorted : bool
Whether to sort the unique elements in ascending order before
returning as output.
return_index : bool
Whether to return an additional tensor with indices for where elements in
the unique tensor come from the original input.
return_inverse : bool
Whether to return an additional tensor with indices for where elements in
the original input ended up in the returned unique list.
return_counts : bool
Whether to return an additional tensor with counts of each unique elements.
axis : Optional
The dimension to apply unique.
If not specified, the unique values of the flattened input are returned.
Returns
-------
ret : relax.Expr
The created relax call with
"""

return _ffi_api.unique( # type: ignore
x, sorted, return_index, return_inverse, return_counts, axis
)


@tvm.register_func("relax.run.unique")
def numpy_unique(
x: tvm.nd.array,
sorted: int,
return_index: int,
return_inverse: int,
return_counts: int,
axis: Optional[int],
) -> tvm.nd.array:
"""Returns the unique elements of the input tensor.
Uses numpy.unique to compute unique elements.
"""
import builtins

# TODO(prakalp): add support for returning a tuple when return_inverse or return_counts is True
if bool(return_index) or bool(return_inverse) or bool(return_counts):
raise NotImplementedError("missing support return_inverse or return_counts set to true")
x_numpy = x.numpy()
# TODO(prakalp): use torch.unique instead of numpy when torch is installed in ci.
output_sorted_numpy, indices = np.unique(x_numpy, return_index=True)
if sorted:
return tvm.nd.array(output_sorted_numpy)
output_numpy = [x_numpy.flatten()[index] for index in builtins.sorted(indices, reverse=True)]
return tvm.nd.array(output_numpy)
2 changes: 2 additions & 0 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
shape_of,
strided_slice,
take,
unique,
)
from tvm.relax.struct_info import StructInfo
from tvm.relax.utils import args_converter
Expand Down Expand Up @@ -434,4 +435,5 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"strided_slice",
"take",
"tuple",
"unique",
]
103 changes: 103 additions & 0 deletions src/relax/op/tensor/set.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file set.cc
* \brief Relax set operators.
*/

#include "set.h"

#include <utility>
#include <vector>

namespace tvm {
namespace relax {

/* relax.unique */
TVM_REGISTER_NODE_TYPE(UniqueAttrs);

Expr unique(Expr x, bool sorted, bool return_index, bool return_inverse, bool return_counts,
Optional<Integer> axis) {
ObjectPtr<UniqueAttrs> attrs = make_object<UniqueAttrs>();
attrs->sorted = sorted;
attrs->return_index = return_index;
attrs->return_inverse = return_inverse;
attrs->return_counts = return_counts;
attrs->axis = std::move(axis);

static const Op& op = Op::Get("relax.unique");
return Call(op, {std::move(x)}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relax.op.unique").set_body_typed(unique);

StructInfo InferStructInfoUnique(const Call& call, const BlockBuilder& ctx) {
TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
const auto* attrs = call->attrs.as<UniqueAttrs>();
if (!data_sinfo->IsUnknownNdim() && attrs->axis.defined()) {
// Normalize the axis for sanity check purpose.
NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->axis.value()->value);
}

int n_int_return = static_cast<int>(attrs->return_index) +
static_cast<int>(attrs->return_inverse) +
static_cast<int>(attrs->return_counts);

std::vector<StructInfo> output_sinfo;
output_sinfo.reserve(1 + n_int_return);

// unique values
if (data_sinfo->ndim == 0) {
output_sinfo.push_back(
TensorStructInfo(ShapeExpr({IntImm(DataType::Int(64), /*value=*/1)}), data_sinfo->dtype));
} else if (attrs->axis.defined()) {
output_sinfo.push_back(TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim));
} else {
output_sinfo.push_back(TensorStructInfo(data_sinfo->dtype, /*ndim=*/1));
}

// index, reverse and counts
TensorStructInfo int_return{nullptr};
if (data_sinfo->ndim == 0) {
int_return =
TensorStructInfo(ShapeExpr({IntImm(DataType::Int(64), /*value=*/1)}), DataType::Int(64));
} else {
int_return = TensorStructInfo(DataType::Int(64), /*ndim=*/1);
}
for (int i = 0; i < n_int_return; ++i) {
output_sinfo.push_back(int_return);
}

if (output_sinfo.size() == 1) {
return output_sinfo[0];
} else {
return TupleStructInfo(output_sinfo);
}
}

TVM_REGISTER_OP("relax.unique")
.set_attrs_type<UniqueAttrs>()
.set_num_inputs(1)
.add_argument("x", "Tensor", "The input tensor")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoUnique)
.set_attr<FCallPacked>("FCallPacked", "relax.run.unique");

} // namespace relax
} // namespace tvm
40 changes: 40 additions & 0 deletions src/relax/op/tensor/set.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. Sex The NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. Sex The License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file set.h
* \brief The functions to make Relax set operator calls.
*/
#ifndef TVM_RELAX_OP_TENSOR_SET_H_
#define TVM_RELAX_OP_TENSOR_SET_H_

#include <tvm/relax/attrs/set.h>

#include "../op_common.h"

namespace tvm {
namespace relax {

Expr unique(Expr x, bool sorted, bool return_index, bool return_inverse, bool return_counts,
Optional<Integer> axis);

} // namespace relax
} // namespace tvm

#endif // TVM_RELAX_OP_TENSOR_SET_H_
Loading

0 comments on commit 1653505

Please sign in to comment.