Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Getting embedding of a sequence #89

Closed
CorvusVaine opened this issue May 24, 2024 · 2 comments
Closed

Getting embedding of a sequence #89

CorvusVaine opened this issue May 24, 2024 · 2 comments

Comments

@CorvusVaine
Copy link

Hello, I am trying to get the output embeddings of my dataset from DNABERT2, and then use it with another model, using the following code :

import os
import transformers
import torch

os.environ["CUDA_VISIBLE_DEVICES"] = "3"
device = "cuda" if torch.cuda.is_available() else "cpu"

model = transformers.AutoModel.from_pretrained(
"zhihan1996/DNABERT-2-117M",
trust_remote_code=True
)
model.config.use_cache = False
model.config.pretraining_tp = 1
model.eval()

tokenizer = transformers.AutoTokenizer.from_pretrained(
"zhihan1996/DNABERT-2-117M",
model_max_length=42,
padding_side="right",
use_fast=True,
trust_remote_code=True,
truncation=True,
padding='max_length',
max_length=40
)

seqs = ["ATCTAGCTAGACGTTACGCTACGCATGTACGTACGCTCAGTAGCATGCTAGCT","CGTAGGTCGTCTAGCTGATCAGTACGCATGCATAGCTAGCTGCATCGTAGCATCGATGATCGATCGATGATGC"]
model.to(device)
inputs = tokenizer(a, padding = 'max_length', truncation=True, max_length = 40, return_tensors='pt')
inputs = {key: value.to(device) for key, value in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)

When I run, this code, I get the following error :

AssertionError Traceback (most recent call last)
Cell In[11], line 9
7 print(inputs)
8 with torch.no_grad():
----> 9 outputs = model(**inputs)
10 print(outputs)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/d064dece8a8b41d9fb8729fbe3435278786931f1/bert_layers.py:609, in BertModel.forward(self, input_ids, token_type_ids, attention_mask, position_ids, output_all_encoded_layers, masked_tokens_mask, **kwargs)
606 first_col_mask[:, 0] = True
607 subset_mask = masked_tokens_mask | first_col_mask
--> 609 encoder_outputs = self.encoder(
610 embedding_output,
611 attention_mask,
612 output_all_encoded_layers=output_all_encoded_layers,
613 subset_mask=subset_mask)
615 if masked_tokens_mask is None:
616 sequence_output = encoder_outputs[-1]

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/d064dece8a8b41d9fb8729fbe3435278786931f1/bert_layers.py:447, in BertEncoder.forward(self, hidden_states, attention_mask, output_all_encoded_layers, subset_mask)
445 if subset_mask is None:
446 for layer_module in self.layer:
--> 447 hidden_states = layer_module(hidden_states,
448 cu_seqlens,
449 seqlen,
450 None,
451 indices,
452 attn_mask=attention_mask,
453 bias=alibi_attn_mask)
454 if output_all_encoded_layers:
455 all_encoder_layers.append(hidden_states)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/d064dece8a8b41d9fb8729fbe3435278786931f1/bert_layers.py:328, in BertLayer.forward(self, hidden_states, cu_seqlens, seqlen, subset_idx, indices, attn_mask, bias)
306 def forward(
307 self,
308 hidden_states: torch.Tensor,
(...)
314 bias: Optional[torch.Tensor] = None,
315 ) -> torch.Tensor:
316 """Forward pass for a BERT layer, including both attention and MLP.
317
318 Args:
(...)
326 bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
327 """
--> 328 attention_output = self.attention(hidden_states, cu_seqlens, seqlen,
329 subset_idx, indices, attn_mask, bias)
330 layer_output = self.mlp(attention_output)
331 return layer_output

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/d064dece8a8b41d9fb8729fbe3435278786931f1/bert_layers.py:241, in BertUnpadAttention.forward(self, input_tensor, cu_seqlens, max_s, subset_idx, indices, attn_mask, bias)
219 def forward(
220 self,
221 input_tensor: torch.Tensor,
(...)
227 bias: Optional[torch.Tensor] = None,
228 ) -> torch.Tensor:
229 """Forward pass for scaled self-attention without padding.
230
231 Arguments:
(...)
239 bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
240 """
--> 241 self_output = self.self(input_tensor, cu_seqlens, max_s, indices,
242 attn_mask, bias)
243 if subset_idx is not None:
244 return self.output(index_first_axis(self_output, subset_idx),
245 index_first_axis(input_tensor, subset_idx))

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/d064dece8a8b41d9fb8729fbe3435278786931f1/bert_layers.py:182, in BertUnpadSelfAttention.forward(self, hidden_states, cu_seqlens, max_seqlen_in_batch, indices, attn_mask, bias)
180 bias_dtype = bias.dtype
181 bias = bias.to(torch.float16)
--> 182 attention = flash_attn_qkvpacked_func(qkv, bias)
183 attention = attention.to(orig_dtype)
184 bias = bias.to(bias_dtype)

File ~/.local/lib/python3.10/site-packages/torch/autograd/function.py:506, in Function.apply(cls, *args, **kwargs)
503 if not torch._C._are_functorch_transforms_active():
504 # See NOTE: [functorch vjp and autograd interaction]
505 args = _functorch.utils.unwrap_dead_wrappers(args)
--> 506 return super().apply(*args, **kwargs) # type: ignore[misc]
508 if cls.setup_context == _SingleLevelFunction.setup_context:
509 raise RuntimeError(
510 'In order to use an autograd.Function with functorch transforms '
511 '(vmap, grad, jvp, jacrev, ...), it must override the setup_context '
512 'staticmethod. For more details, please see '
513 'https://pytorch.org/docs/master/notes/extending.func.html')

File ~/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/d064dece8a8b41d9fb8729fbe3435278786931f1/flash_attn_triton.py:1021, in _FlashAttnQKVPackedFunc.forward(ctx, qkv, bias, causal, softmax_scale)
1019 if qkv.stride(-1) != 1:
1020 qkv = qkv.contiguous()
-> 1021 o, lse, ctx.softmax_scale = _flash_attn_forward(
1022 qkv[:, :, 0],
1023 qkv[:, :, 1],
1024 qkv[:, :, 2],
1025 bias=bias,
1026 causal=causal,
1027 softmax_scale=softmax_scale)
1028 ctx.save_for_backward(qkv, o, lse, bias)
1029 ctx.causal = causal

File ~/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/d064dece8a8b41d9fb8729fbe3435278786931f1/flash_attn_triton.py:781, in _flash_attn_forward(q, k, v, bias, causal, softmax_scale)
778 assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type'
779 assert q.dtype in [torch.float16,
780 torch.bfloat16], 'Only support fp16 and bf16'
--> 781 assert q.is_cuda and k.is_cuda and v.is_cuda
782 softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
784 has_bias = bias is not None

AssertionError:

I guess this error is due to the structure of the network, that may require the data to be fed differently.
Could you tell me how I could simply get the output embeddings from DNABERT2 please ?

@Zhihan1996
Copy link
Collaborator

Please try this:

import os
import numpy as np
import transformers
import torch
import torch.utils.data as util_data
import torch.nn as nn
import tqdm
import argparse
from sklearn.preprocessing import normalize



def calculate_llm_embedding(dna_sequences, model_name_or_path, model_max_length=400, batch_size=25):
    # reorder the sequences by length
    # process sequences with similar lengths in the same batch can greatly speed up the computation
    lengths = [len(seq) for seq in dna_sequences]
    idx = np.argsort(lengths)
    dna_sequences = [dna_sequences[i] for i in idx]
    
    tokenizer = transformers.AutoTokenizer.from_pretrained(
            model_name_or_path,
            cache_dir=None,
            model_max_length=model_max_length,
            padding_side="left",
            use_fast=True,
            trust_remote_code=True,
        )


    model = transformers.AutoModel.from_pretrained(
            model_name_or_path,
            trust_remote_code=True,
            torch_dtype=torch.bfloat16, 
            attn_implementation="flash_attention_2",
        )
    

    n_gpu = torch.cuda.device_count()
    if n_gpu > 1:
        model = nn.DataParallel(model)
        
    model.to("cuda")


    train_loader = util_data.DataLoader(dna_sequences, batch_size=batch_size*n_gpu, shuffle=False, num_workers=2*n_gpu)
    for j, batch in enumerate(tqdm.tqdm(train_loader)):
        with torch.no_grad():
            token_feat = tokenizer.batch_encode_plus(
                    batch, 
                    max_length=model_max_length, 
                    return_tensors='pt', 
                    padding='longest', 
                    truncation=True
                )
            input_ids = token_feat['input_ids'].cuda()
            attention_mask = token_feat['attention_mask'].cuda()
            model_output = model.forward(input_ids=input_ids, attention_mask=attention_mask)[0].detach().cpu()
                
            attention_mask = attention_mask.unsqueeze(-1).detach().cpu()
            embedding = torch.sum(model_output*attention_mask, dim=1) / torch.sum(attention_mask, dim=1)
            
            if j==0:
                embeddings = embedding
            else:
                embeddings = torch.cat((embeddings, embedding), dim=0)

    embeddings = np.array(embeddings.detach().float().cpu())
    
    # reorder the embeddings
    embeddings = embeddings[np.argsort(idx)]

    return embeddings

@CorvusVaine
Copy link
Author

It appears to me that AutoModel.from_pretrained() does not accept attn_implementation as a parameter, although the config does, so I got an error with your code. I tried changing it a bit by loading the model as follows :

config = transformers.AutoConfig.from_pretrained(model_name_or_path,trust_remote_code=True)
config.attn_implementation="flash_attention_2"
model = transformers.AutoModel.from_config(
config=config,
trust_remote_code=True,
torch_dtype=torch.bfloat16
)

and it seems to load fine.

Although, I still have two errors in the forward pass :

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using tokenizers before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
0%| | 0/1 [00:00<?, ?it/s]

RuntimeError Traceback (most recent call last)
Cell In[47], line 1
----> 1 a= calculate_llm_embedding(["ATCGACTGTGCATAGCCTAGATAGCTACGTACGCTCAGCTGACTGATGCTACAGCGT","ATCGACTGTGCATAGCCTAGATAGCTACGTACGCTCAGCTGACTGATGCTACAGCGT","ATCGACTGTGCATAGCCTAGATAGCTACGTACGCTCAGCTGACTGATGCTACAGCGT","ATCGACTGTGCATAGCCTAGATAGCTACGTACGCTCAGCTGACTGATGCTACAGCGT"],"zhihan1996/DNABERT-2-117M")

Cell In[46], line 45, in calculate_llm_embedding(dna_sequences, model_name_or_path, model_max_length, batch_size)
43 input_ids = token_feat['input_ids'].cuda()
44 attention_mask = token_feat['attention_mask'].cuda()
---> 45 model_output = model.forward(input_ids=input_ids, attention_mask=attention_mask)[0].detach().cpu()
47 attention_mask = attention_mask.unsqueeze(-1).detach().cpu()
48 embedding = torch.sum(model_output*attention_mask, dim=1) / torch.sum(attention_mask, dim=1)

File ~/.local/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py:171, in DataParallel.forward(self, *inputs, **kwargs)
169 return self.module(*inputs[0], **kwargs[0])
170 replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
--> 171 outputs = self.parallel_apply(replicas, inputs, kwargs)
172 return self.gather(outputs, self.output_device)

File ~/.local/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py:181, in DataParallel.parallel_apply(self, replicas, inputs, kwargs)
180 def parallel_apply(self, replicas, inputs, kwargs):
--> 181 return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])

File ~/.local/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py:89, in parallel_apply(modules, inputs, kwargs_tup, devices)
87 output = results[i]
88 if isinstance(output, ExceptionWrapper):
---> 89 output.reraise()
90 outputs.append(output)
91 return outputs

File ~/.local/lib/python3.10/site-packages/torch/_utils.py:643, in ExceptionWrapper.reraise(self)
639 exception = self.exc_type(msg)
640 except TypeError:
641 # If the exception takes multiple arguments, don't try to
642 # instantiate since we don't know how to
--> 643 raise RuntimeError(msg) from None
644 raise exception

RuntimeError: Caught CompilationError in replica 0 on device 0.
Original Traceback (most recent call last):
File "", line 21, in _fwd_kernel
KeyError: ('2-.-0-.-0-83ca8b715a9dc5f32dc1110973485f64-d6252949da17ceb5f3a278a70250af13-3b85c7bef5f0a641282f3b73af50f599-3d2aedeb40d6d81c66a42791e268f98b-3498c340fd4b6ee7805fd54b882a04f5-e1f133f98d04093da2078dfc51c36b72-b26258bf01f839199e39d64851821f26-d7c06e3b46e708006c15224aac7a1378-f585402118c8a136948ce0a49cfe122c', (torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.float32, torch.bfloat16, torch.float32, torch.float32, 'fp32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32'), ('matrix', False, 64, False, False, True, 128, 128), (True, True, True, True, True, True, True, (False,), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (False, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False)))

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/home/groy/.local/lib/python3.10/site-packages/triton/compiler.py", line 937, in build_triton_ir
generator.visit(fn.parse())
File "/home/groy/.local/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
return super().visit(node)
File "/opt/conda/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/home/groy/.local/lib/python3.10/site-packages/triton/compiler.py", line 183, in visit_Module
ast.NodeVisitor.generic_visit(self, node)
File "/opt/conda/lib/python3.10/ast.py", line 426, in generic_visit
self.visit(item)
File "/home/groy/.local/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
return super().visit(node)
File "/opt/conda/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/home/groy/.local/lib/python3.10/site-packages/triton/compiler.py", line 252, in visit_FunctionDef
has_ret = self.visit_compound_statement(node.body)
File "/home/groy/.local/lib/python3.10/site-packages/triton/compiler.py", line 177, in visit_compound_statement
self.last_ret_type = self.visit(stmt)
File "/home/groy/.local/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
return super().visit(node)
File "/opt/conda/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/home/groy/.local/lib/python3.10/site-packages/triton/compiler.py", line 678, in visit_For
self.visit_compound_statement(node.body)
File "/home/groy/.local/lib/python3.10/site-packages/triton/compiler.py", line 177, in visit_compound_statement
self.last_ret_type = self.visit(stmt)
File "/home/groy/.local/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
return super().visit(node)
File "/opt/conda/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/home/groy/.local/lib/python3.10/site-packages/triton/compiler.py", line 319, in visit_AugAssign
self.visit(assign)
File "/home/groy/.local/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
return super().visit(node)
File "/opt/conda/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/home/groy/.local/lib/python3.10/site-packages/triton/compiler.py", line 301, in visit_Assign
values = self.visit(node.value)
File "/home/groy/.local/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
return super().visit(node)
File "/opt/conda/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/home/groy/.local/lib/python3.10/site-packages/triton/compiler.py", line 339, in visit_BinOp
rhs = self.visit(node.right)
File "/home/groy/.local/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
return super().visit(node)
File "/opt/conda/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/home/groy/.local/lib/python3.10/site-packages/triton/compiler.py", line 797, in visit_Call
return fn(*args, _builder=self.builder, **kws)
File "/home/groy/.local/lib/python3.10/site-packages/triton/impl/base.py", line 22, in wrapper
return fn(*args, **kwargs)
TypeError: dot() got an unexpected keyword argument 'trans_b'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/home/groy/.local/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker
output = module(*input, **kwargs)
File "/home/groy/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/groy/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/d064dece8a8b41d9fb8729fbe3435278786931f1/bert_layers.py", line 609, in forward
encoder_outputs = self.encoder(
File "/home/groy/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/groy/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/d064dece8a8b41d9fb8729fbe3435278786931f1/bert_layers.py", line 447, in forward
hidden_states = layer_module(hidden_states,
File "/home/groy/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/groy/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/d064dece8a8b41d9fb8729fbe3435278786931f1/bert_layers.py", line 328, in forward
attention_output = self.attention(hidden_states, cu_seqlens, seqlen,
File "/home/groy/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/groy/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/d064dece8a8b41d9fb8729fbe3435278786931f1/bert_layers.py", line 241, in forward
self_output = self.self(input_tensor, cu_seqlens, max_s, indices,
File "/home/groy/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/groy/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/d064dece8a8b41d9fb8729fbe3435278786931f1/bert_layers.py", line 186, in forward
attention = flash_attn_qkvpacked_func(qkv, bias)
File "/home/groy/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/home/groy/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/d064dece8a8b41d9fb8729fbe3435278786931f1/flash_attn_triton.py", line 1021, in forward
o, lse, ctx.softmax_scale = _flash_attn_forward(
File "/home/groy/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/d064dece8a8b41d9fb8729fbe3435278786931f1/flash_attn_triton.py", line 826, in _flash_attn_forward
_fwd_kernel[grid]( # type: ignore
File "/home/groy/.local/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 90, in run
return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
File "/home/groy/.local/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 199, in run
return self.fn.run(*args, **kwargs)
File "", line 41, in _fwd_kernel
File "/home/groy/.local/lib/python3.10/site-packages/triton/compiler.py", line 1621, in compile
next_module = compile(module)
File "/home/groy/.local/lib/python3.10/site-packages/triton/compiler.py", line 1550, in
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
File "/home/groy/.local/lib/python3.10/site-packages/triton/compiler.py", line 962, in ast_to_ttir
mod, _ = build_triton_ir(fn, signature, specialization, constants)
File "/home/groy/.local/lib/python3.10/site-packages/triton/compiler.py", line 942, in build_triton_ir
raise CompilationError(fn.src, node) from e
triton.compiler.CompilationError: at 114:24:
def _fwd_kernel(
Q,
K,
V,
Bias,
Out,
Lse,
TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
softmax_scale,
stride_qb,
stride_qh,
stride_qm,
stride_kb,
stride_kh,
stride_kn,
stride_vb,
stride_vh,
stride_vn,
stride_bb,
stride_bh,
stride_bm,
stride_ob,
stride_oh,
stride_om,
nheads,
seqlen_q,
seqlen_k,
seqlen_q_rounded,
headdim,
CACHE_KEY_SEQLEN_Q,
CACHE_KEY_SEQLEN_K,
BIAS_TYPE: tl.constexpr,
IS_CAUSAL: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr,
EVEN_M: tl.constexpr,
EVEN_N: tl.constexpr,
EVEN_HEADDIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
start_m = tl.program_id(0)
off_hb = tl.program_id(1)
off_b = off_hb // nheads
off_h = off_hb % nheads
# off_b = tl.program_id(1)
# off_h = tl.program_id(2)
# off_hb = off_b * nheads + off_h
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_HEADDIM)
# Initialize pointers to Q, K, V
# Adding parenthesis around indexing might use int32 math instead of int64 math?
# triton-lang/triton#741
# I'm seeing a tiny bit of difference (5-7us)
q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (
offs_m[:, None] * stride_qm + offs_d[None, :])
k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (
offs_n[:, None] * stride_kn + offs_d[None, :])
v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (
offs_n[:, None] * stride_vn + offs_d[None, :])
if BIAS_TYPE == 'vector':
b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
elif BIAS_TYPE == 'matrix':
b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (
offs_m[:, None] * stride_bm + offs_n[None, :])
else:
raise ValueError("BIAS_TYPE must be one of {'vector', 'matrix'}")
# initialize pointer to m and l
t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
# load q: it will stay in SRAM throughout
# [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
# tl.load(q_ptrs), we get the wrong output!
if EVEN_M & EVEN_N:
if EVEN_HEADDIM:
q = tl.load(q_ptrs)
else:
q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
else:
q = tl.load(q_ptrs,
mask=(offs_m[:, None] < seqlen_q) &
(offs_d[None, :] < headdim),
other=0.0)
# loop over k, v and update accumulator
end_n = seqlen_k if not IS_CAUSAL else tl.minimum(
(start_m + 1) * BLOCK_M, seqlen_k)
for start_n in range(0, end_n, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
if EVEN_HEADDIM:
k = tl.load(k_ptrs + start_n * stride_kn)
else:
k = tl.load(k_ptrs + start_n * stride_kn,
mask=offs_d[None, :] < headdim,
other=0.0)
else:
if EVEN_HEADDIM:
k = tl.load(k_ptrs + start_n * stride_kn,
mask=(start_n + offs_n)[:, None] < seqlen_k,
other=0.0)
else:
k = tl.load(k_ptrs + start_n * stride_kn,
mask=((start_n + offs_n)[:, None] < seqlen_k) &
(offs_d[None, :] < headdim),
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k, trans_b=True)
^

Thank you for your help

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants