-
Notifications
You must be signed in to change notification settings - Fork 40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Da 24/reading comprehension #74
Conversation
…on in direct script tun
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good! A few things I think that would be useful to include in the README:
- Examples of downloading and processing a public or local dataset with both approaches.
- Example of calling
trainer.py
- Reference to upstream codebase where the regex code was originally sourced and modified
model_prefix = f"{temp_dir}/domain" | ||
|
||
# Train the SentencePiece model, the model is saved in the temporary directory | ||
spm.SentencePieceTrainer.train(input=text, model_prefix=model_prefix, vocab_size=32000, character_coverage=1.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this missing an import for spm
? When I opened in VS code, it's flagging this line
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--model_name", type=str, required=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would make this code a bit easier to run if a default model is provided
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should also document the types of models that are supported. I tried running:
python -m dalm.datasets.reading_comprehension_generation.synthetic_based --model_name "meta-llama/Llama-2-7b-hf" --input_directory "datasets" --output_directory generated_datasets --state_file "processing_state.txt"
and got error:
AttributeError: 'LlamaTokenizerFast' object has no attribute 'apply_chat_template'
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually after digging in a bit more, it seems like a problem with the version of the transformers library I was using rather than the model.
Running pip install transformers --upgrade
got past the no attribute 'apply_chat_template'
error. In my case, it went from transformers version 4.33.1
-> 4.35.2
As part of this PR, should we add a constraint on the dependencies to require a newer version of the transformers library? (> 4.35
?)
""" | ||
|
||
for index, (gen_text, context) in enumerate( | ||
generate_synthetic_dataset( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With this many args, named arguments would help readability and prevent future errors around arg order:
generate_synthetic_dataset(
model_name = args.model_name,
.. etc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also I think an approach like this would reduce the cognitive load of this code block:
# Create a synthetic dataset generator
dataset_gen = generate_synthetic_dataset(
args.model_name, args.input_directory, args.state_file, args.chunk, args.context_length
)
# Loop over synthetic dataset and extract question and answers
for index, (gen_text, context) in enumerate(dataset_gen):
q_and_a = question_and_answer_extractor(gen_text, context)
.. etc
parser.add_argument("--output_directory", type=str, required=True) | ||
parser.add_argument("--state_file", type=str, required=True) | ||
parser.add_argument("--context_length", type=int, default=2048) | ||
parser.add_argument("--chunk", action="store_true") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could a default=False
be added to make this code easier to run? Eg, less arguments for users to have to think about.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually on second glance, maybe chunk
should default to true. When I ran it on this dataset I got this warning:
This is a friendly reminder - the current text generation call will exceed the model's predefined maximum length (4096). Depending on the model, you may observe exceptions, performance degradation, or nothing at all.
and the model generated complete gibberish:
outputs: [{'generated_text': 'ЋЪЋ.ЪЪЉЁЉЋЋЏЪЉЏЉЉЉЋЋЪЪЋЋЪЉЉЉЋЉЉЏЉЋЪЋЏЋЋЉЉЪЪЏЉЪЪ...'}]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After adding --chunk
it managed to generate a reading comprehension dataset! It definitely seems like we should default to chunk=True
.
parser.add_argument("--model_name", type=str, required=True) | ||
parser.add_argument("--input_directory", type=str, required=True) | ||
parser.add_argument("--output_directory", type=str, required=True) | ||
parser.add_argument("--state_file", type=str, required=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any chance state_file
could be made optional? Meaning if the user didn't pass anything, the code would just create a state file as a temp file somewhere?
State file = something to track processing state?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, is a separate processing state file even needed? If the file has been processed, won't it be present in the output directory? (and likewise, if not processed, not present in the output directory)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was about to remove the state file.
I think because the function has no control over the naming of the output file, keeping a record of the files already processed is the most reliant way of keeping record
Also made a minor change giving user the ability to switch off behaviour
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤔 that said, it would be cleaner to give the caller of the function the reins to state keeping of this nature
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think because the function has no control over the naming of the output file, keeping a record of the files already processed is the most reliant way of keeping record
Yeah for this to work:
If the file has been processed, won't it be present in the output directory? (and likewise, if not processed, not present in the output directory)
The output file would need to have the same name as the input file, or use a "content addressable" scheme of some sort (hash of content). Currently it uses an index in the filename.
If there's a 1:1 mapping between input and output file, then using the same filename in both input and output directories should make it easy to track the processing state without needing an extra file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
give the caller of the function the reins to state keeping of this nature
I think only a small percentage of users would want any control here. So it would be better to have the default behavior "just work" while allowing power users to override it if needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the code is too eagerly adding files to the state file, because I was able to end up in a state where:
- My output directory is empty
- My only input file (
pubmed25.csv
) has been added to the state file
root@04e369b03ae7:/# ls -alh generated_datasets/
total 4.0K
drwxr-xr-x 2 root root 10 Nov 20 13:14 .
drwxr-xr-x 1 root root 4.0K Nov 20 13:26 ..
root@04e369b03ae7:/# strings processing_state.txt
processed_files
pubmed25.csv
Eliminating the separate state file (as suggested above) would eliminate this drift. Or alternatively, a file should only be added to the state file after it successfully generated the corresponding output file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At the time you reported this, a csv file was not considered a valid input, it's very likely the processing was correct and the one file didn't make it through because the q&a parser didn't parse anything out from output of the llm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@metric-space Ok I will try this again and see if I observe the same issue
from tqdm import tqdm | ||
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments | ||
|
||
from trl import SFTTrainer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think trl needs to get added to the project.toml
dependencies
return first_prompt + chat_chain | ||
|
||
|
||
# TODO: type hinting is very necessary here |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah type hinting would be very helpful
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--model_name", type=str, default="HuggingFaceH4/zephyr-7b-alpha", help="the model name") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be zephyr-7b-beta
?
"--dataset_name", type=str, default="arcee-ai/azure-reading-comprehension-dataset", help="the dataset name" | ||
) | ||
parser.add_argument("--split", type=str, default="train", help="the split to use") | ||
parser.add_argument("--size_valid_set", type=int, default=4000, help="the size of the validation set") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems easier for users to express this in terms of percentage of training set rather than absolute size. Eg, 20%
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this argument for streaming, I believe in this scenario a preset size is required
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah got it, for streaming we don't know the full size of the dataset until we load it. I think its confusing and misleading though to advertise this parameter without mentioning the fact that its streaming only.
WDYT of changing to this?
parser.add_argument(
"--size_valid_set_streaming",
type=int,
default=4000,
help="the size of the validation set when used in streaming mode, ignored otherwise"
)
return train_dataset, valid_dataset | ||
|
||
|
||
def chars_token_ratio(dataset, tokenizer, formatting_func, nb_examples=400): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does nb
stand for? If "number", then I think num_examples
is a bit clearer
) | ||
parser.add_argument("--split", type=str, default="train", help="the split to use") | ||
parser.add_argument("--size_valid_set", type=int, default=4000, help="the size of the validation set") | ||
parser.add_argument("--streaming", type=bool, default=False, help="whether to stream the dataset") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason not to just default to True
to keep memory footprint low by default? What are the downsides of streaming the dataset?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't tested out the advantages of streaming over just vanilla loading. So hard for me to go for set that as a default for the user
train_data = dataset.skip(size_valid_set) | ||
train_data = train_data.shuffle(buffer_size=shuffle_buffer, seed=None) | ||
else: | ||
dataset = dataset.train_test_split(test_size=0.05, seed=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like it ignores size_valid_set
in this code path - is that a bug?
parser.add_argument("--weight_decay", type=float, default=0.05, help="the weight decay") | ||
parser.add_argument("--optimizer_type", type=str, default="paged_adamw_32bit", help="the optimizer type") | ||
|
||
parser.add_argument("--output_dir", type=str, default="./generator_finetuned_model", help="the output directory") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should add this directory to .gitignore
so that it doesn't show up in the list of changed files.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It wouldn't show up in the list of changes files. It would show up in the list of untracked files no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah but that's just as bad. Users that want to contribute would be confused by the untracked files being created when they ran git status
. Also it would be a nuisance to core committers to have ignore them when running git add
streaming=streaming, | ||
) | ||
if streaming: | ||
print("Loading the dataset in streaming mode") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should use loggers instead of raw print. When running in cloud training environments, only the logger output is captured and the print statements get lost.
Here's an example from the DALM repo of creating a logger:
and using it:
"--output_dir", type=str, help="directory of the output reading comprehension texts", default="./output" | ||
) | ||
parser.add_argument( | ||
"--ori_spm_path", type=str, help="path of the original sentencepiece model", default="./tokenizers/general.spm" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could the code auto-download that somehow? If not, it seems like something that is good to add to the README example. Ditto for --domain_spm_path
and --domain_tokenizer_training_text
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good one. But this domain tokenizer needs to get trained with user data.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the ori
in ori_spm_path
referring to "original"? If so we should definitely change to orig
instead of ori
, which makes the meaning a lot clearer at the cost of one char
2. regex-based-gen now creates a domain tokenizer if both domain sentencepiece model and domain text (explicitly) is not given 3. attempt at pipeline
"--output_dir", type=str, help="directory of the output reading comprehension texts", default="./output" | ||
) | ||
parser.add_argument( | ||
"--ori_spm_path", type=str, help="path of the original sentencepiece model", default="./tokenizers/general.spm" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good one. But this domain tokenizer needs to get trained with user data.
model_output_dir: Optional[str] = "model_output_dir", | ||
log_freq: Optional[int] = 100, | ||
neftune_noise_alpha: Optional[int] = 5, | ||
log_with: Optional[str] = "wandb", | ||
generation_state_file: Optional[str] = "generation_state.pkl", | ||
): | ||
domain_spm = spm.SentencePieceProcessor(model_file=domain_spm_path) | ||
ori_spm = spm.SentencePieceProcessor(model_file=general_spm_path) | ||
|
||
# generate regex based reading comprehension dataset | ||
if comprehension_type in [SynthMode.REGEX, SynthMode.BOTH]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use both if possible ?
|
||
dataset.save_to_disk("reading_comprehension_dataset") # TODO: change name from | ||
|
||
del dataset, a1, a2 # TODO: change name | ||
# del dataset # TODO: change name | ||
|
||
train_generator( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we are using deep-speed during the prod training ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using del
is suspect and makes me think the code isn't optimally structured (functions too big).
Is there a way to refactor the code so that these automatically go out of scope after they are no longer useful?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh sorry this is just removing some del
statements and adding a comment.
To have it automatically deleted, you could factor out a separate function:
def write_dataset(list_of_data, dataset_name):
dataset = datasets.Dataset.from_list(list_of_data)
dataset.save_to_disk(dataset_name)
and in the caller:
write_dataset(list_of_data, "reading_comprehension_dataset")
train_generator(
model_name=model_name,
dataset_name="reading_comprehension_dataset",
.. etc
)
As soon as write_dataset()
returns, the dataset
local variable will go out of scope and be GC'd.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also I just noticed something suspect in this code block:
if comprehension_type == SynthMode.BOTH:
dataset = datasets.Dataset.from_list(list_of_data)
dataset.save_to_disk("reading_comprehension_dataset") # TODO: change name from
I don't see dataset
ever being created if if comprehension_type != SynthMode.BOTH
, and it looks like that code will throw an exception unless it is set to SynthMode.BOTH
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we use packing = True in the SFT trainer do we still need to use ConstatnLengthDataset object ?
https://huggingface.co/docs/trl/sft_trainer#packing-dataset--constantlengthdataset-
It defaults to
but the user can override the value. If it makes the code too complex (eg, dynamically using a |
|
||
|
||
def gen_prompt(text): | ||
prompt = f"There are 4 types of reading comprehension tasks. The point of reading comprehension tasks is to be assigned a text and questions to prompt answers so as to test conceptual and procedural knowledge present in the text. The four types of reading comprehension tasks are : 1. complete-the-sentence Q&A TASK 2.true/false Q&A TASK (description: a sentence is posed and the user is asked to state the correctness of the statement)3. frame a sentence with domain specific keywords(these keywords are required to be present in the text) Q&A TASK 4. Normal questions and answer Task (description: longform Q&A to test procedural and conceptual knowledge). An example of all four tasks given an example text is as follows: \n EXAMPLE TEXT: The insights into the mechanisms of memory consolidation during the sleep processes in human and animal brain led to other biologically inspired approaches. While declarative memories are in the classical picture consolidated by hippocampo-neocortical dialog during NREM phase of sleep, some types of procedural memories were suggested not to rely on the hippocampus and involve REM phase of the sleep. This inspired models where internal representations (memories) created by previous learning are spontaneously replayed during sleep-like periods in the network itself (i.e. without help of secondary network performed by generative replay approaches mentioned above).\nQuestion: [type: true/false] Is the following sentence true? all types of procedural memories rely on the hippocampus\nAnswer: False. The text clearly states there are some types of procedural memories not reliant on the hippocampus\n--------\nQuestion [type: complete-the-sentence] Complete the following sentence: The insights into ____ in human and animal brain led to other _____ approaches\nAnswer: The insights into the mechanisms of memory consolidation during the sleep processes in human and animal brain led to other biologically inspired approaches\n------\nQuestion [type 3 domain-keywords] Make a sentence with the following keywords 'hippocampo-neocortical', 'declarative' and 'NREM'\nAnswer: declarative memories are in the classical picture consolidated by hippocampo-neocortical dialog during NREM phase of sleep\n-------\nQuestion [type: normal q&a] Some types of procedural memories were suggested not to rely on the hippocampus and involve REM phase of the sleep. What did this go on to inspire?\nAnswer This inspired models where internal representations (memories) created by previous learning are spontaneously replayed during sleep-like periods in the network itself [END OF EXAMPLE]\n\n Similar to the above, could you craft 4 different reading comprehension tasks (make sure your output is a list of question answer pairs and each question is labelled QUESTION and answer is labelled ANSWER and there is one question and answer per task) based solely and completely focused on the following TEXT: {text}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This string should be word wrapped so the line isn't so long
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--model_name", type=str, required=True) | ||
parser.add_argument("--input_directory", type=str, required=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think --input-directory
would be slightly easier to read, and would match the same convention used by https://huggingface.co/docs/text-generation-inference/basic_tutorials/launcher
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While the suggestion makes sense, every other place in our codebase has it has snake case, perhaps to do this while we schedule an upgrade might make sense
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok good call. Let's try to do it for the entire codebase at some point to be consistent.
parser.add_argument("--context_length", type=int, default=2048) | ||
parser.add_argument("--chunk", action="store_true") | ||
|
||
args = parser.parse_args() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To fail-fast, it would be good to check if all required args were passed:
if not args.model_name:
parser.error("--model_name is a required argument")
... etc
Otherwise what might happen is that it takes several minutes to run, then throws an exception because the output_directory
argument was missing. This would be frustrating for the user.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nevermind, argparse
should take care of this for us as long as we mark all args as required=True
(see other comment regarding model_name
)
parser.add_argument("--model_name", type=str, default="HuggingFaceH4/zephyr-7b-alpha", help="the model name") | ||
parser.add_argument("--log_with", type=str, default="wandb", help="use 'wandb' to log with wandb") | ||
parser.add_argument( | ||
"--dataset_name", type=str, default="arcee-ai/azure-reading-comprehension-dataset", help="the dataset name" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should the help message be more descriptive and say something like:
The dataset hugging face repo name or the local directory where the dataset is stored? Must be in reading comprehension format.
I tried running the pipeline with the following command:
but hit this error:
Full error stacktrace
|
1. Via regex based methods that combs the input data for match and aligns them into questions and answers | ||
2. Via prompting a large language model to come up with questions and answers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a user, how should I decide which approach to use? Here's a stab:
- Use regex based reading comprehension dataset generation when it works on that dataset
- Otherwise fallback to the slower synthetic data generation approach
output_dataset_name: str, | ||
input: str, | ||
model_output_dir: str, | ||
generation_state_file: str, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a default value here? This has a default value for the CLI, but not for the API.
"bitsandbytes", | ||
"typer>=0.9.0,<1.0", | ||
"pydantic==1.10.9", # Sync w/ other platform components | ||
"pysbd", | ||
"sentencepiece" | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is also a dependency on wandb
due to this line:
log_with: str = "wandb",
Can you add that to the list of dependencies?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the best way forward is to remove this default, given the options of trackers for accelerator makes sense to not assume anything about the user
What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah agreed
if os.path.isdir(directory_or_file): | ||
for file in os.listdir(directory_or_file): | ||
file_path = os.path.join(directory_or_file, file) | ||
if os.path.isfile(file_path): # Ensures that we are reading files | ||
try: | ||
with open(file_path, "r", encoding="utf-8") as file_contents: | ||
contents = file_contents.read() | ||
except UnicodeDecodeError: | ||
with open(file_path, "r", encoding="utf-8", errors="replace") as file_contents: | ||
contents = file_contents.read() | ||
|
||
yield file, contents |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a very nasty bug here: if you pass a csv file, the csv aware code below kicks in directory_or_file.endswith(".csv") and csv_column:
and will work as expected an emit rows of the CSV.
However, if you instead pass a directory that contains CSV files, it will treat the CSV file as raw text and return the entire text (or an entire chunk) rather than rows from the CSV file.
The fix is to treat CSV files the same whether a single CSV is passed in or a directory containing CSV files is passed in.
if not os.path.exists(output_dataset_name): | ||
dataset.save_to_disk(output_dataset_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bit surprising behavior. If the caller passes in an empty directory that exists on disk, instead of saving the generated dataset to that directory, it will discard the dataset (not save it) and then later fail with an error:
FileNotFoundError: Directory output_dataset is neither a `Dataset` directory nor a `DatasetDict` directory.
A few ideas on how to fix:
- Just overwrite the existing dataset with the new dataset (remove the
if not os.path.exists()
check) - Instead of asking the user to pass the name of the dataset, just generate a unique dataset directory and use that, then inform the user where the dataset was generated.
logger.info(f"Total files missed: {generation_state['files_missed']} out of {generation_state['total_files']}") | ||
logger.info(f"Total files processed: {generation_state['total_files']}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These stats are really helpful! The term "files" is a bit misleading, but its clear from the output that can also represent rows.
If the dataset is empty however, we should throw an exception.
if not dataset:
raise Exception("Failed to generate dataset")
This can happen when you have a small dataset and each chunk fails to generate a question/answer pair.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe it will fail on it's own
>>> datasets.Dataset.from_list([])
Dataset({
features: [],
num_rows: 0
})
>>> a = datasets.Dataset.from_list([])
>>> a.save_to_disk('./my_dataset')
Saving the dataset (0/1 shards): 0 examples [00:00, ? examples/s]
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/xxx/Projects/trial/lib/python3.11/site-packages/datasets/arrow_dataset.py", line 1530, in save_to_disk
for job_id, done, content in Dataset._save_to_disk_single(**kwargs):
File "/home/xxx/Projects/trial/lib/python3.11/site-packages/datasets/arrow_dataset.py", line 1571, in _save_to_disk_single
num_examples, num_bytes = writer.finalize()
^^^^^^^^^^^^^^^^^
File "/home/xxx/Projects/trial/lib/python3.11/site-packages/datasets/arrow_writer.py", line 599, in finalize
raise SchemaInferenceError("Please pass `features` or at least one example when writing data")
datasets.arrow_writer.SchemaInferenceError: Please pass `features` or at least one example when writing data
but the message won't be clear so your recommendation stands
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I hit that error testing with a small dataset and it was a pretty cryptic failure message. The idea is just to make it a lot clearer when that happens.
TrainingArguments, | ||
) | ||
from trl import SFTTrainer # type: ignore[import] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add import wandb
here? This will fail fast when wandb
is not installed, rather than failing in the middle of a pipeline with this stacktrace:
File "/opt/conda/lib/python3.10/site-packages/dalm/pipelines/reading_comprehension_pipeline.py", line 180, in pipeline
train_generator(
File "/opt/conda/lib/python3.10/site-packages/dalm/training/generator_only/trainer.py", line 240, in train_generator
trainer = SFTTrainer(
File "/opt/conda/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 252, in __init__
super().__init__(
File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 525, in __init__
self.callback_handler = CallbackHandler(
File "/opt/conda/lib/python3.10/site-packages/transformers/trainer_callback.py", line 306, in __init__
self.add_callback(cb)
File "/opt/conda/lib/python3.10/site-packages/transformers/trainer_callback.py", line 323, in add_callback
cb = callback() if isinstance(callback, type) else callback
File "/opt/conda/lib/python3.10/site-packages/transformers/integrations/integration_utils.py", line 669, in __init__
raise RuntimeError("WandbCallback requires wandb to be installed. Run `pip install wandb`.")
RuntimeError: WandbCallback requires wandb to be installed. Run `pip install wandb`.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think given the number of options that exists for tracker, separating ourselves from a certain tracker may be the best way forward
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, can we just default to no tracking? Then users can add it as needed, including libs
- remove default wandb value - new logging statement - change of terminology for logging (files -> texts)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking great overall! In my testing I'm close to having it working end-to-end via calling the python API functions.
The main changes I think we still need is to disable wandb by default (as discussed)
A few other changes I'd suggest (but aren't necessarily blockers)
- Default to
SynthMode.LLM
- but that's up for debate. What do folks think? @shamanez @Jacobsolawetz - It does seem odd to me to not mention any of this new code in the main PR of the repo. We can always add this later in a follow-up PR, but it seems like a good time to add it. Thoughts?
@@ -0,0 +1,1266 @@ | |||
# Modified version of code from https://github.com/microsoft/LMOps/blob/main/adaptllm/utils/read.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
(Same, as above i.e assuming you have your dataset as a csv file with the column `text` containing the raw texts) | ||
|
||
Please note there is the choice of passing in a domain sentence model in addition, but this is not required as | ||
the script will train a domain speicifc sentencepiece model on the input corpus |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo: should be "specific"
``` | ||
|
||
the output directory serves as a temporary holding place of all generated data before it can be made a dataset. | ||
The generation process usually takes time. SO every step is taken to ensure if the process is interrupted, once back running |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The generation process usually takes time
I would change this to a concrete estimate like: "when using synthetic data generation, it takes approximately 10 minutes for each 100 rows of data you have". Even if it's not 100% accurate it's a lot more helpful than "takes some time".
SO every step is taken to ensure if the process is interrupted, once back running
A more succinct way to state this: "The script was designed to be idempotent"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A more succinct way to state this: "The script was designed to be idempotent"
I think idempotency is a property of the current script/function and state-tracking to ensure we can resume where we left off, are connected but two distinct concepts. I am unconvinced that the recommended statement works here
The way espoused by the paper is generating reading comprehension questions and answers based on the raw corpora | ||
and training a llm on said generated dataset can enhance its domain adaptiveness | ||
|
||
We have two ways of generating reading comprehension data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is an extra space between generating and reading
|
||
|
||
if __name__ == "__main__": | ||
logger.setLevel(logging.INFO) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should already be being set (see earlier comment)
"If local, be sure to set the local_dataset flag" | ||
), | ||
) | ||
parser.add_argument("--local_dataset", action="store_true", help="whether to use a local dataset") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious: is this the normal way of doing things in HF?
I thought the normal hugging face approach is to "just figure it out". Eg, try to load a local dataset, then fallback to loading from HF hub (or vice-versa).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Somehow it isn't the case, I don't understand why. HF datasets errors and asks you to switch methods to load local. This param is just surfacing this
|
||
|
||
if __name__ == "__main__": | ||
logging.basicConfig(level=logging.INFO) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be not needed (see earlier comments)
dalm/pipelines/README.md
Outdated
--input input.csv --csv_column text \ | ||
--output_dataset_name combined \ | ||
--general_spm_path tokenizers/general.spm \ | ||
--llm_synth_model_name meta-llama/Llama-2-13b \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this be meta-llama/Llama-2-13b-chat-hf
? In order to:
- Use the RLHF'd chat model instead of the raw pretrained model
- Use the HF variation to avoid errors like this
@metric-space I just realized the problem. In the HF TrainingArgs docs:
And on the training container it currently has I think whats happening is that
And then we pass an empty string to |
Co-authored-by: Traun Leyden <traun.leyden@gmail.com>
Co-authored-by: Traun Leyden <traun.leyden@gmail.com>
Co-authored-by: Traun Leyden <traun.leyden@gmail.com>
Co-authored-by: Traun Leyden <traun.leyden@gmail.com>
Co-authored-by: Traun Leyden <traun.leyden@gmail.com>
Co-authored-by: Traun Leyden <traun.leyden@gmail.com>
@tleyden @Jacobsolawetz the main |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! 🚀
I retested the latest round of changes and everything is working.
I will drive a follow-up PR to address #78
Yeah, agreed. I filed a tracking ticket for that: #79 |
What this is
Reading comprehension is based off the idea that we can enhance the domain adaptiveness of the generator via synthetically augmenting input data.
I want to emphasis this targets the generator and not the retriever
Integration of Reading Comprehension
Notes:
Why not stream the dataset by default ?
I have no idea what the behavior is when streaming actually works as intended, so hard for me to recommend something that I have no idea what the benefits are
commands to get things running
This assumes you have a csv file with raw texts in a certain column (for this example let's say
text
)