Skip to content

Commit

Permalink
chpt3
Browse files Browse the repository at this point in the history
  • Loading branch information
joecummings committed Dec 22, 2022
1 parent b5c3210 commit 183b80c
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 26 deletions.
10 changes: 5 additions & 5 deletions test/torchtext_unittest/prototype/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ def setUp(self) -> None:
self.inputs = self.transform(
[
"summarize: studies have shown that owning a dog is good for you",
# "translate English to German: That is good.",
# "cola sentence: The course is jumping well.",
# "stsb sentence1: The rhino grazed on the grass. sentence2: A rhino is grazing in a field.",
# "summarize: state authorities dispatched emergency crews tuesday to survey the damage after an onslaught of severe weather in mississippi...",
"translate English to German: That is good.",
"cola sentence: The course is jumping well.",
"stsb sentence1: The rhino grazed on the grass. sentence2: A rhino is grazing in a field.",
"summarize: state authorities dispatched emergency crews tuesday to survey the damage after an onslaught of severe weather in mississippi...",
]
)
torch.manual_seed(0)
Expand Down Expand Up @@ -55,7 +55,7 @@ def test_warns_when_no_max_len_provided(self, mock) -> None:
def test_beam_search(self) -> None:
generation_model = GenerationUtil(self.model)

tokens = generation_model.generate(self.inputs, num_beams=3, max_len=100)
tokens = generation_model.generate(self.inputs, num_beams=3, max_len=30)

generated_text = self.transform.decode(tokens.tolist())

Expand Down
52 changes: 31 additions & 21 deletions torchtext/prototype/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,26 +115,33 @@ def beam_search(
) -> torch.Tensor:

# Is this right?
T = max_len
# T = max_len
N = vocab_size

emissions = model_kwargs["encoder_outputs"].get("encoder_output")

def update_func(emissions_ptr, N, T, prev_step_token_idxs, prev_step_model_states, timestep):
# `emissions_ptr` should always be the same (from encoder output)
# N is not needed
# T is not needed
# Currently, can only handle decoding one at a time
i = T # Access the current seq in inputs
new_model_kwargs = model_kwargs.copy()

if timestep == 0:
prev_step_token_idxs = input_ids
prev_step_token_idxs = [input_ids[i]]
prev_step_model_states = [
create_emitting_model_state(
Seq2SeqModelState(
timestep=0,
hidden_states=None,
sequence=input_ids,
sequence=input_ids[i].unsqueeze(0),
lm_scores=None
)
)
]

new_model_kwargs["encoder_outputs"]["encoder_output"] = emissions[i, :, :].unsqueeze(0)

# import pdb
# pdb.set_trace()

out_probs, model_states = [], []
for idx, model_state_ptr in zip(prev_step_token_idxs, prev_step_model_states):
Expand All @@ -145,10 +152,10 @@ def update_func(emissions_ptr, N, T, prev_step_token_idxs, prev_step_model_state
prev_model_state = get_obj_from_emitting_model_state(model_state_ptr)

# Create new decoder token ids
new_input_ids = torch.cat([prev_model_state.sequence[:, -1], idx], dim=-1)
new_input_ids = torch.cat([prev_model_state.sequence, idx.unsqueeze(0)], dim=-1)

# Forward pass
model_inputs = self.model.prepare_inputs_for_generation(new_input_ids.unsqueeze(dim=0), **model_kwargs)
model_inputs = self.model.prepare_inputs_for_generation(new_input_ids, **new_model_kwargs)
if self.is_huggingface_model:
model_inputs["return_dict"] = True
model_inputs["output_hidden_states"] = True
Expand All @@ -166,7 +173,7 @@ def update_func(emissions_ptr, N, T, prev_step_token_idxs, prev_step_model_state
Seq2SeqModelState(
timestep=timestep,
hidden_states=outputs["decoder_hidden_states"],
sequence=new_input_ids.unsqueeze(dim=0),
sequence=new_input_ids,
lm_scores=lm_scores
)
)
Expand All @@ -179,7 +186,7 @@ def update_func(emissions_ptr, N, T, prev_step_token_idxs, prev_step_model_state
beam_size_token=self.model.config.vocab_size,
beam_threshold=50,
lm_weight=0.0,
eos_score=1.0,
eos_score=0.5,
log_add=True,
)

Expand All @@ -191,21 +198,24 @@ def update_func(emissions_ptr, N, T, prev_step_token_idxs, prev_step_model_state
max_output_length=max_len
)

emissions = model_kwargs["encoder_outputs"].get("encoder_output")
all_final_tokens = []
for i in range(len(input_ids)):

decoder.decode_step(emissions.data_ptr(), T, N)
hyps = decoder.get_all_final_hypothesis()
decoder.decode_step(emissions.data_ptr(), i, N)
hyps = decoder.get_all_final_hypothesis()

token_scores = [(hyp.tokens, hyp.score) for hyp in hyps]
max_tokens = max(token_scores, key=lambda x: x[1])
token_scores = [(hyp.tokens, hyp.score) for hyp in hyps]
max_tokens = max(token_scores, key=lambda x: x[1])

filtered = list(filter(lambda x: x != -1, max_tokens[0]))
final_tokens = [0] + filtered

import pdb
pdb.set_trace()
filtered = list(filter(lambda x: x != -1, max_tokens[0]))
final_tokens = filtered

while len(final_tokens) < max_len:
final_tokens += [0]

all_final_tokens.append(torch.Tensor(final_tokens).to(torch.long))

return torch.Tensor(final_tokens).to(torch.long)
return torch.stack(all_final_tokens, dim=0)

def generate(
self,
Expand Down

0 comments on commit 183b80c

Please sign in to comment.