Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add multi and single turn chat support #415

Merged

Conversation

dushyantbehl
Copy link
Contributor

@dushyantbehl dushyantbehl commented Dec 12, 2024

Description of the change

Multi turn chat support in SFTTrainer can be used via passing response and instruction template to DataCollatorForCompletionOnlyLM

This PR adds two new arguments to the dataargs

 chat_template: str = field(
        default=None,
        metadata={
            "help": "chat template to use for tokenization. \
            No need to pass this if the tokenizer already has a chat_template \
            if passed, it will overwrite tokenizer.chat_template if it exists"
        },
    )
    instruction_template: str = field(
        default=None,
        metadata={
            "help": "Should be provided for chat training. \
            Piece of text that determines the start of human response"
        },
    )

A combination of these along with response_template is used to train single or multi turn chat data.

  • If a chat_template is supplied its passed to the tokenizer when initialised.
  • If both response and instruction templates are provided then dataset is assumed to be a chat dataset and chat style collator is used by passing appropriate arguments to DataCollatorForCompletionOnlyLM

Related issue number

How to verify the PR

Was the PR tested

The testing of this PR was done locally as below

For single turn data

>>> nntok.chat_template
"{%- if tools %}\n    {{- '<|start_of_role|>available_tools<|end_of_role|>\n' }}\n    {%- for tool in tools %}\n    {{- tool | tojson(indent=4) }}\n    {%- if not loop.last %}\n        {{- '\n\n' }}\n    {%- endif %}\n    {%- endfor %}\n    {{- '<|end_of_text|>\n' }}\n{%- endif %}\n{%- for message in messages['messages'] %}\n    {%- if message['role'] == 'system' %}\n    {{- '<|start_of_role|>system<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n    {%- elif message['role'] == 'user' %}\n    {{- '<|start_of_role|>user<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n    {%- elif message['role'] == 'assistant' %}\n    {{- '<|start_of_role|>assistant<|end_of_role|>'  + message['content'] + '<|end_of_text|>\n' }}\n    {%- elif message['role'] == 'assistant_tool_call' %}\n    {{- '<|start_of_role|>assistant<|end_of_role|><|tool_call|>' + message['content'] + '<|end_of_text|>\n' }}\n    {%- elif message['role'] == 'tool_response' %}\n    {{- '<|start_of_role|>tool_response<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n    {%- endif %}\n    {%- if loop.last and add_generation_prompt %}\n    {{- '<|start_of_role|>assistant<|end_of_role|>' }}\n    {%- endif %}\n{%- endfor %}"
>>> 
>>> single_turn_msg = msg[0]
>>> single_turn=nntok.apply_chat_template(single_turn_msg, tokenize=False)
>>> single_turn
'<|start_of_role|>system<|end_of_role|>You are an AI language model developed by IBM Research. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>What is kurtosis?<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>Kurtosis is a crucial concept in statistical analysis, particularly when evaluating the shape of probability distributions. It is a measure that quantifies the "tailedness" of a distribution, describing how much more peaked or flat it is compared to a normal distribution.\n\nIn more technical terms, kurtosis is the fourth standardized moment, calculated as the fourth power of the standard deviation divided by the square of the variance. It provides a measure of the distribution\'s tails\' heaviness, which can help identify outliers or extreme values in a dataset.\n\nFor a normal distribution, the kurtosis value is 3. However, when the kurtosis value is greater than 3, the distribution is considered leptokurtic, meaning it has heavier tails and a more pronounced peak than the normal distribution. Conversely, if the kurtosis value is less than 3, the distribution is platykurtic, characterized by lighter tails and a flatter peak.\n\nUnderstanding kurtosis is essential for data analysts and researchers, as it can offer valuable insights into the data\'s underlying structure and potential outliers. This information can be particularly useful in fields such as finance, engineering, and biology, where identifying extreme values and tail behavior is critical for risk assessment, stability analysis, and phenotype distribution studies.\n\nIn summary, kurtosis is a statistical measure that quantifies the shape of a probability distribution, with a value of 3 indicating a normal distribution. Values greater than 3 suggest a leptokurtic distribution with heavier tails, while values less than 3 indicate a platykurtic distribution with lighter tails. This measure can help data analysts and researchers better understand their data and identify potential outliers or extreme values.<|end_of_text|>\n'
>>> 
>>> tok_single_turn_msg=nntok(single_turn, return_length=True)
>>> data = Dataset.from_list([tok_single_turn_msg])
>>> data
Dataset({
    features: ['input_ids', 'attention_mask', 'length'],
    num_rows: 1
})
>>> 
>>> from tuning.data.data_preprocessing_utils import get_data_collator
>>> 
>>> collator = get_data_collator(packing=False, response_template="assistant<|end_of_role|>", tokenizer=nntok, is_traindata_tokenized=False, max_seq_length=2048, instruction_template="user<|end_of_role|>")
>>> 
>>> dataloader = torch.utils.data.DataLoader(dataset=data, collate_fn=collator, batch_size=1)
>>> 
>>> for batch in dataloader: print(batch)
... 
{'input_ids': tensor([[128000, 128002,   9125, 128003,   2675,    527,    459,  15592,   4221,
           1646,   8040,    555,  29022,   8483,     13,   1472,    527,    264,
          46878,  18328,     13,   1472,  15884,   1833,  11470,     13,   1472,
            527,  11190,    323,  53997,    323,    499,   1833,  31308,  17959,
            323,  12192,   6928,   7865,     13, 128001,    198, 128002,    882,
         128003,   3923,    374,    597,   5757,  10934,     30, 128001,    198,
         128002,  78191, 128003,     42,   5757,  10934,    374,    264,  16996,
           7434,    304,  29564,   6492,     11,   8104,    994,  38663,    279,
           6211,    315,  19463,  43785,     13,   1102,    374,    264,   6767,
            430,  10484,   9803,    279,    330,   2629,   2230,   2136,      1,
            315,    264,   8141,     11,  23524,   1268,   1790,    810,  78292,
            477,  10269,    433,    374,   7863,    311,    264,   4725,   8141,
            382,    644,    810,  11156,   3878,     11,    597,   5757,  10934,
            374,    279,  11999,  51114,   4545,     11,  16997,    439,    279,
          11999,   2410,    315,    279,   5410,  38664,  18255,    555,    279,
           9518,    315,    279,  33373,     13,   1102,   5825,    264,   6767,
            315,    279,   8141,    596,  64614,      6,  13710,   1918,     11,
            902,    649,   1520,  10765,  87763,    477,  14560,   2819,    304,
            264,  10550,    382,   2520,    264,   4725,   8141,     11,    279,
            597,   5757,  10934,    907,    374,    220,     18,     13,   4452,
             11,    994,    279,    597,   5757,  10934,    907,    374,   7191,
           1109,    220,     18,     11,    279,   8141,    374,   6646,  95540,
            564,   5757,    292,     11,   7438,    433,    706,  44922,  64614,
            323,    264,    810,  38617,  16557,   1109,    279,   4725,   8141,
             13,  82671,     11,    422,    279,    597,   5757,  10934,    907,
            374,   2753,   1109,    220,     18,     11,    279,   8141,    374,
          46089,  73640,   5757,    292,     11,  32971,    555,  30673,  64614,
            323,    264,   1344,   1683,  16557,    382,  71251,    597,   5757,
          10934,    374,   7718,    369,    828,  31499,    323,  12074,     11,
            439,    433,    649,   3085,  15525,  26793,   1139,    279,    828,
            596,  16940,   6070,    323,   4754,  87763,     13,   1115,   2038,
            649,    387,   8104,   5505,    304,   5151,   1778,    439,  17452,
             11,  15009,     11,    323,  34458,     11,   1405,  25607,  14560,
           2819,    323,   9986,   7865,    374,   9200,    369,   5326,  15813,
             11,  20334,   6492,     11,    323,  82423,   8141,   7978,    382,
            644,  12399,     11,    597,   5757,  10934,    374,    264,  29564,
           6767,    430,  10484,   9803,    279,   6211,    315,    264,  19463,
           8141,     11,    449,    264,    907,    315,    220,     18,  19392,
            264,   4725,   8141,     13,  26028,   7191,   1109,    220,     18,
           4284,    264,  95540,    564,   5757,    292,   8141,    449,  44922,
          64614,     11,   1418,   2819,   2753,   1109,    220,     18,  13519,
            264,  46089,  73640,   5757,    292,   8141,    449,  30673,  64614,
             13,   1115,   6767,    649,   1520,    828,  31499,    323,  12074,
           2731,   3619,    872,    828,    323,  10765,   4754,  87763,    477,
          14560,   2819,     13, 128001,    198]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'length': tensor([[401]]), 'labels': tensor([[  -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,     42,   5757,  10934,    374,    264,  16996,
           7434,    304,  29564,   6492,     11,   8104,    994,  38663,    279,
           6211,    315,  19463,  43785,     13,   1102,    374,    264,   6767,
            430,  10484,   9803,    279,    330,   2629,   2230,   2136,      1,
            315,    264,   8141,     11,  23524,   1268,   1790,    810,  78292,
            477,  10269,    433,    374,   7863,    311,    264,   4725,   8141,
            382,    644,    810,  11156,   3878,     11,    597,   5757,  10934,
            374,    279,  11999,  51114,   4545,     11,  16997,    439,    279,
          11999,   2410,    315,    279,   5410,  38664,  18255,    555,    279,
           9518,    315,    279,  33373,     13,   1102,   5825,    264,   6767,
            315,    279,   8141,    596,  64614,      6,  13710,   1918,     11,
            902,    649,   1520,  10765,  87763,    477,  14560,   2819,    304,
            264,  10550,    382,   2520,    264,   4725,   8141,     11,    279,
            597,   5757,  10934,    907,    374,    220,     18,     13,   4452,
             11,    994,    279,    597,   5757,  10934,    907,    374,   7191,
           1109,    220,     18,     11,    279,   8141,    374,   6646,  95540,
            564,   5757,    292,     11,   7438,    433,    706,  44922,  64614,
            323,    264,    810,  38617,  16557,   1109,    279,   4725,   8141,
             13,  82671,     11,    422,    279,    597,   5757,  10934,    907,
            374,   2753,   1109,    220,     18,     11,    279,   8141,    374,
          46089,  73640,   5757,    292,     11,  32971,    555,  30673,  64614,
            323,    264,   1344,   1683,  16557,    382,  71251,    597,   5757,
          10934,    374,   7718,    369,    828,  31499,    323,  12074,     11,
            439,    433,    649,   3085,  15525,  26793,   1139,    279,    828,
            596,  16940,   6070,    323,   4754,  87763,     13,   1115,   2038,
            649,    387,   8104,   5505,    304,   5151,   1778,    439,  17452,
             11,  15009,     11,    323,  34458,     11,   1405,  25607,  14560,
           2819,    323,   9986,   7865,    374,   9200,    369,   5326,  15813,
             11,  20334,   6492,     11,    323,  82423,   8141,   7978,    382,
            644,  12399,     11,    597,   5757,  10934,    374,    264,  29564,
           6767,    430,  10484,   9803,    279,   6211,    315,    264,  19463,
           8141,     11,    449,    264,    907,    315,    220,     18,  19392,
            264,   4725,   8141,     13,  26028,   7191,   1109,    220,     18,
           4284,    264,  95540,    564,   5757,    292,   8141,    449,  44922,
          64614,     11,   1418,   2819,   2753,   1109,    220,     18,  13519,
            264,  46089,  73640,   5757,    292,   8141,    449,  30673,  64614,
             13,   1115,   6767,    649,   1520,    828,  31499,    323,  12074,
           2731,   3619,    872,    828,    323,  10765,   4754,  87763,    477,
          14560,   2819,     13, 128001,    198]])}
>>> 

For multi turn data

>>> multi_turn_msg = msg[1]
>>> multi_turn=nntok.apply_chat_template(multi_turn_msg, tokenize=False)
>>> multi_turn
"<|start_of_role|>system<|end_of_role|>You are an AI language model developed by IBM Research. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Can an autistic patient with depression be treated with bupropion if he would like to avoid weight gain and sexual problems associated with SSRIs (especially chronic post-SSRI syndrome) or should SSRIs still be used as first-line medication in such a case?<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>Bupropion is an antidepressant that is often considered for patients who experience sexual side effects or weight gain from Selective Serotonin Reuptake Inhibitors (SSRIs). It is a norepinephrine-dopamine reuptake inhibitor (NDRI) that can help manage depression by increasing the levels of norepinephrine and dopamine in the brain.\n\nWhile SSRIs are commonly used as first-line treatment for depression, bupropion can be a suitable alternative for autistic patients who have specific concerns regarding weight gain and sexual dysfunction associated with SSRIs. However, it is essential to consider that bupropion may not be the best option for everyone, and the decision should be made based on individual patient needs, preferences, and medical history.\n\nIt is important to note that both bupropion and SSRIs can have side effects, and the patient's response to medication may vary. In some cases, bupropion might be more effective in managing depression without causing sexual dysfunction or weight gain, but it may not be as effective in treating anxiety symptoms that sometimes co-occur with autism and depression.\n\nIn this situation, a thorough discussion with the patient's healthcare provider is necessary to weigh the benefits and risks of each medication and determine the most appropriate treatment plan. The healthcare provider may consider factors such as the patient's overall mental health, previous treatment history, potential drug interactions, and personal preferences when making a recommendation.<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>What if the autistic patient would explicitly prefer a treatment with bupropion rather than SSRIs?<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>If an autistic patient with depression expresses a preference for bupropion over SSRIs, it is crucial to involve the patient in the decision-making process and respect their preferences while considering the clinical rationale. The healthcare provider should evaluate whether bupropion is a suitable option based on the patient's medical history, current health status, and potential interactions with other medications.\n\nIn this case, the healthcare provider should have an open conversation with the patient about the potential benefits and risks of bupropion compared to SSRIs. The provider should also discuss any concerns the patient may have regarding side effects, such as sleep disturbances or anxiety, which can sometimes occur with bupropion.\n\nUltimately, the decision to use bupropion or an SSRI should be made collaboratively between the patient and the healthcare provider, considering the patient's individual needs, preferences, and medical history. Regular follow-ups and monitoring for side effects and treatment response are essential to ensure the chosen treatment remains effective and safe for the patient.<|end_of_text|>\n"
>>> tok_multi_turn_msg=nntok(multi_turn, return_length=True)
>>> data = Dataset.from_list([tok_multi_turn_msg])
>>> data
Dataset({
    features: ['input_ids', 'attention_mask', 'length'],
    num_rows: 1
})
>>> dataloader = torch.utils.data.DataLoader(dataset=data, collate_fn=collator, batch_size=1)
>>> for batch in dataloader: print(batch)
... 
{'input_ids': tensor([[128000, 128002,   9125, 128003,   2675,    527,    459,  15592,   4221,
           1646,   8040,    555,  29022,   8483,     13,   1472,    527,    264,
          46878,  18328,     13,   1472,  15884,   1833,  11470,     13,   1472,
            527,  11190,    323,  53997,    323,    499,   1833,  31308,  17959,
            323,  12192,   6928,   7865,     13, 128001,    198, 128002,    882,
         128003,   6854,    459,  81391,   8893,    449,  18710,    387,  12020,
            449,    293,    455,    897,    290,    422,    568,   1053,   1093,
            311,   5766,   4785,   8895,    323,   7392,   5435,   5938,    449,
          96404,   3957,    320,  36046,  21249,   1772,     12,   1242,   4403,
          28439,      8,    477,   1288,  96404,   3957,   2103,    387,   1511,
            439,   1176,   8614,  24099,    304,   1778,    264,   1162,     30,
         128001,    198, 128002,  78191, 128003,     33,    455,    897,    290,
            374,    459,  65211,    519,    430,    374,   3629,   6646,    369,
           6978,    889,   3217,   7392,   3185,   6372,    477,   4785,   8895,
            505,   8593,    535,   8409,  68055,   1050,   7717,    731,    763,
           5923,  12170,    320,   1242,     49,   3957,    570,   1102,    374,
            264,    308,    461,  39138,    764,  40101,   1773,    454,  20588,
            312,   7717,    731,  70785,    320,   8225,   4403,      8,    430,
            649,   1520,  10299,  18710,    555,   7859,    279,   5990,    315,
            308,    461,  39138,    764,  40101,    323,  66128,    304,    279,
           8271,    382,   8142,  96404,   3957,    527,  17037,   1511,    439,
           1176,   8614,   6514,    369,  18710,     11,    293,    455,    897,
            290,    649,    387,    264,  14791,  10778,    369,  81391,   6978,
            889,    617,   3230,  10742,   9002,   4785,   8895,    323,   7392,
          32403,   5938,    449,  96404,   3957,     13,   4452,     11,    433,
            374,   7718,    311,   2980,    430,    293,    455,    897,    290,
           1253,    539,    387,    279,   1888,   3072,    369,   5127,     11,
            323,    279,   5597,   1288,    387,   1903,   3196,    389,   3927,
           8893,   3966,     11,  19882,     11,    323,   6593,   3925,    382,
           2181,    374,   3062,    311,   5296,    430,   2225,    293,    455,
            897,    290,    323,  96404,   3957,    649,    617,   3185,   6372,
             11,    323,    279,   8893,    596,   2077,    311,  24099,   1253,
          13592,     13,    763,   1063,   5157,     11,    293,    455,    897,
            290,   2643,    387,    810,   7524,    304,  18646,  18710,   2085,
          14718,   7392,  32403,    477,   4785,   8895,     11,    719,    433,
           1253,    539,    387,    439,   7524,    304,  27723,  18547,  13803,
            430,   7170,   1080,     12,    511,   2407,    449,  38281,    323,
          18710,    382,    644,    420,   6671,     11,    264,  17879,  10430,
            449,    279,   8893,    596,  18985,   9287,    374,   5995,    311,
          17988,    279,   7720,    323,  15635,    315,   1855,  24099,    323,
           8417,    279,   1455,   8475,   6514,   3197,     13,    578,  18985,
           9287,   1253,   2980,   9547,   1778,    439,    279,   8893,    596,
           8244,  10723,   2890,     11,   3766,   6514,   3925,     11,   4754,
           5623,  22639,     11,    323,   4443,  19882,    994,   3339,    264,
          28782,     13, 128001,    198, 128002,    882, 128003,   3923,    422,
            279,  81391,   8893,   1053,  21650,  10932,    264,   6514,    449,
            293,    455,    897,    290,   4856,   1109,  96404,   3957,     30,
         128001,    198, 128002,  78191, 128003,   2746,    459,  81391,   8893,
            449,  18710,  61120,    264,  22698,    369,    293,    455,    897,
            290,    927,  96404,   3957,     11,    433,    374,  16996,    311,
          21736,    279,   8893,    304,    279,   5597,  28846,   1920,    323,
           5201,    872,  19882,   1418,  13126,    279,  14830,  57916,     13,
            578,  18985,   9287,   1288,  15806,   3508,    293,    455,    897,
            290,    374,    264,  14791,   3072,   3196,    389,    279,   8893,
            596,   6593,   3925,     11,   1510,   2890,   2704,     11,    323,
           4754,  22639,    449,   1023,  31010,    382,    644,    420,   1162,
             11,    279,  18985,   9287,   1288,    617,    459,   1825,  10652,
            449,    279,   8893,    922,    279,   4754,   7720,    323,  15635,
            315,    293,    455,    897,    290,   7863,    311,  96404,   3957,
             13,    578,   9287,   1288,   1101,   4358,    904,  10742,    279,
           8893,   1253,    617,   9002,   3185,   6372,     11,   1778,    439,
           6212,  85160,    477,  18547,     11,    902,    649,   7170,  12446,
            449,    293,    455,    897,    290,    382,  67343,     11,    279,
           5597,    311,   1005,    293,    455,    897,    290,    477,    459,
          18679,   4403,   1288,    387,   1903,  11430,   8046,   1990,    279,
           8893,    323,    279,  18985,   9287,     11,  13126,    279,   8893,
            596,   3927,   3966,     11,  19882,     11,    323,   6593,   3925,
             13,  29900,   1833,  27859,    323,  16967,    369,   3185,   6372,
            323,   6514,   2077,    527,   7718,    311,   6106,    279,  12146,
           6514,   8625,   7524,    323,   6220,    369,    279,   8893,     13,
         128001,    198]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1]]), 'length': tensor([[632]]), 'labels': tensor([[  -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,     33,    455,    897,    290,
            374,    459,  65211,    519,    430,    374,   3629,   6646,    369,
           6978,    889,   3217,   7392,   3185,   6372,    477,   4785,   8895,
            505,   8593,    535,   8409,  68055,   1050,   7717,    731,    763,
           5923,  12170,    320,   1242,     49,   3957,    570,   1102,    374,
            264,    308,    461,  39138,    764,  40101,   1773,    454,  20588,
            312,   7717,    731,  70785,    320,   8225,   4403,      8,    430,
            649,   1520,  10299,  18710,    555,   7859,    279,   5990,    315,
            308,    461,  39138,    764,  40101,    323,  66128,    304,    279,
           8271,    382,   8142,  96404,   3957,    527,  17037,   1511,    439,
           1176,   8614,   6514,    369,  18710,     11,    293,    455,    897,
            290,    649,    387,    264,  14791,  10778,    369,  81391,   6978,
            889,    617,   3230,  10742,   9002,   4785,   8895,    323,   7392,
          32403,   5938,    449,  96404,   3957,     13,   4452,     11,    433,
            374,   7718,    311,   2980,    430,    293,    455,    897,    290,
           1253,    539,    387,    279,   1888,   3072,    369,   5127,     11,
            323,    279,   5597,   1288,    387,   1903,   3196,    389,   3927,
           8893,   3966,     11,  19882,     11,    323,   6593,   3925,    382,
           2181,    374,   3062,    311,   5296,    430,   2225,    293,    455,
            897,    290,    323,  96404,   3957,    649,    617,   3185,   6372,
             11,    323,    279,   8893,    596,   2077,    311,  24099,   1253,
          13592,     13,    763,   1063,   5157,     11,    293,    455,    897,
            290,   2643,    387,    810,   7524,    304,  18646,  18710,   2085,
          14718,   7392,  32403,    477,   4785,   8895,     11,    719,    433,
           1253,    539,    387,    439,   7524,    304,  27723,  18547,  13803,
            430,   7170,   1080,     12,    511,   2407,    449,  38281,    323,
          18710,    382,    644,    420,   6671,     11,    264,  17879,  10430,
            449,    279,   8893,    596,  18985,   9287,    374,   5995,    311,
          17988,    279,   7720,    323,  15635,    315,   1855,  24099,    323,
           8417,    279,   1455,   8475,   6514,   3197,     13,    578,  18985,
           9287,   1253,   2980,   9547,   1778,    439,    279,   8893,    596,
           8244,  10723,   2890,     11,   3766,   6514,   3925,     11,   4754,
           5623,  22639,     11,    323,   4443,  19882,    994,   3339,    264,
          28782,     13, 128001,    198, 128002,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   2746,    459,  81391,   8893,
            449,  18710,  61120,    264,  22698,    369,    293,    455,    897,
            290,    927,  96404,   3957,     11,    433,    374,  16996,    311,
          21736,    279,   8893,    304,    279,   5597,  28846,   1920,    323,
           5201,    872,  19882,   1418,  13126,    279,  14830,  57916,     13,
            578,  18985,   9287,   1288,  15806,   3508,    293,    455,    897,
            290,    374,    264,  14791,   3072,   3196,    389,    279,   8893,
            596,   6593,   3925,     11,   1510,   2890,   2704,     11,    323,
           4754,  22639,    449,   1023,  31010,    382,    644,    420,   1162,
             11,    279,  18985,   9287,   1288,    617,    459,   1825,  10652,
            449,    279,   8893,    922,    279,   4754,   7720,    323,  15635,
            315,    293,    455,    897,    290,   7863,    311,  96404,   3957,
             13,    578,   9287,   1288,   1101,   4358,    904,  10742,    279,
           8893,   1253,    617,   9002,   3185,   6372,     11,   1778,    439,
           6212,  85160,    477,  18547,     11,    902,    649,   7170,  12446,
            449,    293,    455,    897,    290,    382,  67343,     11,    279,
           5597,    311,   1005,    293,    455,    897,    290,    477,    459,
          18679,   4403,   1288,    387,   1903,  11430,   8046,   1990,    279,
           8893,    323,    279,  18985,   9287,     11,  13126,    279,   8893,
            596,   3927,   3966,     11,  19882,     11,    323,   6593,   3925,
             13,  29900,   1833,  27859,    323,  16967,    369,   3185,   6372,
            323,   6514,   2077,    527,   7718,    311,   6106,    279,  12146,
           6514,   8625,   7524,    323,   6220,    369,    279,   8893,     13,
         128001,    198]])}
>>> 
  • I have added >=1 unit test(s) for every new method I have added.
  • I have ensured all unit tests pass

Copy link

Thanks for making a pull request! 😃
One of the maintainers will review and advise on the next steps.

@github-actions github-actions bot added the feat label Dec 12, 2024
@dushyantbehl dushyantbehl marked this pull request as ready for review December 13, 2024 17:04
@dushyantbehl dushyantbehl changed the title feat: [WIP] Add multi turn chat support. feat: Add multi turn chat support. Dec 13, 2024
@kmehant
Copy link
Collaborator

kmehant commented Dec 16, 2024

Are we planning to add training test cases for single turn and multi turn as well?

@dushyantbehl
Copy link
Contributor Author

Are we planning to add training test cases for single turn and multi turn as well?

Yes we are...for single turn and multi turn..working on a couple

Signed-off-by: Dushyant Behl <dushyantbehl@in.ibm.com>
@dushyantbehl
Copy link
Contributor Author

Added test cases for unit testing e2e using single and multi turn data.

The testcases are e2e but a disection of the logs is shown below like above.

Single turn

>>> single_turn_msg = {"messages": [{"content": "You are an AI language model developed by IBM Research. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.", "role": "system"}, {"content": "Using the word \"grace\", come up with a word that rhymes and has the same number of syllables\n<nopace>", "role": "user"}, {"content": "Certainly! Here's a word that rhymes with \"grace\" and has the same number of syllables:\n1\\. Space", "role": "assistant"}], "group": "lab_extension", "dataset": "base/full-extension", "metadata": "{\"num_turns\": 1}"}
>>> msg = tokenizer.apply_chat_template(single_turn_msg, tokenize=False, add_generation_prompt=False)
>>> msg
'<|system|>\nYou are an AI language model developed by IBM Research. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.</s><|user|>\nUsing the word "grace", come up with a word that rhymes and has the same number of syllables\n<nopace></s><|assistant|>\nCertainly! Here\'s a word that rhymes with "grace" and has the same number of syllables:\n1\\. Space</s>'
>>> tok_msg = tokenizer(msg)
>>> tok_msg
{'input_ids': [1, 32002, 31822, 13, 3838, 397, 363, 7421, 3067, 2228, 3321, 417, 16019, 2858, 31843, 864, 397, 260, 25281, 8825, 31843, 864, 8310, 1085, 8954, 31843, 864, 397, 8032, 291, 29374, 291, 365, 1085, 12515, 8469, 291, 6178, 3575, 4956, 8596, 31829, 31901, 32000, 31822, 13, 10568, 281, 266, 1693, 495, 31839, 4466, 1742, 1412, 550, 351, 260, 1693, 342, 14168, 1276, 277, 291, 470, 266, 1128, 1277, 287, 25775, 2880, 13, 31903, 31828, 386, 568, 3138, 31829, 31901, 32001, 31822, 13, 31851, 1373, 326, 31905, 2611, 31876, 31829, 260, 1693, 342, 14168, 1276, 277, 351, 495, 31839, 4466, 31875, 291, 470, 266, 1128, 1277, 287, 25775, 2880, 31871, 13, 31853, 31890, 31843, 5867, 1089, 31829, 31901], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
>>> data = Dataset.from_list([tok_msg])
>>> data
Dataset({
    features: ['input_ids', 'attention_mask'],
    num_rows: 1
})
>>> collator = get_data_collator(packing=False, response_template="<|assistant|>", tokenizer=tokenizer, is_traindata_tokenized=False, max_seq_length=2048, instruction_template="<|user|>")
>>> collator
DataCollatorForCompletionOnlyLM(tokenizer=LlamaTokenizerFast(name_or_path='/Volumes/Projects/Projects/projects/2023/ai-platform-engg/fms-hf-tuning/tests/artifacts/testdata/tinyllama_tokenizer_special_tokens', vocab_size=32000, model_max_length=2048, is_fast=True, padding_side='left', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<PAD>', 'additional_special_tokens': ['<|user|>', '<|assistant|>', '<|system|>']}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
	0: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	1: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	32000: AddedToken("<|user|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32001: AddedToken("<|assistant|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32002: AddedToken("<|system|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}, mlm=False, mlm_probability=0.15, pad_to_multiple_of=None, tf_experimental_compile=False, return_tensors='pt')
>>> dataloader = torch.utils.data.DataLoader(dataset=data, collate_fn=collator, batch_size=1)
>>> for batch in dataloader: print(batch)
... 
{'input_ids': tensor([[    1, 32002, 31822,    13,  3838,   397,   363,  7421,  3067,  2228,
          3321,   417, 16019,  2858, 31843,   864,   397,   260, 25281,  8825,
         31843,   864,  8310,  1085,  8954, 31843,   864,   397,  8032,   291,
         29374,   291,   365,  1085, 12515,  8469,   291,  6178,  3575,  4956,
          8596, 31829, 31901, 32000, 31822,    13, 10568,   281,   266,  1693,
           495, 31839,  4466,  1742,  1412,   550,   351,   260,  1693,   342,
         14168,  1276,   277,   291,   470,   266,  1128,  1277,   287, 25775,
          2880,    13, 31903, 31828,   386,   568,  3138, 31829, 31901, 32001,
         31822,    13, 31851,  1373,   326, 31905,  2611, 31876, 31829,   260,
          1693,   342, 14168,  1276,   277,   351,   495, 31839,  4466, 31875,
           291,   470,   266,  1128,  1277,   287, 25775,  2880, 31871,    13,
         31853, 31890, 31843,  5867,  1089, 31829, 31901]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'labels': tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         31822,    13, 31851,  1373,   326, 31905,  2611, 31876, 31829,   260,
          1693,   342, 14168,  1276,   277,   351,   495, 31839,  4466, 31875,
           291,   470,   266,  1128,  1277,   287, 25775,  2880, 31871,    13,
         31853, 31890, 31843,  5867,  1089, 31829, 31901]])}
>>> decoded_text = tokenizer.decode([31822,  13, 31851, 1373,  326, 31905, 2611, 31876, 31829,  260, 1693,  342, 14168, 1276,  277,  351,  495, 31839, 4466, 31875, 291,  470,  266, 1128, 1277,  287, 25775, 2880, 31871,  13, 31853, 31890, 31843, 5867, 1089, 31829, 31901], skip_special_tokens=False)
>>> decoded_text
'\nCertainly! Here\'s a word that rhymes with "grace" and has the same number of syllables:\n1\\. Space</s>'
>>> 

Multi turn

>>> tokenizer.chat_template = "{% for message in messages['messages'] %}{% if message['role'] == 'user' %}{{ '<|user|>\n' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ '<|system|>\n' + message['content'] + eos_token }}{% elif message['role'] == 'assistant' %}{{ '<|assistant|>\n'  + message['content'] + eos_token }}{% endif %}{% if loop.last and add_generation_prompt %}{{ '<|assistant|>' }}{% endif %}{% endfor %}"
>>> tok_msg = tokenizer.apply_chat_template(multi_turn_data, tokenize=False)
>>> tok_msg
'<|system|>\nYou are an AI language model developed by IBM Research. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.</s><|user|>\nLook up a word that rhymes with exist</s><|assistant|>\nI found a word that rhymes with "exist":\n1\\. Mist</s><|user|>\nLook up a word that rhymes with exist</s><|assistant|>\nI found a word that rhymes with "exist":\n1\\. Mist</s>'
>>> tok_multi_turn_msg = tokenizer(tok_msg, return_length=True)
>>> tok_multi_turn_msg
{'input_ids': [1, 32002, 31822, 13, 3838, 397, 363, 7421, 3067, 2228, 3321, 417, 16019, 2858, 31843, 864, 397, 260, 25281, 8825, 31843, 864, 8310, 1085, 8954, 31843, 864, 397, 8032, 291, 29374, 291, 365, 1085, 12515, 8469, 291, 6178, 3575, 4956, 8596, 31829, 31901, 32000, 31822, 13, 18114, 550, 260, 1693, 342, 14168, 1276, 277, 351, 2203, 1089, 31829, 31901, 32001, 31822, 13, 31850, 1111, 260, 1693, 342, 14168, 1276, 277, 351, 495, 969, 379, 4412, 13, 31853, 31890, 31843, 17601, 1089, 31829, 31901, 32000, 31822, 13, 18114, 550, 260, 1693, 342, 14168, 1276, 277, 351, 2203, 1089, 31829, 31901, 32001, 31822, 13, 31850, 1111, 260, 1693, 342, 14168, 1276, 277, 351, 495, 969, 379, 4412, 13, 31853, 31890, 31843, 17601, 1089, 31829, 31901], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'length': [123]}
>>> data = Dataset.from_list([tok_multi_turn_msg])
>>> dataloader = torch.utils.data.DataLoader(dataset=data, collate_fn=collator, batch_size=1)
>>> for batch in dataloader: print(batch)
... 
{'input_ids': tensor([[    1, 32002, 31822,    13,  3838,   397,   363,  7421,  3067,  2228,
          3321,   417, 16019,  2858, 31843,   864,   397,   260, 25281,  8825,
         31843,   864,  8310,  1085,  8954, 31843,   864,   397,  8032,   291,
         29374,   291,   365,  1085, 12515,  8469,   291,  6178,  3575,  4956,
          8596, 31829, 31901, 32000, 31822,    13, 18114,   550,   260,  1693,
           342, 14168,  1276,   277,   351,  2203,  1089, 31829, 31901, 32001,
         31822,    13, 31850,  1111,   260,  1693,   342, 14168,  1276,   277,
           351,   495,   969,   379,  4412,    13, 31853, 31890, 31843, 17601,
          1089, 31829, 31901, 32000, 31822,    13, 18114,   550,   260,  1693,
           342, 14168,  1276,   277,   351,  2203,  1089, 31829, 31901, 32001,
         31822,    13, 31850,  1111,   260,  1693,   342, 14168,  1276,   277,
           351,   495,   969,   379,  4412,    13, 31853, 31890, 31843, 17601,
          1089, 31829, 31901]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1]]), 'length': tensor([[123]]), 'labels': tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         31822,    13, 31850,  1111,   260,  1693,   342, 14168,  1276,   277,
           351,   495,   969,   379,  4412,    13, 31853, 31890, 31843, 17601,
          1089, 31829, 31901,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         31822,    13, 31850,  1111,   260,  1693,   342, 14168,  1276,   277,
           351,   495,   969,   379,  4412,    13, 31853, 31890, 31843, 17601,
          1089, 31829, 31901]])}
>>> decoded_text = tokenizer.decode([31822,  13, 31850, 1111,  260, 1693,  342, 14168, 1276,  277, 351,  495,  969,  379, 4412,  13, 31853, 31890, 31843, 17601, 1089, 31829, 31901], skip_special_tokens=False)                                              
>>> decoded_text
'\nI found a word that rhymes with "exist":\n1\\. Mist</s>'

@dushyantbehl dushyantbehl force-pushed the multi_turn_chat branch 2 times, most recently from efa6a2c to ae91828 Compare December 17, 2024 14:26
Signed-off-by: Dushyant Behl <dushyantbehl@in.ibm.com>
Copy link
Collaborator

@Abhishek-TAMU Abhishek-TAMU left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @dushyantbehl for the unit tests. Have few changes to suggest.

tuning/config/configs.py Outdated Show resolved Hide resolved
tests/test_sft_trainer.py Outdated Show resolved Hide resolved
tests/test_sft_trainer.py Outdated Show resolved Hide resolved
tests/artifacts/testdata/__init__.py Outdated Show resolved Hide resolved
Comment on lines 244 to 247
elif data_args.instruction_template and data_args.response_template:
# Data Format 3: Chat dataset with instruction and response template
# We don't do processing for chat dataset
handlers, dataset_text_field = [], None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this unit test, data_args.dataset_text_field is not None so it will always satisfy elif of Line 239 instead of this elif. Hence add data_args.dataset_text_field to None in this unit test

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also have a question that if we are giving data_args.dataset_text_field as None to SFTTrainer , then how would SFTTrainer get to know that messages key has to be picked for data as dataset has multiple keys :

'messages': [], 'group': 'lab_extension', 'dataset': 'base/full-extension', 'metadata': '{"num_turns": 2}'}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fixed..I noticed that SFTTrainer was not properly applying chat templates so having an explicit handler which applies chat template is an easier option...we can keep the dataset_text_field intact and also have more control if we need to do any handling before or after applying chat template.

@dushyantbehl dushyantbehl force-pushed the multi_turn_chat branch 3 times, most recently from 82113e4 to a3f84a4 Compare December 18, 2024 06:14
@kmehant
Copy link
Collaborator

kmehant commented Dec 18, 2024

@dushyantbehl Are we planning to add usage docs to README.md?

@dushyantbehl
Copy link
Contributor Author

Yes @kmehant that will be in the documentation PR as discussed internally on slack

@kmehant
Copy link
Collaborator

kmehant commented Dec 18, 2024

@dushyantbehl Can you also elaborate what response_template would mean in chat training context? Can you add that in help doc for response_template?

@dushyantbehl
Copy link
Contributor Author

@dushyantbehl Can you also elaborate what response_template would mean in chat training context? Can you add that in help doc for response_template?

@kmehant done

@kmehant kmehant changed the title feat: Add multi turn chat support. feat: Add multi turn chat support Dec 18, 2024
@kmehant kmehant changed the title feat: Add multi turn chat support feat: Add multi and single turn chat support Dec 18, 2024
Signed-off-by: Dushyant Behl <dushyantbehl@in.ibm.com>
Copy link
Collaborator

@kmehant kmehant left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good to me.

@dushyantbehl
Copy link
Contributor Author

Waiting for approval from @Abhishek-TAMU @ashokponkumar @willmj

Copy link
Collaborator

@Abhishek-TAMU Abhishek-TAMU left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @dushyantbehl for the fix and adding handler. LGTM!

Copy link
Collaborator

@willmj willmj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks Dushyant

@ashokponkumar ashokponkumar merged commit 42e3077 into foundation-model-stack:main Dec 18, 2024
8 checks passed
@dushyantbehl dushyantbehl deleted the multi_turn_chat branch December 20, 2024 05:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants