Skip to content

Commit

Permalink
[EinsumOp] Make EinsumOp support bfloat16. (PaddlePaddle#43085)
Browse files Browse the repository at this point in the history
* change einsum_v2 as default and add new flags: FLAG_einsum_opt=1|0

* make EInsumOP support bf16

* add unittest for BF16

* add condition for test_BF16

* fix bugs

* fix
  • Loading branch information
2742195759 authored and fuyou765 committed Jun 7, 2022
1 parent 566a5c8 commit 342fa47
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 16 deletions.
2 changes: 2 additions & 0 deletions paddle/phi/kernels/funcs/eigen/broadcast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"

Expand Down Expand Up @@ -73,6 +74,7 @@ struct EigenBroadcastGrad<Eigen::DefaultDevice, T, Rank> {
template struct FUNCTOR<Eigen::DefaultDevice, T, 6>
INSTANTIATION(EigenBroadcast, bool);
INSTANTIATION(EigenBroadcast, dtype::float16);
INSTANTIATION(EigenBroadcast, dtype::bfloat16);
INSTANTIATION(EigenBroadcast, float);
INSTANTIATION(EigenBroadcast, double);
INSTANTIATION(EigenBroadcast, int);
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/funcs/eigen/broadcast.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"

Expand Down Expand Up @@ -73,6 +74,7 @@ struct EigenBroadcastGrad<Eigen::GpuDevice, T, Rank> {
template struct FUNCTOR<Eigen::GpuDevice, T, 6>
INSTANTIATION(EigenBroadcast, bool);
INSTANTIATION(EigenBroadcast, dtype::float16);
INSTANTIATION(EigenBroadcast, dtype::bfloat16);
INSTANTIATION(EigenBroadcast, float);
INSTANTIATION(EigenBroadcast, double);
INSTANTIATION(EigenBroadcast, int);
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/kernels/gpu/einsum_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ PD_REGISTER_KERNEL(einsum_grad,
phi::EinsumGradKernel,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
3 changes: 2 additions & 1 deletion paddle/phi/kernels/gpu/tile_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(tile,
double,
int,
int64_t,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
32 changes: 18 additions & 14 deletions paddle/phi/kernels/impl/einsum_grad_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,20 +197,24 @@ void EinsumGradKernel(const Context& dev_ctx,
// release the cache tensor dTC to save memory right now. they are useless
// now.
cache.clear();
*(x_grad[0]) = PerformTileAndReduction<T, Context>(dev_ctx,
labeltype,
labelshape,
broadcast_dims,
ellipsis_dims[0],
ops[0],
dA);
*(x_grad[1]) = PerformTileAndReduction<T, Context>(dev_ctx,
labeltype,
labelshape,
broadcast_dims,
ellipsis_dims[1],
ops[1],
dB);
if (x_grad[0]) {
*(x_grad[0]) = PerformTileAndReduction<T, Context>(dev_ctx,
labeltype,
labelshape,
broadcast_dims,
ellipsis_dims[0],
ops[0],
dA);
}
if (x_grad[1]) {
*(x_grad[1]) = PerformTileAndReduction<T, Context>(dev_ctx,
labeltype,
labelshape,
broadcast_dims,
ellipsis_dims[1],
ops[1],
dB);
}
}
}
} // namespace phi
18 changes: 18 additions & 0 deletions python/paddle/fluid/tests/unittests/test_einsum_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,5 +478,23 @@ def test_shape(self):
self.assertEqual(C.shape, (-1, 384))


class TestBF16(unittest.TestCase):
"""
EinsumOp support bfloat16 type, add unittest here for the correctness.
"""

def test_shape(self):
cuda_major = paddle.version.cuda().split('.')[0].strip()
if paddle.is_compiled_with_cuda() and int(cuda_major) >= 11:
""" MatmulKernel support bfloat16 only if cuda_major > 11.0.
"""
A = paddle.to_tensor(np.array([1.0, 2.0])).astype(paddle.bfloat16)
A = A.cuda()
B = paddle.to_tensor(np.array([2.0, 3.0])).astype(paddle.bfloat16)
B = B.cuda()
C = paddle.einsum('i,i->', A, B)
self.assertEqual(C.item(), 8.0)


if __name__ == "__main__":
unittest.main()

0 comments on commit 342fa47

Please sign in to comment.