-
Notifications
You must be signed in to change notification settings - Fork 178
/
lm_env.py
297 lines (262 loc) · 11.7 KB
/
lm_env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
"""
PyTorch implementation of Language Model Environment.
There are two main components in this documentation:
- We use GPT-2 as the base language model and construct an environment.
- We make a demonstration of this environment and users can type prompts in the command line to interact with the language model.
"""
import torch
import gym
from typing import Callable, Optional, Dict, Tuple
# For more information about GPT2, please refer to this doc: <link https://huggingface.co/transformers/v3.0.2/model_doc/gpt2.html#gpt2lmheadmodel link>.
from transformers import GPT2Tokenizer, GPT2LMHeadModel
def calculate_perplexity(model: GPT2LMHeadModel, query: torch.Tensor, response: torch.Tensor) -> float:
"""
**Overview:**
Calculate the perplexity of the response given a language model, query token ids and response token ids. \
In essence, the perplexity is the exponential result of cross entropy loss, which can reflect the quality of \
the generation to some extent.
**Arguments:**
- model: The language model to calculate perplexity.
- query: The token ids for query.
- response: The token ids for response.
"""
# Concatenate the query and response.
total_input = torch.cat([query, response], dim=0)
# Calculate the logits given the token ids.
logits = model(total_input, return_dict=True).logits
# Shift the output logits and input ids to match their dimension.
# For the i-th shifted logits, it corresponds to the i-th shifted label.
start = query.shape[0]
shifted_logits = logits[start:-1, :]
shifted_labels = total_input[start+1:]
# Use cross entropy loss to calculate the perplexity.
loss_fct = torch.nn.CrossEntropyLoss()
loss = loss_fct(shifted_logits, shifted_labels)
ppl = torch.exp(loss).item()
return ppl
class TextHistory:
"""
**Overview:**
The TextHistory class keeps track of the history of an interaction between the language model and the environment.
"""
def __init__(self, text: str, tokens: Optional[torch.Tensor]):
"""
**Overview:**
Initialize TextHistory.
**Arguments:**
- text: The text of the first segment.
- tokens: The tokens of the first segment.
"""
if len(text) == 0:
self.text, self.tokens = None, None
return
# Record the total text generated by both user and language model.
self.text = text
# Record the ranges of text for each reply.
self.text_spans = []
# Record the tokenized total text generated by both user and language model.
self.tokens = tokens
# This flag shows whether this record is finished.
self.completed = False
self.append_segment(text, tokens)
# delimiter
def append_segment(self, text: str, tokens: torch.Tensor) -> None:
"""
**Overview:**
Append a new segment to the history.
**Arguments:**
- text: The text of the new segment.
- tokens: The tokens of the new segment.
"""
# If the text is empty, raise Error.
if len(text) == 0 or len(tokens) == 0:
raise ValueError("Can't append empty text or token list to history.")
# Add the new text to ``self.text``
original_text_length = len(self.text)
self.text += text
# Update the range of this new text segment.
self.text_spans.append((original_text_length, len(self.text)))
# Add the new tokens to ``self.tokens``.
self.tokens = torch.cat((self.tokens, tokens))
# delimiter
@property
def last_text_segment(self) -> str:
"""
**Overview:**
Get the last text segment.
"""
start, end = self.text_spans[-1]
return self.text[start:end]
def to_obs(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""
**Overview:**
Convert the history object into an observation tensor and the corresponding mask. \
The observation tensor will be padded to a fixed length (1024). \
For ids generated by user, the mask value is 1; for ids generated by language model, the mask value is 2; for padding ids, the mask value is 0.
"""
# Pad the observation to 1024.
obs = self.tokens
if len(obs) < 1024:
obs = torch.nn.functional.pad(obs, (0, 1024-len(obs)))
# Generate corresponding mask.
mask = torch.zeros_like(obs)
for i in range(len(self.text_spans)):
sli = self.text_spans[i]
# For ids generated for users, the mask value is 1.
if i % 2 == 0:
mask[sli[0]: sli[1]] = 1
# For ids generated by language model, the mask value is 2.
else:
mask[sli[0]: sli[1]] = 2
return obs, mask
# delimiter
class TextEnvironment(gym.Env):
"""
**Overview:**
The TextEnvironment enables interaction of a LLM with an environment.
"""
def __init__(self, model: GPT2LMHeadModel, tokenizer: GPT2Tokenizer, reward_fn: Callable,
max_turns: int = 4, generation_kwargs: Optional[Dict] = None):
"""
**Overview:**
Initialize the TextEnvironment.
**Arguments:**
- model: The model to use for generation.
- tokenizer: The tokenizer to use for generation.
- reward_fn: A callable function that takes a string and returns a reward.
- max_turns: The maximum number of turns to allow.
- generation_kwargs: A dictionary of keyword arguments to pass to the model's generate method.
"""
# Initialize the arguments.
self.model = model
self.tokenizer = tokenizer
self.reward_fn = reward_fn
self.max_turns = max_turns
# Prepare the arguments for text generation.
if generation_kwargs is None:
self.generation_kwargs = dict()
else:
self.generation_kwargs = generation_kwargs
# Count the times of ``self.step()``
self.turn = 0
# Preserve the history of interactions.
self.history = TextHistory("", None)
# Determine the device of running the model (cpu or cuda).
self.current_device = self.model.device
# Define the action-space, reward-space and observation-space.
# The action space is a sentence (string type).
self._action_space = gym.spaces.Text(max_length=1024)
# In this demo, we use the negative perplexity as reward, whose range is (-inf, 0).
self._reward_space = gym.spaces.Box(-float('inf'), 0)
# The observation is the history query and response, whose shape is 1024.
# If the total length of history < 1024, it will be padded. Detailed implementation is shown in ``TextHistory.to_obs``.
# For each element of the observation, the value range is [0, vcab_size).
self._observation_space = gym.spaces.Box(0, tokenizer.vocab_size, [1024])
# delimiter
def reset(self):
"""
**Overview:**
Reset the environment.
"""
# Reset the history and the counter of step number.
self.history = TextHistory("", None)
self.turn = 0
return self.history
# delimiter
def generate(self) -> torch.Tensor:
"""
**Overview:**
Generate responses for a history.
"""
# The input of model is all the interaction histories.
query_tensors = self.history.tokens
# Generate reply.
response_tensors = self._generate(query_tensors)
# Decode the reply into string format.
response_texts = self.tokenizer.decode(response_tensors)
# Add the new generated reply to ``self.history``
self.history.append_segment(response_texts, response_tensors)
return response_tensors
# delimiter
def step(self, query: str) -> Tuple[torch.Tensor, float, bool, Dict]:
"""
**Overview:**
The step function of the language model environment.
"""
# The history is not initialized. Create a new history.
if self.history.tokens is None:
query_tokens = self.tokenizer(query, return_tensors="pt").input_ids[0].to(self.current_device)
self.history = TextHistory(query, query_tokens)
# The history is already initialized. Append to the original history.
else:
query_tokens = self.tokenizer(query, return_tensors="pt").input_ids[0].to(self.current_device)
self.history.append_segment(query, query_tokens)
# Generate response.
response_tokens = self.generate()
# Calculate the reward function.
rew = self.reward_fn(self.model, query_tokens, response_tokens)
# Check whether the environment is finished.
self.turn += 1
self.history.completed = self.turn >= self.max_turns
obs, mask = self.history.to_obs()
return obs, rew, self.history.completed, {"mask": mask}
# delimiter
def _generate(self, query_tensors: torch.Tensor) -> torch.Tensor:
"""
**Overview:**
Generate responses for a list of query tensors.
**Arguments:**
- query_tensors (torch.Tensor): A list of query tensors to generate responses for.
"""
# Add the batch_size dimension to the original input. Shape: [T, N] -> [1, T, N]
query_tensors = query_tensors.unsqueeze(0)
# Generate the corresponding mask tensor.
batch_mask = torch.ones_like(query_tensors)
inputs = {"input_ids": query_tensors, "attention_mask": batch_mask}
# Call the ``generate()`` API of GPT-2.
generation = self.model.generate(**inputs, **self.generation_kwargs,
pad_token_id=self.tokenizer.eos_token_id)
# Remove prompt from the total completed sentence.
output = generation[0, batch_mask[0].sum():]
return output
# delimiter
def test_env():
"""
**Overview:**
In this function, we test the language model environment and interact with it by typing prompts in the command line.
"""
# Load the pretrained model and tokenizer.
# When first call this function, the pretrained files will be automatically downloaded from <link https://huggingface.co/ link>.
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained('gpt2')
# For simplicity, we set the reward function to be the negative perplexity.
reward_function = lambda lm, query, response: - calculate_perplexity(lm, query, response)
# Arguments for text generation.
generation_kwargs = {
# The maximum number of tokens can be generated by language model is 20.
'max_new_tokens': 20,
# Use nondeterministic method to sample generated results each time.
'do_sample': True,
# The temperature of softmax function for sampling.
'temperature': 0.7,
# Penalize the language model to generate repeated words.
'repetition_penalty': 2.0
}
# Initialize the environment.
env = TextEnvironment(model=model, tokenizer=tokenizer, max_turns=3, reward_fn=reward_function,
generation_kwargs=generation_kwargs)
env.reset()
# Interaction loop.
while True:
# User input the question.
q = input("Please type in your question:")
# The env step once.
obs, reward, done, info = env.step(q)
# Print the response and reward.
print("Response (Reward={:.2f}):{}".format(reward, env.history.last_text_segment))
# If the environment is done, break the interaction loop.
if done:
break
if __name__ == '__main__':
test_env()