Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KTOTrainer Memory Leakage #2268

Open
2 of 4 tasks
Isaaclgz opened this issue Oct 24, 2024 · 3 comments
Open
2 of 4 tasks

KTOTrainer Memory Leakage #2268

Isaaclgz opened this issue Oct 24, 2024 · 3 comments
Labels
🐛 bug Something isn't working 🏋 KTO Related to KTO

Comments

@Isaaclgz
Copy link

System Info

I've been running some experiments on KTO using LoRA and noticed that there is a large disparity between the peak allocated and reserved memory, which I suspect is a memory leakage problem.

To my understanding, when using LoRA with KTOTrainer, the target model is also used as the reference model for calculating the reward and approximate KL-divergence by disabling the adapter weights . I am unfamiliar with how gradients and optimizer states might differ when performing KTO compared to SFT, but intuitively I don't think they should be much different, if at all. And so, if we aren't loading a separate model as the reference model, the memory usage should not be that high.

For reference, I am performing KTO with the following parameters, and using a variant of Llama-3-8B-Instruct (called 'aisingapore/llama3-8b-cpt-sea-lionv2.1-instruct'):

training_args = KTOConfig(
            num_train_epochs=2,
            per_device_train_batch_size=4,
            remove_unused_columns=False,
            gradient_accumulation_steps=4,
            learning_rate=5e-7,
            evaluation_strategy="steps",
            beta=0.5,
            max_length=1000,
            gradient_checkpointing=True,
            bf16 = True, 
            save_total_limit=0,
            undesirable_weight=1.0,
            max_prompt_length=800,
            max_completion_length=200,
            report_to="none"
)

Tracking memory usage with torch.cuda utilities, for KTO I have:
image

but for SFT on the same dataset with the same batch size, etc:
image

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

def torch_print_gpu_memory():
print(f"Memory allocated: {torch.cuda.memory_allocated() / 1024 ** 3:.4f} GB")
print(f"Memory reserved: {torch.cuda.memory_reserved() / 1024 ** 3:.4f} GB")
print(f"Peak memory allocated: {torch.cuda.max_memory_allocated() / 10243:.4f} GB")
print(f"Peak memory reserved: {torch.cuda.max_memory_reserved() / 1024
3:.4f} GB")

print('Before loading base model')
torch_print_gpu_memory()
print('\n----------\n')

model = AutoModelForCausalLM.from_pretrained(
'aisingapore/llama3-8b-cpt-sea-lionv2.1-instruct',
trust_remote_code=True,
device_map = 'auto',
torch_dtype = torch.bfloat16,
attn_implementation="flash_attention_2"
)

print('\n----------\n')
print('After loading base model')
torch_print_gpu_memory()
print('\n----------\n')

lora_config = LoraConfig(
r=128,
lora_alpha=128,
lora_dropout=0.05,
target_modules= 'all-linear',
)

training_args = KTOConfig(
num_train_epochs=2,
per_device_train_batch_size=4,
remove_unused_columns=False,
gradient_accumulation_steps=4,
learning_rate=5e-7,
evaluation_strategy="steps",
beta=0.5,
max_length=1000,
gradient_checkpointing=True,
bf16 = True,
save_total_limit=0,
undesirable_weight=1.0,
max_prompt_length=800,
max_completion_length=200,
report_to="none"
)

print('Before KTOTrainer Instantiation')
torch_print_gpu_memory()
print('\n----------\n')

trainer = KTOTrainer(
model=model,
args=training_args,
tokenizer=tokenizer,
train_dataset=dataset,
peft_config=lora_config,
)

print('\n----------\n')
print('After KTOTrainer Instantiation')
torch_print_gpu_memory()
print('\n----------\n')

print('\n----------\n')
print('Before KTO Train')
torch_print_gpu_memory()
print('\n----------\n')

trainer.train()

print('\n----------\n')
print('After KTO Train')
torch_print_gpu_memory()
print('\n----------\n')

Expected behavior

Peak memory reserved when using KTOTrainer for PEFT should be significantly closer to peak memory allocated.

@qgallouedec qgallouedec added 🐛 bug Something isn't working 🏋 KTO Related to KTO labels Oct 25, 2024
@qgallouedec
Copy link
Member

Thanks for reporting, please share your system info

@Isaaclgz
Copy link
Author

Thanks for reporting, please share your system info

Thanks for looking into this!

System:
Debian 11
Python 3.10
1xA100-80GB
Nvidia driver 550.90.07, CUDA 12.4
(running this on a GCP CE instance based on the c0-deeplearning-common-cu123-v20240922-debian-11-py310 image)

Env:
torch==2.4.0
transformers==4.44.0
trl==0.11.3
flash-attn==2.6.3
accelerate==1.0.1

@chenyang399
Copy link

is there any chance that we can run KTO script with 24G GPU

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🐛 bug Something isn't working 🏋 KTO Related to KTO
Projects
None yet
Development

No branches or pull requests

3 participants