Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Topi] Use SimplifyInference for L2 Normalization. #4795

Merged
merged 1 commit into from
Jan 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 0 additions & 16 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,22 +611,6 @@ def schedule_lrn(attrs, outs, target):
reg.register_pattern("nn.lrn", OpPattern.OPAQUE)


# l2_normalize
@reg.register_compute("nn.l2_normalize")
def compute_l2_normalize(attrs, inputs, out_dtype, target):
"""Compute definition of l2 normalize"""
return [topi.nn.l2_normalize(inputs[0], attrs.eps, attrs.axis)]


@reg.register_schedule("nn.l2_normalize")
def schedule_l2_normalize(attrs, outs, target):
"""Schedule definition of l2 normalize"""
with target:
return topi.generic.schedule_l2_normalize(outs)


reg.register_pattern("nn.l2_normalize", OpPattern.OUT_ELEMWISE_FUSABLE)

# upsampling
reg.register_schedule("nn.upsampling", reg.schedule_injective)

Expand Down
5 changes: 5 additions & 0 deletions src/relay/pass/pattern_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,11 @@ inline Expr Divide(Expr lhs, Expr rhs) {
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
}

inline Expr Maximum(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("maximum");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
}

inline Expr ZerosLike(Expr e) {
static const Op& op = Op::Get("zeros_like");
return CallNode::make(op, {e});
Expand Down
27 changes: 22 additions & 5 deletions src/relay/pass/simplify_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,26 @@ Expr InstanceNormToInferUnpack(const Attrs attrs,
return out;
}

Expr L2NormToInferUnpack(const Attrs attrs, Expr data) {
const auto param = attrs.as<L2NormalizeAttrs>();
CHECK(param);

Expr epsilon = MakeConstantScalar(DataType::Float(32), static_cast<float>(param->eps));

Expr sqr = Multiply(data, data);
Expr sum = Maximum(Sum(sqr, param->axis, true, false), epsilon);
Expr sqrt = Sqrt(sum);
return Divide(data, sqrt);
}

class InferenceSimplifier : public ExprMutator {
public:
InferenceSimplifier()
: batch_norm_op_(Op::Get("nn.batch_norm")),
dropout_op_(Op::Get("nn.dropout")),
instance_norm_op_(Op::Get("nn.instance_norm")),
layer_norm_op_(Op::Get("nn.layer_norm")) {}
layer_norm_op_(Op::Get("nn.layer_norm")),
l2_norm_op_(Op::Get("nn.l2_normalize")) {}

Expr VisitExpr_(const TupleGetItemNode* n) final {
Expr new_e = ExprMutator::VisitExpr_(n);
Expand All @@ -155,12 +168,15 @@ class InferenceSimplifier : public ExprMutator {
ty_map_[new_n.as<CallNode>()->args[0]] = n->args[0]->checked_type();
} else if (n->op == layer_norm_op_) {
const auto* call = new_n.as<CallNode>();
return LayerNormToInferUnpack(call->attrs, call->args[0], call->args[1],
call->args[2], n->args[0]->checked_type());
return LayerNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2],
n->args[0]->checked_type());
} else if (n->op == instance_norm_op_) {
const auto* call = new_n.as<CallNode>();
return InstanceNormToInferUnpack(call->attrs, call->args[0], call->args[1],
call->args[2], n->args[0]->checked_type());
return InstanceNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2],
n->args[0]->checked_type());
} else if (n->op == l2_norm_op_) {
const auto* call = new_n.as<CallNode>();
return L2NormToInferUnpack(call->attrs, call->args[0]);
}
return new_n;
}
Expand All @@ -173,6 +189,7 @@ class InferenceSimplifier : public ExprMutator {
const Op& dropout_op_;
const Op& instance_norm_op_;
const Op& layer_norm_op_;
const Op& l2_norm_op_;
std::unordered_map<Expr, Type, ObjectHash, ObjectEqual> ty_map_;
};

Expand Down
49 changes: 0 additions & 49 deletions topi/include/topi/cuda/normalization.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,55 +71,6 @@ inline Schedule schedule_lrn(const Target &target, const Array<Tensor>& outs) {
return s;
}

/*!
* \brief Create a CUDA schedule for L2 normalization
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the given ops.
*/
inline Schedule schedule_l2_normalize(const Target &target, const Array<Tensor>& outs) {
Array<Operation> out_ops;
for (auto t : outs) {
out_ops.push_back(t->op);
}
Schedule s = create_schedule(out_ops);

std::function<void(Operation)> traverse;
traverse = [&](const Operation& op) {
// Inline all one-to-one-mapping operators except the last stage (output)
if (is_injective(op->tag) || op->tag == "l2_normalize") {
if (!detail::contains(s->outputs, op)) {
s[op].compute_inline();
}
for (auto tensor : op->InputTensors()) {
if (tensor->op->InputTensors().size() > 0) {
traverse(tensor->op);
}
}
} else if (op->tag == "comm_reduce") {
ScheduleReduce(target, op, s, false);
for (auto tensor : op->InputTensors()) {
traverse(tensor->op);
}
} else {
LOG(ERROR) << "Unsupported operator " << op->tag;
}
};

traverse(outs[0]->op);
int num_thread = 64;
Tensor l2_normalize = outs[0];
IterVar block_x = tvm::te::thread_axis(Range(), "blockIdx.x");
IterVar thread_x = tvm::te::thread_axis(Range(0, num_thread), "threadIdx.x");
IterVar xto, xti;
s[l2_normalize].split_by_nparts(l2_normalize->op.as<ComputeOpNode>()->axis[1],
num_thread, &xto, &xti);
s[l2_normalize].bind(l2_normalize->op.as<ComputeOpNode>()->axis[0], block_x);
s[l2_normalize].bind(xto, thread_x);
return s;
}
} // namespace cuda
} // namespace topi
#endif // TOPI_CUDA_NORMALIZATION_H_
72 changes: 0 additions & 72 deletions topi/include/topi/nn/l2_normalize.h

This file was deleted.

11 changes: 0 additions & 11 deletions topi/include/topi/rocm/normalization.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,6 @@ inline Schedule schedule_lrn(const Target &target, const Array<Tensor>& outs) {
return topi::cuda::schedule_lrn(target, outs);
}

/*!
* \brief Create a rocm schedule for L2 Normalization
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the given ops.
*/
inline Schedule schedule_l2_normalize(const Target &target, const Array<Tensor>& outs) {
return topi::cuda::schedule_l2_normalize(target, outs);
}
} // namespace rocm
} // namespace topi
#endif // TOPI_ROCM_NORMALIZATION_H_
2 changes: 1 addition & 1 deletion topi/python/topi/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from .injective import schedule_injective, schedule_elemwise, schedule_broadcast
from .dense import schedule_dense
from .pooling import schedule_pool, schedule_adaptive_pool
from .nn import schedule_lrn, schedule_l2_normalize
from .nn import schedule_lrn
from .batch_matmul import schedule_batch_matmul
from .vision import *
from . import ssd
Expand Down
19 changes: 0 additions & 19 deletions topi/python/topi/cuda/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,3 @@ def schedule_lrn(outs):
target = tvm.target.current_target(allow_none=False)
cpp_target = cpp.TEST_create_target(target.target_name)
return cpp.cuda.schedule_lrn(cpp_target, outs)

@generic.schedule_l2_normalize.register(["cuda"])
def schedule_l2_normalize(outs):
"""Schedule for L2 normalize

Parameters
----------
outs: Array of Tensor
The computation graph description of L2 normalize
in the format of an array of tensors.

Returns
-------
sch: Schedule
The computation schedule for the op.
"""
target = tvm.target.current_target(allow_none=False)
cpp_target = cpp.TEST_create_target(target.target_name)
return cpp.cuda.schedule_l2_normalize(cpp_target, outs)
18 changes: 0 additions & 18 deletions topi/python/topi/generic/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,24 +649,6 @@ def schedule_lrn(outs):
cpp_target = cpp.TEST_create_target(target.target_name)
return cpp.generic.default_schedule(cpp_target, outs, False)

@tvm.target.generic_func
def schedule_l2_normalize(outs):
"""Schedule for l2 normalize

Parameters
----------
outs: Array of Tensor
The computation graph description of l2 normalize
in the format of an array of tensors.

Returns
-------
sch: Schedule
The computation schedule for the op.
"""
target = tvm.target.current_target(allow_none=False)
cpp_target = cpp.TEST_create_target(target.target_name)
return cpp.generic.default_schedule(cpp_target, outs, False)

@tvm.target.generic_func
def schedule_sparse_dense(outs):
Expand Down
1 change: 0 additions & 1 deletion topi/python/topi/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from .local_response_norm import *
from .bitserial_conv2d import *
from .bitserial_dense import *
from .l2_normalize import *
from .batch_matmul import *
from .sparse import *
from .pad import *
Expand Down
45 changes: 0 additions & 45 deletions topi/python/topi/nn/l2_normalize.py

This file was deleted.

6 changes: 0 additions & 6 deletions topi/python/topi/rocm/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,3 @@ def schedule_lrn(outs):
target = tvm.target.current_target(allow_none=False)
cpp_target = cpp.TEST_create_target(target.target_name)
return cpp.rocm.schedule_lrn(cpp_target, outs)

@generic.schedule_l2_normalize.register(["rocm", "gpu"])
def schedule_l2_normalize(outs):
target = tvm.target.current_target(allow_none=False)
cpp_target = cpp.TEST_create_target(target.target_name)
return cpp.rocm.schedule_l2_normalize(cpp_target, outs)
Loading