diff --git a/scripts/alpaca_json_to_jsonl.py b/scripts/alpaca_json_to_jsonl.py index 61cb170ec..8ea1983fe 100644 --- a/scripts/alpaca_json_to_jsonl.py +++ b/scripts/alpaca_json_to_jsonl.py @@ -15,6 +15,9 @@ JsonToJsonlConverter, StdoutWriter, ) +from axolotl.logging_config import configure_logging + +configure_logging() # add src to the pythonpath so we don't need to pip install this project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) diff --git a/scripts/finetune.py b/scripts/finetune.py index a1c5b13b9..8696d3c9a 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -17,6 +17,7 @@ from optimum.bettertransformer import BetterTransformer from transformers import GenerationConfig, TextStreamer +from axolotl.logging_config import configure_logging from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_model, load_tokenizer @@ -29,8 +30,10 @@ src_dir = os.path.join(project_root, "src") sys.path.insert(0, src_dir) +configure_logging() +LOG = logging.getLogger("axolotl.scripts") + -logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO")) DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared" @@ -212,7 +215,7 @@ def train( # load the tokenizer first tokenizer_config = cfg.tokenizer_config or cfg.base_model_config - logging.info(f"loading tokenizer... {tokenizer_config}") + LOG.info(f"loading tokenizer... {tokenizer_config}") tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg) if ( @@ -234,7 +237,7 @@ def train( eval_dataset = None if cfg.debug or "debug" in kwargs: - logging.info("check_dataset_labels...") + LOG.info("check_dataset_labels...") check_dataset_labels( train_dataset.select( [random.randrange(0, len(train_dataset) - 1) for _ in range(5)] # nosec @@ -243,11 +246,11 @@ def train( ) if prepare_ds_only: - logging.info("Finished preparing dataset. Exiting...") + LOG.info("Finished preparing dataset. Exiting...") return # Load the model and tokenizer - logging.info("loading model and peft_config...") + LOG.info("loading model and peft_config...") model, peft_config = load_model( cfg.base_model, cfg.base_model_config, @@ -258,17 +261,17 @@ def train( ) if "merge_lora" in kwargs and cfg.adapter is not None: - logging.info("running merge of LoRA with base model") + LOG.info("running merge of LoRA with base model") model = model.merge_and_unload() model.to(dtype=torch.float16) if cfg.local_rank == 0: - logging.info("saving merged model") + LOG.info("saving merged model") model.save_pretrained(str(Path(cfg.output_dir) / "merged")) return if cfg.inference: - logging.info("calling do_inference function") + LOG.info("calling do_inference function") prompter: Optional[str] = "AlpacaPrompter" if "prompter" in kwargs: if kwargs["prompter"] == "None": @@ -287,12 +290,12 @@ def train( model.config.use_cache = False if torch.__version__ >= "2" and sys.platform != "win32": - logging.info("Compiling torch model") + LOG.info("Compiling torch model") model = torch.compile(model) # go ahead and presave, so we have the adapter config available to inspect if peft_config: - logging.info(f"Pre-saving adapter config to {cfg.output_dir}") + LOG.info(f"Pre-saving adapter config to {cfg.output_dir}") peft_config.save_pretrained(cfg.output_dir) # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model @@ -308,9 +311,9 @@ def terminate_handler(_, __, model): signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model) ) - logging.info("Starting trainer...") + LOG.info("Starting trainer...") if cfg.group_by_length: - logging.info("hang tight... sorting dataset for group_by_length") + LOG.info("hang tight... sorting dataset for group_by_length") resume_from_checkpoint = cfg.resume_from_checkpoint if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints: possible_checkpoints = [ @@ -322,7 +325,7 @@ def terminate_handler(_, __, model): key=lambda path: int(path.split("-")[-1]), ) resume_from_checkpoint = sorted_paths[-1] - logging.info( + LOG.info( f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}" ) @@ -336,7 +339,7 @@ def terminate_handler(_, __, model): else: trainer.train(resume_from_checkpoint=resume_from_checkpoint) - logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") + LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index 5593a8dd3..911df8f50 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -14,6 +14,8 @@ # let's check to ensure we don't truncate an item in the middle, we'll use # the collators later on to pad the datasets +LOG = logging.getLogger("axolotl") + class TokenizedPromptDataset(IterableDataset): """ @@ -115,7 +117,7 @@ def __iter__(self): "attention_mask": attention_mask, } else: - logging.warning( + LOG.warning( f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}" ) buffer = { diff --git a/src/axolotl/logging_config.py b/src/axolotl/logging_config.py new file mode 100644 index 000000000..1df272d5c --- /dev/null +++ b/src/axolotl/logging_config.py @@ -0,0 +1,30 @@ +"""Logging configuration settings""" + +import os +import sys +from logging.config import dictConfig +from typing import Any, Dict + +DEFAULT_LOGGING_CONFIG: Dict[str, Any] = { + "version": 1, + "formatters": { + "simple": { + "format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] %(message)s", + }, + }, + "filters": {}, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "formatter": "simple", + "filters": [], + "stream": sys.stdout, + }, + }, + "root": {"handlers": ["console"], "level": os.getenv("LOG_LEVEL", "INFO")}, +} + + +def configure_logging(): + """Configure with default logging""" + dictConfig(DEFAULT_LOGGING_CONFIG) diff --git a/src/axolotl/monkeypatch/llama_landmark_attn.py b/src/axolotl/monkeypatch/llama_landmark_attn.py index 2a4cdbc36..24a98305f 100644 --- a/src/axolotl/monkeypatch/llama_landmark_attn.py +++ b/src/axolotl/monkeypatch/llama_landmark_attn.py @@ -53,7 +53,7 @@ replace_return_docstrings, ) -logger = logging.get_logger(__name__) +LOG = logging.getLogger("axolotl") _CONFIG_FOR_DOC = "LlamaConfig" @@ -862,7 +862,7 @@ def forward( if self.gradient_checkpointing and self.training: if use_cache: - logger.warning_once( + LOG.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/axolotl/prompt_strategies/pygmalion.py b/src/axolotl/prompt_strategies/pygmalion.py index d38bc2beb..88208f6ec 100644 --- a/src/axolotl/prompt_strategies/pygmalion.py +++ b/src/axolotl/prompt_strategies/pygmalion.py @@ -11,6 +11,8 @@ tokenize_prompt_default, ) +LOG = logging.getLogger("axolotl") + IGNORE_TOKEN_ID = -100 @@ -64,7 +66,7 @@ def tokenize_prompt(self, prompt): *copy.deepcopy(res["input_ids"]) ][len(self.bot_prefix_token_ids) :] else: - logging.warning(f"unknown role in conversation: {role}") + LOG.warning(f"unknown role in conversation: {role}") res = defaultdict(lambda: []) # pylint: disable=duplicate-code diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 8216d73dd..fb6f39b0d 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -10,6 +10,8 @@ from axolotl.prompters import IGNORE_TOKEN_ID +LOG = logging.getLogger("axolotl") + IGNORE_INDEX = -100 LLAMA_DEFAULT_PAD_TOKEN = "[PAD]" # nosec LLAMA_DEFAULT_EOS_TOKEN = "" # nosec @@ -384,7 +386,7 @@ def tokenize_prompt(self, prompt): # everything from this is masked out from the labels labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) else: - logging.warning(f"unhandled role: {part[0]}") + LOG.warning(f"unhandled role: {part[0]}") # pylint: disable=duplicate-code result, current_len = parse_tokenized_to_result( diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 715a227c8..a304bd137 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -5,6 +5,7 @@ from enum import Enum, auto from typing import Generator, List, Optional, Tuple, Union +LOG = logging.getLogger("axolotl") IGNORE_TOKEN_ID = -100 @@ -241,7 +242,7 @@ def get_prompt(self) -> Generator[Tuple[str, str], None, None]: if message: yield (role + ":", " " + message) else: - logging.warning(f"role with empty message: {role}") + LOG.warning(f"role with empty message: {role}") yield (role + ":", "") def copy(self): diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 8df1e4d38..ef732a8ad 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -35,6 +35,8 @@ SummarizeTLDRPrompter, ) +LOG = logging.getLogger("axolotl") + def load_tokenized_prepared_datasets( tokenizer, cfg, default_dataset_prepared_path @@ -73,17 +75,17 @@ def load_tokenized_prepared_datasets( if dataset: ... elif any(prepared_ds_path.glob("*")): - logging.info(f"Loading prepared dataset from disk at {prepared_ds_path}...") + LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...") dataset = load_from_disk(str(prepared_ds_path)) - logging.info("Prepared dataset loaded from disk...") + LOG.info("Prepared dataset loaded from disk...") else: - logging.info(f"Unable to find prepared dataset in {prepared_ds_path}") - logging.info("Loading raw datasets...") + LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}") + LOG.info("Loading raw datasets...") if cfg.seed: seed = cfg.seed else: - logging.info("No seed provided, using default seed of 42") + LOG.info("No seed provided, using default seed of 42") seed = 42 datasets = [] @@ -255,25 +257,21 @@ def load_tokenized_prepared_datasets( suffix = "" if ":load_" in d.type: suffix = f" Did you mean {d.type.replace(':load_', '.load_')}?" - logging.error( - f"unhandled prompt tokenization strategy: {d.type}. {suffix}" - ) + LOG.error(f"unhandled prompt tokenization strategy: {d.type}. {suffix}") raise ValueError( f"unhandled prompt tokenization strategy: {d.type} {suffix}" ) - logging.info("tokenizing, merging, and shuffling master dataset") + LOG.info("tokenizing, merging, and shuffling master dataset") samples: List[int] = [] for d in datasets: samples = samples + list(d) dataset = Dataset.from_list(samples).shuffle(seed=seed) if cfg.local_rank == 0: - logging.info( - f"Saving merged prepared dataset to disk... {prepared_ds_path}" - ) + LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}") dataset.save_to_disk(prepared_ds_path) if cfg.push_dataset_to_hub: - logging.info( + LOG.info( f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}" ) dataset.push_to_hub( @@ -324,7 +322,7 @@ def load_prepare_datasets( use_auth_token = cfg.hf_use_auth_token try: if cfg.push_dataset_to_hub: - logging.info( + LOG.info( f"Checking for packed prepared dataset from hub... {cfg.push_dataset_to_hub}/{ds_hash}" ) dataset = load_dataset( @@ -338,13 +336,13 @@ def load_prepare_datasets( if dataset: ... elif any(prepared_ds_path.glob("*")): - logging.info( + LOG.info( f"Loading prepared packed dataset from disk at {prepared_ds_path}..." ) dataset = load_from_disk(str(prepared_ds_path)) - logging.info("Prepared packed dataset loaded from disk...") + LOG.info("Prepared packed dataset loaded from disk...") if cfg.push_dataset_to_hub: - logging.info( + LOG.info( f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}" ) dataset.push_to_hub( @@ -363,9 +361,7 @@ def load_prepare_datasets( [dataset], seq_length=max_packed_sequence_len, ) - logging.info( - f"packing master dataset to len: {cfg.max_packed_sequence_len}" - ) + LOG.info(f"packing master dataset to len: {cfg.max_packed_sequence_len}") dataset = Dataset.from_list(list(constant_len_dataset)) # filter out bad data @@ -381,12 +377,12 @@ def load_prepare_datasets( ) if cfg.local_rank == 0: - logging.info( + LOG.info( f"Saving packed prepared dataset to disk... {prepared_ds_path}" ) dataset.save_to_disk(prepared_ds_path) if cfg.push_dataset_to_hub: - logging.info( + LOG.info( f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}" ) dataset.push_to_hub( @@ -399,7 +395,7 @@ def load_prepare_datasets( ) if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None: - logging.info( + LOG.info( f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards" ) dataset = dataset.shard( @@ -520,7 +516,7 @@ def encode_pretraining(tokenizer, max_tokens, examples): "attention_mask": [seq.tolist() for seq in new_attention_mask], } - logging.debug(len(ret["input_ids"])) + LOG.debug(len(ret["input_ids"])) return ret diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 6849551bc..a88a3807e 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -23,6 +23,8 @@ from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN +LOG = logging.getLogger("axolotl") + if TYPE_CHECKING: from peft import PeftConfig # noqa: F401 @@ -50,10 +52,10 @@ def load_tokenizer( use_fast=use_fast, ) - logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}") - logging.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}") - logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}") - logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}") + LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}") + LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}") + LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}") + LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}") if tokenizer.__class__.__name__ in [ "LlamaTokenizer", @@ -92,21 +94,21 @@ def load_model( if cfg.device not in ["mps", "cpu"] and not cfg.inference: from axolotl.flash_attn import replace_llama_attn_with_flash_attn - logging.info("patching with flash attention") + LOG.info("patching with flash attention") replace_llama_attn_with_flash_attn() elif cfg.is_llama_derived_model and cfg.xformers_attention: from axolotl.monkeypatch.llama_attn_hijack_xformers import ( hijack_llama_attention, ) - logging.info("patching with xformers attention") + LOG.info("patching with xformers attention") hijack_llama_attention() elif cfg.is_llama_derived_model and cfg.sdp_attention: from axolotl.monkeypatch.llama_attn_hijack_xformers import ( hijack_llama_sdp_attention, ) - logging.info("patching with sdp attention") + LOG.info("patching with sdp attention") hijack_llama_sdp_attention() elif cfg.is_llama_derived_model and cfg.landmark_attention: from axolotl.monkeypatch.llama_landmark_attn import ( @@ -114,7 +116,7 @@ def load_model( patch_llama_with_landmark_attn, ) - logging.info("patching with landmark attention") + LOG.info("patching with landmark attention") patch_llama_with_landmark_attn() # Note: This might overwrite previous additional_special_tokens @@ -125,7 +127,7 @@ def load_model( replace_llama_rope_with_xpos_rope, ) - logging.info("patching with xpos rope") + LOG.info("patching with xpos rope") replace_llama_rope_with_xpos_rope() if cfg.bf16 or cfg.bfloat16: @@ -142,7 +144,7 @@ def load_model( replace_peft_model_with_int4_lora_model() except Exception as err: - logging.exception(err) + LOG.exception(err) raise err try: @@ -187,7 +189,7 @@ def load_model( if len(files) > 0: model_path = str(files[0]) else: - logging.warning( + LOG.warning( "unable to find a cached model file, this will likely fail..." ) model_path = str(cache_model_path) @@ -266,14 +268,14 @@ def load_model( and cfg.sequence_len > config.max_seq_len ): config.max_seq_len = cfg.sequence_len - logging.warning(f"increasing context length to {cfg.sequence_len}") + LOG.warning(f"increasing context length to {cfg.sequence_len}") elif ( hasattr(config, "max_sequence_length") and config.max_sequence_length and cfg.sequence_len > config.max_sequence_length ): config.max_sequence_length = cfg.sequence_len - logging.warning(f"increasing context length to {cfg.sequence_len}") + LOG.warning(f"increasing context length to {cfg.sequence_len}") model = AutoModelForCausalLM.from_pretrained( base_model, config=config, @@ -285,10 +287,10 @@ def load_model( **model_kwargs, ) except Exception as err: # pylint: disable=broad-exception-caught - logging.error( + LOG.error( "Exception raised attempting to load model, retrying with AutoModelForCausalLM" ) - logging.exception(err) + LOG.exception(err) model = AutoModelForCausalLM.from_pretrained( base_model, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, @@ -307,7 +309,7 @@ def load_model( and model.config.max_position_embeddings and cfg.sequence_len >= model.config.max_position_embeddings ): - logging.warning( + LOG.warning( f"increasing model.config.max_position_embeddings to {cfg.sequence_len}" ) model.config.max_position_embeddings = cfg.sequence_len @@ -316,7 +318,7 @@ def load_model( (cfg.adapter == "lora" and load_in_8bit) or (cfg.adapter == "qlora" and cfg.load_in_4bit) ): - logging.info("converting PEFT model w/ prepare_model_for_kbit_training") + LOG.info("converting PEFT model w/ prepare_model_for_kbit_training") model = prepare_model_for_kbit_training( model, use_gradient_checkpointing=cfg.gradient_checkpointing ) @@ -328,7 +330,7 @@ def load_model( if cfg.gptq: # Scales to half - logging.info("Fitting 4bit scales and zeros to half") + LOG.info("Fitting 4bit scales and zeros to half") for _, module in model.named_modules(): if "Autograd4bitQuantLinear" in str(type(module)) or "Linear4bitLt" in str( type(module) @@ -354,7 +356,7 @@ def load_model( if param.requires_grad: requires_grad.append(f"{name}: {param.requires_grad}") if len(requires_grad) == 0: - logging.warning("there are no parameters that require gradient updates") + LOG.warning("there are no parameters that require gradient updates") model.config.use_cache = False if cfg.flash_optimum: @@ -388,7 +390,7 @@ def load_llama_adapter(model, cfg): ) if cfg.lora_model_dir: - logging.info("Loading pretained LORA") + LOG.info("Loading pretained LORA") model = PeftModel.from_pretrained( model, cfg.lora_model_dir, @@ -435,7 +437,7 @@ def load_lora(model, cfg): bits = 8 linear_names = find_all_linear_names(bits, model) - logging.info(f"found linear modules: {repr(linear_names)}") + LOG.info(f"found linear modules: {repr(linear_names)}") lora_target_modules = list(set(lora_target_modules + linear_names)) lora_config = LoraConfig( diff --git a/src/axolotl/utils/tokenization.py b/src/axolotl/utils/tokenization.py index 7d0d1dd83..b2d1df400 100644 --- a/src/axolotl/utils/tokenization.py +++ b/src/axolotl/utils/tokenization.py @@ -5,6 +5,8 @@ from termcolor import colored +LOG = logging.getLogger("axolotl") + def check_dataset_labels(dataset, tokenizer): # the dataset is already shuffled, so let's just check the first 5 elements @@ -32,7 +34,7 @@ def check_example_labels(example, tokenizer): ) colored_tokens.append(colored_token) - logging.info(" ".join(colored_tokens)) - logging.info("\n\n\n") + LOG.info(" ".join(colored_tokens)) + LOG.info("\n\n\n") return " ".join(colored_tokens) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 3dcebeb8a..bdd760526 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -26,6 +26,8 @@ get_cosine_schedule_with_quadratic_warmup, ) +LOG = logging.getLogger("axolotl") + class AxolotlTrainingArguments(TrainingArguments): """ @@ -324,7 +326,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): set_model_mem_id(model, tokenizer) - logging.info("Adding landmark attention tokens to dataset") + LOG.info("Adding landmark attention tokens to dataset") for dataset in [train_dataset, eval_dataset]: dataset = dataset.map( diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py index 40dfb84a9..06669cba2 100644 --- a/src/axolotl/utils/validation.py +++ b/src/axolotl/utils/validation.py @@ -4,6 +4,8 @@ import torch +LOG = logging.getLogger("axolotl") + def validate_config(cfg): if cfg.gradient_accumulation_steps and cfg.batch_size: @@ -11,7 +13,7 @@ def validate_config(cfg): "please set only one of gradient_accumulation_steps or batch_size" ) if cfg.batch_size: - logging.warning( + LOG.warning( "%s\n%s", "batch_size is not recommended. Please use gradient_accumulation_steps instead.", "To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.", @@ -44,10 +46,10 @@ def validate_config(cfg): raise ValueError("Require cfg.load_in_4bit to be True for qlora") if not cfg.load_in_8bit and cfg.adapter == "lora": - logging.warning("We recommend setting `load_in_8bit: true` for LORA finetuning") + LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning") if cfg.trust_remote_code: - logging.warning( + LOG.warning( "`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model." ) @@ -66,31 +68,29 @@ def validate_config(cfg): if cfg.flash_optimum is True: if cfg.adapter: - logging.warning( - "BetterTransformers probably doesn't work with PEFT adapters" - ) + LOG.warning("BetterTransformers probably doesn't work with PEFT adapters") if cfg.fp16 or cfg.bf16: raise ValueError("AMP is not supported with BetterTransformer") if cfg.float16 is not True and cfg.bloat16 is not True: - logging.warning( + LOG.warning( "You should probably set bfloat16 or float16 to true to " "load the model in float16 for BetterTransformers" ) if int(torch.__version__.split(".")[0]) < 2: - logging.warning("torch>=2.0.0 required") + LOG.warning("torch>=2.0.0 required") raise ValueError( f"flash_optimum for BetterTransformers may not be used with {torch.__version__}" ) if cfg.pretraining_dataset and cfg.group_by_length: - logging.warning( + LOG.warning( "You probably want to disable group_by_length as it will force a streamed dataset to download completely." ) if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and ( not cfg.optimizer or "adamw" not in cfg.optimizer ): - logging.warning("adamw hyperparameters found, but no adamw optimizer set") + LOG.warning("adamw hyperparameters found, but no adamw optimizer set") if cfg.push_to_hub_model_id: raise ValueError( diff --git a/tests/test_prompt_tokenizers.py b/tests/test_prompt_tokenizers.py index 3ddbe77bf..a3e4cdbdf 100644 --- a/tests/test_prompt_tokenizers.py +++ b/tests/test_prompt_tokenizers.py @@ -17,7 +17,7 @@ ) from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompter -logging.basicConfig(level="INFO") +LOG = logging.getLogger("axolotl") class TestPromptTokenizationStrategies(unittest.TestCase):