From 45a1a664494f9dd0d31675bf41e8a3b36f60fb3a Mon Sep 17 00:00:00 2001 From: zwkkk <396635841@qq.com> Date: Mon, 3 Apr 2023 11:41:19 +0800 Subject: [PATCH] fix bug in tokenize_dataset_rows.py and infer.ipynb --- .DS_Store | Bin 0 -> 8196 bytes infer.ipynb | 27 +++++++++++++-------------- tokenize_dataset_rows.py | 6 ++++-- 3 files changed, 17 insertions(+), 16 deletions(-) create mode 100644 .DS_Store diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..a0875276163294b90f5a1d1d65e8d93ab946a5d9 GIT binary patch literal 8196 zcmeHM&uudg2YLo)8a17Srjsx$ z3zMM;wK{mNawp+ww6#^hDzK^m=kB}ICP^36UC7_t$lmJ*D`$mz?->OY)BvK>S+ z0E;);5lZ=|GPL+L&;e^mttG#&oBV=t(l_&c3;I!*3@erIVrx5dZ71twoxJn1`9_YL z!`5(=)LMh*eDy@iD7b95f~R43>=$kx%XrudQM~-W96n7HEZR!Ch z=j8lCaWdJj?w8!X2M5!VJ2|LUO76Y8_ovgGvwQo_;i-Qf^`zwb^`fte z{x7NM&(SGnJ)$~20+#Wqca|8u57LYRc3hX9##MV literal 0 HcmV?d00001 diff --git a/infer.ipynb b/infer.ipynb index 78a4e33..b750434 100644 --- a/infer.ipynb +++ b/infer.ipynb @@ -134,18 +134,17 @@ " for idx, item in enumerate(instructions[:3]):\n", " feature = format_example(item)\n", " input_text = feature['context']\n", - " ids = tokenizer.encode(input_text)\n", - " input_ids = torch.LongTensor([ids])\n", - " out = model.generate(\n", - " input_ids=input_ids,\n", - " max_length=150,\n", - " do_sample=False,\n", - " temperature=0\n", - " )\n", - " out_text = tokenizer.decode(out[0])\n", - " answer = out_text.replace(input_text, \"\").replace(\"\\nEND\", \"\").strip()\n", + " input_ids = tokenizer.encode(input_text, return_tensors=\"pt\")\n", + " inputs = model.prepare_inputs_for_generation(input_ids)\n", + " for k,v in inputs.items():\n", + " if v is not None:\n", + " inputs[k] = v.to(\"cuda\")\n", + " outputs = model.generate(**inputs, max_length=512, eos_token_id=tokenizer.eop_token_id)\n", + " out = outputs[0].tolist()[input_ids.size()[-1]:]\n", + " answer = tokenizer.decode(out)\n", " item['infer_answer'] = answer\n", - " print(out_text)\n", + " print(input_text)\n", + " print(answer)\n", " print(f\"### {idx+1}.Answer:\\n\", item.get('output'), '\\n\\n')\n", " answers.append({'index': idx, **item})" ] @@ -153,7 +152,7 @@ ], "metadata": { "kernelspec": { - "display_name": "venv", + "display_name": "Python 3.9.6 64-bit", "language": "python", "name": "python3" }, @@ -167,12 +166,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.12" + "version": "3.9.6" }, "orig_nbformat": 4, "vscode": { "interpreter": { - "hash": "25273a2a68c96ebac13d7fb9e0db516f9be0772777a0507fe06d682a441a3ba7" + "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" } } }, diff --git a/tokenize_dataset_rows.py b/tokenize_dataset_rows.py index 5352a32..f0824c8 100644 --- a/tokenize_dataset_rows.py +++ b/tokenize_dataset_rows.py @@ -9,13 +9,15 @@ def preprocess(tokenizer, config, example, max_seq_length): prompt = example["context"] target = example["target"] - prompt_ids = tokenizer.encode(prompt, max_length=max_seq_length, truncation=True) + prompt_ids = tokenizer.encode(prompt, max_length=max_seq_length, truncation=True,return_attention_mask=False, + add_special_tokens=False) target_ids = tokenizer.encode( target, max_length=max_seq_length, truncation=True, + return_attention_mask=False, add_special_tokens=False) - input_ids = prompt_ids + target_ids + [config.eos_token_id] + input_ids = prompt_ids + [150001, 150004] + target_ids + [150005] return {"input_ids": input_ids, "seq_len": len(prompt_ids)}