Skip to content
This repository has been archived by the owner on Aug 16, 2024. It is now read-only.

[enhancement] refactoring for better usability #87

Merged
merged 7 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions Install.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ then enter the following code:
```python
import mlora
mlora.setup_logging("INFO")
mlora.get_backend().check_available()
mlora.backend.check_available()
```

Expected output:
Expand Down Expand Up @@ -88,7 +88,7 @@ then enter the following code:
```python
import mlora
mlora.setup_logging("INFO")
mlora.get_backend().check_available()
mlora.backend.check_available()
```

Expected output:
Expand Down Expand Up @@ -138,7 +138,7 @@ then enter the following code:
```python
import mlora
mlora.setup_logging("INFO")
mlora.get_backend().check_available()
mlora.backend.check_available()
```

Expected output:
Expand Down Expand Up @@ -183,7 +183,7 @@ then enter the following code:
```python
import mlora
mlora.setup_logging("INFO")
mlora.get_backend().check_available()
mlora.backend.check_available()
```

Expected output:
Expand Down
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,16 @@ It's crucial to note that regardless of the settings, **LoRA weights are always

For users with NVIDIA Ampere or newer GPU architectures, the `--tf32` option can be utilized to enable full-precision calculation acceleration.

## Offline Configuration

m-LoRA relies on **HuggingFace Hub** to download necessary models, datasets, etc. If you cannot access the Internet or need to deploy m-LoRA in an offline environment, please refer to the following guide.

1. Use `git-lfs` manually downloads models and datasets from [HuggingFace Hub](https://huggingface.co).
2. Set `--data_path` to the local path to datasets when executing `launch.py gen`.
3. Clone the [evaluate](https://github.com/huggingface/evaluate) code repository locally.
4. Set environment variable `MLORA_METRIC_PATH` to the local path to evaluate code repository.
5. Set `--base_model` to the local path to models when executing `launch.py run`.

## Known issues

+ Quantization with Qwen2 have no effect (same with transformers).
Expand Down
6 changes: 4 additions & 2 deletions evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
def main(
base_model: str,
task_name: str,
data_path: str = None,
lora_weights: str = None,
load_16bit: bool = True,
load_8bit: bool = False,
Expand All @@ -15,12 +16,12 @@ def main(
save_file: str = None,
batch_size: int = 32,
router_profile: bool = False,
device: str = mlora.get_backend().default_device_name(),
device: str = mlora.backend.default_device_name(),
):

mlora.setup_logging("INFO")

if not mlora.get_backend().check_available():
if not mlora.backend.check_available():
exit(-1)

model = mlora.LLMModel.from_pretrained(
Expand All @@ -39,6 +40,7 @@ def main(
evaluate_paramas = mlora.EvaluateConfig(
adapter_name=adapter_name,
task_name=task_name,
data_path=data_path,
batch_size=batch_size,
router_profile=router_profile,
)
Expand Down
2 changes: 1 addition & 1 deletion generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def main(
flash_attn: bool = False,
max_seq_len: int = None,
stream: bool = False,
device: str = mlora.get_backend().default_device_name(),
device: str = mlora.backend.default_device_name(),
):

model = mlora.LLMModel.from_pretrained(
Expand Down
2 changes: 1 addition & 1 deletion inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def main(
load_8bit: bool = False,
load_4bit: bool = False,
flash_attn: bool = False,
device: str = mlora.get_backend().default_device_name(),
device: str = mlora.backend.default_device_name(),
server_name: str = "0.0.0.0",
share_gradio: bool = False,
):
Expand Down
24 changes: 19 additions & 5 deletions launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,15 @@ def update_record(dict_: dict, key_, value_):
def gen_config(
# essential
template: str,
tasks: str,
task_list: str,
# optional
adapter_name: str = None,
file_name: str = "mlora.json",
data_path: str = None,
multi_task: bool = False,
append: bool = False,
# default value provided by template
prompt_template: str = None,
cutoff_len: int = None,
save_step: int = None,
lr_scheduler: str = None,
Expand Down Expand Up @@ -107,14 +109,22 @@ def gen_config(

index = len(template_obj["lora"])
if multi_task:
tasks = [tasks]
task_list = [task_list]
path_list = [data_path]
else:
tasks = tasks.split(";")
task_list = task_list.split(";")
path_list = (
[None] * len(task_list) if data_path is None else data_path.split(";")
)

for lora_template in lora_templates:
for task_name in tasks:
for task_name, data_path in zip(task_list, path_list):
lora_config = lora_template.copy()
if task_name not in mlora.tasks.task_dict:
if multi_task:
lora_config["name"] = f"multi_task_{index}"
lora_config["task_name"] = task_name
elif task_name not in mlora.tasks.task_dict:
assert os.path.exists(task_name), f"File '{task_name}' not exist."
lora_config["name"] = f"casual_{index}"
lora_config["task_name"] = "casual"
lora_config["data"] = task_name
Expand All @@ -128,6 +138,8 @@ def gen_config(
if adapter_name is not None:
lora_config["name"] = f"{adapter_name}_{index}"

update_record(lora_config, "data", data_path)
update_record(lora_config, "prompt", prompt_template)
update_record(lora_config, "scheduler_type", lr_scheduler)
update_record(lora_config, "warmup_steps", warmup_steps)
update_record(lora_config, "lr", learning_rate)
Expand Down Expand Up @@ -173,8 +185,10 @@ def show_help():
--tasks task names separate by ';'
--adapter_name default is task name
--file_name default is 'mlora.json'
--data_path path to input data
--multi_task multi-task training
--append append to existed config
--prompt_template [alpaca]
--cutoff_len
--save_step
--warmup_steps
Expand Down
11 changes: 7 additions & 4 deletions misc/finetune-demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
"\n",
"model = mlora.LLMModel.from_pretrained(\n",
" base_model,\n",
" device=mlora.get_backend().default_device_name(),\n",
" device=mlora.backend.default_device_name(),\n",
" load_dtype=torch.bfloat16,\n",
")\n",
"tokenizer = mlora.Tokenizer(base_model)"
Expand Down Expand Up @@ -96,12 +96,13 @@
"model.init_adapter(lora_config)\n",
"\n",
"train_config = mlora.TrainConfig(\n",
" adapter_name=\"lora_0\",\n",
" data_path=\"TUDB-Labs/Dummy-mLoRA\",\n",
" num_epochs=10,\n",
" batch_size=16,\n",
" micro_batch_size=8,\n",
" learning_rate=1e-4,\n",
" casual_train_data=\"TUDB-Labs/Dummy-mLoRA\",\n",
").init(lora_config)\n",
")\n",
"\n",
"mlora.train(model=model, tokenizer=tokenizer, configs=[train_config])"
]
Expand All @@ -125,7 +126,9 @@
" stop_token=\"\\n\",\n",
")\n",
"\n",
"output = mlora.generate(model=model, tokenizer=tokenizer, configs=[generate_config])\n",
"output = mlora.generate(\n",
" model=model, tokenizer=tokenizer, configs=[generate_config], max_gen_len=128\n",
")\n",
"\n",
"print(output[\"lora_0\"][0])"
]
Expand Down
6 changes: 3 additions & 3 deletions misc/mmlu_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def do_evaluate(
model_dtype: str,
adapter_names: List[str],
batch_size: int = 2,
device: str = mlora.get_backend().default_device_name(),
device: str = mlora.backend.default_device_name(),
output: str = "mmlu_scores.csv",
):
tokenizer = mlora.Tokenizer(model_name)
Expand Down Expand Up @@ -301,9 +301,9 @@ def do_evaluate(


def main(config: str):
mlora.get_backend().manual_seed(66)
mlora.backend.manual_seed(66)
mlora.setup_logging("INFO")
if not mlora.get_backend().check_available():
if not mlora.backend.check_available():
exit(-1)
with open(config, "r", encoding="utf8") as fp:
mmlu_config = json.load(fp)
Expand Down
79 changes: 17 additions & 62 deletions mlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,81 +148,36 @@ def init_adapter_config(
logging.info(f"Setting cutoff_len to {llm_model.max_seq_len_} automatically.")

for lora_config in config["lora"]:
lora_weight = None
config_class = mlora.lora_config_factory(lora_config)
config_class.adapter_name = lora_config["name"]
config_class.task_name = lora_config.get("task_name", "casual")
config_class.device = args.device

adapter_file_path = (
args.dir + os.sep + config_class.adapter_name + os.sep + "adapter_model.bin"
)
if args.load_adapter:
adapter_config_path = (
args.dir
+ os.sep
+ config_class.adapter_name
+ os.sep
+ "adapter_config.json"
)
logging.info(f"Load adapter: {adapter_file_path}")
with open(adapter_config_path, "r", encoding="utf8") as fp:
adapter_config = json.load(fp)
base_model_name_or_path = adapter_config.get(
"base_model_name_or_path", ""
)
if (
base_model_name_or_path != ""
and base_model_name_or_path != llm_model.name_or_path_
):
raise ValueError(
"loading adapter with unmatched base model."
+ f" current is {llm_model.name_or_path_}, provided {base_model_name_or_path}"
)
lora_weight = torch.load(adapter_file_path, map_location=args.device)
elif os.path.isfile(adapter_file_path):
adapter_name = lora_config["name"]
adapter_path = f"{args.dir}{os.sep}{adapter_name}"
if not args.load_adapter and os.path.exists(adapter_path):
if args.overwrite:
logging.warning(
f"Overwriting existed adapter model file: {adapter_file_path}"
f"Overwriting existed adapter model file: {adapter_path}"
)
elif not query_yes_no(
f"Existed adapter model file detected: {adapter_file_path}\n"
+ "Overwrite?"
f"Existed adapter model file detected: {adapter_path}\n" + "Overwrite?"
):
logging.info("User canceled training due to file conflict.")
exit(0)

if args.verbose:
logging.info(config_class.__dict__)
if args.load_adapter:
llm_model.load_adapter(adapter_path, adapter_name)
else:
llm_model.init_adapter(mlora.lora_config_factory(lora_config))

llm_model.init_adapter(config_class, lora_weight)
if args.inference:
config_class = mlora.GenerateConfig(adapter_name=config_class.adapter_name)
config_class = mlora.GenerateConfig(adapter_name=adapter_name)
if not args.disable_prompter:
config_class.prompt_template = lora_config.get("prompt", None)
config_list.append(config_class)
elif args.evaluate:
if ";" in config_class.task_name:
for task_name in config_class.task_name.split(";"):
config_list.append(
mlora.EvaluateConfig(
adapter_name=config_class.adapter_name,
task_name=task_name,
batch_size=lora_config["evaluate_batch_size"],
)
)
else:
config_list.append(
mlora.EvaluateConfig(
adapter_name=config_class.adapter_name,
task_name=config_class.task_name,
batch_size=lora_config["evaluate_batch_size"],
)
)
config_list.extend(mlora.EvaluateConfig.from_config(lora_config))
else:
config_list.append(
mlora.TrainConfig().from_config(lora_config).init_for(config_class)
)
config_list.append(mlora.TrainConfig.from_config(lora_config))

if args.verbose:
logging.info(config_list[-1].__dict__)

return config_list

Expand Down Expand Up @@ -277,7 +232,7 @@ def inference(

mlora.setup_logging("INFO", args.log_file)

mlora_backend = mlora.get_backend()
mlora_backend = mlora.backend

if not mlora_backend.check_available():
exit(-1)
Expand All @@ -293,7 +248,7 @@ def inference(
args.attn_impl = "eager"

if args.device is None:
args.device = mlora.get_backend().default_device_name()
args.device = mlora.backend.default_device_name()

mlora_backend.use_deterministic_algorithms(args.deterministic)
mlora_backend.allow_tf32(args.tf32)
Expand Down
5 changes: 2 additions & 3 deletions mlora/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .backends import get_backend
from .backends import backend
from .common import (
AdapterConfig,
Cache,
Expand Down Expand Up @@ -26,7 +26,6 @@
"transformers", "4.43.0"
), "m-LoRA requires transformers>=4.43.0"


setup_logging()

__all__ = [
Expand All @@ -53,5 +52,5 @@
"Prompter",
"Tokenizer",
"setup_logging",
"get_backend",
"backend",
]
Loading
Loading