Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error while saving checkpoint during training #26732

Closed
2 of 4 tasks
ghost opened this issue Oct 11, 2023 · 12 comments · Fixed by #26570 or #27099
Closed
2 of 4 tasks

Error while saving checkpoint during training #26732

ghost opened this issue Oct 11, 2023 · 12 comments · Fixed by #26570 or #27099

Comments

@ghost
Copy link

ghost commented Oct 11, 2023

System Info

  • transformers version: 4.34.0
  • Platform: Linux-5.4.0-162-generic-x86_64-with-glibc2.29
  • Python version: 3.8.10
  • Huggingface_hub version: 0.18.0
  • Safetensors version: 0.4.0
  • Accelerate version: 0.23.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.1.0+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I am training codellama model on custom dataset. Training starts but when it tries to save the checkpoint then it gives the error and stop training.

ERROR:
2023-10-11 11:34:18,589 - ERROR - Error in Logs due to Object of type method is not JSON serializable
CODE:

import json
import torch
import pandas as pd
import datasets
from peft import LoraConfig,PeftModel
from transformers import (AutoModelForCausalLM,AutoTokenizer,TrainingArguments,BitsAndBytesConfig)
import transformers
from trl import SFTTrainer
import os

import logging
import sys

RANK = 16
LR = 1e-4
EPOCH = 10
BATCH = 11


output_dir = f"../results/10-10-2023/{RANK}_RANK--{LR}_LR--{EPOCH}_EPOCH--{BATCH}_BATCH/"


if not os.path.exists(output_dir):
    # If the directory doesn't exist, create it
    os.makedirs(output_dir)
    print(f"Directory '{output_dir}' created.")
else:
    print(f"Directory '{output_dir}' already exists.")


# Create a logger instance
logger = logging.getLogger()
logger.setLevel(logging.INFO)

# Create a formatter with the desired format
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')

# Create a stream handler to output log messages to the console
stream_handler = logging.StreamHandler(sys.stdout)
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)

# Create a file handler to log messages to a file
file_handler = logging.FileHandler(f'{output_dir}/trl-trainer-codellama.txt', encoding='utf-8')  # Specify the file name here
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
console_handler = logging.StreamHandler(stream=sys.stdout)


# DEVICE = "cuda:0" if torch.cuda.is_available() else 'cpu'



MODEL_NAME = "./CodeLlama-7b-Instruct-HF"

# loading dataset
dataset = datasets.load_from_disk("../verilog-dataset/codellama_800L_74052E/")
# loading model
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME,use_safetensors=True,load_in_8bit=True,trust_remote_code=True,device_map='auto')
# loading tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, add_special_tokens=False, add_eos_token=False, add_bos_token=False)
tokenizer.pad_token = "[PAD]"

# LORA Configuration
peft_config = LoraConfig(
    lora_alpha=RANK*2,
    lora_dropout=0.05,
    r = RANK,
    bias="none",
    task_type = "CAUSAL_LM",
    target_modules = ["q_proj", "v_proj","lm_head"]
)



training_arguments = TrainingArguments(
    per_device_train_batch_size=BATCH,
    gradient_accumulation_steps=2,
    optim="paged_adamw_32bit",
    learning_rate=LR,
    fp16=True,
    max_grad_norm=0.3,
    num_train_epochs=EPOCH,
    warmup_ratio=0.05,
    logging_steps=5,
    save_total_limit=100,
    save_strategy="steps",
    save_steps=2,
    group_by_length=True,
    output_dir=output_dir,
    report_to="tensorboard",
    save_safetensors=True,
    lr_scheduler_type="cosine",
    seed=42)

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    peft_config=peft_config,
    dataset_text_field="text",
    max_seq_length=800,
    tokenizer=tokenizer,
    args=training_arguments,
)


try:
    trainer.train()
except Exception as e:
    logger.error(f"Error in Logs due to {e}")

Expected behavior

I am expecting that model should continue training without stopping while saving the checkpoints.

@LysandreJik
Copy link
Member

Hmmm we have very little visibility in the error due to your log of the error. Would it be possible to have it completely raise so as to have the traceback?

Also could you try installing from source to see if your problem is fixed? You can do so with pip install git+https://github.com/huggingface/transformers.
Thanks!

@ghost
Copy link
Author

ghost commented Oct 12, 2023

@LysandreJik Please find the full traceback below

TypeError                                 Traceback (most recent call last)
Cell In[x], line 3
      1 # Save the fine-tuned model
----> 3 tokenizer.save_pretrained("tokenfile")

File /3tb/share/anaconda3/envs/ak_env/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:2130, in PreTrainedTokenizerBase.save_pretrained(self, save_directory, legacy_format, filename_prefix, push_to_hub, **kwargs)
   2128 write_dict = convert_added_tokens(self.special_tokens_map_extended, add_type_field=False)
   2129 with open(special_tokens_map_file, "w", encoding="utf-8") as f:
-> 2130     out_str = json.dumps(write_dict, indent=2, sort_keys=True, ensure_ascii=False) + "\n"
   2131     f.write(out_str)
   2132 logger.info(f"Special tokens file saved in {special_tokens_map_file}")

File /3tb/share/anaconda3/envs/ak_env/lib/python3.10/json/__init__.py:238, in dumps(obj, skipkeys, ensure_ascii, check_circular, allow_nan, cls, indent, separators, default, sort_keys, **kw)
    232 if cls is None:
    233     cls = JSONEncoder
    234 return cls(
    235     skipkeys=skipkeys, ensure_ascii=ensure_ascii,
    236     check_circular=check_circular, allow_nan=allow_nan, indent=indent,
    237     separators=separators, default=default, sort_keys=sort_keys,
--> 238     **kw).encode(obj)

File /3tb/share/anaconda3/envs/ak_env/lib/python3.10/json/encoder.py:201, in JSONEncoder.encode(self, o)
    199 chunks = self.iterencode(o, _one_shot=True)
...
    178     """
--> 179     raise TypeError(f'Object of type {o.__class__.__name__} '
    180                     f'is not JSON serializable')

TypeError: Object of type property is not JSON serializable

Temporary Fix 🔧

Issue was happening when we save tokenizer while saving checkpoint. I was able to fix it by removing tokenizer parameter in trainer as below:

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    peft_config=peft_config,
    dataset_text_field="text",
    max_seq_length=800,
    //tokenizer=tokenizer,
    args=training_arguments,
)

@LysandreJik
Copy link
Member

cc @ArthurZucker

@ArthurZucker
Copy link
Collaborator

Hey! I think this will be fixed by #26570! Will keep you updated

@ArthurZucker
Copy link
Collaborator

Hey @humza-sami could you try running your script with #26570?
doing something liek gh pr checkout 26570 if you have installed from source should help

@ghost
Copy link
Author

ghost commented Oct 17, 2023

Hi @ArthurZucker , I followed this:

pip uninstall transformers
git clone https://github.com/huggingface/transformers.git
cd transformers/
git fetch origin pull/26570/head:pull_26570
git checkout pull_26570
pip install .

Still when I save the tokenizer, error is same.

from transformers import (AutoModelForCausalLM,AutoTokenizer,TrainingArguments,BitsAndBytesConfig)
MODEL_NAME = "codellama/CodeLlama-7b-Instruct-hf"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, add_special_tokens=False, add_eos_token=False, add_bos_token=False)
tokenizer.pad_token = None
tokenizer.save_pretrained("sample")

ERROR


Using pad_token, but it is not set yet.
Using pad_token, but it is not set yet.
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[10], line 1
----> 1 tokenizer.save_pretrained("sample")

File /usr/lib/python3.8/site-packages/transformers/tokenization_utils_base.py:2435, in PreTrainedTokenizerBase.save_pretrained(self, save_directory, legacy_format, filename_prefix, push_to_hub, **kwargs)
   2432     tokenizer_config.pop("special_tokens_map_file", None)
   2434 with open(tokenizer_config_file, "w", encoding="utf-8") as f:
-> 2435     out_str = json.dumps(tokenizer_config, indent=2, sort_keys=True, ensure_ascii=False) + "\n"
   2436     f.write(out_str)
   2437 logger.info(f"tokenizer config file saved in {tokenizer_config_file}")

File /usr/lib/python3.8/json/__init__.py:234, in dumps(obj, skipkeys, ensure_ascii, check_circular, allow_nan, cls, indent, separators, default, sort_keys, **kw)
    232 if cls is None:
    233     cls = JSONEncoder
--> 234 return cls(
    235     skipkeys=skipkeys, ensure_ascii=ensure_ascii,
    236     check_circular=check_circular, allow_nan=allow_nan, indent=indent,
    237     separators=separators, default=default, sort_keys=sort_keys,
    238     **kw).encode(obj)

File /usr/lib/python3.8/json/encoder.py:201, in JSONEncoder.encode(self, o)
    199 chunks = self.iterencode(o, _one_shot=True)
    200 if not isinstance(chunks, (list, tuple)):
--> 201     chunks = list(chunks)
    202 return ''.join(chunks)

File /usr/lib/python3.8/json/encoder.py:431, in _make_iterencode.<locals>._iterencode(o, _current_indent_level)
    429     yield from _iterencode_list(o, _current_indent_level)
    430 elif isinstance(o, dict):
--> 431     yield from _iterencode_dict(o, _current_indent_level)
    432 else:
    433     if markers is not None:

File /usr/lib/python3.8/json/encoder.py:405, in _make_iterencode.<locals>._iterencode_dict(dct, _current_indent_level)
    403         else:
    404             chunks = _iterencode(value, _current_indent_level)
--> 405         yield from chunks
    406 if newline_indent is not None:
    407     _current_indent_level -= 1

File /usr/lib/python3.8/json/encoder.py:438, in _make_iterencode.<locals>._iterencode(o, _current_indent_level)
    436         raise ValueError("Circular reference detected")
    437     markers[markerid] = o
--> 438 o = _default(o)
    439 yield from _iterencode(o, _current_indent_level)
    440 if markers is not None:

File /usr/lib/python3.8/json/encoder.py:179, in JSONEncoder.default(self, o)
    160 def default(self, o):
    161     """Implement this method in a subclass such that it returns
    162     a serializable object for ``o``, or calls the base implementation
    163     (to raise a ``TypeError``).
   (...)
    177 
    178     """
--> 179     raise TypeError(f'Object of type {o.__class__.__name__} '
    180                     f'is not JSON serializable')

TypeError: Object of type method is not JSON serializable

@ArthurZucker
Copy link
Collaborator

Bit strange, this worked for me

@ghost
Copy link
Author

ghost commented Oct 19, 2023

@ArthurZucker If possible can you share a test code snippet you are using which I can test with my code ?
Please see my simple code which is causing issue:

from transformers import AutoTokenizer
MODEL_NAME = "codellama/CodeLlama-7b-Instruct-hf"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, add_special_tokens=False, add_eos_token=False, add_bos_token=False)
tokenizer.pad_token = "[PAD]"
tokenizer.save_pretrained("sample")

Its giving me error. I am using latest 4.34.1v of transformers

@ArthurZucker
Copy link
Collaborator

Alright I can indee reproduce now, the init_kwargs include add_special_tokens which is a function as well as something we pass to the model usually to specify whether we want to add special tokens or not. It should not be saved as an init_kwargs / should be filtered out when we serialized. I'll push a fix soon

@ArthurZucker
Copy link
Collaborator

I'm still working on the PR 😉

@huggingface huggingface deleted a comment from github-actions bot Dec 9, 2023
@ArthurZucker
Copy link
Collaborator

It's planned for this release! 🤗 One small test to fix and will be merged

@ArthurZucker
Copy link
Collaborator

Thanks for you patience @ghost (oups) now fixed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants