Skip to content

Commit

Permalink
Faster & memory-efficient logprobs calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
li-plus committed Dec 2, 2023
1 parent c26d450 commit 0b1da65
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 11 deletions.
8 changes: 4 additions & 4 deletions trlx/models/modeling_nemo_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from trlx.data.ppo_types import PPORLBatch
from trlx.models.modeling_ppo import PPOConfig
from trlx.utils import to_device, tree_map
from trlx.utils.modeling import logprobs_of_labels, whiten
from trlx.utils.modeling import logprobs_of_next_labels, whiten

# Track a per dp rank RNG to sample different rollouts
# per dp rank
Expand Down Expand Up @@ -993,7 +993,7 @@ def loss_func(model_output):
start = batch.query_tensors.shape[1]
end = start + response_length

label_logprobs = logprobs_of_labels(logits[:, :-1, :], inputs[:, 1:])
label_logprobs = logprobs_of_next_labels(logits, inputs)
label_logprobs = label_logprobs[:, start:end]

advantages, returns = self.ppo_config.get_advantages_and_returns(
Expand Down Expand Up @@ -1079,11 +1079,11 @@ def ppo_postprocess(model_output):
# to save memory

if run_policy_model and compute_logprobs:
logprobs = logprobs_of_labels(logits[:, :-1, :], tokens[:, 1:])
logprobs = logprobs_of_next_labels(logits, tokens)
return logprobs, dict(logprobs=logprobs, values=values)

if run_reference_model and compute_logprobs:
ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], tokens[:, 1:])
ref_logprobs = logprobs_of_next_labels(ref_logits, tokens)
return ref_logprobs, dict(ref_logprobs=ref_logprobs)

return logits, {"logits": logits, "values": values, "ref_logits": ref_logits}
Expand Down
14 changes: 7 additions & 7 deletions trlx/trainer/accelerate_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from trlx.trainer import register_trainer
from trlx.trainer.accelerate_base_trainer import AccelerateRLTrainer
from trlx.utils import Clock, infinite_dataloader
from trlx.utils.modeling import RunningMoments, gather_dict, logprobs_of_labels
from trlx.utils.modeling import RunningMoments, gather_dict, logprobs_of_next_labels

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -163,7 +163,7 @@ def loss(self, batch: PPORLBatch) -> Tuple[float, Dict[str, Any]]:

logits = outputs.logits
values_pred = outputs.value
logprobs = logprobs_of_labels(logits[:, :-1, :], decoder_input_ids[:, 1:])
logprobs = logprobs_of_next_labels(logits, decoder_input_ids)
mask = decoder_input_ids.ne(self.tokenizer.pad_token_id).long().to(self.accelerator.device)
start = 0
end = start + response_length
Expand All @@ -181,7 +181,7 @@ def loss(self, batch: PPORLBatch) -> Tuple[float, Dict[str, Any]]:
logits = outputs.logits
values_pred = outputs.value
values_pred = values_pred[:, :-1]
logprobs = logprobs_of_labels(logits[:, :-1, :], tokens[:, 1:])
logprobs = logprobs_of_next_labels(logits, tokens)

start = query_tensors.shape[1] - 1
end = start + response_length
Expand Down Expand Up @@ -438,12 +438,12 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
ref_logits = ref_logits.to(device)

if self.config.model.model_arch_type == "seq2seq":
logprobs = logprobs_of_labels(logits[:, :-1, :], sample_outputs[:, 1:])
ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], sample_outputs[:, 1:])
logprobs = logprobs_of_next_labels(logits, sample_outputs)
ref_logprobs = logprobs_of_next_labels(ref_logits, sample_outputs)
else:
# NOTE: logprob[i] is (log)prob at which all_token[i+1] was sampled
logprobs = logprobs_of_labels(logits[:, :-1, :], all_tokens[:, 1:])
ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], all_tokens[:, 1:])
logprobs = logprobs_of_next_labels(logits, all_tokens)
ref_logprobs = logprobs_of_next_labels(ref_logits, all_tokens)

n_samples: int = samples.shape[0]

Expand Down
7 changes: 7 additions & 0 deletions trlx/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,13 @@ def logprobs_of_labels(logits, labels):
return logprobs_labels.squeeze(-1)


def logprobs_of_next_labels(logits, labels):
"""Log probabilities of the next labels, optimized for memory and speed"""
logits = logits.view(labels.numel(), -1)
shift_labels = F.pad(labels[:, 1:], (0, 1, 0, 0), value=-100).view(-1)
return -F.cross_entropy(logits, shift_labels, reduction="none").view(labels.shape)[:, :-1]


def flatten_dict(
d: Union[dict, MutableMapping],
parent_key: str = "",
Expand Down

0 comments on commit 0b1da65

Please sign in to comment.