Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions docs/recipes/generate.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
---
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't be published because of @249. I think the problem is missing variables in

FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE pip install --no-build-isolation -e ".[CORE,OPTIONAL,DEV,DOCS]"
.

(Like those in

FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install --no-build-isolation -e ".[CORE,OPTIONAL,DEV,DOCS]"
)

title: How to Generate with a Fast-LLM Model
---

Fast-LLM models support `generate` and `forward` operations through Hugging Face–compatible wrappers.

⚠️ Limitations:

- No support for `cache`, `past_key_values`, `labels`, `attention` outputs, or `inputs_embeds`
- `position_ids` are ignored and reconstructed from the attention mask
- **model-parallel** and **sequence-data-parallel** generation is **not** supported

---

### πŸ”§ Generating Text from a Fast-LLM Model

Below is a step-by-step example of how to generate text using a Fast-LLM model checkpoint from Hugging Face Hub.

```python
# Import dependencies
import huggingface_hub
from transformers import AutoTokenizer
from fast_llm.engine.checkpoint.config import CheckpointLoadConfig
from fast_llm.models.gpt.config import LlamaGPTHuggingfaceCheckpointFormat
from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM

# Specify model and configuration
model = "HuggingFaceTB/SmolLM2-135M-Instruct"
checkpoint_format = LlamaGPTHuggingfaceCheckpointFormat
max_new_tokens = 50

# Download model checkpoint from the Hugging Face Hub to a local directory
model_path = huggingface_hub.snapshot_download(repo_id=model, local_dir="/tmp")

# Load tokenizer from the downloaded model
tokenizer = AutoTokenizer.from_pretrained(model_path)

# Optional: updates to Fast-LLM config before loading the model
updates = {
("base_model", "transformer", "use_flash_attention"): True,
("distributed", "training_dtype"): "bf16"
}

# Load the model from the checkpoint with the given configuration
model = HuggingfaceGPTModelForCausalLM.from_pretrained(
CheckpointLoadConfig(
path=model_path,
format=checkpoint_format,
model_weights=True,
),
updates,
)

# Example input messages formatted for chat-style generation
messages = [
{"role": "user", "content": "What is gravity?"},
{"role": "user", "content": "Who is the president of EU?"},
]

# Convert messages into model input format using chat template
input_text = [tokenizer.apply_chat_template([el], tokenize=False) for el in messages]

# Prepare tokenized input for the model
tokenizer.padding_side = "left" # Important for correct padding
inputs = tokenizer(input_text, padding="longest", return_tensors="pt").to("cuda")

# Generate text using the model
outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, use_cache=False)

# Decode and display outputs
outputs = [tokenizer.decode(el, skip_special_tokens=True) for el in outputs]

print("--------------------------------------------------------------------")
for el in outputs:
print(el)
print("--------------------------------------------------------------------")
```
17 changes: 17 additions & 0 deletions fast_llm/engine/inference/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import logging
import os
import pathlib
Expand Down Expand Up @@ -36,6 +37,22 @@ def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool =
finally:
transformers.configuration_utils.CONFIG_NAME = _backup

def __deepcopy__(self, memo):
# Hugging Face's PretrainedModel will deep copy the config
# when `generate` is enabled. However, `fast_llm_config`
# cannot be deep copied if the world size is greater than 1,
# as it will contain references to process groups.
# Therefore, we copy it by reference instead.
cls = self.__class__
copied = cls.__new__(cls)
memo[id(self)] = copied
for k, v in self.__dict__.items():
if k == "fast_llm_config":
setattr(copied, k, v) # Keep the same reference
else:
setattr(copied, k, copy.deepcopy(v, memo))
return copied

@classmethod
def _get_config_dict(
cls, pretrained_model_name_or_path: str | os.PathLike | CheckpointLoadMetadataConfig, **kwargs
Expand Down
74 changes: 58 additions & 16 deletions fast_llm/engine/inference/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@
import pathlib
import typing

import torch
import transformers.generation.utils
import transformers.modeling_outputs

from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, FastLLMCheckpointFormat
from fast_llm.engine.inference.config import HuggingfaceModelConfig
from fast_llm.engine.inference.runner import InferenceRunner
from fast_llm.engine.multi_stage.config import StageMode
from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel
from fast_llm.engine.schedule.runner import ScheduleRunner
from fast_llm.utils import Assert


class HuggingfacePreTrainedModel(transformers.PreTrainedModel):
Expand All @@ -20,21 +24,36 @@ class HuggingfacePreTrainedModel(transformers.PreTrainedModel):
# _supports_cache_class = False
# _tied_weights_keys = []

def __init__(self, config: HuggingfaceModelConfig, fast_llm_model: FastLLMModel, **kwargs):
def __init__(
self,
fast_llm_model: FastLLMModel,
config: HuggingfaceModelConfig | None = None,
runner: ScheduleRunner | None = None,
**kwargs,
):
if config is None:
config = self.config_class(fast_llm_model.config)

assert self.runner_class.model_class.config_class is config.model_config_class
assert config.fast_llm_config is fast_llm_model.config
assert isinstance(config, self.config_class)

super().__init__(config, **kwargs)

self._inference_runner = self.runner_class(fast_llm_model)
if not fast_llm_model.is_setup:
fast_llm_model.setup(mode=StageMode.inference)
self._inference_runner = self.runner_class(fast_llm_model, runner)

# A model can be created from pretrained which set it up in the current HF wrapper api
# or set existing model which also must be setup, so, do not accept not setup model
assert fast_llm_model.is_setup

# We only support data parallel for now
Assert.eq(fast_llm_model.distributed.config.model_parallel, 1)
Assert.eq(fast_llm_model.distributed.config.sequence_data_parallel, 1)

self._inference_runner.setup()

# Transformers needs to be able to inspect the base model.
self.fast_llm_base_model = fast_llm_model.base_model
# TODO: Support distributed models?
assert fast_llm_model.config.distributed.world_size == 1

with transformers.modeling_utils.no_init_weights():
self.post_init()
Expand All @@ -43,8 +62,12 @@ def __init__(self, config: HuggingfaceModelConfig, fast_llm_model: FastLLMModel,
def from_pretrained(
cls,
pretrained_model_name_or_path: str | os.PathLike | CheckpointLoadConfig,
*,
mode: StageMode = StageMode.inference,
*updates: dict[str | tuple[str, ...], typing.Any],
optimizer_state_names: tuple[str, ...] | None = None,
# setup: bool = True,
mode: StageMode = StageMode.training,
use_cpu: bool = False,
stage_filter: set | None = None,
**kwargs,
) -> typing.Self:
# Pretrained config.
Expand All @@ -54,18 +77,37 @@ def from_pretrained(
format=FastLLMCheckpointFormat,
)

updates = {}
torch_dtype = kwargs.pop("torch_dtype", None)
if torch_dtype is not None:
updates[("distributed", "training_dtype")] = torch_dtype

# Create the model
# always set up model and crate distributed instance internally for now
fast_llm_model = cls.runner_class.model_class.from_pretrained(
pretrained_model_name_or_path, updates, mode=mode
pretrained_model_name_or_path,
*updates,
optimizer_state_names=optimizer_state_names,
setup=True,
mode=mode,
use_cpu=use_cpu,
stage_filter=stage_filter,
)
config = cls.config_class(fast_llm_model.config)

return cls(config, fast_llm_model, **kwargs)
return cls(fast_llm_model, **kwargs)

def _init_weights(self, module) -> None:
raise NotImplementedError(module)


class HuggingfaceBaseModelForCausalLM(HuggingfacePreTrainedModel, transformers.generation.utils.GenerationMixin):
def forward(
self,
input_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
past_key_values=None,
inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
) -> tuple | transformers.modeling_outputs.CausalLMOutputWithPast:
# Meant to be overridden in derived classes
raise NotImplementedError()
39 changes: 29 additions & 10 deletions fast_llm/engine/inference/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,42 @@
from fast_llm.engine.schedule.config import BatchConfig, ScheduleConfig
from fast_llm.engine.schedule.runner import ScheduleRunner
from fast_llm.engine.schedule.schedule import Schedule
from fast_llm.utils import Assert


class InferenceRunner(abc.ABC):
model_class: typing.ClassVar[type[FastLLMModel]] = FastLLMModel
batch_config_class: typing.ClassVar[type[BatchConfig]] = BatchConfig

def __init__(self, fast_llm_model: FastLLMModel):
def __init__(
self,
fast_llm_model: FastLLMModel,
runner: ScheduleRunner | None = None,
):
assert isinstance(fast_llm_model, self.model_class)
self._fast_llm_model = fast_llm_model
# We only need a basic schedule and don't care about dimensions.
self._schedule_config = ScheduleConfig()
# TODO: Sort things out.

with NoAutoValidate():
self._batch_config = self.batch_config_class()
self._batch_config.setup(self._fast_llm_model.config.distributed)
self._batch_config.validate()
self._runner = ScheduleRunner(
config=self._schedule_config,
multi_stage=self._fast_llm_model,
distributed_config=self._fast_llm_model.config.distributed,
)

if runner is None:
# We only need a basic schedule and don't care about dimensions.
self._schedule_config = ScheduleConfig()
# TODO: Sort things out.

self._runner = ScheduleRunner(
config=self._schedule_config,
multi_stage=self._fast_llm_model,
distributed_config=self._fast_llm_model.config.distributed,
)
else:
self._schedule_config = runner.config
self._runner = runner
# External runner from training loop must be already setup
assert runner._is_setup

# TODO: Random state? (Distributed.set_step)
self._schedule = Schedule(
multi_stage=self._fast_llm_model,
Expand All @@ -42,7 +57,11 @@ def fast_llm_model(self) -> FastLLMModel:
return self._fast_llm_model

def setup(self):
self._runner.setup(self._fast_llm_model.distributed)
if not self._runner._is_setup:
self._runner.setup(self._fast_llm_model.distributed)
else:
# Means external runner was passed, check it has the same distributed class as the model
Assert.is_(self._runner._distributed, self._fast_llm_model.distributed)

def forward(
self, input_, kwargs: dict, *, iteration: int = 1, return_metrics: bool = False
Expand Down
4 changes: 2 additions & 2 deletions fast_llm/engine/multi_stage/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from fast_llm.utils import Assert

if typing.TYPE_CHECKING:
from fast_llm.engine.inference.huggingface import HuggingfacePreTrainedModel
from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM
from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -247,7 +247,7 @@ def get_model_class(cls) -> type["FastLLMModel"]:
raise NotImplementedError

@classmethod
def get_huggingface_model_class(cls) -> type["HuggingfacePreTrainedModel"]:
def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceBaseModelForCausalLM"]:
raise NotImplementedError

@classmethod
Expand Down
16 changes: 16 additions & 0 deletions fast_llm/engine/multi_stage/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from fast_llm.tensor import ParameterMeta, TensorMeta, accumulate_gradient
from fast_llm.utils import Assert

if typing.TYPE_CHECKING:
pass

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -111,6 +114,19 @@ def forward(
metrics,
)
self._log_layer_forward(output, kwargs, i)

# TODO: very slow and memory consuming, only use for debugging for now
# TODO: decide if and how we want to return
# HF transformer style details from forward properly
if "output_hidden_states" in kwargs and kwargs["output_hidden_states"]:
# Last layer does not provide output
if output is not None:
meta = self._meta_outputs[i]
output_global, _ = meta.local_to_global(output.detach(), distributed=self._distributed)
kwargs["hidden_states"][self._layer_range[i]] = {
"layer_type": type(layer).__name__,
"tensor": output_global,
}
return None if output is None else output.detach(), (input_, output)

def backward(
Expand Down
16 changes: 16 additions & 0 deletions fast_llm/layers/language_model/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,22 @@ def _forward_backward(
with torch.enable_grad():
ln_output = self.final_norm(input_)

if "output_hidden_states" in kwargs and kwargs["output_hidden_states"]:
# The last hidden layer output is returned normalized in the HF Transformers-style output, at least for LLama style models.
# So, if needed, we gather the data after normalization and set it as the output of the previous layer.
dims = list(kwargs[TransformerKwargs.hidden_dims])
sequence_index = 1 - int(kwargs[TransformerKwargs.sequence_first])
dims[sequence_index] = (
TensorDim(
TransformerDimNames.sequence_q_tp, dims[sequence_index].global_size, DistributedDimNames.tensor
)
if self._sequence_parallel_logits
else TensorDim(TransformerDimNames.sequence_q, dims[sequence_index].global_size)
)
meta = TensorMeta.from_dims(tuple(dims), tensor_name="transformer hidden_state", dtype=ln_output.dtype)
hidden_state, _ = meta.local_to_global(ln_output.detach(), distributed=self._tensor_space.distributed)
kwargs["hidden_states"][len(kwargs["hidden_states"]) - 1]["tensor"] = hidden_state

grad_output = kwargs[TransformerKwargs.grad_output] / (
self._group_size if self._sequence_parallel_logits else 1
)
Expand Down
5 changes: 4 additions & 1 deletion fast_llm/layers/transformer/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,10 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None:
]
if (sequence_lengths := kwargs.get(TransformerKwargs.sequence_lengths, None)) is not None:
seq_ids = torch.stack(
[torch.cat([torch.arange(x) for x in sample_lens]) for sample_lens in sequence_lengths]
[
torch.cat([torch.full((x,), i) for i, x in enumerate(sample_lens)])
for sample_lens in sequence_lengths
]
)
document_mask = (seq_ids[:, None, :] == seq_ids[:, :, None]).to(self._tensor_space.distributed.device)
kwargs[TransformerKwargs.attention_mask] = (
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/models/custom/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_model_class(cls) -> type["CustomModel"]:
return CustomModel

@classmethod
def get_huggingface_model_class(cls) -> type["HuggingfaceCustomModelForCausalLM"]:
def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceCustomModelForCausalLM"]:
from fast_llm.models.custom.huggingface import HuggingfaceCustomModelForCausalLM

return HuggingfaceCustomModelForCausalLM
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/models/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def get_model_class(cls) -> type["GPTModel"]:
return GPTModel

@classmethod
def get_huggingface_model_class(cls) -> type["HuggingfaceGPTModelForCausalLM"]:
def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceGPTModelForCausalLM"]:
from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM

return HuggingfaceGPTModelForCausalLM
Expand Down
Loading