Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

offload_dir error @ PeftModel.from_pretrained() when loading LoRA FLAN-T5 XL in default 13GB RAM Colab #136

Closed
minimaxir opened this issue Feb 26, 2023 · 7 comments · Fixed by #248

Comments

@minimaxir
Copy link

minimaxir commented Feb 26, 2023

Notebook for reproduction: https://colab.research.google.com/drive/1eWep-uJUEBVIM3FMxS08LaX5gzRZxQOv?usp=sharing

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-2-17ad1218bc80> in <module>
     14 
     15 # Load the Lora model
---> 16 model = PeftModel.from_pretrained(model, peft_model_id)

1 frames
/usr/local/lib/python3.8/dist-packages/peft/peft_model.py in from_pretrained(cls, model, model_id, **kwargs)
    175                     model, max_memory=max_memory, no_split_module_classes=no_split_module_classes
    176                 )
--> 177             model = dispatch_model(model, device_map=device_map)
    178             hook = AlignDevicesHook(io_same_device=True)
    179             if model.peft_config.peft_type == PeftType.LORA:

/usr/local/lib/python3.8/dist-packages/accelerate/big_modeling.py in dispatch_model(model, device_map, main_device, state_dict, offload_dir, offload_index, offload_buffers, preload_module_classes)
    289     disk_modules = [name for name, device in device_map.items() if device == "disk"]
    290     if offload_dir is None and offload_index is None and len(disk_modules) > 0:
--> 291         raise ValueError(
    292             "We need an `offload_dir` to dispatch this model according to this `device_map`, the following submodules "
    293             f"need to be offloaded: {', '.join(disk_modules)}."

ValueError: We need an `offload_dir` to dispatch this model according to this `device_map`, the following submodules need to be offloaded: base_model.model.decoder.block.10, base_model.model.decoder.block.11, base_model.model.decoder.block.12, base_model.model.decoder.block.13, base_model.model.decoder.block.14, base_model.model.decoder.block.15, base_model.model.decoder.block.16, base_model.model.decoder.block.17, base_model.model.decoder.block.18, base_model.model.decoder.block.19, base_model.model.decoder.block.20, base_model.model.decoder.block.21, base_model.model.decoder.block.22, base_model.model.decoder.block.23, base_model.model.decoder.final_layer_norm, base_model.model.decoder.dropout, base_model.model.lm_head.

It doesn't reproduce with a VM with more RAM, so accelerate is likely offloading. (system has 8.6 / 12.7 GB before it hits that line)

if there's another way to get a LoRAed FLAN-T5 XL to load within the default Colab VM, it would be appreciated! (although baking the LORA-weights would technically work too)

@minimaxir minimaxir changed the title offload_dir @ PeftModel.from_pretrained() when loading LoRA FLAN-T5 XL in default 13GB RAM Colab offload_dir error @ PeftModel.from_pretrained() when loading LoRA FLAN-T5 XL in default 13GB RAM Colab Feb 26, 2023
@vakandi
Copy link

vakandi commented Mar 28, 2023

does it happened to you even though you had enough ram on your system? (swap area excluded)

@rasmi
Copy link

rasmi commented Mar 29, 2023

See also #225.

I am running into the same issue with Alpaca-LoRA 7B. The error appears during loading at 13.3GB (with 15GB of total GPU memory allocated in my instance). The issue does not occur on a 'Premium' Colab GPU instance, which allocates 40GB of GPU memory.

model = LlamaForCausalLM.from_pretrained(
    'decapoda-research/llama-7b-hf',
    load_in_8bit=False,
    torch_dtype=torch.float16,
    device_map='auto'
)
model = PeftModel.from_pretrained(
    model,
    'tloen/alpaca-lora-7b',
    torch_dtype=torch.float16,
    force_download=True
)

I got it 'working' (my system RAM usage was increasing) by adding offload_folder="." and offload_dir="" to LlamaForCausalLM.from_pretrained and PeftModel.from_pretrained respectively, but have been unable to reproduce this approach since then.

@mberman84
Copy link

does it happened to you even though you had enough ram on your system? (swap area excluded)

I have 24gb on my MacBook air m2

@sergiocasero
Copy link

Same issue on my side,

Trying to run in my local env.

32GB RAM + 3080, it says We need an offload_dirto dispatch this model according to thisdevice_map, the following submodules need to be offloaded

Tried with @rasmi "modification" but no luck:

My code:

def load_alpaca(load_8bit=False, base_model="decapoda-research/llama-7b-hf", lora_weights="tloen/alpaca-lora-7b", prompt_template=""):
    tokenizer = LlamaTokenizer.from_pretrained(base_model)

    print(device)
    
    model = LlamaForCausalLM.from_pretrained(
        base_model,
        load_in_8bit=load_8bit,
        torch_dtype=torch.float16,
        device_map="auto",
        offload_folder="."
    )
    model = PeftModel.from_pretrained(
        model,
        lora_weights,
        torch_dtype=torch.float16,
        offload_dir=""
    )

@younesbelkada
Copy link
Contributor

@sergiocasero can you try the solution provided in #248 ?
I can confirm I was able to run @minimaxir 's notebook without any issue, however, you need to specify offload_dir when calling PeftModel.from_pretrained

@sergiocasero
Copy link

Yes!!! Thanks @younesbelkada

@younesbelkada
Copy link
Contributor

Hi @sergiocasero

Just a heads up ! #257 will be merged, you will need to use offload_folder instead of offload_dir after that PR being merged

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants