Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optional gate activation histogram logging during eval #641

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/levanter/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def eval_loss_loop(loss_fn, model, dataset, max_batches: Optional[int] = None, n
load_time = time.time() - time_in
total_load_time += load_time
loss = loss_fn(model, batch)
if isinstance(loss, tuple):
loss, _ = loss
total_loss += loss.item()
n += 1
loss_time = time.time() - time_in - load_time
Expand Down
36 changes: 29 additions & 7 deletions src/levanter/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from levanter.data import Dataset, ReplicatedBatchLoader
from levanter.logging import LoadingTimeTrackerIterator
from levanter.models.lm_model import LmExample, LmHeadModel
from levanter.tracker.histograms import NBINS
from levanter.trainer import StepInfo
from levanter.utils.stat_utils import RunningMean
from levanter.utils.tree_utils import inference_mode
Expand All @@ -34,6 +35,7 @@ class EvalResult:
tag_macro_losses: dict[str, float] # per tag average-per-token loss
tag_micro_losses: dict[str, float] # per tag total loss, for "parent" tags
total_eval_loading_time: float
extras: dict[str, float]


class DomainTaggedDataset(Dataset[tuple[T, hax.NamedArray]]):
Expand Down Expand Up @@ -123,6 +125,17 @@ def eval_callback(step: StepInfo):
_join_prefix(prefix, "loading_time"): result.total_eval_loading_time,
_join_prefix(prefix, "total_time"): time_fn(),
}
if (gate_hist := result.extras.get("gate_hist", None)) is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so i think i'm gonna have a strong preference for

  1. extracting this block (and the part in the loop) into a class (sort of like runningmean)
  2. not actually checking the usage of it in taggedevaluator (or in the models) into main, but instead
  3. making a little guide on how to add it in, since it's something that people want to play with sometimes but kinda adds a bunch of noise

pos_idx = NBINS // 2 + 1
log_dict[_join_prefix(prefix, "gate_hist/all")] = np.array(gate_hist.sum(axis=0))
num_gt0 = gate_hist[:, pos_idx:].sum().item()
total = gate_hist.sum().item()
log_dict[_join_prefix(prefix, "gate_gt0/all")] = num_gt0 / total
for i in range(gate_hist.shape[0]):
log_dict[_join_prefix(prefix, f"gate_hist/layer{i+1}")] = np.array(gate_hist[i])
num_gt0 = gate_hist[i, pos_idx:].sum().item()
total = gate_hist[i].sum().item()
log_dict[_join_prefix(prefix, f"gate_gt0/layer{i+1}")] = num_gt0 / total

logger.info(f"{prefix} loss: {result.micro_avg_loss:.3f}")
for tag, loss in result.tag_macro_losses.items():
Expand Down Expand Up @@ -185,12 +198,12 @@ def __init__(

@hax.named_jit(out_axis_resources=axis_mapping)
def accum_for_batch(
m: LmHeadModel, state: tuple[RunningMean, RunningMean], batch: LmExample, tags: hax.NamedArray
m: LmHeadModel, state: tuple[RunningMean, RunningMean, dict], batch: LmExample, tags: hax.NamedArray
):
m = inference_mode(m, True)
with hax.axis_mapping(axis_mapping):
total_mean, mean_per_tag = state
losses = m.compute_loss(batch, reduction=None, reduction_axis=())
total_mean, mean_per_tag, total_extras = state
losses, extras = m.compute_loss(batch, reduction=None, reduction_axis=())
mask = batch.loss_mask # [Batch, Token]
this_tokens = hax.einsum("->", mask)
this_loss = hax.einsum("->", losses, mask) # to scalar
Expand All @@ -203,23 +216,32 @@ def accum_for_batch(
safe_mean = hax.where(this_tokens_per_tag, this_loss_per_tag / this_tokens_per_tag, 0.0)
mean_per_tag = mean_per_tag.add(safe_mean, this_tokens_per_tag)

return mean, mean_per_tag
if extras:
for key in extras:
curr = total_extras.get(key, jnp.zeros_like(extras[key]))
total_extras[key] = extras[key] + curr
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is summing always going to be the right reduction here?


return mean, mean_per_tag, total_extras

self.accum_for_batch = accum_for_batch

def evaluate(self, m: LmHeadModel):
total_loss = jnp.zeros(())
mean_losses_per_tag = hax.zeros(self.dataset.Tag, dtype=np.float32)

state = (RunningMean.zeros_like(total_loss), RunningMean.zeros_like(mean_losses_per_tag))
state: tuple[RunningMean, RunningMean, dict] = (
RunningMean.zeros_like(total_loss),
RunningMean.zeros_like(mean_losses_per_tag),
{},
)
state = hax.shard(state)

iterator = LoadingTimeTrackerIterator(self.loader)

for batch, tags in tqdm.tqdm(iterator, "eval"):
state = self.accum_for_batch(m, state, batch, tags)

total_loss, losses_per_tag = state
total_loss, losses_per_tag, extras = state

micro_avg_loss = total_loss.mean.item()
tag_avg_loss = losses_per_tag.mean
Expand Down Expand Up @@ -252,4 +274,4 @@ def evaluate(self, m: LmHeadModel):
tag_micro_loss[tag] = mean_loss_per_tag_cpu[index]
# no macro loss for the leaf tags

return EvalResult(micro_avg_loss, macro_avg_loss, tag_macro_loss, tag_micro_loss, iterator.total_time)
return EvalResult(micro_avg_loss, macro_avg_loss, tag_macro_loss, tag_micro_loss, iterator.total_time, extras)
2 changes: 1 addition & 1 deletion src/levanter/main/viz_logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def main(config: VizGpt2Config):
def compute_log_probs(model: LmHeadModel, example: LmExample):
model = inference_mode(model, True)
model = mp.cast_to_compute(model)
logprobs = model.compute_loss(example, reduction=None)
logprobs, _ = model.compute_loss(example, reduction=None)
# roll forward to get the loss for each predicted token
logprobs = hax.roll(logprobs, 1, Pos)
return logprobs.rearrange((EvalBatch, Pos)).array
Expand Down
4 changes: 2 additions & 2 deletions src/levanter/models/asr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,15 @@ def compute_loss(
across the reduction axis (with reduction_axis=None meaning all axes). If reduction is None, the loss is not
reduced, and the result is a named array with axes (*batch axes, sequence_length).
"""
logits = self(example.audio, example.tokens, example.attn_mask, key=key)
logits, extras = self(example.audio, example.tokens, example.attn_mask, key=key)
logits = logits.astype(jnp.float32)
targets = hax.roll(example.tokens, -1, axis=self.Pos.name)
target_y = hax.nn.one_hot(targets, self.Vocab, dtype=logits.dtype)
loss = cross_entropy_loss(
logits, self.Vocab, target_y, reduction, reduction_axis=reduction_axis, where=example.loss_mask
)

return loss
return loss, extras

@property
def vocab_size(self) -> int:
Expand Down
24 changes: 15 additions & 9 deletions src/levanter/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from levanter.models.lm_model import LmConfig, LmHeadModel
from levanter.types import BlockFoldable
from levanter.utils.flop_utils import lm_flops_per_token
from levanter.utils.py_utils import cached_classproperty


silence_transformer_nag()
Expand Down Expand Up @@ -75,7 +76,6 @@ class GemmaConfig(HFCompatConfig):
vocab_size: int = 256_000
num_layers: int = 18
num_heads: int = 8
head_dim: int = 256
num_kv_heads: int = 1
attn_dropout = 0.0
norm_eps = 1e-6
Expand Down Expand Up @@ -106,10 +106,14 @@ class GemmaConfig(HFCompatConfig):
Mlp = property(lambda self: Axis(name="mlp", size=self.intermediate_dim))
HeadSize = property(lambda self: Axis(name="head_size", size=self.hidden_dim // self.num_heads))

@property
def head_dim(self) -> int: return self.hidden_dim // self.num_heads

def __post_init__(self):
assert (
self.num_heads % self.num_kv_heads == 0
), f"num_heads={self.num_heads} not divisible by num_kv_heads={self.num_kv_heads}."
assert (self.head_dim * self.num_heads) == self.hidden_dim, "head_dim * num_heads must equal hidden_dim."

def hf_checkpoint_converter(self) -> HFCheckpointConverter["GemmaConfig"]: # type: ignore
return HFCheckpointConverter(
Expand All @@ -129,7 +133,9 @@ def from_hf_config(cls, hf_config: HfConfig):
if hf_config.hidden_activation:
activation_function = hf_config.hidden_activation
else:
activation_function = hf_config.hidden_act
# This is the implementation in huggingface
# https://github.com/huggingface/transformers/blob/12b1620e615592fbf099d4ec44af7b9f2d1b48aa/src/transformers/models/gemma/modeling_gemma.py#L200
activation_function = "gelu_pytorch_tanh"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i swore we already did this


if activation_function == "gelu_pytorch_tanh":
activation_function = "gelu_new"
Expand Down Expand Up @@ -168,7 +174,7 @@ def to_hf_config(self, vocab_size: int, config_overrides: Optional[Dict] = None)
num_hidden_layers=self.num_layers,
num_attention_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
head_dim=self.hidden_dim // self.num_heads,
head_dim=self.head_dim,
hidden_activation=(
"gelu_pytorch_tanh" if self.activation_function == "gelu_new" else self.activation_function
),
Expand Down Expand Up @@ -263,9 +269,9 @@ def __call__(self, x: NamedArray, mask: Optional[NamedArray | AttentionMask], *,
# MLP and skip connection
residual = x
x = self.post_attention_layernorm(x)
mlp_output = self.mlp(x, key=k_mlp)
mlp_output, extras = self.mlp(x, key=k_mlp)
output = residual + mlp_output
return output
return output, extras


class GemmaTransformer(StateDictSerializationMixin, eqx.Module):
Expand All @@ -292,10 +298,10 @@ def init(config: GemmaConfig, *, key) -> "GemmaTransformer":
@named_call
def __call__(self, x: NamedArray, attn_mask: Optional[NamedArray | AttentionMask], *, key) -> NamedArray:
keys = maybe_rng_split(key, self.config.num_layers) if key is not None else None
x = self.layers.fold(x, mask=attn_mask, key=keys)
x, extras = self.layers.scan(x, mask=attn_mask, key=keys)
x = self.norm(x)

return x
return x, extras

def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None):
if isinstance(self.layers, Stacked):
Expand Down Expand Up @@ -358,9 +364,9 @@ def __call__(
The attn_mask from training pipeline may be an AttentionMask object instead of NamedArray
"""
x = self.embeddings.embed(input_ids)
x = self.transformer(x, attn_mask=attn_mask, key=key)
x, extras = self.transformer(x, attn_mask=attn_mask, key=key)
lm_logits = self.embeddings.unembed(x)
return lm_logits
return lm_logits, extras

def resize_vocab(self, new_size: int, key=None) -> "LmHeadModel[GemmaConfig]":
new_embeddings = self.embeddings.resize_embeddings(new_size, key=key)
Expand Down
1 change: 1 addition & 0 deletions src/levanter/models/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from levanter.models.attention import AttentionBackend, AttentionMask, dot_product_attention
from levanter.models.lm_model import LmConfig
from levanter.utils.flop_utils import lm_flops_per_token
from levanter.utils.py_utils import cached_classproperty


silence_transformer_nag()
Expand Down
33 changes: 24 additions & 9 deletions src/levanter/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from levanter.models.lm_model import LmConfig, LmHeadModel
from levanter.types import BlockFoldable
from levanter.utils.flop_utils import lm_flops_per_token
from levanter.utils.py_utils import cached_classproperty
from levanter.tracker.histograms import get_bins, sharded_histogram


silence_transformer_nag()
Expand Down Expand Up @@ -78,6 +80,7 @@ class LlamaConfig(HFCompatConfig):
use_bias: bool = False
use_layer_norm_weight: bool = True
rope_scaling: Optional[dict] = None
measure_act_stats: bool = True

reference_checkpoint: str = "meta-llama/Llama-2-7b-hf"
tokenizer: Optional[str] = None
Expand Down Expand Up @@ -181,10 +184,17 @@ class LlamaMlp(eqx.Module, StateDictSerializationMixin):
up_proj: hnn.Linear # projection from Embed to Mlp
down_proj: hnn.Linear # projection from Mlp to Embed
act: Callable = eqx.static_field()
measure_act_stats: bool = False

@staticmethod
def init(
Embed: Axis, Mlp: Axis, activation_fn: Union[str, Callable], *, key, use_bias: bool = False
Embed: Axis,
Mlp: Axis,
activation_fn: Union[str, Callable],
*,
key,
use_bias: bool = False,
measure_act_stats=False,
) -> "LlamaMlp":
k_fc, k_up_proj, k_down_proj = jrandom.split(key, 3)
gate_proj = hnn.Linear.init(Out=Mlp, In=Embed, key=k_fc, use_bias=use_bias, out_first=True)
Expand All @@ -193,16 +203,20 @@ def init(
if isinstance(activation_fn, str):
activation_fn = ACT2FN[activation_fn]
act = activation_fn # type: ignore
return LlamaMlp(gate_proj, up_proj, down_proj, act)
get_bins() # initialize bins
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rm?

return LlamaMlp(gate_proj, up_proj, down_proj, act, measure_act_stats)

@named_call
def __call__(self, x: NamedArray, *, key=None) -> NamedArray:
k_gate, k_up, k_down = maybe_rng_split(key, 3)
hidden_states = self.gate_proj(x, key=k_gate)
extras = {}
if self.measure_act_stats:
extras["gate_hist"] = sharded_histogram(hidden_states.array, bins=get_bins())
hidden_states = self.act(hidden_states)
hidden_states = hidden_states * self.up_proj(x, key=k_up)
outputs = self.down_proj(hidden_states, key=k_down)
return outputs
return outputs, extras

def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None):
# unflatten the linear layers of HF state_dict to match the shape of LlamaMlp
Expand Down Expand Up @@ -402,6 +416,7 @@ def init(config: LlamaConfig, *, key) -> "LlamaDecoderLayer":
config.activation_function,
key=k_mlp,
use_bias=config.use_bias,
measure_act_stats=config.measure_act_stats,
)
ln_1 = config.mk_LayerNorm(config.Embed)
ln_2 = config.mk_LayerNorm(config.Embed)
Expand All @@ -420,9 +435,9 @@ def __call__(self, x: NamedArray, mask: Optional[NamedArray | AttentionMask], *,
# MLP and skip connection
residual = x
x = self.post_attention_layernorm(x)
mlp_output = self.mlp(x, key=k_mlp)
mlp_output, extras = self.mlp(x, key=k_mlp)
output = residual + mlp_output
return output
return output, extras


class LlamaTransformer(StateDictSerializationMixin, eqx.Module):
Expand All @@ -449,10 +464,10 @@ def init(config: LlamaConfig, *, key) -> "LlamaTransformer":
@named_call
def __call__(self, x: NamedArray, attn_mask: Optional[NamedArray | AttentionMask], *, key) -> NamedArray:
keys = maybe_rng_split(key, self.config.num_layers) if key is not None else None
x = self.layers.fold(x, mask=attn_mask, key=keys)
x, extras = self.layers.scan(x, mask=attn_mask, key=keys)
x = self.norm(x)

return x
return x, extras

def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None):
if isinstance(self.layers, Stacked):
Expand Down Expand Up @@ -544,9 +559,9 @@ def __call__(
"""
k_t, k_head = maybe_rng_split(key, 2)
x = self.embeddings.embed(input_ids)
x = self.transformer(x, attn_mask=attn_mask, key=k_t)
x, extras = self.transformer(x, attn_mask=attn_mask, key=k_t)
lm_logits = self.lm_head(x, key=k_head)
return lm_logits
return lm_logits, extras

def resize_vocab(self, new_size: int, key=None) -> "LmHeadModel[LlamaConfig]":
new_Vocab = self.Vocab.resize(new_size)
Expand Down
6 changes: 5 additions & 1 deletion src/levanter/models/lm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ def compute_loss(
reduced, and the result is a named array with axes (*batch axes, sequence_length).
"""
logits = self(example.tokens, example.attn_mask, key=key)
extras = None
if isinstance(logits, tuple):
assert len(logits) == 2
logits, extras = logits
# TODO: would be nice if we made the dtype configurable
logits = logits.astype(jnp.float32)
targets = hax.roll(example.tokens, -1, axis=self.Pos.name)
Expand All @@ -135,7 +139,7 @@ def compute_loss(
logits, self.Vocab, target_y, reduction, reduction_axis=reduction_axis, where=example.loss_mask
)

return loss
return loss, extras

@property
def vocab_size(self) -> int:
Expand Down
5 changes: 3 additions & 2 deletions src/levanter/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from levanter.models.llama import LlamaConfig, LlamaEmbedding, LlamaTransformer
from levanter.models.lm_model import LmConfig, LmHeadModel
from levanter.utils.flop_utils import lm_flops_per_token
from levanter.utils.py_utils import cached_classproperty


silence_transformer_nag()
Expand Down Expand Up @@ -192,9 +193,9 @@ def __call__(
"""
k_t, k_head = maybe_rng_split(key, 2)
x = self.embeddings.embed(input_ids)
x = self.transformer(x, attn_mask=attn_mask, key=k_t)
x, extras = self.transformer(x, attn_mask=attn_mask, key=k_t)
lm_logits = self.lm_head(x, key=k_head)
return lm_logits
return lm_logits, extras

def resize_vocab(self, new_size: int, key=None) -> "LmHeadModel[MistralConfig]":
new_Vocab = self.Vocab.resize(new_size)
Expand Down
Loading
Loading