Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepSpeed ZeRO-2 produces negative KL divergence #506

Closed
lewtun opened this issue Jul 7, 2023 · 6 comments
Closed

DeepSpeed ZeRO-2 produces negative KL divergence #506

lewtun opened this issue Jul 7, 2023 · 6 comments

Comments

@lewtun
Copy link
Member

lewtun commented Jul 7, 2023

Hello, while testing out the DeepSpeed ZeRO-2 plugin in the sentiment example for gpt2, I noticed that the KL divergence starts out negative. This suggests the model parameters of the reference and active model are being sharded in a peculiar manner that produces a mismatch in the log probs.

Below is a screenshot from WandB which shows the pure DDP baseline in teal vs the Z3 curve in purple:

Screenshot 2023-07-07 at 15 58 08

Code to reproduce

I ran this on 2 x A100 (80GB) machines, but that's overkill for this example :)

Accelerate config

# config.yaml
compute_environment: LOCAL_MACHINE
deepspeed_config:
  gradient_accumulation_steps: 1
  offload_optimizer_device: none
  offload_param_device: none
  zero3_init_flag: false
  zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Script

# gpt_sentiment.py
from dataclasses import dataclass, field
from typing import Optional

import torch
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoTokenizer, HfArgumentParser, pipeline

from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, set_seed
from trl.core import LengthSampler
from accelerate import Accelerator
from accelerate.utils import DummyOptim


tqdm.pandas()

@dataclass
class ScriptArguments:
    model_name: Optional[str] = field(default="lvwerra/gpt2-imdb", metadata={"help": "the model name"})
    log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"})
    learning_rate: Optional[float] = field(default=1.41e-5, metadata={"help": "the learning rate"})
    mini_batch_size: Optional[int] = field(default=128, metadata={"help": "the PPO minibatch size"})
    batch_size: Optional[int] = field(default=128, metadata={"help": "the batch size"})
    gradient_accumulation_steps: Optional[int] = field(
        default=1, metadata={"help": "the number of gradient accumulation steps"}
    )
    early_stopping: Optional[bool] = field(default=False, metadata={"help": "whether to early stop"})
    kl_penalty: Optional[str] = field(
        default="kl",
        metadata={
            "help": "kl penalty options: 'kl': model_logp - ref_logp,  'abs': abs(kl) and 'mse': mean squared error mse(kl)."
        },
    )
    target_kl: Optional[float] = field(default=0.1, metadata={"help": "kl target for early stopping"})
    seed: Optional[int] = field(default=0, metadata={"help": "the random seed"})


parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]

accelerator = Accelerator()

config = PPOConfig(
    model_name=script_args.model_name,
    learning_rate=script_args.learning_rate,
    log_with=script_args.log_with,
    mini_batch_size=script_args.mini_batch_size,
    batch_size=script_args.batch_size,
    gradient_accumulation_steps=script_args.gradient_accumulation_steps,
    early_stopping=script_args.early_stopping,
    target_kl=script_args.target_kl,
    kl_penalty=script_args.kl_penalty,
    seed=script_args.seed,
)

sent_kwargs = {"return_all_scores": True, "function_to_apply": "none", "batch_size": 16}

def build_dataset(config, dataset_name="imdb", input_min_text_length=2, input_max_text_length=8):
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)
    tokenizer.pad_token = tokenizer.eos_token
    # load imdb with datasets
    ds = load_dataset(dataset_name, split="train")
    ds = ds.rename_columns({"text": "review"})
    ds = ds.filter(lambda x: len(x["review"]) > 200, batched=False)

    input_size = LengthSampler(input_min_text_length, input_max_text_length)

    def tokenize(sample):
        sample["input_ids"] = tokenizer.encode(sample["review"])[: input_size()]
        sample["query"] = tokenizer.decode(sample["input_ids"])
        return sample

    ds = ds.map(tokenize, batched=False)
    ds.set_format(type="torch")
    return ds


# We retrieve the dataloader by calling the `build_dataset` function.
dataset = build_dataset(config)


def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])


# set seed before initializing value head for deterministic eval
set_seed(config.seed)

# Now let's build the model, the reference model, and the tokenizer.
model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
tokenizer = AutoTokenizer.from_pretrained(config.model_name)

tokenizer.pad_token = tokenizer.eos_token

# We then build the PPOTrainer, passing the model, the reference model, the tokenizer
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, dataset=dataset, data_collator=collator)

device = ppo_trainer.accelerator.device
if ppo_trainer.accelerator.num_processes == 1:
    device = 0 if torch.cuda.is_available() else "cpu"  # to avoid a `pipeline` bug
ds_plugin = ppo_trainer.accelerator.state.deepspeed_plugin
if ds_plugin is not None and ds_plugin.is_zero3_init_enabled():
    with ds_plugin.zero3_init_context_manager(enable=False):
        sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device)
else:
    sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device)

generation_kwargs = {
    "min_length": -1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
    "max_new_tokens": 32,
}

for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    query_tensors = batch["input_ids"]

    # Get response from gpt2
    response_tensors = ppo_trainer.generate(query_tensors, return_prompt=False, **generation_kwargs)
    batch["response"] = tokenizer.batch_decode(response_tensors)

    # Compute sentiment score
    texts = [q + r for q, r in zip(batch["query"], batch["response"])]
    pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
    rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]

    # Run PPO step
    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
    ppo_trainer.log_stats(stats, batch, rewards)

Run with

accelerate launch --config_file config.yaml gpt2_sentiment.py --log_with="wandb"

Env

- `transformers` version: 4.31.0.dev0
- `trl` version: trl @ git+https://github.com/lvwerra/trl.git@bbc7eeb29c7de42c93e11579676ecf7078fe88aa
- Platform: Linux-5.15.0-1023-aws-x86_64-with-glibc2.31
- Python version: 3.10.11
- Huggingface_hub version: 0.15.1
- Safetensors version: 0.3.1
- PyTorch version (GPU?): 2.0.1+cu118 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: yes
- Using distributed or parallel set-up in script?: <fill in>
- ```
@lewtun lewtun changed the title DeepSpeed ZeRO-3 produces negative KL divergence DeepSpeed ZeRO-2 produces negative KL divergence Aug 1, 2023
@vwxyzjn
Copy link
Contributor

vwxyzjn commented Aug 18, 2023

I was not able to reproduce the same issue. Probably the related bugs were already fixed in the latest main.

image

The report is here https://wandb.ai/costa-huang/trl/reports/deepspeed-test--Vmlldzo1MTc3NDcw

And here is one of the runs: https://wandb.ai/costa-huang/trl/runs/viz5drqj/logs?workspace=user-costa-huang, and its logs seem to indicate deepspeed is running as expected.

@lewtun
Copy link
Member Author

lewtun commented Aug 21, 2023

Thanks a lot for diving into this! I'm still getting a negative KL even after bumping trl to main - can you share the accelerate and transformers dependencies you're using?

@lewtun
Copy link
Member Author

lewtun commented Aug 21, 2023

Ah if you look closely at the KL divergence of your run (https://wandb.ai/costa-huang/trl/runs/viz5drqj?workspace=user-lewtun), one sees that it is indeed still slightly negative:

Screenshot 2023-08-21 at 15 37 00

Since step 0 should be a direct match between the reference & active models, it would make sense to see if we can understand why deepspeed is causing this difference. One possibility is that deepspeed is setting the active mode in train model (e.g. with dropout) while the reference mode in in eval model (no dropout)

@vwxyzjn
Copy link
Contributor

vwxyzjn commented Aug 21, 2023

Thanks a lot for diving into this! I'm still getting a negative KL even after bumping trl to main - can you share the accelerate and transformers dependencies you're using?

https://wandb.ai/costa-huang/trl/runs/viz5drqj/files/requirements.txt has all dependencies. I used your accelerate config in the issue description.

# config.yaml
compute_environment: LOCAL_MACHINE
deepspeed_config:
  gradient_accumulation_steps: 1
  offload_optimizer_device: none
  offload_param_device: none
  zero3_init_flag: false
  zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Since step 0 should be a direct match between the reference & active models, it would make sense to see if we can understand why deepspeed is causing this difference. One possibility is that deepspeed is setting the active model in train model (e.g. with dropout) while the reference model in in eval model (no dropout)

Interesting. Thanks for bringing up this point. I will look into it!

@lvwerra
Copy link
Member

lvwerra commented Aug 21, 2023

An explanation could be that maybe model is not in eval mode? In that case you could have a little bit of noise even if the models are identical.

@lewtun
Copy link
Member Author

lewtun commented Sep 13, 2023

Closed by #758 (the root cause of the issue was using bf16 mixed precision without properly initialising the reference model with deepspeed)

@lewtun lewtun closed this as completed Sep 13, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants