diff --git a/control_codes.py b/control_codes.py new file mode 100644 index 0000000..3ab7296 --- /dev/null +++ b/control_codes.py @@ -0,0 +1,58 @@ + +CONTROL_CODES = { + "Pregnancy": 168629, + "Christianity": 7675, + "Explain": 106423, + "Fitness": 63440, + "Saving": 63163, + "Ask": 27171, + "Ass": 95985, + "Joke": 163509, + "Questions": 45622, + "Thoughts": 49605, + "Retail": 52342, + "Feminism": 164338, + "Writing": 11992, + "Atheism": 192263, + "Netflix": 48616, + "Computing": 39639, + "Opinion": 43213, + "Alone": 44967, + "Funny": 58917, + "Gaming": 40358, + "Human": 4088, + "India": 1331, + "Joker": 77138, + "Diet": 36206, + "Legal": 11859, + "Norman": 4939, + "Tip": 72689, + "Weight": 52343, + "Movies": 46273, + "Running": 23425, + "Science": 2090, + "Horror": 37793, + "Confession": 60572, + "Finance": 12250, + "Politics": 16360, + "Scary": 191985, + "Support": 12654, + "Technologies": 32516, + "Teenage": 66160, + "Event": 32769, + "Learned": 67460, + "Notion": 182770, + "Wikipedia": 37583, + "Books": 6665, + "Extract": 76050, + "Confessions": 102701, + "Conspiracy": 75932, + "Links": 63674, + "Narcissus": 150425, + "Relationship": 54766, + "Relationships": 134796, + "Reviews": 41671, + "News": 4256, + "Translation": 26820, + "multilingual": 128406, +} diff --git a/generation.py b/generation.py index b029926..ed680b7 100644 --- a/generation.py +++ b/generation.py @@ -15,6 +15,7 @@ from tensorflow.python.ops import embedding_ops import fastBPE import platform +from control_codes import CONTROL_CODES use_py3 = platform.python_version()[0] == '3' @@ -170,6 +171,8 @@ def serving_input_fn(): # tokenize provided prompt split_prompt = bpe.apply([prompt])[0].split() + if not any(split_prompt[0] == x for x in CONTROL_CODES.keys()): + print("WARNING! You are not starting your generation from a control code so you won't get good results") text = [word2idx[i] for i in split_prompt] # pad with 0s and create a mini-batch of 2 (arbitrary, for ease of code) diff --git a/pytorch_generation.py b/pytorch_generation.py index 9a95c14..ea094f4 100644 --- a/pytorch_generation.py +++ b/pytorch_generation.py @@ -12,6 +12,7 @@ import tensorflow as tf import fastBPE from tensorflow.python import pywrap_tensorflow +from control_codes import CONTROL_CODES use_py3 = platform.python_version()[0] == '3' @@ -163,6 +164,8 @@ def predict_fn(inputs): # tokenize provided prompt split_prompt = bpe.apply([prompt])[0].split() + if not any(split_prompt[0] == x for x in CONTROL_CODES.keys()): + print("WARNING! You are not starting your generation from a control code so you won't get good results") text = [word2idx[i] for i in split_prompt]