diff --git a/include/small/SoftMaxLayer.hpp b/include/small/SoftMaxLayer.hpp index f056770..6790f5c 100644 --- a/include/small/SoftMaxLayer.hpp +++ b/include/small/SoftMaxLayer.hpp @@ -70,4 +70,55 @@ class SoftMaxLayer : public Layer } }; +//**************************************************************************** +template +class LogSoftMaxLayer : public Layer +{ +public: + typedef typename BufferT::value_type value_type; + + LogSoftMaxLayer(shape_type const &input_shape) + : Layer(input_shape) // input_shape == output_shape + { +#if defined(DEBUG_LAYERS) + auto const &output_shape(this->output_shape()); + std::cerr << "LogSoftMax(batches:" << output_shape[BATCH] + << ",chans:" << output_shape[CHANNEL] + << ",img:" << output_shape[HEIGHT] + << "x" << output_shape[WIDTH] + << ")" << std::endl; +#endif + } + + virtual ~LogSoftMaxLayer() {} + + virtual void compute_output( + std::vector const *> input, + Tensor* output) const + { + if ((input.size() != 1) || (input[0]->shape() != this->output_shape())) + { + throw std::invalid_argument( + "LogSoftMaxLayer::compute_output() ERROR: " + "incorrect input buffer shape."); + } + + if (output->capacity() < this->output_size()) + { + throw std::invalid_argument( + "LogSoftMaxLayer::compute_output() ERROR: " + "insufficient output buffer space."); + } + + auto const &output_shape(this->output_shape()); + + small::LogSoftMax(output_shape[CHANNEL], + output_shape[HEIGHT], output_shape[WIDTH], + input[0]->buffer(), + output->buffer()); + + output->set_shape(output_shape); + } +}; + } diff --git a/include/small/float_detail/abstract_layer.hpp b/include/small/float_detail/abstract_layer.hpp index a84bf5b..4e18134 100644 --- a/include/small/float_detail/abstract_layer.hpp +++ b/include/small/float_detail/abstract_layer.hpp @@ -78,7 +78,7 @@ void abstract_layer( /// @todo add B (batch size) param? ScalarT const *I_buf = I->data(); //__restrict__ ? ScalarT const *F_buf = nullptr; - if constexpr (op_type == OP_CONV || op_type == OP_LEAKY_RELU || op_type == OP_MUL) // if (F != nullptr) + if constexpr (op_type == OP_CONV || op_type == OP_LEAKY_RELU || op_type == OP_MUL || op_type == OP_EWISE_ADD_SCALAR) // if (F != nullptr) { F_buf = F->data(); } @@ -288,7 +288,7 @@ void abstract_layer( /// @todo add B (batch size) param? // if leaky relu, the weight pointer does not change with the group id ScalarT const *F_group; - if constexpr ((op_type == OP_LEAKY_RELU) || (op_type == OP_MUL)) + if constexpr ((op_type == OP_LEAKY_RELU) || (op_type == OP_MUL) || (op_type == OP_EWISE_ADD_SCALAR)) { F_group = F_buf; } @@ -315,7 +315,7 @@ void abstract_layer( /// @todo add B (batch size) param? // Loop over input channel reduction for (index_t i = 0; i < (F_c / _F_cb); i++) { - bool first = rewrite_output && (i == 0); + bool first = (rewrite_output || op_type == OP_EWISE_ADD_SCALAR) && (i == 0); ScalarT const *I_channel_block_input = I_channel_block_output + i * (I_h * I_w * _F_cb * _G_b); diff --git a/include/small/float_detail/abstract_op.hpp b/include/small/float_detail/abstract_op.hpp index 4d4570d..bb0c1ce 100644 --- a/include/small/float_detail/abstract_op.hpp +++ b/include/small/float_detail/abstract_op.hpp @@ -46,14 +46,19 @@ namespace float_detail { \ FLOAT_ACCUM_TILE_C(step, a_cur, O_wb, C_ob); \ } \ + else if constexpr (op_type == OP_EWISE_ADD_SCALAR) \ + { \ + float scalar = b_cur[0]; \ + FLOAT_EWISE_ADD_SCALAR_TILE_C(scalar, O_wb, C_ob); \ + } \ else if constexpr (op_type == OP_MUL) \ { \ float drop_out_rate = b_cur[0]; \ - FLOAT_DIV_TILE_C(drop_out_rate, O_wb, C_ob) \ + FLOAT_DIV_TILE_C(drop_out_rate, O_wb, C_ob); \ } \ else if constexpr (op_type == OP_EXP) \ { \ - FLOAT_EXP_TILE_C(step, a_cur, O_wb, C_ob) \ + FLOAT_EXP_TILE_C(step, a_cur, O_wb, C_ob); \ } //**************************************************************************** @@ -80,7 +85,12 @@ namespace float_detail else if constexpr (op_type == OP_ADD || op_type == OP_AVERAGE_POOL) \ { \ FLOAT_ACCUM_END_C(step, a_cur, c_cur, W_elements, C_ob); \ - } \ + } \ + else if constexpr (op_type == OP_EWISE_ADD_SCALAR) \ + { \ + float scalar = b_cur[0]; \ + FLOAT_EWISE_ADD_SCALAR_END_C(c_cur, scalar, W_elements, C_ob); \ + } \ else if constexpr (op_type == OP_MUL) \ { \ float drop_out_rate = b_cur[0]; \ @@ -105,6 +115,11 @@ namespace float_detail { \ FLOAT_ACCUM_TILE_C(step, b_cur, O_wb, C_ob); \ } \ + else if constexpr (op_type == OP_EWISE_ADD_SCALAR) \ + { \ + float scalar = b_cur[0]; \ + FLOAT_EWISE_ADD_SCALAR_TILE_C(scalar, O_wb, C_ob); \ + } \ else if constexpr (op_type == OP_MUL) \ { \ float drop_out_rate = b_cur[0]; \ @@ -129,6 +144,11 @@ namespace float_detail { \ FLOAT_ACCUM_END_C(step, b_cur, c_cur, W_elements, C_ob); \ } \ + else if constexpr (op_type == OP_EWISE_ADD_SCALAR) \ + { \ + float scalar = b_cur[0]; \ + FLOAT_EWISE_ADD_SCALAR_END_C(c_cur, scalar, W_elements, C_ob); \ + } \ else if constexpr (op_type == OP_MUL) \ { \ float drop_out_rate = b_cur[0]; \ diff --git a/include/small/float_detail/kernel.hpp b/include/small/float_detail/kernel.hpp index 78170b6..9eff116 100644 --- a/include/small/float_detail/kernel.hpp +++ b/include/small/float_detail/kernel.hpp @@ -65,19 +65,26 @@ void inline kernel( if (first) { FLOAT_ZERO_TILE_C(_O_wb, _C_ob); - if (op_type == OP_MAX_POOL || op_type == OP_MUL) + if constexpr(op_type == OP_MAX_POOL || op_type == OP_MUL || op_type == OP_EWISE_ADD_SCALAR) { /// @note using platform C_ob FLOAT_LOAD_TILE_C_strided(I, step, _O_wb, FLOAT_C_ob); } - else if (op_type == OP_UPSAMPLE) + else if constexpr(op_type == OP_UPSAMPLE) { FLOAT_LOAD_TILE_C_upsample(I, _stride, _C_ib, _O_wb, _C_ob); } } else { - FLOAT_LOAD_TILE_C(O, _O_wb, _C_ob); + if constexpr(op_type == OP_EWISE_ADD_SCALAR) + { + FLOAT_LOAD_TILE_C_strided(I, step, _O_wb, FLOAT_C_ob); + } + else + { + FLOAT_LOAD_TILE_C(O, _O_wb, _C_ob); + } if constexpr (op_type == OP_UPSAMPLE) { FLOAT_ACCUM_TILE_C_upsample(I, _stride, _C_ib, _O_wb, _C_ob); diff --git a/include/small/float_detail/kernel_right.hpp b/include/small/float_detail/kernel_right.hpp index 02946af..5654e2c 100644 --- a/include/small/float_detail/kernel_right.hpp +++ b/include/small/float_detail/kernel_right.hpp @@ -68,7 +68,7 @@ void inline kernel_right( { FLOAT_ZERO_END_C(O_w_left, _C_ob); - if ( (op_type == OP_MUL)|| (op_type == OP_MAX_POOL && H_lb == 0 && H_ub == 0)) + if ( (op_type == OP_MUL)|| (op_type == OP_EWISE_ADD_SCALAR) || (op_type == OP_MAX_POOL && H_lb == 0 && H_ub == 0)) { FLOAT_LOAD_END_C_strided(I, step, O_w_left, _C_ob); } @@ -84,7 +84,14 @@ void inline kernel_right( { FLOAT_ZERO_END_C(O_w_left, _C_ob); } - FLOAT_LOAD_END_C(O, O_w_left, _C_ob); + if constexpr(op_type == OP_EWISE_ADD_SCALAR) + { + FLOAT_LOAD_END_C_strided(I, step, O_w_left, _C_ob); + } + else + { + FLOAT_LOAD_END_C(O, O_w_left, _C_ob); + } if constexpr (op_type == OP_UPSAMPLE) { FLOAT_ACCUM_END_C_upsample(I, _stride, _C_ib, O_w_left, _C_ob); @@ -140,7 +147,7 @@ void inline kernel_right( //@note padding should always be 'v' for pointwise operations, // so this code path should not be used - if (op_type == OP_MUL) + if (op_type == OP_MUL || op_type == OP_EWISE_ADD_SCALAR) { FLOAT_LOAD_END_C_strided(I_ptr, step, r_pad_el, _C_ob); } diff --git a/include/small/float_detail/kernel_right_1D.hpp b/include/small/float_detail/kernel_right_1D.hpp index e2a5eae..5d6260a 100644 --- a/include/small/float_detail/kernel_right_1D.hpp +++ b/include/small/float_detail/kernel_right_1D.hpp @@ -68,7 +68,7 @@ void inline kernel_right_1D( { FLOAT_ZERO_END_C(O_w_left, _C_ob); - if ((op_type == OP_MUL) || (op_type == OP_MAX_POOL)) // && H_lb == 0 && H_ub == 0)) + if ((op_type == OP_MUL) || (op_type == OP_EWISE_ADD_SCALAR)|| (op_type == OP_MAX_POOL)) // && H_lb == 0 && H_ub == 0)) { FLOAT_LOAD_END_C_strided(I, step, O_w_left, _C_ob); } @@ -140,7 +140,7 @@ void inline kernel_right_1D( //@note padding should always be 'v' for pointwise operations, // so this code path should not be used - if (op_type == OP_MUL) + if constexpr(op_type == OP_MUL || op_type == OP_EWISE_ADD_SCALAR) { FLOAT_LOAD_END_C_strided(I_ptr, step, r_pad_el, _C_ob); } diff --git a/include/small/interface_abstract.hpp b/include/small/interface_abstract.hpp index 3e92aec..22bff3e 100644 --- a/include/small/interface_abstract.hpp +++ b/include/small/interface_abstract.hpp @@ -1833,6 +1833,71 @@ void SoftMax(int input_channels, } #endif +//============================================================================ +#if defined(SMALL_HAS_FLOAT_SUPPORT) +template ::value, bool> = true> +void LogSoftMax(int input_channels, + int input_height, int input_width, + BufferT const &input_buf, + BufferT &output_buf) +{ +#if defined(RECORD_CALLS) + std::cout << "LogSoftMax(chans:" << input_channels + << ",img:" << input_height << "x" << input_width + << ",I,O)\n"; +#endif + + if (input_channels % FLOAT_C_ib == 0) + { + // LogSoftMax is a point-wise ADD of input to a global ADD of point-wise exp + + // point-wise exponent + float_detail::abstract_layer< + FloatBuffer, FLOAT_C_ob, 1, 1, FLOAT_W_ob, 1, 1, OP_EXP, 0, 1>( + input_channels, // Output Channel Grouping + 1, // Output Channels per group + 1, + input_height, input_width, + 1, 1, + 0, 0, 0, 0, + &input_buf, (FloatBuffer *)nullptr, &output_buf); + + // global sum + FloatBuffer softmax_norm_buf(1); + float_detail::abstract_layer< + FloatBuffer, 1, 1, FLOAT_C_ob, FLOAT_W_ob, 1, FLOAT_C_ob, OP_ADD, 3, 1>( + 1, // Output Channel Grouping + 1, // Output Channels per group + input_channels, + input_height, input_width, + input_height, input_width, + 0, 0, 0, 0, + &output_buf, (FloatBuffer *)nullptr, &softmax_norm_buf); + + softmax_norm_buf.data()[0] = -std::log(softmax_norm_buf.data()[0]); + + // element-wise shift + float_detail::abstract_layer< + FloatBuffer, FLOAT_C_ob, 1, 1, FLOAT_W_ob, 1, 1, OP_EWISE_ADD_SCALAR, 0, 0>( + input_channels, // Output Channel Grouping + 1, // Output Channels per group + 1, + input_height, input_width, + 1, 1, + 0, 0, 0, 0, + &input_buf, &softmax_norm_buf, &output_buf); + + } + else + { + throw std::invalid_argument( + "SoftMax ERROR: in_channels unsupported."); + } +} +#endif + //**************************************************************************** //**************************************************************************** // nearest neighbor upsampling diff --git a/include/small/models/Resnet1D.hpp b/include/small/models/Resnet1D.hpp new file mode 100644 index 0000000..e69de29 diff --git a/include/small/op_type.hpp b/include/small/op_type.hpp index 7015754..934b422 100644 --- a/include/small/op_type.hpp +++ b/include/small/op_type.hpp @@ -27,6 +27,7 @@ enum OpType OP_MUL = 6, OP_UPSAMPLE = 7, // 'u' OP_EXP = 8, + OP_EWISE_ADD_SCALAR = 9, OP_NONE = -1 }; diff --git a/include/small/platforms/arm/intrinsics.h b/include/small/platforms/arm/intrinsics.h index 6a40b0c..6332490 100644 --- a/include/small/platforms/arm/intrinsics.h +++ b/include/small/platforms/arm/intrinsics.h @@ -714,6 +714,66 @@ else\ } #endif +//**************************************************************************** +// Broadcast addition kernels +//**************************************************************************** + +#define FLOAT_EWISE_ADD_SCALAR_TILE_C(scalar, W_ob, C_ob) \ + float32x4_t av; \ + av = vld1q_dup_f32(&scalar);\ + c_0_0 = vaddq_f32(c_0_0, av); \ + c_0_1 = vaddq_f32(c_0_1, av); \ + c_0_2 = vaddq_f32(c_0_2, av); \ + c_0_3 = vaddq_f32(c_0_3, av); \ + c_1_0 = vaddq_f32(c_1_0, av); \ + c_1_1 = vaddq_f32(c_1_1, av); \ + c_1_2 = vaddq_f32(c_1_2, av); \ + c_1_3 = vaddq_f32(c_1_3, av); \ + c_2_0 = vaddq_f32(c_2_0, av); \ + c_2_1 = vaddq_f32(c_2_1, av); \ + c_2_2 = vaddq_f32(c_2_2, av); \ + c_2_3 = vaddq_f32(c_2_3, av); \ + c_3_0 = vaddq_f32(c_3_0, av); \ + c_3_1 = vaddq_f32(c_3_1, av); \ + c_3_2 = vaddq_f32(c_3_2, av); \ + c_3_3 = vaddq_f32(c_3_3, av); \ + c_4_0 = vaddq_f32(c_4_0, av); \ + c_4_1 = vaddq_f32(c_4_1, av); \ + c_4_2 = vaddq_f32(c_4_2, av); \ + c_4_3 = vaddq_f32(c_4_3, av); \ + c_5_0 = vaddq_f32(c_5_0, av); \ + c_5_1 = vaddq_f32(c_5_1, av); \ + c_5_2 = vaddq_f32(c_5_2, av); \ + c_5_3 = vaddq_f32(c_5_3, av); + +#if FLOAT_SIMD_EPILOGUE == 1 +#define FLOAT_EWISE_ADD_SCALAR_END_C(c_cur, scalar, W_last, C_ob) \ + float *c_pixel = c_cur; \ + for (uint32_t kk = 0; kk < W_last; kk++) \ + { \ + float *c_channel = c_pixel; \ + for (uint32_t jj = 0; jj < C_ob; jj++) \ + { \ + *(c_channel) += scalar; \ + c_channel++; \ + } \ + c_pixel += C_ob; \ + } +#else +#define FLOAT_EWISE_ADD_SCALAR_END_C(c_cur, scalar, W_last, C_ob) \ + float32x4_t av; \ + av = vld1q_dup_f32(&scalar); \ + float32x4_t *c_pixel = c_cur; \ + for (uint32_t kk = 0; kk < W_last; kk++) \ + { \ + for (uint32_t jj = 0; jj < C_ob / FLOAT_SIMD; jj++) \ + { \ + c_pixel[(kk) * (C_ob / FLOAT_SIMD) + jj] = \ + vaddq_f32(c_pixel[(kk) * (C_ob / FLOAT_SIMD) + jj], av); \ + } \ + } +#endif + //**************************************************************************** // Softmax (Ewise exponentiation) //**************************************************************************** diff --git a/include/small/platforms/reference/intrinsics_float.h b/include/small/platforms/reference/intrinsics_float.h index b18fee5..d12e91f 100644 --- a/include/small/platforms/reference/intrinsics_float.h +++ b/include/small/platforms/reference/intrinsics_float.h @@ -557,6 +557,32 @@ namespace small c_pixel += C_ob; \ } +#define FLOAT_EWISE_ADD_SCALAR_TILE_C(scalar, W_ob, C_ob) \ + float *c_pixel = c_tile; \ + for (uint32_t kk = 0; kk < W_ob; kk++) \ + { \ + float *c_channel = c_pixel; \ + for (uint32_t jj = 0; jj < C_ob; jj++) \ + { \ + *(c_channel) += scalar; \ + c_channel++; \ + } \ + c_pixel += C_ob; \ + } + +#define FLOAT_EWISE_ADD_SCALAR_END_C(c_cur, scalar, W_last, C_ob) \ + float *c_pixel = c_cur; \ + for (uint32_t kk = 0; kk < W_last; kk++) \ + { \ + float *c_channel = c_pixel; \ + for (uint32_t jj = 0; jj < C_ob; jj++) \ + { \ + *(c_channel) += scalar; \ + c_channel++; \ + } \ + c_pixel += C_ob; \ + } + //**************************************************************************** // Accumulate upsampling //**************************************************************************** diff --git a/include/small/platforms/zen2/intrinsics.h b/include/small/platforms/zen2/intrinsics.h index 932a0dc..96aaa17 100644 --- a/include/small/platforms/zen2/intrinsics.h +++ b/include/small/platforms/zen2/intrinsics.h @@ -904,6 +904,50 @@ for (uint32_t kk = 0; kk < W_last; kk++) \ c_pixel += (C_ob/FLOAT_SIMD); \ } #endif + +//**************************************************************************** +// Broadcast Addition kernels +//**************************************************************************** + +#define FLOAT_EWISE_ADD_SCALAR_TILE_C(scalar, W_ob, C_ob) \ + b0 = _mm256_broadcast_ss(&scalar); \ + c0 = _mm256_add_ps(b0, c0); \ + c1 = _mm256_add_ps(b0, c1); \ + c2 = _mm256_add_ps(b0, c2); \ + c3 = _mm256_add_ps(b0, c3); \ + c4 = _mm256_add_ps(b0, c4); \ + c5 = _mm256_add_ps(b0, c5); \ + c6 = _mm256_add_ps(b0, c6); \ + c7 = _mm256_add_ps(b0, c7); \ + c8 = _mm256_add_ps(b0, c8); \ + c9 = _mm256_add_ps(b0, c9); \ + c10 = _mm256_add_ps(b0, c10); \ + c11 = _mm256_add_ps(b0, c11); + +#if FLOAT_SIMD_EPILOGUE == 1 +#define FLOAT_EWISE_ADD_SCALAR_END_C(c_cur, scalar, W_last, C_ob) \ + float *c_pixel = c_cur; \ + for (uint32_t kk = 0; kk < W_last; kk++) \ + { \ + float *c_channel = c_pixel; \ + for (uint32_t jj = 0; jj < C_ob; jj++) \ + { \ + *(c_channel) += scalar; \ + c_channel++; \ + } \ + c_pixel += C_ob; \ + } +#elif FLOAT_SIMD_EPILOGUE == 8 +#define FLOAT_EWISE_ADD_SCALAR_END_C(c_cur, scalar, W_last, C_ob) \ + b_0 = _mm256_broadcast_ss(&scalar); \ + __m256 *c_pixel = c_cur; \ + for (uint32_t kk = 0; kk < W_last; kk++) \ + { \ + c_pixel[0] = _mm256_add_ps(b_0, c_pixel[0]);\ + c_pixel[1] = _mm256_add_ps(b_0, c_pixel[1]);\ + c_pixel += (C_ob/FLOAT_SIMD); \ + } +#endif //**************************************************************************** // Accumulate upsampling //**************************************************************************** diff --git a/test/test_softmax.cpp b/test/test_softmax.cpp index ba7f61b..1b9e520 100644 --- a/test/test_softmax.cpp +++ b/test/test_softmax.cpp @@ -530,6 +530,513 @@ void measure_softmax_performance(void) } } +//**************************************************************************** +// Generate logsoftmax output regression data from unpack MaxPool input data. +//**************************************************************************** +template +bool compute_logsoftmax_output(LayerParams const ¶ms) +{ + /// @todo add smart pointer to buffers + // Read input data + std::string in_fname = + get_pathname(data_dir, "in", "pool", + params, + params.C_i*params.H*params.W); + std::cout << "\nlogsoftmax: input file = " << in_fname << std::endl; + + BufferT input_dc = read_inputs(in_fname); + TEST_ASSERT(input_dc.size() == params.C_i*params.H*params.W); + + // Read output regression data + size_t Ho(small::compute_output_dim( + params.H, 1,1, params.p)); + size_t Wo(small::compute_output_dim( + params.W, 1,1, params.p)); + std::cerr << "Output image dims: " << Ho << ", " << Wo << std::endl; + assert(Ho == params.H && Wo == params.W); + std::string out_fname = + get_pathname(data_dir, "out", "logsoftmax", + params, + params.C_i*params.H*params.W); + std::cout << "logsoftmax: output file= " << out_fname << std::endl; + + BufferT output_dc_answers(params.C_i*Ho*Wo); + + uint8_t t_pad=0, b_pad=0, l_pad=0, r_pad=0; + if (params.p == small::PADDING_F) + { + small::calc_padding(params.H, 1,1, t_pad, b_pad); + small::calc_padding(params.W, 1,1, l_pad, r_pad); + } + + std::cout << "Padding t,b,l,r: " << (int)t_pad << "," << (int)b_pad + << "," << (int)l_pad << "," << (int)r_pad << std::endl; + + + // small::LogSoftMax(1, 1,1, + // t_pad, b_pad, l_pad, r_pad, + // params.C_i, params.H, params.W, + // packed_input_dc, packed_output_dc); + + // Compute logsoftmax outputs + size_t num_outputs = 0; + + //Ewise Exponent and sum + float sum = 0; + for (size_t c = 0; c < params.C_i * params.H * params.W; ++c) + { + + output_dc_answers[c] = std::exp(input_dc[c]); + sum += output_dc_answers[c]; + num_outputs++; + + } + sum = std::log(sum); + for (size_t c = 0; c < params.C_i * params.H * params.W; ++c) + { + output_dc_answers[c] = input_dc[c] - sum; + } + + std::cerr << "num_outputs = " << num_outputs << std::endl; + std::cerr << "..should be = " << (params.C_i*Ho*Wo) << std::endl; + TEST_CHECK(num_outputs == params.C_i*Ho*Wo); + write_outputs(out_fname, output_dc_answers, num_outputs); + + return true; +} + +//**************************************************************************** +void test_compute_logsoftmax_output(void) +{ + std::vector params = + { + {16, 3, 3, 3, 2, small::PADDING_V, 0}, //Ci,Hi,Wi,k,s,p,Co + {16, 3, 13, 3, 2, small::PADDING_V, 0}, + + {16, 30, 30, 3, 2, small::PADDING_V, 0}, + {96, 30, 30, 3, 2, small::PADDING_V, 0}, + {96, 3, 13, 3, 2, small::PADDING_V, 0}, + + // {16, 3, 3, 3, 2, small::PADDING_F, 0}, //Ci,Hi,Wi,k,s,p,Co + // {16, 3, 13, 3, 2, small::PADDING_F, 0}, + // {96, 30, 30, 3, 2, small::PADDING_F, 0}, + // {96, 3, 13, 3, 2, small::PADDING_F, 0} + }; + for (LayerParams const &p: params) + { +#if defined(QUANTIZED) + TEST_CHECK(true == compute_logsoftmax_output(p)); +#else + TEST_CHECK(true == compute_logsoftmax_output(p)); +#endif + } +} + +//**************************************************************************** +template +bool run_logsoftmax_config(LayerParams const ¶ms) +{ + /// @todo add smart pointer to buffers + // Read input data + std::string in_fname = + get_pathname(data_dir, "in", "pool", + params, + params.C_i*params.H*params.W); + std::cout << "\nlogsoftmax: input file = " << in_fname << std::endl; + + BufferT input_dc = read_inputs(in_fname); + TEST_ASSERT(input_dc.size() == params.C_i*params.H*params.W); + + // Pack input data + BufferT packed_input_dc(input_dc.size()); + small::pack_buffer(input_dc, + small::INPUT, + 1U, params.C_i, params.H, params.W, + BufferT::C_ib, BufferT::C_ob, + packed_input_dc); + + // Read output regression data + size_t Ho(small::compute_output_dim( + params.H, 1, 1, params.p)); + size_t Wo(small::compute_output_dim( + params.W, 1, 1, params.p)); + std::cerr << "Output image dims: " << Ho << ", " << Wo << std::endl; + std::string out_fname = + get_pathname(data_dir, "out", "logsoftmax", + params, + params.C_i*Ho*Wo); + std::cout << "logsoftmax: output file= " << out_fname << std::endl; + + BufferT output_dc_answers = read_inputs(out_fname); + TEST_ASSERT(output_dc_answers.size() == params.C_i*Ho*Wo); + + // Pack output answer data + BufferT packed_output_dc_answers(output_dc_answers.size()); + small::pack_buffer(output_dc_answers, + small::OUTPUT, + 1U, params.C_i, Ho, Wo, + BufferT::C_ib, BufferT::C_ob, + packed_output_dc_answers); + + // Allocate output buffer +#if defined(QUANTIZED) + BufferT packed_output_dc(output_dc_answers.size()*4); /// @todo HACK hardcoded. +#else + BufferT packed_output_dc(output_dc_answers.size()); +#endif + + uint8_t t_pad=0, b_pad=0, l_pad=0, r_pad=0; + if (params.p == small::PADDING_F) + { + small::calc_padding(params.H, 1, 1, t_pad, b_pad); + small::calc_padding(params.W, 1, 1, l_pad, r_pad); + } + + // Compute layer + small::LogSoftMax(params.C_i, params.H, params.W, + packed_input_dc, packed_output_dc); + + // Check answer + bool passing = true; + for (size_t ix = 0; ix < packed_output_dc_answers.size(); ++ix) + { + //if (packed_output_dc[ix] != packed_output_dc_answers[ix]) + // printf("%f %f\n", packed_output_dc[ix], packed_output_dc_answers[ix]); + if (!almost_equal(packed_output_dc[ix], packed_output_dc_answers[ix])) + { + passing = false; + + std::cout << "FAIL: logsoftmax_out(" << ix << ")--> " + << std::setw(12) << std::setprecision(10) + << packed_output_dc[ix] << "(computed) != " + << std::setw(12) << std::setprecision(10) + << packed_output_dc_answers[ix] + << ", ratio = " << packed_output_dc[ix]/packed_output_dc_answers[ix] + << std::endl; + } + } + + if (passing) std::cerr << "Test PASSED\n"; + return passing; +} + +//**************************************************************************** + +template +bool run_logsoftmax_layer_config(LayerParams const ¶ms) +{ + // todo: adapt to logsoftmax + + /// @todo add smart pointer to buffers + //========================================================================= + small::shape_type input_shape({1UL, params.C_i, params.H, params.W}); + size_t input_size = params.C_i*params.H*params.W; + small::LogSoftMaxLayer logsoftmax_layer(input_shape); + //========================================================================= + + // Read input data + std::string in_fname = + get_pathname(data_dir, "in", "pool", + params, + input_size); + std::cout << "\nlogsoftmax: input file = " << in_fname << std::endl; + + // Allocate the input buffer + BufferT input_dc = read_inputs(in_fname); + + TEST_ASSERT(input_dc.size() == input_size); + + // Pack input data + BufferT packed_input_dc(input_dc.size()); + small::pack_buffer(input_dc, + small::INPUT, + 1U, params.C_i, params.H, params.W, + BufferT::C_ib, BufferT::C_ob, + packed_input_dc); + + small::Tensor packed_input_tensor( + input_shape, + std::move(packed_input_dc)); + + // Read output regression data + auto output_shape(logsoftmax_layer.output_shape()); + size_t output_buffer_size(logsoftmax_layer.output_size()); + + std::cerr << "Output image dims: " + << output_shape[small::HEIGHT] << "x" << output_shape[small::WIDTH] + << std::endl; + std::string out_fname = + get_pathname(data_dir, "out", "logsoftmax", + params, + output_buffer_size); + std::cout << "logsoftmax: output file= " << out_fname << std::endl; + + BufferT output_dc_answers = read_inputs(out_fname); + TEST_ASSERT(output_dc_answers.size() == output_buffer_size); + + // Pack output answer data + BufferT packed_output_dc_answers(output_dc_answers.size()); + small::pack_buffer(output_dc_answers, + small::OUTPUT, + 1U, output_shape[small::CHANNEL], + output_shape[small::HEIGHT], output_shape[small::WIDTH], + BufferT::C_ib, BufferT::C_ob, + packed_output_dc_answers); + + // Allocate output buffer +#if defined(QUANTIZED) + BufferT packed_output_dc(output_dc_answers.size()*4); /// @todo HACK hardcoded. +#else + BufferT packed_output_dc(output_dc_answers.size()); +#endif + small::Tensor packed_output_tensor(output_shape, + std::move(packed_output_dc)); + + // Compute layer + logsoftmax_layer.compute_output({&packed_input_tensor}, &packed_output_tensor); + + // Check answer + bool passing = true; + BufferT &buf(packed_output_tensor.buffer()); + for (size_t ix = 0; ix < packed_output_tensor.size(); ++ix) + { + //if (buf[ix] != packed_output_dc_answers[ix]) + if (!almost_equal(buf[ix], packed_output_dc_answers[ix])) + { + passing = false; + std::cout << "FAIL: logsoftmax_out(" << ix << ")--> " + << std::setw(12) << std::setprecision(10) + << buf[ix] << "(computed) != " + << std::setw(12) << std::setprecision(10) + << packed_output_dc_answers[ix] + << std::endl; + } + } + + if (passing) std::cerr << "Test PASSED\n"; + return passing; +} + +//**************************************************************************** +void test_logsoftmax_regression_data(void) +{ + std::vector params = + { + {16, 3, 3, 3, 2, small::PADDING_V, 0}, //Ci,Hi,Wi,k,s,p,Co + {16, 3, 13, 3, 2, small::PADDING_V, 0}, + + {16, 30, 30, 3, 2, small::PADDING_V, 0}, + {96, 30, 30, 3, 2, small::PADDING_V, 0}, + {96, 3, 13, 3, 2, small::PADDING_V, 0}, + + // {16, 3, 3, 3, 2, small::PADDING_F, 0}, //Ci,Hi,Wi,k,s,p,Co + // {16, 3, 13, 3, 2, small::PADDING_F, 0}, + // {96, 30, 30, 3, 2, small::PADDING_F, 0}, + // {96, 3, 13, 3, 2, small::PADDING_F, 0} + }; + for (LayerParams const &p: params) + { +#if defined(QUANTIZED) + TEST_CHECK(true == run_logsoftmax_config(p)); +#else + TEST_CHECK(true == run_logsoftmax_config(p)); +#endif + } +} + +//**************************************************************************** +void test_logsoftmax_layer_regression_data(void) +{ + std::vector params = + { + {16, 3, 3, 3, 2, small::PADDING_V, 0}, //Ci,Hi,Wi,k,s,p,Co + {16, 3, 13, 3, 2, small::PADDING_V, 0}, + + {16, 30, 30, 3, 2, small::PADDING_V, 0}, + {96, 30, 30, 3, 2, small::PADDING_V, 0}, + {96, 3, 13, 3, 2, small::PADDING_V, 0}, + + // {16, 3, 3, 3, 2, small::PADDING_F, 0}, //Ci,Hi,Wi,k,s,p,Co + // {16, 3, 13, 3, 2, small::PADDING_F, 0}, + // {96, 30, 30, 3, 2, small::PADDING_F, 0}, + // {96, 3, 13, 3, 2, small::PADDING_F, 0} + }; + for (LayerParams const &p: params) + { +#if defined(QUANTIZED) + TEST_CHECK(true == run_logsoftmax_layer_config(p)); +#else + TEST_CHECK(true == run_logsoftmax_layer_config(p)); +#endif + } +} + +//**************************************************************************** +void measure_logsoftmax_performance(void) +{ + // To do: adapt to logsoftmax + + // C_i,Hi,Wi,k,s,p,C_o + std::vector params = + { + { 16, 48, 48, 3, 1, small::PADDING_F, 16}, + { 32, 24, 24, 3, 1, small::PADDING_F, 32}, + + { 32, 48, 48, 3, 1, small::PADDING_F, 32}, + { 64, 24, 24, 3, 1, small::PADDING_F, 64}, + { 128, 12, 12, 3, 1, small::PADDING_F, 128}, + + { 16, 48, 48, 3, 1, small::PADDING_F, 32}, + { 32, 24, 24, 3, 1, small::PADDING_F, 64}, + { 64, 12, 12, 3, 1, small::PADDING_F, 128}, + { 128, 6, 6, 3, 1, small::PADDING_F, 256}, + + { 128, 24, 24, 3, 1, small::PADDING_F, 128}, + { 256, 12, 12, 3, 1, small::PADDING_F, 256}, + + { 512, 12, 12, 3, 1, small::PADDING_F, 512}, + {1024, 6, 6, 3, 1, small::PADDING_F, 1024}, + + { 32, 208, 208, 3, 1, small::PADDING_F, 64}, + { 64, 104, 104, 3, 1, small::PADDING_F, 128}, + { 128, 52, 52, 3, 1, small::PADDING_F, 256}, + { 256, 26, 26, 3, 1, small::PADDING_F, 512}, + { 512, 13, 13, 3, 1, small::PADDING_F, 1024} + }; + + uint32_t const num_threads[] = {1, 2, 4}; + char const *str_num_threads[] = {"1", "2", "4"}; + uint32_t const num_runs(100); + small::Timer t; + +#if defined(QUANTIZED) + std::string type("quint8"); + using Buffer = small::QUInt8Buffer; +#else + std::string type("float"); + using Buffer = small::FloatBuffer; +#endif + + printf("\nlogsoftmax2D(%s) func.\n", type.c_str()); + printf("\tC_i\tH\tW\tk\ts\tnthd\truns\tt_min\tt_max\tt_avg\n"); + + for (LayerParams const &p: params) + { + size_t Ho(small::compute_output_dim(p.H, p.k, p.s, p.p)); + size_t Wo(small::compute_output_dim(p.W, p.k, p.s, p.p)); + + uint8_t t_pad=0, b_pad=0, l_pad=0, r_pad=0; + if (p.p == small::PADDING_F) + { + small::calc_padding(p.H, p.k, p.s, t_pad, b_pad); + small::calc_padding(p.W, p.k, p.s, l_pad, r_pad); + } + + size_t num_input_elts(p.C_i*p.H*p.W); + size_t num_output_elts(p.C_i*Ho*Wo); + + Buffer input_dc(num_input_elts); + Buffer output_dc(num_output_elts); + small::init(input_dc, num_input_elts); + + for (size_t ix = 0; ix < 3; ++ix) + { + setenv("OMP_NUM_THREADS", str_num_threads[ix], 1); + //std::string ont = std::getenv("OMP_NUM_THREADS"); // read it back + //auto nt = atol(ont.c_str()); + + double tx(0.); + double min_t = std::numeric_limits::max(); + double max_t = 0.; + + // Warmup + small::LogSoftMax( + p.C_i, p.H, p.W, + input_dc, output_dc); + + for (size_t iy = 0; iy < num_runs; ++iy) + { + t.start(); + small::LogSoftMax( + p.C_i, p.H, p.W, + input_dc, output_dc); + t.stop(); + double ts = t.elapsed(); + tx += ts; + min_t = std::min(min_t, ts); + max_t = std::max(max_t, ts); + } + + printf("function\t%d\t%d\t%d\t%d\t%d\t%d\t%d\t%lf\t%lf\t%lf\n", + p.C_i, p.H, p.W, p.k, p.s, + num_threads[ix], num_runs, + min_t, max_t, (tx/num_runs)); + } + } + + printf("\nlogsoftmax2D(%s) class\n", type.c_str()); + printf("\tC_i\tH\tW\tk\ts\tnthd\truns\tt_min\tt_max\tt_avg\n"); + + for (LayerParams const &p: params) + { + size_t Ho(small::compute_output_dim(p.H, p.k, p.s, p.p)); + size_t Wo(small::compute_output_dim(p.W, p.k, p.s, p.p)); + + uint8_t t_pad=0, b_pad=0, l_pad=0, r_pad=0; + if (p.p == small::PADDING_F) + { + small::calc_padding(p.H, p.k, p.s, t_pad, b_pad); + small::calc_padding(p.W, p.k, p.s, l_pad, r_pad); + } + + size_t num_input_elts(p.C_i*p.H*p.W); + size_t num_output_elts(p.C_i*Ho*Wo); + small::shape_type input_shape({1UL, p.C_i, p.H, p.W}); + + small::Tensor input_dc(input_shape); + small::init(input_dc.buffer(), num_input_elts); + std::vector*> inputs; + inputs.push_back(&input_dc); + + small::Tensor output_dc(num_output_elts); + std::vector*> outputs; + outputs.push_back(&output_dc); + + small::SoftMaxLayer + softmax_layer(input_shape); + + for (size_t ix = 0; ix < 3; ++ix) + { + setenv("OMP_NUM_THREADS", str_num_threads[ix], 1); + //std::string ont = std::getenv("OMP_NUM_THREADS"); + //auto nt = atol(ont.c_str()); + + double tx(0.); + double min_t = std::numeric_limits::max(); + double max_t = 0.; + + // Warm up + softmax_layer.compute_output({&input_dc}, &output_dc); + + for (size_t iy = 0; iy < num_runs; ++iy) + { + t.start(); + softmax_layer.compute_output({&input_dc}, &output_dc); + t.stop(); + double ts = t.elapsed(); + tx += ts; + min_t = std::min(min_t, ts); + max_t = std::max(max_t, ts); + } + + printf("class \t%d\t%d\t%d\t%d\t%d\t%d\t%d\t%lf\t%lf\t%lf\n", + p.C_i, p.H, p.W, p.k, p.s, + num_threads[ix], num_runs, + min_t, max_t, (tx/num_runs)); + } + } +} + //**************************************************************************** //**************************************************************************** TEST_LIST = { @@ -537,5 +1044,9 @@ TEST_LIST = { {"softmax_regression_data", test_softmax_regression_data}, {"softmax_layer_regression_data", test_softmax_layer_regression_data}, // {"softmax_performance", measure_softmax_performance}, + {"compute_output_logsoftmax", test_compute_logsoftmax_output}, + {"logsoftmax_regression_data", test_logsoftmax_regression_data}, + {"logsoftmax_layer_regression_data", test_logsoftmax_layer_regression_data}, + // {"logsoftmax_performance", measure_logsoftmax_performance}, {NULL, NULL} };