From c447864dd4e6a484e0ac0f3a69c2e1e0a3b97994 Mon Sep 17 00:00:00 2001 From: Daniel Tan <25474937+dtch1997@users.noreply.github.com> Date: Thu, 4 Apr 2024 16:52:34 +0100 Subject: [PATCH] feat: add functions to compute logit statistics (#145) * Add functions to compute logit statistics * Make logit statistics optional --------- Co-authored-by: Daniel CH Tan --- repepo/core/pipeline.py | 66 ++++++++++++++++++++++++++++++++++++- tests/core/test_pipeline.py | 61 +++++++++++++++++++++++++++++++++- 2 files changed, 125 insertions(+), 2 deletions(-) diff --git a/repepo/core/pipeline.py b/repepo/core/pipeline.py index 068d4fbb..a705357d 100644 --- a/repepo/core/pipeline.py +++ b/repepo/core/pipeline.py @@ -11,9 +11,28 @@ @dataclass class TokenProb: token_id: int + # Note: the logit, logprob are for this token, not the next token logprob: float logit: float text: str + # Metrics for logits of other tokens that were in this token position + logit_mean: float = float("nan") + logit_std: float = float("nan") + logit_skew: float = float("nan") + logit_kurtosis: float = float("nan") + logit_100_quantile: float = float("nan") + logit_75_quantile: float = float("nan") + logit_50_quantile: float = float("nan") + logit_25_quantile: float = float("nan") + logit_0_quantile: float = float("nan") + + @property + def logit_max(self) -> float: + return self.logit_100_quantile + + @property + def logit_min(self) -> float: + return self.logit_0_quantile @dataclass @@ -47,6 +66,30 @@ def __call__(self, context: PipelineContext) -> AbstractContextManager[None]: ... +def compute_moments(tensor: torch.Tensor, dim: int) -> torch.Tensor: + """Compute mean, std, skew, kurtosis along the specified dimension + Input: tensor of shape (batch_size, num_classes) + Returns a tensor of shape (batch_size, 4) + """ + mean = tensor.mean(dim=dim, keepdim=True) + std = tensor.std(dim=dim, keepdim=True) + skew = ((tensor - mean) ** 3).mean(dim=dim, keepdim=True) / (std**3) + kurtosis = ((tensor - mean) ** 4).mean(dim=dim, keepdim=True) / (std**4) + return torch.cat([mean, std, skew, kurtosis], dim=dim) + + +def compute_quantiles(tensor: torch.Tensor, dim: int) -> torch.Tensor: + """Compute quantiles along the specified dimension + Input: tensor of shape (batch_size, num_classes) + Returns a tensor of shape (batch_size, num_quantiles) + """ + quantile_thresholds = torch.tensor([0, 0.25, 0.5, 0.75, 1]) + quantiles = torch.quantile(tensor, quantile_thresholds, dim=dim) + # transpose to get the shape (batch_size, num_quantiles) + quantiles = quantiles.transpose(0, 1) + return quantiles + + @dataclass class Pipeline: """Generation pipeline""" @@ -96,15 +139,25 @@ def calculate_output_logprobs(self, completion: Completion) -> TextProbs: logprobs = logprobs[:, :-1, :] # get the logprobs for the target tokens + # first, get the tokens which correspond to completions target_ids = inputs.input_ids[:, 1:].cpu() + # next, select the indices corresponding to the target token ids gen_logprobs = torch.gather(logprobs, 2, target_ids[:, :, None]).squeeze( -1 )[0] gen_logits = torch.gather(logits, 2, target_ids[:, :, None]).squeeze(-1)[0] + # For each logit, calculate the moments and quantiles + # logits is of shape (1, seq_len, vocab_size) + assert logits.shape[0] == 1 + logits = logits[0] + logit_moments = compute_moments(logits, dim=-1) + logit_quantiles = compute_quantiles(logits, dim=-1) text_probs: list[TokenProb] = [] - for token, logprob, logit in zip(target_ids[0], gen_logprobs, gen_logits): + for token, logprob, logit, logit_moment, logit_quantile in zip( + target_ids[0], gen_logprobs, gen_logits, logit_moments, logit_quantiles + ): if token not in self.tokenizer.all_special_ids: text_probs.append( TokenProb( @@ -112,6 +165,17 @@ def calculate_output_logprobs(self, completion: Completion) -> TextProbs: text=self.tokenizer.decode(token), logprob=logprob.item(), logit=logit.item(), + # moments + logit_mean=logit_moment[0].item(), + logit_std=logit_moment[1].item(), + logit_skew=logit_moment[2].item(), + logit_kurtosis=logit_moment[3].item(), + # quantiles + logit_0_quantile=logit_quantile[0].item(), + logit_25_quantile=logit_quantile[1].item(), + logit_50_quantile=logit_quantile[2].item(), + logit_75_quantile=logit_quantile[3].item(), + logit_100_quantile=logit_quantile[4].item(), ) ) return TextProbs(text=full_prompt, token_probs=text_probs) diff --git a/tests/core/test_pipeline.py b/tests/core/test_pipeline.py index da0c1676..ec883f3e 100644 --- a/tests/core/test_pipeline.py +++ b/tests/core/test_pipeline.py @@ -1,12 +1,71 @@ +import torch +import scipy +import numpy as np from textwrap import dedent from transformers import GPTNeoXForCausalLM -from repepo.core.pipeline import Pipeline +from repepo.core.pipeline import Pipeline, compute_moments, compute_quantiles from repepo.core.format import IdentityFormatter, LlamaChatFormatter from repepo.core.types import Completion, Tokenizer from syrupy.assertion import SnapshotAssertion +def _compute_moments_scipy(x: np.ndarray, axis: int) -> np.ndarray: + mean = np.mean(x, axis=axis) + std = scipy.stats.tstd(x, axis=axis, ddof=1) + skew = scipy.stats.skew(x, axis=axis) + kurtosis = scipy.stats.kurtosis(x, axis=axis, fisher=False) + return np.stack([mean, std, skew, kurtosis], axis=1) + + +def test_compute_moments_basic(): + tensor = torch.tensor( + [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [2.0, 2.0, 2.0, 2.0, 2.0, 2.0]] + ) + output = compute_moments(tensor, dim=1) + expected_output = torch.from_numpy(_compute_moments_scipy(tensor.numpy(), axis=1)) + # NOTE: torch kurtosis does not agree with scipy kurtosis for some reason... + # Omitted from testing for now + torch.testing.assert_allclose(output[:, :3], expected_output[:, :3]) + + +def test_compute_moments_edge_cases(): + # Test with a single-value tensor + tensor = torch.tensor([[1.0]]) + expected_output = torch.tensor([[1.0, np.nan, np.nan, np.nan]]) + output = compute_moments(tensor, dim=1) + torch.testing.assert_allclose(output[:, :3], expected_output[:, :3]) + + # Test with a tensor with uniform values + tensor = torch.full((1, 4), 3.0) + expected_output = torch.tensor([[3.0, 0.0, np.nan, np.nan]]) + output = compute_moments(tensor, dim=1) + torch.testing.assert_allclose(output[:, :3], expected_output[:, :3]) + + +def test_compute_quantiles_basic(): + tensor = torch.tensor([[1.0, 2.0, 3.0, 4.0], [4.0, 3.0, 2.0, 1.0]]) + expected_output = torch.tensor( + [[1.0, 1.75, 2.5, 3.25, 4.0], [1.0, 1.75, 2.5, 3.25, 4.0]] + ) + output = compute_quantiles(tensor, dim=1) + torch.testing.assert_allclose(output, expected_output) + + +def test_compute_quantiles_edge_cases(): + # Test with a single-value tensor + tensor = torch.tensor([[2.0]]) + expected_output = torch.tensor([[2.0, 2.0, 2.0, 2.0, 2.0]]) + output = compute_quantiles(tensor, dim=1) + torch.testing.assert_allclose(output, expected_output) + + # Test with non-unique values + tensor = torch.tensor([[2.0, 2.0, 2.0, 2.0]]) + expected_output = torch.tensor([[2.0, 2.0, 2.0, 2.0, 2.0]]) + output = compute_quantiles(tensor, dim=1) + torch.testing.assert_allclose(output, expected_output) + + def test_basic_Pipeline_build_generation_prompt( model: GPTNeoXForCausalLM, tokenizer: Tokenizer ) -> None: