Skip to content

Commit

Permalink
[RELAY][DYN] Dynamic broadcast_to, zeros, ones (apache#6007)
Browse files Browse the repository at this point in the history
* Dynamic BroadcastTo

* fixed lint!

* add test_one_hot() back

* add one_hot registration back

* Dynamic BroadcastTo

* fixed lint!

* add one_hot registration back

* fixed lint.. again

* fixed lint

* lint

* responding to comments

* skipping cuda in dynamic test

* skipping cuda in dynamic test

* fixed i386 test and GPU test

* lint

* starting ones and zeros

* fixed dynamic ones and zeros, wrote dyn ones and zeros test

* added static version of zeros, ones and added a check for size of types to static BroadCastToRel

* added dynamic to static pass for zeros and ones, dynamic test and dynamic to static test

* removed op_str in dyn to static pass test

* fixed lint

* fix lint hopefully

* removed import const

* removed import that was actually used

* copy all attributes from broadcast_to, ones, zeros, full

* responding to comments

* fixed build error

* finishing rebase

* fix lint

Co-authored-by: Lily Orth-Smith <lorthsmith@Lilys-MacBook-Pro.local>
  • Loading branch information
2 people authored and trevor-m committed Sep 3, 2020
1 parent a495537 commit e4d578d
Show file tree
Hide file tree
Showing 16 changed files with 384 additions and 76 deletions.
2 changes: 0 additions & 2 deletions python/tvm/relay/_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,6 @@ def __call__(self, args, attrs, type_args):
attrs = {}
if self.operator in (op.strided_slice,):
x = self.operator(*args)
elif self.operator in (op.zeros, op.ones, op.full, op.broadcast_to):
x = self.operator(*args, dtype=attrs["dtype"])
else:
x = self.operator(*args, **{k: self.convert(v) for k, v in attrs.items()})
if isinstance(x, expr.TupleWrapper):
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
# zeros
@register_compute("zeros")
def zeros_compute(attrs, inputs, output_type):
assert len(inputs) == 1
assert not inputs
return [topi.full(output_type.shape, output_type.dtype, 0.0)]

register_broadcast_schedule("zeros")
Expand All @@ -109,7 +109,7 @@ def zeros_like_compute(attrs, inputs, output_type):
# ones
@register_compute("ones")
def ones_compute(attrs, inputs, output_type):
assert len(inputs) == 1
assert not inputs
return [topi.full(output_type.shape, output_type.dtype, 1.0)]

register_broadcast_schedule("ones")
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/op/dyn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@

from . import _algorithm
from . import _transform
from . import _tensor
46 changes: 46 additions & 0 deletions python/tvm/relay/op/dyn/_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#pylint: disable=invalid-name, unused-argument, len-as-condition
"""Backend compiler related feature registration for dynamic ops"""

import topi

from ..op import register_shape_func, register_compute
from ..op import register_broadcast_schedule
from ..op import register_pattern, OpPattern
from .._tensor import full_shape_func, no_data_full_shape_func

# ones
@register_compute("dyn.ones")
def ones_compute(attrs, inputs, output_type):
assert len(inputs) == 1
return [topi.full(output_type.shape, output_type.dtype, 1.0)]

register_broadcast_schedule("dyn.ones")
register_pattern("dyn.ones", OpPattern.ELEMWISE)

@register_compute("dyn.zeros")
def zeros_compute(attrs, inputs, output_type):
assert len(inputs) == 1
return [topi.full(output_type.shape, output_type.dtype, 0.0)]

register_broadcast_schedule("dyn.zeros")
register_pattern("dyn.zeros", OpPattern.ELEMWISE)

register_shape_func("dyn.broadcast_to", True, full_shape_func)
register_shape_func("dyn.ones", True, no_data_full_shape_func)
register_shape_func("dyn.zeros", True, no_data_full_shape_func)
1 change: 1 addition & 0 deletions python/tvm/relay/op/dyn/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from tvm.te.hybrid import script
from .. import op as _reg

_reg.register_broadcast_schedule("dyn.broadcast_to")
_reg.register_injective_schedule("dyn.reshape")
_reg.register_broadcast_schedule("dyn.tile")

Expand Down
15 changes: 12 additions & 3 deletions python/tvm/relay/op/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from tvm.runtime import TVMContext as _TVMContext

from . import _make
from ..expr import Tuple, const
from .dyn import _make as _dyn_make
from ..expr import Tuple, Expr


# We create a wrapper function for each operator in the
Expand Down Expand Up @@ -939,8 +940,12 @@ def zeros(shape, dtype):
result : relay.Expr
The resulting tensor.
"""
if isinstance(shape, Expr):
return _dyn_make.zeros(shape, dtype)
if isinstance(shape, int):
shape = [shape]
if isinstance(shape, (list, tuple)):
shape = const(list(shape), "int32")
shape = list(shape)
return _make.zeros(shape, dtype)


Expand Down Expand Up @@ -976,8 +981,12 @@ def ones(shape, dtype):
result : relay.Expr
The resulting tensor.
"""
if isinstance(shape, Expr):
return _dyn_make.ones(shape, dtype)
if isinstance(shape, int):
shape = [shape]
if isinstance(shape, (list, tuple)):
shape = const(list(shape), "int32")
shape = list(shape)
return _make.ones(shape, dtype)


Expand Down
6 changes: 5 additions & 1 deletion python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,8 +661,12 @@ def broadcast_to(data, shape):
result : relay.Expr
The resulting tensor.
"""
if isinstance(shape, Expr):
return _dyn_make.broadcast_to(data, shape)
if isinstance(shape, int):
shape = [shape]
if isinstance(shape, (list, tuple)):
shape = const(list(shape), "int32")
shape = list(shape)
return _make.broadcast_to(data, shape)

def broadcast_to_like(data, broadcast_type):
Expand Down
109 changes: 109 additions & 0 deletions src/relay/op/dyn/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,22 @@
*/
#include "transform.h"

#include <topi/broadcast.h>
#include <topi/transform.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/runtime/registry.h>

#include <utility>
#include <vector>

namespace tvm {
namespace relay {
namespace dyn {

/* relay.dyn.reshape */

bool ReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// types: [data, newshape, result]
Expand Down Expand Up @@ -195,6 +198,112 @@ RELAY_REGISTER_OP("dyn.tile")
.set_attr<FTVMCompute>("FTVMCompute", TileCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

// broadcast_to operator
bool BroadCastToRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// types = [data_type, broadcast_shape_type, ret_type]
CHECK_EQ(types.size(), 3);

const auto* target_shape = types[1].as<TensorTypeNode>();
DataType out_dtype = types[0].as<TensorTypeNode>()->dtype;
// rank must be static
const IntImmNode* rank = target_shape->shape[0].as<IntImmNode>();
CHECK(rank) << "Target shape must have static rank"; // rank must be static even in dyn pass
// could add support for dyn rank in futures

std::vector<IndexExpr> oshape;
for (int i = 0; i < rank->value; ++i) {
oshape.push_back(Any());
}

reporter->Assign(types[2], TensorType(oshape, out_dtype));
return true;
}

Expr MakeBroadCastTo(Expr data, Expr shape) {
static const Op& op = Op::Get("dyn.broadcast_to");
auto attrs = make_object<InitOpAttrs>();
return Call(op, {data, shape}, Attrs(attrs), {});
}

Array<te::Tensor> BroadCastToCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
const auto* out_ttype = out_type.as<TensorTypeNode>();
return {topi::broadcast_to(inputs[0], out_ttype->shape)};
}

TVM_REGISTER_GLOBAL("relay.op.dyn._make.broadcast_to").set_body_typed(MakeBroadCastTo);

RELAY_REGISTER_OP("dyn.broadcast_to")
.describe(R"code(Broadcast the first input to match the shape argument.
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("shape", "Tensor", "Target shape.")
.set_support_level(4)
.add_type_rel("DynamicBroadCastTo", BroadCastToRel)
.set_attr<FTVMCompute>("FTVMCompute", BroadCastToCompute)
.set_attr<TOpPattern>("TOpPattern", kBroadcast);

// zeros and ones operator
bool InitOpRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// types = [zeros_shape, ret_type]
CHECK_EQ(types.size(), 2);
const InitOpAttrs* param = attrs.as<InitOpAttrs>();
const auto* fill_shape = types[0].as<TensorTypeNode>();
DataType out_dtype = param->dtype;

const IntImmNode* shape_shape = fill_shape->shape[0].as<IntImmNode>();
CHECK(shape_shape) << "Parameter shape must have static rank";

std::vector<IndexExpr> oshape;
for (int i = 0; i < shape_shape->value; ++i) {
oshape.push_back(Any());
}

reporter->Assign(types[1], TensorType(oshape, out_dtype));
return true;
}

Expr MakeZeros(Expr shape, DataType dtype) {
auto attrs = make_object<InitOpAttrs>();
attrs->dtype = std::move(dtype);
static const Op& op = Op::Get("dyn.zeros");
return Call(op, {shape}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op.dyn._make.zeros").set_body_typed(MakeZeros);

RELAY_REGISTER_OP("dyn.zeros")
.describe(R"code(Fill array with zeros.
)code" TVM_ADD_FILELINE)
.set_attrs_type<InitOpAttrs>()
.set_num_inputs(1)
.add_argument("shape", "Tensor", "Target shape.")
.set_support_level(3)
.add_type_rel("DynamicInitOp", InitOpRel);

Expr MakeOnes(Expr shape, DataType dtype) {
auto attrs = make_object<InitOpAttrs>();
attrs->dtype = std::move(dtype);
static const Op& op = Op::Get("dyn.ones");
return Call(op, {shape}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op.dyn._make.ones").set_body_typed(MakeOnes);

RELAY_REGISTER_OP("dyn.ones")
.describe(R"code(Fill array with ones.
)code" TVM_ADD_FILELINE)
.set_attrs_type<InitOpAttrs>()
.set_num_inputs(1)
.add_argument("shape", "Tensor", "Target shape.")
.set_support_level(3)
.add_type_rel("DynamicInitOp", InitOpRel);

} // namespace dyn
} // namespace relay
} // namespace tvm
6 changes: 3 additions & 3 deletions src/relay/op/make_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
namespace tvm {
namespace relay {

Expr MakeBroadCastTo(Expr data, Expr shape);
Expr MakeBroadCastTo(Expr data, Array<Integer> shape);

Expr MakeCast(Expr data, DataType dtype);

Expand All @@ -52,7 +52,7 @@ Expr MakeFull(Expr fill_value, Expr shape, DataType dtype);

Expr MakeLayoutTransform(Expr data, String src_layout, String dst_layout);

Expr MakeOnes(Expr shape, DataType dtype);
Expr MakeOnes(Array<Integer> shape, DataType dtype);

Expr MakePad(Expr data, Array<Array<IndexExpr>> pad_width, double pad_value, String pad_mode);

Expand All @@ -76,7 +76,7 @@ Expr MakeTopK(Expr data, int k, int axis, String ret_type, bool is_ascend, DataT

Expr MakeVariance(Expr data, Expr mean, Array<Integer> axis, bool keepdims, bool exclude);

Expr MakeZeros(Expr shape, DataType dtype);
Expr MakeZeros(Array<Integer> shape, DataType dtype);

} // namespace relay
} // namespace tvm
Expand Down
Loading

0 comments on commit e4d578d

Please sign in to comment.