-
-
Notifications
You must be signed in to change notification settings - Fork 871
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* cleanup dpo to be a little more extensible, add zephyr/nectar strategy * fix eos slash * support for eval split * fix kwargs * handle empty evals * don't load peft model for dpo * ensure dpo traning args gets bf16 for peft if applicable * fix duplicate kwargs for bf16 * make sure to respect the configured lr scheduler * supprt trainer callback to push config to wandb * set dataloader preload args * ensure that we are loading the lora when merging * Update src/axolotl/utils/data.py Co-authored-by: Agus <agustin.piqueres@gmail.com> * support local datasets for dpo Co-authored-by: Agus <agustin.piqueres@gmail.com> * chore: lint * dpo/kto/ipo smoke tests w lora, simplify dpo dataset type names * add split to dpo tests * fix rebase/merging error * handle edge case w logging * use accelerator for dpo datasets so it doesn't break the logger * missing args * validate checkpoint is an adapter for now * log warning when dataset strategy is not loadable --------- Co-authored-by: Agus <agustin.piqueres@gmail.com>
- Loading branch information
Showing
10 changed files
with
440 additions
and
106 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
""" | ||
module for DPO style dataset transform strategies | ||
""" | ||
|
||
import importlib | ||
import logging | ||
|
||
LOG = logging.getLogger("axolotl") | ||
|
||
|
||
def load(strategy, cfg): | ||
try: | ||
load_fn = strategy.split(".")[-1] | ||
strategy = ".".join(strategy.split(".")[:-1]) | ||
mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies.dpo") | ||
func = getattr(mod, load_fn) | ||
load_kwargs = {} | ||
return func(cfg, **load_kwargs) | ||
except Exception: # pylint: disable=broad-exception-caught | ||
LOG.warning(f"unable to load strategy {strategy}") | ||
return None |
Oops, something went wrong.