Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Silent wrong result for RNG + transpose non-square tile #1926

Closed
zasdfgbnm opened this issue Aug 24, 2022 · 4 comments · Fixed by #1924
Closed

Silent wrong result for RNG + transpose non-square tile #1926

zasdfgbnm opened this issue Aug 24, 2022 · 4 comments · Fixed by #1924

Comments

@zasdfgbnm
Copy link
Collaborator

zasdfgbnm commented Aug 24, 2022

🐛 Describe the bug

TEST_F(NVFuserTest, FusionBroadcastingRNGSmemNonSquareTile_CUDA) {
  std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
  auto fusion = fusion_ptr.get();
  FusionGuard fg(fusion);

  TensorView* tv0 = makeConcreteTensor({5, 1});
  TensorView* tv1 = makeConcreteTensor({5, 5});
  fusion->addInput(tv0);
  fusion->addInput(tv1);
  auto tv2 = randlike(tv0);
  auto tv3 = add(tv1, tv2);
  auto tv4 = add(tv0, tv3);
  fusion->addOutput(tv4);

  auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0);
  at::Tensor t0 = at::zeros({5, 1}, options);
  at::Tensor t1 = at::zeros({5, 5}, options);

  TransposeParams heuristics;
  heuristics.tile_size1 = 8;
  heuristics.tile_size2 = 4;
  scheduleTranspose(fusion, heuristics);

  FusionExecutor fe;
  fe.compileFusion(fusion, {t0, t1});
  auto cg_outputs = fe.runFusion({t0, t1});
  auto out = cg_outputs[0];

  std::cout << out << std::endl;

  TORCH_CHECK((out.select(1, 0) == out.select(1, 1)).all().item<bool>());
  TORCH_CHECK((out.select(1, 0) == out.select(1, 2)).all().item<bool>());
  TORCH_CHECK((out.select(1, 0) == out.select(1, 3)).all().item<bool>());
  TORCH_CHECK((out.select(1, 0) == out.select(1, 4)).all().item<bool>());
}

Output:

 0.8388  0.8388  0.8388  0.8388  0.8388
 0.2753  0.2753  0.2753  0.2753  0.2753
 0.1871  0.1871  0.1871  0.1871  0.1871
 0.7694  0.7694  0.7694  0.7694  0.7694
 0.1788  0.0000  0.0000  0.0000  0.1788
[ CUDAFloatType{5,5} ]

Fusion:

%kernel {
T5_s[ iblockIdx.x63{( ceilDiv(( ( ceilDiv(5, 4) ) * ( ceilDiv(1, 8) ) ), 1) )}, iUS64{1}, iUR80{( ceilDiv(( ceilDiv(( 8 * 4 ), 1) ), 32) )}, ithreadIdx.x81{32}, iS79{1} ] ca_pos( 2 )
   = T0_g[ iS70{( ceilDiv(( ( ceilDiv(5, 4) ) * ( ceilDiv(1, 8) ) ), 1) )}, iS71{1}, iS75{( ceilDiv(( ceilDiv(( 8 * 4 ), 1) ), 32) )}, iS76{32}, iS74{1} ];
T6_l[ iblockIdx.x49{( ceilDiv(( ( ceilDiv(5, 4) ) * ( ceilDiv(5, 8) ) ), 1) )}, iUS50{1}, iUR125{( ceilDiv(( ceilDiv(( 4 * 8 ), 1) ), 32) )}, ithreadIdx.x126{32}, iS124{1} ] ca_pos( 2 )
   = T1_g[ iS56{( ceilDiv(( ( ceilDiv(5, 4) ) * ( ceilDiv(5, 8) ) ), 1) )}, iS57{1}, iS130{( ceilDiv(( ceilDiv(( 4 * 8 ), 1) ), 32) )}, iS131{32}, iS129{1} ];
T2_l[ iblockIdx.x42{( ceilDiv(( ( ceilDiv(5, 4) ) * ( ceilDiv(1, 8) ) ), 1) )}, iUS43{1}, iS120{( ceilDiv(( ceilDiv(( 4 * 8 ), 1) ), 32) )}, ithreadIdx.x121{32}, iS119{1} ] ca_pos( 5 )
   = rng_uniform();
T3_l[ iblockIdx.x35{( ceilDiv(( ( ceilDiv(5, 4) ) * ( ceilDiv(5, 8) ) ), 1) )}, iUS36{1}, iS115{( ceilDiv(( ceilDiv(( 4 * 8 ), 1) ), 32) )}, ithreadIdx.x116{32}, iS114{1} ] ca_pos( 5 ) produce_pos( 5)
   = T6_l[ iblockIdx.x49{( ceilDiv(( ( ceilDiv(5, 4) ) * ( ceilDiv(5, 8) ) ), 1) )}, iUS50{1}, iUR125{( ceilDiv(( ceilDiv(( 4 * 8 ), 1) ), 32) )}, ithreadIdx.x126{32}, iS124{1} ] ca_pos( 2 )
   + T2_l[ iblockIdx.x42{( ceilDiv(( ( ceilDiv(5, 4) ) * ( ceilDiv(1, 8) ) ), 1) )}, iUS43{1}, iS120{( ceilDiv(( ceilDiv(( 4 * 8 ), 1) ), 32) )}, ithreadIdx.x121{32}, iS119{1} ] ca_pos( 5 );
T7_l[ iblockIdx.x28{( ceilDiv(( ( ceilDiv(5, 4) ) * ( ceilDiv(5, 8) ) ), 1) )}, iUS29{1}, iS110{( ceilDiv(( ceilDiv(( 4 * 8 ), 1) ), 32) )}, ithreadIdx.x111{32}, iS109{1} ] ca_pos( 2 ) produce_pos( 5)
   = T5_s[ iblockIdx.x63{( ceilDiv(( ( ceilDiv(5, 4) ) * ( ceilDiv(1, 8) ) ), 1) )}, iUS64{1}, iUR80{( ceilDiv(( ceilDiv(( 8 * 4 ), 1) ), 32) )}, ithreadIdx.x81{32}, iS79{1} ] ca_pos( 2 )
   + T3_l[ iblockIdx.x35{( ceilDiv(( ( ceilDiv(5, 4) ) * ( ceilDiv(5, 8) ) ), 1) )}, iUS36{1}, iS115{( ceilDiv(( ceilDiv(( 4 * 8 ), 1) ), 32) )}, ithreadIdx.x116{32}, iS114{1} ] ca_pos( 5 ) produce_pos( 5);
T4_g[ iblockIdx.x21{( ceilDiv(( ( ceilDiv(5, 4) ) * ( ceilDiv(5, 8) ) ), 1) )}, iUS22{1}, iUR105{( ceilDiv(( ceilDiv(( 4 * 8 ), 1) ), 32) )}, ithreadIdx.x106{32}, iS104{1} ] ca_pos( 2 ) produce_pos( 2)
   = T7_l[ iblockIdx.x28{( ceilDiv(( ( ceilDiv(5, 4) ) * ( ceilDiv(5, 8) ) ), 1) )}, iUS29{1}, iS110{( ceilDiv(( ceilDiv(( 4 * 8 ), 1) ), 32) )}, ithreadIdx.x111{32}, iS109{1} ] ca_pos( 2 ) produce_pos( 5);

TransformPrinter : 
T0_g[ iS70{( ceilDiv(( ( ceilDiv(5, 4) ) * ( ceilDiv(1, 8) ) ), 1) )}, iS71{1}, iS75{( ceilDiv(( ceilDiv(( 8 * 4 ), 1) ), 32) )}, iS76{32}, iS74{1} ]
 root domain : (iS0{5},bS1{1})
  Split: iS0{5} by factor 4 -> iS65{( ceilDiv(5, 4) )}, iS66{4}, start offset: 0, stop offset: 0
  Split: bS1{1} by factor 8 -> bS67{( ceilDiv(1, 8) )}, bS68{8}, start offset: 0, stop offset: 0
  Merge: iS65{( ceilDiv(5, 4) )} and bS67{( ceilDiv(1, 8) )} -> iS69{( ( ceilDiv(5, 4) ) * ( ceilDiv(1, 8) ) )}
  Split: iS69{( ( ceilDiv(5, 4) ) * ( ceilDiv(1, 8) ) )} by factor 1 -> iS70{( ceilDiv(( ( ceilDiv(5, 4) ) * ( ceilDiv(1, 8) ) ), 1) )}, iS71{1}, start offset: 0, stop offset: 0
  Merge: bS68{8} and iS66{4} -> iS72{( 8 * 4 )}
  Split: iS72{( 8 * 4 )} by factor 1 -> iS73{( ceilDiv(( 8 * 4 ), 1) )}, iS74{1}, start offset: 0, stop offset: 0
  Split: iS73{( ceilDiv(( 8 * 4 ), 1) )} by factor 32 -> iS75{( ceilDiv(( ceilDiv(( 8 * 4 ), 1) ), 32) )}, iS76{32}, start offset: 0, stop offset: 0
T5_s[ iblockIdx.x63{( ceilDiv(( ( ceilDiv(5, 4) ) * ( ceilDiv(1, 8) ) ), 1) )}, iUS64{1}, iUR80{( ceilDiv(( ceilDiv(( 8 * 4 ), 1) ), 32) )}, ithreadIdx.x81{32}, iS79{1} ] ca_pos( 2 )
 root domain : (iS10{5},bS11{1})
  Split: iS10{5} by factor 4 -> iS58{( ceilDiv(5, 4) )}, iS59{4}, start offset: 0, stop offset: 0
  Split: bS11{1} by factor 8 -> bS60{( ceilDiv(1, 8) )}, bS61{8}, start offset: 0, stop offset: 0
  Merge: iS58{( ceilDiv(5, 4) )} and bS60{( ceilDiv(1, 8) )} -> iS62{( ( ceilDiv(5, 4) ) * ( ceilDiv(1, 8) ) )}
  Split: iS62{( ( ceilDiv(5, 4) ) * ( ceilDiv(1, 8) ) )} by factor 1 -> iblockIdx.x63{( ceilDiv(( ( ceilDiv(5, 4) ) * ( ceilDiv(1, 8) ) ), 1) )}, iUS64{1}, start offset: 0, stop offset: 0
  Merge: bS61{8} and iS59{4} -> iS77{( 8 * 4 )}
  Split: iS77{( 8 * 4 )} by factor 1 -> iS78{( ceilDiv(( 8 * 4 ), 1) )}, iS79{1}, start offset: 0, stop offset: 0
  Split: iS78{( ceilDiv(( 8 * 4 ), 1) )} by factor 32 -> iUR80{( ceilDiv(( ceilDiv(( 8 * 4 ), 1) ), 32) )}, ithreadIdx.x81{32}, start offset: 0, stop offset: 0
T1_g[ iS56{( ceilDiv(( ( ceilDiv(5, 4) ) * ( ceilDiv(5, 8) ) ), 1) )}, iS57{1}, iS130{( ceilDiv(( ceilDiv(( 4 * 8 ), 1) ), 32) )}, iS131{32}, iS129{1} ]
 root domain : (iS2{5},iS3{5})
  Split: iS2{5} by factor 4 -> iS51{( ceilDiv(5, 4) )}, iS52{4}, start offset: 0, stop offset: 0
  Split: iS3{5} by factor 8 -> iS53{( ceilDiv(5, 8) )}, iS54{8}, start offset: 0, stop offset: 0
  Merge: iS51{( ceilDiv(5, 4) )} and iS53{( ceilDiv(5, 8) )} -> iS55{( ( ceilDiv(5, 4) ) * ( ceilDiv(5, 8) ) )}
  Split: iS55{( ( ceilDiv(5, 4) ) * ( ceilDiv(5, 8) ) )} by factor 1 -> iS56{( ceilDiv(( ( ceilDiv(5, 4) ) * ( ceilDiv(5, 8) ) ), 1) )}, iS57{1}, start offset: 0, stop offset: 0
  Merge: iS52{4} and iS54{8} -> iS127{( 4 * 8 )}
  Split: iS127{( 4 * 8 )} by factor 1 -> iS128{( ceilDiv(( 4 * 8 ), 1) )}, iS129{1}, start offset: 0, stop offset: 0
  Split: iS128{( ceilDiv(( 4 * 8 ), 1) )} by factor 32 -> iS130{( ceilDiv(( ceilDiv(( 4 * 8 ), 1) ), 32) )}, iS131{32}, start offset: 0, stop offset: 0
T6_l[ iblockIdx.x49{( ceilDiv(( ( ceilDiv(5, 4) ) * ( ceilDiv(5, 8) ) ), 1) )}, iUS50{1}, iUR125{( ceilDiv(( ceilDiv(( 4 * 8 ), 1) ), 32) )}, ithreadIdx.x126{32}, iS124{1} ] ca_pos( 2 )
 root domain : (iS12{5},iS13{5})
  Split: iS12{5} by factor 4 -> iS44{( ceilDiv(5, 4) )}, iS45{4}, start offset: 0, stop offset: 0
  Split: iS13{5} by factor 8 -> iS46{( ceilDiv(5, 8) )}, iS47{8}, start offset: 0, stop offset: 0
  Merge: iS44{( ceilDiv(5, 4) )} and iS46{( ceilDiv(5, 8) )} -> iS48{( ( ceilDiv(5, 4) ) * ( ceilDiv(5, 8) ) )}
  Split: iS48{( ( ceilDiv(5, 4) ) * ( ceilDiv(5, 8) ) )} by factor 1 -> iblockIdx.x49{( ceilDiv(( ( ceilDiv(5, 4) ) * ( ceilDiv(5, 8) ) ), 1) )}, iUS50{1}, start offset: 0, stop offset: 0
  Merge: iS45{4} and iS47{8} -> iS122{( 4 * 8 )}
  Split: iS122{( 4 * 8 )} by factor 1 -> iS123{( ceilDiv(( 4 * 8 ), 1) )}, iS124{1}, start offset: 0, stop offset: 0
  Split: iS123{( ceilDiv(( 4 * 8 ), 1) )} by factor 32 -> iUR125{( ceilDiv(( ceilDiv(( 4 * 8 ), 1) ), 32) )}, ithreadIdx.x126{32}, start offset: 0, stop offset: 0
T2_l[ iblockIdx.x42{( ceilDiv(( ( ceilDiv(5, 4) ) * ( ceilDiv(1, 8) ) ), 1) )}, iUS43{1}, iS120{( ceilDiv(( ceilDiv(( 4 * 8 ), 1) ), 32) )}, ithreadIdx.x121{32}, iS119{1} ] ca_pos( 5 )
 root domain : (iS4{5},bS5{1})
  Split: iS4{5} by factor 4 -> iS37{( ceilDiv(5, 4) )}, iS38{4}, start offset: 0, stop offset: 0
  Split: bS5{1} by factor 8 -> bS39{( ceilDiv(1, 8) )}, bS40{8}, start offset: 0, stop offset: 0
  Merge: iS37{( ceilDiv(5, 4) )} and bS39{( ceilDiv(1, 8) )} -> iS41{( ( ceilDiv(5, 4) ) * ( ceilDiv(1, 8) ) )}
  Split: iS41{( ( ceilDiv(5, 4) ) * ( ceilDiv(1, 8) ) )} by factor 1 -> iblockIdx.x42{( ceilDiv(( ( ceilDiv(5, 4) ) * ( ceilDiv(1, 8) ) ), 1) )}, iUS43{1}, start offset: 0, stop offset: 0
  Merge: iS38{4} and bS40{8} -> iS117{( 4 * 8 )}
  Split: iS117{( 4 * 8 )} by factor 1 -> iS118{( ceilDiv(( 4 * 8 ), 1) )}, iS119{1}, start offset: 0, stop offset: 0
  Split: iS118{( ceilDiv(( 4 * 8 ), 1) )} by factor 32 -> iS120{( ceilDiv(( ceilDiv(( 4 * 8 ), 1) ), 32) )}, ithreadIdx.x121{32}, start offset: 0, stop offset: 0
T3_l[ iblockIdx.x35{( ceilDiv(( ( ceilDiv(5, 4) ) * ( ceilDiv(5, 8) ) ), 1) )}, iUS36{1}, iS115{( ceilDiv(( ceilDiv(( 4 * 8 ), 1) ), 32) )}, ithreadIdx.x116{32}, iS114{1} ] ca_pos( 5 ) produce_pos( 5)
 root domain : (iS6{5},iS7{5})
  Split: iS6{5} by factor 4 -> iS30{( ceilDiv(5, 4) )}, iS31{4}, start offset: 0, stop offset: 0
  Split: iS7{5} by factor 8 -> iS32{( ceilDiv(5, 8) )}, iS33{8}, start offset: 0, stop offset: 0
  Merge: iS30{( ceilDiv(5, 4) )} and iS32{( ceilDiv(5, 8) )} -> iS34{( ( ceilDiv(5, 4) ) * ( ceilDiv(5, 8) ) )}
  Split: iS34{( ( ceilDiv(5, 4) ) * ( ceilDiv(5, 8) ) )} by factor 1 -> iblockIdx.x35{( ceilDiv(( ( ceilDiv(5, 4) ) * ( ceilDiv(5, 8) ) ), 1) )}, iUS36{1}, start offset: 0, stop offset: 0
  Merge: iS31{4} and iS33{8} -> iS112{( 4 * 8 )}
  Split: iS112{( 4 * 8 )} by factor 1 -> iS113{( ceilDiv(( 4 * 8 ), 1) )}, iS114{1}, start offset: 0, stop offset: 0
  Split: iS113{( ceilDiv(( 4 * 8 ), 1) )} by factor 32 -> iS115{( ceilDiv(( ceilDiv(( 4 * 8 ), 1) ), 32) )}, ithreadIdx.x116{32}, start offset: 0, stop offset: 0
T7_l[ iblockIdx.x28{( ceilDiv(( ( ceilDiv(5, 4) ) * ( ceilDiv(5, 8) ) ), 1) )}, iUS29{1}, iS110{( ceilDiv(( ceilDiv(( 4 * 8 ), 1) ), 32) )}, ithreadIdx.x111{32}, iS109{1} ] ca_pos( 2 ) produce_pos( 5)
 root domain : (iS8{5},iS9{5})
  Split: iS8{5} by factor 4 -> iS23{( ceilDiv(5, 4) )}, iS24{4}, start offset: 0, stop offset: 0
  Split: iS9{5} by factor 8 -> iS25{( ceilDiv(5, 8) )}, iS26{8}, start offset: 0, stop offset: 0
  Merge: iS23{( ceilDiv(5, 4) )} and iS25{( ceilDiv(5, 8) )} -> iS27{( ( ceilDiv(5, 4) ) * ( ceilDiv(5, 8) ) )}
  Split: iS27{( ( ceilDiv(5, 4) ) * ( ceilDiv(5, 8) ) )} by factor 1 -> iblockIdx.x28{( ceilDiv(( ( ceilDiv(5, 4) ) * ( ceilDiv(5, 8) ) ), 1) )}, iUS29{1}, start offset: 0, stop offset: 0
  Merge: iS24{4} and iS26{8} -> iS107{( 4 * 8 )}
  Split: iS107{( 4 * 8 )} by factor 1 -> iS108{( ceilDiv(( 4 * 8 ), 1) )}, iS109{1}, start offset: 0, stop offset: 0
  Split: iS108{( ceilDiv(( 4 * 8 ), 1) )} by factor 32 -> iS110{( ceilDiv(( ceilDiv(( 4 * 8 ), 1) ), 32) )}, ithreadIdx.x111{32}, start offset: 0, stop offset: 0
T4_g[ iblockIdx.x21{( ceilDiv(( ( ceilDiv(5, 4) ) * ( ceilDiv(5, 8) ) ), 1) )}, iUS22{1}, iUR105{( ceilDiv(( ceilDiv(( 4 * 8 ), 1) ), 32) )}, ithreadIdx.x106{32}, iS104{1} ] ca_pos( 2 ) produce_pos( 2)
 root domain : (iS14{5},iS15{5})
  Split: iS14{5} by factor 4 -> iS18{( ceilDiv(5, 4) )}, iS19{4}, start offset: 0, stop offset: 0
  Split: iS15{5} by factor 8 -> iS16{( ceilDiv(5, 8) )}, iS17{8}, start offset: 0, stop offset: 0
  Merge: iS18{( ceilDiv(5, 4) )} and iS16{( ceilDiv(5, 8) )} -> iS20{( ( ceilDiv(5, 4) ) * ( ceilDiv(5, 8) ) )}
  Split: iS20{( ( ceilDiv(5, 4) ) * ( ceilDiv(5, 8) ) )} by factor 1 -> iblockIdx.x21{( ceilDiv(( ( ceilDiv(5, 4) ) * ( ceilDiv(5, 8) ) ), 1) )}, iUS22{1}, start offset: 0, stop offset: 0
  Merge: iS19{4} and iS17{8} -> iS102{( 4 * 8 )}
  Split: iS102{( 4 * 8 )} by factor 1 -> iS103{( ceilDiv(( 4 * 8 ), 1) )}, iS104{1}, start offset: 0, stop offset: 0
  Split: iS103{( ceilDiv(( 4 * 8 ), 1) )} by factor 32 -> iUR105{( ceilDiv(( ceilDiv(( 4 * 8 ), 1) ), 32) )}, ithreadIdx.x106{32}, start offset: 0, stop offset: 0
}

CUDA:

__global__ void kernel1(Tensor<float, 2> T0, Tensor<float, 2> T1, Tensor<float, 2> T4, at::PhiloxCudaState philox_args) {
  auto philox_offset = philox_args.captured_ ?
    static_cast<uint64_t>(*(philox_args.offset_.ptr) + philox_args.offset_intragraph_) :
    philox_args.offset_.val;
  uint4 rng_result;
  nvfuser_index_t rng_subseq = -1;
  nvfuser_index_t rng_offset = -1;
  alignas(16) extern __shared__ char array[];
  unsigned smem_offset = 0;
  NVFUSER_DEFINE_MAGIC_ZERO
  smem_offset = alignBufferSize(smem_offset, 16);
  float* T5 = reinterpret_cast<float*>(array + smem_offset);
  smem_offset += ((((ceilDiv((ceilDiv((8 * 4), 1)), 32)) * 32) * 1) * sizeof(float));
  if (((((((((nvfuser_index_t)blockIdx.x) / (ceilDiv(5, 8))) * 4) + (((((ceilDiv((ceilDiv((4 * 8), 1)), 32)) - 1) * 32) + ((nvfuser_index_t)threadIdx.x)) / 8)) < 5) && ((((((nvfuser_index_t)blockIdx.x) % (ceilDiv(5, 8))) * 8) + (((((ceilDiv((ceilDiv((4 * 8), 1)), 32)) - 1) * 32) + ((nvfuser_index_t)threadIdx.x)) % 8)) < 5)) && ((((((nvfuser_index_t)blockIdx.x) / (ceilDiv(5, 8))) * 4) + (((((ceilDiv((ceilDiv((8 * 4), 1)), 32)) - 1) * 32) + ((nvfuser_index_t)threadIdx.x)) % 4)) < 5))) {
    float T6[((ceilDiv((ceilDiv((4 * 8), 1)), 32)) * 1)];
    #pragma unroll
    for(nvfuser_index_t i103 = 0; i103 < (ceilDiv((ceilDiv((4 * 8), 1)), 32)); ++i103) {
      T6[i103] = 0;
    }
    NVFUSER_UPDATE_MAGIC_ZERO
    #pragma unroll
    for(nvfuser_index_t i103 = 0; i103 < (ceilDiv((ceilDiv((4 * 8), 1)), 32)); ++i103) {
      T6[i103]
         = T1[((((((nvfuser_index_t)blockIdx.x) / (ceilDiv(5, 8))) * 4) + ((((i103 + nvfuser_zero) * 32) + ((nvfuser_index_t)threadIdx.x)) / 8)) * T1.stride[0]) + ((((((nvfuser_index_t)blockIdx.x) % (ceilDiv(5, 8))) * 8) + ((((i103 + nvfuser_zero) * 32) + ((nvfuser_index_t)threadIdx.x)) % 8)) * T1.stride[1])];
    }
    NVFUSER_UPDATE_MAGIC_ZERO
    #pragma unroll
    for(nvfuser_index_t i96 = 0; i96 < (ceilDiv((ceilDiv((8 * 4), 1)), 32)); ++i96) {
      T5[(((i96 * 32) + ((nvfuser_index_t)threadIdx.x)) % 4)]
         = T0[((((((nvfuser_index_t)blockIdx.x) / (ceilDiv(5, 8))) * 4) + ((((i96 + nvfuser_zero) * 32) + ((nvfuser_index_t)threadIdx.x)) % 4)) * T0.stride[0])];
    }
    NVFUSER_UPDATE_MAGIC_ZERO
    float T7[((ceilDiv((ceilDiv((4 * 8), 1)), 32)) * 1)];
    __barrier_sync(0);
    #pragma unroll
    for(nvfuser_index_t i105 = 0; i105 < (ceilDiv((ceilDiv((4 * 8), 1)), 32)); ++i105) {
      float T2[1];
      nvfuser_index_t rng_subseq305 = ((((((nvfuser_index_t)blockIdx.x) / (ceilDiv(5, 8))) * 4) + ((((i105 + nvfuser_zero) * 32) + ((nvfuser_index_t)threadIdx.x)) / 8))) / 4;
      nvfuser_index_t rng_component305 = ((((((nvfuser_index_t)blockIdx.x) / (ceilDiv(5, 8))) * 4) + ((((i105 + nvfuser_zero) * 32) + ((nvfuser_index_t)threadIdx.x)) / 8))) % 4;
      nvfuser_index_t rng_offset305 = 0;
      if (rng_subseq != rng_subseq305 || rng_offset != rng_offset305) {
        rng_result = philox(philox_args.seed_, rng_subseq305, philox_offset / 4 + rng_offset305);
        rng_subseq = rng_subseq305;
        rng_offset = rng_offset305;
      }
      T2[0] = rng_uniformf(rng_result, rng_component305);
      float T3[1];
      T3[0]
        = T6[i105]
        + T2[0];
      T7[i105]
        = T5[(((i105 * 32) + ((nvfuser_index_t)threadIdx.x)) / 8)]
        + T3[0];
    }
    NVFUSER_UPDATE_MAGIC_ZERO
    #pragma unroll
    for(nvfuser_index_t i108 = 0; i108 < (ceilDiv((ceilDiv((4 * 8), 1)), 32)); ++i108) {
      T4[((((((nvfuser_index_t)blockIdx.x) / (ceilDiv(5, 8))) * 4) + ((((i108 + nvfuser_zero) * 32) + ((nvfuser_index_t)threadIdx.x)) / 8)) * 5) + (((((nvfuser_index_t)blockIdx.x) % (ceilDiv(5, 8))) * 8) + ((((i108 + nvfuser_zero) * 32) + ((nvfuser_index_t)threadIdx.x)) % 8))]
         = T7[i108];
    }
    NVFUSER_UPDATE_MAGIC_ZERO
  }
}

Versions

TOT devel

@zasdfgbnm
Copy link
Collaborator Author

cc: @csarofeen

@zasdfgbnm
Copy link
Collaborator Author

Looks like T6 should not share the same predicate with other tensors

@zasdfgbnm
Copy link
Collaborator Author

Minimum repro:

TEST_F(NVFuserTest, FusionPredicateUnshare_CUDA) {
  std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
  auto fusion = fusion_ptr.get();
  FusionGuard fg(fusion);

  TensorView* tv0 = makeSymbolicTensor(2);
  fusion->addInput(tv0);
  auto tv1 = set(tv0);
  auto tv2 = set(tv1);
  fusion->addOutput(tv2);

  tv1->setMemoryType(MemoryType::Shared);
  for (auto tv : {tv1, tv2}) {
    tv->split(0, 4);
    tv->reorder({{1, -1}});
    tv->split(1, 8);
    tv->merge(0);
    tv->split(0, 1);
    tv->axis(0)->parallelize(ParallelType::BIDx);
    tv->axis(1)->parallelize(ParallelType::Unswitch);
  }
  tv1->merge(2);
  tv2->reorder({{2, 3}});
  tv2->merge(2);
  for (auto tv : {tv1, tv2}) {
    tv->axis(-1)->parallelize(ParallelType::TIDx);
  }

  InlinePropagator propagator(tv2, -1, ComputeAtMode::MostInlined);
  MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator);

  auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0);
  at::Tensor t0 = at::randn({5, 5}, options);

  FusionExecutor fe;
  fe.compileFusion(fusion, {t0});
  auto cg_outputs = fe.runFusion({t0});
  auto out = cg_outputs[0];

  testValidate(fusion, {out}, {t0}, {t0}, __LINE__, __FILE__);
}

@zasdfgbnm
Copy link
Collaborator Author

The following is a quick hacking way to "fix" the issue:

bool UnrollPass::canOmitElseClause(kir::ForLoop* fl) {
  return false;
}

Will take a deeper look to find a better canOmitElseClause and write a PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant