diff --git a/cmake/modules/Hexagon.cmake b/cmake/modules/Hexagon.cmake index c08ea5eb1df1..aad770120201 100644 --- a/cmake/modules/Hexagon.cmake +++ b/cmake/modules/Hexagon.cmake @@ -172,6 +172,16 @@ if(BUILD_FOR_HEXAGON) list(APPEND TVM_RUNTIME_LINKER_LIBS -Wl,--whole-archive ${USE_HEXAGON_SDK}/libs/qhl/prebuilt/hexagon_toolv84_v68/libqhmath.a -Wl,--no-whole-archive) endif() + + # Hand-written ops + file_glob_append(RUNTIME_HEXAGON_SRCS + "${TVMRT_SOURCE_DIR}/hexagon/ops/*.cc" + ) + + set_source_files_properties( + "${TVMRT_SOURCE_DIR}/hexagon/ops/conv2d_fp16_hvx.cc" + PROPERTIES COMPILE_FLAGS "-mhvx" + ) endif() if(USE_HEXAGON_RPC) diff --git a/include/tvm/runtime/hexagon/ops/conv2d.h b/include/tvm/runtime/hexagon/ops/conv2d.h new file mode 100644 index 000000000000..d759149727e8 --- /dev/null +++ b/include/tvm/runtime/hexagon/ops/conv2d.h @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include + +#include + +#ifndef TVM_RUNTIME_HEXAGON_OPS_CONV2D_H_ +#define TVM_RUNTIME_HEXAGON_OPS_CONV2D_H_ + +namespace tvm { +namespace runtime { +namespace hexagon { +static constexpr auto hexagon_device = DLDevice{static_cast(kDLHexagon), 0}; + +// Standalone DLTensor: the standalone-ness means that this object owns the shape +// (as opposed to a DLTensor). +template +class SDLTensor : public DLTensor { + public: + SDLTensor(void* data_ptr, DLDataType data_type, void* data_space, const int64_t* data_dims) + : SDLTensor(data_ptr, data_type, data_space) { + for (size_t i = 0; i < NDIM; ++i) dims[i] = data_dims[i]; + } + + SDLTensor(void* data_ptr, DLDataType data_type, void* data_space, + std::initializer_list data_dims) + : SDLTensor(data_ptr, data_type, data_space, data_dims.begin()) {} + + void* GetDataSpace() const { return data_space; } + + private: + /** + * @brief Construct SDLTensor + * + * @param data_ptr Either points to the same memory as data_space or an array of pointers to the + * start of each chunk of weight. Since weights can be of varying sizes, this array could contain + * the pointer to each chunk of memory + * @param data_type data type of the elements in Tensor + * @param data_space is meant to store the pointer returned from AllocDataSpace and can be freed + * by passing it to FreeDataSpace + */ + SDLTensor(void* data_ptr, DLDataType data_type, void* data_space) : data_space(data_space) { + data = data_ptr; + device = hexagon_device; + ndim = NDIM; + dtype = data_type; + shape = dims; + strides = nullptr; + byte_offset = 0; + } + + void* data_space = nullptr; + int64_t dims[NDIM]; +}; + +inline void* to_ptr(uintptr_t v) { return reinterpret_cast(v); } + +inline uintptr_t to_uint(void* ptr) { return reinterpret_cast(ptr); } + +constexpr int xyc_to_sm_16b(int y, int x, int c) { + // Map y,x,c coordinates within a block to the offset (in 16-bit elements) + // from the beginning of the block in spatial-major layout. + // 10-bit spatial mask: yyyxcccccx + assert(y >= 0 && x >= 0 && c >= 0); + return y << 7 | (x & 2) << 5 | c << 1 | (x & 1); +} + +constexpr int hwio_to_sm_16b(int width, int y, int x, int i, int o) { + // Map y,x,i,o coordinates within a chunk (assuming the origin at the + // top-left spatial corner) to the offset (in 16-bit elements) from the + // beginning of the chunk in spatial-major layout. + // Spatial mask: p..piiiioooooi, where p..p are position bits. + assert(width >= 1); + assert(y >= 0 && x >= 0 && i >= 0 && o >= 0); + int p = y * width + (width - 1 - x); + return p << 10 | (i & 0x1e) << 5 | o << 1 | (i & 1); +} + +inline constexpr int round_up(int v, int p2) { return (v + p2 - 1) & -p2; } + +// Returns the block address at the given index +// Assumptions +// - The data type of tensor is fp16 +// - There is only one batch, and hence n==0 +inline uintptr_t nhwc_at(const DLTensor& a, int n, int y, int x, int c) { + if (y < 0 || y >= a.shape[1]) return uintptr_t(0); + auto p = static_cast(a.data); + assert(n == 0); + return p[y * a.shape[2] * a.shape[3] + x * a.shape[3] + c]; +} + +// Returns the address of the chunk stored at given index +// Assumptions +// - The data type of tensor is fp16 +inline uintptr_t hwio_at(const DLTensor& f, int y, int x, int i, int o) { + auto p = static_cast(f.data); + return p[y * f.shape[1] * f.shape[2] * f.shape[3] + x * f.shape[2] * f.shape[3] + i * f.shape[3] + + o]; +} + +/** + * @brief Function to "blockize" the flat input data + * The term "blockize" is used to mention that the data is stored in non-contiguous blocks + * + * The input is mapped into the below mentioned layout (notation similar to index map used for + * transform layout): + * + * lambda n, h, w, c: n, h//8, w//4, c//32, AXIS_SEPARATOR, h%8, (w%4)//2, c%32, w%2 + * + * where AXIS_SEPARATOR represents split up in the physical layout + * + * @param out Pre-allocated output memory pointer + * @param inp_flat Flat input data pointer + * @param height + * @param width + * @param depth + */ +void blockize_hwc_16b(void* out, void* inp_flat, int height, int width, int depth); + +/** + * @brief Convert back from non-contguous layout to a flat layout + * + * @param out_flat Pre-allocated output memory pointer + * @param inp Blockized input data pointer + * @param height + * @param width + * @param depth + */ +void deblockize_hwc_16b(void* out_flat, void* inp, int height, int width, int depth); + +/** + * @brief Convert the layout of weights from flat to "chunked". The term chunked is explained below: + * + * Weights are packed into the below mentioned layout (notation similar to index map): + * Since weights cannot be exactly represented into a index map notation, the + * base split up is mentioned below with a few gotchas + * + * lambda h, w, i, o: h//8, w//4, o//32, i//32, h%8, w%4, (i%32)//2, o%32, i%2 + * + * The gotchas are: + * - (w%4) is actually stored in the right to left order, as in 3,2,1,0 instead of 0,1,2,3 + * - The h%8 and (w%4) dimensions are not padded up, leading to chunks of different sizes + * (thereby the name "chunked" instead of packed) + * - The thinnest chunk of width is stored first. For example, if a kernel is 5x5, the first + * chunk along the width has size 1 (representing index 0) and then next one has size 4 + * representing indices (1,2,3,4) + * + * @param out_ptr Base pointer table to be filled with the list of pointers to the first addresses + * of the "chunked" weights + * @param out_ptr_size The number of chunks + * @param out Pointer to pre-allocated output memory + * @param inp Pointer to flat input data + * @param height + * @param width + * @param idepth + * @param odepth + */ +void chunkify_hwio_16b(void** out_ptr, int out_ptr_size, void* out, void* inp, int height, + int width, int idepth, int odepth); + +SDLTensor<4> prepare_nhwc(tvm::runtime::DeviceAPI* device_api, const DLTensor* nhwc_flat, + bool copy_data); + +int calculate_num_weight_chunks(int64_t* shape_hwio); + +SDLTensor<4> prepare_hwio(tvm::runtime::DeviceAPI* device_api, const DLTensor* hwio_flat, + int num_chunks, void** ptr_table); + +template +void release(tvm::runtime::DeviceAPI* device_api, const SDLTensor& tensor) { + if (auto* data_space = tensor.GetDataSpace()) { + device_api->FreeDataSpace(hexagon_device, data_space); + } +} + +} // namespace hexagon +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_HEXAGON_OPS_CONV2D_H_ diff --git a/src/runtime/hexagon/ops/conv2d_fp16_hvx.cc b/src/runtime/hexagon/ops/conv2d_fp16_hvx.cc new file mode 100644 index 000000000000..cf4dc43c6515 --- /dev/null +++ b/src/runtime/hexagon/ops/conv2d_fp16_hvx.cc @@ -0,0 +1,489 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "tvm/runtime/hexagon/ops/conv2d.h" + +// Current limitations: +// - N in NHWC must be 1 +// - dilated convolutions are not supported +// - Bias is not accepted +// - Optional "relu" is not performed + +// Packed arguments: +// 0: DLTensor activations (NHWC) +// 1: DLTensor weights (HWIO) +// 2: int offset_top +// 3: int offset_left +// 4: int stride_h +// 5: int stride_w +// 6: DLTensor output (NHWC) +extern "C" int conv2d_packed_fp16(TVMValue* args, int* type_codes, int num_args, TVMValue* out_val, + int out_code, void* res_handle); + +namespace tvm { +namespace runtime { +namespace hexagon { + +/** + * @brief Returns the pointer to the element within the given block + * assuming fp16 type and speicific layout as mentioned in blockize_hwc_16b. + * All the below params are explained with the same layout assumption + * + * @param block_out_y y-index of block + * @param block_out_x x-index of block + * @param block_out_c c-index of block + * @param yi height-offset within the block + * @param xio outer width offset within the block + * @param ci channel offset within the block + * @param xii inner width offset within the block + * @param block base DLTensor + * + * @return The pointer to the element within the given block + */ +static inline uint16_t* getElementPtr(int block_out_y, int block_out_x, int block_out_c, int yi, + int xio, int ci, int xii, const DLTensor& tensor) { + auto block_ptr = nhwc_at(tensor, 0, block_out_y, block_out_x, block_out_c); + auto block_offset = yi * 128 + xio * 64 + ci * 2 + xii; + auto first_element_ptr = reinterpret_cast(block_ptr); + return first_element_ptr + block_offset; +} + +/** + * @brief Compute 2 vectors with ones in the even and odd lanes + * + * Output vectors are: + * vector 1 = [0xFFFF,0x0000,0xFFFFF,0x0000,...,0xFFFF,0x0000] + * vector lanes = [ 0 , 2 , 3 , 4 ,..., 62 , 63 ] + * + * vector 2 = [0x0000,0xFFFF,0x0000,0xFFFFF,...,0xFFFF,0x0000] + * vector lanes = [ 0 , 2 , 3 , 4 ,..., 62 , 63 ] + * + * @return Return the 2 vectors + */ +inline std::pair getOddEvenOnes() { + HVX_Vector v0 = Q6_V_vzero(); + HVX_Vector v1 = Q6_Vh_vsplat_R(0xFFFF); + + HVX_Vector v1e = Q6_Vh_vshuffe_VhVh(v0, v1); + HVX_Vector v1o = Q6_V_vnot_V(v1e); + return {v1e, v1o}; +} + +/** + * @brief Return the input vector filled with the 2 channel elements(which is the 1st and 3rd + * element) from base_ptr filled up 32 times to get 64 elements + * + * 1. It's generated by first creating 2 vectors "splatted" with the 2 required elements + * 2. Then we andd it with vectors containing all ones (0xFFFF) in the even and odd lanes + * 3. Finally those 2 vectors are OR'ed together + * + * @param base_ptr pointer to the first of the 2 channel elements to be filled + * + * @return input vector + */ +inline HVX_Vector getInputVector(uint16_t* base_ptr) { + HVX_Vector v1 = Q6_Vh_vsplat_R(base_ptr[0]); + HVX_Vector v2 = Q6_Vh_vsplat_R(base_ptr[2]); + + auto oddEvenOnes = getOddEvenOnes(); + auto v1e = oddEvenOnes.first; + auto v1o = oddEvenOnes.second; + + HVX_Vector v_even_vals = Q6_V_vand_VV(v1, v1e); + HVX_Vector v_odd_vals = Q6_V_vand_VV(v2, v1o); + + return Q6_V_vor_VV(v_even_vals, v_odd_vals); +} + +/** + * @brief Return the Output vector which contains the 32 output channels in the even lanes + * + * The output vector is commputed as: + * 1. vector multiply(vmpy) of input and weights + * 2. Rotate the vector right by 1 element and add with the first vector to add the 2 input channels + * 3. Then convert the results back from qfloat16 to IEEE half-precision float + * 4. The added values are in even lanes, so zero out the odd lanes by anding with ones in even + * lanes and return + * + * @param act_vec Input activations vector + * @param wgt_vec Weights vector + * + * @return output vector with 32 output channels even lanes + */ +inline HVX_Vector computeOuputVector(HVX_Vector act_vec, HVX_Vector wgt_vec) { + HVX_Vector v_res = Q6_Vqf16_vmpy_VhfVhf(act_vec, wgt_vec); // result is in qf16 + HVX_Vector v_rot = Q6_V_vror_VR(v_res, 2); + HVX_Vector v_reduced = Q6_Vqf16_vadd_Vqf16Vqf16(v_res, v_rot); + HVX_Vector v_hf = Q6_Vhf_equals_Vqf16(v_reduced); + HVX_Vector v1e = getOddEvenOnes().first; + HVX_Vector v_reduced_even_lanes = Q6_V_vand_VV(v_hf, v1e); + return v_reduced_even_lanes; +} + +static int round_down(int v, int base) { return v - (v % base); } + +/** + * @brief Compute the convolution of inputs from cr_act, and weights from + * cr_filt to update the output to cr_out. The goal is to have an efficient + * HVX implementation + * + * Assumptions: + * ----------- + * - This implementation right now assumes that the dilation is 1 + * - there is zero padding or the input was already pre-padded. + * - block specific spatial padding is only expected at the end and hence + * pad_top and pad_left are not yet used + * - Relu activation is not used + * - Bias add is not done + * + * @param cr_out blockized output tensor with zeros already filled in + * @param cr_act blockized activations + * @param cr_filt Chunkified weights as returned from output of prepare_hwio + * @param out_shape Original output shape of the tensor before blockization + * @param act_shape Original input shape + * @param bias_flat Flat bias values and are not used right now + * TODO (quic-sanirudh) Add support for bias add + * @param filt_shape Original filter shape + * @param pad_shape Pad top and pad left shape + * @param relu Whether to apply relu after convolution, not done right now + * TODO (quic-sanirudh) Add support for relu activation + * @param zero_block A block filled with zeros + * + * @return + */ +void conv_layer_fp16_hvx(DLTensor& cr_out, const DLTensor& cr_act, // NOLINT(*) + const DLTensor& cr_filt, const DLTensor& out_shape, + const DLTensor& act_shape, const DLTensor& bias_flat, + const DLTensor& filt_shape, const DLTensor& pad_shape, bool relu, + int stride_h, int stride_w, uintptr_t zero_block) { + int64_t filt_height = filt_shape.shape[0]; + int64_t filt_width = filt_shape.shape[1]; + int64_t filt_idepth = filt_shape.shape[2]; + + int pad_top = pad_shape.shape[0]; + int pad_left = pad_shape.shape[1]; + LOG_INFO << "filt_height=" << filt_height << ", filt_width=" << filt_width + << ", filt_idepth=" << filt_idepth << ", pad_top=" << pad_top + << ", pad_left=" << pad_left << "\n"; + + ICHECK_LT(pad_top, 8) << "pad_top offset cannot be >= 8"; + ICHECK_LT(pad_left, 4) << "pad_left offset cannot be >= 4"; + + int a_height = cr_act.shape[1]; + int a_width = cr_act.shape[2]; + int a_depth = cr_act.shape[3]; + + int w_height = cr_filt.shape[0]; + int w_width = cr_filt.shape[1]; + + int o_depth = cr_out.shape[3]; + int b_depth = bias_flat.shape[0]; + + int o_height = cr_out.shape[1]; + int o_width = cr_out.shape[2]; + + int out_height = out_shape.shape[1]; + int out_width = out_shape.shape[2]; + + LOG_INFO << "a: 1x" << a_height << "x" << a_width << "x" << a_depth << ", w: " << w_height << "x" + << w_width << "x" << static_cast(cr_filt.shape[2]) << "x" + << static_cast(cr_filt.shape[3]) << ", o: 1x" << o_height << "x" << o_width << "x" + << o_depth << ", b: " << b_depth << ", out_shape: " << out_height << "x" << out_width + << "\n"; + + ICHECK_EQ(a_depth, cr_filt.shape[2]) << "input depth should match weights input channels"; + ICHECK_EQ(o_depth, cr_filt.shape[3]) << "output depth should match the weights output channel"; + + int rd = round_down(filt_width, 4); + int wgt_chunk_thin_width = filt_width - rd; + + /* + * Compute the output vector of either 1 or 2 elements along the width and max 32 elements along + * the depth to constitue a maximum of 64 elements + * + * The weights are loaded directly in the order they're stored, which results + * in 2 input channels and 32 output channels + * + * Weights vector illustration: + * ------- ------ ------------ + * weights_vec = [0-0,0-1,1-0,1-1,2-0,2-1,3-0,3-1,4-0,4-1,...,31-0,31-1] -> This is the + * vector representation of weights, where the elements are represented as + * "out_channel-input_channel" + * + * + * Same 2 input channels have to be multiplied across all output channels in the weights. + * + * Activations vector would thus be: + * ----------- ------ ----- ---- -- + * act_vec = [i0,i1,i0,i1,i0,i1,...,i0,i1] - 2 elements of the input channels broadcasted 32 times + * to fill 64 elements of the vector + * + * + * Thus the computation is just a vmpy(act_vec,weights_vec) followed by a some rearrangement to + * add every pair of 16b lanes in the vector to reduce along the input channels + * + * This result is added to the result of the next pair of input channels all the way until we + * have reduced across the entire input channels. + * + * Then the same vector is added to the results of the following elements along the width and + * height to finally get 32 elements representing 32 output channels. + * + * Since the output block also has the 8h2w32c2w format, the 32 elements of the next element + * along the width is also added into the the same vector such that the first 32 channel elements + * occupy the even lanes and the next 32 occupy the odd lanes to form a single 64-element vector + * which is then stored + */ + auto computeConv = [filt_height, filt_width, wgt_chunk_thin_width, filt_idepth, stride_h, + stride_w, &cr_out, &cr_act, &cr_filt](int out_act_y, int out_act_x, int out_c, + int h, int wo, bool skip_wi_1 = false) { + auto out_element_ptr = getElementPtr(out_act_y, out_act_x, out_c, h, wo, 0, 0, cr_out); + + LOG_INFO << "out_act_y: " << out_act_y << ", out_act_x: " << out_act_x << ", out_c: " << out_c + << ", h: " << h << ", wo: " << wo << " out_element_ptr: " << out_element_ptr; + + HVX_Vector* out_vector = reinterpret_cast(out_element_ptr); + HVX_Vector existing_out_vec = *out_vector; + + for (int fh = 0; fh < filt_height; ++fh) { + for (int fw = 0; fw < filt_width; ++fw) { + int fch = fh / 8; + int fcw = 0; + if (fw >= wgt_chunk_thin_width) { + fcw = (fw - wgt_chunk_thin_width) / 4 + 1; + } + int fx = (fw < wgt_chunk_thin_width) ? fw : ((fw - wgt_chunk_thin_width) % 4); + int fy = fh % 8; + for (int c = 0; c < round_up(filt_idepth, 2); c += 2) { + int out_act_cc = c / 32; + int ci = c % 32; + auto wgt_chunk = hwio_at(cr_filt, fch, fcw, out_act_cc, out_c); + + // Find weight chunk offset ptr + int max_x = (fcw == 0) ? wgt_chunk_thin_width : 4; + + int wi = 0; + + int out_width_idx = out_act_x * 4 + wo * 2 + wi; + int act_width_access_idx = out_width_idx * stride_w + fw; + int true_out_act_x = act_width_access_idx / 4; + int true_wo = (act_width_access_idx % 4) / 2; + int true_wi = act_width_access_idx % 2; + + int out_height_idx = out_act_y * 8 + h; + int act_height_access_idx = out_height_idx * stride_h + fh; + int true_out_act_y = act_height_access_idx / 8; + int true_h = act_height_access_idx % 8; + + int act_channel_idx = out_act_cc * 32 + ci; + + auto act_element_ptr = getElementPtr(true_out_act_y, true_out_act_x, out_act_cc, true_h, + true_wo, ci, true_wi, cr_act); + HVX_Vector act_vec = getInputVector(act_element_ptr); + + auto wgt_chunk_offset = hwio_to_sm_16b(max_x, fy, fx, ci, 0); + auto base_chunk_ptr = reinterpret_cast(wgt_chunk); + auto chunk_ptr = base_chunk_ptr + wgt_chunk_offset; + + LOG_INFO << "act: 0x" << act_height_access_idx << "x" << act_width_access_idx << "x" + << act_channel_idx << ", wgt: " << fh << "x" << fw << "x" << act_channel_idx + << "x" << out_c * 32 << ", out: 0x" << out_height_idx << "x" << out_width_idx + << "x" << out_c * 32 << ", wgt_chunk_offset: " << wgt_chunk_offset; + + const HVX_Vector* weights_vec_ptr = reinterpret_cast(chunk_ptr); + HVX_Vector weights_vec = *weights_vec_ptr; + + HVX_Vector reduced_vec_even_elements = computeOuputVector(act_vec, weights_vec); + + if (!skip_wi_1) { + wi = 1; + + out_width_idx = out_act_x * 4 + wo * 2 + wi; + act_width_access_idx = out_width_idx * stride_w + fw; + true_out_act_x = act_width_access_idx / 4; + true_wo = (act_width_access_idx % 4) / 2; + true_wi = act_width_access_idx % 2; + + act_element_ptr = getElementPtr(true_out_act_y, true_out_act_x, out_act_cc, true_h, + true_wo, ci, true_wi, cr_act); + act_vec = getInputVector(act_element_ptr); + + LOG_INFO << "act: 0x" << act_height_access_idx << "x" << act_width_access_idx << "x" + << act_channel_idx << ", wgt: " << fh << "x" << fw << "x" << act_channel_idx + << "x" << out_c * 32 << ", out: 0x" << out_height_idx << "x" << out_width_idx + << "x" << out_c * 32 << ", wgt_chunk_offset: " << wgt_chunk_offset; + + HVX_Vector reduced_vec_odd_elements = computeOuputVector(act_vec, weights_vec); + reduced_vec_odd_elements = Q6_V_vror_VR(reduced_vec_odd_elements, -2); + HVX_Vector out_final = Q6_V_vor_VV(reduced_vec_even_elements, reduced_vec_odd_elements); + + HVX_Vector out_vec_qf16 = Q6_Vqf16_vadd_VhfVhf(out_final, existing_out_vec); + existing_out_vec = Q6_Vhf_equals_Vqf16(out_vec_qf16); + } else { + HVX_Vector out_vec_qf16 = + Q6_Vqf16_vadd_VhfVhf(reduced_vec_even_elements, existing_out_vec); + existing_out_vec = Q6_Vhf_equals_Vqf16(out_vec_qf16); + } + } + } + } + *out_vector = existing_out_vec; + }; + + auto computeFullWidth = [&computeConv](int out_y, int out_x, int out_c, int h) { + for (int wo = 0; wo < 2; ++wo) { + computeConv(out_y, out_x, out_c, h, wo); + } + }; + + auto computePartialWidth = [out_width, o_width, &computeConv](int out_y, int out_c, int h) { + int out_x = o_width - 1; + int wo = 0; + for (; wo < (out_width % 4) / 2; ++wo) { + computeConv(out_y, out_x, out_c, h, wo); + } + + if (out_width % 2) { + computeConv(out_y, out_x, out_c, h, wo, true /* skip_wi_1 */); + } + }; + + for (int out_c = 0; out_c < cr_filt.shape[3]; ++out_c) { + for (int out_act_y = 0; out_act_y < out_height / 8; ++out_act_y) { + int out_y = out_act_y; + for (int out_act_x = 0; out_act_x < out_width / 4; ++out_act_x) { + int out_x = out_act_x; + for (int h = 0; h < 8; ++h) { + computeFullWidth(out_y, out_x, out_c, h); + } + } + + for (int h = 0; h < 8; ++h) { + computePartialWidth(out_y, out_c, h); + } + } + + int out_y = o_height - 1; + for (int h = 0; h < out_height % 8; ++h) { + for (int out_act_x = 0; out_act_x < out_width / 4; ++out_act_x) { + int out_x = out_act_x; + computeFullWidth(out_y, out_x, out_c, h); + } + computePartialWidth(out_y, out_c, h); + } + } +} +} // namespace hexagon +} // namespace runtime +} // namespace tvm + +int conv2d_packed_fp16(TVMValue* args, int* type_codes, int num_args, TVMValue* out_val, + int out_code, void* res_handle) { + namespace hexagonrt = tvm::runtime::hexagon; + ICHECK_EQ(num_args, 7) << "Unexpected number of arguments"; + ICHECK_EQ(type_codes[0], kTVMDLTensorHandle) + << "First argument is expected to be the input tensor"; // Input activations + ICHECK_EQ(type_codes[1], kTVMDLTensorHandle) + << "Second argument is expected to be the weights tensor"; // Weights + ICHECK_EQ(type_codes[2], kDLInt) + << "Third argument is expected to be the pad_top offset"; // pad_top offset + ICHECK_EQ(type_codes[3], kDLInt) + << "Fourth argument is expected to be the pad_left offset"; // pad_left offset + ICHECK_EQ(type_codes[4], kDLInt) << "Fifth argument is expected to be the stride_h"; // stride_h + ICHECK_EQ(type_codes[5], kDLInt) << "Sixth argument is expected to be the stride_w"; // stride_w + ICHECK_EQ(type_codes[6], kTVMDLTensorHandle) + << "Seventh argument is expected to be the output tensor"; // output + + auto* act_flat = static_cast(args[0].v_handle); + auto* wgt_flat = static_cast(args[1].v_handle); + auto* out_flat = static_cast(args[6].v_handle); + + // Temporary assertion until multiple batches are supported + ICHECK_EQ(act_flat->shape[0], 1) << "Input batch size more than 1 is not supported yet"; + + // Temporary assertion until multiple batches are supported + ICHECK_EQ(out_flat->shape[0], 1) << "Output batch size more than 1 is not supported yet"; + + int pad_top = args[2].v_int64; + int pad_left = args[3].v_int64; + int stride_h = args[4].v_int64; + int stride_w = args[5].v_int64; + + LOG_INFO << "act.shape=" << act_flat->shape[0] << "x" << act_flat->shape[1] << "x" + << act_flat->shape[2] << "x" << act_flat->shape[3] + << ", wgt.shape=" << wgt_flat->shape[0] << "x" << wgt_flat->shape[1] << "x" + << wgt_flat->shape[2] << "x" << wgt_flat->shape[3] << ", pad_top=" << pad_top + << ", pad_left=" << pad_left; + + auto* device_api = tvm::runtime::DeviceAPI::Get(hexagonrt::hexagon_device, false); + ICHECK(device_api != nullptr); + tvm::runtime::String vtcm_scope = "global.vtcm"; + + auto act_vtcm = hexagonrt::prepare_nhwc(device_api, act_flat, /*copy_data=*/true); + + ICHECK_NE(wgt_flat->shape[0], 0) << "Weights height should not be zero"; + ICHECK_NE(wgt_flat->shape[1], 0) << "Weights width should not be zero"; + ICHECK_NE(wgt_flat->shape[2], 0) << "Weights input channels should not be zero"; + ICHECK_NE(wgt_flat->shape[3], 0) << "Weights output channels should not be zero"; + int num_wgt_chunks = hexagonrt::calculate_num_weight_chunks(wgt_flat->shape); + LOG_INFO << "num_wgt_chunks: " << num_wgt_chunks; + auto wgt_ptr_table = + reinterpret_cast(__builtin_alloca(num_wgt_chunks * sizeof(uintptr_t))); + auto wgt_vtcm = hexagonrt::prepare_hwio(device_api, wgt_flat, num_wgt_chunks, wgt_ptr_table); + + auto out_vtcm = hexagonrt::prepare_nhwc(device_api, out_flat, /*copy_data=*/false); + + // Prepare zero_block + int64_t block_nbytes = 2048; + void* zero_block = device_api->AllocDataSpace(hexagonrt::hexagon_device, 1, &block_nbytes, + tvm::runtime::DataType::UInt(8), vtcm_scope); + memset(zero_block, 0, 2048); + + // FIXME: Setting bias to zero_block: this works for up to 256 output channels. + auto bias_flat = + hexagonrt::SDLTensor<1>(zero_block, wgt_flat->dtype, zero_block, &wgt_flat->shape[3]); + auto act_shape = hexagonrt::SDLTensor<4>(nullptr, act_flat->dtype, nullptr, act_flat->shape); + auto filt_shape = hexagonrt::SDLTensor<4>(nullptr, wgt_flat->dtype, nullptr, wgt_flat->shape); + auto pad_shape = hexagonrt::SDLTensor<2>(nullptr, act_flat->dtype, nullptr, {pad_top, pad_left}); + auto out_shape = hexagonrt::SDLTensor<4>(nullptr, out_flat->dtype, nullptr, out_flat->shape); + bool relu = false; + + hexagonrt::conv_layer_fp16_hvx(out_vtcm, act_vtcm, wgt_vtcm, out_shape, act_shape, bias_flat, + filt_shape, pad_shape, relu, stride_h, stride_w, + hexagonrt::to_uint(zero_block)); + + hexagonrt::deblockize_hwc_16b(out_flat->data, out_vtcm.data, out_flat->shape[1], + out_flat->shape[2], out_flat->shape[3]); + + device_api->FreeDataSpace(hexagonrt::hexagon_device, zero_block); + hexagonrt::release(device_api, out_vtcm); + hexagonrt::release(device_api, wgt_vtcm); + hexagonrt::release(device_api, act_vtcm); + + return 0; +} diff --git a/src/runtime/hexagon/ops/conv_utils.cc b/src/runtime/hexagon/ops/conv_utils.cc new file mode 100644 index 000000000000..e1ec1e17277d --- /dev/null +++ b/src/runtime/hexagon/ops/conv_utils.cc @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "tvm/runtime/hexagon/ops/conv2d.h" + +namespace tvm { +namespace runtime { +namespace hexagon { + +/** + * @brief Function to "blockize" the flat input data + * The term "blockize" is used to mention that the data is stored in non-contiguous blocks + * + * The input is mapped into the below mentioned layout (notation similar to index map used for + * transform layout): + * + * lambda n, h, w, c: n, h//8, w//4, c//32, AXIS_SEPARATOR, h%8, (w%4)//2, c%32, w%2 + * + * where AXIS_SEPARATOR represents split up in the physical layout + * + * @param out Pre-allocated output memory pointer + * @param inp_flat Flat input data pointer + * @param height + * @param width + * @param depth + */ +void blockize_hwc_16b(void* out, void* inp_flat, int height, int width, int depth) { + auto inp_data = static_cast(inp_flat); + auto out_data = static_cast(out); + const int stride_x = depth; + const int stride_y = stride_x * width; + + for (int cy = 0; cy < height; cy += 8) { + for (int cx = 0; cx < width; cx += 4) { + for (int cc = 0; cc < depth; cc += 32) { + auto block = reinterpret_cast(*out_data++); + int max_y = std::min(8, height - cy); + int max_x = std::min(4, width - cx); + int max_c = std::min(32, depth - cc); + for (int y = 0; y < max_y; ++y) { + for (int x = 0; x < max_x; ++x) { + for (int c = 0; c < max_c; ++c) { + block[xyc_to_sm_16b(y, x, c)] = + inp_data[(cy + y) * stride_y + (cx + x) * stride_x + (cc + c)]; + } + for (int c = max_c; c < 32; ++c) block[xyc_to_sm_16b(y, x, c)] = 0; + } + for (int x = max_x; x < 4; ++x) { + for (int c = 0; c < 32; ++c) block[xyc_to_sm_16b(y, x, c)] = 0; + } + } + + for (int y = max_y; y < 8; ++y) + for (int x = 0; x < 4; ++x) + for (int c = 0; c < 32; ++c) block[xyc_to_sm_16b(y, x, c)] = 0; + } // cc + } // cx + } // cy +} + +/** + * @brief Convert back from non-contguous layout to a flat layout + * + * @param out_flat Pre-allocated output memory pointer + * @param inp Blockized input data pointer + * @param height + * @param width + * @param depth + */ +void deblockize_hwc_16b(void* out_flat, void* inp, int height, int width, int depth) { + uintptr_t* inp_data = static_cast(inp); + uint16_t* out_data = static_cast(out_flat); + const int stride_x = depth; + const int stride_y = stride_x * width; + + for (int cy = 0; cy < height; cy += 8) { + for (int cx = 0; cx < width; cx += 4) { + for (int cc = 0; cc < depth; cc += 32) { + auto block = reinterpret_cast(*inp_data); + int max_y = std::min(8, height - cy); + int max_x = std::min(4, width - cx); + int max_c = std::min(32, depth - cc); + for (int y = 0; y < max_y; ++y) { + for (int x = 0; x < max_x; ++x) { + for (int c = 0; c < max_c; ++c) { + out_data[(cy + y) * stride_y + (cx + x) * stride_x + (cc + c)] = + block[xyc_to_sm_16b(y, x, c)]; + } + } + } + + inp_data++; + } + } + } +} + +/** + * @brief Convert the layout of weights from flat to "chunked". The term chunked is explained below: + * + * Weights are packed into the below mentioned layout (notation similar to index map): + * Since weights cannot be exactly represented into a index map notation, the + * base split up is mentioned below with a few gotchas + * + * lambda h, w, i, o: h//8, w//4, o//32, i//32, h%8, w%4, (i%32)//2, o%32, i%2 + * + * The gotchas are: + * - (w%4) is actually stored in the right to left order, as in 3,2,1,0 instead of 0,1,2,3 + * - The h%8 and (w%4) dimensions are not padded up, leading to chunks of different sizes + * (thereby the name "chunked" instead of packed) + * - The thinnest chunk of width is stored first. For example, if a kernel is 5x5, the first + * chunk along the width has size 1 (representing index 0) and then next one has size 4 + * representing indices (1,2,3,4) + * + * @param out_ptr Base pointer table to be filled with the list of pointers to the first addresses + * of the "chunked" weights + * @param out_ptr_size The number of chunks + * @param out Pointer to pre-allocated output memory + * @param inp Pointer to flat input data + * @param height + * @param width + * @param idepth + * @param odepth + */ +void chunkify_hwio_16b(void** out_ptr, int out_ptr_size, void* out, void* inp, int height, + int width, int idepth, int odepth) { + auto inp_data = static_cast(inp); + auto out_data = static_cast(out); + const int stride_i = odepth; + const int stride_x = stride_i * idepth; + const int stride_y = stride_x * width; + + for (int cy = 0; cy < height; cy += 8) { + // In the chunkified tensor, the chunks are ordered in increasing + // x order, but they start from the thin one. + for (int cx = width - round_up(width, 4); cx < width; cx += 4) { + int cx0 = std::max(0, cx); + for (int ci = 0; ci < idepth; ci += 32) { + for (int co = 0; co < odepth; co += 32) { + int max_y = std::min(8, height - cy); + int max_x = std::min(4, cx + 4 - cx0); + int max_i = std::min(32, idepth - ci); + int max_o = std::min(32, odepth - co); + + auto chunk = reinterpret_cast(out_data); + for (int y = 0; y < max_y; ++y) { + for (int x = max_x - 1; x >= 0; --x) { + for (int i = 0; i < max_i; ++i) { + for (int o = 0; o < max_o; ++o) { + chunk[hwio_to_sm_16b(max_x, y, x, i, o)] = + inp_data[(cy + y) * stride_y + (cx0 + x) * stride_x + (ci + i) * stride_i + + (co + o)]; + } + for (int o = max_o; o < 32; ++o) chunk[hwio_to_sm_16b(max_x, y, x, i, o)] = 0; + } + for (int i = max_i; i < 32; ++i) + for (int o = 0; o < 32; ++o) chunk[hwio_to_sm_16b(max_x, y, x, i, o)] = 0; + } + } + + *out_ptr++ = chunk; + out_data += max_y * max_x * 32 * 32; + out_ptr_size--; + assert(out_ptr_size >= 0); + } + } + } + } +} + +SDLTensor<4> prepare_nhwc(tvm::runtime::DeviceAPI* device_api, const DLTensor* nhwc_flat, + bool copy_data) { + tvm::runtime::String vtcm_scope = "global.vtcm"; + + // Allocate blocks for activations. We will use the block pointers + // directly from the allocated area. + int n = nhwc_flat->shape[0]; + int h = round_up(nhwc_flat->shape[1], 8); + int w = round_up(nhwc_flat->shape[2], 4); + int c = round_up(nhwc_flat->shape[3], 32); + int64_t shape_2d[2] = {(n * h * w * c) / (8 * 4 * 32), 8 * 4 * 32}; + void* nhwc_vtcm = + device_api->AllocDataSpace(hexagon_device, 2, shape_2d, nhwc_flat->dtype, vtcm_scope); + if (copy_data) { + blockize_hwc_16b(nhwc_vtcm, nhwc_flat->data, nhwc_flat->shape[1], nhwc_flat->shape[2], + nhwc_flat->shape[3]); + } + + return SDLTensor<4>(nhwc_vtcm, nhwc_flat->dtype, nhwc_vtcm, {n, h / 8, w / 4, c / 32}); +} + +SDLTensor<4> prepare_hwio(tvm::runtime::DeviceAPI* device_api, const DLTensor* hwio_flat, + int num_chunks, void** ptr_table) { + tvm::runtime::String vtcm_scope = "global.vtcm"; + + // Allocate one block for filter data. We will need to create our own + // pointer table. The reason is that filter chunks cannot be padded + // height- or width-wise, so filter chunks may have different sizes. + // A filter chunk is a block of size HxWx32x32, where H, W are at most + // height and width of a block respectively. + int h = hwio_flat->shape[0]; + int w = hwio_flat->shape[1]; + int i = round_up(hwio_flat->shape[2], 32); + int o = round_up(hwio_flat->shape[3], 32); + int64_t shape_1d[] = {h * w * i * o}; + void* hwio_vtcm = + device_api->AllocDataSpace(hexagon_device, 1, shape_1d, hwio_flat->dtype, vtcm_scope); + + chunkify_hwio_16b(ptr_table, num_chunks, hwio_vtcm, hwio_flat->data, hwio_flat->shape[0], + hwio_flat->shape[1], hwio_flat->shape[2], hwio_flat->shape[3]); + + return SDLTensor<4>(ptr_table, hwio_flat->dtype, hwio_vtcm, + {round_up(h, 8) / 8, round_up(w, 4) / 4, i / 32, o / 32}); +} + +int calculate_num_weight_chunks(int64_t* shape_hwio) { + int h = round_up(shape_hwio[0], 8); + int w = round_up(shape_hwio[1], 4); + int i = round_up(shape_hwio[2], 32); + int o = round_up(shape_hwio[3], 32); + + return (h * w * i * o) / (8 * 4 * 32 * 32); +} + +} // namespace hexagon +} // namespace runtime +} // namespace tvm diff --git a/tests/cpp-runtime/hexagon/hexagon_fp16_utils_tests.cc b/tests/cpp-runtime/hexagon/hexagon_fp16_utils_tests.cc new file mode 100644 index 000000000000..3b922fa6c2a8 --- /dev/null +++ b/tests/cpp-runtime/hexagon/hexagon_fp16_utils_tests.cc @@ -0,0 +1,289 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include + +#include +#include +#include +#include +#include + +#include "tvm/runtime/hexagon/ops/conv2d.h" + +using namespace tvm::runtime::hexagon; + +class HexagonUtilsTest : public ::testing::Test { + public: + void SetUp() override { + vtcm_scope = "global.vtcm"; + device_api = tvm::runtime::DeviceAPI::Get(hexagon_device, false); + float16.code = kDLFloat; + float16.bits = 16; + float16.lanes = 1; + } + + void setupTensor(std::tuple shape) { + auto [s1, s2, s3, s4] = shape; + tensor_shape[0] = s1; + tensor_shape[1] = s2; + tensor_shape[2] = s3; + tensor_shape[3] = s4; + int64_t shape_1d[1] = {s1 * s2 * s3 * s4}; + + flat_mem = device_api->AllocDataSpace(hexagon_device, 1, shape_1d, float16, vtcm_scope); + flat_mem_data = static_cast(flat_mem); + fill_vals(flat_mem_data, shape_1d[0]); + + flat_tensor.data = flat_mem; + flat_tensor.device = hexagon_device; + flat_tensor.ndim = 4; + flat_tensor.dtype = float16; + flat_tensor.shape = tensor_shape; + flat_tensor.strides = nullptr; + flat_tensor.byte_offset = 0; + } + + void TearDownTensor() { + if (flat_tensor.data) device_api->FreeDataSpace(hexagon_device, flat_mem); + } + + static void fill_vals(uint16_t* arr, int size) { + // Testing with uint16 instead of float16 as generating random float16 is not easy within c++ + uint16_t max = UINT16_MAX; + srand(std::time(0)); + for (int i = 0; i < size; ++i) { + arr[i] = static_cast(std::rand() % max); + } + } + + static int flattened_idx(int nn, int hh, int ww, int cc, int64_t* shape) { + int h = shape[1]; + int w = shape[2]; + int c = shape[3]; + return cc + c * (ww + w * (hh + h * (nn))); + } + + DLTensor flat_tensor; + void* flat_mem; + uint16_t* flat_mem_data; + tvm::runtime::DeviceAPI* device_api; + tvm::runtime::String vtcm_scope; + DLDataType float16; + int64_t tensor_shape[4]; +}; + +// Parameterized test fixture with 4 params representing n, h, w, c +class HexagonUtilsActivationsBlockizeTest + : public HexagonUtilsTest, + public ::testing::WithParamInterface, std::tuple>> {}; + +// TODO (quic-sanirudh): See if we can test with random generated indices +INSTANTIATE_TEST_SUITE_P( + BlockizeDeblockizeTestFixtures, HexagonUtilsActivationsBlockizeTest, + ::testing::Combine(::testing::Values(std::make_tuple(1, 14, 7, 60)), + ::testing::Values(std::make_tuple(0, 0, 0, 0), // first element + std::make_tuple(0, 7, 3, 31), // last element + // Remaining are random element tests + std::make_tuple(0, 13, 6, 59), + std::make_tuple(0, 0, 0, 32), std::make_tuple(0, 0, 4, 32), + std::make_tuple(0, 2, 3, 4), std::make_tuple(0, 5, 6, 7), + std::make_tuple(0, 10, 4, 12))), + [](const ::testing::TestParamInfo& info) { + // Can use info.param here to generate the test suffix + auto indices = std::get<1>(info.param); + int h = std::get<1>(indices); + int w = std::get<2>(indices); + int c = std::get<3>(indices); + // Generate test name as "hwc0x0x0" if the indices of hwc are 0,0,0 + std::string name = + "hwc" + std::to_string(h) + "x" + std::to_string(w) + "x" + std::to_string(c); + return name; + }); + +TEST_F(HexagonUtilsActivationsBlockizeTest, prepare_nhwc) { + auto shape = std::make_tuple(1, 14, 7, 60); + auto [n, h, w, c] = shape; + setupTensor(shape); + + // // copy_data is set to false here as there's a separate test for blockize when copy_data + // becomes true + auto blocked_tensor = prepare_nhwc(device_api, &flat_tensor, /*copy_data=*/false); + + EXPECT_EQ(blocked_tensor.shape[0], n); + EXPECT_EQ(blocked_tensor.shape[1], round_up(h, 8) / 8); + EXPECT_EQ(blocked_tensor.shape[2], round_up(w, 4) / 4); + EXPECT_EQ(blocked_tensor.shape[3], round_up(c, 32) / 32); + + TearDownTensor(); + release(device_api, blocked_tensor); +} + +TEST_P(HexagonUtilsActivationsBlockizeTest, blockize_hwc_16b) { + auto shape_tuple = std::get<0>(GetParam()); + setupTensor(shape_tuple); + auto [n, h, w, c] = shape_tuple; + int64_t shape[] = {n, h, w, c}; + + int h_rounded = round_up(h, 8); + int w_rounded = round_up(w, 4); + int c_rounded = round_up(c, 32); + int64_t shape_2d[2] = {(n * h_rounded * w_rounded * c_rounded) / (8 * 4 * 32), 8 * 4 * 32}; + + void* blocked_mem = device_api->AllocDataSpace(hexagon_device, 2, shape_2d, float16, vtcm_scope); + int64_t blocked_shape[] = {n, h_rounded / 8, w_rounded / 4, c_rounded / 32}; + blockize_hwc_16b(blocked_mem, flat_mem, h, w, c); + + std::function flatten = + HexagonUtilsActivationsBlockizeTest::flattened_idx; + + auto getBlockedElem = [&blocked_shape, blocked_mem, flatten](int nn, int hh, int ww, int cc) { + auto* blocks = static_cast(blocked_mem); + int blockIdx = flatten(nn, hh / 8, ww / 4, cc / 32, blocked_shape); + uint16_t* block = reinterpret_cast(blocks[blockIdx]); + return block[xyc_to_sm_16b(hh % 8, ww % 4, cc % 32)]; + }; + + auto [nn, hh, ww, cc] = std::get<1>(GetParam()); + + EXPECT_EQ(flat_mem_data[flattened_idx(nn, hh, ww, cc, shape)], getBlockedElem(nn, hh, ww, cc)); + + TearDownTensor(); + device_api->FreeDataSpace(hexagon_device, blocked_mem); +} + +TEST_P(HexagonUtilsActivationsBlockizeTest, deblockize_hwc_16b) { + auto shape_tuple = std::get<0>(GetParam()); + setupTensor(shape_tuple); + auto [n, h, w, c] = shape_tuple; + int64_t shape[] = {n, h, w, c}; + int64_t shape_1d[1] = {n * h * w * c}; + + int h_rounded = round_up(h, 8); + int w_rounded = round_up(w, 4); + int c_rounded = round_up(c, 32); + int64_t shape_2d[2] = {(n * h_rounded * w_rounded * c_rounded) / (8 * 4 * 32), 8 * 4 * 32}; + + void* blocked_mem = device_api->AllocDataSpace(hexagon_device, 2, shape_2d, float16, vtcm_scope); + blockize_hwc_16b(blocked_mem, flat_mem, h, w, c); + + void* deblocked_flat_mem = + device_api->AllocDataSpace(hexagon_device, 1, shape_1d, float16, vtcm_scope); + deblockize_hwc_16b(deblocked_flat_mem, blocked_mem, h, w, c); + auto* deblocked_flat_mem_data = static_cast(deblocked_flat_mem); + + auto [nn, hh, ww, cc] = std::get<1>(GetParam()); + + auto idx = flattened_idx(nn, hh, ww, cc, shape); + EXPECT_EQ(flat_mem_data[idx], deblocked_flat_mem_data[idx]); + + TearDownTensor(); + device_api->FreeDataSpace(hexagon_device, blocked_mem); + device_api->FreeDataSpace(hexagon_device, deblocked_flat_mem); +} + +class HexagonUtilsWeightsChunkifyTest + : public HexagonUtilsTest, + public ::testing::WithParamInterface, std::tuple>> {}; + +INSTANTIATE_TEST_SUITE_P( + ChunkifyDechunkifyTests, HexagonUtilsWeightsChunkifyTest, + ::testing::Combine(::testing::Values(std::make_tuple(3, 3, 40, 40)), + ::testing::Values(std::make_tuple(0, 0, 0, 0), // first element + std::make_tuple(2, 2, 39, 39), // Last element + // Remaining are random element tests + std::make_tuple(1, 1, 28, 33), + std::make_tuple(1, 2, 8, 38), + std::make_tuple(1, 0, 12, 15), + std::make_tuple(2, 1, 9, 22), std::make_tuple(0, 2, 6, 7), + std::make_tuple(1, 2, 3, 4))), + [](const ::testing::TestParamInfo& info) { + // Can use info.param here to generate the test suffix + auto indices = std::get<1>(info.param); + int h = std::get<0>(indices); + int w = std::get<1>(indices); + int i = std::get<2>(indices); + int o = std::get<3>(indices); + // Generate test name as "hwc0x0x0" if the indices of hwc are 0,0,0 + std::string name = "hwio" + std::to_string(h) + std::to_string(w) + "x" + std::to_string(i) + + "x" + std::to_string(o); + return name; + }); + +TEST_F(HexagonUtilsWeightsChunkifyTest, calculate_num_weight_chunks) { + int64_t shape[] = {3, 3, 40, 40}; + int num_wgt_chunks = calculate_num_weight_chunks(shape); + EXPECT_EQ(num_wgt_chunks, 4); +} + +TEST_F(HexagonUtilsWeightsChunkifyTest, prepare_hwio) { + int64_t shape[] = {3, 3, 40, 40}; + auto [h, w, i, o] = shape; + auto shape_tuple = std::make_tuple(h, w, i, o); + setupTensor(shape_tuple); + + // copy_data is set to false here as there's a separate test for blockize when copy_data becomes + // true + auto num_wgt_chunks = calculate_num_weight_chunks(shape); + auto wgt_ptr_table = + reinterpret_cast(__builtin_alloca(num_wgt_chunks * sizeof(uintptr_t))); + auto chunked_tensor = prepare_hwio(device_api, &flat_tensor, num_wgt_chunks, wgt_ptr_table); + + EXPECT_EQ(chunked_tensor.shape[0], round_up(h, 8) / 8); + EXPECT_EQ(chunked_tensor.shape[1], round_up(w, 4) / 4); + EXPECT_EQ(chunked_tensor.shape[2], round_up(i, 32) / 32); + EXPECT_EQ(chunked_tensor.shape[3], round_up(o, 32) / 32); + + release(device_api, chunked_tensor); + TearDownTensor(); +} + +TEST_P(HexagonUtilsWeightsChunkifyTest, chunkify_hwio_16b) { + auto [shape_tuple, indices] = GetParam(); + auto [h, w, i, o] = shape_tuple; + setupTensor(shape_tuple); + int64_t shape[] = {h, w, i, o}; + + auto num_wgt_chunks = calculate_num_weight_chunks(shape); + auto wgt_ptr_table = + reinterpret_cast(__builtin_alloca(num_wgt_chunks * sizeof(uintptr_t))); + auto chunked_tensor = prepare_hwio(device_api, &flat_tensor, num_wgt_chunks, wgt_ptr_table); + + int rd = w - (w % 4); // round down by 4 for width + int thin_w = w - rd; + + auto getChunkedElem = [thin_w, chunked_tensor](int hh, int ww, int ii, int oo) { + int fcw = 0; + if (ww >= thin_w) { + fcw = (ww - thin_w) / 4 + 1; + ww = (ww - thin_w) % 4; + } + auto chunk = hwio_at(chunked_tensor, hh / 8, fcw, ii / 32, oo / 32); + auto chunk_uint16 = reinterpret_cast(chunk); + return chunk_uint16[hwio_to_sm_16b(thin_w, hh % 8, ww, ii % 32, oo % 32)]; + }; + + auto [hh, ww, ii, oo] = indices; + + EXPECT_EQ(flat_mem_data[flattened_idx(hh, ww, ii, oo, shape)], getChunkedElem(hh, ww, ii, oo)); + release(device_api, chunked_tensor); +} diff --git a/tests/python/contrib/test_hexagon/topi/test_conv2d_fp16_intrin.py b/tests/python/contrib/test_hexagon/topi/test_conv2d_fp16_intrin.py new file mode 100644 index 000000000000..e8efdb369590 --- /dev/null +++ b/tests/python/contrib/test_hexagon/topi/test_conv2d_fp16_intrin.py @@ -0,0 +1,248 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" Test conv2d HVX intrinsic implementation""" + +import numpy as np + +import tvm +import tvm.contrib.hexagon +from tvm.topi.testing import conv2d_nhwc_python + + +def build_conv2d(target): + """Build and the return the conv2d module that calls the intrinsic implementation""" + act_n, act_h, act_w, act_c = ( + tvm.te.var("act_n"), + tvm.te.var("act_h"), + tvm.te.var("act_w"), + tvm.te.var("act_c"), + ) + filt_h, filt_w, filt_o = tvm.te.var("filt_h"), tvm.te.var("fw"), tvm.te.var("filt_o") + off_l, off_t = tvm.te.var("off_l"), tvm.te.var("off_t") + stride_h, stride_w = tvm.te.var("stride_h"), tvm.te.var("stride_w") + + act_flat = tvm.te.placeholder( + shape=(act_n, act_h, act_w, act_c), dtype="float16", name="act_flat" + ) + wgt_flat = tvm.te.placeholder( + shape=(filt_h, filt_w, act_c, filt_o), dtype="float16", name="wgt_flat" + ) + + out_flat = tvm.te.extern( + shape=(act_n, (act_h - filt_h) // stride_h + 1, (act_w - filt_w) // stride_w + 1, filt_o), + inputs=[act_flat, wgt_flat], + fcompute=lambda ins, outs: tvm.tir.call_cpacked( + "conv2d_packed_fp16", # Function from TVM runtime + ins[0], + ins[1], + off_t, + off_l, + stride_h, + stride_w, + outs[0], + tvm.runtime.const(0), # resource_handle (unused) + ), + dtype="float16", + ) + + s = tvm.te.create_schedule(out_flat.op) + + func_name = "extern_conv" + with tvm.transform.PassContext(opt_level=3): + module = tvm.build( + s, + [act_flat, wgt_flat, off_t, off_l, stride_h, stride_w, out_flat], + target=target, + name=func_name, + ) + + return module + + +shape_parameters = [ + ( + (1, 8, 4, 3), + (3, 3, 3, 3), + (1, 1), + ), + ( + (1, 10, 14, 3), + (3, 3, 3, 3), + (1, 1), + ), + ( + (1, 14, 6, 3), + (3, 3, 3, 3), + (1, 1), + ), + ( + (1, 14, 6, 3), + (3, 3, 3, 64), + (1, 1), + ), + ( + (1, 14, 6, 3), + (5, 5, 3, 3), + (1, 1), + ), + ( + (1, 8, 8, 3), + (2, 2, 3, 3), + (1, 1), + ), + ( + (1, 14, 6, 64), + (3, 3, 64, 3), + (1, 1), + ), + ( + (1, 4, 4, 40), + (3, 3, 40, 3), + (1, 1), + ), + ( + (1, 4, 4, 3), + (3, 3, 3, 3), + (1, 1), + ), + ( + (1, 5, 5, 3), + (3, 3, 3, 3), + (1, 1), + ), + ( + (1, 6, 6, 3), + (3, 3, 3, 3), + (1, 1), + ), + ( + (1, 7, 7, 3), + (3, 3, 3, 3), + (1, 1), + ), + ( + (1, 8, 8, 3), + (3, 3, 3, 3), + (1, 1), + ), + ( + (1, 8, 8, 3), + (5, 5, 3, 3), + (1, 1), + ), + ( + (1, 8, 8, 64), + (2, 2, 64, 64), + (1, 1), + ), + ( + (1, 8, 4, 3), + (3, 3, 3, 3), + (2, 2), + ), + ( + (1, 14, 6, 3), + (3, 3, 3, 64), + (2, 2), + ), + ( + (1, 14, 6, 3), + (5, 5, 3, 3), + (2, 2), + ), + ( + (1, 8, 8, 3), + (2, 2, 3, 3), + (2, 2), + ), +] + + +def gen_config(params): + """Utility function to generate useful ids for shape_parameters""" + + dims = lambda vals: "x".join(map(str, vals)) + + config = {} + for param in params: + act_shape, wgt_shape, inp_stride = param + name = f"nhwc{dims(act_shape)}-hwio{dims(wgt_shape)}-stride{dims(inp_stride)}" + config[name] = param + + return config + + +class TestConv2dIntrin: + """Test Conv2d Intrin class""" + + config = gen_config(shape_parameters) + act_shape, wgt_shape, inp_stride = tvm.testing.parameters(*config.values(), ids=config.keys()) + inp_offset = tvm.testing.parameter((0, 0), ids=["offset0x0"]) + + @tvm.testing.requires_hexagon + def test_conv2d(self, act_shape, wgt_shape, inp_stride, inp_offset, hexagon_session): + """Test conv2d intrinsic implementation""" + assert act_shape[3] == wgt_shape[2] + + target_hexagon = tvm.target.hexagon("v69") + target = tvm.target.Target(target_hexagon, host=target_hexagon) + + # Currently, input offset does not affect the output shape + def get_out_shape(ash, wsh, inp_stride): + assert ash[3] == wsh[2] + osh = ( + ash[0], + (ash[1] - wsh[0]) // inp_stride[0] + 1, + (ash[2] - wsh[1]) // inp_stride[1] + 1, + wsh[3], + ) + assert tvm.tir.all([x > 0 for x in osh]) + return osh + + act = np.random.rand(*act_shape).astype("float16") + wgt = np.random.rand(*wgt_shape).astype("float16") + + module = build_conv2d(target) + + mod = hexagon_session.load_module(module) + output = tvm.nd.array( + np.zeros(get_out_shape(act_shape, wgt_shape, inp_stride), dtype="float16"), + device=hexagon_session.device, + ) + mod( + tvm.nd.array(act, device=hexagon_session.device), + tvm.nd.array(wgt, device=hexagon_session.device), + inp_offset[0], # off_t + inp_offset[1], # off_l + inp_stride[0], # stride_height + inp_stride[1], # stride_width + output, + ) + + out = output.numpy() + + # Generate reference output and compare: + ref_out = conv2d_nhwc_python( + act.astype("float32"), wgt.astype("float32"), stride=inp_stride, padding="VALID" + ).astype("float16") + + tvm.testing.assert_allclose(out, ref_out, rtol=5e-2, atol=5e-2) + + +if __name__ == "__main__": + tvm.testing.main()