diff --git a/.mypy.ini b/.mypy.ini index c542178e0..a562a37bc 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -8,6 +8,9 @@ ignore_missing_imports = True [mypy-axolotl.monkeypatch.*] ignore_errors = True +[mypy-axolotl.utils.callbacks] +disable_error_code = attr-defined + [mypy-flash_attn.*] ignore_missing_imports = True diff --git a/requirements.txt b/requirements.txt index 1e95b716e..4ef9f5fd2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -30,3 +30,4 @@ scipy scikit-learn==1.2.2 pynvml art +wandb diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 99c7b147a..d16b8aed6 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -11,7 +11,6 @@ import pandas as pd import torch import torch.distributed as dist -import wandb from datasets import load_dataset from optimum.bettertransformer import BetterTransformer from tqdm import tqdm @@ -25,6 +24,7 @@ ) from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy +import wandb from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.distributed import ( barrier, @@ -367,7 +367,7 @@ def on_evaluate( output_scores=False, ) - def logits_to_tokens(logits) -> str: + def logits_to_tokens(logits) -> torch.Tensor: probabilities = torch.softmax(logits, dim=-1) # Get the predicted token ids (the ones with the highest probability) predicted_token_ids = torch.argmax(probabilities, dim=-1)