Skip to content

Commit

Permalink
support ConcatBnRelu for BFloat16 (#647)
Browse files Browse the repository at this point in the history
* support ConcatBnRelu for BFloat16

* modify some details of ConcatBnRelu
  • Loading branch information
jiayisunx authored Mar 29, 2022
1 parent 11a982e commit cad3f82
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 154 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,24 @@ at::Tensor concat_bn_relu_kernel_impl(
}
#if defined(CPU_CAPABILITY_AVX512)
if (tensor_check) {
at::Tensor output = at::empty(
output_dim,
a[0].options()
.dtype(at::kFloat)
.memory_format(a[0].suggest_memory_format()));
torch_ipex::cpu::kernel::vec::vec512::ConcatBnReluKernelImpl_ChannelsLast<
float>(a, bn_scale, bn_beta, output);
at::Tensor output;
if (a[0].scalar_type() == at::kBFloat16) {
output = at::empty(
output_dim,
a[0].options()
.dtype(at::kBFloat16)
.memory_format(a[0].suggest_memory_format()));
torch_ipex::cpu::kernel::vec::vec512::ConcatBnReluKernelImpl_ChannelsLast<
at::BFloat16>(a, bn_scale, bn_beta, output);
} else {
output = at::empty(
output_dim,
a[0].options()
.dtype(at::kFloat)
.memory_format(a[0].suggest_memory_format()));
torch_ipex::cpu::kernel::vec::vec512::ConcatBnReluKernelImpl_ChannelsLast<
float>(a, bn_scale, bn_beta, output);
}
return output;
}
#endif
Expand Down
91 changes: 72 additions & 19 deletions intel_extension_for_pytorch/csrc/cpu/vec512/concat_bn_relu.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,16 @@
#include <limits>
#include "utils.h"

// use float as accumulation type for BFloat16
template <typename scalar_t>
struct AccType {
using type = scalar_t;
};
template <>
struct AccType<BFloat16> {
using type = float;
};

namespace torch_ipex {
namespace cpu {
namespace kernel {
Expand All @@ -17,33 +27,75 @@ namespace vec512 {

using Tensor = at::Tensor;

template <typename T>
void _concat_bn_relu_kernel_channels_last(
template <typename T, typename ACC_T>
static void _concat_bn_relu_kernel_channels_last(
const std::vector<const T*>& in_ptr,
const std::vector<int64_t>& in_ch,
T* out_ptr,
const T* scale_ptr,
const T* beta_ptr,
const ACC_T* scale_ptr,
const ACC_T* beta_ptr,
int64_t total_size_except_channels,
int64_t ci,
int64_t co) {
int64_t i = 0, j = 0, k = 0;
auto zero = _mm512_set1_ps(0.0);
#ifdef _OPENMP
#if (_OPENMP >= 201307)
#pragma omp parallel for simd schedule( \
static) if (omp_get_max_threads() > 1 && !omp_in_parallel())
#else
#pragma omp parallel for schedule( \
static) if (omp_get_max_threads() > 1 && !omp_in_parallel())
#endif
#endif
for (i = 0; i < total_size_except_channels; ++i) {
for (j = 0; j < in_ptr.size(); ++j) {
auto concat_in_ptr = in_ptr[j] + i * in_ch[j + 1] - (i + 1) * in_ch[j];
for (k = in_ch[j]; k < in_ch[j + 1]; k += 16) {
auto in = _mm512_loadu_ps(concat_in_ptr + k);
auto beta = _mm512_loadu_ps(beta_ptr + k);
auto scale = _mm512_loadu_ps(scale_ptr + k);
auto bn_out = _mm512_add_ps(beta, _mm512_mul_ps(scale, in));
auto out = _mm512_max_ps(zero, bn_out);
_mm512_storeu_ps(out_ptr + i * co + k, out);
}
}
}
}

template <>
void _concat_bn_relu_kernel_channels_last<at::BFloat16, float>(
const std::vector<const at::BFloat16*>& in_ptr,
const std::vector<int64_t>& in_ch,
at::BFloat16* out_ptr,
const float* scale_ptr,
const float* beta_ptr,
int64_t total_size_except_channels,
int64_t ci,
int64_t co) {
int64_t i = 0, j = 0, k = 0;
auto zero = _mm512_set1_ps(0.0);
#pragma omp parallel for collapse(2)
#ifdef _OPENMP
#if (_OPENMP >= 201307)
#pragma omp parallel for simd schedule( \
static) if (omp_get_max_threads() > 1 && !omp_in_parallel())
#else
#pragma omp parallel for schedule( \
static) if (omp_get_max_threads() > 1 && !omp_in_parallel())
#endif
#endif
for (i = 0; i < total_size_except_channels; ++i) {
for (j = 0; j < in_ptr.size(); ++j) {
auto concat_in_ptr = in_ptr[j] + i * in_ch[j + 1] - (i + 1) * in_ch[j];
for (k = in_ch[j]; k < in_ch[j + 1]; k += 16) {
_mm512_store_ps(
out_ptr + i * co + k,
_mm512_max_ps(
zero,
_mm512_add_ps(
_mm512_load_ps(beta_ptr + k),
_mm512_mul_ps(
_mm512_load_ps(scale_ptr + k),
_mm512_load_ps(
in_ptr[j] + i * (in_ch[j + 1] - in_ch[j]) + k -
in_ch[j])))));
auto in =
cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(concat_in_ptr + k)));
auto beta = _mm512_loadu_ps(beta_ptr + k);
auto scale = _mm512_loadu_ps(scale_ptr + k);
auto bn_out = _mm512_add_ps(beta, _mm512_mul_ps(scale, in));
auto out = _mm512_max_ps(zero, bn_out);
_mm256_storeu_si256(
(__m256i*)(out_ptr + i * co + k), cvt_fp32_to_bf16(out));
}
}
}
Expand All @@ -57,6 +109,7 @@ void ConcatBnReluKernelImpl_ChannelsLast(
const Tensor& scale,
const Tensor& beta,
Tensor& output) {
using ACC_T = typename AccType<T>::type;
int64_t list_length = a.size();
int64_t total_size_except_channels = 1;
std::vector<const T*> input_ptr(list_length);
Expand All @@ -74,11 +127,11 @@ void ConcatBnReluKernelImpl_ChannelsLast(
total_size_except_channels *= a[0].size(i);
}

const T* scale_data = scale.data_ptr<T>();
const T* beta_data = beta.data_ptr<T>();
const ACC_T* scale_data = scale.data_ptr<ACC_T>();
const ACC_T* beta_data = beta.data_ptr<ACC_T>();
T* output_data = output.data_ptr<T>();

_concat_bn_relu_kernel_channels_last<T>(
_concat_bn_relu_kernel_channels_last<T, ACC_T>(
input_ptr,
input_channels,
output_data,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,8 @@ void FuseConcatBnRelu(std::shared_ptr<Graph>& graph) {
auto tensor1 = listConstruct->input(0)->type()->cast<TensorType>();
auto check_type_channelsize = [](c10::TensorType tensor) {
return (
tensor.scalarType().value() == at::kFloat &&
(tensor.scalarType().value() == at::kFloat ||
tensor.scalarType().value() == at::kBFloat16) &&
tensor.sizes()[1].value() % 16 == 0 && is_channelslast(tensor));
};
// Check if the dimension of the first tensor is either 4 or 5.
Expand All @@ -562,6 +563,15 @@ void FuseConcatBnRelu(std::shared_ptr<Graph>& graph) {
}
}
}
// Check if the BN weights is fp32 datatype.
auto bn_node = node->input(0)->node();
if (bn_node->namedInput("weight")
->type()
->cast<TensorType>()
->scalarType()
.value() != at::kFloat) {
return false;
}
return true;
};

Expand Down
171 changes: 44 additions & 127 deletions tests/cpu/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,33 +701,14 @@ def forward(self, x, y, z):
x = x + y + z
return self.layernorm(x)

class ConcatBnRelu2d(torch.nn.Module):
def __init__(self):
super(ConcatBnRelu2d, self).__init__()
self.bn = torch.nn.BatchNorm2d(96)
self.relu = torch.nn.ReLU()
def forward(self, x1, x2, x3):
x = torch.cat((x1, x2, x3), dim = 1)
x = self.bn(x)
return self.relu(x)

class ConcatBnRelu2d_v1(torch.nn.Module):
def __init__(self):
super(ConcatBnRelu2d_v1, self).__init__()
self.bn = torch.nn.BatchNorm2d(32)
self.relu = torch.nn.ReLU()
def forward(self, x1, x2, x3):
x = torch.cat((x1, x2, x3), dim = 2)
x = self.bn(x)
return self.relu(x)

class ConcatBnRelu3d(torch.nn.Module):
def __init__(self):
super(ConcatBnRelu3d, self).__init__()
self.bn = torch.nn.BatchNorm3d(96)
class ConcatBnRelu(torch.nn.Module):
def __init__(self, dim, cat_dim, in_channels, **kwargs):
super(ConcatBnRelu, self).__init__()
self.bn = bn_module[dim](in_channels)
self.relu = torch.nn.ReLU()
self.cat_dim = cat_dim
def forward(self, x1, x2, x3):
x = torch.cat((x1, x2, x3), dim = 1)
x = torch.cat((x1, x2, x3), dim = self.cat_dim)
x = self.bn(x)
return self.relu(x)

Expand Down Expand Up @@ -1010,114 +991,50 @@ def test_add_layernorm(self):
self.assertTrue(any(n.kind() == node for n in trace_graph.nodes()))

def test_concat_bn_relu(self):
a1 = torch.randn(1, 32, 13, 24, dtype=torch.bfloat16).contiguous(memory_format=torch.channels_last)
a2 = torch.randn(1, 32, 13, 24, dtype=torch.bfloat16).contiguous(memory_format=torch.channels_last)
a3 = torch.randn(1, 32, 13, 24, dtype=torch.bfloat16).contiguous(memory_format=torch.channels_last)
model = ConcatBnRelu2d().eval().to(memory_format=torch.channels_last)
model = ipex.optimize(model, dtype=torch.bfloat16, level='O0')
with torch.no_grad():
jit_model = torch.jit.trace(model, (a1, a2, a3)).eval()
jit_model = torch.jit.freeze(jit_model)
#warmup run
for _ in range(2):
jit_res = jit_model(a1, a2, a3)
ori_res = model(a1, a2, a3)
self.assertEqual(jit_res, ori_res)

a1 = torch.randn(1, 32, 13, 24, dtype=torch.float).contiguous(memory_format=torch.channels_last)
a2 = torch.randn(1, 32, 13, 24, dtype=torch.float).contiguous(memory_format=torch.channels_last)
a3 = torch.randn(1, 32, 13, 24, dtype=torch.float).contiguous(memory_format=torch.channels_last)
model = ConcatBnRelu2d_v1().eval().to(memory_format=torch.channels_last)
model = ipex.optimize(model, dtype=torch.float32, level='O0')
with torch.no_grad():
jit_model = torch.jit.trace(model, (a1, a2, a3)).eval()
jit_model = torch.jit.freeze(jit_model)
#warmup run
for _ in range(2):
jit_res = jit_model(a1, a2, a3)
ori_res = model(a1, a2, a3)
self.assertEqual(jit_res, ori_res)

model = ConcatBnRelu2d().eval().to(memory_format=torch.channels_last)
model = ipex.optimize(model, dtype=torch.float32, level='O0')
with torch.no_grad():
jit_model = torch.jit.trace(model, (a1, a2, a3)).eval()
jit_model = torch.jit.freeze(jit_model)
#warmup run
for _ in range(2):
jit_res = jit_model(a1, a2, a3)
ori_res = model(a1, a2, a3)
self.assertEqual(jit_res, ori_res)

a1 = torch.randn(1, 32, 18, 53, dtype=torch.float).contiguous(memory_format=torch.channels_last)
a2 = torch.randn(1, 32, 18, 53, dtype=torch.float).contiguous(memory_format=torch.channels_last)
a3 = torch.randn(1, 32, 18, 53, dtype=torch.float).contiguous(memory_format=torch.channels_last)
with torch.no_grad():
jit_res = jit_model(a1, a2, a3)
ori_res = model(a1, a2, a3)
self.assertEqual(jit_res, ori_res)

a1 = torch.randn(1, 16, 24, 116, dtype=torch.float).contiguous(memory_format=torch.channels_last)
a2 = torch.randn(1, 48, 24, 116, dtype=torch.float).contiguous(memory_format=torch.channels_last)
a3 = torch.randn(1, 32, 24, 116, dtype=torch.float).contiguous(memory_format=torch.channels_last)
with torch.no_grad():
jit_res = jit_model(a1, a2, a3)
ori_res = model(a1, a2, a3)
self.assertEqual(jit_res, ori_res)
batch_size = 3
image_size = 16
options = itertools.product([2, 3], [[32, 32, 32], [60, 60, 60], [17, 27, 32], [16, 32, 48]], [torch.float32, torch.bfloat16], ['O0', 'O1'], [True, False])
for dim, channels, dtype, level, use_channels_last in options:
input_size = [
[batch_size, channels[0], image_size, image_size],
[batch_size, channels[1], image_size, image_size],
[batch_size, channels[2], image_size, image_size]
]
if dim == 3:
for i in range(3):
input_size[i].append(image_size)
a1 = torch.randn(input_size[0], dtype=dtype)
a2 = torch.randn(input_size[1], dtype=dtype)
a3 = torch.randn(input_size[2], dtype=dtype)
a = [a1, a2, a3]

a1 = torch.randn(1, 17, 15, 24, dtype=torch.float).contiguous(memory_format=torch.channels_last)
a2 = torch.randn(1, 47, 15, 24, dtype=torch.float).contiguous(memory_format=torch.channels_last)
a3 = torch.randn(1, 32, 15, 24, dtype=torch.float).contiguous(memory_format=torch.channels_last)
with torch.no_grad():
jit_res = jit_model(a1, a2, a3)
ori_res = model(a1, a2, a3)
self.assertEqual(jit_res, ori_res)
in_channels = sum(channels)
model = ConcatBnRelu(dim, 1, in_channels).eval()

a1 = torch.randn(1, 32, 13, 24, dtype=torch.float)
a2 = torch.randn(1, 32, 13, 24, dtype=torch.float)
a3 = torch.randn(1, 32, 13, 24, dtype=torch.float)
with torch.no_grad():
jit_res = jit_model(a1, a2, a3)
ori_res = model(a1, a2, a3)
self.assertEqual(jit_res, ori_res)
if use_channels_last:
suggest_memory_format = torch.channels_last if dim == 2 else torch.channels_last_3d
for i in range(3):
a[i] = a[i].to(memory_format=suggest_memory_format)
model = model.to(memory_format=suggest_memory_format)

a1 = torch.randn(1, 32, 13, 24, 33, dtype=torch.float).contiguous(memory_format=torch.channels_last_3d)
a2 = torch.randn(1, 32, 13, 24, 33, dtype=torch.float).contiguous(memory_format=torch.channels_last_3d)
a3 = torch.randn(1, 32, 13, 24, 33, dtype=torch.float).contiguous(memory_format=torch.channels_last_3d)
model = ConcatBnRelu3d().eval().to(memory_format=torch.channels_last_3d)
model = ipex.optimize(model, dtype=torch.float32, level='O0')
with torch.no_grad():
jit_model = torch.jit.trace(model, (a1, a2, a3)).eval()
jit_model = torch.jit.freeze(jit_model)
#warmup run
for _ in range(2):
jit_res = jit_model(a1, a2, a3)
ori_res = model(a1, a2, a3)
self.assertEqual(jit_res, ori_res)
model = ipex.optimize(model, dtype=dtype, level=level)

a1 = torch.randn(1, 16, 17, 14, 31, dtype=torch.float).contiguous(memory_format=torch.channels_last_3d)
a2 = torch.randn(1, 48, 17, 14, 31, dtype=torch.float).contiguous(memory_format=torch.channels_last_3d)
a3 = torch.randn(1, 32, 17, 14, 31, dtype=torch.float).contiguous(memory_format=torch.channels_last_3d)
with torch.no_grad():
jit_res = jit_model(a1, a2, a3)
ori_res = model(a1, a2, a3)
self.assertEqual(jit_res, ori_res)
with torch.cpu.amp.autocast(enabled=True if dtype == torch.bfloat16 else False), torch.no_grad():
result = model(a[0], a[1], a[2])
trace_model = torch.jit.trace(model, (a[0], a[1], a[2])).eval()
trace_model = torch.jit.freeze(trace_model)

a1 = torch.randn(1, 17, 13, 24, 33, dtype=torch.float).contiguous(memory_format=torch.channels_last_3d)
a2 = torch.randn(1, 47, 13, 24, 33, dtype=torch.float).contiguous(memory_format=torch.channels_last_3d)
a3 = torch.randn(1, 32, 13, 24, 33, dtype=torch.float).contiguous(memory_format=torch.channels_last_3d)
with torch.no_grad():
jit_res = jit_model(a1, a2, a3)
ori_res = model(a1, a2, a3)
self.assertEqual(jit_res, ori_res)
tresult = trace_model(a[0], a[1], a[2])
trace_graph = trace_model.graph_for(a[0], a[1], a[2])

a1 = torch.randn(1, 32, 13, 24, 33, dtype=torch.float)
a2 = torch.randn(1, 32, 13, 24, 33, dtype=torch.float)
a3 = torch.randn(1, 32, 13, 24, 33, dtype=torch.float)
with torch.no_grad():
jit_res = jit_model(a1, a2, a3)
ori_res = model(a1, a2, a3)
self.assertEqual(jit_res, ori_res)
self.assertEqual(result, tresult)
self.assertEqual(tresult.dtype, dtype)
if use_channels_last:
self.assertTrue(tresult.is_contiguous(memory_format=suggest_memory_format))
if use_channels_last and a1.size(1) % 16 == 0 and a2.size(1) % 16 == 0 and a3.size(1) % 16 == 0 :
self.assertTrue(any(n.kind() == "ipex::concat_bn_relu" for n in trace_graph.nodes()))
else:
self.assertTrue(all(n.kind() != "ipex::concat_bn_relu" for n in trace_graph.nodes()))

def test_mha_scores_calculation(self):
def _check_match_mha(trace_model, mat1, mat2, bias, node = "ipex::mha_scores_calc"):
Expand Down

0 comments on commit cad3f82

Please sign in to comment.