Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
cjld committed Jul 19, 2020
2 parents 4726cb2 + a4bb631 commit c49ebb1
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 4 deletions.
111 changes: 111 additions & 0 deletions extern/cuda/cublas/ops/cublas_batched_matmul_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors:
// Meng-Hao Guo <guomenghao1997@gmail.com>
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************


// cublas_batched_matmul_op.cc
#include "var.h"

#include "cublas_batched_matmul_op.h"
#include "cublas_warper.h"

using namespace std;

namespace jittor {

#ifndef JIT

static auto make_cublas_batched_matmul = get_op_info("cublas_batched_matmul")
.get_constructor<VarPtr, Var*, Var*, bool, bool>();

CublasBatchedMatmulOp::CublasBatchedMatmulOp(Var* a, Var* b, bool trans_a, bool trans_b)
: a(a), b(b), trans_a(trans_a), trans_b(trans_b) {
// TODO: support int8 * int8
ASSERT(a->dtype().is_float() && b->dtype().is_float()) << "type of two inputs should be the same";
// TODO: support diffrent input type
ASSERT(a->dtype().dsize() == b->dtype().dsize()) << "type of two inputs should be the same";
c = create_output(nullptr, a->dtype());
flags.set(NodeFlags::_cpu, 0);
flags.set(NodeFlags::_cuda, 1);
}


VarPtr CublasBatchedMatmulOp::grad(Var* out, Var* dout, Var* v, int v_index) {
// a [b,n,m] b [b,m,k], c[b,n,k]
// c = a*b
if (v_index == 0) {
// da = dc*b^T
return make_cublas_batched_matmul(dout, b, trans_a^0, trans_b^1);
} else {
// db = a^T*dc
return make_cublas_batched_matmul(a, dout, trans_a^1, trans_b^0);
}
}

void CublasBatchedMatmulOp::infer_shape(){
ASSERTop(a->shape.size(),==,3);
ASSERTop(b->shape.size(),==,3);

int batch_size = a->shape[0], n = a->shape[1], m = a->shape[2];
int batch_size_ = b->shape[0], m_ = b->shape[1], k = b->shape[2];

ASSERTop(batch_size,==,batch_size_);
if (trans_a) {
swap(n, m);
}
if (trans_b) {
swap(m_, k);
}
ASSERTop(m,==,m_);

c->set_shape({batch_size, n, k});
}

void CublasBatchedMatmulOp::jit_prepare() {
add_jit_define("T", a->dtype());
add_jit_define("Trans_a", trans_a ? "T" : "N");
add_jit_define("Trans_b", trans_b ? "T" : "N");
add_jit_define("op", a->dtype().dsize() == 4 ? "S" : "D");
}

#else // JIT
#ifdef JIT_cuda
#pragma clang diagnostic ignored "-Wtautological-compare"
void CublasBatchedMatmulOp::jit_run() {
cublasHandle_t& handle_ = cublas_handle;
const T alpha = 1.0f;
const T beta = 0.0f;

const auto& as = a->shape;
const auto& bs = b->shape;
auto batch_size = as[0];
auto n = as[1];
auto m = as[2];
auto k = bs[2];
if ('@Trans_a'=='T') {
n = as[2];
m = as[1];
}
if ('@Trans_b'=='T') {
k = bs[1];
}
// a: [b,n,m], b: [b,m,k], c: [b,n,k]
checkCudaErrors(cublas@op@@gemmStridedBatched(handle_,
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
k, n, m, &alpha,
b->ptr<T>(), '@Trans_b' == 'N' ? k : m, k * m,
a->ptr<T>(), '@Trans_a' == 'N' ? m : n, n * m, &beta,
c->ptr<T>(), k, k * n,
batch_size));
}
#endif
#endif // JIT

} // jittor


30 changes: 30 additions & 0 deletions extern/cuda/cublas/ops/cublas_batched_matmul_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors:
// Meng-Hao Guo <guomenghao1997@gmail.com>
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************


// cublas_batched_matmul_op.h
#pragma once
#include "op.h"
#include "ops/op_register.h"
#include "var.h"

namespace jittor {

struct CublasBatchedMatmulOp : Op {
Var* a, * b, * c;
bool trans_a, trans_b;
CublasBatchedMatmulOp(Var* a, Var* b, bool trans_a, bool trans_b);

const char* name() const override { return "cublas_batched_matmul"; }
void infer_shape() override;
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
DECLARE_jit_run;
};

} // jittor
8 changes: 4 additions & 4 deletions python/jittor/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ def matmul_transpose(a, b):
'''
assert len(a.shape) >= 2 and len(b.shape) == 2
assert a.shape[-1] == b.shape[-1]

if jt.flags.use_cuda:
jt.compile_extern.cublas_ops.cublas_batched_matmul(a, b, 0, 0)
shape = list(a.shape)[:-1] + list(b.shape)
a = a.broadcast(shape, [len(shape)-2])
b = b.broadcast(shape)
return (a*b).sum(len(shape)-1)


def bmm(a, b):
''' batch matrix multiply,
shape of input a is [batch, n, m],
Expand All @@ -46,11 +46,11 @@ def bmm(a, b):
a = jt.random((batch, n, m))
b = jt.random((batch, m, k))
c = nn.bmm(a, b)
'''
assert len(a.shape) >= 2 and len(b.shape) >= 2
assert a.shape[-1] == b.shape[-2]

if jt.flags.use_cuda:
return jt.compile_extern.cublas_ops.cublas_batched_matmul(a, b, 0, 0)
shape = list(a.shape) + [b.shape[-1]]
a = a.broadcast(shape, [len(shape)-1])
b = b.broadcast(shape, [len(shape)-3])
Expand Down
44 changes: 44 additions & 0 deletions python/jittor/test/test_bmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Meng-Hao Guo <guomenghao1997@gmail.com>
# Dun Liang <randonlang@gmail.com>.
#
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import jittor as jt
from jittor import nn
import unittest
import numpy as np

class TestBMM(unittest.TestCase):
@unittest.skipIf(not jt.has_cuda, "No cuda found")
def test_bmm_cuda(self):
def check(batch, n, m, k):
def calc(use_cuda, a, b, mask):
jt.flags.use_cuda = use_cuda
a = jt.array(a)
b = jt.array(b)
mask = jt.array(mask)
c = nn.bmm(a, b)
da, db = jt.grad(c*mask, [a, b])
return c.data, da.data, db.data
mask = np.random.rand(batch, n, k).astype("float32")
a = np.random.rand(batch, n, m).astype("float32")
b = np.random.rand(batch, m, k).astype("float32")
a1,a2,a3 = calc(0, a, b, mask)
b1,b2,b3 = calc(1, a, b, mask)
assert np.allclose(a1, b1)
assert np.allclose(a2, b2)
assert np.allclose(a3, b3)
check(10,3,4,5)
check(10,8,8,8)
check(10,8,1,8)
check(10,8,8,1)
check(10,1,8,8)
check(1,7,8,8)


if __name__ == "__main__":
unittest.main()

0 comments on commit c49ebb1

Please sign in to comment.