diff --git a/requirements.txt b/requirements.txt index 1e95b716e..4ef9f5fd2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -30,3 +30,4 @@ scipy scikit-learn==1.2.2 pynvml art +wandb diff --git a/scripts/finetune.py b/scripts/finetune.py index b998edc79..ca72c7910 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -26,7 +26,7 @@ from axolotl.utils.distributed import is_main_process from axolotl.utils.models import load_tokenizer from axolotl.utils.tokenization import check_dataset_labels -from axolotl.utils.wandb import setup_wandb_env_vars +from axolotl.utils.wandb_ import setup_wandb_env_vars project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) src_dir = os.path.join(project_root, "src") diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 99c7b147a..819360f1d 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -367,7 +367,7 @@ def on_evaluate( output_scores=False, ) - def logits_to_tokens(logits) -> str: + def logits_to_tokens(logits) -> torch.Tensor: probabilities = torch.softmax(logits, dim=-1) # Get the predicted token ids (the ones with the highest probability) predicted_token_ids = torch.argmax(probabilities, dim=-1) diff --git a/src/axolotl/utils/wandb.py b/src/axolotl/utils/wandb_.py similarity index 100% rename from src/axolotl/utils/wandb.py rename to src/axolotl/utils/wandb_.py