-
-
Notifications
You must be signed in to change notification settings - Fork 871
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DPO cleanup #1126
DPO cleanup #1126
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome PR! I left a comment in case you see fit. Also, maybe it could be tackled in a different PR, but the preprocess
command could also be updated to allow checking rl
datasets:
+ if parsed_cfg.rl:
+ _ = load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
+ else:
+ _ = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
- _ = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
d5f97c3
to
c0a1553
Compare
|
||
def load(strategy, cfg): | ||
try: | ||
load_fn = strategy.split(".")[-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is most likely not correct. The strategy
includes underscores, not .
, such as intel_apply_chatml
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def load(strategy, cfg):
try:
load_fn = strategy.split("_")[-1]
#strategy = ".".join(strategy.split("_")[:-1])
LOG.info(load_fn)
LOG.info(strategy)
mod = importlib.import_module(f".{load_fn}", "axolotl.prompt_strategies.dpo")
func = getattr(mod, strategy)
load_kwargs = {}
return func(cfg, **load_kwargs)
except Exception as e: # pylint: disable=broad-exception-caught
LOG.warning(e)
return None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This works for me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the intention is the setting is something like
type: chatml.argilla
in which case it will load the argilla function from the axolotl.prompt_strategies.dpo.chatml
module.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @winglian 👋🏻 thanks. That makes sense. I will test it later today 👍🏻
Co-authored-by: Agus <agustin.piqueres@gmail.com>
Co-authored-by: Agus <agustin.piqueres@gmail.com>
Description
This PR cleans up some hardcoding, improves the integration with trl's DPOTrainer and adds support for dpo prompt_strategies.