Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Enable Intel XPU support #839

Merged
merged 34 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
f505e00
enable xpu support
abhilash1910 Oct 6, 2023
227f487
fix bug
abhilash1910 Oct 6, 2023
8c644fc
Merge branch 'huggingface:main' into sycl
abhilash1910 Oct 6, 2023
ef44042
review commits
abhilash1910 Oct 6, 2023
ae42341
fix style
abhilash1910 Oct 9, 2023
b80112d
Merge branch 'huggingface:main' into sycl
abhilash1910 Oct 9, 2023
973bfe4
add xou decorator
abhilash1910 Oct 9, 2023
e72d3c2
refactor review commit
abhilash1910 Oct 9, 2023
295f87d
Merge branch 'huggingface:main' into sycl
abhilash1910 Oct 9, 2023
99f82eb
fix test
abhilash1910 Oct 10, 2023
1752ae4
review commit
abhilash1910 Oct 11, 2023
008c2a6
fix test
abhilash1910 Oct 12, 2023
d6f399b
Update benchmark.yml (#856)
lvwerra Oct 11, 2023
e72cc90
Standardise example scripts (#842)
lewtun Oct 11, 2023
3fce988
Fix version check in import_utils.py (#853)
adampauls Oct 11, 2023
9f4d177
dont use get_peft_model if model is already peft (#857)
abhishekkrthakur Oct 11, 2023
06dfd8e
merge conflict
abhilash1910 Oct 12, 2023
5e923b7
add xou decorator
abhilash1910 Oct 9, 2023
308d7bf
resolve
abhilash1910 Oct 12, 2023
2c8343a
resolves
abhilash1910 Oct 12, 2023
5805670
upstream
abhilash1910 Oct 12, 2023
0e54bc6
resolve conflicts
abhilash1910 Oct 12, 2023
bea855f
Merge pull request #1 from abhilash1910/main
abhilash1910 Oct 12, 2023
aa518c3
Merge branch 'main' into sycl
abhilash1910 Oct 12, 2023
2c3f999
refactor and precommit
abhilash1910 Oct 12, 2023
b0c8b52
Merge pull request #3 from abhilash1910/main
abhilash1910 Oct 12, 2023
a579504
Merge branch 'main' into sycl
abhilash1910 Oct 12, 2023
fdc0896
fix new tests
abhilash1910 Oct 12, 2023
2200089
Merge branch 'huggingface:main' into sycl
abhilash1910 Oct 12, 2023
98956ac
Merge branch 'huggingface:main' into sycl
abhilash1910 Oct 17, 2023
9a962fc
Merge branch 'huggingface:main' into sycl
abhilash1910 Oct 17, 2023
d6e9e4f
add device mapping for xpu
abhilash1910 Oct 17, 2023
aac92f8
Merge branch 'huggingface:main' into sycl
abhilash1910 Oct 20, 2023
8b77310
Merge branch 'huggingface:main' into sycl
abhilash1910 Oct 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from peft import AutoPeftModelForCausalLM, LoraConfig
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments
from accelerate.utils import is_xpu_available

from trl import SFTTrainer
from trl.trainer import ConstantLengthDataset
Expand Down Expand Up @@ -208,7 +209,10 @@ def create_datasets(tokenizer, args):

# Free memory for merging weights
del base_model
torch.cuda.empty_cache()
if is_xpu_available():
torch.xpu.empty_cache()
else:
torch.cuda.empty_cache()

model = AutoPeftModelForCausalLM.from_pretrained(output_dir, device_map="auto", torch_dtype=torch.bfloat16)
model = model.merge_and_unload()
Expand Down
2 changes: 1 addition & 1 deletion examples/research_projects/tools/python_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def solution():
log_with="wandb",
tracker_project_name="trl-gsm8k",
remove_unused_columns=False,
optimize_cuda_cache=True,
optimize_device_cache=True,
)

ppo_trainer = PPOTrainer(config=ppo_config, model=model, tokenizer=tokenizer, dataset=ds)
Expand Down
2 changes: 1 addition & 1 deletion examples/research_projects/tools/triviaqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class ScriptArguments:
ppo_epochs=args.ppo_epochs,
gradient_accumulation_steps=args.gradient_accumulation_steps,
seed=args.seed,
optimize_cuda_cache=True,
optimize_device_cache=True,
)
ppo_trainer = PPOTrainer(config=config, model=model, tokenizer=tokenizer)
dataset = load_dataset("trivia_qa", "rc", split="train")
Expand Down
12 changes: 9 additions & 3 deletions examples/research_projects/toxicity/scripts/evaluate-toxicity.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

from accelerate.utils import is_xpu_available

toxicity = evaluate.load("ybelkada/toxicity", "DaNLP/da-electra-hatespeech-detection", module_type="measurement")
ds = load_dataset("OxAISH-AL-LLM/wiki_toxic", split="test")
Expand Down Expand Up @@ -50,7 +50,10 @@
output_file = args.output_file
max_new_tokens = args.max_new_tokens
context_length = args.context_length
device = torch.cuda.current_device() if torch.cuda.is_available() else "cpu"
if is_xpu_available():
device = torch.xpu.current_device()
else:
device = torch.cuda.current_device() if torch.cuda.is_available() else "cpu"

# consider only toxic prompts
ds = ds.filter(lambda x: x["label"] == 1)
Expand Down Expand Up @@ -116,7 +119,10 @@
print(f"Model: {model_id} - Mean: {mean} - Std: {std}")

model = None
torch.cuda.empty_cache()
if is_xpu_available():
torch.xpu.empty_cache()
else:
torch.cuda.empty_cache()

# close file
file.close()
2 changes: 1 addition & 1 deletion examples/scripts/multi_adapter_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def collator(data):
batch_size=8,
mini_batch_size=2,
gradient_accumulation_steps=2,
optimize_cuda_cache=True,
optimize_device_cache=True,
)

ppo_trainer = PPOTrainer(
Expand Down
6 changes: 5 additions & 1 deletion examples/scripts/sentiment_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch
import tyro
from accelerate import Accelerator
from accelerate.utils import is_xpu_available
from datasets import load_dataset
from peft import LoraConfig
from tqdm import tqdm
Expand Down Expand Up @@ -153,7 +154,10 @@ def collator(data):
# to the same device as the PPOTrainer.
device = ppo_trainer.accelerator.device
if ppo_trainer.accelerator.num_processes == 1:
device = 0 if torch.cuda.is_available() else "cpu" # to avoid a `pipeline` bug
if is_xpu_available():
device = "xpu:0"
else:
device = 0 if torch.cuda.is_available() else "cpu" # to avoid a `pipeline` bug
ds_plugin = ppo_trainer.accelerator.state.deepspeed_plugin
task, model_name = args.ppo_config.reward_model.split(":")
if ds_plugin is not None and ds_plugin.is_zero3_init_enabled():
Expand Down
4 changes: 3 additions & 1 deletion examples/scripts/stable_diffusion_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError
from transformers import CLIPModel, CLIPProcessor
from accelerate.utils import is_xpu_available

from trl import DDPOConfig, DDPOTrainer, DefaultDDPOStableDiffusionPipeline

Expand Down Expand Up @@ -82,7 +83,8 @@ def aesthetic_scorer(hub_model_id, model_filename):
model_id=hub_model_id,
model_filename=model_filename,
dtype=torch.float32,
).cuda()
)
scorer = scorer.xpu() if is_xpu_available() or scorer.cuda()

def _fn(images, prompts, metadata):
images = (images * 255).round().clamp(0, 255).to(torch.uint8)
Expand Down
24 changes: 17 additions & 7 deletions trl/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from transformers import top_k_top_p_filtering

from accelerate.utils import is_xpu_available

try:
from collections.abc import Mapping
Expand Down Expand Up @@ -240,7 +240,10 @@ def set_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if is_xpu_available():
torch.xpu.manual_seed_all(seed)
else:
torch.cuda.manual_seed_all(seed)


class LengthSampler:
Expand All @@ -257,12 +260,19 @@ def __call__(self):

class PPODecorators(object):
optimize_cuda_cache = False
optimize_xpu_cache = False

@classmethod
@contextmanager
def empty_cuda_cache(cls):
def empty_device_cache(cls):
yield
if cls.optimize_cuda_cache and torch.cuda.is_available():
gc.collect()
torch.cuda.empty_cache()
gc.collect()
if is_xpu_available():
if cls.optimize_xpu_cache and torch.xpu.is_available():
gc.collect()
torch.xpu.empty_cache()
gc.collect()
else:
if cls.optimize_cuda_cache and torch.cuda.is_available():
gc.collect()
torch.cuda.empty_cache()
gc.collect()
3 changes: 2 additions & 1 deletion trl/models/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch
import torch.nn as nn
from accelerate import Accelerator
from accelerate.utils import is_xpu_available
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, HFValidationError, LocalEntryNotFoundError
from transformers import PreTrainedModel
Expand Down Expand Up @@ -328,7 +329,7 @@ def _get_current_device(cls):
The current device.
"""
dummy_accelerator = Accelerator()
return dummy_accelerator.local_process_index if torch.cuda.is_available() else "cpu"
return dummy_accelerator.local_process_index if (torch.cuda.is_available() or is_xpu_available()) else "cpu"

@classmethod
def _split_kwargs(cls, kwargs):
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/ppo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class PPOConfig:
"""Maximum gradient norm for gradient clipping"""
seed: int = 0
"""Seed value for random generations"""
optimize_cuda_cache: bool = False
optimize_device_cache: bool = False
"""Optimize CUDA cache for slightly more memory-efficient training"""
early_stopping: bool = False
"""Whether to stop the PPO optimization loop early is the KL too high"""
Expand Down
15 changes: 9 additions & 6 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import torch
import torch.nn.functional as F
from accelerate import Accelerator
from accelerate.utils import ProjectConfiguration, is_deepspeed_available
from accelerate.utils import ProjectConfiguration, is_deepspeed_available, is_xpu_available
from datasets import Dataset
from huggingface_hub import whoami
from packaging import version
Expand Down Expand Up @@ -341,9 +341,12 @@ def __init__(
if not getattr(self.model, "is_sequential_parallel", False):
self.current_device = self.accelerator.device
else:
self.current_device = torch.device("cuda:0")
if is_xpu_available():
self.current_device = torch.device("xpu:0")
else:
self.current_device = torch.device("cuda:0")

PPODecorators.optimize_cuda_cache = self.config.optimize_cuda_cache
PPODecorators.optimize_device_cache = self.config.optimize_device_cache

self.running = RunningMoments(self.accelerator)

Expand Down Expand Up @@ -576,7 +579,7 @@ def _step_safety_checker(

return queries, responses, scores, masks

@PPODecorators.empty_cuda_cache()
@PPODecorators.empty_device_cache()
def step(
self,
queries: List[torch.LongTensor],
Expand Down Expand Up @@ -909,7 +912,7 @@ def prepare_model_inputs(self, queries: torch.Tensor, responses: torch.Tensor):
input_data.pop("labels", None) # we don't want to compute LM losses
return input_data

@PPODecorators.empty_cuda_cache()
@PPODecorators.empty_device_cache()
def batched_forward_pass(
self,
model: PreTrainedModelWrapper,
Expand Down Expand Up @@ -1000,7 +1003,7 @@ def batched_forward_pass(
torch.cat(all_masks)[:, :-1],
)

@PPODecorators.empty_cuda_cache()
@PPODecorators.empty_device_cache()
def train_minibatch(
self,
old_logprobs: torch.FloatTensor,
Expand Down