Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support custom field for completion from yml #580

Merged
merged 3 commits into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- path: EleutherAI/pile
name: enron_emails
type: completion # format from earlier
field: text # Optional[str] default: text, field to use for completion data
winglian marked this conversation as resolved.
Show resolved Hide resolved

# huggingface repo with multiple named configurations/subsets
datasets:
Expand Down
5 changes: 5 additions & 0 deletions src/axolotl/prompt_strategies/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Module to load prompt strategies."""

import importlib
import inspect

from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig

Expand All @@ -16,6 +17,10 @@ def load(strategy, tokenizer, cfg, ds_cfg):
load_kwargs = {}
if strategy == "user_defined":
load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg)
else:
sig = inspect.signature(func)
if "ds_cfg" in sig.parameters:
load_kwargs["ds_cfg"] = ds_cfg
return func(tokenizer, cfg, **load_kwargs)
except Exception: # pylint: disable=broad-exception-caught
return None
20 changes: 20 additions & 0 deletions src/axolotl/prompt_strategies/completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""
Basic completion text
"""
from typing import Any, Dict, Optional

from axolotl.prompt_tokenizers import CompletionPromptTokenizingStrategy
from axolotl.prompters import CompletionPrompter


def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
strat = CompletionPromptTokenizingStrategy(
CompletionPrompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
if ds_cfg and "field" in ds_cfg:
strat.field = ds_cfg["field"]

return strat
25 changes: 24 additions & 1 deletion src/axolotl/prompt_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,31 @@ class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
Tokenizing strategy for Completion prompts.
"""

_field: str = "text"

@property
def field(self) -> str:
return self._field

@field.setter
def field(self, new_field: str):
self._field = new_field

def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
prompt[self.field],
"",
"",
)

def tokenize_prompt(self, prompt):
full_prompt = self._build_full_prompt(prompt["text"], None, None)
(
instruction,
_,
_,
) = self.parse_instruction_fields(prompt)

full_prompt = self._build_full_prompt(instruction, None, None)
tokenized_full_prompt = self._tokenize(full_prompt)

return tokenized_full_prompt
Expand Down
11 changes: 0 additions & 11 deletions src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
AlpacaMultipleChoicePromptTokenizingStrategy,
AlpacaPromptTokenizingStrategy,
AlpacaReflectionPTStrategy,
CompletionPromptTokenizingStrategy,
GPTeacherPromptTokenizingStrategy,
JeopardyPromptTokenizingStrategy,
OpenAssistantPromptTokenizingStrategy,
Expand All @@ -31,7 +30,6 @@
)
from axolotl.prompters import (
AlpacaPrompter,
CompletionPrompter,
GPTeacherPrompter,
JeopardyPrompter,
MultipleChoiceConcisePrompter,
Expand Down Expand Up @@ -327,15 +325,6 @@ def for_d_in_datasets(dataset_configs):
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "completion":
ds_strategy = CompletionPromptTokenizingStrategy(
CompletionPrompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
else:
suffix = ""
if ":load_" in d.type:
Expand Down