-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathgenerate_utils.py
79 lines (68 loc) · 3.81 KB
/
generate_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import torch.nn.functional as F
import numpy as np
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from transformers import GPT2Config
from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer
import pretty_midi
from pychord import Chord
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size x vocabulary size)
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
top_k = min(top_k, logits.size(-1)) # Safety check
if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
logits[indices_to_remove] = filter_value
return logits
def sample_sequence(model, length, context, num_samples=1, temperature=1,
top_k=0, top_p=0.0, repetition_penalty=1.0, device='cpu'):
context = torch.tensor(context, dtype=torch.long, device=device)
context = context.unsqueeze(0).repeat(num_samples, 1)
generated = context
with torch.no_grad():
for _ in range(length):
inputs = {'input_ids': generated}
outputs = model(**inputs)
next_token_logits = outputs[0][:, -1, :] / (temperature if temperature > 0 else 1.)
# repetition penalty from CTRL (https://arxiv.org/abs/1909.05858)
for i in range(num_samples):
for _ in set(generated[i].tolist()):
next_token_logits[i, _] /= repetition_penalty
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
if temperature == 0: # greedy sampling:
next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(-1)
else:
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
generated = torch.cat((generated, next_token), dim=1)
return generated
def create_midi(chords):
midi_data = pretty_midi.PrettyMIDI()
piano_program = pretty_midi.instrument_name_to_program('Acoustic Grand Piano')
piano = pretty_midi.Instrument(program=piano_program)
length = 1
for n, chord in enumerate(chords):
for note_name in chord.components_with_pitch(root_pitch=4):
note_number = pretty_midi.note_name_to_number(note_name)
note = pretty_midi.Note(velocity=100, pitch=note_number, start=n * length, end=(n + 1) * length)
piano.notes.append(note)
midi_data.instruments.append(piano)
midi_data.write('chord.mid')