diff --git a/intel_extension_for_pytorch/csrc/aten/cpu/kernels/ConcatBnReluKrnl.cpp b/intel_extension_for_pytorch/csrc/aten/cpu/kernels/ConcatBnReluKrnl.cpp index fe8a66657..4c42b5e5a 100644 --- a/intel_extension_for_pytorch/csrc/aten/cpu/kernels/ConcatBnReluKrnl.cpp +++ b/intel_extension_for_pytorch/csrc/aten/cpu/kernels/ConcatBnReluKrnl.cpp @@ -53,13 +53,24 @@ at::Tensor concat_bn_relu_kernel_impl( } #if defined(CPU_CAPABILITY_AVX512) if (tensor_check) { - at::Tensor output = at::empty( - output_dim, - a[0].options() - .dtype(at::kFloat) - .memory_format(a[0].suggest_memory_format())); - torch_ipex::cpu::kernel::vec::vec512::ConcatBnReluKernelImpl_ChannelsLast< - float>(a, bn_scale, bn_beta, output); + at::Tensor output; + if (a[0].scalar_type() == at::kBFloat16) { + output = at::empty( + output_dim, + a[0].options() + .dtype(at::kBFloat16) + .memory_format(a[0].suggest_memory_format())); + torch_ipex::cpu::kernel::vec::vec512::ConcatBnReluKernelImpl_ChannelsLast< + at::BFloat16>(a, bn_scale, bn_beta, output); + } else { + output = at::empty( + output_dim, + a[0].options() + .dtype(at::kFloat) + .memory_format(a[0].suggest_memory_format())); + torch_ipex::cpu::kernel::vec::vec512::ConcatBnReluKernelImpl_ChannelsLast< + float>(a, bn_scale, bn_beta, output); + } return output; } #endif diff --git a/intel_extension_for_pytorch/csrc/cpu/vec512/concat_bn_relu.h b/intel_extension_for_pytorch/csrc/cpu/vec512/concat_bn_relu.h index 4d1d4b1fa..7f181a05b 100644 --- a/intel_extension_for_pytorch/csrc/cpu/vec512/concat_bn_relu.h +++ b/intel_extension_for_pytorch/csrc/cpu/vec512/concat_bn_relu.h @@ -9,6 +9,16 @@ #include #include "utils.h" +// use float as accumulation type for BFloat16 +template +struct AccType { + using type = scalar_t; +}; +template <> +struct AccType { + using type = float; +}; + namespace torch_ipex { namespace cpu { namespace kernel { @@ -17,33 +27,75 @@ namespace vec512 { using Tensor = at::Tensor; -template -void _concat_bn_relu_kernel_channels_last( +template +static void _concat_bn_relu_kernel_channels_last( const std::vector& in_ptr, const std::vector& in_ch, T* out_ptr, - const T* scale_ptr, - const T* beta_ptr, + const ACC_T* scale_ptr, + const ACC_T* beta_ptr, + int64_t total_size_except_channels, + int64_t ci, + int64_t co) { + int64_t i = 0, j = 0, k = 0; + auto zero = _mm512_set1_ps(0.0); +#ifdef _OPENMP +#if (_OPENMP >= 201307) +#pragma omp parallel for simd schedule( \ + static) if (omp_get_max_threads() > 1 && !omp_in_parallel()) +#else +#pragma omp parallel for schedule( \ + static) if (omp_get_max_threads() > 1 && !omp_in_parallel()) +#endif +#endif + for (i = 0; i < total_size_except_channels; ++i) { + for (j = 0; j < in_ptr.size(); ++j) { + auto concat_in_ptr = in_ptr[j] + i * in_ch[j + 1] - (i + 1) * in_ch[j]; + for (k = in_ch[j]; k < in_ch[j + 1]; k += 16) { + auto in = _mm512_loadu_ps(concat_in_ptr + k); + auto beta = _mm512_loadu_ps(beta_ptr + k); + auto scale = _mm512_loadu_ps(scale_ptr + k); + auto bn_out = _mm512_add_ps(beta, _mm512_mul_ps(scale, in)); + auto out = _mm512_max_ps(zero, bn_out); + _mm512_storeu_ps(out_ptr + i * co + k, out); + } + } + } +} + +template <> +void _concat_bn_relu_kernel_channels_last( + const std::vector& in_ptr, + const std::vector& in_ch, + at::BFloat16* out_ptr, + const float* scale_ptr, + const float* beta_ptr, int64_t total_size_except_channels, int64_t ci, int64_t co) { int64_t i = 0, j = 0, k = 0; auto zero = _mm512_set1_ps(0.0); -#pragma omp parallel for collapse(2) +#ifdef _OPENMP +#if (_OPENMP >= 201307) +#pragma omp parallel for simd schedule( \ + static) if (omp_get_max_threads() > 1 && !omp_in_parallel()) +#else +#pragma omp parallel for schedule( \ + static) if (omp_get_max_threads() > 1 && !omp_in_parallel()) +#endif +#endif for (i = 0; i < total_size_except_channels; ++i) { for (j = 0; j < in_ptr.size(); ++j) { + auto concat_in_ptr = in_ptr[j] + i * in_ch[j + 1] - (i + 1) * in_ch[j]; for (k = in_ch[j]; k < in_ch[j + 1]; k += 16) { - _mm512_store_ps( - out_ptr + i * co + k, - _mm512_max_ps( - zero, - _mm512_add_ps( - _mm512_load_ps(beta_ptr + k), - _mm512_mul_ps( - _mm512_load_ps(scale_ptr + k), - _mm512_load_ps( - in_ptr[j] + i * (in_ch[j + 1] - in_ch[j]) + k - - in_ch[j]))))); + auto in = + cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(concat_in_ptr + k))); + auto beta = _mm512_loadu_ps(beta_ptr + k); + auto scale = _mm512_loadu_ps(scale_ptr + k); + auto bn_out = _mm512_add_ps(beta, _mm512_mul_ps(scale, in)); + auto out = _mm512_max_ps(zero, bn_out); + _mm256_storeu_si256( + (__m256i*)(out_ptr + i * co + k), cvt_fp32_to_bf16(out)); } } } @@ -57,6 +109,7 @@ void ConcatBnReluKernelImpl_ChannelsLast( const Tensor& scale, const Tensor& beta, Tensor& output) { + using ACC_T = typename AccType::type; int64_t list_length = a.size(); int64_t total_size_except_channels = 1; std::vector input_ptr(list_length); @@ -74,11 +127,11 @@ void ConcatBnReluKernelImpl_ChannelsLast( total_size_except_channels *= a[0].size(i); } - const T* scale_data = scale.data_ptr(); - const T* beta_data = beta.data_ptr(); + const ACC_T* scale_data = scale.data_ptr(); + const ACC_T* beta_data = beta.data_ptr(); T* output_data = output.data_ptr(); - _concat_bn_relu_kernel_channels_last( + _concat_bn_relu_kernel_channels_last( input_ptr, input_channels, output_data, diff --git a/intel_extension_for_pytorch/csrc/jit/cpu/passes/graph_rewrite.cpp b/intel_extension_for_pytorch/csrc/jit/cpu/passes/graph_rewrite.cpp index 2fe264d7b..9bc837e46 100644 --- a/intel_extension_for_pytorch/csrc/jit/cpu/passes/graph_rewrite.cpp +++ b/intel_extension_for_pytorch/csrc/jit/cpu/passes/graph_rewrite.cpp @@ -536,7 +536,8 @@ void FuseConcatBnRelu(std::shared_ptr& graph) { auto tensor1 = listConstruct->input(0)->type()->cast(); auto check_type_channelsize = [](c10::TensorType tensor) { return ( - tensor.scalarType().value() == at::kFloat && + (tensor.scalarType().value() == at::kFloat || + tensor.scalarType().value() == at::kBFloat16) && tensor.sizes()[1].value() % 16 == 0 && is_channelslast(tensor)); }; // Check if the dimension of the first tensor is either 4 or 5. @@ -562,6 +563,15 @@ void FuseConcatBnRelu(std::shared_ptr& graph) { } } } + // Check if the BN weights is fp32 datatype. + auto bn_node = node->input(0)->node(); + if (bn_node->namedInput("weight") + ->type() + ->cast() + ->scalarType() + .value() != at::kFloat) { + return false; + } return true; }; diff --git a/tests/cpu/test_jit.py b/tests/cpu/test_jit.py index 99555ea81..003435b81 100644 --- a/tests/cpu/test_jit.py +++ b/tests/cpu/test_jit.py @@ -701,33 +701,14 @@ def forward(self, x, y, z): x = x + y + z return self.layernorm(x) -class ConcatBnRelu2d(torch.nn.Module): - def __init__(self): - super(ConcatBnRelu2d, self).__init__() - self.bn = torch.nn.BatchNorm2d(96) - self.relu = torch.nn.ReLU() - def forward(self, x1, x2, x3): - x = torch.cat((x1, x2, x3), dim = 1) - x = self.bn(x) - return self.relu(x) - -class ConcatBnRelu2d_v1(torch.nn.Module): - def __init__(self): - super(ConcatBnRelu2d_v1, self).__init__() - self.bn = torch.nn.BatchNorm2d(32) - self.relu = torch.nn.ReLU() - def forward(self, x1, x2, x3): - x = torch.cat((x1, x2, x3), dim = 2) - x = self.bn(x) - return self.relu(x) - -class ConcatBnRelu3d(torch.nn.Module): - def __init__(self): - super(ConcatBnRelu3d, self).__init__() - self.bn = torch.nn.BatchNorm3d(96) +class ConcatBnRelu(torch.nn.Module): + def __init__(self, dim, cat_dim, in_channels, **kwargs): + super(ConcatBnRelu, self).__init__() + self.bn = bn_module[dim](in_channels) self.relu = torch.nn.ReLU() + self.cat_dim = cat_dim def forward(self, x1, x2, x3): - x = torch.cat((x1, x2, x3), dim = 1) + x = torch.cat((x1, x2, x3), dim = self.cat_dim) x = self.bn(x) return self.relu(x) @@ -1010,114 +991,50 @@ def test_add_layernorm(self): self.assertTrue(any(n.kind() == node for n in trace_graph.nodes())) def test_concat_bn_relu(self): - a1 = torch.randn(1, 32, 13, 24, dtype=torch.bfloat16).contiguous(memory_format=torch.channels_last) - a2 = torch.randn(1, 32, 13, 24, dtype=torch.bfloat16).contiguous(memory_format=torch.channels_last) - a3 = torch.randn(1, 32, 13, 24, dtype=torch.bfloat16).contiguous(memory_format=torch.channels_last) - model = ConcatBnRelu2d().eval().to(memory_format=torch.channels_last) - model = ipex.optimize(model, dtype=torch.bfloat16, level='O0') - with torch.no_grad(): - jit_model = torch.jit.trace(model, (a1, a2, a3)).eval() - jit_model = torch.jit.freeze(jit_model) -#warmup run - for _ in range(2): - jit_res = jit_model(a1, a2, a3) - ori_res = model(a1, a2, a3) - self.assertEqual(jit_res, ori_res) - - a1 = torch.randn(1, 32, 13, 24, dtype=torch.float).contiguous(memory_format=torch.channels_last) - a2 = torch.randn(1, 32, 13, 24, dtype=torch.float).contiguous(memory_format=torch.channels_last) - a3 = torch.randn(1, 32, 13, 24, dtype=torch.float).contiguous(memory_format=torch.channels_last) - model = ConcatBnRelu2d_v1().eval().to(memory_format=torch.channels_last) - model = ipex.optimize(model, dtype=torch.float32, level='O0') - with torch.no_grad(): - jit_model = torch.jit.trace(model, (a1, a2, a3)).eval() - jit_model = torch.jit.freeze(jit_model) -#warmup run - for _ in range(2): - jit_res = jit_model(a1, a2, a3) - ori_res = model(a1, a2, a3) - self.assertEqual(jit_res, ori_res) - - model = ConcatBnRelu2d().eval().to(memory_format=torch.channels_last) - model = ipex.optimize(model, dtype=torch.float32, level='O0') - with torch.no_grad(): - jit_model = torch.jit.trace(model, (a1, a2, a3)).eval() - jit_model = torch.jit.freeze(jit_model) -#warmup run - for _ in range(2): - jit_res = jit_model(a1, a2, a3) - ori_res = model(a1, a2, a3) - self.assertEqual(jit_res, ori_res) - - a1 = torch.randn(1, 32, 18, 53, dtype=torch.float).contiguous(memory_format=torch.channels_last) - a2 = torch.randn(1, 32, 18, 53, dtype=torch.float).contiguous(memory_format=torch.channels_last) - a3 = torch.randn(1, 32, 18, 53, dtype=torch.float).contiguous(memory_format=torch.channels_last) - with torch.no_grad(): - jit_res = jit_model(a1, a2, a3) - ori_res = model(a1, a2, a3) - self.assertEqual(jit_res, ori_res) - - a1 = torch.randn(1, 16, 24, 116, dtype=torch.float).contiguous(memory_format=torch.channels_last) - a2 = torch.randn(1, 48, 24, 116, dtype=torch.float).contiguous(memory_format=torch.channels_last) - a3 = torch.randn(1, 32, 24, 116, dtype=torch.float).contiguous(memory_format=torch.channels_last) - with torch.no_grad(): - jit_res = jit_model(a1, a2, a3) - ori_res = model(a1, a2, a3) - self.assertEqual(jit_res, ori_res) + batch_size = 3 + image_size = 16 + options = itertools.product([2, 3], [[32, 32, 32], [60, 60, 60], [17, 27, 32], [16, 32, 48]], [torch.float32, torch.bfloat16], ['O0', 'O1'], [True, False]) + for dim, channels, dtype, level, use_channels_last in options: + input_size = [ + [batch_size, channels[0], image_size, image_size], + [batch_size, channels[1], image_size, image_size], + [batch_size, channels[2], image_size, image_size] + ] + if dim == 3: + for i in range(3): + input_size[i].append(image_size) + a1 = torch.randn(input_size[0], dtype=dtype) + a2 = torch.randn(input_size[1], dtype=dtype) + a3 = torch.randn(input_size[2], dtype=dtype) + a = [a1, a2, a3] - a1 = torch.randn(1, 17, 15, 24, dtype=torch.float).contiguous(memory_format=torch.channels_last) - a2 = torch.randn(1, 47, 15, 24, dtype=torch.float).contiguous(memory_format=torch.channels_last) - a3 = torch.randn(1, 32, 15, 24, dtype=torch.float).contiguous(memory_format=torch.channels_last) - with torch.no_grad(): - jit_res = jit_model(a1, a2, a3) - ori_res = model(a1, a2, a3) - self.assertEqual(jit_res, ori_res) + in_channels = sum(channels) + model = ConcatBnRelu(dim, 1, in_channels).eval() - a1 = torch.randn(1, 32, 13, 24, dtype=torch.float) - a2 = torch.randn(1, 32, 13, 24, dtype=torch.float) - a3 = torch.randn(1, 32, 13, 24, dtype=torch.float) - with torch.no_grad(): - jit_res = jit_model(a1, a2, a3) - ori_res = model(a1, a2, a3) - self.assertEqual(jit_res, ori_res) + if use_channels_last: + suggest_memory_format = torch.channels_last if dim == 2 else torch.channels_last_3d + for i in range(3): + a[i] = a[i].to(memory_format=suggest_memory_format) + model = model.to(memory_format=suggest_memory_format) - a1 = torch.randn(1, 32, 13, 24, 33, dtype=torch.float).contiguous(memory_format=torch.channels_last_3d) - a2 = torch.randn(1, 32, 13, 24, 33, dtype=torch.float).contiguous(memory_format=torch.channels_last_3d) - a3 = torch.randn(1, 32, 13, 24, 33, dtype=torch.float).contiguous(memory_format=torch.channels_last_3d) - model = ConcatBnRelu3d().eval().to(memory_format=torch.channels_last_3d) - model = ipex.optimize(model, dtype=torch.float32, level='O0') - with torch.no_grad(): - jit_model = torch.jit.trace(model, (a1, a2, a3)).eval() - jit_model = torch.jit.freeze(jit_model) -#warmup run - for _ in range(2): - jit_res = jit_model(a1, a2, a3) - ori_res = model(a1, a2, a3) - self.assertEqual(jit_res, ori_res) + model = ipex.optimize(model, dtype=dtype, level=level) - a1 = torch.randn(1, 16, 17, 14, 31, dtype=torch.float).contiguous(memory_format=torch.channels_last_3d) - a2 = torch.randn(1, 48, 17, 14, 31, dtype=torch.float).contiguous(memory_format=torch.channels_last_3d) - a3 = torch.randn(1, 32, 17, 14, 31, dtype=torch.float).contiguous(memory_format=torch.channels_last_3d) - with torch.no_grad(): - jit_res = jit_model(a1, a2, a3) - ori_res = model(a1, a2, a3) - self.assertEqual(jit_res, ori_res) + with torch.cpu.amp.autocast(enabled=True if dtype == torch.bfloat16 else False), torch.no_grad(): + result = model(a[0], a[1], a[2]) + trace_model = torch.jit.trace(model, (a[0], a[1], a[2])).eval() + trace_model = torch.jit.freeze(trace_model) - a1 = torch.randn(1, 17, 13, 24, 33, dtype=torch.float).contiguous(memory_format=torch.channels_last_3d) - a2 = torch.randn(1, 47, 13, 24, 33, dtype=torch.float).contiguous(memory_format=torch.channels_last_3d) - a3 = torch.randn(1, 32, 13, 24, 33, dtype=torch.float).contiguous(memory_format=torch.channels_last_3d) - with torch.no_grad(): - jit_res = jit_model(a1, a2, a3) - ori_res = model(a1, a2, a3) - self.assertEqual(jit_res, ori_res) + tresult = trace_model(a[0], a[1], a[2]) + trace_graph = trace_model.graph_for(a[0], a[1], a[2]) - a1 = torch.randn(1, 32, 13, 24, 33, dtype=torch.float) - a2 = torch.randn(1, 32, 13, 24, 33, dtype=torch.float) - a3 = torch.randn(1, 32, 13, 24, 33, dtype=torch.float) - with torch.no_grad(): - jit_res = jit_model(a1, a2, a3) - ori_res = model(a1, a2, a3) - self.assertEqual(jit_res, ori_res) + self.assertEqual(result, tresult) + self.assertEqual(tresult.dtype, dtype) + if use_channels_last: + self.assertTrue(tresult.is_contiguous(memory_format=suggest_memory_format)) + if use_channels_last and a1.size(1) % 16 == 0 and a2.size(1) % 16 == 0 and a3.size(1) % 16 == 0 : + self.assertTrue(any(n.kind() == "ipex::concat_bn_relu" for n in trace_graph.nodes())) + else: + self.assertTrue(all(n.kind() != "ipex::concat_bn_relu" for n in trace_graph.nodes())) def test_mha_scores_calculation(self): def _check_match_mha(trace_model, mat1, mat2, bias, node = "ipex::mha_scores_calc"):