@@ -1374,7 +1374,10 @@ struct ggml_compute_state {
13741374
13751375inline  static  void  ggml_vec_set_i8 (const  int  n , int8_t  *  x , const  int8_t  v ) { for  (int  i  =  0 ; i  <  n ; ++ i ) x [i ] =  v ; }
13761376inline  static  void  ggml_vec_set_i16 (const  int  n , int16_t  *  x , const  int16_t  v ) { for  (int  i  =  0 ; i  <  n ; ++ i ) x [i ] =  v ; }
1377- inline  static  void  ggml_vec_set_i32 (const  int  n , int32_t  *  x , const  int32_t  v ) { for  (int  i  =  0 ; i  <  n ; ++ i ) x [i ] =  v ; }
1377+ 
1378+ inline  static  void  ggml_vec_set_i32 (const  int  n , int32_t  *  x , const  int32_t    v ) { for  (int  i  =  0 ; i  <  n ; ++ i ) x [i ] =  v ;    }
1379+ inline  static  void  ggml_vec_cpy_i32 (const  int  n , int32_t  *  y , const  int32_t  *  x ) { for  (int  i  =  0 ; i  <  n ; ++ i ) y [i ] =  x [i ]; }
1380+ 
13781381inline  static  void  ggml_vec_set_f16 (const  int  n , ggml_fp16_t  *  x , const  int32_t  v ) { for  (int  i  =  0 ; i  <  n ; ++ i ) x [i ] =  v ; }
13791382inline  static  void  ggml_vec_set_bf16 (const  int  n , ggml_bf16_t  *  x , const  ggml_bf16_t  v ) { for  (int  i  =  0 ; i  <  n ; ++ i ) x [i ] =  v ; }
13801383inline  static  void  ggml_vec_add_f32  (const  int  n , float  *  z , const  float  *  x , const  float  *  y ) { for  (int  i  =  0 ; i  <  n ; ++ i ) z [i ]  =  x [i ] +  y [i ]; }
@@ -8248,6 +8251,77 @@ static void ggml_compute_forward_set_f32(
82488251    }
82498252}
82508253
8254+ static  void  ggml_compute_forward_set_i32 (
8255+         const  struct  ggml_compute_params  *  params ,
8256+         struct  ggml_tensor  *  dst ) {
8257+ 
8258+     const  struct  ggml_tensor  *  src0  =  dst -> src [0 ];
8259+     const  struct  ggml_tensor  *  src1  =  dst -> src [1 ];
8260+ 
8261+     GGML_ASSERT (ggml_are_same_shape (src0 , dst ));
8262+     GGML_ASSERT (ggml_is_contiguous (dst ) &&  ggml_is_contiguous (src0 ));
8263+ 
8264+     // view src0 and dst with these strides and data offset inbytes during set 
8265+     // nb0 is implicitly element_size because src0 and dst are contiguous 
8266+     size_t  nb1      =  ((int32_t  * ) dst -> op_params )[0 ];
8267+     size_t  nb2      =  ((int32_t  * ) dst -> op_params )[1 ];
8268+     size_t  nb3      =  ((int32_t  * ) dst -> op_params )[2 ];
8269+     size_t  offset   =  ((int32_t  * ) dst -> op_params )[3 ];
8270+     bool    inplace  =  (bool ) ((int32_t  * ) dst -> op_params )[4 ];
8271+ 
8272+     if  (!inplace ) {
8273+         if  (params -> ith  ==  0 ) {
8274+             // memcpy needs to be synchronized across threads to avoid race conditions. 
8275+             // => do it in INIT phase 
8276+             memcpy (
8277+                 ((char  * )  dst -> data ),
8278+                 ((char  * ) src0 -> data ),
8279+                 ggml_nbytes (dst ));
8280+         }
8281+         ggml_barrier (params -> threadpool );
8282+     }
8283+ 
8284+     const  int  ith  =  params -> ith ;
8285+     const  int  nth  =  params -> nth ;
8286+ 
8287+     const  int  nr  =  ggml_nrows (src1 );
8288+     const  int  nc  =  src1 -> ne [0 ];
8289+ 
8290+     GGML_TENSOR_LOCALS (int64_t , ne1 , src1 , ne )
8291+     GGML_TENSOR_LOCALS (size_t ,  nb1 , src1 , nb )
8292+ 
8293+     // src0 and dst as viewed during set 
8294+     const  size_t  nb0  =  ggml_element_size (src0 );
8295+ 
8296+     const  int  im0  =  (ne10  ==  0  ? 0  : ne10 - 1 );
8297+     const  int  im1  =  (ne11  ==  0  ? 0  : ne11 - 1 );
8298+     const  int  im2  =  (ne12  ==  0  ? 0  : ne12 - 1 );
8299+     const  int  im3  =  (ne13  ==  0  ? 0  : ne13 - 1 );
8300+ 
8301+     GGML_ASSERT (offset  +  im0 * nb0   +  im1 * nb1   +  im2 * nb2   +  im3 * nb3   <= ggml_nbytes (dst ));
8302+ 
8303+     GGML_ASSERT (nb10  ==  sizeof (int32_t ));
8304+ 
8305+     // rows per thread 
8306+     const  int  dr  =  (nr  +  nth  -  1 )/nth ;
8307+ 
8308+     // row range for this thread 
8309+     const  int  ir0  =  dr * ith ;
8310+     const  int  ir1  =  MIN (ir0  +  dr , nr );
8311+ 
8312+     for  (int  ir  =  ir0 ; ir  <  ir1 ; ++ ir ) {
8313+         // src0 and dst are viewed with shape of src1 and offset 
8314+         // => same indices 
8315+         const  int  i3  =  ir /(ne12 * ne11 );
8316+         const  int  i2  =  (ir  -  i3 * ne12 * ne11 )/ne11 ;
8317+         const  int  i1  =  (ir  -  i3 * ne12 * ne11  -  i2 * ne11 );
8318+ 
8319+         ggml_vec_cpy_i32 (nc ,
8320+                 (int32_t  * ) ((char  * )  dst -> data  +  i3 * nb3   +  i2 * nb2   +  i1 * nb1   +  offset ),
8321+                 (int32_t  * ) ((char  * ) src1 -> data  +  i3 * nb13  +  i2 * nb12  +  i1 * nb11 ));
8322+     }
8323+ }
8324+ 
82518325static  void  ggml_compute_forward_set (
82528326        const  struct  ggml_compute_params  *  params ,
82538327        struct  ggml_tensor  *  dst ) {
@@ -8259,6 +8333,10 @@ static void ggml_compute_forward_set(
82598333            {
82608334                ggml_compute_forward_set_f32 (params , dst );
82618335            } break ;
8336+         case  GGML_TYPE_I32 :
8337+             {
8338+                 ggml_compute_forward_set_i32 (params , dst );
8339+             } break ;
82628340        case  GGML_TYPE_F16 :
82638341        case  GGML_TYPE_BF16 :
82648342        case  GGML_TYPE_Q4_0 :
0 commit comments