@@ -1312,6 +1312,111 @@ aclnnStatus aclnnIm2col(void* workspace, uint64_t workspaceSize,
13121312#ifdef __cplusplus
13131313}
13141314#endif
1315+
1316+ static void ggml_cann_im2col_2d_post_process (ggml_backend_cann_context& ctx,
1317+ ggml_tensor* dst,
1318+ ggml_tensor* src1,
1319+ aclTensor* tmp_cast_tensor,
1320+ aclTensor* tmp_im2col_tensor) {
1321+ // Permute: [N, IC * KH * KW, OW * OH] -> [N, OW * OH, IC * KH * KW]
1322+ int64_t dst_ne[] = {dst->ne [0 ], dst->ne [1 ] * dst->ne [2 ], dst->ne [3 ]};
1323+ size_t dst_nb[] = {dst->nb [0 ], dst->nb [1 ], dst->nb [3 ]};
1324+ aclTensor* acl_dst =
1325+ ggml_cann_create_tensor (dst, dst_ne, dst_nb, GGML_MAX_DIMS - 1 );
1326+
1327+ int64_t permute_dim[] = {0 , 2 , 1 };
1328+ if (src1->type != dst->type ) {
1329+ aclnn_permute (ctx, tmp_cast_tensor, acl_dst, permute_dim, 3 );
1330+ } else {
1331+ aclnn_permute (ctx, tmp_im2col_tensor, acl_dst, permute_dim, 3 );
1332+ }
1333+
1334+ // release
1335+ ACL_CHECK (aclDestroyTensor (acl_dst));
1336+ }
1337+
1338+ static void ggml_cann_im2col_1d_post_process (
1339+ ggml_backend_cann_context& ctx, ggml_tensor* dst, ggml_tensor* src1,
1340+ aclTensor* tmp_cast_tensor, aclTensor* tmp_im2col_tensor,
1341+ const std::vector<int64_t >& im2col_op_params) {
1342+ // get params
1343+ const int64_t KH = im2col_op_params[0 ];
1344+ const int64_t KW = im2col_op_params[1 ];
1345+ const int64_t IW = im2col_op_params[2 ];
1346+ const int64_t IC = im2col_op_params[3 ];
1347+ const int64_t N = im2col_op_params[4 ];
1348+ const int64_t OH = im2col_op_params[5 ];
1349+ const int64_t OW = im2col_op_params[6 ];
1350+ const int64_t s0 = im2col_op_params[7 ];
1351+ const int64_t p0 = im2col_op_params[8 ];
1352+ const int64_t d0 = im2col_op_params[9 ];
1353+ const int64_t n_bytes_factor = im2col_op_params[10 ];
1354+
1355+ // Permute: [N, IC * KH * KW, OW * OH] ->
1356+ // [N, OW * OH * n_bytes_factor, IC * KH * KW]
1357+ aclTensor* tmp_permute_tensor = nullptr ;
1358+ ggml_cann_pool_alloc tmp_permute_allocator (ctx.pool ());
1359+ tmp_permute_allocator.alloc (ggml_nbytes (dst) * n_bytes_factor);
1360+ void * tmp_permute_buffer = tmp_permute_allocator.get ();
1361+
1362+ int64_t tmp_permute_ne[] = {IC * KH * KW, OW * OH * n_bytes_factor, N};
1363+ size_t tmp_permute_nb[GGML_MAX_DIMS - 1 ];
1364+ tmp_permute_nb[0 ] = ggml_type_size (dst->type );
1365+ for (int i = 1 ; i < GGML_MAX_DIMS - 1 ; i++) {
1366+ tmp_permute_nb[i] = tmp_permute_nb[i - 1 ] * tmp_permute_ne[i - 1 ];
1367+ }
1368+
1369+ tmp_permute_tensor = ggml_cann_create_tensor (
1370+ tmp_permute_buffer, ggml_cann_type_mapping (dst->type ),
1371+ ggml_type_size (dst->type ), tmp_permute_ne, tmp_permute_nb,
1372+ GGML_MAX_DIMS - 1 , ACL_FORMAT_ND);
1373+
1374+ int64_t permute_dim[] = {0 , 2 , 1 };
1375+ if (src1->type != dst->type ) {
1376+ aclnn_permute (ctx, tmp_cast_tensor, tmp_permute_tensor, permute_dim, 3 );
1377+ } else {
1378+ aclnn_permute (ctx, tmp_im2col_tensor, tmp_permute_tensor, permute_dim,
1379+ 3 );
1380+ }
1381+
1382+ // number of times the kernel moves in W dimension
1383+ const int n_step_w = (IW + 2 * p0 - d0 * (KW - 1 ) - 1 ) / s0 + 1 ;
1384+ size_t offset;
1385+ void *cur_dst_buffer = dst->data , *cur_permute_buffer = tmp_permute_buffer;
1386+
1387+ // memory copy with offset to restore 1D im2col from 2d
1388+ if (IC > 1 ) {
1389+ offset = IC * KH * KW * n_step_w * ggml_type_size (dst->type );
1390+ size_t size_cpy = KH * KW * ggml_type_size (dst->type );
1391+
1392+ for (int c = 0 ; c < IC; c++) {
1393+ cur_permute_buffer = (char *)tmp_permute_buffer + offset +
1394+ KH * KW * c * ggml_type_size (dst->type );
1395+ cur_dst_buffer = (char *)dst->data +
1396+ c * KH * KW * n_step_w * ggml_type_size (dst->type );
1397+
1398+ for (int i = 0 ; i < n_step_w; i++) {
1399+ ACL_CHECK (aclrtMemcpyAsync (
1400+ cur_dst_buffer, size_cpy, cur_permute_buffer, size_cpy,
1401+ ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream ()));
1402+ cur_dst_buffer =
1403+ (char *)cur_dst_buffer + KH * KW * ggml_type_size (dst->type );
1404+ cur_permute_buffer = (char *)cur_permute_buffer +
1405+ KH * KW * IC * ggml_type_size (dst->type );
1406+ }
1407+ }
1408+ } else {
1409+ offset = KH * KW * n_step_w *
1410+ ggml_type_size (dst->type ); // equal to ggml_nbytes(dst)
1411+ ACL_CHECK (aclrtMemcpyAsync (dst->data , offset,
1412+ (char *)tmp_permute_buffer + offset, offset,
1413+ ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream ()));
1414+ }
1415+
1416+ // release
1417+ ACL_CHECK (aclDestroyTensor (tmp_permute_tensor));
1418+ }
1419+
13151420void ggml_cann_im2col (ggml_backend_cann_context& ctx, ggml_tensor* dst) {
13161421 ggml_tensor* src0 = dst->src [0 ]; // kernel
13171422 ggml_tensor* src1 = dst->src [1 ]; // input
@@ -1320,31 +1425,36 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
13201425 GGML_ASSERT (src1->type == GGML_TYPE_F32);
13211426 GGML_ASSERT (dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
13221427
1428+ GGML_TENSOR_BINARY_OP_LOCALS;
1429+
1430+ // aclnnIm2col only works on 2D. set s1, p1, d1 to 1 to perform 2D
1431+ // im2col and do post-processing to restore it to 1D.
1432+ const bool is_2D = ((const int32_t *)(dst->op_params ))[6 ] == 1 ;
13231433 const int32_t s0 = ((const int32_t *)(dst->op_params ))[0 ];
1324- const int32_t s1 = ((const int32_t *)(dst->op_params ))[1 ];
1434+ const int32_t s1 = is_2D ? ((const int32_t *)(dst->op_params ))[1 ] : 1 ;
13251435 const int32_t p0 = ((const int32_t *)(dst->op_params ))[2 ];
1326- const int32_t p1 = ((const int32_t *)(dst->op_params ))[3 ];
1436+ const int32_t p1 = is_2D ? ((const int32_t *)(dst->op_params ))[3 ] : 1 ;
13271437 const int32_t d0 = ((const int32_t *)(dst->op_params ))[4 ];
1328- const int32_t d1 = ((const int32_t *)(dst->op_params ))[5 ];
1329- const bool is_2D = ((const int32_t *)(dst->op_params ))[6 ] == 1 ;
1330-
1331- GGML_TENSOR_BINARY_OP_LOCALS;
1332-
1333- const int64_t N = is_2D ? ne13 : ne12;
1334- const int64_t IC = is_2D ? ne12 : ne11;
1438+ const int32_t d1 = is_2D ? ((const int32_t *)(dst->op_params ))[5 ] : 1 ;
13351439
1336- const int64_t KH = is_2D ? ne01 : 1 ;
1440+ const int64_t N = ne13;
1441+ const int64_t IC = ne12;
1442+ const int64_t KH = ne01;
13371443 const int64_t KW = ne00;
1444+ const int64_t IW = ne10;
13381445
13391446 const int64_t OH = is_2D ? ne2 : 1 ;
13401447 const int64_t OW = ne1;
13411448
13421449 GGML_ASSERT (nb00 == sizeof (ggml_fp16_t ));
13431450 GGML_ASSERT (nb10 == sizeof (float ));
13441451
1345- // im2col: [N,C,H,W] -> [N, IC * KH * KW, OW * OH]
1452+ // memory allocated increased to 3x when is_2D == false
1453+ const int64_t n_bytes_factor = is_2D ? 1 : 3 ;
1454+
1455+ // im2col: [N,C,H,W] -> [N, IC * KH * KW, OW * OH * n_bytes_factor]
13461456 aclTensor* acl_src1 = ggml_cann_create_tensor (src1);
1347- int64_t tmp_im2col_ne[] = {OW * OH, IC * KH * KW, N};
1457+ int64_t tmp_im2col_ne[] = {OW * OH * n_bytes_factor , IC * KH * KW, N};
13481458 size_t tmp_im2col_nb[GGML_MAX_DIMS - 1 ];
13491459
13501460 tmp_im2col_nb[0 ] = ggml_type_size (src1->type );
@@ -1356,8 +1466,10 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
13561466 // If dst is f16, tmp_buffer is f32, we need alloc src.typesize *
13571467 // dst.elemcount.
13581468 ggml_cann_pool_alloc im2col_allocator (
1359- ctx.pool (), ggml_nelements (dst) * ggml_element_size (src1));
1469+ ctx.pool (),
1470+ ggml_nelements (dst) * ggml_element_size (src1) * n_bytes_factor);
13601471 void * tmp_im2col_buffer = im2col_allocator.get ();
1472+
13611473 aclTensor* tmp_im2col_tensor = ggml_cann_create_tensor (
13621474 tmp_im2col_buffer, ggml_cann_type_mapping (src1->type ),
13631475 ggml_type_size (src1->type ), tmp_im2col_ne, tmp_im2col_nb,
@@ -1380,8 +1492,9 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
13801492 paddings, strides, tmp_im2col_tensor,
13811493 &workspaceSize, &executor));
13821494
1495+ ggml_cann_pool_alloc workspace_allocator (ctx.pool ());
13831496 if (workspaceSize > 0 ) {
1384- ggml_cann_pool_alloc workspace_allocator (ctx. pool (), workspaceSize);
1497+ workspace_allocator. alloc ( workspaceSize);
13851498 workspaceAddr = workspace_allocator.get ();
13861499 }
13871500
@@ -1391,9 +1504,10 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
13911504 // Cast if dst is f16.
13921505 aclTensor* tmp_cast_tensor = nullptr ;
13931506 ggml_cann_pool_alloc tmp_cast_allocator (ctx.pool ());
1507+ void * tmp_cast_buffer = nullptr ;
13941508 if (src1->type != dst->type ) {
1395- tmp_cast_allocator.alloc (ggml_nbytes (dst));
1396- void * tmp_cast_buffer = tmp_cast_allocator.get ();
1509+ tmp_cast_allocator.alloc (ggml_nbytes (dst) * n_bytes_factor );
1510+ tmp_cast_buffer = tmp_cast_allocator.get ();
13971511 size_t temp_cast_nb[GGML_MAX_DIMS - 1 ];
13981512 temp_cast_nb[0 ] = ggml_type_size (dst->type );
13991513 for (int i = 1 ; i < GGML_MAX_DIMS - 1 ; i++) {
@@ -1408,24 +1522,21 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
14081522 ggml_cann_type_mapping (dst->type ));
14091523 }
14101524
1411- // Permute: [N, IC * KH * KW, OW * OH] -> [N, OW * OH, IC * KH * KW]
1412- int64_t dst_ne[] = {dst->ne [0 ], dst->ne [1 ] * dst->ne [2 ], dst->ne [3 ]};
1413- size_t dst_nb[] = {dst->nb [0 ], dst->nb [1 ], dst->nb [3 ]};
1414- aclTensor* acl_dst =
1415- ggml_cann_create_tensor (dst, dst_ne, dst_nb, GGML_MAX_DIMS - 1 );
1416-
1417- int64_t permute_dim[] = {0 , 2 , 1 };
1418- if (src1->type != dst->type ) {
1419- aclnn_permute (ctx, tmp_cast_tensor, acl_dst, permute_dim, 3 );
1525+ // post-processing
1526+ if (is_2D) {
1527+ ggml_cann_im2col_2d_post_process (ctx, dst, src1, tmp_cast_tensor,
1528+ tmp_im2col_tensor);
14201529 } else {
1421- aclnn_permute (ctx, tmp_im2col_tensor, acl_dst, permute_dim, 3 );
1530+ std::vector<int64_t > im2col_op_params = {
1531+ KH, KW, IW, IC, N, OH, OW, s0, p0, d0, n_bytes_factor};
1532+ ggml_cann_im2col_1d_post_process (ctx, dst, src1, tmp_cast_tensor,
1533+ tmp_im2col_tensor, im2col_op_params);
14221534 }
14231535
14241536 // release
14251537 ACL_CHECK (aclDestroyTensor (acl_src1));
14261538 ACL_CHECK (aclDestroyTensor (tmp_im2col_tensor));
14271539 ACL_CHECK (aclDestroyTensor (tmp_cast_tensor));
1428- ACL_CHECK (aclDestroyTensor (acl_dst));
14291540 ACL_CHECK (aclDestroyIntArray (kernel_size));
14301541 ACL_CHECK (aclDestroyIntArray (dilations));
14311542 ACL_CHECK (aclDestroyIntArray (paddings));
0 commit comments