Skip to content

Commit

Permalink
Merge pull request #111 from OpenAccess-AI-Collective/sharegpt-token-…
Browse files Browse the repository at this point in the history
…tests

add unit test for sharegpt tokenization
  • Loading branch information
winglian authored May 30, 2023
2 parents e43bcc6 + e65aeed commit a924a33
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 41 deletions.
67 changes: 33 additions & 34 deletions ds_config.json
Original file line number Diff line number Diff line change
@@ -1,58 +1,57 @@
{
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 0,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 0,
"stage3_max_reuse_distance": 0,
"stage3_gather_16bit_weights_on_model_save": true
},
"bf16": {
"enabled": "auto"
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"optimizer": {
"type": "Adam",
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"betas": [
0.9,
0.999
],
"eps": 1e-8,
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupDecayLR",
"type": "OneCycle",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto"
"cycle_min_lr": 0.00001,
"cycle_max_lr": 0.00003,
"cycle_first_step_size": 120
}
},
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true,
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"contiguous_gradients": true,
"reduce_bucket_size": "auto",
"reduce_scatter": true,
"stage3_max_live_parameters": 0,
"stage3_max_reuse_distance": 0,
"stage3_gather_16bit_weights_on_model_save": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 5,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false,
"round_robin_gradients": true
"wall_clock_breakdown": false
}
10 changes: 5 additions & 5 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def load_model(
load_in_4bit=True,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_compute_dtype=torch_dtype,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
Expand Down Expand Up @@ -174,7 +174,7 @@ def load_model(
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
device_map="auto" if cfg.world_size == 1 else cfg.device_map,
**model_kwargs,
)
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
Expand Down Expand Up @@ -273,13 +273,13 @@ def load_model(
if (
torch.cuda.device_count() > 1
and int(os.getenv("WORLD_SIZE", "1")) > 1
and cfg.gptq
and (cfg.gptq or cfg.load_in_4bit)
):
# llama is PROBABLY model parallelizable, but the default isn't that it is
# so let's only set it for the 4bit, see
# https://github.com/johnsmith0031/alpaca_lora_4bit/blob/08b3fca4a4a9e0d3945be1bab4529f100a428636/finetune.py#L130-L133
model.is_parallelizable = True
model.model_parallel = True
setattr(model, 'is_parallelizable', True)
setattr(model, 'model_parallel', True)

requires_grad = []
for name, param in model.named_parameters(recurse=True):
Expand Down
5 changes: 3 additions & 2 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
output_dir=cfg.output_dir,
save_total_limit=3,
load_best_model_at_end=True
if cfg.val_set_size > 0
if cfg.load_best_model_at_end is not False # if explicitly set to False, it should be resort to False
and cfg.val_set_size > 0
and save_steps is not None
and save_steps % eval_steps == 0
and cfg.load_in_8bit is not True
Expand Down Expand Up @@ -218,7 +219,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):

trainer_cls = (
OneCycleLRSchedulerTrainer
if cfg.lr_scheduler == "one_cycle" and cfg.fsdp
if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora")
else transformers.Trainer
)
trainer = trainer_cls(
Expand Down
52 changes: 52 additions & 0 deletions tests/fixtures/conversation.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
{
"conversations": [
{
"from": "human",
"value": "how could i implement a minesweeper algorithm that utilises algebraic topology to solve boards?"
},
{
"from": "gpt",
"value": "It's possible to use algebraic topology to solve Minesweeper, but it would likely be quite complex. Algebraic topology is a branch of mathematics that studies the properties of spaces that are preserved under continuous deformations, such as stretching and bending. In the context of Minesweeper, this could be used to analyze the structure of the game board and determine the locations of mines based on certain topological invariants. However, implementing this would require a deep understanding of both algebraic topology and the game of Minesweeper. It's not something that can be explained in a simple answer."
},
{
"from": "human",
"value": "could you give me the gist of how it could though?"
},
{
"from": "gpt",
"value": "Sure. The basic idea would be to use algebraic topology to analyze the structure of the Minesweeper game board and determine the locations of mines based on certain topological invariants. This could involve representing the game board as a topological space and using topological invariants such as homotopy groups and Betti numbers to infer information about the location of mines on the board. However, as I mentioned before, implementing this would require a deep understanding of both algebraic topology and the game of Minesweeper, so it's not something that can be explained in a simple answer."
},
{
"from": "human",
"value": "give me a specific example"
},
{
"from": "gpt",
"value": "As a specific example, suppose we have a Minesweeper game board with the following configuration:\n```\n 1 1 1 1 1\n 1 ? ? ? 1\n 1 ? ? ? 1\n 1 ? ? ? 1\n 1 1 1 1 1\n```\nIn this case, the game board can be represented as a topological space, where each square on the board corresponds to a point in the space. The mines on the board can be thought of as \"holes\" in the space, which are points that are not part of the space. Using algebraic topology, we can analyze the structure of this space and determine the locations of the mines based on certain topological invariants.\n\nFor instance, one way to do this is to compute the homotopy groups of the space. Homotopy groups are a way of describing the holes in a space, and they can be used to infer information about the locations of the mines on the Minesweeper board. In the case of the board above, we can compute the first homotopy group of the space, which describes the number and \"shape\" of the holes in the space. In this case, the first homotopy group would tell us that there are four holes in the space, arranged in a square shape.\n\nBased on this information, we can infer that the mines in the game are located in the four corners of the board, as shown below:\n```\n 1 1 1 1 1\n 1 M ? ? 1\n 1 ? ? ? 1\n 1 ? ? M 1\n 1 1 1 1 1\n```\nOf course, this is just one example of how algebraic topology could be used to solve Minesweeper. There are many other possible approaches, and the specific details of how to implement them would depend on the specific game board and the topological invariants being used. As I mentioned before, implementing a solution like this would require a deep understanding of both algebraic topology and the game of Minesweeper."
},
{
"from": "human",
"value": "assuming i was someone with a deep understanding of both, could you explain the same concept to me with that assumption?"
},
{
"from": "gpt",
"value": "Sure. Assuming you have a deep understanding of both algebraic topology and the game of Minesweeper, here is how you might use algebraic topology to solve a Minesweeper game board:\n\n1. Represent the Minesweeper game board as a topological space. Each square on the board corresponds to a point in the space, and the mines on the board can be thought of as \"holes\" in the space, which are points that are not part of the space.\n2. Use topological invariants to infer information about the locations of the mines on the board. There are many possible topological invariants that could be used for this, such as homotopy groups, Betti numbers, or persistent homology.\n3. Use the information inferred from the topological invariants to determine the locations of the mines on the board. For example, if the first homotopy group of the space indicates that there are four holes in the space, arranged in a square shape, then we can infer that the mines are located in the four corners of the board.\n4. Use this information to solve the Minesweeper game. Once you know the locations of the mines, you can use this information to avoid them and clear the board.\n\nOf course, this is just one possible approach to using algebraic topology to solve Minesweeper. There may be other ways to do it, depending on the specific game board and the topological invariants being used. However, this should give you a general idea of how algebraic topology can be used to solve Minesweeper."
},
{
"from": "human",
"value": "how would i do it with persistent homology or any of the other topological invariants?"
},
{
"from": "gpt",
"value": "To use persistent homology or any of the other topological invariants to solve a Minesweeper game board, you would follow the same basic steps as described above, but with some differences in the specific details.\n\nFor instance, here is how you might use persistent homology to solve a Minesweeper game board:\n\n1. Represent the Minesweeper game board as a topological space. Each square on the board corresponds to a point in the space, and the mines on the board can be thought of as \"holes\" in the space, which are points that are not part of the space.\n2. Compute the persistent homology of the space. Persistent homology is a way of describing the holes in a space over time, and it can be used to infer information about the locations of the mines on the Minesweeper board.\n3. Use the information from the persistent homology to determine the locations of the mines on the board. For example, if the persistent homology indicates that there are four holes in the space, arranged in a square shape, then we can infer that the mines are located in the four corners of the board.\n4. Use this information to solve the Minesweeper game. Once you know the locations of the mines, you can use this information to avoid them and clear the board.\n\nOf course, this is just one possible approach to using persistent homology to solve Minesweeper. There may be other ways to do it, depending on the specific game board and the specific details of the persistent homology computation. However, this should give you a general idea of how persistent homology can be used to solve Minesweeper."
},
{
"from": "human",
"value": "how could i implement this in python?"
},
{
"from": "gpt",
"value": "To implement a Minesweeper solver that uses algebraic topology in Python, you would need to do the following:\n\n1. Install a Python package for algebraic topology. There are several options available, such as Gudhi, Dionysus, or scikit-tda. These packages provide algorithms and data structures for computing topological invariants, such as homotopy groups, Betti numbers, or persistent homology.\n2. Represent the Minesweeper game board as a topological space. This could be done using a data structure provided by the algebraic topology package, such as a simplicial complex or a Vietoris-Rips complex. Each square on the board would correspond to a point in the space, and the mines on the board would be represented as \"holes\" in the space.\n3. Compute the topological invariants of the space. This could be done using algorithms provided by the algebraic topology package. For example, you could compute the persistent homology of the space using the persistent homology algorithm.\n4. Use the information from the topological invariants to determine the locations of the mines on the board. This could be done using a combination of mathematical reasoning and programming. For example, if the persistent homology indicates that there are four holes in the space, arranged in a square shape, then you could use a piece of code to infer that the mines are located in the four corners of the board.\n5. Use this information to solve the Minesweeper game. Once you know the locations of the mines, you could use a piece of code to automatically click on the safe squares and clear the board.\n\nOf course, this is just one possible approach to implementing a Minesweeper solver that uses algebraic topology in Python. There may be other ways to do it, depending on the specific details of the implementation. However, this should give you a general idea of how to get started with such a project."
}
]
}
1 change: 1 addition & 0 deletions tests/fixtures/conversation.tokenized.json

Large diffs are not rendered by default.

47 changes: 47 additions & 0 deletions tests/test_prompt_tokenizers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import json
import logging
import unittest
from pathlib import Path

from transformers import AutoTokenizer

from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
from axolotl.prompters import ShareGPTPrompter

logging.basicConfig(level="INFO")


class TestPromptTokenizationStrategies(unittest.TestCase):
def setUp(self) -> None:
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens(
{
"bos_token": "<s>",
"eos_token": "</s>",
"unk_token": "<unk>",
}
)

def test_sharegpt_integration(self):
print(Path(__file__).parent)
with open(Path(__file__).parent / "fixtures/conversation.json", "r") as fin:
data = fin.read()
conversation = json.loads(data)
with open(Path(__file__).parent / "fixtures/conversation.tokenized.json", "r") as fin:
data = fin.read()
tokenized_conversation = json.loads(data)
prompter = ShareGPTPrompter("chat")
strat = ShareGPTPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
example = strat.tokenize_prompt(conversation)
for fields in ["input_ids", "attention_mask", "labels"]:
self.assertEqual(len(example[fields]), len(tokenized_conversation[fields]))
self.assertEqual(example[fields], tokenized_conversation[fields])


if __name__ == "__main__":
unittest.main()

0 comments on commit a924a33

Please sign in to comment.