Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
130 commits
Select commit Hold shift + click to select a range
552179c
Mixed distillation loss
jlamypoirier Jul 3, 2025
201417c
fix
jlamypoirier Jul 3, 2025
88d7b3d
misc
jlamypoirier Jul 3, 2025
7d9d5f8
wip
oleksost Jul 3, 2025
41b71b6
wip
oleksost Jul 3, 2025
1fe02b3
reverse kl and loss factors
oleksost Jul 4, 2025
7c09e67
hack to load model with attention_factor none
oleksost Jul 4, 2025
cabcf40
reverse kl etc.
oleksost Jul 4, 2025
1413c1e
removed notebook
oleksost Jul 4, 2025
f6d84cd
loss defs
oleksost Jul 4, 2025
892e2aa
nvm
oleksost Jul 4, 2025
6018a7e
addressed comments
oleksost Jul 9, 2025
c2509a5
Merge branch 'main' into mixed_distillation_loss
oleksost Jul 11, 2025
177155a
wip
oleksost Jul 11, 2025
8e5c5b0
Merge branch 'mixed_distillation_loss' into m1mamba
oleksost Jul 11, 2025
8d94390
bug
oleksost Jul 11, 2025
bd21b74
small bug
oleksost Jul 11, 2025
2012a23
Merge branch 'mixed_distillation_loss' of https://github.com/ServiceN…
oleksost Jul 11, 2025
4e83075
wip
oleksost Jul 11, 2025
59f42e3
wip
oleksost Jul 11, 2025
73fe122
wip
oleksost Jul 11, 2025
424bc94
wip
oleksost Jul 11, 2025
264bd50
wip
oleksost Jul 11, 2025
096ad63
missing hps in checkpoint
oleksost Jul 11, 2025
3579dbe
wip
oleksost Jul 11, 2025
656cae3
loss type
oleksost Jul 11, 2025
8146958
convertion m2
oleksost Jul 12, 2025
dbb5334
interative training
oleksost Jul 14, 2025
75c3a52
convertion with mil
oleksost Jul 14, 2025
235c018
save correct config
oleksost Jul 14, 2025
1575e90
save config
oleksost Jul 14, 2025
9e86c5e
click
oleksost Jul 14, 2025
7a546a5
comments
oleksost Jul 14, 2025
dedb838
seperate DistillationLossImpl
oleksost Jul 14, 2025
f5acafb
test
oleksost Jul 14, 2025
c9b65c0
Merge branch 'mixed_distillation_loss' into mamba2
oleksost Jul 14, 2025
ddd27eb
wip
oleksost Jul 14, 2025
5afaac2
make hybrid
oleksost Jul 15, 2025
9d3ef2e
nvm
oleksost Jul 15, 2025
0a21503
eval wrapper
oleksost Jul 15, 2025
446138c
modelling bug
oleksost Jul 15, 2025
0fec714
modeling
oleksost Jul 15, 2025
be5da95
clean modeling
oleksost Jul 16, 2025
879bed2
corrected modeling
oleksost Jul 16, 2025
73a70ec
Merge branch 'main' into mamba2
oleksost Jul 16, 2025
9a88157
removed if __name__
oleksost Jul 16, 2025
c7d21ee
clean
oleksost Jul 16, 2025
89ed59d
clean
oleksost Jul 16, 2025
d71494e
clean
oleksost Jul 16, 2025
0043ca2
convertion clean up
oleksost Jul 16, 2025
aa5dce3
import trys
oleksost Jul 16, 2025
7871390
clean up
oleksost Jul 16, 2025
e0a9d41
Merge branch 'main' into mamba2
oleksost Jul 16, 2025
c311e9b
notebook
oleksost Jul 16, 2025
eaca220
test
oleksost Jul 16, 2025
ceaa76c
wip
oleksost Jul 16, 2025
976b010
wip
oleksost Jul 16, 2025
0e4c562
comments
oleksost Jul 16, 2025
2db0b8f
comments
oleksost Jul 16, 2025
81f08bc
clean
oleksost Jul 16, 2025
88ac516
nvm
oleksost Jul 16, 2025
037ee29
rename dim name
oleksost Jul 16, 2025
5661fd2
clean
oleksost Jul 16, 2025
59366e6
minor
oleksost Jul 17, 2025
d71b10f
conv innit bug
oleksost Jul 18, 2025
0a0c0dd
wip
oleksost Jul 18, 2025
fb3db80
Merge branch 'mamba2' of https://github.com/ServiceNow/Fast-LLM into …
oleksost Jul 18, 2025
83ef232
Merge branch 'main' into mamba2
oleksost Jul 18, 2025
bb9095e
notebooks
oleksost Jul 18, 2025
c89816d
Support for forward, generate and lm_eval for ssm models (#327)
bigximik Jul 25, 2025
b605bd2
doc
jlamypoirier Jul 25, 2025
7ea6d47
Merge remote-tracking branch 'origin/tp_mamba' into mamba2
oleksost Jul 28, 2025
5eea938
fix
jlamypoirier Jul 28, 2025
0a3e2a7
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 28, 2025
2e6d082
fixes
jlamypoirier Jul 28, 2025
b6c8613
misc
jlamypoirier Jul 28, 2025
f0c04cf
Merge remote-tracking branch 'origin/main' into debug_mamba
jlamypoirier Jul 28, 2025
acdfab1
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 28, 2025
e536af9
Concatenated dim
jlamypoirier Jul 28, 2025
017f5cc
fixes
jlamypoirier Jul 28, 2025
93e4c94
Merge branch 'concatenated_dim' into tp_mamba
jlamypoirier Jul 28, 2025
c41efc2
doc
jlamypoirier Jul 28, 2025
0b8bd5d
cleanup
jlamypoirier Jul 28, 2025
c8756c7
modeling
oleksost Jul 29, 2025
87c7db8
TP
oleksost Jul 29, 2025
f20f41e
TP
oleksost Jul 29, 2025
8ff6291
TP
oleksost Jul 29, 2025
6bf06d6
fix
jlamypoirier Jul 29, 2025
2ddc3a7
fix
jlamypoirier Jul 29, 2025
c0f1597
Merge branch 'concatenated_dim' into tp_mamba
jlamypoirier Jul 29, 2025
a35b8d9
modelling
oleksost Jul 30, 2025
87236ac
wip
oleksost Jul 30, 2025
4385136
modeling
oleksost Jul 30, 2025
288a955
wip
oleksost Jul 30, 2025
cef7c15
fix
jlamypoirier Jul 30, 2025
e3c6773
sequence tensor parallel
oleksost Jul 30, 2025
551b6ea
distributed kl
oleksost Jul 30, 2025
f7f30e1
wip
oleksost Jul 30, 2025
442f1a8
reverse kl distributed
oleksost Jul 30, 2025
af2961e
roll back reverse kl
oleksost Jul 30, 2025
7359700
kl test
oleksost Jul 30, 2025
3d660d8
undo cache change
oleksost Jul 31, 2025
8709792
ce remove clamp, add log softmax distributed
oleksost Jul 31, 2025
fb72cc9
clean up
oleksost Jul 31, 2025
05daed7
clean
oleksost Jul 31, 2025
f841388
old modeling with new cache
oleksost Jul 31, 2025
a62cbfb
new TP modeling but no cache
oleksost Jul 31, 2025
1032355
new TP modeling but no cache
oleksost Jul 31, 2025
481f0c5
remmoved generation requirements
oleksost Aug 1, 2025
301d25b
remmoved generation requirements
oleksost Aug 1, 2025
dfd2451
removed clamping in the test
oleksost Aug 1, 2025
8db1e62
Merge branch 'tp_mamba' into mamba2
oleksost Aug 1, 2025
1d4eabc
final modeling
oleksost Aug 1, 2025
c261805
modeling removed cache
oleksost Aug 2, 2025
6859d17
wip
oleksost Aug 4, 2025
88a5e8b
sequence tensor parallel
oleksost Aug 4, 2025
58e3e95
wip
oleksost Aug 6, 2025
882197d
fix transformer version
oleksost Aug 6, 2025
967669e
undo transformer version fix
oleksost Aug 6, 2025
94e090d
loss masking for tp distillation
oleksost Aug 6, 2025
11bcf11
fix transformers version
oleksost Aug 7, 2025
d004caf
reverse kl loss mask
oleksost Aug 7, 2025
1b55fa8
clamping
oleksost Aug 7, 2025
1bb50e1
use inference-runner corresponding to reference model
RaymondLi0 Aug 7, 2025
a8b5f28
clean
oleksost Aug 7, 2025
6bdd684
clean + revert reverse kl masked mean
oleksost Aug 7, 2025
069737b
revert inference runner
oleksost Aug 7, 2025
c802243
Merge remote-tracking branch 'origin/raymond/ref_model_inference_runn…
oleksost Aug 7, 2025
906aadc
modelling
oleksost Aug 7, 2025
0012511
Merge branch 'hybrid_dev' into mamba2_to_be_merged
oleksost Aug 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 43 additions & 6 deletions fast_llm/engine/config_utils/tensor_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,23 @@ def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self:
return TensorDim(self.name, self.size * distributed_dim.size, distributed_dim)

def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor":
if self.parallel_group is not None:
if self.is_parallel:
from fast_llm.core.ops import gather_op

return gather_op(tensor, self.parallel_group, dim)
else:
return tensor

def local_to_global_partial(
self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1
) -> "torch.Tensor":
if self.is_parallel:
output = tensor.new_full((*tensor.shape[:dim], self.parallel_dim.size, *tensor.shape[dim:]), fill_value)
output.narrow(dim, self.parallel_dim.rank, 1).copy_(tensor.unsqueeze(dim)).squeeze(dim)
return output.flatten(dim, dim + 1)
else:
return tensor

def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor":
return (
tensor.chunk(self.parallel_dim.size, dim)[self.parallel_dim.rank]
Expand All @@ -85,7 +95,7 @@ class CompositeTensorDim(TensorDim):
def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]):
parallel_dim = None
for dim, tensor_dim in enumerate(tensor_dims):
if tensor_dim.is_parallel:
if tensor_dim.parallel_dim is not None:
# TODO: Allow more than one parallel subdim?
assert parallel_dim is None
parallel_dim = tensor_dim.parallel_dim
Expand All @@ -111,6 +121,15 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor

return tensor.flatten(dim, dim + len(self._tensor_dims) - 1)

def local_to_global_partial(
self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1
) -> "torch.Tensor":
tensor = tensor.unflatten(dim, [tensor_dim.size for tensor_dim in self._tensor_dims])
for i, tensor_dim in enumerate(self._tensor_dims):
tensor = tensor_dim.local_to_global_partial(tensor, dim + i)

return tensor.flatten(dim, dim + len(self._tensor_dims) - 1)

def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor":
tensor = tensor.unflatten(dim, [tensor_dim.global_size for tensor_dim in self._tensor_dims])
for i, tensor_dim in reversed(list(enumerate(self._tensor_dims))):
Expand Down Expand Up @@ -157,6 +176,27 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor
else tensor
)

def local_to_global_partial(
self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1
) -> "torch.Tensor":
import torch

return (
torch.concatenate(
[
tensor_dim.local_to_global_partial(tensor_, dim)
for tensor_, tensor_dim in zip(
tensor.split([tensor_dim.size for tensor_dim in self._tensor_dims], dim),
self._tensor_dims,
strict=True,
)
],
dim,
)
if self.is_parallel
else tensor
)

def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor":
if self.is_parallel and expand:
raise NotImplementedError()
Expand Down Expand Up @@ -223,8 +263,5 @@ def add_tensor_dim(self, tensor_dim: TensorDim) -> None:
)
self._tensor_dims[tensor_dim.name] = tensor_dim

def get_tensor_dim(self, name: str) -> TensorDim:
def __getitem__(self, name: str) -> TensorDim:
return self._tensor_dims[name]

# TODO: Replace uses
__getitem__ = get_tensor_dim
5 changes: 5 additions & 0 deletions fast_llm/engine/multi_stage/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

if typing.TYPE_CHECKING:
from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM
from fast_llm.engine.inference.runner import InferenceRunner
from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -241,6 +242,10 @@ def get_checkpoint_handler_class(cls, format: type[CheckpointFormat] | str) -> t
def get_model_class(cls) -> type["FastLLMModel"]:
raise NotImplementedError

@classmethod
def get_inference_runner_class(cls) -> type["InferenceRunner"]:
raise NotImplementedError

@classmethod
def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceBaseModelForCausalLM"]:
raise NotImplementedError
Expand Down
32 changes: 7 additions & 25 deletions fast_llm/engine/multi_stage/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,39 +441,21 @@ def _get_parameter_shard_indices_in_full_weight(
where it is located in the shard if it exists, or -1 if it's not in the shard.
Used to determine the location of each entry in a different distributed configuration.
"""

# Create an empty index for the global parameter.
index = torch.full(
parameter_meta.global_shape,
-1,
dtype=torch.int64,
device=device,
)
# Set the shard slice of the global parameter to corresponding indices of the parameter slice of the shard
begin, end = self._get_parameter_range_in_shard(parameter_name)

buffer_index = parameter_meta.global_to_local(index, expand=True)
# Copying directly into `buffer_index` requires a view of the tensor, which may not be feasible.
# In that case, we work with a separate tensor to be copied back into `buffer_index`.
try:
buffer_index_flat = buffer_index.view(-1)
is_view = True
except RuntimeError:
buffer_index_flat = buffer_index.new_full((buffer_index.numel(),), -1)
is_view = False

# Copy the shard indices at their respective positions in the flat buffer index.
buffer_index_flat[
# Create an empty local index to hold the local shard indices.
buffer_index = torch.full_like(parameter_meta, -1, dtype=torch.int64, device=device)

# Copy the shard indices at their respective positions in the buffer index.
buffer_index.flatten()[
self._index_buffer_to_param(
self._fsdp_dim.rank * self._shard_size, parameter_name
) : self._index_buffer_to_param((self._fsdp_dim.rank + 1) * self._shard_size, parameter_name)
].copy_(torch.arange(begin, end, dtype=torch.int64, device=device))

# If needed, copy the flat buffer index back into the index.
if not is_view:
buffer_index.copy_(buffer_index_flat.view_as(buffer_index))

return index
# Create a global index from the local one.
return parameter_meta.local_to_global_partial(buffer_index, -1)

def copy_shard_overlaps(
self,
Expand Down
5 changes: 0 additions & 5 deletions fast_llm/engine/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from fast_llm.utils import Assert

if typing.TYPE_CHECKING:
from fast_llm.engine.inference.runner import InferenceRunner
from fast_llm.engine.training.trainer import Trainer, TrainingEvaluator


Expand Down Expand Up @@ -403,10 +402,6 @@ def _setup(self):
def get_trainer_class(cls) -> type["Trainer"]:
raise NotImplementedError

@classmethod
def get_inference_runner_class(cls) -> type["InferenceRunner"]:
raise NotImplementedError

def _get_runnable(self) -> typing.Callable[[], None]:
from fast_llm.engine.distributed.distributed import Distributed

Expand Down
2 changes: 1 addition & 1 deletion fast_llm/engine/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def __init__(self, config: TrainerConfig):
self._reference_models = {}
for name, reference_config in self._config.reference_models.items():
log_main_rank(f"Creating `{name} reference model...")
self._reference_models[name] = self._config.get_inference_runner_class()(
self._reference_models[name] = reference_config.model.get_inference_runner_class()(
reference_config.model.get_model_class()(reference_config.model)
)
self._multi_stage.base_model.add_reference_model(name, self._reference_models[name])
Expand Down
95 changes: 84 additions & 11 deletions fast_llm/functional/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,19 @@ def _torch_cross_entropy_forward_backward(
return loss.detach_(), grad


def distributed_log_softmax(logits: torch.Tensor, group: ProcessGroup, dim: int = -1):
logits = logits.float()
local_max = logits.max(dim=dim, keepdim=True)[0]
all_reduce(local_max, op=ReduceOp.MAX, group=group)

logits_shifted = logits - local_max
exp_logits = torch.exp(logits_shifted)
sum_exp = exp_logits.sum(dim=dim, keepdim=True)
all_reduce(sum_exp, op=ReduceOp.SUM, group=group)

return logits_shifted - sum_exp.log() # log_softmax


@torch.compile
def _fused_softmax_base(
logits: torch.Tensor, logits_scale_factor: float = 1.0, group: ProcessGroup | None = None, dim: int = -1
Expand Down Expand Up @@ -214,38 +227,88 @@ def cross_entropy_forward_backward(
)


def _torch_reverse_kl_forward_backward(
def _torch_reverse_kl_forward_backward_vocab_parallel(
logits: torch.Tensor,
target: torch.Tensor,
loss_mask: torch.Tensor | None,
grad_output: float | None,
logits_scale_factor: float,
target_format: TargetFormat,
group: ProcessGroup | None = None,
teacher_softmax_temperature: float = 1.0,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
Reverse KL using PyTorch's native kl_div function.
This is used for TP version where we split accross vocab dimantion.
This works with sequence-tensor-parallel (distributing over the sequence dimention) as well as a non-TP case.
In sequence-tensor-parallel, where we split along sequence dim., we compute per split loss and then average the loss.
"""
# TODO: merge into single function _torch_reverse_kl_forward_backward
Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format")
Assert.eq(target.shape, logits.shape)
assert target.dtype.is_floating_point, target.dtype
if loss_mask is not None:
Assert.eq(loss_mask.shape, logits.shape[:-1])

# Compute log probabilities - let _fused_softmax handle scaling internally
# teacher_probs = _fused_softmax(target, logits_scale_factor * (1 / teacher_softmax_temperature), group)
# # teacher_log_probs = torch.log(teacher_probs + 1e-8) # log(p)
# teacher_probs = torch.clamp(teacher_probs, min=1e-7) # or even 1e-6
# teacher_log_probs = torch.log(teacher_probs)
teacher_log_probs = distributed_log_softmax(target, group=group)
batch_size = logits.shape[0]
with torch.enable_grad():
logits_ = logits.detach().requires_grad_(grad_output is not None)
student_log_probs = distributed_log_softmax(logits_, group=group)

# Reverse KL: input=teacher_log_probs, target=student_probs
if loss_mask is None:
loss = torch.nn.functional.kl_div(
teacher_log_probs, # input = log(p)
student_log_probs, # target = log(q)
reduction="sum",
log_target=True,
)
else:
# Apply loss mask - this requires some reshaping
raise NotImplementedError("Loss mask not implemented with TP for reverse KL , it must be doublechecked")
loss_per_sample = torch.nn.functional.kl_div(
teacher_log_probs, student_log_probs, reduction="none", log_target=True
).sum(dim=-1)
loss = (loss_per_sample * loss_mask).sum()

if group is not None and target_format != TargetFormat.labels:
all_reduce(loss, op=ReduceOp.SUM, group=group)
loss /= batch_size

if grad_output is not None:
loss.backward(torch.full_like(loss, grad_output))
grad = logits_.grad.to(logits.dtype)
else:
grad = None

return loss.detach_(), grad


def _torch_reverse_kl_forward_backward(
logits: torch.Tensor,
target: torch.Tensor,
loss_mask: torch.Tensor | None,
grad_output: float | None,
logits_scale_factor: float,
target_format: TargetFormat,
group: ProcessGroup | None = None,
teacher_softmax_temperature: float = 1.0,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
Reverse KL using PyTorch's native kl_div function.
This works with sequence-tensor-parallel (distributing over the sequence dimention) as well as a non-TP case.
In sequence-tensor-parallel, where we split along sequence dim., we compute per split loss and then average the loss.
"""
Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format")
Assert.eq(target.shape, logits.shape)
assert target.dtype.is_floating_point, target.dtype
if loss_mask is not None:
Assert.eq(loss_mask.shape, logits.shape[:-1])
# Scale target logits more carefully
scaled_target = target * (logits_scale_factor / teacher_softmax_temperature)
# Clamp to prevent extreme values that cause NaNs in log_softmax
scaled_target = torch.clamp(scaled_target, min=-100.0, max=100.0)

# Clamp to prevent extreme values before log_softmax
scaled_target = torch.clamp(scaled_target, min=-50, max=50)
teacher_log_probs = torch.log_softmax(scaled_target, dim=-1)

# For reverse KL: KL(q||p) = Ξ£ q * log(q/p) = Ξ£ q * (log(q) - log(p))
Expand All @@ -256,9 +319,10 @@ def _torch_reverse_kl_forward_backward(
logits_ = logits.detach().requires_grad_(grad_output is not None)

scaled_logits = logits_ * logits_scale_factor
scaled_logits = torch.clamp(scaled_logits, min=-50, max=50)
# Clamp to prevent extreme values that cause NaNs in log_softmax
scaled_logits = torch.clamp(scaled_logits, min=-100.0, max=100.0)
student_log_probs = torch.log_softmax(scaled_logits, dim=-1)

# Reverse KL: input=teacher_log_probs, target=student_probs
if loss_mask is None:
loss = torch.nn.functional.kl_div(
Expand All @@ -279,6 +343,7 @@ def _torch_reverse_kl_forward_backward(
loss /= group.size()

if grad_output is not None:
# note, we never get here in TP over seq. dim.
loss.backward(torch.full_like(loss, grad_output))
grad = logits_.grad.to(logits.dtype)
else:
Expand Down Expand Up @@ -344,6 +409,14 @@ def reverse_kl_forward_backward(
Assert.eq(teacher_softmax_temperature, 1)
Assert.eq(logits_scale_factor, 1)
raise NotImplementedError("Vocab parallel reverse KL is not implemented yet.")
return _torch_reverse_kl_forward_backward_vocab_parallel(
logits,
target,
loss_mask,
grad_output,
target_format,
group,
)
else:
return _torch_reverse_kl_forward_backward(
logits,
Expand Down
8 changes: 4 additions & 4 deletions fast_llm/layers/language_model/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ def __init__(
self._dropout_p = config.transformer.hidden_dropout
self._use_absolute_position_embeddings = config.use_absolute_position_embeddings

hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden)
vocab_dim = tensor_space.get_tensor_dim(
hidden_dim = tensor_space[TransformerDimNames.hidden]
vocab_dim = tensor_space[
LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab
)
]

if self._parallel_embeddings:
self._vocab_start_index = self._distributed_config.tensor_rank * vocab_dim.size
Expand All @@ -66,7 +66,7 @@ def __init__(
)
if self._use_absolute_position_embeddings:
self.position_embeddings_weight = ParameterMeta.from_dims(
(tensor_space.get_tensor_dim(LanguageModelDimNames.position_embed), hidden_dim),
(tensor_space[LanguageModelDimNames.position_embed], hidden_dim),
init_method=init_normal_(
std=config.init_method_std_embed,
min_val=config.init_method_min_embed,
Expand Down
11 changes: 5 additions & 6 deletions fast_llm/layers/language_model/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(
if self._cross_entropy_splits is not None and self._sequence_parallel:
assert not self._parallel_embeddings

hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden)
hidden_dim = self._tensor_space[TransformerDimNames.hidden]

self._loss_coefficient = (
config.prediction_loss_coefficient[prediction_distance] if config.prediction_loss_coefficient else 1.0
Expand Down Expand Up @@ -108,9 +108,9 @@ def _init_output_weights(self, hidden_dim: TensorDim, config) -> None:
if self._tie_word_embeddings or self._prediction_distance > 0:
return
# untie embedding weights
vocab_dim = self._tensor_space.get_tensor_dim(
vocab_dim = self._tensor_space[
LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab
)
]
self.output_weights = ParameterMeta.from_dims(
(vocab_dim, hidden_dim),
init_method=init_normal_(
Expand Down Expand Up @@ -237,7 +237,6 @@ def _get_targets(
).flatten()
else:
lm_target = None

targets = (dpo_target, lm_target, distillation_target)
# If we do distillation, no need to split it here as it has already been split in the embedding layer!
# if we do CPT/language modeling, we need to split the targets here!
Expand Down Expand Up @@ -350,9 +349,9 @@ def _logits_cross_entropy_forward_backward(
logits_scale_factor=self._logits_scale_factor,
)
if self._debug_transformer and self._cross_entropy_splits is None:
vocab_dim = self._tensor_space.get_tensor_dim(
vocab_dim = self._tensor_space[
LanguageModelDimNames.vocab if self._sequence_parallel_logits else LanguageModelDimNames.vocab_tp
)
]
dims = [*kwargs[TransformerKwargs.hidden_dims][:-1], vocab_dim]
sequence_index = 1 - int(kwargs[TransformerKwargs.sequence_first])
dims[sequence_index] = (
Expand Down
Loading