Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Max Length rework #741

Merged
merged 32 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
1f83575
implementation
psinger May 28, 2024
d7ca016
c
psinger May 28, 2024
6635139
Merge remote-tracking branch 'origin/psi/dpo_nonlora' into psi/freeze
psinger May 28, 2024
cfde40f
implementation
psinger May 29, 2024
067aae8
docs
psinger May 29, 2024
c2d49bb
Merge branch 'main' into psi/freeze
psinger May 29, 2024
39b1e42
c
psinger May 29, 2024
7c92aba
dont allow freeze and lora
psinger May 31, 2024
e950cf3
implementation
psinger Jun 4, 2024
aa62d85
format
psinger Jun 4, 2024
0bf74a0
rm
psinger Jun 4, 2024
a92c1e3
Merge branch 'main' into psi/maxlength
psinger Jun 4, 2024
8188859
c
psinger Jun 4, 2024
2919641
c
psinger Jun 4, 2024
51395fa
c
psinger Jun 4, 2024
ef2a5fd
fix
psinger Jun 5, 2024
5ef1b8a
r
psinger Jun 5, 2024
2467b17
Merge branch 'main' into psi/maxlength
psinger Jun 5, 2024
2bc4a68
c
psinger Jun 5, 2024
1aed424
format
psinger Jun 5, 2024
11a864c
Merge branch 'main' into psi/maxlength
psinger Jun 5, 2024
d904d64
fix
psinger Jun 5, 2024
443afd5
Merge branch 'psi/maxlength' of https://github.com/h2oai/h2o-llmstudi…
psinger Jun 5, 2024
0b50ed0
docs
psinger Jun 5, 2024
20ad2b0
Merge branch 'main' into psi/maxlength
psinger Jun 5, 2024
8755f20
readme
psinger Jun 5, 2024
025b297
ui test
psinger Jun 5, 2024
3fb997d
fixing unfreeze + lora + dpo
psinger Jun 6, 2024
1fcd6df
Merge branch 'psi/maxlength' of https://github.com/h2oai/h2o-llmstudi…
psinger Jun 6, 2024
cf71314
fixing position_id issue
psinger Jun 6, 2024
5119663
c
psinger Jun 6, 2024
caf0db8
c
psinger Jun 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,19 @@ test: reports
-o log_cli=true -o log_level=INFO -o log_file=reports/tests.log \
tests/* 2>&1 | tee reports/tests.log'


.PHONY: test-debug
test-debug: reports
@bash -c 'set -o pipefail; export PYTHONPATH=$(PWD); \
$(PIPENV) run pytest -v --junitxml=reports/junit.xml \
--import-mode importlib \
--html=./reports/pytest.html \
-k test_encode \
-s \
-o log_cli=false -o log_level=WARNING -o log_file=/dev/null \
tests/*'


.PHONY: test-ui
test-ui: reports setup-ui
@bash -c 'set -o pipefail; \
Expand Down
2 changes: 0 additions & 2 deletions documentation/docs/get-started/llm-studio-performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,6 @@ problem_type: text_causal_language_modeling
tokenizer:
add_prompt_answer_tokens: false
max_length: 512
max_length_answer: 256
max_length_prompt: 256
padding_quantile: 1.0
training:
batch_size: 2
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
The column in the dataset containing the expected output.

For classification, this needs to be an integer column containing the class label.
For classification, this needs to be an integer column starting from zero containing the class label.
psinger marked this conversation as resolved.
Show resolved Hide resolved

This file was deleted.

psinger marked this conversation as resolved.
Show resolved Hide resolved

This file was deleted.

2 changes: 0 additions & 2 deletions examples/example_oasst2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@ problem_type: text_causal_language_modeling
tokenizer:
add_prompt_answer_tokens: false
max_length: 512
max_length_answer: 256
max_length_prompt: 256
padding_quantile: 1.0
tokenizer_kwargs: '{"use_fast": true, "add_prefix_space": false}'
training:
Expand Down
2 changes: 1 addition & 1 deletion llm_studio/app_utils/sections/chat_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ async def answer_chat(q: Q) -> str:
logger.info(f"Full prompt: {full_prompt}")

inputs = cfg.dataset.dataset_class.encode(
tokenizer, full_prompt, cfg.tokenizer.max_length_prompt, "left"
tokenizer, full_prompt, cfg.tokenizer.max_length, "left"
)
inputs["prompt_input_ids"] = (
inputs.pop("input_ids").unsqueeze(0).to(cfg.environment._device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,11 @@ def __post_init__(self):

@dataclass
class ConfigNLPCausalClassificationTokenizer(ConfigNLPCausalLMTokenizer):
max_length_prompt: int = 512
max_length: int = 512

def __post_init__(self):
super().__post_init__()

self._visibility["max_length_answer"] = -1


@dataclass
class ConfigNLPCausalClassificationArchitecture(ConfigNLPCausalLMArchitecture):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,8 @@ def __post_init__(self):
)
self._possible_values["differential_learning_rate_layers"] = (
possible_values.String(
values=("backbone", "embed"),
allow_custom=False,
values=("backbone", "embed", "head"),
pascal-pfeiffer marked this conversation as resolved.
Show resolved Hide resolved
allow_custom=True,
placeholder="Select optional layers...",
)
)
Expand Down Expand Up @@ -250,17 +250,13 @@ def __post_init__(self):

@dataclass
class ConfigNLPCausalLMTokenizer(DefaultConfig):
max_length_prompt: int = 256
max_length_answer: int = 256
max_length: int = 512
add_prompt_answer_tokens: bool = False
padding_quantile: float = 1.0
tokenizer_kwargs: str = '{"use_fast": true, "add_prefix_space": false}'

def __post_init__(self):
super().__post_init__()
self._possible_values["max_length_prompt"] = (32, 1024 * 16, 32)
self._possible_values["max_length_answer"] = (32, 1024 * 16, 32)
self._possible_values["max_length"] = (32, 1024 * 16, 32)
self._possible_values["padding_quantile"] = (0, 1, 0.01)
self._padding_side = "left"
Expand Down Expand Up @@ -353,7 +349,7 @@ def __post_init__(self):

self._possible_values["num_beams"] = (1, 4, 1)
self._possible_values["temperature"] = (0, 10, 0.05)
self._possible_values["repetition_penalty"] = (1, 10, 0.05)
self._possible_values["repetition_penalty"] = (1, 10, 0.025)
self._possible_values["top_k"] = (0, 100, 1)
self._possible_values["top_p"] = (0.5, 1, 0.05)
self._possible_values["num_history"] = (1, 50, 1)
Expand Down
43 changes: 17 additions & 26 deletions llm_studio/src/datasets/text_causal_language_modeling_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def __getitem__(self, idx: int) -> Dict:
input_text_dict["prompts"] = [
self.parse_prompt(self.cfg, prompt) for prompt in input_text_dict["prompts"]
]
input_text_dict["answers"] = [
self.parse_answer(self.cfg, answer) for answer in input_text_dict["answers"]
]

sample = dict()
system_encoding, prompt_encodings, answer_encodings = self.get_encodings(
Expand Down Expand Up @@ -72,7 +75,7 @@ def __getitem__(self, idx: int) -> Dict:
self.pad_tokens(
answer_encodings[-1],
attention_mask=torch.ones_like(answer_encodings[-1]),
max_length=self.cfg.tokenizer.max_length_answer,
max_length=self.cfg.tokenizer.max_length,
pad_token_id=self.tokenizer.pad_token_id,
direction="right",
prefix="answer_",
Expand All @@ -99,14 +102,6 @@ def __getitem__(self, idx: int) -> Dict:
)
)

# make sure system encoding is always prepended if max_length exceeded
if sample["input_ids"][0] != self.tokenizer.pad_token_id:
sample["input_ids"][: len(system_encoding)] = system_encoding
if self.cfg.dataset.mask_prompt_labels and "labels" in sample.keys():
sample["labels"][: len(system_encoding)] = -100
if sample["prompt_input_ids"][0] != self.tokenizer.pad_token_id:
sample["prompt_input_ids"][: len(system_encoding)] = system_encoding

return sample

@staticmethod
Expand All @@ -122,6 +117,12 @@ def parse_prompt(cfg: Any, prompt: str):
)
return prompt

@staticmethod
def parse_answer(cfg: Any, answer: str):
if cfg.dataset.add_eos_token_to_answer:
answer += cfg._tokenizer_eos_token
return answer

@staticmethod
def parse_system(cfg: Any, system: str):
# no system tokens if empty
Expand Down Expand Up @@ -375,9 +376,6 @@ def get_labels(self, prompt_encodings, answer_encodings):
]
).to(torch.bool)
labels.masked_fill_(prompt_mask, -100)
if self.cfg.dataset.add_eos_token_to_answer:
# eos_token may be equal to pad_token. Add the label back manually.
labels[-1] = self.tokenizer.eos_token_id
if self.cfg.tokenizer.max_length < len(labels):
labels = labels[-self.cfg.tokenizer.max_length :]

Expand Down Expand Up @@ -446,27 +444,16 @@ def augment_data(self, encodings):
def _get_sample_encoding(self, system: str, prompt: str, answer: str) -> List:
if len(system) > 0:
system_encoding = self.encode(
self.tokenizer, system, self.cfg.tokenizer.max_length_prompt, "right"
self.tokenizer, system, self.cfg.tokenizer.max_length, "right"
)["input_ids"]
else:
system_encoding = torch.empty(0)
prompt_encoding = self.encode(
self.tokenizer, prompt, self.cfg.tokenizer.max_length_prompt, "left"
self.tokenizer, prompt, self.cfg.tokenizer.max_length, "left"
)["input_ids"]
max_length_answer = self.cfg.tokenizer.max_length_answer - int(
self.cfg.dataset.add_eos_token_to_answer
)
answer_encoding = self.encode(
self.tokenizer, answer, max_length_answer, "right"
self.tokenizer, answer, self.cfg.tokenizer.max_length, "right"
)["input_ids"]
if self.cfg.dataset.add_eos_token_to_answer:
answer_encoding = torch.cat(
[
answer_encoding,
torch.Tensor([self.tokenizer.eos_token_id]),
],
dim=0,
)

return [system_encoding, prompt_encoding, answer_encoding]

Expand All @@ -482,6 +469,10 @@ def pad_tokens(
sample = {}

if max_length < len(input_ids):
logger.info(
f"Input length of {len(input_ids)} exceeds max_length of "
f"{max_length}, truncating sample."
)
input_ids = input_ids[-max_length:]
attention_mask = attention_mask[-max_length:]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
from transformers import AutoModelForCausalLM

from llm_studio.src.utils.data_utils import batch_padding
from llm_studio.src.utils.modeling_utils import create_nlp_backbone, prepare_lora
from llm_studio.src.utils.modeling_utils import (
create_nlp_backbone,
get_position_ids,
prepare_lora,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -70,6 +74,7 @@ def forward(
output = self.backbone(
input_ids=batch["prompt_input_ids"],
attention_mask=batch["prompt_attention_mask"],
position_ids=get_position_ids(batch["prompt_input_ids"]),
)

output.logits = self.classification_head(output[0][:, -1].float())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from llm_studio.src.utils.modeling_utils import (
create_nlp_backbone,
generate,
get_position_ids,
prepare_lora,
)

Expand Down Expand Up @@ -95,6 +96,7 @@ def forward(
output = self.backbone(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
position_ids=get_position_ids(batch["attention_mask"]),
)

if "labels" in batch:
Expand Down
8 changes: 8 additions & 0 deletions llm_studio/src/models/text_dpo_modeling_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from llm_studio.src.utils.modeling_utils import (
create_nlp_backbone,
generate,
get_position_ids,
prepare_lora,
)

Expand Down Expand Up @@ -140,6 +141,7 @@ def forward(
logits = self.backbone(
input_ids=batch[f"{answer}_input_ids"],
attention_mask=batch[f"{answer}_attention_mask"],
position_ids=get_position_ids(batch[f"{answer}_attention_mask"]),
).logits

logits_dict[answer] = logits
Expand All @@ -157,12 +159,18 @@ def forward(
reference_logits = self.backbone(
input_ids=batch[f"{answer}_input_ids"],
attention_mask=batch[f"{answer}_attention_mask"],
position_ids=get_position_ids(
batch[f"{answer}_attention_mask"]
),
).logits
else:
with torch.no_grad():
reference_logits = self.backbone_orig(
input_ids=batch[f"{answer}_input_ids"],
attention_mask=batch[f"{answer}_attention_mask"],
position_ids=get_position_ids(
batch[f"{answer}_attention_mask"]
),
).logits
outputs[f"{answer}_reference_logps"] = get_batch_logps(
reference_logits,
Expand Down
10 changes: 8 additions & 2 deletions llm_studio/src/utils/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,13 +681,13 @@ def update_backbone_config(config: Any, cfg: DefaultConfigProblemBase):
if config.eos_token_id != tokenizer.eos_token_id:
logger.warning(
"EOS token id not matching between config and tokenizer. "
"Overwriting with tokenizer id."
f"Overwriting with tokenizer id {tokenizer.eos_token_id}."
)
config.eos_token_id = tokenizer.eos_token_id
if config.pad_token_id != tokenizer.pad_token_id:
logger.warning(
"PAD token id not matching between config and tokenizer. "
"Overwriting with tokenizer id."
f"Overwriting with tokenizer id {tokenizer.pad_token_id}."
)
config.pad_token_id = tokenizer.pad_token_id
# no warning needed as not used
Expand Down Expand Up @@ -1137,3 +1137,9 @@ def get_torch_dtype(dtype):
return torch.bfloat16
else:
return torch.float32


def get_position_ids(attention_mask):
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
return position_ids
2 changes: 1 addition & 1 deletion prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def parse_param(cfg, prompt):
print(prompt)

inputs = cfg.dataset.dataset_class.encode(
tokenizer, prompt, cfg.tokenizer.max_length_prompt, "left"
tokenizer, prompt, cfg.tokenizer.max_length, "left"
)
inputs["prompt_input_ids"] = inputs.pop("input_ids").unsqueeze(0).to(DEVICE)
inputs["prompt_attention_mask"] = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ problem_type: text_causal_classification_modeling
tokenizer:
add_prompt_answer_tokens: false
max_length: 512
max_length_answer: 256
max_length_prompt: 256
padding_quantile: 1.0
tokenizer_kwargs: '{"use_fast": true, "add_prefix_space": false}'
training:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,6 @@ problem_type: text_causal_classification_modeling
tokenizer:
add_prompt_answer_tokens: false
max_length: 32
max_length_answer: 16
max_length_prompt: 16
padding_quantile: 1.0
tokenizer_kwargs: '{"use_fast": true, "add_prefix_space": false}'
training:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,6 @@ problem_type: text_causal_language_modeling
tokenizer:
add_prompt_answer_tokens: false
max_length: 512
max_length_answer: 256
max_length_prompt: 256
padding_quantile: 1.0
tokenizer_kwargs: '{"use_fast": true, "add_prefix_space": false}'
training:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,6 @@ problem_type: text_causal_language_modeling
tokenizer:
add_prompt_answer_tokens: false
max_length: 32
max_length_answer: 16
max_length_prompt: 16
padding_quantile: 1.0
tokenizer_kwargs: '{"use_fast": true, "add_prefix_space": false}'
training:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ problem_type: text_causal_classification_modeling
tokenizer:
add_prompt_answer_tokens: false
max_length: 512
max_length_answer: 256
max_length_prompt: 256
padding_quantile: 1.0
tokenizer_kwargs: '{"use_fast": true, "add_prefix_space": false}'
training:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,6 @@ problem_type: text_causal_classification_modeling
tokenizer:
add_prompt_answer_tokens: false
max_length: 32
max_length_answer: 16
max_length_prompt: 16
padding_quantile: 1.0
tokenizer_kwargs: '{"use_fast": true, "add_prefix_space": false}'
training:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,6 @@ problem_type: text_sequence_to_sequence_modeling
tokenizer:
add_prompt_answer_tokens: false
max_length: 512
max_length_answer: 256
max_length_prompt: 256
padding_quantile: 1.0
tokenizer_kwargs: '{"use_fast": true, "add_prefix_space": false}'
training:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,6 @@ problem_type: text_sequence_to_sequence_modeling
tokenizer:
add_prompt_answer_tokens: false
max_length: 32
max_length_answer: 16
max_length_prompt: 16
padding_quantile: 1.0
tokenizer_kwargs: '{"use_fast": true, "add_prefix_space": false}'
training:
Expand Down
Loading
Loading