Skip to content

Commit

Permalink
[Feature] Enable Intel XPU support (huggingface#839)
Browse files Browse the repository at this point in the history
* enable xpu support

* fix bug

* review commits

* fix style

* add xou decorator

* refactor review commit

* fix test

* review commit

* fix test

* Update benchmark.yml (huggingface#856)

* Standardise example scripts (huggingface#842)

* Standardise example scripts

* fix plotting script

* Rename run_xxx to xxx

* Fix doc

---------

Co-authored-by: Costa Huang <costa.huang@outlook.com>

* Fix version check in import_utils.py (huggingface#853)

* dont use get_peft_model if model is already peft (huggingface#857)

* merge conflict

* add xou decorator

* resolve

* resolves

* upstream

* refactor and precommit

* fix new tests

* add device mapping for xpu

---------

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Costa Huang <costa.huang@outlook.com>
Co-authored-by: Adam Pauls <adpauls@gmail.com>
Co-authored-by: abhishek thakur <1183441+abhishekkrthakur@users.noreply.github.com>
  • Loading branch information
6 people authored and Andrew Lapp committed May 10, 2024
1 parent 52c3d36 commit 4219c86
Show file tree
Hide file tree
Showing 14 changed files with 119 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments

from trl import SFTTrainer
from trl.import_utils import is_xpu_available
from trl.trainer import ConstantLengthDataset


Expand Down Expand Up @@ -198,7 +199,10 @@ def create_datasets(tokenizer, args):

# Free memory for merging weights
del base_model
torch.cuda.empty_cache()
if is_xpu_available():
torch.xpu.empty_cache()
else:
torch.cuda.empty_cache()

model = AutoPeftModelForCausalLM.from_pretrained(output_dir, device_map="auto", torch_dtype=torch.bfloat16)
model = model.merge_and_unload()
Expand Down
12 changes: 10 additions & 2 deletions examples/research_projects/toxicity/scripts/evaluate-toxicity.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

from trl.import_utils import is_xpu_available


toxicity = evaluate.load("ybelkada/toxicity", "DaNLP/da-electra-hatespeech-detection", module_type="measurement")
ds = load_dataset("OxAISH-AL-LLM/wiki_toxic", split="test")
Expand Down Expand Up @@ -50,7 +52,10 @@
output_file = args.output_file
max_new_tokens = args.max_new_tokens
context_length = args.context_length
device = torch.cuda.current_device() if torch.cuda.is_available() else "cpu"
if is_xpu_available():
device = torch.xpu.current_device()
else:
device = torch.cuda.current_device() if torch.cuda.is_available() else "cpu"

# consider only toxic prompts
ds = ds.filter(lambda x: x["label"] == 1)
Expand Down Expand Up @@ -116,7 +121,10 @@
print(f"Model: {model_id} - Mean: {mean} - Std: {std}")

model = None
torch.cuda.empty_cache()
if is_xpu_available():
torch.xpu.empty_cache()
else:
torch.cuda.empty_cache()

# close file
file.close()
4 changes: 3 additions & 1 deletion examples/scripts/ddpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from transformers import CLIPModel, CLIPProcessor

from trl import DDPOConfig, DDPOTrainer, DefaultDDPOStableDiffusionPipeline
from trl.import_utils import is_xpu_available


@dataclass
Expand Down Expand Up @@ -119,7 +120,8 @@ def aesthetic_scorer(hub_model_id, model_filename):
model_id=hub_model_id,
model_filename=model_filename,
dtype=torch.float32,
).cuda()
)
scorer = scorer.xpu() if is_xpu_available() else scorer.cuda()

def _fn(images, prompts, metadata):
images = (images * 255).round().clamp(0, 255).to(torch.uint8)
Expand Down
6 changes: 5 additions & 1 deletion examples/scripts/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from trl import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead, PPOConfig, PPOTrainer, set_seed
from trl.core import LengthSampler
from trl.import_utils import is_xpu_available


tqdm.pandas()
Expand Down Expand Up @@ -154,7 +155,10 @@ def collator(data):
# to the same device as the PPOTrainer.
device = ppo_trainer.accelerator.device
if ppo_trainer.accelerator.num_processes == 1:
device = 0 if torch.cuda.is_available() else "cpu" # to avoid a `pipeline` bug
if is_xpu_available():
device = "xpu:0"
else:
device = 0 if torch.cuda.is_available() else "cpu" # to avoid a `pipeline` bug
ds_plugin = ppo_trainer.accelerator.state.deepspeed_plugin
task, model_name = args.ppo_config.reward_model.split(":")
if ds_plugin is not None and ds_plugin.is_zero3_init_enabled():
Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/ppo_multi_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from tqdm import tqdm
from transformers import BitsAndBytesConfig, HfArgumentParser, LlamaTokenizer

from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, is_xpu_available
from trl.core import LengthSampler


Expand Down Expand Up @@ -82,7 +82,7 @@ def tokenize(example):
)
model = AutoModelForCausalLMWithValueHead.from_pretrained(
script_args.model_name,
device_map={"": 0},
device_map={"": "xpu:0"} if is_xpu_available() else {"": 0},
peft_config=lora_config,
quantization_config=nf4_config,
reward_adapter=script_args.rm_adapter,
Expand Down
8 changes: 6 additions & 2 deletions examples/scripts/reward_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from tqdm import tqdm
from transformers import AutoModelForSequenceClassification, AutoTokenizer, BitsAndBytesConfig

from trl import RewardConfig, RewardTrainer
from trl import RewardConfig, RewardTrainer, is_xpu_available


tqdm.pandas()
Expand Down Expand Up @@ -83,7 +83,11 @@ class ScriptArguments:
elif args.load_in_8bit or args.load_in_4bit:
quantization_config = BitsAndBytesConfig(load_in_8bit=args.load_in_8bit, load_in_4bit=args.load_in_4bit)
# Copy the model to each device
device_map = {"": Accelerator().local_process_index}
device_map = (
{"": f"xpu:{Accelerator().local_process_index}"}
if is_xpu_available()
else {"": Accelerator().local_process_index}
)
else:
device_map = None
quantization_config = None
Expand Down
8 changes: 6 additions & 2 deletions examples/scripts/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from tqdm import tqdm
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, HfArgumentParser, TrainingArguments

from trl import SFTTrainer
from trl import SFTTrainer, is_xpu_available


tqdm.pandas()
Expand Down Expand Up @@ -77,7 +77,11 @@ class ScriptArguments:
load_in_8bit=script_args.load_in_8bit, load_in_4bit=script_args.load_in_4bit
)
# Copy the model to each device
device_map = {"": Accelerator().local_process_index}
device_map = (
{"": f"xpu:{Accelerator().local_process_index}"}
if is_xpu_available()
else {"": Accelerator().local_process_index}
)
torch_dtype = torch.bfloat16
else:
device_map = None
Expand Down
11 changes: 10 additions & 1 deletion tests/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import torch

from trl import is_peft_available, is_wandb_available
from trl import is_peft_available, is_wandb_available, is_xpu_available


def require_peft(test_case):
Expand Down Expand Up @@ -64,3 +64,12 @@ def require_torch_multi_gpu(test_case):
if torch.cuda.device_count() < 2:
test_case = unittest.skip("test requires multiple GPUs")(test_case)
return test_case


def require_torch_multi_xpu(test_case):
"""
Decorator marking a test that requires multiple XPUs. Skips the test if there aren't enough XPUs.
"""
if torch.xpu.device_count() < 2 and is_xpu_available():
test_case = unittest.skip("test requires multiple XPUs")(test_case)
return test_case
2 changes: 1 addition & 1 deletion trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .core import set_seed
from .environment import TextEnvironment, TextHistory
from .extras import BestOfNSampler
from .import_utils import is_diffusers_available, is_peft_available, is_wandb_available
from .import_utils import is_diffusers_available, is_peft_available, is_wandb_available, is_xpu_available
from .models import (
AutoModelForCausalLMWithValueHead,
AutoModelForSeq2SeqLMWithValueHead,
Expand Down
25 changes: 18 additions & 7 deletions trl/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from torch.nn.utils.rnn import pad_sequence
from transformers import top_k_top_p_filtering

from .import_utils import is_xpu_available


try:
from collections.abc import Mapping
Expand Down Expand Up @@ -241,7 +243,10 @@ def set_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if is_xpu_available():
torch.xpu.manual_seed_all(seed)
else:
torch.cuda.manual_seed_all(seed)


class LengthSampler:
Expand All @@ -257,16 +262,22 @@ def __call__(self):


class PPODecorators(object):
optimize_cuda_cache = False
optimize_device_cache = False

@classmethod
@contextmanager
def empty_cuda_cache(cls):
def empty_device_cache(cls):
yield
if cls.optimize_cuda_cache and torch.cuda.is_available():
gc.collect()
torch.cuda.empty_cache()
gc.collect()
if is_xpu_available():
if cls.optimize_device_cache and torch.xpu.is_available():
gc.collect()
torch.xpu.empty_cache()
gc.collect()
else:
if cls.optimize_device_cache and torch.cuda.is_available():
gc.collect()
torch.cuda.empty_cache()
gc.collect()


def randn_tensor(
Expand Down
28 changes: 28 additions & 0 deletions trl/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,18 @@ def is_peft_available() -> bool:
return importlib.util.find_spec("peft") is not None


def is_accelerate_greater_20_0() -> bool:
if _is_python_greater_3_8:
from importlib.metadata import version

accelerate_version = version("accelerate")
else:
import pkg_resources

accelerate_version = pkg_resources.get_distribution("accelerate").version
return accelerate_version >= "0.20.0"


def is_transformers_greater_than(version: str) -> bool:
_transformers_version = importlib.metadata.version("transformers")
return _transformers_version > version
Expand Down Expand Up @@ -60,3 +72,19 @@ def is_rich_available() -> bool:

def is_wandb_available() -> bool:
return importlib.util.find_spec("wandb") is not None


def is_xpu_available() -> bool:
if is_accelerate_greater_20_0:
import accelerate

return accelerate.utils.is_xpu_available()
else:
if importlib.util.find_spec("intel_extension_for_pytorch") is None:
return False
try:
import torch

return hasattr(torch, "xpu") and torch.xpu.is_available()
except RuntimeError:
return False
7 changes: 5 additions & 2 deletions trl/models/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from huggingface_hub.utils import EntryNotFoundError, HFValidationError, LocalEntryNotFoundError
from transformers import PreTrainedModel

from ..import_utils import is_peft_available, is_transformers_greater_than
from ..import_utils import is_peft_available, is_transformers_greater_than, is_xpu_available


if is_peft_available():
Expand Down Expand Up @@ -333,7 +333,10 @@ def _get_current_device(cls):
The current device.
"""
dummy_accelerator = Accelerator()
return dummy_accelerator.local_process_index if torch.cuda.is_available() else "cpu"
if is_xpu_available():
return f"xpu:{dummy_accelerator.local_process_index}"
else:
return dummy_accelerator.local_process_index if torch.cuda.is_available() else "cpu"

@classmethod
def _split_kwargs(cls, kwargs):
Expand Down
12 changes: 11 additions & 1 deletion trl/trainer/ppo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ class PPOConfig:
max_grad_norm: Optional[float] = None
"""Maximum gradient norm for gradient clipping"""
optimize_cuda_cache: bool = False
"""Optimize CUDA cache for slightly more memory-efficient training"""
"""DEPRECATED: use `optimize_device_cache` instead, which does the same thing."""
optimize_device_cache: Optional[bool] = False
"""Optimize device cache for slightly more memory-efficient training"""
early_stopping: bool = False
"""Whether to stop the PPO optimization loop early is the KL too high"""
target_kl: float = 1
Expand Down Expand Up @@ -135,6 +137,14 @@ class PPOConfig:
global_batch_size: tyro.conf.Suppress[int] = None
"""TO BE FILLED In RUNTIME: the effective `batch_size` across all processes"""

if optimize_cuda_cache is not None:
warnings.warn(
"The `optimize_cuda_cache` arguement will be deprecated soon, please use `optimize_device_cache` instead."
)
optimize_device_cache = optimize_cuda_cache
else:
optimize_device_cache = False

def __post_init__(self):
if self.forward_batch_size is not None:
warnings.warn(
Expand Down
15 changes: 9 additions & 6 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
stack_dicts,
stats_to_np,
)
from ..import_utils import is_torch_greater_2_0
from ..import_utils import is_torch_greater_2_0, is_xpu_available
from ..models import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper, create_reference_model
from . import AdaptiveKLController, BaseTrainer, FixedKLController, PPOConfig, RunningMoments

Expand Down Expand Up @@ -341,9 +341,12 @@ def __init__(
if not getattr(self.model, "is_sequential_parallel", False):
self.current_device = self.accelerator.device
else:
self.current_device = torch.device("cuda:0")
if is_xpu_available():
self.current_device = torch.device("xpu:0")
else:
self.current_device = torch.device("cuda:0")

PPODecorators.optimize_cuda_cache = self.config.optimize_cuda_cache
PPODecorators.optimize_device_cache = self.config.optimize_device_cache

self.running = RunningMoments(self.accelerator)

Expand Down Expand Up @@ -576,7 +579,7 @@ def _step_safety_checker(

return queries, responses, scores, masks

@PPODecorators.empty_cuda_cache()
@PPODecorators.empty_device_cache()
def step(
self,
queries: List[torch.LongTensor],
Expand Down Expand Up @@ -909,7 +912,7 @@ def prepare_model_inputs(self, queries: torch.Tensor, responses: torch.Tensor):
input_data.pop("labels", None) # we don't want to compute LM losses
return input_data

@PPODecorators.empty_cuda_cache()
@PPODecorators.empty_device_cache()
def batched_forward_pass(
self,
model: PreTrainedModelWrapper,
Expand Down Expand Up @@ -1000,7 +1003,7 @@ def batched_forward_pass(
torch.cat(all_masks)[:, :-1],
)

@PPODecorators.empty_cuda_cache()
@PPODecorators.empty_device_cache()
def train_minibatch(
self,
old_logprobs: torch.FloatTensor,
Expand Down

0 comments on commit 4219c86

Please sign in to comment.