Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pedal msgs to tokenizer #18

Merged
merged 4 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 51 additions & 26 deletions amt/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,22 @@

from torch.multiprocessing import Queue
from tqdm import tqdm
from functools import wraps
from torch.cuda import is_bf16_supported

from amt.model import AmtEncoderDecoder
from amt.tokenizer import AmtTokenizer
from amt.audio import AudioTransform
from amt.audio import AudioTransform, pad_or_trim
from amt.data import get_wav_mid_segments


MAX_SEQ_LEN = 4096
LEN_MS = 30000
STRIDE_FACTOR = 3
CHUNK_LEN_MS = LEN_MS // STRIDE_FACTOR
BEAM = 3
ONSET_TOLERANCE = 50
VEL_TOLERANCE = 50
BEAM = 5
ONSET_TOLERANCE = 61
VEL_TOLERANCE = 100


def _setup_logger():
Expand Down Expand Up @@ -105,10 +108,6 @@ def calculate_onset(
return tokenizer.tok_to_id[("onset", new_onset)]


from functools import wraps
from torch.cuda import is_bf16_supported


def optional_bf16_autocast(func):
@wraps(func)
def wrapper(*args, **kwargs):
Expand Down Expand Up @@ -145,7 +144,7 @@ def process_segments(
tokenizer.trunc_seq(prefix, MAX_SEQ_LEN) for prefix in raw_prefixes
]
seq = torch.stack([tokenizer.encode(prefix) for prefix in prefixes]).cuda()
eos_seen = [False for _ in prefixes]
end_idxs = [MAX_SEQ_LEN for _ in prefixes]

kv_cache = model.get_empty_cache()

Expand Down Expand Up @@ -173,7 +172,7 @@ def process_segments(
next_tok_ids = torch.argmax(logits[:, -1], dim=-1)

for batch_idx in range(logits.shape[0]):
if eos_seen[batch_idx] is not False:
if idx > end_idxs[batch_idx]:
# End already seen, add pad token
tok_id = tokenizer.pad_id
elif idx >= prefix_lens[batch_idx]:
Expand All @@ -192,20 +191,24 @@ def process_segments(
tok_id = tokenizer.tok_to_id[prefixes[batch_idx][idx]]

seq[batch_idx, idx] = tok_id
if tokenizer.id_to_tok[tok_id] == tokenizer.eos_tok:
eos_seen[batch_idx] = idx

if all(eos_seen):
tok = tokenizer.id_to_tok[tok_id]
if tok == tokenizer.eos_tok:
end_idxs[batch_idx] = idx
elif (
type(tok) is tuple
and tok[0] == "onset"
and tok[1] >= LEN_MS - CHUNK_LEN_MS
):
end_idxs[batch_idx] = idx - 2

if all(_idx <= idx for _idx in end_idxs):
break

if not all(eos_seen):
if not all(_idx <= idx for _idx in end_idxs):
logger.warning("Context length overflow when transcribing segment")
for _idx in range(seq.shape[0]):
if eos_seen[_idx] == False:
eos_seen[_idx] = MAX_SEQ_LEN

results = [
tokenizer.decode(seq[_idx, : eos_seen[_idx] + 1])
tokenizer.decode(seq[_idx, : end_idxs[_idx] + 1])
for _idx in range(seq.shape[0])
]

Expand All @@ -218,7 +221,7 @@ def gpu_manager(
model: AmtEncoderDecoder,
batch_size: int,
):
# model.compile()
model.compile()
logger = _setup_logger()
audio_transform = AudioTransform().cuda()
tokenizer = AmtTokenizer(return_tensors=True)
Expand Down Expand Up @@ -283,7 +286,7 @@ def _truncate_seq(
except:
return ["<S>"]
else:
return res[: res.index(tokenizer.eos_tok)]
return res[: res.index(tokenizer.eos_tok)] # Needs to change


def process_file(
Expand All @@ -302,8 +305,15 @@ def process_file(
audio_path=file_path, stride_factor=STRIDE_FACTOR
)
]
seq = ["<S>"]
res = ["<S>"]

# Add addtional (padded) final audio segment
_last_seg = audio_segments[-1]
audio_segments.append(
pad_or_trim(_last_seg[len(_last_seg) // STRIDE_FACTOR :])
)

seq = [tokenizer.bos_tok]
res = [tokenizer.bos_tok]
for idx, audio_seg in enumerate(audio_segments):
init_idx = len(seq)

Expand All @@ -318,15 +328,18 @@ def process_file(
result_queue.put(gpu_result)

res += _shift_onset(
seq[init_idx : seq.index(tokenizer.eos_tok)],
seq[init_idx:],
idx * CHUNK_LEN_MS,
)

if idx == len(audio_segments) - 1:
break
elif res[-1] == tokenizer.eos_tok:
logger.info(f"Exiting early")
break
else:
seq = _truncate_seq(seq, CHUNK_LEN_MS, LEN_MS)
if len(seq) == 1:
seq = _truncate_seq(seq, CHUNK_LEN_MS, LEN_MS - CHUNK_LEN_MS)
if len(seq) <= 2:
logger.info(f"Exiting early")
return res

Expand Down Expand Up @@ -441,3 +454,15 @@ def batch_transcribe(
p.join()

gpu_manager_process.join()


def sample_top_p(probs, p):
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)

return next_token
115 changes: 90 additions & 25 deletions amt/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(self, return_tensors: bool = False):
self.prev_tokens = [("prev", i) for i in range(128)]
self.note_on_tokens = [("on", i) for i in range(128)]
self.note_off_tokens = [("off", i) for i in range(128)]
self.pedal_tokens = [("pedal", 0), (("pedal", 1))]
self.velocity_tokens = [("vel", i) for i in self.velocity_quantizations]
self.onset_tokens = [
("onset", i) for i in self.onset_time_quantizations
Expand All @@ -56,6 +57,7 @@ def __init__(self, return_tensors: bool = False):
+ self.prev_tokens
+ self.note_on_tokens
+ self.note_off_tokens
+ self.pedal_tokens
+ self.velocity_tokens
+ self.onset_tokens
)
Expand All @@ -76,7 +78,10 @@ def _quantize_velocity(self, velocity: int):
else:
return velocity_quantized

# This method needs to be cleaned up completely, variables renamed
# TODO:
# - I need to make this method more robust, as it will have to handle
# an arbitrary MIDI file
# - Decide whether to put pedal messages as prev tokens
def _tokenize_midi_dict(
self,
midi_dict: MidiDict,
Expand All @@ -88,6 +93,12 @@ def _tokenize_midi_dict(
), "Invalid values for start_ms, end_ms"

midi_dict.resolve_pedal() # Important !!
pedal_intervals = midi_dict._build_pedal_intervals()
if len(pedal_intervals.keys()) > 1:
print("Warning: midi_dict has more than one pedal channel")
pedal_intervals = pedal_intervals[0]

last_msg_ms = -1
on_off_notes = []
prev_notes = []
for msg in midi_dict.note_msgs:
Expand All @@ -109,6 +120,9 @@ def _tokenize_midi_dict(
ticks_per_beat=midi_dict.ticks_per_beat,
)

if note_end_ms > last_msg_ms:
last_msg_ms = note_end_ms

rel_note_start_ms_q = self._quantize_onset(note_start_ms - start_ms)
rel_note_end_ms_q = self._quantize_onset(note_end_ms - start_ms)
velocity_q = self._quantize_velocity(_velocity)
Expand Down Expand Up @@ -149,35 +163,70 @@ def _tokenize_midi_dict(
("off", _pitch, rel_note_end_ms_q, None)
)

on_off_notes.sort(key=lambda x: (x[2], x[0] == "on"))
on_off_pedal = []
for pedal_on_tick, pedal_off_tick in pedal_intervals:
pedal_on_ms = get_duration_ms(
start_tick=0,
end_tick=pedal_on_tick,
tempo_msgs=midi_dict.tempo_msgs,
ticks_per_beat=midi_dict.ticks_per_beat,
)
pedal_off_ms = get_duration_ms(
start_tick=0,
end_tick=pedal_off_tick,
tempo_msgs=midi_dict.tempo_msgs,
ticks_per_beat=midi_dict.ticks_per_beat,
)

rel_on_ms_q = self._quantize_onset(pedal_on_ms - start_ms)
rel_off_ms_q = self._quantize_onset(pedal_off_ms - start_ms)

# On message
if pedal_on_ms <= start_ms or pedal_on_ms >= end_ms:
continue
else:
on_off_pedal.append(("pedal", 1, rel_on_ms_q, None))

# Off message
if pedal_off_ms <= start_ms or pedal_off_ms >= end_ms:
continue
else:
on_off_pedal.append(("pedal", 0, rel_off_ms_q, None))

on_off_combined = on_off_notes + on_off_pedal
on_off_combined.sort(
key=lambda x: (
x[2],
(0 if x[0] == "pedal" else 1 if x[0] == "off" else 2),
)
)
random.shuffle(prev_notes)

tokenized_seq = []
note_status = {}
for pitch in prev_notes:
note_status[pitch] = True
for note in on_off_notes:
_type, _pitch, _onset, _velocity = note
for tok in on_off_combined:
_type, _val, _onset, _velocity = tok
if _type == "on":
if note_status.get(_pitch) == True:
# Place holder - we can remove note_status logic now
raise Exception

tokenized_seq.append(("on", _pitch))
tokenized_seq.append(("on", _val))
tokenized_seq.append(("onset", _onset))
tokenized_seq.append(("vel", _velocity))
note_status[_pitch] = True
elif _type == "off":
if note_status.get(_pitch) == False:
# Place holder - we can remove note_status logic now
raise Exception
else:
tokenized_seq.append(("off", _pitch))
tokenized_seq.append(("off", _val))
tokenized_seq.append(("onset", _onset))
elif _type == "pedal":
if _val == 0:
tokenized_seq.append(("pedal", _val))
tokenized_seq.append(("onset", _onset))
elif _val:
tokenized_seq.append(("pedal", _val))
tokenized_seq.append(("onset", _onset))
note_status[_pitch] = False

prefix = [("prev", p) for p in prev_notes]
return prefix + [self.bos_tok] + tokenized_seq + [self.eos_tok]

# Add eos_tok only if segment includes end of midi_dict
if last_msg_ms < end_ms:
return prefix + [self.bos_tok] + tokenized_seq + [self.eos_tok]
else:
return prefix + [self.bos_tok] + tokenized_seq

def _detokenize_midi_dict(
self,
Expand Down Expand Up @@ -243,16 +292,29 @@ def _detokenize_midi_dict(
print("Unexpected token order: 'prev' seen after '<S>'")
if DEBUG:
raise Exception
elif tok_1_type == "pedal":
# Pedal information contained in note-off messages, so we don't
# need to manually processes them
_pedal_data = tok_1_data
_tick = tok_2_data
note_msgs.append(
{
"type": "pedal",
"data": _pedal_data,
"tick": _tick,
"channel": 0,
}
)
elif tok_1_type == "on":
if (tok_2_type, tok_3_type) != ("onset", "vel"):
print("Unexpected token order")
print("Unexpected token order:", tok_1, tok_2, tok_3)
if DEBUG:
raise Exception
else:
notes_to_close[tok_1_data] = (tok_2_data, tok_3_data)
elif tok_1_type == "off":
if tok_2_type != "onset":
print("Unexpected token order")
print("Unexpected token order:", tok_1, tok_2, tok_3)
if DEBUG:
raise Exception
else:
Expand Down Expand Up @@ -336,9 +398,6 @@ def export_data_aug(self):

def export_msg_mixup(self):
def msg_mixup(src: list):
def round_to_base(n, base=150):
return base * round(n / base)

# Process bos, eos, and pad tokens
orig_len = len(src)
seen_pad_tok = False
Expand Down Expand Up @@ -387,13 +446,19 @@ def round_to_base(n, base=150):
elif tok_1_type == "off":
_onset = tok_2_data
buffer[_onset]["off"].append((tok_1, tok_2))
elif tok_1_type == "pedal":
_onset = tok_2_data
buffer[_onset]["pedal"].append((tok_1, tok_2))
else:
pass

# Shuffle order and re-append to result
for k, v in sorted(buffer.items()):
random.shuffle(v["on"])
random.shuffle(v["off"])
for item in v["pedal"]:
res.append(item[0]) # Pedal
res.append(item[1]) # Onset
for item in v["off"]:
res.append(item[0]) # Pitch
res.append(item[1]) # Onset
Expand Down
Loading
Loading