@@ -9018,18 +9018,20 @@ static void ggml_compute_forward_rms_norm_f32(
9018
9018
GGML_ASSERT (ggml_are_same_shape (src0 , dst ));
9019
9019
9020
9020
if (params -> type == GGML_TASK_INIT || params -> type == GGML_TASK_FINALIZE ) {
9021
+ atomic_store (params -> aic , 0 );
9022
+
9021
9023
return ;
9022
9024
}
9023
9025
9024
9026
GGML_ASSERT (src0 -> nb [0 ] == sizeof (float ));
9025
9027
9026
- const int ith = params -> ith ;
9028
+ const int ith = params -> ith ; UNUSED ( ith );
9027
9029
const int nth = params -> nth ;
9028
9030
9029
9031
const int64_t ne00 = src0 -> ne [0 ];
9030
9032
const int64_t ne01 = src0 -> ne [1 ];
9031
9033
const int64_t ne02 = src0 -> ne [2 ];
9032
- const int64_t ne03 = src0 -> ne [3 ];
9034
+ const int64_t ne03 = src0 -> ne [3 ]; UNUSED ( ne03 );
9033
9035
9034
9036
const size_t nb01 = src0 -> nb [1 ];
9035
9037
const size_t nb02 = src0 -> nb [2 ];
@@ -9041,30 +9043,45 @@ static void ggml_compute_forward_rms_norm_f32(
9041
9043
9042
9044
const float eps = 1e-6f ; // TODO: make this a parameter
9043
9045
9044
- // TODO: optimize
9045
- for (int64_t i03 = 0 ; i03 < ne03 ; i03 ++ ) {
9046
- for (int64_t i02 = 0 ; i02 < ne02 ; i02 ++ ) {
9047
- for (int64_t i01 = ith ; i01 < ne01 ; i01 += nth ) {
9048
- const float * x = (float * ) ((char * ) src0 -> data + i01 * nb01 + i02 * nb02 + i03 * nb03 );
9049
-
9050
- ggml_float sum = 0.0 ;
9051
- for (int64_t i00 = 0 ; i00 < ne00 ; i00 ++ ) {
9052
- sum += (ggml_float )(x [i00 ] * x [i00 ]);
9053
- }
9046
+ const int nr = ggml_nrows (src0 );
9047
+ const int dr = (nr + 8 * nth - 1 )/(8 * nth );
9054
9048
9055
- float mean = sum /ne00 ;
9049
+ while (true) {
9050
+ const int ir0 = atomic_fetch_add (params -> aic , dr );
9056
9051
9057
- float * y = (float * ) ((char * ) dst -> data + i01 * nb1 + i02 * nb2 + i03 * nb3 );
9052
+ for (int ir = ir0 ; ir < ir0 + dr ; ++ ir ) {
9053
+ if (ir >= nr ) {
9054
+ break ;
9055
+ }
9058
9056
9059
- memcpy ( y , x , ne00 * sizeof ( float ));
9060
- // for ( int i00 = 0; i00 < ne00; i00++) {
9061
- // y[i00] = x[i00] ;
9062
- // }
9057
+ // src0 indices
9058
+ const int i03 = ir /( ne02 * ne01 );
9059
+ const int i02 = ( ir - i03 * ne02 * ne01 )/ ne01 ;
9060
+ const int i01 = ( ir - i03 * ne02 * ne01 - i02 * ne01 );
9063
9061
9064
- const float scale = 1.0f / sqrtf ( mean + eps );
9062
+ const float * x = ( float * ) (( char * ) src0 -> data + i01 * nb01 + i02 * nb02 + i03 * nb03 );
9065
9063
9066
- ggml_vec_scale_f32 (ne00 , y , scale );
9064
+ ggml_float sum = 0.0 ;
9065
+ for (int64_t i00 = 0 ; i00 < ne00 ; i00 ++ ) {
9066
+ sum += (ggml_float )(x [i00 ] * x [i00 ]);
9067
9067
}
9068
+
9069
+ float mean = sum /ne00 ;
9070
+
9071
+ float * y = (float * ) ((char * ) dst -> data + i01 * nb1 + i02 * nb2 + i03 * nb3 );
9072
+
9073
+ memcpy (y , x , ne00 * sizeof (float ));
9074
+ // for (int i00 = 0; i00 < ne00; i00++) {
9075
+ // y[i00] = x[i00];
9076
+ // }
9077
+
9078
+ const float scale = 1.0f /sqrtf (mean + eps );
9079
+
9080
+ ggml_vec_scale_f32 (ne00 , y , scale );
9081
+ }
9082
+
9083
+ if (ir0 + dr >= nr ) {
9084
+ break ;
9068
9085
}
9069
9086
}
9070
9087
}
@@ -9739,11 +9756,9 @@ static void ggml_compute_forward_mul_mat_q_f32(
9739
9756
const int nb2 = dst -> nb [2 ];
9740
9757
const int nb3 = dst -> nb [3 ];
9741
9758
9742
- const int ith = params -> ith ;
9759
+ const int ith = params -> ith ; UNUSED ( ith );
9743
9760
const int nth = params -> nth ;
9744
9761
9745
- UNUSED (ith );
9746
-
9747
9762
GGML_ASSERT (ne02 == ne12 );
9748
9763
GGML_ASSERT (ne03 == ne13 );
9749
9764
GGML_ASSERT (ne2 == ne12 );
0 commit comments