Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI] FIFO buffer op, to accelerate sequence modeling with dilated convolutions #4039

Merged
merged 6 commits into from
Oct 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,15 @@ struct SparseTransposeAttrs : public tvm::AttrsNode<SparseTransposeAttrs> {
TVM_DECLARE_ATTRS(SparseTransposeAttrs, "relay.attrs.SparseTransposeAttrs") {}
};

/*! \brief Attributes for FIFO buffer operator */
struct FIFOBufferAttrs : public tvm::AttrsNode<FIFOBufferAttrs> {
int axis;

TVM_DECLARE_ATTRS(FIFOBufferAttrs, "relay.attrs.FIFOBufferAttrs") {
TVM_ATTR_FIELD(axis).set_default(0);
}
};

/*! \brief Attributes for upsampling operator */
struct UpSamplingAttrs : public tvm::AttrsNode<UpSamplingAttrs> {
int scale;
Expand Down
7 changes: 7 additions & 0 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,6 +1018,12 @@ def _mx_one_hot(inputs, attrs):
return _op.one_hot(indices, on_value, off_value, depth, -1, dtype)


def _mx_contrib_fifo_buffer(inputs, attrs):
new_attrs = {}
new_attrs['axis'] = attrs.get_int('axis')
return _op.nn.fifo_buffer(*inputs, **new_attrs)


# Note: due to attribute conversion constraint
# ops in the identity set must be attribute free
_identity_list = [
Expand Down Expand Up @@ -1189,6 +1195,7 @@ def _mx_one_hot(inputs, attrs):
# TODO(tvm-tvm): support all operators.
#
# "broadcast_to",
"contrib_fifo_buffer" : _mx_contrib_fifo_buffer,
zhiics marked this conversation as resolved.
Show resolved Hide resolved
}

# set identity list
Expand Down
14 changes: 14 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,20 @@ def schedule_dense(attrs, outputs, target):
reg.register_pattern("nn.dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)


@reg.register_compute('nn.fifo_buffer')
def compute_fifo_buffer(attrs, inputs, out_type, target):
return [topi.nn.fifo_buffer(inputs[0], inputs[1], axis=attrs.get_int('axis'))]

hcho3 marked this conversation as resolved.
Show resolved Hide resolved

@reg.register_schedule('nn.fifo_buffer')
def schedule_fifo_buffer(attrs, outputs, target):
with target:
return topi.generic.schedule_injective(outputs)


reg.register_pattern("nn.fifo_buffer", OpPattern.OPAQUE)


# batch_matmul
@reg.register_compute("nn.batch_matmul")
def compute_batch_matmul(attrs, inputs, out_type, target):
Expand Down
30 changes: 30 additions & 0 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,36 @@ def dense(data, weight, units=None, out_dtype=""):
return _make.dense(data, weight, units, out_dtype)


def fifo_buffer(data, buffer, axis):
"""FIFO buffer

Compute equivalent of
```
concat(buffer, data, axis=axis) \
.slice_axis(axis=axis, begin=data.shape[axis], end=data.shape[axis]+buffer.shape[axis])
```

Useful for
* Encoding explicit re-use of computation in convolution ops operated on a sliding window input
* Implementing a FIFO queue to cache intermediate results, e.g. as in Fast WaveNet.

Parameters
----------
data : tvm.relay.Expr
The input data
buffer : tvm.relay.Expr
Previous value of the FIFO buffer
axis : int
Specify which axis should be used for buffering

Returns
-------
result : tvm.relay.Expr
Updated value for the buffer
"""
return _make.fifo_buffer(data, buffer, axis)


def relu(data):
"""Rectified linear unit.

Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ class DenseAttrs(Attrs):
"""Attributes for nn.dense"""


@register_relay_attr_node
class FIFOBufferAttrs(Attrs):
"""Attributes for nn.fifo_buffer"""


@register_relay_attr_node
class UpSamplingAttrs(Attrs):
"""Attributes for nn.upsampling"""
Expand Down
67 changes: 67 additions & 0 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,73 @@ RELAY_REGISTER_OP("nn.bias_add")
});


// relay.nn.fifo_buffer
TVM_REGISTER_NODE_TYPE(FIFOBufferAttrs);

Expr MakeFIFOBuffer(Expr input, Expr buffer, int axis) {
auto attrs = make_node<FIFOBufferAttrs>();
attrs->axis = axis;
static const Op& op = Op::Get("nn.fifo_buffer");
return CallNode::make(op, {input, buffer}, Attrs(attrs), {});
}

bool FIFOBufferRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* input = types[0].as<TensorTypeNode>();
const auto* buffer = types[1].as<TensorTypeNode>();
const FIFOBufferAttrs* param = attrs.as<FIFOBufferAttrs>();
if (input == nullptr || buffer == nullptr) {
return false;
}
CHECK(param != nullptr);
CHECK_EQ(input->shape.size(), buffer->shape.size());

const size_t buffer_axis
= static_cast<size_t>(param->axis < 0 ? static_cast<int>(buffer->shape.size()) + param->axis
: param->axis);

reporter->Assert(buffer_axis < buffer->shape.size());
for (size_t i = 0; i < buffer->shape.size(); ++i) {
if (i != buffer_axis) {
reporter->AssertEQ(input->shape[i], buffer->shape[i]);
}
}
reporter->Assert(input->shape[buffer_axis] < buffer->shape[buffer_axis]);

Array<tvm::Expr> oshape = buffer->shape;

reporter->Assign(types[2], TensorTypeNode::make(oshape, buffer->dtype));
return true;
}

TVM_REGISTER_API("relay.op.nn._make.fifo_buffer")
.set_body_typed(MakeFIFOBuffer);

RELAY_REGISTER_OP("nn.fifo_buffer")
.describe(R"code(FIFO buffer
Compute equivalent of

```
concat(buffer, data, axis=axis) \
.slice_axis(axis=axis, begin=data.shape[axis], end=data.shape[axis]+buffer.shape[axis])
```

Useful for
* Encoding explicit re-use of computation in convolution ops operated on a sliding window input
* Implementing a FIFO queue to cache intermediate results, e.g. as in Fast WaveNet.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.FIFOBufferAttrs")
.set_num_inputs(2)
.add_argument("data", "Tensor", "Latest input")
.add_argument("buffer", "Tensor",
"Buffer storing latest [length_buffer] inputs")
.set_support_level(3)
.add_type_rel("FIFOBuffer", FIFOBufferRel);


// relay.nn.dense
TVM_REGISTER_NODE_TYPE(DenseAttrs);

Expand Down
1 change: 1 addition & 0 deletions topi/python/topi/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@
from .batch_matmul import *
from .sparse import *
from .pad import *
from .fifo_buffer import *
127 changes: 127 additions & 0 deletions topi/python/topi/nn/fifo_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""FIFO buffer op"""
from __future__ import absolute_import as _abs
import tvm
from .. import tag
from ..transform import concatenate, strided_slice

@tvm.tag_scope(tag=tag.INJECTIVE+",fifo_buffer")
def fifo_buffer(data, buffer, axis):
hcho3 marked this conversation as resolved.
Show resolved Hide resolved
"""
Implements the FIFO buffer
"""
assert len(data.shape) == len(buffer.shape), \
'buffer and data must have same number of dimensions, ' + \
'buffer.shape = {}, data.shape = {}'.format(buffer.shape, data.shape)
assert len(buffer.shape) >= 1, 'Zero-dimension tensor not supported'
assert 0 <= axis < len(buffer.shape), 'buffer axis out of range'
for i in range(len(data.shape)):
if i == axis:
assert int(str(data.shape[i])) <= int(str(buffer.shape[i]))
else:
assert int(str(data.shape[i])) == int(str(buffer.shape[i]))

buflen = buffer.shape[axis]
data_size = data.shape[axis]

# Explicitly write out formula up to 4D, and then use concat+slice combo for 5D and higher
if len(buffer.shape) == 1:
return tvm.compute(buffer.shape,
lambda i:
tvm.if_then_else(i < buflen - data_size,
buffer[i + data_size],
data[i - buflen + data_size]),
name='new_buffer')
elif len(buffer.shape) == 2:
if axis == 0:
return tvm.compute(buffer.shape,
lambda i, j:
tvm.if_then_else(i < buflen - data_size,
buffer[i + data_size, j],
data[i - buflen + data_size, j]),
name='new_buffer')
if axis == 1:
return tvm.compute(buffer.shape,
lambda i, j:
tvm.if_then_else(j < buflen - data_size,
buffer[i, j + data_size],
data[i, j - buflen + data_size]),
name='new_buffer')
assert False, 'Invalid value for axis; it should be at most {}'.format(len(buffer.shape))
elif len(buffer.shape) == 3:
if axis == 0:
return tvm.compute(buffer.shape,
lambda i, j, k:
tvm.if_then_else(i < buflen - data_size,
buffer[i + data_size, j, k],
data[i - buflen + data_size, j, k]),
name='new_buffer')
if axis == 1:
return tvm.compute(buffer.shape,
lambda i, j, k:
tvm.if_then_else(j < buflen - data_size,
buffer[i, j + data_size, k],
data[i, j - buflen + data_size, k]),
name='new_buffer')
if axis == 2:
return tvm.compute(buffer.shape,
lambda i, j, k:
tvm.if_then_else(k < buflen - data_size,
buffer[i, j, k + data_size],
data[i, j, k - buflen + data_size]),
name='new_buffer')
assert False, 'Invalid value for axis; it should be at most {}'.format(len(buffer.shape))
elif len(buffer.shape) == 4:
if axis == 0:
return tvm.compute(buffer.shape,
lambda i, j, k, l:
tvm.if_then_else(i < buflen - data_size,
buffer[i + data_size, j, k, l],
data[i - buflen + data_size, j, k, l]),
name='new_buffer')
if axis == 1:
return tvm.compute(buffer.shape,
lambda i, j, k, l:
tvm.if_then_else(j < buflen - data_size,
buffer[i, j + data_size, k, l],
data[i, j - buflen + data_size, k, l]),
name='new_buffer')
if axis == 2:
return tvm.compute(buffer.shape,
lambda i, j, k, l:
tvm.if_then_else(k < buflen - data_size,
buffer[i, j, k + data_size, l],
data[i, j, k - buflen + data_size, l]),
name='new_buffer')
if axis == 3:
return tvm.compute(buffer.shape,
lambda i, j, k, l:
tvm.if_then_else(l < buflen - data_size,
buffer[i, j, k, l + data_size],
data[i, j, k, l - buflen + data_size]),
name='new_buffer')
assert False, 'Invalid value for axis; it should be at most {}'.format(len(buffer.shape))
else:
# Implement FIFO buffer as combination of concat and slice
begin = [0] * len(buffer.shape)
begin[axis] = data.shape[axis]
end = list(buffer.shape[:])
end[axis] += data.shape[axis]
return strided_slice(concatenate((buffer, data), axis=axis), begin=begin, end=end)
return None
Loading