Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Easy backwards compatibility fix #41

Open
michael-heinrich opened this issue May 6, 2024 · 4 comments
Open

Easy backwards compatibility fix #41

michael-heinrich opened this issue May 6, 2024 · 4 comments

Comments

@michael-heinrich
Copy link

Your version of transformers forces LlamaFlashAttention2 in the constructor of LlamaDecoderLayer in transformers/models/llama/modeling_llama.py which requires Ampere or newer to work. Just by using the old LlamaAttention class instead of LlamaFlashAttention2 here, I could make the video inference demo run on an ancient GTX1060 (even if it's very slow).
The current main branch of transformers uses a mechanism to decide which is the best compatible attention for this purpose.
If you don't want to backport that, you could use a very simple logic to decide which class to use here. Something like this:

def is_at_least_ampere():
    if torch.cuda.is_available():
        num_of_gpus = torch.cuda.device_count()

        # Loop over each GPU
        for i in range(num_of_gpus):
            gpu_properties = torch.cuda.get_device_properties(i)

            # Compute capability is major.minor version format
            # Convert it to a float for comparison
            compute_capability = float(f"{gpu_properties.major}.{gpu_properties.minor}")

            # If compute capability is less than 8.0 (Ampere or newer), return False
            if compute_capability < 8.0:
                return False

        # If all GPUs are Ampere or newer, return True
        return True
    else:
        # If CUDA is not available, return False
        return False

class LlamaDecoderLayer(nn.Module):
    def __init__(self, config: LlamaConfig):
        super().__init__()
        self.hidden_size = config.hidden_size
        ampere_or_newer = is_at_least_ampere()
        self.self_attn = (
            LlamaFlashAttention2(config=config) if ampere_or_newer else LlamaAttention(config=config)
            # LlamaAttention(config=config)
            # LlamaFlashAttention2(config=config)
        )
        self.mlp = LlamaMLP(config)

`

@Efficient-Large-Language-Model
Copy link
Contributor

I see. Thanks! Could you submit a PR?

@rahulthakur319
Copy link

Thanks for this smart fix. @michael-heinrich : Do you see a similar workaround to provide backwards compatibility to awq/kernels as well for running VILA using AWQ & TinyChat?

This will help to run video inference demo on ancient GPUs faster ;-)

Refer: https://github.com/mit-han-lab/llm-awq/

At present, the installation of awq/kernels fails due to following error

Feature '.m16n8k16' requires .target sm_80 or higher

@michael-heinrich
Copy link
Author

Thanks for this smart fix. @michael-heinrich : Do you see a similar workaround to provide backwards compatibility to awq/kernels as well for running VILA using AWQ & TinyChat?

This will help to run video inference demo on ancient GPUs faster ;-)

Refer: https://github.com/mit-han-lab/llm-awq/

At present, the installation of awq/kernels fails due to following error

Feature '.m16n8k16' requires .target sm_80 or higher

I will definitely look into it. Last night I already spent a few hours on getting the AWQ quants running, but no luck so far. From the source code / documentation of the transformers library, it appears to have AWQ support built in and with a few changes to the HF repo, I could partially load the AWQ checkpoint using the video inference demo. However, in the end the shapes of the tensors did not match. But maybe it's possible to load it like this.
In the end, I was not sure that's even remotely the right direction.
Transformers also allows to quantize a model when loading using bitsandbytes. That might work on an older card but would not have the accuracy of an AWQ quant.

@vedernikovphoto
Copy link

Your version of transformers forces LlamaFlashAttention2 in the constructor of LlamaDecoderLayer in transformers/models/llama/modeling_llama.py which requires Ampere or newer to work. Just by using the old LlamaAttention class instead of LlamaFlashAttention2 here, I could make the video inference demo run on an ancient GTX1060 (even if it's very slow). The current main branch of transformers uses a mechanism to decide which is the best compatible attention for this purpose. If you don't want to backport that, you could use a very simple logic to decide which class to use here. Something like this:

def is_at_least_ampere():
    if torch.cuda.is_available():
        num_of_gpus = torch.cuda.device_count()

        # Loop over each GPU
        for i in range(num_of_gpus):
            gpu_properties = torch.cuda.get_device_properties(i)

            # Compute capability is major.minor version format
            # Convert it to a float for comparison
            compute_capability = float(f"{gpu_properties.major}.{gpu_properties.minor}")

            # If compute capability is less than 8.0 (Ampere or newer), return False
            if compute_capability < 8.0:
                return False

        # If all GPUs are Ampere or newer, return True
        return True
    else:
        # If CUDA is not available, return False
        return False

class LlamaDecoderLayer(nn.Module):
    def __init__(self, config: LlamaConfig):
        super().__init__()
        self.hidden_size = config.hidden_size
        ampere_or_newer = is_at_least_ampere()
        self.self_attn = (
            LlamaFlashAttention2(config=config) if ampere_or_newer else LlamaAttention(config=config)
            # LlamaAttention(config=config)
            # LlamaFlashAttention2(config=config)
        )
        self.mlp = LlamaMLP(config)

`

Hi!

Thanks for your piece of code. Have you changed anything apart from that? I am encountering an issue when running inference on the Llama-3-VILA1.5-8B model. The error message I receive is:

RuntimeError: FlashAttention only supports Ampere GPUs or newer.

I am using a V100 GPU, which is not an Ampere GPU. Could you please provide guidance on how to disable Flash Attention for this model, and if there are any other steps besides what you have already provided? Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants