-
Notifications
You must be signed in to change notification settings - Fork 37
Base model interface review #370
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
looks useful. can we merge? |
|
||
def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: | ||
self._rotary.preprocess(batch, kwargs) | ||
if not self._use_flash_attention: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we make this a config option, an enum? we select other implementations like that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's already an option. We could turn into an enum, but I don't see too much reason
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be useful if we could select the varlen version this way. That would need extra work because it's currently determined in BatchConfig.cross_document_attention
and affects data processing, so won't do now but that's something to consider for the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, I think that would make sense. the way it's selected right now is very ad-hoc and not really controllable. there are also more sophisticated attention masks possible with flex attention, which we don't use yet but might.
Soon, but there are still multiple side issues to address first. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good and should be merged asap, but at this occasion I want to understand what loss definitions are used for. are they just for logging metrics?
prediction_distance=prediction_distance, | ||
def get_loss_definitions(self, count: int = 1) -> list[LossDef]: | ||
# TODO ====== Wrong ====== | ||
return self.block.get_loss_definitions(count=count * self.prediction_heads) + self.head.get_loss_definitions( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure what is wrong here but please fix
name = f"{name}_{self._prediction_distance}" | ||
return name | ||
|
||
def get_loss_definitions(self, count: int = 1) -> list[LossDef]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is just for bookkeeping and logging?
losses[self._distillation_loss_name].append(distillation_loss.detach()) | ||
if self._config.distillation_model is not None and lm_loss is not None: | ||
losses[LanguageModelLossNames.distil_lm_loss].append(lm_loss.detach()) | ||
losses[self._distillation_language_model_loss_name].append(lm_loss.detach()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it would help to clarify the purposes of the 1) detached losses, 2) loss definitions, and 3) non-detached returned loss.
- and 2) are 1-on-1 related, but idk if 3) and 2) are.
def get_layers(self) -> list["Layer"]: | ||
return self.embeddings.get_layers() + self.decoder.get_layers() + self.head.get_layers() | ||
|
||
def preprocess_meta( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not possible atm, it's no longer in the preprocessor/layer interface but is still there for base models. Planning to revisit this with the batch-splitting mechanism, but leaving for later.
Yes, loss definitions are for logging. They are clearly in dire need of improvements, but not exactly sure how and that's low priority so I just did the bare minimum for now. Will merge later today once I get the tests to pass (2 minor bugs remaining) and finish cleanup. |
✨ Description
A series of changes extracted from #369 centered around the base model interface, aimed to make it simpler, more intuitive (closer to plain pytorch) and more flexible.
Note that this introduces additional breaking changes on checkpoints (model structure / parameter names) and configs (minor tweaks), assuming it's ok given that we just broke everything anyway, but could mitigate if needed.
Notable changes (details coming):
embedding
,decoder
,head
. Use matching names for the associated config parametersget_loss_definitions
to module instances (same reason).Block
interface. No immediate benefit, but will allow using the mlp directly for the vision adapter, and opens up new possibilities.Also brought minor things from #369 to reduce diff.