Skip to content

Commit

Permalink
[TOPI] Use cblas for dense and batch_matmul when "cblas" is in the ta…
Browse files Browse the repository at this point in the history
…rget libraries (apache#3787)

* Support cblas library in dense

* start to add support for generic batch_matmul compute

* Add x86 override for batch_matmul

* Fix linting

* reset file

* Fix typos

* dummy change to re-trigger CI
  • Loading branch information
soiferj authored and wweic committed Sep 16, 2019
1 parent 53f9017 commit e6ef631
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 50 deletions.
3 changes: 2 additions & 1 deletion python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def schedule_dense(attrs, outputs, target):
@reg.register_compute("nn.batch_matmul")
def compute_batch_matmul(attrs, inputs, out_type, target):
"""Compute definition of batch_matmul"""
return [topi.nn.batch_matmul(inputs[0], inputs[1])]
with target:
return [topi.nn.batch_matmul(inputs[0], inputs[1])]


@reg.register_schedule("nn.batch_matmul")
Expand Down
25 changes: 22 additions & 3 deletions topi/python/topi/nn/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
import tvm
from ..util import get_const_tuple


def batch_matmul(x, y):
def batch_matmul_default(x, y):
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
data in batch.
Expand All @@ -30,7 +29,7 @@ def batch_matmul(x, y):
x : tvm.Tensor
3-D with shape [batch, M, K]
y : tvm.TEnsor
y : tvm.Tensor
3-D with shape [batch, N, K]
Returns
Expand All @@ -49,3 +48,23 @@ def batch_matmul(x, y):
return tvm.compute((batch, M, N),
lambda b, i, j: tvm.sum(x[b, i, k] * y[b, j, k], axis=k),
tag='batch_matmul')

@tvm.target.generic_func
def batch_matmul(x, y):
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
data in batch.
Parameters
----------
x : tvm.Tensor
3-D with shape [batch, M, K]
y : tvm.Tensor
3-D with shape [batch, N, K]
Returns
-------
output : tvm.Tensor
3-D with shape [batch, M, N]
"""
return batch_matmul_default(x, y)
29 changes: 28 additions & 1 deletion topi/python/topi/x86/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,33 @@
"""x86 batch_matmul operators"""
from __future__ import absolute_import as _abs
import tvm

from tvm.contrib import cblas
from topi.nn import batch_matmul, batch_matmul_default
from .. import generic
from ..util import traverse_inline, get_const_tuple, get_max_power2_factor

@batch_matmul.register(["cpu"])
def batch_matmul_x86(x, y):
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
data in batch.
Parameters
----------
x : tvm.Tensor
3-D with shape [batch, M, K]
y : tvm.Tensor
3-D with shape [batch, N, K]
Returns
-------
output : tvm.Tensor
3-D with shape [batch, M, N]
"""
target = tvm.target.current_target()
if "cblas" in target.libs:
return cblas.batch_matmul(x, y, False, True)
return batch_matmul_default(x, y)

@generic.schedule_batch_matmul.register(["cpu"])
def schedule_batch_matmul(outs):
Expand All @@ -38,6 +61,10 @@ def schedule_batch_matmul(outs):
sch: Schedule
The computation schedule for the op.
"""
target = tvm.target.current_target()
if "cblas" in target.libs:
return generic.schedule_extern(outs)

s = tvm.create_schedule([x.op for x in outs])

def _callback(op):
Expand Down
107 changes: 62 additions & 45 deletions topi/python/topi/x86/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import tvm
from tvm import autotvm
from tvm.autotvm.task.space import SplitEntity
from tvm.contrib import cblas

from .util import get_fp32_len
from .. import generic, tag, nn
Expand All @@ -40,29 +41,33 @@ def _declaration_dense(cfg, data, weight, bias=None, out_dtype=None):
# Declare dense compute with packing weight into cache-friendly layout
@autotvm.register_topi_compute(nn.dense, "cpu", "direct_pack")
def _declaration_dense_pack(cfg, data, weight, bias=None, out_dtype=None):
if out_dtype is None:
out_dtype = data.dtype
batch, in_dim = get_const_tuple(data.shape)
out_dim, _ = get_const_tuple(weight.shape)
# create tuning space
cfg.define_split("tile_y", batch, num_outputs=3)
cfg.define_split("tile_x", out_dim, num_outputs=3)
cfg.define_split("tile_k", in_dim, num_outputs=2)
if cfg.is_fallback:
_default_dense_pack_config(cfg, batch, out_dim, in_dim)

packw_bn = cfg["tile_x"].size[-1]
packw_shape = (out_dim // packw_bn, in_dim, packw_bn)
packw = tvm.compute(packw_shape,
lambda z, y, x: weight[z * packw_bn + x, y], name="packed_weight")

k = tvm.reduce_axis((0, in_dim), name="k")
C = tvm.compute((batch, out_dim),
lambda y, x: tvm.sum(
data[y, k].astype(out_dtype) *
packw[x // packw_bn, k, x % packw_bn].astype(out_dtype),
axis=k),
tag="dense_pack")
target = tvm.target.current_target()
if "cblas" in target.libs:
C = cblas.matmul(data, weight, False, True)
else:
if out_dtype is None:
out_dtype = data.dtype
batch, in_dim = get_const_tuple(data.shape)
out_dim, _ = get_const_tuple(weight.shape)
# create tuning space
cfg.define_split("tile_y", batch, num_outputs=3)
cfg.define_split("tile_x", out_dim, num_outputs=3)
cfg.define_split("tile_k", in_dim, num_outputs=2)
if cfg.is_fallback:
_default_dense_pack_config(cfg, batch, out_dim, in_dim)

packw_bn = cfg["tile_x"].size[-1]
packw_shape = (out_dim // packw_bn, in_dim, packw_bn)
packw = tvm.compute(packw_shape,
lambda z, y, x: weight[z * packw_bn + x, y], name="packed_weight")

k = tvm.reduce_axis((0, in_dim), name="k")
C = tvm.compute((batch, out_dim),
lambda y, x: tvm.sum(
data[y, k].astype(out_dtype) *
packw[x // packw_bn, k, x % packw_bn].astype(out_dtype),
axis=k),
tag="dense_pack")
if bias is not None:
C = tvm.compute((batch, out_dim), lambda i, j: C[i, j] + bias[j].astype(out_dtype),
tag=tag.BROADCAST)
Expand All @@ -72,28 +77,32 @@ def _declaration_dense_pack(cfg, data, weight, bias=None, out_dtype=None):
# Declare dense compute without packing weight
@autotvm.register_topi_compute(nn.dense, "cpu", "direct_nopack")
def _declaration_dense_nopack(cfg, data, weight, bias=None, out_dtype=None):
if out_dtype is None:
out_dtype = data.dtype
batch, in_dim = get_const_tuple(data.shape)
out_dim, _ = get_const_tuple(weight.shape)
# create tuning space
cfg.define_split("tile_x", out_dim, num_outputs=2)
cfg.define_split("tile_y", batch, num_outputs=2)
cfg.define_split("tile_k", in_dim, num_outputs=2)
if cfg.is_fallback:
_default_dense_nopack_config(cfg, batch, out_dim, in_dim)

vec = cfg["tile_k"].size[-1]
k = tvm.reduce_axis((0, in_dim // vec), "k")
CC = tvm.compute((batch, out_dim, vec),
lambda z, y, x: tvm.sum(
data[z, k * vec + x].astype(out_dtype) *
weight[y, k * vec + x].astype(out_dtype), axis=k))

kk = tvm.reduce_axis((0, vec), "kk")
C = tvm.compute((batch, out_dim),
lambda y, x: tvm.sum(CC[y, x, kk], axis=kk),
tag="dense_nopack")
target = tvm.target.current_target()
if "cblas" in target.libs:
C = cblas.matmul(data, weight, False, True)
else:
if out_dtype is None:
out_dtype = data.dtype
batch, in_dim = get_const_tuple(data.shape)
out_dim, _ = get_const_tuple(weight.shape)
# create tuning space
cfg.define_split("tile_x", out_dim, num_outputs=2)
cfg.define_split("tile_y", batch, num_outputs=2)
cfg.define_split("tile_k", in_dim, num_outputs=2)
if cfg.is_fallback:
_default_dense_nopack_config(cfg, batch, out_dim, in_dim)

vec = cfg["tile_k"].size[-1]
k = tvm.reduce_axis((0, in_dim // vec), "k")
CC = tvm.compute((batch, out_dim, vec),
lambda z, y, x: tvm.sum(
data[z, k * vec + x].astype(out_dtype) *
weight[y, k * vec + x].astype(out_dtype), axis=k))

kk = tvm.reduce_axis((0, vec), "kk")
C = tvm.compute((batch, out_dim),
lambda y, x: tvm.sum(CC[y, x, kk], axis=kk),
tag="dense_nopack")
if bias is not None:
C = tvm.compute((batch, out_dim), lambda i, j: C[i, j] + bias[j].astype(out_dtype),
tag=tag.BROADCAST)
Expand All @@ -116,6 +125,10 @@ def _callback(op):

@autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct_pack")
def _schedule_dense_pack(cfg, outs):
target = tvm.target.current_target()
if "cblas" in target.libs:
return generic.schedule_extern(outs)

s = tvm.create_schedule([x.op for x in outs])

def _callback(op):
Expand All @@ -127,6 +140,10 @@ def _callback(op):

@autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct_nopack")
def _schedule_dense_nopack(cfg, outs):
target = tvm.target.current_target()
if "cblas" in target.libs:
return generic.schedule_extern(outs)

s = tvm.create_schedule([x.op for x in outs])

def _callback(op):
Expand Down

0 comments on commit e6ef631

Please sign in to comment.