Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

profiling ops on xpu #2249

Merged
merged 4 commits into from
Jan 24, 2025
Merged

profiling ops on xpu #2249

merged 4 commits into from
Jan 24, 2025

Conversation

songhappy
Copy link
Contributor

Context

What is the purpose of this PR? Is it to

  • add a new feature

Please link to any issues this PR addresses.
https://jira.devtools.intel.com/browse/IPB-2875

Changelog

What are the changes made in this PR?
Added 'xpu' in _profiler.py and modify cuda related memory profiler to cuda only in driver scripts in receipts directory.
*

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • manually run any new or modified recipes with sufficient proof of correctness
    Steps:
  1. download Llama-3.2-3B-Instruct model
  2. modify recipes/configs/llama3_2/3B_full_single_device.yaml, change "device: xpu", "profiler.enabled:True"
  3. tune run full_finetune_single_device --config recipes/configs/llama3_2/3B_full_single_device.yaml
  4. see profiling results under /tmp/full-llama3.2-finetune/profiling_results
{
  "schemaVersion": 1,
  "deviceProperties": [
  ],
  "with_flops": 1,
  "record_shapes": 1,
  "profile_memory": 1,
  "with_stack": 1,
  "trace_id": "1DB69D6280304432870B411620032DC3",
  "traceEvents": [
  {
    "ph": "X", "cat": "cpu_op", "name": "aten::conv2d", "pid": 1341316, "tid": 1
341316,
    "ts": 731713074715.025, "dur": 95265.387,
    "args": {
      "External id": 1,"Record function id": 0, "Concrete Inputs": ["", "", "",
"[2, 2]", "[3, 3]", "[1, 1]", "1"], "Input type": ["float", "float", "", "Scalar
List", "ScalarList", "ScalarList", "Scalar"], "Input Strides": [[150528, 50176,
224, 1], [147, 49, 7, 1], [], [], [], [], []], "Input Dims": [[32, 3, 224, 224],
 [64, 3, 7, 7], [], [], [], [], []], "Ev Idx": 0
    }
  },

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API

Copy link

pytorch-bot bot commented Jan 10, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2249

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit da30adb with merge base d7afc40 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 10, 2025
@songhappy
Copy link
Contributor Author

@SalmanMohammadi Could you please review and approve it?

@felipemello1
Copy link
Contributor

hey @songhappy , thanks for the PR!

Just two questions before i approve it:

  1. i saw that you added and self._device.type == "cuda" to some recipes, but not all. Is it because the other recipes already have it, or did we forget some?

  2. In your testing, i dont think that you set profiler.profile_memory=True. Do you think its worth checking this option to make sure its working for xpu?

@songhappy
Copy link
Contributor Author

@felipemello1 thanks a lot for reviewing it.

  1. Yes, all of them should have and self._device.type == "cuda", I forgot one, and added it now.
  2. modified testcases to cover profiler.profile_memory=True on xpu too.
  3. I have run recipe while setting profiler.profile_memory=True in the configeration file and copied some profiling logging when add this PR, it is not necessary to set profiler.profile_memory=True as default.

@felipemello1
Copy link
Contributor

Awesome, thanks @songhappy , do you mind just adding to https://github.com/pytorch/torchtune/blob/main/recipes/ppo_full_finetune_single_device.py?

I will merge after it

@songhappy
Copy link
Contributor Author

@felipemello1 added in ppo. One question, have you tried smaller models other than 7b to run https://github.com/pytorch/torchtune/blob/main/recipes/ppo_full_finetune_single_device.py? like 3b llama or similar models. I want to try smaller models of PPO. Please guide

@SalmanMohammadi
Copy link
Collaborator

Hey @songhappy. Try this config out which I used for testing with a tiny Llama2 model. You'll need to download the corresponding models at the top of the config. I haven't had a chance to test it in a little while so let me know if there's any issues with it.

# Config for single device RLHF full finetuning using PPO in ppo_full_finetune_single_device.py
# using a Mistral 7B model.
#
# This config has been tested on an A100 80GB.
# This config uses hyperparameters based on small set of experiments and information
# available from existing implementations.
#
# This config assumes that you've run the following command before launching
# this run:
#   tune download smohammadi/tinyllama_rm_sentiment_1b --output-dir /tmp/tinyllama_rm_sentiment_1b/
#   tune download TinyLlama/TinyLlama_v1.1 --output-dir /tmp/TinyLlama_v1.1/ 
#
# You'll also need to ensure that {output_dir} exists beforehand, as checkpoints for policy and value models are saved in sub-folders.
# The default config uses an optimizer from bitsandbytes. If you do not have it installed,
# you can install it with
#   pip install bitsandbytes
#
# To launch on a single device, run the following command from root:
#   tune run ppo_full_finetune_single_device --config <config_dir>/1B_full_ppo_low_memory_single_device.yaml

output_dir: /tmp/torchtune/llama2_1B_full_ppo_low_memory # /tmp may be deleted by your system. Change it to your preference.

# Tokenizer
tokenizer:
  _component_: torchtune.models.llama2.llama2_tokenizer
  path:  ./target/dummy/tokenizer.model
  max_seq_len: 512


# Dataset
dataset:
  _component_: torchtune.datasets.text_completion_dataset
  source: trl-internal-testing/sentiment-trl-style
  split: train
  column: prompt
  add_eos: False


policy_model:
  _component_: torchtune.models.llama2.llama2
  vocab_size: 32000
  num_layers: 22
  num_heads: 32
  num_kv_heads: 4
  embed_dim: 2048
  max_seq_len: 2048
  intermediate_dim: 5632
  attn_dropout: 0.0
  norm_eps: 1e-5

reward_and_value_model:
  _component_: torchtune.models.llama2.llama2_classifier
  num_classes: 1
  vocab_size: 32000
  num_layers: 22
  num_heads: 32
  num_kv_heads: 4
  embed_dim: 2048
  max_seq_len: 2048
  intermediate_dim: 5632
  attn_dropout: 0.0
  norm_eps: 1e-5

# checkpointer for the policy model - update this if resuming from checkpoint
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /tmp/TinyLlama_v1.1/ 
  checkpoint_files: [
      "pytorch_model.bin",
  ]
  # this is the only place where you should update `recipe_checkpoint` if resuming training
  recipe_checkpoint: null
  output_dir: ${output_dir}/policy
  model_type: LLAMA2

# this should be setup identically to the policy model checkpointer at the start of training
# ensure `checkpoint_files` always points to the original policy weights, even if resuming training
ref_policy_checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /tmp/TinyLlama_v1.1/ 
  checkpoint_files: [
      "pytorch_model.bin",
  ]
  # this is the only place where you should update `recipe_checkpoint` if resuming training
  recipe_checkpoint: null
  output_dir: ${output_dir}/policy
  model_type: LLAMA2

# checkpointer for the value model - update `checkpoint_files` if resuming from checkpoint
# since this model will be identical to the reward model it's helpful to initialise this
# from the trained reward model weights
value_checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir:  /tmp/tinyllama_rm_sentiment_1b/
  # only `checkpoint_files` need to be updated if resuming training
  checkpoint_files: [
      "model.safetensors"
  ]
  output_dir: ${output_dir}/value
  model_type: REWARD

# checkpointer for the reward model, ensure `checkpoint_files`
# always points to the original reward model weights, even if resuming training
reward_checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir:  /tmp/tinyllama_rm_sentiment_1b/
  # only `checkpoint_files` need to be updated if resuming training
  checkpoint_files: [
      "model.safetensors"
  ]
  output_dir: ${output_dir}/value
  model_type: REWARD

resume_from_checkpoint: False
seed: null
shuffle: True

# Training env
device: cuda

# Training arguments
batch_size: 16
num_steps: 1000
ppo_epochs: 1
ppo_batch_size: 16
gradient_accumulation_steps: 1  # Use to increase effective batch size

# Memory management and performance
compile: True  # torch.compile the model + loss, True increases speed + decreases memory
optimizer:
  _component_: bitsandbytes.optim.PagedAdamW
  lr: 3e-6
optimizer_in_bwd: True  # True saves memory. Requires gradient_accumulation_steps=1
log_peak_memory_stats: True
enable_activation_checkpointing: True # True reduces memory
enable_kv_cache: True

# Reduced precision
dtype: bf16

# batch size for forward pass during generation
forward_batch_size: 16
max_generated_tokens: 58
temperature: 0.7
top_k: null

# parameter for penalising generations shorter than `min_response_length`
min_response_length: 18
# parameter for penalising generations without a stop token
penalise_no_eos: True
# scalar penalty to apply when penalising
reward_penalty: -3

# tokens to consider as "end of sequence" tokens
stop_token_ids: [
  29889
]
whiten_rewards: False

# GAE hyperparameters
gamma: 1
lmbda: 0.95

# PPO hyperparameters
loss:
  _component_: torchtune.rlhf.loss.PPOLoss
  epsilon: 0.2
  value_coeff: 0.1
  value_clip_range: 0.2
kl_coeff: 0.01

# Logging
metric_logger:
  _component_: torchtune.training.metric_logging.DiskLogger
  log_dir: ${output_dir}/logs

log_every_n_steps: 1

profiler:
  _component_: torchtune.training.setup_torch_profiler
  enabled: False

  #Output directory of trace artifacts
  output_dir: ${output_dir}/profiling_outputs

  #`torch.profiler.ProfilerActivity` types to trace
  cpu: True
  cuda: True

  #trace options passed to `torch.profiler.profile`
  profile_memory: True
  with_stack: False
  record_shapes: False
  with_flops: False

  # `torch.profiler.schedule` options:
  # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
  wait_steps: 5
  warmup_steps: 3
  active_steps: 3
  num_cycles: 1

@felipemello1 felipemello1 merged commit 5764650 into pytorch:main Jan 24, 2025
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants