@@ -515,6 +515,7 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
515
515
quantize_row_q4_0_reference (x , y , k );
516
516
}
517
517
518
+
518
519
void quantize_row_q4_1_reference (const float * restrict x , block_q4_1 * restrict y , int k ) {
519
520
const int qk = QK4_1 ;
520
521
@@ -3039,6 +3040,197 @@ size_t quantize_q6_K(const float * src, void * dst, int nrow, int n_per_row, int
3039
3040
return nrow * row_size ;
3040
3041
}
3041
3042
3043
+ static void quantize_row_q4_0_impl (const float * restrict x , block_q4_0 * restrict y , int n_per_row , const float * quant_weights ) {
3044
+ static_assert (QK4_0 == 32 , "QK4_0 must be 32" );
3045
+
3046
+ if (!quant_weights ) {
3047
+ quantize_row_q4_0_reference (x , y , n_per_row );
3048
+ return ;
3049
+ }
3050
+
3051
+ float weight [QK4_0 ];
3052
+ int8_t L [QK4_0 ];
3053
+
3054
+ float sum_x2 = 0 ;
3055
+ for (int j = 0 ; j < n_per_row ; ++ j ) sum_x2 += x [j ]* x [j ];
3056
+ float sigma2 = sum_x2 /n_per_row ;
3057
+
3058
+ const int nb = n_per_row /QK4_0 ;
3059
+ for (int ib = 0 ; ib < nb ; ++ ib ) {
3060
+ const float * xb = x + QK4_0 * ib ;
3061
+ const float * qw = quant_weights + QK4_0 * ib ;
3062
+ for (int j = 0 ; j < QK4_0 ; ++ j ) weight [j ] = qw [j ] * sqrtf (sigma2 + xb [j ]* xb [j ]);
3063
+ float d = make_qx_quants (QK4_0 , 8 , xb , L , 1 , weight );
3064
+ y [ib ].d = GGML_FP32_TO_FP16 (d );
3065
+ for (int j = 0 ; j < 16 ; ++ j ) {
3066
+ y [ib ].qs [j ] = L [j ] | (L [j + 16 ] << 4 );
3067
+ }
3068
+ }
3069
+ }
3070
+
3071
+ size_t quantize_q4_0 (const float * src , void * dst , int nrow , int n_per_row , int64_t * hist , const float * quant_weights ) {
3072
+ if (!quant_weights ) {
3073
+ return ggml_quantize_q4_0 (src , dst , nrow * n_per_row , n_per_row , hist );
3074
+ }
3075
+ int row_size = ggml_row_size (GGML_TYPE_Q4_0 , n_per_row );
3076
+ char * qrow = (char * )dst ;
3077
+ for (int row = 0 ; row < nrow ; ++ row ) {
3078
+ quantize_row_q4_0_impl (src , (block_q4_0 * )qrow , n_per_row , quant_weights );
3079
+ src += n_per_row ;
3080
+ qrow += row_size ;
3081
+ }
3082
+ return nrow * row_size ;
3083
+ }
3084
+
3085
+ static void quantize_row_q4_1_impl (const float * restrict x , block_q4_1 * restrict y , int n_per_row , const float * quant_weights ) {
3086
+ static_assert (QK4_1 == 32 , "QK4_1 must be 32" );
3087
+
3088
+ if (!quant_weights ) {
3089
+ quantize_row_q4_1_reference (x , y , n_per_row );
3090
+ return ;
3091
+ }
3092
+
3093
+ float weight [QK4_1 ];
3094
+ uint8_t L [QK4_1 ], Laux [QK4_1 ];
3095
+
3096
+ float sum_x2 = 0 ;
3097
+ for (int j = 0 ; j < n_per_row ; ++ j ) sum_x2 += x [j ]* x [j ];
3098
+ float sigma2 = sum_x2 /n_per_row ;
3099
+
3100
+ const int nb = n_per_row /QK4_1 ;
3101
+ for (int ib = 0 ; ib < nb ; ++ ib ) {
3102
+ const float * xb = x + QK4_1 * ib ;
3103
+ const float * qw = quant_weights + QK4_1 * ib ;
3104
+ for (int j = 0 ; j < QK4_1 ; ++ j ) weight [j ] = qw [j ] * sqrtf (sigma2 + xb [j ]* xb [j ]);
3105
+ float min ;
3106
+ float d = make_qkx3_quants (QK4_1 , 15 , xb , weight , L , & min , Laux , -0.9f , 0.05f , 36 , false);
3107
+ y [ib ].d = GGML_FP32_TO_FP16 (d );
3108
+ y [ib ].m = GGML_FP32_TO_FP16 (- min );
3109
+ for (int j = 0 ; j < 16 ; ++ j ) {
3110
+ y [ib ].qs [j ] = L [j ] | (L [j + 16 ] << 4 );
3111
+ }
3112
+ }
3113
+ }
3114
+
3115
+ size_t quantize_q4_1 (const float * src , void * dst , int nrow , int n_per_row , int64_t * hist , const float * quant_weights ) {
3116
+ if (!quant_weights ) {
3117
+ return ggml_quantize_q4_1 (src , dst , nrow * n_per_row , n_per_row , hist );
3118
+ }
3119
+ int row_size = ggml_row_size (GGML_TYPE_Q4_1 , n_per_row );
3120
+ char * qrow = (char * )dst ;
3121
+ for (int row = 0 ; row < nrow ; ++ row ) {
3122
+ quantize_row_q4_1_impl (src , (block_q4_1 * )qrow , n_per_row , quant_weights );
3123
+ src += n_per_row ;
3124
+ qrow += row_size ;
3125
+ }
3126
+ return nrow * row_size ;
3127
+ }
3128
+
3129
+ static void quantize_row_q5_0_impl (const float * restrict x , block_q5_0 * restrict y , int n_per_row , const float * quant_weights ) {
3130
+ static_assert (QK5_0 == 32 , "QK5_0 must be 32" );
3131
+
3132
+ if (!quant_weights ) {
3133
+ quantize_row_q5_0_reference (x , y , n_per_row );
3134
+ return ;
3135
+ }
3136
+
3137
+ float weight [QK5_0 ];
3138
+ int8_t L [QK5_0 ];
3139
+
3140
+ float sum_x2 = 0 ;
3141
+ for (int j = 0 ; j < n_per_row ; ++ j ) sum_x2 += x [j ]* x [j ];
3142
+ float sigma2 = sum_x2 /n_per_row ;
3143
+
3144
+ const int nb = n_per_row /QK5_0 ;
3145
+ for (int ib = 0 ; ib < nb ; ++ ib ) {
3146
+ const float * xb = x + QK5_0 * ib ;
3147
+ const float * qw = quant_weights + QK5_0 * ib ;
3148
+ for (int j = 0 ; j < QK5_0 ; ++ j ) weight [j ] = qw [j ] * sqrtf (sigma2 + xb [j ]* xb [j ]);
3149
+ float d = make_qx_quants (QK5_0 , 16 , xb , L , 1 , weight );
3150
+ y [ib ].d = GGML_FP32_TO_FP16 (d );
3151
+
3152
+ uint32_t qh = 0 ;
3153
+
3154
+ for (int j = 0 ; j < 16 ; ++ j ) {
3155
+ const uint8_t xi0 = L [j ];
3156
+ const uint8_t xi1 = L [j + 16 ];
3157
+ y [ib ].qs [j ] = (xi0 & 0x0F ) | ((xi1 & 0x0F ) << 4 );
3158
+
3159
+ // get the 5-th bit and store it in qh at the right position
3160
+ qh |= ((xi0 & 0x10u ) >> 4 ) << (j + 0 );
3161
+ qh |= ((xi1 & 0x10u ) >> 4 ) << (j + QK5_0 /2 );
3162
+ }
3163
+
3164
+ memcpy (& y [ib ].qh , & qh , sizeof (qh ));
3165
+ }
3166
+ }
3167
+
3168
+ size_t quantize_q5_0 (const float * src , void * dst , int nrow , int n_per_row , int64_t * hist , const float * quant_weights ) {
3169
+ if (!quant_weights ) {
3170
+ return ggml_quantize_q5_0 (src , dst , nrow * n_per_row , n_per_row , hist );
3171
+ }
3172
+ int row_size = ggml_row_size (GGML_TYPE_Q5_0 , n_per_row );
3173
+ char * qrow = (char * )dst ;
3174
+ for (int row = 0 ; row < nrow ; ++ row ) {
3175
+ quantize_row_q5_0_impl (src , (block_q5_0 * )qrow , n_per_row , quant_weights );
3176
+ src += n_per_row ;
3177
+ qrow += row_size ;
3178
+ }
3179
+ return nrow * row_size ;
3180
+ }
3181
+
3182
+ static void quantize_row_q5_1_impl (const float * restrict x , block_q5_1 * restrict y , int n_per_row , const float * quant_weights ) {
3183
+ static_assert (QK5_1 == 32 , "QK5_1 must be 32" );
3184
+
3185
+ if (!quant_weights ) {
3186
+ quantize_row_q5_1_reference (x , y , n_per_row );
3187
+ return ;
3188
+ }
3189
+
3190
+ float weight [QK5_1 ];
3191
+ uint8_t L [QK5_1 ], Laux [QK5_1 ];
3192
+
3193
+ float sum_x2 = 0 ;
3194
+ for (int j = 0 ; j < n_per_row ; ++ j ) sum_x2 += x [j ]* x [j ];
3195
+ float sigma2 = sum_x2 /n_per_row ;
3196
+
3197
+ const int nb = n_per_row /QK5_1 ;
3198
+ for (int ib = 0 ; ib < nb ; ++ ib ) {
3199
+ const float * xb = x + QK5_1 * ib ;
3200
+ const float * qw = quant_weights + QK5_1 * ib ;
3201
+ for (int j = 0 ; j < QK5_1 ; ++ j ) weight [j ] = qw [j ] * sqrtf (sigma2 + xb [j ]* xb [j ]);
3202
+ float min ;
3203
+ float d = make_qkx3_quants (QK5_1 , 31 , xb , weight , L , & min , Laux , -0.9f , 0.05f , 36 , false);
3204
+ y [ib ].d = GGML_FP32_TO_FP16 (d );
3205
+ y [ib ].m = GGML_FP32_TO_FP16 (- min );
3206
+
3207
+ uint32_t qh = 0 ;
3208
+ for (int j = 0 ; j < 16 ; ++ j ) {
3209
+ const uint8_t xi0 = L [j ];
3210
+ const uint8_t xi1 = L [j + 16 ];
3211
+ y [ib ].qs [j ] = (xi0 & 0x0F ) | ((xi1 & 0x0F ) << 4 );
3212
+ // get the 5-th bit and store it in qh at the right position
3213
+ qh |= ((xi0 & 0x10u ) >> 4 ) << (j + 0 );
3214
+ qh |= ((xi1 & 0x10u ) >> 4 ) << (j + QK5_0 /2 );
3215
+ }
3216
+ memcpy (& y [ib ].qh , & qh , sizeof (qh ));
3217
+ }
3218
+ }
3219
+
3220
+ size_t quantize_q5_1 (const float * src , void * dst , int nrow , int n_per_row , int64_t * hist , const float * quant_weights ) {
3221
+ if (!quant_weights ) {
3222
+ return ggml_quantize_q5_1 (src , dst , nrow * n_per_row , n_per_row , hist );
3223
+ }
3224
+ int row_size = ggml_row_size (GGML_TYPE_Q5_1 , n_per_row );
3225
+ char * qrow = (char * )dst ;
3226
+ for (int row = 0 ; row < nrow ; ++ row ) {
3227
+ quantize_row_q5_1_impl (src , (block_q5_1 * )qrow , n_per_row , quant_weights );
3228
+ src += n_per_row ;
3229
+ qrow += row_size ;
3230
+ }
3231
+ return nrow * row_size ;
3232
+ }
3233
+
3042
3234
// ====================== "True" 2-bit (de)-quantization
3043
3235
3044
3236
static const uint64_t iq2xxs_grid [256 ] = {
0 commit comments