Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finetune fastchat with Zephyr format #2918

Open
nghidinhit opened this issue Jan 12, 2024 · 2 comments
Open

Finetune fastchat with Zephyr format #2918

nghidinhit opened this issue Jan 12, 2024 · 2 comments

Comments

@nghidinhit
Copy link

How can I perform fine-tuning on FastChat using the Zephyr format? I've noticed that within the preprocess function, there is hardcoded logic intended for fine-tuning with the Vicuna template.

image

@jwong8314
Copy link

jwong8314 commented Jan 21, 2024

You can just change the "vicuna" to "zephyr". The bigger issue is the rest of preprocess where you need to rewrite the hardcoded mask generation.

I wrote something like the following to support dolphin's format:

 if dolphin_format:
            turns = conversation.split(sep)

            cur_len = 1
            target[:cur_len] = IGNORE_TOKEN_ID
            for i, turn in enumerate(turns):
                if turn == "":
                    break
                turn_len = len(tokenizer(turn).input_ids)           
                
                if turn.split("\n")[0] != conv.roles[1]:
                    target[cur_len: cur_len + turn_len + 2] = IGNORE_TOKEN_ID
                elif  turn.split("\n")[0] == conv.roles[1]:
                    size_of_role = len(tokenizer(conv.roles[1]).input_ids)
                    target[cur_len: cur_len + size_of_role]  = IGNORE_TOKEN_ID
                    target[cur_len + turn_len: cur_len+turn_len+2]  = IGNORE_TOKEN_ID

                cur_len += turn_len + 2
            target[cur_len:] = IGNORE_TOKEN_ID

If I have some time later, I'll push a PR that should make finetuning agnostic to formatting. In the meantime, feel free to just switch to dolphin's format.

@nghidinhit
Copy link
Author

nghidinhit commented Jan 22, 2024

This is my custom code:
`
def preprocess(
sources,
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
conv = get_conversation_template("zephyr")

roles = {"human": conv.roles[0], "gpt": conv.roles[1]}

# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
    conv.system_message = source[0]["value"].strip()
    source = source[1:]

    conv.messages = []
    for j, sentence in enumerate(source):
        role = roles[sentence["from"]]
        assert role == conv.roles[j % 2], f"{i}"
        conv.append_message(role, sentence["value"])
    conversations.append(conv.get_prompt())

# Tokenize conversations
input_ids = tokenizer(
    conversations,
    return_tensors="pt",
    padding="max_length",
    max_length=tokenizer.model_max_length,
    truncation=True,
).input_ids

targets = input_ids.clone()

# Mask targets. Only compute loss on the assistant outputs.
sep = conv.roles[1] + "\n"
for conversation, target in zip(conversations, targets):
    total_len = int(target.ne(tokenizer.pad_token_id).sum())

    turns = conversation.split("<|user|>\n")
    cur_len = 1  # for <s> special character
    target[:cur_len] = IGNORE_TOKEN_ID

    for i, turn in enumerate(turns):
        if turn == "":
            break

        if i == 0:  # system message
            parts = [turn, ""]
        else:
            turn = f"<|user|>\n{turn}"
            parts = turn.split(sep)  # user text and assistant text
            if len(parts) != 2:
                break
            parts[0] += sep
        turn_len = len(tokenizer(turn).input_ids) - 1  # loại bỏ kí tự <s>
        instruction_len = len(tokenizer(parts[0]).input_ids) - 1
        target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID

        cur_len += turn_len

    target[cur_len:] = IGNORE_TOKEN_ID

    if False:  # Inspect and check the correctness of masking
        z = target.clone()
        z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
        rank0_print(tokenizer.decode(z))
        exit()

    if cur_len < tokenizer.model_max_length:
        if cur_len != total_len:
            target[:] = IGNORE_TOKEN_ID
            rank0_print(
                f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
                f" #turn = {len(turns) - 1}. (ignored)"
            )

return dict(
    input_ids=input_ids,
    labels=targets,
    attention_mask=input_ids.ne(tokenizer.pad_token_id),
)

`

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants