Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tensor parallel support to T5 via NxD #697

Merged
merged 29 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions optimum/exporters/neuron/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,7 @@ def _get_submodels_and_neuron_configs_for_encoder_decoder(
input_shapes=input_shapes,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
model_name_or_path=model_name_or_path,
JingyaHuang marked this conversation as resolved.
Show resolved Hide resolved
)
output_model_names = {
ENCODER_NAME: os.path.join(ENCODER_NAME, NEURON_FILE_NAME),
Expand Down
2 changes: 1 addition & 1 deletion optimum/exporters/neuron/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGen
def load_pretrained_with_parallel_attn(self, model, ckpt_path):
# Parallel implementation of Attention modules.
import neuronx_distributed
from t5_model_layers import ParallelSelfAttention, ParallelFF, ParallelCrossAttention
from .t5_model_layers import ParallelSelfAttention, ParallelFF, ParallelCrossAttention

for index, block in enumerate(model.decoder.block):
if index == 0:
Expand Down
38 changes: 20 additions & 18 deletions optimum/exporters/neuron/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import numpy as np
import torch
from transformers import PreTrainedModel

from ...exporters.error_utils import OutputMatchError, ShapeError
from ...neuron.utils import (
Expand All @@ -43,11 +44,10 @@
is_sentence_transformers_available,
logging,
)
from .config import TextSeq2SeqNeuronConfig


if TYPE_CHECKING:
from transformers import PreTrainedModel

from .base import NeuronDefaultConfig

if is_neuron_available():
Expand Down Expand Up @@ -369,7 +369,7 @@ def export_models(

start_time = time.time()
neuron_inputs, neuron_outputs = export(
model=submodel,
model_or_path=submodel,
config=sub_neuron_config,
output=output_path,
compiler_workdir=compiler_workdir,
Expand Down Expand Up @@ -432,7 +432,7 @@ def export_models(


def export(
model: "PreTrainedModel",
model_or_path: Union["PreTrainedModel", str, Path],
config: "NeuronDefaultConfig",
output: Path,
compiler_workdir: Optional[Path] = None,
Expand All @@ -445,7 +445,7 @@ def export(
) -> Tuple[List[str], List[str]]:
if is_neuron_available():
return export_neuron(
model=model,
model=model_or_path,
config=config,
output=output,
compiler_workdir=compiler_workdir,
Expand All @@ -457,7 +457,7 @@ def export(
)
elif is_neuronx_available():
return export_neuronx(
model=model,
model_or_path=model_or_path,
config=config,
output=output,
compiler_workdir=compiler_workdir,
Expand All @@ -473,7 +473,7 @@ def export(


def export_neuronx(
model: "PreTrainedModel",
model_or_path: Union["PreTrainedModel", str, Path],
config: "NeuronDefaultConfig",
output: Path,
compiler_workdir: Optional[Path] = None,
Expand All @@ -486,8 +486,8 @@ def export_neuronx(
Exports a PyTorch model to a serialized TorchScript module compiled by neuronx-cc compiler.

Args:
model ([`PreTrainedModel`]):
The model to export.
model_or_path (Union["PreTrainedModel", str, Path]):
The model to export or its location(case when applying the parallelism as the model needs to be loaded with the tracing).
config ([`~exporter.NeuronDefaultConfig`]):
The Neuron configuration associated with the exported model.
output (`Path`):
Expand All @@ -514,17 +514,19 @@ def export_neuronx(
if isinstance(compiler_workdir, Path):
compiler_workdir = compiler_workdir.as_posix()

if hasattr(model, "config"):
model.config.return_dict = True
model.config.torchscript = True
model.eval()
if hasattr(model_or_path, "config"):
model_or_path.config.return_dict = True
model_or_path.config.torchscript = True
if isinstance(model_or_path, PreTrainedModel):
model_or_path.eval()

# Check if we need to override certain configuration item
if config.values_override is not None:
logger.info(f"Overriding {len(config.values_override)} configuration item(s)")
for override_config_key, override_config_value in config.values_override.items():
logger.info(f"\t- {override_config_key} -> {override_config_value}")
setattr(model.config, override_config_key, override_config_value)
if isinstance(model_or_path, PreTrainedModel):
setattr(model_or_path.config, override_config_key, override_config_value)

# Prepare dummy inputs for tracing
input_shapes = {}
Expand All @@ -538,12 +540,12 @@ def export_neuronx(
# Prepare the model / function(tp) to trace
aliases = {}
tp_degree = config.tp_degree
if hasattr(model, "config") and getattr(model.config, "is_encoder_decoder", False):
checked_model = config.patch_model_for_export(model, **input_shapes)
if hasattr(model_or_path, "config") and isinstance(config, TextSeq2SeqNeuronConfig):
checked_model = config.patch_model_for_export(model_or_path, **input_shapes)
if tp_degree==1:
aliases = config.generate_io_aliases(checked_model)
else:
checked_model = config.patch_model_for_export(model, dummy_inputs)
checked_model = config.patch_model_for_export(model_or_path, dummy_inputs)

# Construct compiler configurations
if auto_cast is not None:
Expand Down Expand Up @@ -595,7 +597,7 @@ def export_neuronx(
improve_stable_diffusion_loading(config, neuron_model)
torch.jit.save(neuron_model, output)

del model
del model_or_path
del checked_model
del dummy_inputs
del neuron_model
Expand Down
89 changes: 45 additions & 44 deletions optimum/exporters/neuron/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
T5EncoderWrapper,
UnetNeuronWrapper,
)
from transformers import T5ForConditionalGeneration


if TYPE_CHECKING:
Expand Down Expand Up @@ -795,31 +796,31 @@ class T5EncoderNeuronConfig(TextSeq2SeqNeuronConfig):
def is_decoder(self) -> bool:
return False

def patch_model_for_export(self, model, device="xla", **kwargs):
def patch_model_for_export(self, model_or_path, device="xla", **kwargs):
num_beams = kwargs.pop("num_beams", 1)
sequence_length = kwargs.pop("sequence_length", None)
batch_size = kwargs.pop("batch_size", None)

if self.tp_degree > 1:
return self.patch_model_for_parallel_export(model, sequence_length, batch_size, num_beams, device)
# `torch.nn.modules` objects not eligible for pickling, the model needs to be loaded within the func.
return partial(self.get_parallel_encoder_func, model_or_path, sequence_length, batch_size, num_beams, device, self.tp_degree)
else:
return self.CUSTOM_MODEL_WRAPPER(model, sequence_length=sequence_length, batch_size=batch_size, num_beams=num_beams, device=device, tp_degree=self.tp_degree)
return self.CUSTOM_MODEL_WRAPPER(model_or_path, sequence_length=sequence_length, batch_size=batch_size, num_beams=num_beams, device=device, tp_degree=self.tp_degree)

def patch_model_for_parallel_export(self, model, sequence_length, batch_size, num_beams, device):
def get_parallel_encoder_func(self, model_name_or_path, sequence_length, batch_size, num_beams, device, tp_degree):
"""Unlike `torch_neuronx.trace`, `parallel_model_trace` requires a function returning a model object and a dictionary of states."""
def get_wrapped_encoder(model, sequence_length, batch_size, num_beams, device, tp_degree):
model.config.use_cache = True
parallizer = ParallelizersManager.parallelizer_for_model(model)
with parallizer.saved_model_in_temporary_directory(model) as ckpt_path:
parallel_model = self.load_pretrained_with_parallel_attn(model, ckpt_path)
# using parallizer
# parallizer = ParallelizersManager.parallelizer_for_model(model)
# model = parallizer.parallelize(model)
encoder = self.CUSTOM_MODEL_WRAPPER(parallel_model, sequence_length=sequence_length, batch_size=batch_size, num_beams=num_beams, device=device, tp_degree=tp_degree)
encoder.eval()
aliases = self.generate_io_aliases(encoder)
return encoder, aliases
return partial(get_wrapped_encoder, model, sequence_length, batch_size, num_beams, device, self.tp_degree)
model = T5ForConditionalGeneration.from_pretrained(model_name_or_path, torch_dtype="auto")
model.config.use_cache = True
parallizer = ParallelizersManager.parallelizer_for_model(model)
with parallizer.saved_model_in_temporary_directory(model) as ckpt_path:
parallel_model = self.load_pretrained_with_parallel_attn(model, ckpt_path)
# using parallizer
# parallizer = ParallelizersManager.parallelizer_for_model(model)
# model = parallizer.parallelize(model)
encoder = self.CUSTOM_MODEL_WRAPPER(parallel_model, sequence_length=sequence_length, batch_size=batch_size, num_beams=num_beams, device=device, tp_degree=tp_degree)
encoder.eval()
aliases = self.generate_io_aliases(encoder)
return encoder, aliases

def generate_io_aliases(self, encoder=None):
if self.tp_degree > 1:
Expand Down Expand Up @@ -882,38 +883,38 @@ def patch_model_for_export(self, model, device="xla", **kwargs):
"num_beams": num_beams,
"output_hidden_states": self.output_hidden_states,
"output_attentions": self.output_attentions,
"device": device
"device": device,
"tp_degree": self.tp_degree,
}
if self.tp_degree > 1:
return self.patch_model_for_parallel_export(**trace_args)
return partial(self.get_parallel_decoder_func, model, batch_size, sequence_length, num_beams, self.output_hidden_states, self.output_attentions, device, self.tp_degree)
else:
return self.CUSTOM_MODEL_WRAPPER(**trace_args)

def patch_model_for_parallel_export(self, model, batch_size, sequence_length, num_beams, output_hidden_states, output_attentions, device):
def get_parallel_decoder_func(self, model, batch_size, sequence_length, num_beams, output_hidden_states, output_attentions, device, tp_degree):
"""Unlike `torch_neuronx.trace`, `parallel_model_trace` requires a function returning a model object and a dictionary of states."""
def get_wrapped_decoder(model, batch_size, sequence_length, num_beams, output_hidden_states, output_attentions, device, tp_degree):
model.config.use_cache = True
parallizer = ParallelizersManager.parallelizer_for_model(model)
with parallizer.saved_model_in_temporary_directory(model) as ckpt_path:
parallel_model = self.load_pretrained_with_parallel_attn(model, ckpt_path)
# using parallizer
# parallizer = ParallelizersManager.parallelizer_for_model(model)
# model = parallizer.parallelize(model)

decoder = self.CUSTOM_MODEL_WRAPPER(
parallel_model,
batch_size=batch_size,
sequence_length=sequence_length,
num_beams=num_beams,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
device=device,
tp_degree=tp_degree,
)
decoder.eval()
aliases = self.generate_io_aliases(decoder)
return decoder, aliases
return partial(get_wrapped_decoder, model, batch_size, sequence_length, num_beams, output_hidden_states, output_attentions, device, self.tp_degree)
model.config.use_cache = True
# parallizer = ParallelizersManager.parallelizer_for_model(model)
# with parallizer.saved_model_in_temporary_directory(model) as ckpt_path:
# parallel_model = self.load_pretrained_with_parallel_attn(model, ckpt_path)
# # using parallizer
# # parallizer = ParallelizersManager.parallelizer_for_model(model)
# # model = parallizer.parallelize(model)

parallel_model = model
decoder = self.CUSTOM_MODEL_WRAPPER(
parallel_model,
batch_size=batch_size,
sequence_length=sequence_length,
num_beams=num_beams,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
device=device,
tp_degree=tp_degree,
)
decoder.eval()
aliases = self.generate_io_aliases(decoder)
return decoder, aliases

def generate_io_aliases(self, decoder):
num_outputs_from_trace = 3 if decoder.num_beams > 1 else 1
Expand Down
20 changes: 18 additions & 2 deletions optimum/exporters/neuron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import copy
import os
from pathlib import Path
from collections import OrderedDict
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -471,6 +472,7 @@ def get_encoder_decoder_models_for_export(
dynamic_batch_size: Optional[bool] = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
model_name_or_path: Optional[Union[str, Path]] = None,
) -> Dict[str, Tuple["PreTrainedModel", "NeuronDefaultConfig"]]:
"""
Returns the components of an encoder-decoder model and their subsequent neuron configs.
Expand All @@ -493,6 +495,8 @@ def get_encoder_decoder_models_for_export(
Whether or not for the traced model to return the attentions tensors of all attention layers.
output_hidden_states (`bool`, defaults to `False`):
Whether or not for the traced model to return the hidden states of all layers.
model_name_or_path (`Optional[Union[str, Path]]`, defaults to `None`):
The location from where the model is loaded, this is needed in the case of tensor parallelism, since we need to load the model within the tracing API.

Returns:
`Dict[str, Tuple["PreTrainedModel", "NeuronDefaultConfig"]]`: A Dict containing the model and
Expand All @@ -516,7 +520,13 @@ def get_encoder_decoder_models_for_export(
tensor_parallel_size=tensor_parallel_size,
**input_shapes,
)
models_for_export[ENCODER_NAME] = (model, encoder_neuron_config)
if not tensor_parallel_size > 1:
models_for_export[ENCODER_NAME] = (model, encoder_neuron_config)
else:
if model_name_or_path:
models_for_export[ENCODER_NAME] = (model_name_or_path, encoder_neuron_config)
else:
raise ValueError(f"you need to precise `model_name_or_path` when the parallelism is on, but now it's {model_name_or_path}.")

# Decoder
model_type = getattr(model.config, "model_type") + "-decoder"
Expand All @@ -536,6 +546,12 @@ def get_encoder_decoder_models_for_export(
output_hidden_states=output_hidden_states,
**input_shapes,
)
models_for_export[DECODER_NAME] = (model, decoder_neuron_config)
if not tensor_parallel_size > 1:
models_for_export[DECODER_NAME] = (model, decoder_neuron_config)
else:
if model_name_or_path:
models_for_export[DECODER_NAME] = (model_name_or_path, decoder_neuron_config)
else:
raise ValueError(f"you need to precise `model_name_or_path` when the parallelism is on, but now it's {model_name_or_path}.")

return models_for_export
Loading