-
Notifications
You must be signed in to change notification settings - Fork 314
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'gmh' of https://github.com/Jittor/jittor
- Loading branch information
Showing
4 changed files
with
189 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
// *************************************************************** | ||
// Copyright (c) 2020 Jittor. Authors: | ||
// Meng-Hao Guo <guomenghao1997@gmail.com> | ||
// Dun Liang <randonlang@gmail.com>. | ||
// All Rights Reserved. | ||
// This file is subject to the terms and conditions defined in | ||
// file 'LICENSE.txt', which is part of this source code package. | ||
// *************************************************************** | ||
|
||
|
||
// cublas_batched_matmul_op.cc | ||
#include "var.h" | ||
|
||
#include "cublas_batched_matmul_op.h" | ||
#include "cublas_warper.h" | ||
|
||
using namespace std; | ||
|
||
namespace jittor { | ||
|
||
#ifndef JIT | ||
|
||
static auto make_cublas_batched_matmul = get_op_info("cublas_batched_matmul") | ||
.get_constructor<VarPtr, Var*, Var*, bool, bool>(); | ||
|
||
CublasBatchedMatmulOp::CublasBatchedMatmulOp(Var* a, Var* b, bool trans_a, bool trans_b) | ||
: a(a), b(b), trans_a(trans_a), trans_b(trans_b) { | ||
// TODO: support int8 * int8 | ||
ASSERT(a->dtype().is_float() && b->dtype().is_float()) << "type of two inputs should be the same"; | ||
// TODO: support diffrent input type | ||
ASSERT(a->dtype().dsize() == b->dtype().dsize()) << "type of two inputs should be the same"; | ||
c = create_output(nullptr, a->dtype()); | ||
flags.set(NodeFlags::_cpu, 0); | ||
flags.set(NodeFlags::_cuda, 1); | ||
} | ||
|
||
|
||
VarPtr CublasBatchedMatmulOp::grad(Var* out, Var* dout, Var* v, int v_index) { | ||
// a [b,n,m] b [b,m,k], c[b,n,k] | ||
// c = a*b | ||
if (v_index == 0) { | ||
// da = dc*b^T | ||
return make_cublas_batched_matmul(dout, b, trans_a^0, trans_b^1); | ||
} else { | ||
// db = a^T*dc | ||
return make_cublas_batched_matmul(a, dout, trans_a^1, trans_b^0); | ||
} | ||
} | ||
|
||
void CublasBatchedMatmulOp::infer_shape(){ | ||
ASSERTop(a->shape.size(),==,3); | ||
ASSERTop(b->shape.size(),==,3); | ||
|
||
int batch_size = a->shape[0], n = a->shape[1], m = a->shape[2]; | ||
int batch_size_ = b->shape[0], m_ = b->shape[1], k = b->shape[2]; | ||
|
||
ASSERTop(batch_size,==,batch_size_); | ||
if (trans_a) { | ||
swap(n, m); | ||
} | ||
if (trans_b) { | ||
swap(m_, k); | ||
} | ||
ASSERTop(m,==,m_); | ||
|
||
c->set_shape({batch_size, n, k}); | ||
} | ||
|
||
void CublasBatchedMatmulOp::jit_prepare() { | ||
add_jit_define("T", a->dtype()); | ||
add_jit_define("Trans_a", trans_a ? "T" : "N"); | ||
add_jit_define("Trans_b", trans_b ? "T" : "N"); | ||
add_jit_define("op", a->dtype().dsize() == 4 ? "S" : "D"); | ||
} | ||
|
||
#else // JIT | ||
#ifdef JIT_cuda | ||
#pragma clang diagnostic ignored "-Wtautological-compare" | ||
void CublasBatchedMatmulOp::jit_run() { | ||
cublasHandle_t& handle_ = cublas_handle; | ||
const T alpha = 1.0f; | ||
const T beta = 0.0f; | ||
|
||
const auto& as = a->shape; | ||
const auto& bs = b->shape; | ||
auto batch_size = as[0]; | ||
auto n = as[1]; | ||
auto m = as[2]; | ||
auto k = bs[2]; | ||
if ('@Trans_a'=='T') { | ||
n = as[2]; | ||
m = as[1]; | ||
} | ||
if ('@Trans_b'=='T') { | ||
k = bs[1]; | ||
} | ||
// a: [b,n,m], b: [b,m,k], c: [b,n,k] | ||
checkCudaErrors(cublas@op@@gemmStridedBatched(handle_, | ||
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a, | ||
k, n, m, &alpha, | ||
b->ptr<T>(), '@Trans_b' == 'N' ? k : m, k * m, | ||
a->ptr<T>(), '@Trans_a' == 'N' ? m : n, n * m, &beta, | ||
c->ptr<T>(), k, k * n, | ||
batch_size)); | ||
} | ||
#endif | ||
#endif // JIT | ||
|
||
} // jittor | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
// *************************************************************** | ||
// Copyright (c) 2020 Jittor. Authors: | ||
// Meng-Hao Guo <guomenghao1997@gmail.com> | ||
// Dun Liang <randonlang@gmail.com>. | ||
// All Rights Reserved. | ||
// This file is subject to the terms and conditions defined in | ||
// file 'LICENSE.txt', which is part of this source code package. | ||
// *************************************************************** | ||
|
||
|
||
// cublas_batched_matmul_op.h | ||
#pragma once | ||
#include "op.h" | ||
#include "ops/op_register.h" | ||
#include "var.h" | ||
|
||
namespace jittor { | ||
|
||
struct CublasBatchedMatmulOp : Op { | ||
Var* a, * b, * c; | ||
bool trans_a, trans_b; | ||
CublasBatchedMatmulOp(Var* a, Var* b, bool trans_a, bool trans_b); | ||
|
||
const char* name() const override { return "cublas_batched_matmul"; } | ||
void infer_shape() override; | ||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; | ||
DECLARE_jit_run; | ||
}; | ||
|
||
} // jittor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# *************************************************************** | ||
# Copyright (c) 2020 Jittor. Authors: | ||
# Meng-Hao Guo <guomenghao1997@gmail.com> | ||
# Dun Liang <randonlang@gmail.com>. | ||
# | ||
# All Rights Reserved. | ||
# This file is subject to the terms and conditions defined in | ||
# file 'LICENSE.txt', which is part of this source code package. | ||
# *************************************************************** | ||
import jittor as jt | ||
from jittor import nn | ||
import unittest | ||
import numpy as np | ||
|
||
class TestBMM(unittest.TestCase): | ||
@unittest.skipIf(not jt.has_cuda, "No cuda found") | ||
def test_bmm_cuda(self): | ||
def check(batch, n, m, k): | ||
def calc(use_cuda, a, b, mask): | ||
jt.flags.use_cuda = use_cuda | ||
a = jt.array(a) | ||
b = jt.array(b) | ||
mask = jt.array(mask) | ||
c = nn.bmm(a, b) | ||
da, db = jt.grad(c*mask, [a, b]) | ||
return c.data, da.data, db.data | ||
mask = np.random.rand(batch, n, k).astype("float32") | ||
a = np.random.rand(batch, n, m).astype("float32") | ||
b = np.random.rand(batch, m, k).astype("float32") | ||
a1,a2,a3 = calc(0, a, b, mask) | ||
b1,b2,b3 = calc(1, a, b, mask) | ||
assert np.allclose(a1, b1) | ||
assert np.allclose(a2, b2) | ||
assert np.allclose(a3, b3) | ||
check(10,3,4,5) | ||
check(10,8,8,8) | ||
check(10,8,1,8) | ||
check(10,8,8,1) | ||
check(10,1,8,8) | ||
check(1,7,8,8) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |