Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix generate callback to work with precision context #322

Merged
merged 4 commits into from
Jun 14, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions llmfoundry/callbacks/generate_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch
import wandb
from composer.core import Callback, State
from composer.core import Callback, State, get_precision_context
from composer.loggers import Logger, WandBLogger
from composer.utils import dist, ensure_tuple
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
Expand Down Expand Up @@ -72,15 +72,16 @@ def generate(self, state: State, logger: Logger):
# dummy forward call needed for FSDP to work consistently
dummy_input = torch.tensor([[0]], dtype=torch.long)
dummy_input = device.tensor_to_device(dummy_input)
with torch.no_grad():
_ = model.model(input_ids=dummy_input) # type: ignore

output_token_ids = model.model.generate( # type: ignore
input_ids=tokenized_input['input_ids'],
attention_mask=tokenized_input['attention_mask'],
synced_gpus=True,
**self.generate_kwargs,
)
with get_precision_context(state.precision):
with torch.no_grad():
_ = model.model(input_ids=dummy_input) # type: ignore

output_token_ids = model.model.generate( # type: ignore
input_ids=tokenized_input['input_ids'],
attention_mask=tokenized_input['attention_mask'],
synced_gpus=True,
**self.generate_kwargs,
)

if dist.get_global_rank() == 0:
if self.wandb_logger is not None:
Expand Down