Skip to content

Conversation

jlamypoirier
Copy link
Collaborator

@jlamypoirier jlamypoirier commented Sep 26, 2025

✨ Description

A series of changes extracted from #369 centered around the base model interface, aimed to make it simpler, more intuitive (closer to plain pytorch) and more flexible.

Note that this introduces additional breaking changes on checkpoints (model structure / parameter names) and configs (minor tweaks), assuming it's ok given that we just broke everything anyway, but could mitigate if needed.

Notable changes (details coming):

  • Relax requirement that the base model be a flat sequence of layers. The model can now have any module structure, instead it is required to provide a list of layers to the Fast-LLM engine. This will help sort out complex models and make things cleaner.
  • Promote the decoder to an actual module (enabled by above). This improves the config/module parity and allows moving functionality (ex. preprocesing) to the module side. Add kwarg namespaces to fix name conflict issues in preprocessing.
  • Promote MTP to an independent dynamic head type. This was made necessary by the recent changes, and the existing workarounds were becoming unsustainable. New format has its own block config and one block per head (now including the first one). It's a bit more complicated than before but basically unavoidable.
  • Change language model / gpt structure accordingly. Replace flat list of submodules with appropriately named embedding, decoder, head. Use matching names for the associated config parameters
  • Drop preprocessors entirely. Move model preprocessing to the associated modules. Move preference spans directly to gpt preprocessing (it's batch splitting, will need separate work). This way preprocessing is as close as possible to the code it's preprocessing for.
  • Move get_loss_definitions to module instances (same reason).
  • Simplify tied weight interface. Tied weights are now defined and used as any other, and the model provides the set of tied weights, Fast-LLM engine handles the rest.
  • Make mlps and mixers respect the Block interface. No immediate benefit, but will allow using the mlp directly for the vision adapter, and opens up new possibilities.

Also brought minor things from #369 to reduce diff.

@jlamypoirier jlamypoirier changed the title Promote mlps and mixers to blocks. Base model interface review Oct 1, 2025
@tscholak
Copy link
Collaborator

tscholak commented Oct 1, 2025

looks useful. can we merge?


def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None:
self._rotary.preprocess(batch, kwargs)
if not self._use_flash_attention:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we make this a config option, an enum? we select other implementations like that

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's already an option. We could turn into an enum, but I don't see too much reason

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be useful if we could select the varlen version this way. That would need extra work because it's currently determined in BatchConfig.cross_document_attention and affects data processing, so won't do now but that's something to consider for the future.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I think that would make sense. the way it's selected right now is very ad-hoc and not really controllable. there are also more sophisticated attention masks possible with flex attention, which we don't use yet but might.

@jlamypoirier
Copy link
Collaborator Author

looks useful. can we merge?

Soon, but there are still multiple side issues to address first.

@jlamypoirier jlamypoirier marked this pull request as ready for review October 3, 2025 01:29
Copy link
Collaborator

@tscholak tscholak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good and should be merged asap, but at this occasion I want to understand what loss definitions are used for. are they just for logging metrics?

prediction_distance=prediction_distance,
def get_loss_definitions(self, count: int = 1) -> list[LossDef]:
# TODO ====== Wrong ======
return self.block.get_loss_definitions(count=count * self.prediction_heads) + self.head.get_loss_definitions(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure what is wrong here but please fix

name = f"{name}_{self._prediction_distance}"
return name

def get_loss_definitions(self, count: int = 1) -> list[LossDef]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is just for bookkeeping and logging?

losses[self._distillation_loss_name].append(distillation_loss.detach())
if self._config.distillation_model is not None and lm_loss is not None:
losses[LanguageModelLossNames.distil_lm_loss].append(lm_loss.detach())
losses[self._distillation_language_model_loss_name].append(lm_loss.detach())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would help to clarify the purposes of the 1) detached losses, 2) loss definitions, and 3) non-detached returned loss.

  1. and 2) are 1-on-1 related, but idk if 3) and 2) are.

def get_layers(self) -> list["Layer"]:
return self.embeddings.get_layers() + self.decoder.get_layers() + self.head.get_layers()

def preprocess_meta(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not possible atm, it's no longer in the preprocessor/layer interface but is still there for base models. Planning to revisit this with the batch-splitting mechanism, but leaving for later.

@jlamypoirier
Copy link
Collaborator Author

jlamypoirier commented Oct 3, 2025

looks good and should be merged asap, but at this occasion I want to understand what loss definitions are used for. are they just for logging metrics?

Yes, loss definitions are for logging. They are clearly in dire need of improvements, but not exactly sure how and that's low priority so I just did the bare minimum for now.

Will merge later today once I get the tests to pass (2 minor bugs remaining) and finish cleanup.

@jlamypoirier jlamypoirier merged commit bda052f into main Oct 3, 2025
4 checks passed
@jlamypoirier jlamypoirier deleted the jlp/mlp_block branch October 3, 2025 23:18
@jlamypoirier jlamypoirier mentioned this pull request Oct 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants