Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI] Use cblas for dense and batch_matmul when "cblas" is in the target libraries #3787

Merged
merged 7 commits into from
Aug 21, 2019
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def schedule_dense(attrs, outputs, target):
@reg.register_compute("nn.batch_matmul")
def compute_batch_matmul(attrs, inputs, out_type, target):
"""Compute definition of batch_matmul"""
return [topi.nn.batch_matmul(inputs[0], inputs[1])]
with target:
return [topi.nn.batch_matmul(inputs[0], inputs[1])]


@reg.register_schedule("nn.batch_matmul")
Expand Down
25 changes: 22 additions & 3 deletions topi/python/topi/nn/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
import tvm
from ..util import get_const_tuple


def batch_matmul(x, y):
def batch_matmul_default(x, y):
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
data in batch.

Expand All @@ -30,7 +29,7 @@ def batch_matmul(x, y):
x : tvm.Tensor
3-D with shape [batch, M, K]

y : tvm.TEnsor
y : tvm.Tensor
3-D with shape [batch, N, K]

Returns
Expand All @@ -49,3 +48,23 @@ def batch_matmul(x, y):
return tvm.compute((batch, M, N),
lambda b, i, j: tvm.sum(x[b, i, k] * y[b, j, k], axis=k),
tag='batch_matmul')

@tvm.target.generic_func
def batch_matmul(x, y):
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
data in batch.

Parameters
----------
x : tvm.Tensor
3-D with shape [batch, M, K]

y : tvm.Tensor
3-D with shape [batch, N, K]

Returns
-------
output : tvm.Tensor
3-D with shape [batch, M, N]
"""
return batch_matmul_default(x, y)
28 changes: 28 additions & 0 deletions topi/python/topi/x86/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,34 @@
"""x86 batch_matmul operators"""
from __future__ import absolute_import as _abs
import tvm
from tvm.contrib import cblas

from topi.nn import batch_matmul, batch_matmul_default
from .. import generic
from ..util import traverse_inline, get_const_tuple, get_max_power2_factor

@batch_matmul.register(["cpu"])
def batch_matmul_x86(x, y):
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
data in batch.

Parameters
----------
x : tvm.Tensor
3-D with shape [batch, M, K]

y : tvm.Tensor
3-D with shape [batch, N, K]

Returns
-------
output : tvm.Tensor
3-D with shape [batch, M, N]
"""
target = tvm.target.current_target()
if "cblas" in target.libs:
return cblas.batch_matmul(x, y, False, True)
return batch_matmul_default(x, y)

@generic.schedule_batch_matmul.register(["cpu"])
def schedule_batch_matmul(outs):
Expand All @@ -38,6 +62,10 @@ def schedule_batch_matmul(outs):
sch: Schedule
The computation schedule for the op.
"""
target = tvm.target.current_target()
if "cblas" in target.libs:
return generic.schedule_extern(outs)

s = tvm.create_schedule([x.op for x in outs])

def _callback(op):
Expand Down
107 changes: 62 additions & 45 deletions topi/python/topi/x86/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import tvm
from tvm import autotvm
from tvm.autotvm.task.space import SplitEntity
from tvm.contrib import cblas

from .util import get_fp32_len
from .. import generic, tag, nn
Expand All @@ -40,29 +41,33 @@ def _declaration_dense(cfg, data, weight, bias=None, out_dtype=None):
# Declare dense compute with packing weight into cache-friendly layout
@autotvm.register_topi_compute(nn.dense, "cpu", "direct_pack")
def _declaration_dense_pack(cfg, data, weight, bias=None, out_dtype=None):
if out_dtype is None:
out_dtype = data.dtype
batch, in_dim = get_const_tuple(data.shape)
out_dim, _ = get_const_tuple(weight.shape)
# create tuning space
cfg.define_split("tile_y", batch, num_outputs=3)
cfg.define_split("tile_x", out_dim, num_outputs=3)
cfg.define_split("tile_k", in_dim, num_outputs=2)
if cfg.is_fallback:
_default_dense_pack_config(cfg, batch, out_dim, in_dim)

packw_bn = cfg["tile_x"].size[-1]
packw_shape = (out_dim // packw_bn, in_dim, packw_bn)
packw = tvm.compute(packw_shape,
lambda z, y, x: weight[z * packw_bn + x, y], name="packed_weight")

k = tvm.reduce_axis((0, in_dim), name="k")
C = tvm.compute((batch, out_dim),
lambda y, x: tvm.sum(
data[y, k].astype(out_dtype) *
packw[x // packw_bn, k, x % packw_bn].astype(out_dtype),
axis=k),
tag="dense_pack")
target = tvm.target.current_target()
if "cblas" in target.libs:
C = cblas.matmul(data, weight, False, True)
else:
if out_dtype is None:
out_dtype = data.dtype
batch, in_dim = get_const_tuple(data.shape)
out_dim, _ = get_const_tuple(weight.shape)
# create tuning space
cfg.define_split("tile_y", batch, num_outputs=3)
cfg.define_split("tile_x", out_dim, num_outputs=3)
cfg.define_split("tile_k", in_dim, num_outputs=2)
if cfg.is_fallback:
_default_dense_pack_config(cfg, batch, out_dim, in_dim)

packw_bn = cfg["tile_x"].size[-1]
packw_shape = (out_dim // packw_bn, in_dim, packw_bn)
packw = tvm.compute(packw_shape,
lambda z, y, x: weight[z * packw_bn + x, y], name="packed_weight")

k = tvm.reduce_axis((0, in_dim), name="k")
C = tvm.compute((batch, out_dim),
lambda y, x: tvm.sum(
data[y, k].astype(out_dtype) *
packw[x // packw_bn, k, x % packw_bn].astype(out_dtype),
axis=k),
tag="dense_pack")
if bias is not None:
C = tvm.compute((batch, out_dim), lambda i, j: C[i, j] + bias[j].astype(out_dtype),
tag=tag.BROADCAST)
Expand All @@ -72,28 +77,32 @@ def _declaration_dense_pack(cfg, data, weight, bias=None, out_dtype=None):
# Declare dense compute without packing weight
@autotvm.register_topi_compute(nn.dense, "cpu", "direct_nopack")
def _declaration_dense_nopack(cfg, data, weight, bias=None, out_dtype=None):
if out_dtype is None:
out_dtype = data.dtype
batch, in_dim = get_const_tuple(data.shape)
out_dim, _ = get_const_tuple(weight.shape)
# create tuning space
cfg.define_split("tile_x", out_dim, num_outputs=2)
cfg.define_split("tile_y", batch, num_outputs=2)
cfg.define_split("tile_k", in_dim, num_outputs=2)
if cfg.is_fallback:
_default_dense_nopack_config(cfg, batch, out_dim, in_dim)

vec = cfg["tile_k"].size[-1]
k = tvm.reduce_axis((0, in_dim // vec), "k")
CC = tvm.compute((batch, out_dim, vec),
lambda z, y, x: tvm.sum(
data[z, k * vec + x].astype(out_dtype) *
weight[y, k * vec + x].astype(out_dtype), axis=k))

kk = tvm.reduce_axis((0, vec), "kk")
C = tvm.compute((batch, out_dim),
lambda y, x: tvm.sum(CC[y, x, kk], axis=kk),
tag="dense_nopack")
target = tvm.target.current_target()
if "cblas" in target.libs:
C = cblas.matmul(data, weight, False, True)
else:
if out_dtype is None:
out_dtype = data.dtype
batch, in_dim = get_const_tuple(data.shape)
out_dim, _ = get_const_tuple(weight.shape)
# create tuning space
cfg.define_split("tile_x", out_dim, num_outputs=2)
cfg.define_split("tile_y", batch, num_outputs=2)
cfg.define_split("tile_k", in_dim, num_outputs=2)
if cfg.is_fallback:
_default_dense_nopack_config(cfg, batch, out_dim, in_dim)

vec = cfg["tile_k"].size[-1]
k = tvm.reduce_axis((0, in_dim // vec), "k")
CC = tvm.compute((batch, out_dim, vec),
lambda z, y, x: tvm.sum(
data[z, k * vec + x].astype(out_dtype) *
weight[y, k * vec + x].astype(out_dtype), axis=k))

kk = tvm.reduce_axis((0, vec), "kk")
C = tvm.compute((batch, out_dim),
lambda y, x: tvm.sum(CC[y, x, kk], axis=kk),
tag="dense_nopack")
if bias is not None:
C = tvm.compute((batch, out_dim), lambda i, j: C[i, j] + bias[j].astype(out_dtype),
tag=tag.BROADCAST)
Expand All @@ -116,6 +125,10 @@ def _callback(op):

@autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct_pack")
def _schedule_dense_pack(cfg, outs):
target = tvm.target.current_target()
if "cblas" in target.libs:
return generic.schedule_extern(outs)

s = tvm.create_schedule([x.op for x in outs])

def _callback(op):
Expand All @@ -127,6 +140,10 @@ def _callback(op):

@autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct_nopack")
def _schedule_dense_nopack(cfg, outs):
target = tvm.target.current_target()
if "cblas" in target.libs:
return generic.schedule_extern(outs)

s = tvm.create_schedule([x.op for x in outs])

def _callback(op):
Expand Down