Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cinn(test): add rmsnorm subgraph symbolic test #61317

Merged
merged 1 commit into from
Jan 30, 2024

Conversation

6clc
Copy link
Contributor

@6clc 6clc commented Jan 29, 2024

PR types

Others

PR changes

Others

Description

Pcard-78120

graph fusion

8: ===-----------------------------------------------------------------------===
8:         IRPrinting on builtin.module before lower_cinn_fusion_op pass
8: ===-----------------------------------------------------------------------===
8: {
8:  (%0) = "builtin.parameter" () {is_persisable:[true],parameter_name:"parameter_0",stop_gradient:[false]} : () -> pd_op.tensor<768xf32>
8:  (%1) = "pd_op.data" () {dtype:(pd_op.DataType)float32,name:"_jst.0.hidden_states.0",place:(pd_op.Place)Place(undefined:0),shape:(pd_op.IntArray)[7,2048,768],stop_gradient:[false]} : () -> pd_op.tensor<-1x2048x768xf32>
8:  (%2) = cinn_op.fusion () -> pd_op.tensor<7x2048x768xf32> {
8:  (%3) = "pd_op.full" () {dtype:(pd_op.DataType)float32,place:(pd_op.Place)Place(cpu),shape:(pd_op.IntArray)[7,2048,768],stop_gradient:[true],value:(Float)2} : () -> pd_op.tensor<-1x2048x768xf32>
8:  (%4) = "pd_op.elementwise_pow" (%1, %3) {stop_gradient:[false]} : (pd_op.tensor<-1x2048x768xf32>, pd_op.tensor<-1x2048x768xf32>) -> pd_op.tensor<-1x2048x768xf32>
8:  (%5) = "cinn_op.reduce_sum" (%4) {dim:[(Int64)-1],keep_dim:true,stop_gradient:[false]} : (pd_op.tensor<-1x2048x768xf32>) -> pd_op.tensor<-1x2048x1xf32>
8:  (%6) = "cinn_op.scale" (%5) {bias:(Float)0,bias_after_scale:true,scale:(Float)0.00130208,stop_gradient:[false]} : (pd_op.tensor<-1x2048x1xf32>) -> pd_op.tensor<-1x2048x1xf32>
8:  (%7) = "pd_op.full" () {dtype:(pd_op.DataType)float32,place:(pd_op.Place)Place(cpu),shape:(pd_op.IntArray)[7,2048,1],stop_gradient:[true],value:(Float)-0.5} : () -> pd_op.tensor<-1x2048x1xf32>
8:  (%8) = "cinn_op.scale" (%6) {bias:(Float)1e-06,bias_after_scale:true,scale:(Float)1,stop_gradient:[false]} : (pd_op.tensor<-1x2048x1xf32>) -> pd_op.tensor<-1x2048x1xf32>
8:  (%9) = "pd_op.elementwise_pow" (%8, %7) {stop_gradient:[false]} : (pd_op.tensor<-1x2048x1xf32>, pd_op.tensor<-1x2048x1xf32>) -> pd_op.tensor<-1x2048x1xf32>
8:  (%10) = "cinn_op.broadcast" (%9) {broadcast_axes:[(Int64)0,(Int64)1,(Int64)2],out_shape:[(Int64)7,(Int64)2048,(Int64)768],stop_gradient:[false]} : (pd_op.tensor<-1x2048x1xf32>) -> pd_op.tensor<-1x2048x768xf32>
8:  (%11) = "cinn_op.broadcast" (%0) {broadcast_axes:[(Int64)2],out_shape:[(Int64)7,(Int64)2048,(Int64)768],stop_gradient:[false]} : (pd_op.tensor<768xf32>) -> pd_op.tensor<-1x2048x768xf32>
8:  (%12) = "pd_op.multiply" (%10, %1) {stop_gradient:[false]} : (pd_op.tensor<-1x2048x768xf32>, pd_op.tensor<-1x2048x768xf32>) -> pd_op.tensor<-1x2048x768xf32>
8:  (%13) = "pd_op.multiply" (%12, %11) {stop_gradient:[false]} : (pd_op.tensor<-1x2048x768xf32>, pd_op.tensor<-1x2048x768xf32>) -> pd_op.tensor<-1x2048x768xf32>
8:  () = "cf.yield" (%13) {} : (pd_op.tensor<-1x2048x768xf32>) ->  
8:  }
8:  () = "builtin.shadow_output" (%2) {output_name:"output_0"} : (pd_op.tensor<7x2048x768xf32>) -> 
8: }

cuda c

__global__
void __launch_bounds__(256) fn_fill_constant_pow_reduce_sum_scale_fill_constant_0_scale_0_pow_0_broadcast_to_broadcast_to_0_elementwise_mul_elementwise_mul_0__COND__FPA__FPA__FPA__FPA_7llMULS0_BPA_GE1024_BPA_AND_FPA__FPA_7llMULS0_BPA_LT2147483647_BPA__BPA_AND_FPA__FPA_768llGE256_BPA_AND_FPA_768llLT2147483647_BPA__BPA__BPA___kernel(const float* __restrict__ var_0, const float* __restrict__ var_8, float* __restrict__ var_11, int64_t S0)
{
  float *_var_2_rf_temp_buffer = new float[ min((((7ll * S0) / 1024) + (1ll + ((((7ll * S0) / 1024) / S0) * min((((7ll * S0) / 1024) + 1ll), S0)))), (S0 + ((((7ll * S0) / 1024) / S0) * min((((7ll * S0) / 1024) + 1ll), S0)))) ];
  float *_var_2_temp_buffer = new float[ min((((7ll * S0) / 1024) + (1ll + ((((7ll * S0) / 1024) / S0) * min((((7ll * S0) / 1024) + 1ll), S0)))), (S0 + ((((7ll * S0) / 1024) / S0) * min((((7ll * S0) / 1024) + 1ll), S0)))) ];
  __shared__ float shm32__fp32_reduce [ 32 ];
  float* var_2 = _var_2_temp_buffer;
  float* var_2__reduce_init = _var_2_temp_buffer;
  float* var_2_rf = _var_2_rf_temp_buffer;
  float* var_2_rf__reduce_init = _var_2_rf_temp_buffer;
  if (((int)blockIdx.y < 2ll)) {
    if (((int)blockIdx.x < 512ll)) {
      for (int32_t i_j_k_fused_34 = 0; i_j_k_fused_34 < (((7ll * S0) / 1024) + 1ll); i_j_k_fused_34 += 1) {
        if ((((((7ll * S0) / 1024) * (int)blockIdx.x) + ((512ll * (int)blockIdx.y) + ((((7ll * S0) / 1024) * (512ll * (int)blockIdx.y)) + ((int)blockIdx.x + i_j_k_fused_34)))) < (7ll * S0))) {
          var_2__reduce_init[(((i_j_k_fused_34 / S0) * min((1ll + (7ll * S0)), S0)) + (i_j_k_fused_34 % S0))] = 0.00000000f;
          if (((int)threadIdx.x < 256ll)) {
            var_2_rf__reduce_init[(((i_j_k_fused_34 / S0) * min((1ll + (7ll * S0)), S0)) + (i_j_k_fused_34 % S0))] = 0.00000000f;
            for (int32_t reduce_k_0_3 = 0; reduce_k_0_3 < 3; reduce_k_0_3 += 1) {
              var_2_rf[(((i_j_k_fused_34 / S0) * min((1ll + (7ll * S0)), S0)) + (i_j_k_fused_34 % S0))] = (var_2_rf[(((i_j_k_fused_34 / S0) * min((1ll + (7ll * S0)), S0)) + (i_j_k_fused_34 % S0))] + cinn_nvgpu_pow_fp32(var_0[(((((((7ll * S0) / 1024) * (int)blockIdx.x) + ((512ll * (int)blockIdx.y) + ((((7ll * S0) / 1024) * (512ll * (int)blockIdx.y)) + ((int)blockIdx.x + i_j_k_fused_34)))) / S0) * (768ll * S0)) + ((768ll * (((((7ll * S0) / 1024) * (int)blockIdx.x) + ((512ll * (int)blockIdx.y) + ((((7ll * S0) / 1024) * (512ll * (int)blockIdx.y)) + ((int)blockIdx.x + i_j_k_fused_34)))) % S0)) + ((3ll * (int)threadIdx.x) + reduce_k_0_3)))], 2.00000000f));
            };
            var_2[(((i_j_k_fused_34 / S0) * min((1ll + (7ll * S0)), S0)) + (i_j_k_fused_34 % S0))] = cinn_block_reduce_sum_fp32_internal_shm(var_2_rf[(((i_j_k_fused_34 / S0) * min((1ll + (7ll * S0)), S0)) + (i_j_k_fused_34 % S0))], shm32__fp32_reduce);
          };
        };
      };
    };
  };
  if (((int)blockIdx.y < 2ll)) {
    if (((int)blockIdx.x < 512ll)) {
      for (int32_t i_j_k_fused_17_i_j_k_fused_18_fused_1 = 0; i_j_k_fused_17_i_j_k_fused_18_fused_1 < (((7ll * S0) / 1024) + 1ll); i_j_k_fused_17_i_j_k_fused_18_fused_1 += 1) {
        if ((((((7ll * S0) / 1024) * (int)blockIdx.x) + ((512ll * (int)blockIdx.y) + ((((7ll * S0) / 1024) * (512ll * (int)blockIdx.y)) + ((int)blockIdx.x + i_j_k_fused_17_i_j_k_fused_18_fused_1)))) < (7ll * S0))) {
          if (((int)threadIdx.x < 256)) {
            for (int32_t i_j_k_fused_19_0 = 0; i_j_k_fused_19_0 < 3; i_j_k_fused_19_0 += 1) {
              var_11[(((((((7ll * S0) / 1024) * (int)blockIdx.x) + ((512ll * (int)blockIdx.y) + ((((7ll * S0) / 1024) * (512ll * (int)blockIdx.y)) + ((int)blockIdx.x + i_j_k_fused_17_i_j_k_fused_18_fused_1)))) / S0) * (768ll * S0)) + ((768ll * (((((7ll * S0) / 1024) * (int)blockIdx.x) + ((512ll * (int)blockIdx.y) + ((((7ll * S0) / 1024) * (512ll * (int)blockIdx.y)) + ((int)blockIdx.x + i_j_k_fused_17_i_j_k_fused_18_fused_1)))) % S0)) + ((3ll * (int)threadIdx.x) + i_j_k_fused_19_0)))] = (cinn_nvgpu_rsqrt_fp32((9.99999997e-07f + (0.00130208337f * var_2[(((i_j_k_fused_17_i_j_k_fused_18_fused_1 / S0) * min((1ll + (7ll * S0)), S0)) + (i_j_k_fused_17_i_j_k_fused_18_fused_1 % S0))]))) * (var_0[(((((((7ll * S0) / 1024) * (int)blockIdx.x) + ((512ll * (int)blockIdx.y) + ((((7ll * S0) / 1024) * (512ll * (int)blockIdx.y)) + ((int)blockIdx.x + i_j_k_fused_17_i_j_k_fused_18_fused_1)))) / S0) * (768ll * S0)) + ((768ll * (((((7ll * S0) / 1024) * (int)blockIdx.x) + ((512ll * (int)blockIdx.y) + ((((7ll * S0) / 1024) * (512ll * (int)blockIdx.y)) + ((int)blockIdx.x + i_j_k_fused_17_i_j_k_fused_18_fused_1)))) % S0)) + ((3ll * (int)threadIdx.x) + i_j_k_fused_19_0)))] * var_8[((3ll * (int)threadIdx.x) + i_j_k_fused_19_0)]));
            };
          };
        };
      };
    };
  };
  delete [] _var_2_rf_temp_buffer;
;
  delete [] _var_2_temp_buffer;
;
}

Copy link
Contributor

@BiynXu BiynXu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@BiynXu BiynXu merged commit 3d0601e into PaddlePaddle:develop Jan 30, 2024
30 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants