Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some generated cuda kernel's input's shape is 0 #389

Closed
VincentXWD opened this issue Dec 9, 2023 · 0 comments
Closed

Some generated cuda kernel's input's shape is 0 #389

VincentXWD opened this issue Dec 9, 2023 · 0 comments

Comments

@VincentXWD
Copy link

Hello, I noticed that some generated cuda kernel's input's shape is 0. Here is the hidet python model-define code:

I wonder know it would happen and is it a bug? Thanks.

import torch
import torch._dynamo
from torch import nn
import hidet 
import math


class LayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-12):
        """Construct a layernorm module in the TF style (epsilon inside the square root).
        """
        super(LayerNorm, self).__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.bias = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

    def forward(self, x):
        u = x.mean(-1, keepdim=True)
        s = (x - u).pow(2).mean(-1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.variance_epsilon)
        return self.weight * x + self.bias
        
class SelfAttention(nn.Module):
    def __init__(self, num_attention_heads, input_size, hidden_size, attention_probs_dropout_prob, hidden_dropout_prob):
        super(SelfAttention, self).__init__()
        if hidden_size % num_attention_heads != 0:
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (hidden_size, num_attention_heads))
        self.num_attention_heads = num_attention_heads
        self.attention_head_size = int(hidden_size / num_attention_heads)
        self.all_head_size = hidden_size

        self.query = nn.Linear(input_size, self.all_head_size)
        self.key = nn.Linear(input_size, self.all_head_size)
        self.value = nn.Linear(input_size, self.all_head_size)

        self.attn_dropout = nn.Dropout(attention_probs_dropout_prob)

        self.dense = nn.Linear(hidden_size, hidden_size)
        self.LayerNorm = LayerNorm(hidden_size, eps=1e-12)
        self.out_dropout = nn.Dropout(hidden_dropout_prob)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, input_tensor):
        mixed_query_layer = self.query(input_tensor)
        mixed_key_layer = self.key(input_tensor)
        mixed_value_layer = self.value(input_tensor)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
        # [batch_size heads seq_len seq_len] scores
        # [batch_size 1 1 seq_len]

        # attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.Softmax(dim=-1)(attention_scores)
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        # Fixme
        attention_probs = self.attn_dropout(attention_probs)
        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        hidden_states = self.dense(context_layer)
        hidden_states = self.out_dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)

        return hidden_states


hidet.option.cache_dir('./outs/cache')
model = SelfAttention(num_attention_heads = 12, input_size = 768, hidden_size = 768, attention_probs_dropout_prob = 0.5, hidden_dropout_prob = 0.5).cuda().eval()
x = torch.rand(1, 128, 768).cuda()
# print(model)
model_opt = torch.compile(model, backend='hidet')  
y = model_opt(x)

Here is the 12nd kernel meta.json:

{
  "name": "fused_subtract_pow",
  "symbols": [],
  "inputs": [
    {
      "device": "cuda",
      "dtype": "float32",
      "shape": []
    },
    {
      "device": "cuda",
      "dtype": "float32",
      "shape": [
        1,
        128,
        768
      ]
    },
    {
      "device": "cuda",
      "dtype": "float32",
      "shape": [
        1,
        128,
        1
      ]
    }
  ],
  "outputs": [
    {
      "device": "cuda",
      "dtype": "float32",
      "shape": [
        1,
        128,
        768
      ]
    }
  ],
  "target": "cuda",
  "num_candidates": 1,
  "hidet_version": "0.3.1.dev"
}

Here is the generated kernel:

#include <stdint.h>
#include <hidet/runtime/symbols.h>
#include <hidet/runtime/memory_planner.h>
#include <hidet/runtime/cpu/context.h>
#include <hidet/runtime/cuda/complex.h>
#include <hidet/runtime/cuda/context.h>
#include <hidet/runtime/logging.h>


static __global__ void __launch_bounds__(512) hidet_fused_compute_z(float * __restrict__ x, float * __restrict__ y, float * __restrict__ y_1, float * __restrict__ z) {
  z[((((((int)blockIdx.x * 512) + (int)threadIdx.x) / 768) * 768) + ((((int)blockIdx.x * 512) + (int)threadIdx.x) % 768))] = powf((x[((((((int)blockIdx.x * 512) + (int)threadIdx.x) / 768) * 768) + ((((int)blockIdx.x * 512) + (int)threadIdx.x) % 768))] - y[((((int)blockIdx.x * 512) + (int)threadIdx.x) / 768)]), y_1[0]);
}

DLL void hidet_get_input_shape(int32_t idx, int32_t * __restrict__ dims) {
  if (idx == 0) {
  } 
  if (idx == 1) {
    dims[0] = 1;
    dims[1] = 128;
    dims[2] = 768;
  } 
  if (idx == 2) {
    dims[0] = 1;
    dims[1] = 128;
    dims[2] = 1;
  } 
}

DLL void hidet_get_output_shape(int32_t idx, int32_t * __restrict__ dims) {
  if (idx == 0) {
    dims[0] = 1;
    dims[1] = 128;
    dims[2] = 768;
  } 
}

DLL void hidet_launch_0(float * __restrict__ y, float * __restrict__ x, float * __restrict__ y_1, float * __restrict__ z) {
  hidet_fused_compute_z<<<dim3(192, 1, 1), dim3(512, 1, 1), 0, (cudaStream_t)get_cuda_stream()>>>(x, y_1, y, z);
  {cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) LOG(ERROR) << "CUDA error: " << cudaGetErrorString(err) << "\n";}
}
vadiklyutiy pushed a commit that referenced this issue Dec 19, 2024
Cast return value of `get_parallel_num_workers` to float. Fix #388
vadiklyutiy pushed a commit that referenced this issue Dec 20, 2024
Cast return value of `get_parallel_num_workers` to float. Fix #388
vadiklyutiy pushed a commit that referenced this issue Dec 26, 2024
Cast return value of `get_parallel_num_workers` to float. Fix #388
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant