Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with non-even division of block count and the else path of an element-wise broadcast kernel? #1785

Closed
kevinstephano opened this issue Jun 30, 2022 · 5 comments · Fixed by #1787
Assignees
Labels

Comments

@kevinstephano
Copy link
Collaborator

🐛 Describe the bug

There seems to be broadcast issue in our TOT code as the following code produces a max difference of 80 in given element.

Example Code:

import torch

class Fusion(torch.nn.Module) :
    def __init__(self) :
        super(Fusion, self).__init__()

    def forward(self, x, a, b) :
         out = torch.mul(x.unsqueeze(-1), a)
         out = out + b
         return out

x = torch.randn(1024, 192, 3, device='cuda')
a = torch.randn(3, 128, device='cuda')
b = torch.randn(3, 128, device='cuda')

model = Fusion()
jit_model = torch.jit.script(model)

with torch.jit.fuser('fuser2'):
    for _ in range(5) :
        out_ref = model(x, a, b)
        out_jit = jit_model(x, a, b)

print(out_ref.allclose(out_jit))
print(torch.max(torch.abs(out_ref - out_jit)))

In the NGC 22.05 container, unsqueeze() is fused and the blocking is different. For 22.05, the number of blocks corresponds to T2's outer 2 dimensions 1024 * 192. In TOT, the outer 2 dimensions in T1 1024 * 192 does not divide evenly by the the number of blocks 65535. 65536 * 3 == 1024 * 192. I am guessing something is wrong in the else path as you step through non-vectorized loads of the remainder of T1. I didn't see any obvious differences on the if-then path.

[DUMP profiling_graph_executor_impl.cpp:683] with prim::CudaFusionGroup_0 = graph(%1 : Float(3, 128, strides=[128, 1], requires_grad=0, device=cuda:0),
[DUMP profiling_graph_executor_impl.cpp:683]       %5 : Float(3, 128, strides=[128, 1], requires_grad=0, device=cuda:0),
[DUMP profiling_graph_executor_impl.cpp:683]       %7 : Float(1024, 192, 3, strides=[576, 3, 1], requires_grad=0, device=cuda:0)):
[DUMP profiling_graph_executor_impl.cpp:683]   %2 : int = prim::Constant[value=1]()
[DUMP profiling_graph_executor_impl.cpp:683]   %8 : int = prim::Constant[value=-1]() # izzy.py:8:37
[DUMP profiling_graph_executor_impl.cpp:683]   %9 : Float(1024, 192, 3, 1, strides=[576, 3, 1, 1], requires_grad=0, device=cuda:0) = prim::unsqueeze_copy(%7, %8)
[DUMP profiling_graph_executor_impl.cpp:683]   %out.1 : Float(1024, 192, 3, 128, strides=[73728, 384, 128, 1], requires_grad=0, device=cuda:0) = aten::mul(%9, %5) # izzy.py:8:15
[DUMP profiling_graph_executor_impl.cpp:683]   %out.5 : Float(1024, 192, 3, 128, strides=[73728, 384, 128, 1], requires_grad=0, device=cuda:0) = aten::add(%out.1, %1, %2) # izzy.py:9:15
[DUMP profiling_graph_executor_impl.cpp:683]   return (%out.5)

Fusion IR:

Inputs:
  T0_g[ iblockIdx.y106{( ceilDiv(( ceilDiv(( i0 * i2 ), 4) ), blockDim.x) )}, iS105{4}, ithreadIdx.x107{blockDim.x} ], float
  T1_g[ iS118{( ceilDiv(( ceilDiv(( i3 * i4 ), 4) ), blockDim.x) )}, iS117{4}, iS119{blockDim.x} ], float
  T2_g[ iS110{( ceilDiv(( ceilDiv(i7, 4) ), blockDim.x) )}, iblockIdx.x113{( ceilDiv(( i5 * i6 ), 1) )}, iUS114{1}, iS109{4}, iS111{blockDim.x} ], float
Outputs:
  T7_g[ iblockIdx.y42{( ceilDiv(( ceilDiv(( i7 * i4 ), 4) ), blockDim.x) )}, iblockIdx.x44{( ceilDiv(( i5 * i6 ), 1) )}, iUS45{1}, iV41{4}, ithreadIdx.x43{blockDim.x} ] produce_pos( 3), float

%kernel_math {
T10_l[ iblockIdx.y93{( ceilDiv(( ceilDiv(i7, 4) ), blockDim.x) )}, iblockIdx.x96{( ceilDiv(( i5 * i6 ), 1) )}, iUS97{1}, iS92{4}, ithreadIdx.x94{blockDim.x} ] ca_pos( 3 )
   = T2_g[ iS110{( ceilDiv(( ceilDiv(i7, 4) ), blockDim.x) )}, iblockIdx.x113{( ceilDiv(( i5 * i6 ), 1) )}, iUS114{1}, iS109{4}, iS111{blockDim.x} ];
T3_l[ iblockIdx.y73{( ceilDiv(( ceilDiv(( i7 * 1 ), 4) ), blockDim.x) )}, iblockIdx.x76{( ceilDiv(( i5 * i6 ), 1) )}, iUS77{1}, iS72{4}, ithreadIdx.x74{blockDim.x} ] ca_pos( 5 ) produce_pos( 3) = broadcast( T10_l[ iblockIdx.y93{( ceilDiv(( ceilDiv(i7, 4) ), blockDim.x) )}, iblockIdx.x96{( ceilDiv(( i5 * i6 ), 1) )}, iUS97{1}, iS92{4}, ithreadIdx.x94{blockDim.x} ] ca_pos( 3 ) )
T9_l[ iblockIdx.y101{( ceilDiv(( ceilDiv(( i3 * i4 ), 4) ), blockDim.x) )}, iV100{4}, ithreadIdx.x102{blockDim.x} ] ca_pos( 1 )
   = T1_g[ iS118{( ceilDiv(( ceilDiv(( i3 * i4 ), 4) ), blockDim.x) )}, iS117{4}, iS119{blockDim.x} ];
T4_l[ iblockIdx.y81{( ceilDiv(( ceilDiv(( i3 * i4 ), 4) ), blockDim.x) )}, bblockIdx.x84{( ceilDiv(( 1 * 1 ), 1) )}, bUS85{1}, iS80{4}, ithreadIdx.x82{blockDim.x} ] ca_pos( 5 ) produce_pos( 1) = broadcast( T9_l[ iblockIdx.y101{( ceilDiv(( ceilDiv(( i3 * i4 ), 4) ), blockDim.x) )}, iV100{4}, ithreadIdx.x102{blockDim.x} ] ca_pos( 1 ) )
T5_l[ iblockIdx.y57{( ceilDiv(( ceilDiv(( i7 * i4 ), 4) ), blockDim.x) )}, iblockIdx.x60{( ceilDiv(( i5 * i6 ), 1) )}, iUS61{1}, iS56{4}, ithreadIdx.x58{blockDim.x} ] ca_pos( 5 ) produce_pos( 5)
   = T3_l[ iblockIdx.y73{( ceilDiv(( ceilDiv(( i7 * 1 ), 4) ), blockDim.x) )}, iblockIdx.x76{( ceilDiv(( i5 * i6 ), 1) )}, iUS77{1}, iS72{4}, ithreadIdx.x74{blockDim.x} ] ca_pos( 5 ) produce_pos( 3)
   * T4_l[ iblockIdx.y81{( ceilDiv(( ceilDiv(( i3 * i4 ), 4) ), blockDim.x) )}, bblockIdx.x84{( ceilDiv(( 1 * 1 ), 1) )}, bUS85{1}, iS80{4}, ithreadIdx.x82{blockDim.x} ] ca_pos( 5 ) produce_pos( 1);
T8_l[ iblockIdx.y89{( ceilDiv(( ceilDiv(( i0 * i2 ), 4) ), blockDim.x) )}, iV88{4}, ithreadIdx.x90{blockDim.x} ] ca_pos( 1 )
   = T0_g[ iblockIdx.y106{( ceilDiv(( ceilDiv(( i0 * i2 ), 4) ), blockDim.x) )}, iS105{4}, ithreadIdx.x107{blockDim.x} ];
T6_l[ iblockIdx.y65{( ceilDiv(( ceilDiv(( i0 * i2 ), 4) ), blockDim.x) )}, bblockIdx.x68{( ceilDiv(( 1 * 1 ), 1) )}, bUS69{1}, iS64{4}, ithreadIdx.x66{blockDim.x} ] ca_pos( 5 ) produce_pos( 1) = broadcast( T8_l[ iblockIdx.y89{( ceilDiv(( ceilDiv(( i0 * i2 ), 4) ), blockDim.x) )}, iV88{4}, ithreadIdx.x90{blockDim.x} ] ca_pos( 1 ) )
T11_l[ iblockIdx.y49{( ceilDiv(( ceilDiv(( i7 * i4 ), 4) ), blockDim.x) )}, iblockIdx.x52{( ceilDiv(( i5 * i6 ), 1) )}, iUS53{1}, iS48{4}, ithreadIdx.x50{blockDim.x} ] ca_pos( 3 ) produce_pos( 5)
   = T5_l[ iblockIdx.y57{( ceilDiv(( ceilDiv(( i7 * i4 ), 4) ), blockDim.x) )}, iblockIdx.x60{( ceilDiv(( i5 * i6 ), 1) )}, iUS61{1}, iS56{4}, ithreadIdx.x58{blockDim.x} ] ca_pos( 5 ) produce_pos( 5)
   + T6_l[ iblockIdx.y65{( ceilDiv(( ceilDiv(( i0 * i2 ), 4) ), blockDim.x) )}, bblockIdx.x68{( ceilDiv(( 1 * 1 ), 1) )}, bUS69{1}, iS64{4}, ithreadIdx.x66{blockDim.x} ] ca_pos( 5 ) produce_pos( 1);
T7_g[ iblockIdx.y42{( ceilDiv(( ceilDiv(( i7 * i4 ), 4) ), blockDim.x) )}, iblockIdx.x44{( ceilDiv(( i5 * i6 ), 1) )}, iUS45{1}, iV41{4}, ithreadIdx.x43{blockDim.x} ] produce_pos( 3)
   = T11_l[ iblockIdx.y49{( ceilDiv(( ceilDiv(( i7 * i4 ), 4) ), blockDim.x) )}, iblockIdx.x52{( ceilDiv(( i5 * i6 ), 1) )}, iUS53{1}, iS48{4}, ithreadIdx.x50{blockDim.x} ] ca_pos( 3 ) produce_pos( 5);
} 

Launch Params:
Grid(196608, 1, 1) Block(96, 1, 1)

Kernel:

__global__ void kernel1(Tensor<float, 2> T0, Tensor<float, 2> T1, Tensor<float, 3> T2, Tensor<float, 4> T7) {
  NVFUSER_DEFINE_MAGIC_ZERO
  int64_t i249;
  i249 = (((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 4) + 3;
  int64_t i111;
  i111 = ((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 4;
  Array<float, 4, 4> T8;
  T8.set(0);
  if (((((nvfuser_index_t)blockIdx.y) < (ceilDiv((ceilDiv((T0.size[0] * T0.size[1]), 4)), ((nvfuser_index_t)blockDim.x)))) && (i249 < (T0.size[0] * T0.size[1])))) {
    loadGlobalToLocal<float, 4, false>(&T8[0],  &T0[i111]);
  }
  Array<float, 4, 4> T9;
  T9.set(0);
  if (((((nvfuser_index_t)blockIdx.y) < (ceilDiv((ceilDiv((T0.size[0] * T0.size[1]), 4)), ((nvfuser_index_t)blockDim.x)))) && (i249 < (T0.size[0] * T0.size[1])))) {
    loadGlobalToLocal<float, 4, false>(&T9[0],  &T1[i111]);
  }
  if ((((((nvfuser_index_t)blockIdx.y) < (ceilDiv((ceilDiv((T0.size[0] * T0.size[1]), 4)), ((nvfuser_index_t)blockDim.x)))) && ((((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 4) + 3) / T0.size[1]) < T0.size[0])) && (i249 < (T0.size[0] * T0.size[1])))) {
    float T10[4];
    #pragma unroll
    for(nvfuser_index_t i100 = 0; i100 < 4; ++i100) {
      T10[i100] = 0;
    }
    NVFUSER_UPDATE_MAGIC_ZERO
    #pragma unroll
    for(nvfuser_index_t i88 = 0; i88 < 4; ++i88) {
      T10[i88]
         = T2[(((nvfuser_index_t)blockIdx.x) * T0.size[0]) + (((((((nvfuser_index_t)blockIdx.y) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 4) + (i88 + nvfuser_zero)) / T0.size[1])];
    }
    NVFUSER_UPDATE_MAGIC_ZERO
    // Alias Allocation - register
    auto& T11 = T9;
    #pragma unroll
    for(nvfuser_index_t i86 = 0; i86 < 4; ++i86) {
      float T4[1];
      T4[0]
         = T9[i86];
      float T3[1];
      T3[0]
         = T10[i86];
      float T5[1];
      T5[0]
        = T3[0]
        * T4[0];
      float T6[1];
      T6[0]
         = T8[i86];
      T11[i86]
        = T5[0]
        + T6[0];
    }
    NVFUSER_UPDATE_MAGIC_ZERO
    loadLocalToGlobal<float, 4, false>( &T7[(((nvfuser_index_t)blockIdx.x) * (T0.size[1] * T0.size[0])) + i111], &T11[0]);
  } else {

For TOT :

graph:

[DUMP graph_fuser.cpp:2502] with prim::CudaFusionGroup_0 = graph(%1 : Float(3, 128, strides=[128, 1], requires_grad=0, device=cuda:0),
[DUMP graph_fuser.cpp:2502]       %4 : Float(1024, 192, 3, 1, strides=[576, 3, 1, 1], requires_grad=0, device=cuda:0),
[DUMP graph_fuser.cpp:2502]       %5 : Float(3, 128, strides=[128, 1], requires_grad=0, device=cuda:0)):
[DUMP graph_fuser.cpp:2502]   %2 : int = prim::Constant[value=1]()
[DUMP graph_fuser.cpp:2502]   %out.1 : Float(1024, 192, 3, 128, strides=[73728, 384, 128, 1], requires_grad=0, device=cuda:0) = aten::mul(%4, %5) # izzy.py:8:15
[DUMP graph_fuser.cpp:2502]   %out.5 : Float(1024, 192, 3, 128, strides=[73728, 384, 128, 1], requires_grad=0, device=cuda:0) = aten::add(%out.1, %1, %2) # izzy.py:9:15
[DUMP graph_fuser.cpp:2502]   return (%out.5)

Fusion IR:

Inputs:                        
  T0_g[ iS124{( ceilDiv(( ceilDiv(( i0 * i2 ), 4) ), blockDim.x) )}, iS123{4}, iS125{blockDim.x} ], float
  T1_g[ iS99{( ceilDiv(( ceilDiv(( i5 * 1 ), 4) ), blockDim.x) )}, iS104{( ceilDiv(( ceilDiv(( i3 * i4 ), 1) ), 65535) )}, iS105{65535}, iS103{1}, iS98{4}, iS100{blockDim.x} ], float
  T2_g[ iS114{( ceilDiv(( ceilDiv(( i7 * i8 ), 4) ), blockDim.x) )}, iS113{4}, iS115{blockDim.x} ], float
Outputs:                                                                                                                                                                                                                                             
  T6_g[ iblockIdx.x40{( ceilDiv(( ceilDiv(( i5 * i8 ), 4) ), blockDim.x) )}, iS44{( ceilDiv(( ceilDiv(( i3 * i4 ), 1) ), 65535) )}, iblockIdx.y45{65535}, iUS43{1}, iV39{4}, ithreadIdx.x41{blockDim.x} ] produce_pos( 4), float
                                                                                                                          
%kernel_math {                                
T8_l[ iblockIdx.x89{( ceilDiv(( ceilDiv(( i5 * 1 ), 4) ), blockDim.x) )}, iS94{( ceilDiv(( ceilDiv(( i3 * i4 ), 1) ), 65535) )}, iblockIdx.y95{65535}, iUS93{1}, iS88{4}, ithreadIdx.x90{blockDim.x} ] ca_pos( 4 )
   = T1_g[ iS99{( ceilDiv(( ceilDiv(( i5 * 1 ), 4) ), blockDim.x) )}, iS104{( ceilDiv(( ceilDiv(( i3 * i4 ), 1) ), 65535) )}, iS105{65535}, iS103{1}, iS98{4}, iS100{blockDim.x} ];
T9_l[ iblockIdx.x109{( ceilDiv(( ceilDiv(( i7 * i8 ), 4) ), blockDim.x) )}, iV108{4}, ithreadIdx.x110{blockDim.x} ] ca_pos( 1 )
   = T2_g[ iS114{( ceilDiv(( ceilDiv(( i7 * i8 ), 4) ), blockDim.x) )}, iS113{4}, iS115{blockDim.x} ];
T3_l[ iblockIdx.x79{( ceilDiv(( ceilDiv(( i7 * i8 ), 4) ), blockDim.x) )}, bS84{( ceilDiv(( ceilDiv(( 1 * 1 ), 1) ), 65535) )}, bblockIdx.y85{65535}, bUS83{1}, iS78{4}, ithreadIdx.x80{blockDim.x} ] ca_pos( 6 ) produce_pos( 1)
   = broadcast( T9_l[ iblockIdx.x109{( ceilDiv(( ceilDiv(( i7 * i8 ), 4) ), blockDim.x) )}, iV108{4}, ithreadIdx.x110{blockDim.x} ] ca_pos( 1 ) )
T4_l[ iblockIdx.x69{( ceilDiv(( ceilDiv(( i5 * i8 ), 4) ), blockDim.x) )}, iS74{( ceilDiv(( ceilDiv(( i3 * i4 ), 1) ), 65535) )}, iblockIdx.y75{65535}, iUS73{1}, iS68{4}, ithreadIdx.x70{blockDim.x} ] ca_pos( 6 ) produce_pos( 6)
   = T8_l[ iblockIdx.x89{( ceilDiv(( ceilDiv(( i5 * 1 ), 4) ), blockDim.x) )}, iS94{( ceilDiv(( ceilDiv(( i3 * i4 ), 1) ), 65535) )}, iblockIdx.y95{65535}, iUS93{1}, iS88{4}, ithreadIdx.x90{blockDim.x} ] ca_pos( 4 )
   * T3_l[ iblockIdx.x79{( ceilDiv(( ceilDiv(( i7 * i8 ), 4) ), blockDim.x) )}, bS84{( ceilDiv(( ceilDiv(( 1 * 1 ), 1) ), 65535) )}, bblockIdx.y85{65535}, bUS83{1}, iS78{4}, ithreadIdx.x80{blockDim.x} ] ca_pos( 6 ) produce_pos( 1);
T7_l[ iblockIdx.x119{( ceilDiv(( ceilDiv(( i0 * i2 ), 4) ), blockDim.x) )}, iV118{4}, ithreadIdx.x120{blockDim.x} ] ca_pos( 1 )
   = T0_g[ iS124{( ceilDiv(( ceilDiv(( i0 * i2 ), 4) ), blockDim.x) )}, iS123{4}, iS125{blockDim.x} ];
T5_l[ iblockIdx.x59{( ceilDiv(( ceilDiv(( i0 * i2 ), 4) ), blockDim.x) )}, bS64{( ceilDiv(( ceilDiv(( 1 * 1 ), 1) ), 65535) )}, bblockIdx.y65{65535}, bUS63{1}, iS58{4}, ithreadIdx.x60{blockDim.x} ] ca_pos( 6 ) produce_pos( 1)
   = broadcast( T7_l[ iblockIdx.x119{( ceilDiv(( ceilDiv(( i0 * i2 ), 4) ), blockDim.x) )}, iV118{4}, ithreadIdx.x120{blockDim.x} ] ca_pos( 1 ) )
T10_l[ iblockIdx.x49{( ceilDiv(( ceilDiv(( i5 * i8 ), 4) ), blockDim.x) )}, iS54{( ceilDiv(( ceilDiv(( i3 * i4 ), 1) ), 65535) )}, iblockIdx.y55{65535}, iUS53{1}, iS48{4}, ithreadIdx.x50{blockDim.x} ] ca_pos( 4 ) produce_pos( 6)
   = T4_l[ iblockIdx.x69{( ceilDiv(( ceilDiv(( i5 * i8 ), 4) ), blockDim.x) )}, iS74{( ceilDiv(( ceilDiv(( i3 * i4 ), 1) ), 65535) )}, iblockIdx.y75{65535}, iUS73{1}, iS68{4}, ithreadIdx.x70{blockDim.x} ] ca_pos( 6 ) produce_pos( 6)
   + T5_l[ iblockIdx.x59{( ceilDiv(( ceilDiv(( i0 * i2 ), 4) ), blockDim.x) )}, bS64{( ceilDiv(( ceilDiv(( 1 * 1 ), 1) ), 65535) )}, bblockIdx.y65{65535}, bUS63{1}, iS58{4}, ithreadIdx.x60{blockDim.x} ] ca_pos( 6 ) produce_pos( 1);
T6_g[ iblockIdx.x40{( ceilDiv(( ceilDiv(( i5 * i8 ), 4) ), blockDim.x) )}, iS44{( ceilDiv(( ceilDiv(( i3 * i4 ), 1) ), 65535) )}, iblockIdx.y45{65535}, iUS43{1}, iV39{4}, ithreadIdx.x41{blockDim.x} ] produce_pos( 4)
   = T10_l[ iblockIdx.x49{( ceilDiv(( ceilDiv(( i5 * i8 ), 4) ), blockDim.x) )}, iS54{( ceilDiv(( ceilDiv(( i3 * i4 ), 1) ), 65535) )}, iblockIdx.y55{65535}, iUS53{1}, iS48{4}, ithreadIdx.x50{blockDim.x} ] ca_pos( 4 ) produce_pos( 6);
}               

Launch Params:
Grid(1, 65535, 1) Block(96, 1, 1)

Kernel:

__global__ void kernel1(Tensor<float, 2> T0, Tensor<float, 4> T1, Tensor<float, 2> T2, Tensor<float, 4> T6) {                                                                                                                                                                                                                                                                                                                                                                                             
  NVFUSER_DEFINE_MAGIC_ZERO                                                                                                                                                                                                                                                                                                                                                                                                                                                                               
  int i284;                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               
  i284 = (((((nvfuser_index_t)blockIdx.x) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 4) + 3;                                                                                                                                                                                                                                                                                                                                                                                    
  int i116;                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               
  i116 = ((((nvfuser_index_t)blockIdx.x) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 4;                                                                                                                                                                                                                                                                                                                                                                                          
  Array<float, 4, 4> T7;                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  
  T7.set(0);                                                                                                                                                                                                                                                                                                                                                                                                                                                                                              
  if ((i284 < (T0.size[0] * T0.size[1]))) {                                                                                                                                                                                                                                                                                                                                                                                                                                                               
    loadGlobalToLocal<float, 4, false>(&T7[0],  &T0[i116]);                                                                                                                                                                                                                                                                                                                                                                                                                                               
  }                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       
  Array<float, 4, 4> T9;                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  
  T9.set(0);                                                                                                                                                                                                                                                                                                                                                                                                                                                                                              
  if ((i284 < (T0.size[0] * T0.size[1]))) {                                                                                                                                                                                                                                                                                                                                                                                                                                                               
    loadGlobalToLocal<float, 4, false>(&T9[0],  &T2[i116]);
  }
  #pragma unroll 1
  for(nvfuser_index_t i101 = 0; i101 < (ceilDiv((ceilDiv((T1.size[0] * T1.size[1]), 1)), 65535)); ++i101) {
    int i144;
    i144 = (i101 * 65535) + ((nvfuser_index_t)blockIdx.y);
    if (((((((((((nvfuser_index_t)blockIdx.x) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 4) + 3) / T0.size[1]) < T0.size[0]) && (i144 < (T1.size[0] * T1.size[1]))) && (i284 < (T0.size[0] * T0.size[1])))) {
      float T8[4];
      #pragma unroll
      for(nvfuser_index_t i91 = 0; i91 < 4; ++i91) {
        T8[i91] = 0;
      }
      NVFUSER_UPDATE_MAGIC_ZERO
      #pragma unroll
      for(nvfuser_index_t i91 = 0; i91 < 4; ++i91) {
        T8[i91]
           = T1[(i144 * T0.size[0]) + (((((((nvfuser_index_t)blockIdx.x) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 4) + (i91 + nvfuser_zero)) / T0.size[1])];
      }
      NVFUSER_UPDATE_MAGIC_ZERO
      // Alias Allocation - register
      auto& T10 = T9;
      #pragma unroll
      for(nvfuser_index_t i100 = 0; i100 < 4; ++i100) {
        float T3[1];
        T3[0]
           = T9[i100];
        float T4[1];
        T4[0]
          = T8[i100]
          * T3[0];
        float T5[1];
        T5[0]
           = T7[i100];
        T10[i100]
          = T4[0]
          + T5[0];
      }
      NVFUSER_UPDATE_MAGIC_ZERO
      loadLocalToGlobal<float, 4, false>( &T6[(i144 * (T0.size[1] * T0.size[0])) + i116], &T10[0]);
    } else {                                                                                                                                                                                                                                                                                                                                                                                                                                                                                              
      float T8[4];                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        
      #pragma unroll                                                                                                                                                                                                                                                                                                                                                                                                                                                                                      
      for(nvfuser_index_t i91 = 0; i91 < 4; ++i91) {                                                                                                                                                                                                                                                                                                                                                                                                                                                      
        T8[i91] = 0;                                                                                                                                                                                                                                                                                                                                                                                                                                                                                      
      }                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   
      NVFUSER_UPDATE_MAGIC_ZERO                                                                                                                                                                                                                                                                                                                                                                                                                                                                           
      #pragma unroll                                                                                                                                                                                                                                                                                                                                                                                                                                                                                      
      for(nvfuser_index_t i91 = 0; i91 < 4; ++i91) {                                                                                                                                                                                                                                                                                                                                                                                                                                                      
        int i205;                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         
        i205 = ((((((nvfuser_index_t)blockIdx.x) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 4) + (i91 + nvfuser_zero)) / T0.size[1];                                                                                                                                                                                                                                                                                                                                            
        if (((i205 < T0.size[0]) && (i144 < (T1.size[0] * T1.size[1])))) {                                                                                                                                                                                                                                                                                                                                                                                                                                
          T8[i91]
             = T1[(i144 * T0.size[0]) + i205];
        }
      }
      NVFUSER_UPDATE_MAGIC_ZERO
      // Alias Allocation - register
      auto& T10 = T9;
      #pragma unroll
      for(nvfuser_index_t i100 = 0; i100 < 4; ++i100) {
        float T3[1];
        T3[0]
           = T9[i100];
        float T4[1];
        T4[0]
          = T8[i100]
          * T3[0];
        float T5[1];
        T5[0]
           = T7[i100];
        T10[i100]
          = T4[0]
          + T5[0];
      }
      NVFUSER_UPDATE_MAGIC_ZERO
      if (((i284 < (T0.size[0] * T0.size[1])) && (i144 < (T1.size[0] * T1.size[1])))) {
        loadLocalToGlobal<float, 4, false>( &T6[(i144 * (T0.size[1] * T0.size[0])) + i116], &T10[0]);
      }
    }   

Versions

TOT

@kevinstephano
Copy link
Collaborator Author

kevinstephano commented Jun 30, 2022

As a note, if I make the x tensor batching a multiple of 65,535, the problem goes away.

This is an example:

x = torch.randn(4369, 15, 3, device='cuda')

@shmsong
Copy link

shmsong commented Jun 30, 2022

This is a register aliasing problem. Fixing it now.

@kevinstephano
Copy link
Collaborator Author

I would be curious to know which part of the kernel was in error if you don't mind pointing it out.

@shmsong
Copy link

shmsong commented Jun 30, 2022

// Alias Allocation - register
auto& T10 = T9;

^^^
This shouldn't be happening as lifetime of T9 carries around the serial loop

for(nvfuser_index_t i101 = 0; i101 < (ceilDiv((ceilDiv((T1.size[0] * T1.size[1]), 1)), 65535)); ++i101) {

@shmsong
Copy link

shmsong commented Jun 30, 2022

To unblock, going for a quicker fix that blanket disables a wider range of scenarios. Will improve the analysis precision in follow ups.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants