Skip to content

Commit

Permalink
feat: add functions to compute logit statistics (#145)
Browse files Browse the repository at this point in the history
* Add functions to compute logit statistics

* Make logit statistics optional

---------

Co-authored-by: Daniel CH Tan <dtch1997@users.noreply.github.com>
  • Loading branch information
dtch1997 and dtch1997 authored Apr 4, 2024
1 parent 037bea3 commit c447864
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 2 deletions.
66 changes: 65 additions & 1 deletion repepo/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,28 @@
@dataclass
class TokenProb:
token_id: int
# Note: the logit, logprob are for this token, not the next token
logprob: float
logit: float
text: str
# Metrics for logits of other tokens that were in this token position
logit_mean: float = float("nan")
logit_std: float = float("nan")
logit_skew: float = float("nan")
logit_kurtosis: float = float("nan")
logit_100_quantile: float = float("nan")
logit_75_quantile: float = float("nan")
logit_50_quantile: float = float("nan")
logit_25_quantile: float = float("nan")
logit_0_quantile: float = float("nan")

@property
def logit_max(self) -> float:
return self.logit_100_quantile

@property
def logit_min(self) -> float:
return self.logit_0_quantile


@dataclass
Expand Down Expand Up @@ -47,6 +66,30 @@ def __call__(self, context: PipelineContext) -> AbstractContextManager[None]:
...


def compute_moments(tensor: torch.Tensor, dim: int) -> torch.Tensor:
"""Compute mean, std, skew, kurtosis along the specified dimension
Input: tensor of shape (batch_size, num_classes)
Returns a tensor of shape (batch_size, 4)
"""
mean = tensor.mean(dim=dim, keepdim=True)
std = tensor.std(dim=dim, keepdim=True)
skew = ((tensor - mean) ** 3).mean(dim=dim, keepdim=True) / (std**3)
kurtosis = ((tensor - mean) ** 4).mean(dim=dim, keepdim=True) / (std**4)
return torch.cat([mean, std, skew, kurtosis], dim=dim)


def compute_quantiles(tensor: torch.Tensor, dim: int) -> torch.Tensor:
"""Compute quantiles along the specified dimension
Input: tensor of shape (batch_size, num_classes)
Returns a tensor of shape (batch_size, num_quantiles)
"""
quantile_thresholds = torch.tensor([0, 0.25, 0.5, 0.75, 1])
quantiles = torch.quantile(tensor, quantile_thresholds, dim=dim)
# transpose to get the shape (batch_size, num_quantiles)
quantiles = quantiles.transpose(0, 1)
return quantiles


@dataclass
class Pipeline:
"""Generation pipeline"""
Expand Down Expand Up @@ -96,22 +139,43 @@ def calculate_output_logprobs(self, completion: Completion) -> TextProbs:
logprobs = logprobs[:, :-1, :]

# get the logprobs for the target tokens
# first, get the tokens which correspond to completions
target_ids = inputs.input_ids[:, 1:].cpu()
# next, select the indices corresponding to the target token ids
gen_logprobs = torch.gather(logprobs, 2, target_ids[:, :, None]).squeeze(
-1
)[0]
gen_logits = torch.gather(logits, 2, target_ids[:, :, None]).squeeze(-1)[0]

# For each logit, calculate the moments and quantiles
# logits is of shape (1, seq_len, vocab_size)
assert logits.shape[0] == 1
logits = logits[0]
logit_moments = compute_moments(logits, dim=-1)
logit_quantiles = compute_quantiles(logits, dim=-1)
text_probs: list[TokenProb] = []

for token, logprob, logit in zip(target_ids[0], gen_logprobs, gen_logits):
for token, logprob, logit, logit_moment, logit_quantile in zip(
target_ids[0], gen_logprobs, gen_logits, logit_moments, logit_quantiles
):
if token not in self.tokenizer.all_special_ids:
text_probs.append(
TokenProb(
token_id=token.item(),
text=self.tokenizer.decode(token),
logprob=logprob.item(),
logit=logit.item(),
# moments
logit_mean=logit_moment[0].item(),
logit_std=logit_moment[1].item(),
logit_skew=logit_moment[2].item(),
logit_kurtosis=logit_moment[3].item(),
# quantiles
logit_0_quantile=logit_quantile[0].item(),
logit_25_quantile=logit_quantile[1].item(),
logit_50_quantile=logit_quantile[2].item(),
logit_75_quantile=logit_quantile[3].item(),
logit_100_quantile=logit_quantile[4].item(),
)
)
return TextProbs(text=full_prompt, token_probs=text_probs)
Expand Down
61 changes: 60 additions & 1 deletion tests/core/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,71 @@
import torch
import scipy
import numpy as np
from textwrap import dedent
from transformers import GPTNeoXForCausalLM

from repepo.core.pipeline import Pipeline
from repepo.core.pipeline import Pipeline, compute_moments, compute_quantiles
from repepo.core.format import IdentityFormatter, LlamaChatFormatter
from repepo.core.types import Completion, Tokenizer
from syrupy.assertion import SnapshotAssertion


def _compute_moments_scipy(x: np.ndarray, axis: int) -> np.ndarray:
mean = np.mean(x, axis=axis)
std = scipy.stats.tstd(x, axis=axis, ddof=1)
skew = scipy.stats.skew(x, axis=axis)
kurtosis = scipy.stats.kurtosis(x, axis=axis, fisher=False)
return np.stack([mean, std, skew, kurtosis], axis=1)


def test_compute_moments_basic():
tensor = torch.tensor(
[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [2.0, 2.0, 2.0, 2.0, 2.0, 2.0]]
)
output = compute_moments(tensor, dim=1)
expected_output = torch.from_numpy(_compute_moments_scipy(tensor.numpy(), axis=1))
# NOTE: torch kurtosis does not agree with scipy kurtosis for some reason...
# Omitted from testing for now
torch.testing.assert_allclose(output[:, :3], expected_output[:, :3])


def test_compute_moments_edge_cases():
# Test with a single-value tensor
tensor = torch.tensor([[1.0]])
expected_output = torch.tensor([[1.0, np.nan, np.nan, np.nan]])
output = compute_moments(tensor, dim=1)
torch.testing.assert_allclose(output[:, :3], expected_output[:, :3])

# Test with a tensor with uniform values
tensor = torch.full((1, 4), 3.0)
expected_output = torch.tensor([[3.0, 0.0, np.nan, np.nan]])
output = compute_moments(tensor, dim=1)
torch.testing.assert_allclose(output[:, :3], expected_output[:, :3])


def test_compute_quantiles_basic():
tensor = torch.tensor([[1.0, 2.0, 3.0, 4.0], [4.0, 3.0, 2.0, 1.0]])
expected_output = torch.tensor(
[[1.0, 1.75, 2.5, 3.25, 4.0], [1.0, 1.75, 2.5, 3.25, 4.0]]
)
output = compute_quantiles(tensor, dim=1)
torch.testing.assert_allclose(output, expected_output)


def test_compute_quantiles_edge_cases():
# Test with a single-value tensor
tensor = torch.tensor([[2.0]])
expected_output = torch.tensor([[2.0, 2.0, 2.0, 2.0, 2.0]])
output = compute_quantiles(tensor, dim=1)
torch.testing.assert_allclose(output, expected_output)

# Test with non-unique values
tensor = torch.tensor([[2.0, 2.0, 2.0, 2.0]])
expected_output = torch.tensor([[2.0, 2.0, 2.0, 2.0, 2.0]])
output = compute_quantiles(tensor, dim=1)
torch.testing.assert_allclose(output, expected_output)


def test_basic_Pipeline_build_generation_prompt(
model: GPTNeoXForCausalLM, tokenizer: Tokenizer
) -> None:
Expand Down

0 comments on commit c447864

Please sign in to comment.