Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compile model+loss for LoRA single device recipe #1296

Merged
merged 6 commits into from
Aug 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 26 additions & 25 deletions recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,9 @@ def _setup_model(
if compile_model:
log.info("Compiling model with torch.compile...")
backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor")
model.compile(backend=backend)
self._loss_step_original = self._loss_step
self._loss_step = torch.compile(self._loss_step, backend=backend)

if self._device.type == "cuda":
memory_stats = utils.get_memory_stats(device=self._device)
utils.log_memory_stats(memory_stats)
Expand Down Expand Up @@ -526,6 +528,26 @@ def save_checkpoint(self, epoch: int) -> None:
adapter_only=self._save_adapter_weights_only,
)

def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
# Both are shape [b, s]
tokens, labels = batch["tokens"], batch["labels"]
# Get the attention mask and position ids from the dataset if they
# exist. Currently, only sample packing in PackedDataset returns these
mask = batch.get("mask", None) # shape [b, s, s]
input_pos = batch.get("input_pos", None) # shape [b, s]

logits = self._model(tokens, mask=mask, input_pos=input_pos)
# Shift so that tokens < n predict n
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
logits = logits.transpose(1, 2)
# Compute loss
loss = self._loss_fn(logits, labels)
# free logits otherwise it peaks backward memory
del logits

return loss

def train(self) -> None:
"""
The core training loop.
Expand Down Expand Up @@ -557,31 +579,10 @@ def train(self) -> None:
):
break

# Both are shape [b, s]
tokens, labels = batch["tokens"], batch["labels"]
# Get the attention mask and position ids from the dataset if they
# exist. Currently, only sample packing in PackedDataset returns these
mask = batch.get("mask", None) # shape [b, s, s]
input_pos = batch.get("input_pos", None) # shape [b, s]

tokens = tokens.to(self._device)
num_tokens += tokens.numel()
labels = labels.to(self._device)
mask = mask.to(self._device) if mask is not None else None
input_pos = (
input_pos.to(self._device) if input_pos is not None else None
)

logits = self._model(tokens, mask=mask, input_pos=input_pos)
# Shift so that tokens < n predict n
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
logits = logits.transpose(1, 2)
# Compute loss
loss = self._loss_fn(logits, labels)
# free logits otherwise it peaks backward memory
del logits
batch = {k: v.to(self._device) for k, v in batch.items()}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My concern is that the batch may contain keys that may not be necessary in the gpu. Can we keep the .to(device) like it was before? Or did you have some other motivation for it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried putting the old .to(device) code within self._loss_step(), but torch.compile() shows some warnings. Should be harmless though. I will double-check it again.

Another option is to keep the old .to(device) outside self._loss_step(). Then, it is a bit more clunky to pass tokens, labels, mask, input_pos to self._loss_step() instead of just batch (or we can just modify the batch object in-place). Lmk which one you prefer, then I will just change to that.

Copy link
Contributor

@felipemello1 felipemello1 Aug 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it should be fine to keep as you did. I dont think that our batches output extra terms, so that scenario I raised would be an exception.

num_tokens += batch["tokens"].numel()

loss = self._loss_step(batch)
loss = loss / self._gradient_accumulation_steps
running_loss += loss
loss.backward()
Expand Down
Loading