Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Logging update: added PID and formatting #276

Merged
merged 4 commits into from
Jul 16, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions scripts/alpaca_json_to_jsonl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
JsonToJsonlConverter,
StdoutWriter,
)
from axolotl.logging_config import configure_logging

configure_logging()

# add src to the pythonpath so we don't need to pip install this
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
Expand Down
31 changes: 17 additions & 14 deletions scripts/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,16 @@
from axolotl.utils.trainer import setup_trainer
from axolotl.utils.validation import validate_config
from axolotl.utils.wandb import setup_wandb_env_vars
from axolotl.logging_config import configure_logging

project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)

configure_logging()
LOG = logging.getLogger("axolotl.scripts")


logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"


Expand Down Expand Up @@ -212,7 +215,7 @@ def train(

# load the tokenizer first
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
logging.info(f"loading tokenizer... {tokenizer_config}")
LOG.info(f"loading tokenizer... {tokenizer_config}")
tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)

if (
Expand All @@ -234,7 +237,7 @@ def train(
eval_dataset = None

if cfg.debug or "debug" in kwargs:
logging.info("check_dataset_labels...")
LOG.info("check_dataset_labels...")
check_dataset_labels(
train_dataset.select(
[random.randrange(0, len(train_dataset) - 1) for _ in range(5)] # nosec
Expand All @@ -243,11 +246,11 @@ def train(
)

if prepare_ds_only:
logging.info("Finished preparing dataset. Exiting...")
LOG.info("Finished preparing dataset. Exiting...")
return

# Load the model and tokenizer
logging.info("loading model and peft_config...")
LOG.info("loading model and peft_config...")
model, peft_config = load_model(
cfg.base_model,
cfg.base_model_config,
Expand All @@ -258,17 +261,17 @@ def train(
)

if "merge_lora" in kwargs and cfg.adapter is not None:
logging.info("running merge of LoRA with base model")
LOG.info("running merge of LoRA with base model")
model = model.merge_and_unload()
model.to(dtype=torch.float16)

if cfg.local_rank == 0:
logging.info("saving merged model")
LOG.info("saving merged model")
model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
return

if cfg.inference:
logging.info("calling do_inference function")
LOG.info("calling do_inference function")
prompter: Optional[str] = "AlpacaPrompter"
if "prompter" in kwargs:
if kwargs["prompter"] == "None":
Expand All @@ -287,12 +290,12 @@ def train(
model.config.use_cache = False

if torch.__version__ >= "2" and sys.platform != "win32":
logging.info("Compiling torch model")
LOG.info("Compiling torch model")
model = torch.compile(model)

# go ahead and presave, so we have the adapter config available to inspect
if peft_config:
logging.info(f"Pre-saving adapter config to {cfg.output_dir}")
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
peft_config.save_pretrained(cfg.output_dir)

# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
Expand All @@ -308,9 +311,9 @@ def terminate_handler(_, __, model):
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
)

logging.info("Starting trainer...")
LOG.info("Starting trainer...")
if cfg.group_by_length:
logging.info("hang tight... sorting dataset for group_by_length")
LOG.info("hang tight... sorting dataset for group_by_length")
resume_from_checkpoint = cfg.resume_from_checkpoint
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
possible_checkpoints = [
Expand All @@ -322,7 +325,7 @@ def terminate_handler(_, __, model):
key=lambda path: int(path.split("-")[-1]),
)
resume_from_checkpoint = sorted_paths[-1]
logging.info(
LOG.info(
f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}"
)

Expand All @@ -336,7 +339,7 @@ def terminate_handler(_, __, model):
else:
trainer.train(resume_from_checkpoint=resume_from_checkpoint)

logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")

# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
Expand Down
3 changes: 2 additions & 1 deletion src/axolotl/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# let's check to ensure we don't truncate an item in the middle, we'll use
# the collators later on to pad the datasets

LOG = logging.getLogger("axolotl")

class TokenizedPromptDataset(IterableDataset):
"""
Expand Down Expand Up @@ -115,7 +116,7 @@ def __iter__(self):
"attention_mask": attention_mask,
}
else:
logging.warning(
LOG.warning(
f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}"
)
buffer = {
Expand Down
27 changes: 27 additions & 0 deletions src/axolotl/logging_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import sys
from logging.config import dictConfig
from typing import Any, Dict

DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
"version": 1,
"formatters": {
"simple": {
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] %(message)s",
},
},
"filters": {},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"formatter": "simple",
"filters": [],
"stream": sys.stdout,
},
},
"root": {"handlers": ["console"], "level": "INFO"},
}


def configure_logging():
"""Configure with default logging"""
dictConfig(DEFAULT_LOGGING_CONFIG)
3 changes: 1 addition & 2 deletions src/axolotl/monkeypatch/llama_landmark_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@
logging,
replace_return_docstrings,
)

logger = logging.get_logger(__name__)
LOG = logging.getLogger("axolotl")

_CONFIG_FOR_DOC = "LlamaConfig"

Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/prompt_strategies/pygmalion.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def tokenize_prompt(self, prompt):
*copy.deepcopy(res["input_ids"])
][len(self.bot_prefix_token_ids) :]
else:
logging.warning(f"unknown role in conversation: {role}")
LOG.warning(f"unknown role in conversation: {role}")
res = defaultdict(lambda: [])

# pylint: disable=duplicate-code
Expand Down
4 changes: 3 additions & 1 deletion src/axolotl/prompt_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

from axolotl.prompters import IGNORE_TOKEN_ID

LOG = logging.getLogger("axolotl")

IGNORE_INDEX = -100
LLAMA_DEFAULT_PAD_TOKEN = "[PAD]" # nosec
LLAMA_DEFAULT_EOS_TOKEN = "</s>" # nosec
Expand Down Expand Up @@ -384,7 +386,7 @@ def tokenize_prompt(self, prompt):
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
else:
logging.warning(f"unhandled role: {part[0]}")
LOG.warning(f"unhandled role: {part[0]}")

# pylint: disable=duplicate-code
result, current_len = parse_tokenized_to_result(
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/prompters.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def get_prompt(self) -> Generator[Tuple[str, str], None, None]:
if message:
yield (role + ":", " " + message)
else:
logging.warning(f"role with empty message: {role}")
LOG.warning(f"role with empty message: {role}")
yield (role + ":", "")

def copy(self):
Expand Down
38 changes: 20 additions & 18 deletions src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
SummarizeTLDRPrompter,
)

LOG = logging.getLogger("axolotl")


def load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path
Expand Down Expand Up @@ -73,17 +75,17 @@ def load_tokenized_prepared_datasets(
if dataset:
...
elif any(prepared_ds_path.glob("*")):
logging.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
dataset = load_from_disk(str(prepared_ds_path))
logging.info("Prepared dataset loaded from disk...")
LOG.info("Prepared dataset loaded from disk...")
else:
logging.info(f"Unable to find prepared dataset in {prepared_ds_path}")
logging.info("Loading raw datasets...")
LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}")
LOG.info("Loading raw datasets...")

if cfg.seed:
seed = cfg.seed
else:
logging.info("No seed provided, using default seed of 42")
LOG.info("No seed provided, using default seed of 42")
seed = 42

datasets = []
Expand Down Expand Up @@ -256,25 +258,25 @@ def load_tokenized_prepared_datasets(
suffix = ""
if ":load_" in d.type:
suffix = f" Did you mean {d.type.replace(':load_', '.load_')}?"
logging.error(
LOG.error(
f"unhandled prompt tokenization strategy: {d.type}. {suffix}"
)
raise ValueError(
f"unhandled prompt tokenization strategy: {d.type} {suffix}"
)
logging.info("tokenizing, merging, and shuffling master dataset")
LOG.info("tokenizing, merging, and shuffling master dataset")

samples: List[int] = []
for d in datasets:
samples = samples + list(d)
dataset = Dataset.from_list(samples).shuffle(seed=seed)
if cfg.local_rank == 0:
logging.info(
LOG.info(
f"Saving merged prepared dataset to disk... {prepared_ds_path}"
)
dataset.save_to_disk(prepared_ds_path)
if cfg.push_dataset_to_hub:
logging.info(
LOG.info(
f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
)
dataset.push_to_hub(
Expand Down Expand Up @@ -325,7 +327,7 @@ def load_prepare_datasets(
use_auth_token = cfg.hf_use_auth_token
try:
if cfg.push_dataset_to_hub:
logging.info(
LOG.info(
f"Checking for packed prepared dataset from hub... {cfg.push_dataset_to_hub}/{ds_hash}"
)
dataset = load_dataset(
Expand All @@ -339,13 +341,13 @@ def load_prepare_datasets(
if dataset:
...
elif any(prepared_ds_path.glob("*")):
logging.info(
LOG.info(
f"Loading prepared packed dataset from disk at {prepared_ds_path}..."
)
dataset = load_from_disk(str(prepared_ds_path))
logging.info("Prepared packed dataset loaded from disk...")
LOG.info("Prepared packed dataset loaded from disk...")
if cfg.push_dataset_to_hub:
logging.info(
LOG.info(
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
)
dataset.push_to_hub(
Expand All @@ -364,7 +366,7 @@ def load_prepare_datasets(
[dataset],
seq_length=max_packed_sequence_len,
)
logging.info(
LOG.info(
f"packing master dataset to len: {cfg.max_packed_sequence_len}"
)
dataset = Dataset.from_list(list(constant_len_dataset))
Expand All @@ -382,12 +384,12 @@ def load_prepare_datasets(
)

if cfg.local_rank == 0:
logging.info(
LOG.info(
f"Saving packed prepared dataset to disk... {prepared_ds_path}"
)
dataset.save_to_disk(prepared_ds_path)
if cfg.push_dataset_to_hub:
logging.info(
LOG.info(
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
)
dataset.push_to_hub(
Expand All @@ -400,7 +402,7 @@ def load_prepare_datasets(
)

if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
logging.info(
LOG.info(
f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards"
)
dataset = dataset.shard(
Expand Down Expand Up @@ -521,7 +523,7 @@ def encode_pretraining(tokenizer, max_tokens, examples):
"attention_mask": [seq.tolist() for seq in new_attention_mask],
}

logging.debug(len(ret["input_ids"]))
LOG.debug(len(ret["input_ids"]))
return ret


Expand Down
Loading