Skip to content

Commit

Permalink
format change, remove [-+] in format RE
Browse files Browse the repository at this point in the history
  • Loading branch information
Menooker committed May 12, 2020
1 parent 38f98ed commit 81c344a
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 24 deletions.
36 changes: 19 additions & 17 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@

logger = logging.getLogger('strategy')

_NCHWc_matcher = re.compile("^NCHW[-+]?[0-9]+c$")
_OIHWio_matcher = re.compile("^OIHW[-+]?[0-9]+i[-+]?[0-9]+o$")
_NCHWc_matcher = re.compile("^NCHW[0-9]+c$")
_OIHWio_matcher = re.compile("^OIHW[0-9]+i[-+]?[0-9]+o$")

@schedule_injective.register("cpu")
def schedule_injective_cpu(attrs, outs, target):
Expand Down Expand Up @@ -88,13 +88,7 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target):
raise ValueError("dilation should be positive value")

if groups == 1:
if layout.startswith("NCHW"):
if layout != "NCHW":
# check if layout is NCHWxc
assert _NCHWc_matcher.match(layout)
assert _OIHWio_matcher.match(kernel_layout)
else:
assert kernel_layout == "OIHW"
def add_implementation_nchw():
if topi.x86.is_int8_hw_support(data.dtype, kernel.dtype):
strategy.add_implementation(
wrap_compute_conv2d(topi.x86.conv2d_nchw_int8),
Expand All @@ -105,6 +99,12 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target):
wrap_compute_conv2d(topi.x86.conv2d_nchw),
wrap_topi_schedule(topi.x86.schedule_conv2d_nchw),
name="conv2d_nchw.x86")
if layout == "NCHW":
assert kernel_layout == "OIHW"
add_implementation_nchw()
elif _NCHWc_matcher.match(layout): # check if layout is NCHWxc
assert _OIHWio_matcher.match(kernel_layout) # check if kernel is OIHWio
add_implementation_nchw()
elif layout == "NHWC":
assert kernel_layout == "HWIO"
logger.warning("For x86 target, NCHW layout is recommended for conv2d.")
Expand All @@ -122,14 +122,7 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target):
else:
raise RuntimeError("Unsupported conv2d layout {} for x86".format(layout))
elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
if layout.startswith("NCHW"):
if layout != "NCHW":
# check if layout is NCHWxc
assert _NCHWc_matcher.match(layout)
assert _OIHWio_matcher.match(kernel_layout)
else:
assert kernel_layout == "OIHW"
channel_multiplier = get_const_tuple(inputs[1].shape)[1]
def add_implementation_depthwise_nchw(channel_multiplier):
if channel_multiplier == 1 and dilation_h == 1 and dilation_w == 1:
strategy.add_implementation(
wrap_compute_conv2d(topi.x86.depthwise_conv2d_nchw),
Expand All @@ -142,6 +135,15 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target):
wrap_compute_conv2d(topi.nn.depthwise_conv2d_nchw),
wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nchw),
name="depthwise_conv2d_nchw.generic")
if layout == "NCHW":
assert kernel_layout == "OIHW"
channel_multiplier = get_const_tuple(inputs[1].shape)[1]
add_implementation_depthwise_nchw(channel_multiplier)
elif _NCHWc_matcher.match(layout): # check if layout is NCHWxc
assert _OIHWio_matcher.match(kernel_layout) # check if kernel is OIHWio
kernel_shape = get_const_tuple(inputs[1].shape)
channel_multiplier = kernel_shape[1] * kernel_shape[4]
add_implementation_depthwise_nchw(channel_multiplier)
elif layout == "NHWC":
assert kernel_layout == "HWOI"
logger.warning("depthwise_conv2d NHWC layout is not optimized for x86.")
Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/tensor/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ namespace tvm {
namespace relay {

extern Expr MakeReshape(Expr data,
Array<Integer> newshape);
Array<Integer> newshape);

template <typename AttrType>
bool ConcatenateRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
Expand Down
10 changes: 6 additions & 4 deletions src/relay/transforms/fold_scale_axis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ static bool IsIntInArray(const Array<Integer>& axis, int v) {
}

static Expr ReshapeToMatchAxis(Expr scale, const Array<PrimExpr>& shape,
const Array<Integer>& axis) {
const Array<Integer>& axis) {
Array<Integer> arr;
for (size_t i = 0; i < shape.size(); i++) {
if (IsIntInArray(axis, i)) {
Expand All @@ -337,7 +337,7 @@ static Expr ReshapeToMatchAxis(Expr scale, const Array<PrimExpr>& shape,

// if only one axis, use expand dim. Else, use reshape
static Expr ReshapeOrExpandToMatchAxis(Expr scale, const Array<PrimExpr>& shape,
const Array<Integer>& axis) {
const Array<Integer>& axis) {
if (axis.size() > 1) {
return ReshapeToMatchAxis(scale, shape, axis);
} else {
Expand Down Expand Up @@ -407,8 +407,9 @@ Expr AddSubForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
CHECK(MatchBroadcastToLeftAxes(tlhs, trhs, slhs->axes));
Expr scale = ReshapeOrExpandToMatchAxis(
slhs->scale, tlhs->shape, slhs->axes);
if (!scale.defined())
if (!scale.defined()) {
return Expr();
}
Expr rhs = Divide(new_args[1], scale);
rnode->value = Call(ref_call->op, {slhs->value, rhs}, ref_call->attrs, ref_call->type_args);
rnode->scale = slhs->scale;
Expand All @@ -418,8 +419,9 @@ Expr AddSubForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
CHECK(MatchBroadcastToLeftAxes(trhs, tlhs, srhs->axes));
Expr scale = ReshapeOrExpandToMatchAxis(
srhs->scale, trhs->shape, srhs->axes);
if (!scale.defined())
if (!scale.defined()) {
return Expr();
}
Expr lhs = Divide(new_args[0], scale);
rnode->value = Call(ref_call->op, {lhs, srhs->value}, ref_call->attrs, ref_call->type_args);
rnode->scale = srhs->scale;
Expand Down
4 changes: 2 additions & 2 deletions topi/python/topi/x86/conv2d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@

logger = logging.getLogger('topi')

_NCHWc_matcher = re.compile("^NCHW[-+]?[0-9]+c$")
_OIHWio_matcher = re.compile("^OIHW[-+]?[0-9]+i[-+]?[0-9]+o$")
_NCHWc_matcher = re.compile("^NCHW[0-9]+c$")
_OIHWio_matcher = re.compile("^OIHW[0-9]+i[-+]?[0-9]+o$")

@conv2d_alter_layout.register("cpu")
def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
Expand Down

0 comments on commit 81c344a

Please sign in to comment.