Skip to content

Commit

Permalink
Max Length rework (#741)
Browse files Browse the repository at this point in the history
* implementation

* c

* implementation

* docs

* c

* dont allow freeze and lora

* implementation

* format

* rm

* c

* c

* c

* fix

* r

* c

* format

* fix

* docs

* readme

* ui test

* fixing unfreeze + lora + dpo

* fixing position_id issue

* c

* c
  • Loading branch information
psinger authored Jun 6, 2024
1 parent 9cafe8c commit bf7f7f9
Show file tree
Hide file tree
Showing 33 changed files with 137 additions and 133 deletions.
13 changes: 13 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,19 @@ test: reports
-o log_cli=true -o log_level=INFO -o log_file=reports/tests.log \
tests/* 2>&1 | tee reports/tests.log'


.PHONY: test-debug
test-debug: reports
@bash -c 'set -o pipefail; export PYTHONPATH=$(PWD); \
$(PIPENV) run pytest -v --junitxml=reports/junit.xml \
--import-mode importlib \
--html=./reports/pytest.html \
-k test_encode \
-s \
-o log_cli=false -o log_level=WARNING -o log_file=/dev/null \
tests/*'


.PHONY: test-ui
test-ui: reports setup-ui
@bash -c 'set -o pipefail; \
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ Using CLI for fine-tuning LLMs:
## What's New

- [PR 747](https://github.com/h2oai/h2o-llmstudio/pull/747) Fully removed RLHF in favor of DPO/IPO/KTO optimization.
- [PR 741](https://github.com/h2oai/h2o-llmstudio/pull/741) Removing separate max length settings for prompt and answer in favor of a single `max_length` settings better resembling `chat_template` functionality from `transformers`.
- [PR 592](https://github.com/h2oai/h2o-llmstudio/pull/599) Added `KTOPairLoss` for DPO modeling allowing to train models with simple preference data. Data currently needs to be manually prepared by randomly matching positive and negative examples as pairs.
- [PR 592](https://github.com/h2oai/h2o-llmstudio/pull/592) Starting to deprecate RLHF in favor of DPO/IPO optimization. Training is disabled, but old experiments are still viewable. RLHF will be fully removed in a future release.
- [PR 530](https://github.com/h2oai/h2o-llmstudio/pull/530) Introduced a new problem type for DPO/IPO optimization. This optimization technique can be used as an alternative to RLHF.
Expand Down
2 changes: 0 additions & 2 deletions documentation/docs/get-started/llm-studio-performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,6 @@ problem_type: text_causal_language_modeling
tokenizer:
add_prompt_answer_tokens: false
max_length: 512
max_length_answer: 256
max_length_prompt: 256
padding_quantile: 1.0
training:
batch_size: 2
Expand Down
10 changes: 0 additions & 10 deletions documentation/docs/guide/experiments/experiment-settings.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ import DStextAnswerSeparator from '../../tooltips/experiments/_text-answer-separ
import DSaddEosTokentoprompt from '../../tooltips/experiments/_add-eos-token-to-prompt.mdx';
import DSaddEosTokentoanswer from '../../tooltips/experiments/_add-eos-token-to-answer.mdx';
import DSmaskPromptlabels from '../../tooltips/experiments/_mask-prompt-labels.mdx';
import TSmaxLengthPrompt from '../../tooltips/experiments/_max-length-prompt.mdx';
import TSmaxLengthAnswer from '../../tooltips/experiments/_max-length-answer.mdx';
import TSmaxLength from '../../tooltips/experiments/_max-length.mdx';
import TSaddpromptanswertokens from '../../tooltips/experiments/_add-prompt-answer-tokens.mdx';
import TSpaddingQuantile from '../../tooltips/experiments/_padding-quantile.mdx';
Expand Down Expand Up @@ -173,14 +171,6 @@ The settings under each category are listed and described below.

## Tokenizer settings

### Max length prompt

<TSmaxLengthPrompt/>

### Max length answer

<TSmaxLengthAnswer/>

### Max length

<TSmaxLength/>
Expand Down
2 changes: 1 addition & 1 deletion documentation/docs/tooltips/experiments/_answer-column.mdx
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
The column in the dataset containing the expected output.

For classification, this needs to be an integer column containing the class label.
For classification, this needs to be an integer column starting from zero containing the class label.

This file was deleted.

This file was deleted.

2 changes: 0 additions & 2 deletions examples/example_oasst2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@ problem_type: text_causal_language_modeling
tokenizer:
add_prompt_answer_tokens: false
max_length: 512
max_length_answer: 256
max_length_prompt: 256
padding_quantile: 1.0
tokenizer_kwargs: '{"use_fast": true, "add_prefix_space": false}'
training:
Expand Down
2 changes: 1 addition & 1 deletion llm_studio/app_utils/sections/chat_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ async def answer_chat(q: Q) -> str:
logger.info(f"Full prompt: {full_prompt}")

inputs = cfg.dataset.dataset_class.encode(
tokenizer, full_prompt, cfg.tokenizer.max_length_prompt, "left"
tokenizer, full_prompt, cfg.tokenizer.max_length, "left"
)
inputs["prompt_input_ids"] = (
inputs.pop("input_ids").unsqueeze(0).to(cfg.environment._device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,11 @@ def __post_init__(self):

@dataclass
class ConfigNLPCausalClassificationTokenizer(ConfigNLPCausalLMTokenizer):
max_length_prompt: int = 512
max_length: int = 512

def __post_init__(self):
super().__post_init__()

self._visibility["max_length_answer"] = -1


@dataclass
class ConfigNLPCausalClassificationArchitecture(ConfigNLPCausalLMArchitecture):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,8 @@ def __post_init__(self):
)
self._possible_values["differential_learning_rate_layers"] = (
possible_values.String(
values=("backbone", "embed"),
allow_custom=False,
values=("backbone", "embed", "head"),
allow_custom=True,
placeholder="Select optional layers...",
)
)
Expand Down Expand Up @@ -250,17 +250,13 @@ def __post_init__(self):

@dataclass
class ConfigNLPCausalLMTokenizer(DefaultConfig):
max_length_prompt: int = 256
max_length_answer: int = 256
max_length: int = 512
add_prompt_answer_tokens: bool = False
padding_quantile: float = 1.0
tokenizer_kwargs: str = '{"use_fast": true, "add_prefix_space": false}'

def __post_init__(self):
super().__post_init__()
self._possible_values["max_length_prompt"] = (32, 1024 * 16, 32)
self._possible_values["max_length_answer"] = (32, 1024 * 16, 32)
self._possible_values["max_length"] = (32, 1024 * 16, 32)
self._possible_values["padding_quantile"] = (0, 1, 0.01)
self._padding_side = "left"
Expand Down Expand Up @@ -353,7 +349,7 @@ def __post_init__(self):

self._possible_values["num_beams"] = (1, 4, 1)
self._possible_values["temperature"] = (0, 10, 0.05)
self._possible_values["repetition_penalty"] = (1, 10, 0.05)
self._possible_values["repetition_penalty"] = (1, 10, 0.025)
self._possible_values["top_k"] = (0, 100, 1)
self._possible_values["top_p"] = (0.5, 1, 0.05)
self._possible_values["num_history"] = (1, 50, 1)
Expand Down
40 changes: 14 additions & 26 deletions llm_studio/src/datasets/text_causal_language_modeling_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def __getitem__(self, idx: int) -> Dict:
input_text_dict["prompts"] = [
self.parse_prompt(self.cfg, prompt) for prompt in input_text_dict["prompts"]
]
input_text_dict["answers"] = [
self.parse_answer(self.cfg, answer) for answer in input_text_dict["answers"]
]

sample = dict()
system_encoding, prompt_encodings, answer_encodings = self.get_encodings(
Expand Down Expand Up @@ -72,7 +75,7 @@ def __getitem__(self, idx: int) -> Dict:
self.pad_tokens(
answer_encodings[-1],
attention_mask=torch.ones_like(answer_encodings[-1]),
max_length=self.cfg.tokenizer.max_length_answer,
max_length=self.cfg.tokenizer.max_length,
pad_token_id=self.tokenizer.pad_token_id,
direction="right",
prefix="answer_",
Expand All @@ -99,14 +102,6 @@ def __getitem__(self, idx: int) -> Dict:
)
)

# make sure system encoding is always prepended if max_length exceeded
if sample["input_ids"][0] != self.tokenizer.pad_token_id:
sample["input_ids"][: len(system_encoding)] = system_encoding
if self.cfg.dataset.mask_prompt_labels and "labels" in sample.keys():
sample["labels"][: len(system_encoding)] = -100
if sample["prompt_input_ids"][0] != self.tokenizer.pad_token_id:
sample["prompt_input_ids"][: len(system_encoding)] = system_encoding

return sample

@staticmethod
Expand All @@ -122,6 +117,12 @@ def parse_prompt(cfg: Any, prompt: str):
)
return prompt

@staticmethod
def parse_answer(cfg: Any, answer: str):
if cfg.dataset.add_eos_token_to_answer:
answer += cfg._tokenizer_eos_token
return answer

@staticmethod
def parse_system(cfg: Any, system: str):
# no system tokens if empty
Expand Down Expand Up @@ -375,9 +376,6 @@ def get_labels(self, prompt_encodings, answer_encodings):
]
).to(torch.bool)
labels.masked_fill_(prompt_mask, -100)
if self.cfg.dataset.add_eos_token_to_answer:
# eos_token may be equal to pad_token. Add the label back manually.
labels[-1] = self.tokenizer.eos_token_id
if self.cfg.tokenizer.max_length < len(labels):
labels = labels[-self.cfg.tokenizer.max_length :]

Expand Down Expand Up @@ -446,27 +444,16 @@ def augment_data(self, encodings):
def _get_sample_encoding(self, system: str, prompt: str, answer: str) -> List:
if len(system) > 0:
system_encoding = self.encode(
self.tokenizer, system, self.cfg.tokenizer.max_length_prompt, "right"
self.tokenizer, system, self.cfg.tokenizer.max_length, "right"
)["input_ids"]
else:
system_encoding = torch.empty(0)
prompt_encoding = self.encode(
self.tokenizer, prompt, self.cfg.tokenizer.max_length_prompt, "left"
self.tokenizer, prompt, self.cfg.tokenizer.max_length, "left"
)["input_ids"]
max_length_answer = self.cfg.tokenizer.max_length_answer - int(
self.cfg.dataset.add_eos_token_to_answer
)
answer_encoding = self.encode(
self.tokenizer, answer, max_length_answer, "right"
self.tokenizer, answer, self.cfg.tokenizer.max_length, "right"
)["input_ids"]
if self.cfg.dataset.add_eos_token_to_answer:
answer_encoding = torch.cat(
[
answer_encoding,
torch.Tensor([self.tokenizer.eos_token_id]),
],
dim=0,
)

return [system_encoding, prompt_encoding, answer_encoding]

Expand All @@ -482,6 +469,7 @@ def pad_tokens(
sample = {}

if max_length < len(input_ids):
logger.info(f"Input exceeds max_length of {max_length}, truncating sample.")
input_ids = input_ids[-max_length:]
attention_mask = attention_mask[-max_length:]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
from transformers import AutoModelForCausalLM

from llm_studio.src.utils.data_utils import batch_padding
from llm_studio.src.utils.modeling_utils import create_nlp_backbone, prepare_lora
from llm_studio.src.utils.modeling_utils import (
create_nlp_backbone,
forward,
prepare_lora,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -67,7 +71,8 @@ def forward(
padding_side=self.cfg.tokenizer._padding_side,
)

output = self.backbone(
output = forward(
self.backbone,
input_ids=batch["prompt_input_ids"],
attention_mask=batch["prompt_attention_mask"],
)
Expand Down
6 changes: 2 additions & 4 deletions llm_studio/src/models/text_causal_language_modeling_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from llm_studio.src.utils.data_utils import batch_padding
from llm_studio.src.utils.modeling_utils import (
create_nlp_backbone,
forward,
generate,
prepare_lora,
)
Expand Down Expand Up @@ -92,10 +93,7 @@ def forward(
padding_side=self.cfg.tokenizer._padding_side,
)

output = self.backbone(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
)
output = forward(self.backbone, batch["input_ids"], batch["attention_mask"])

if "labels" in batch:
loss = self.loss_fn(output.logits, batch["labels"])
Expand Down
22 changes: 16 additions & 6 deletions llm_studio/src/models/text_dpo_modeling_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from llm_studio.src.utils.data_utils import batch_padding
from llm_studio.src.utils.modeling_utils import (
create_nlp_backbone,
forward,
generate,
prepare_lora,
)
Expand Down Expand Up @@ -83,7 +84,12 @@ def __init__(self, cfg: Any):

if cfg.training.lora:
self.backbone = prepare_lora(cfg=cfg, backbone=self.backbone)

if cfg.training.lora and not cfg.training.lora_unfreeze_layers:
self.backbone_orig = None
else:
if cfg.environment._local_rank == 0:
logger.info("Duplicating backbone for reference model.")
self.backbone_orig, self.backbone_orig_config = create_nlp_backbone(
cfg, model_class=AutoModelForCausalLM
)
Expand Down Expand Up @@ -137,7 +143,8 @@ def forward(
f"{answer}_labels",
],
)
logits = self.backbone(
logits = forward(
self.backbone,
input_ids=batch[f"{answer}_input_ids"],
attention_mask=batch[f"{answer}_attention_mask"],
).logits
Expand All @@ -152,18 +159,21 @@ def forward(
)

with torch.no_grad():
if self.cfg.training.lora:
with self.backbone.disable_adapter():
reference_logits = self.backbone(
if self.backbone_orig:
with torch.no_grad():
reference_logits = forward(
self.backbone_orig,
input_ids=batch[f"{answer}_input_ids"],
attention_mask=batch[f"{answer}_attention_mask"],
).logits
else:
with torch.no_grad():
reference_logits = self.backbone_orig(
with self.backbone.disable_adapter():
reference_logits = forward(
self.backbone,
input_ids=batch[f"{answer}_input_ids"],
attention_mask=batch[f"{answer}_attention_mask"],
).logits

outputs[f"{answer}_reference_logps"] = get_batch_logps(
reference_logits,
batch[f"{answer}_labels"],
Expand Down
Loading

0 comments on commit bf7f7f9

Please sign in to comment.