From 5a1d8999cc538d52bcbd72db478ce0600c8ffea0 Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 31 Dec 2024 16:56:44 +0000 Subject: [PATCH 01/72] demo --- demo/demo.py | 785 +++++++++++++++++++++++++++++++++++++ demo/midi-tunnel-client.py | 130 ++++++ demo/midi-tunnel-server.py | 61 +++ 3 files changed, 976 insertions(+) create mode 100644 demo/demo.py create mode 100644 demo/midi-tunnel-client.py create mode 100755 demo/midi-tunnel-server.py diff --git a/demo/demo.py b/demo/demo.py new file mode 100644 index 0000000..1bcc753 --- /dev/null +++ b/demo/demo.py @@ -0,0 +1,785 @@ +#!/usr/bin/env python3 + +import argparse +import os +import time +import copy +import logging +import threading +import queue +import torch +import mido +import torch._dynamo.config +import torch._inductor.config + +from torch.cuda import is_available as cuda_is_available +from ariautils.midi import MidiDict, NoteMessage, midi_to_dict + +from aria.tokenizer import InferenceAbsTokenizer +from aria.utils import _load_weight +from aria.inference import TransformerLM +from aria.model import ModelConfig +from aria.config import load_model_config, load_config +from aria.sample import prefill, decode_one, sample_top_p, update_seq_ids_ + +# torch._inductor.config.coordinate_descent_tuning = True +# torch._inductor.config.triton.unique_kernel_names = True +# torch._inductor.config.fx_graph_cache = True + +DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 +MAX_SEQ_LEN = 8192 + +# CONTROL FLOW: + +# 1. Loads model, compiles forward +# 2. Listen on MIDI port for first note +# 3. Start timer at first message seen +# 4. Wait for control-signal +# 3. Signal seen -> prefill all closed notes +# 4. Wait for all notes to close (ignore new notes) -> prefill the rest of the notes +# 5. Init main loop: + +# Generate next token +# Decode NoteMessage +# Add to message global NoteMessages +# Convert into on/off MIDI message and add to dequeue +# Check next message against current time and send message while msg_time <= curr_time +# Listen for control-signal + +# 6. Pop all note-on msgs, pop all note-off messages that are not processed yet +# 7. Go back to (3) - sending messages from extra list arg of off-msgs + + +# TODO: +# - Implement with flex attention to speed up prefill/decoding +# - Possibly compile different kernels for decoding different shapes, and for prefill + +file_handler = logging.FileHandler("./demo.log", mode="w") +file_handler.setLevel(logging.DEBUG) + + +def get_logger(name: str | None = None) -> logging.Logger: + logger = logging.getLogger(name) + if not logger.handlers: + logger.propagate = False + logger.setLevel(logging.DEBUG) + if name is not None: + formatter = logging.Formatter( + "[%(asctime)s]: [%(levelname)s] [%(name)s] %(message)s" + ) + else: + formatter = logging.Formatter( + "[%(asctime)s]: [%(levelname)s] %(message)s" + ) + + ch = logging.StreamHandler() + ch.setLevel(logging.INFO) + ch.setFormatter(formatter) + logger.addHandler(ch) + + # Reuse shared file handler + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + return logger + + +@torch.autocast("cuda", dtype=DTYPE) +@torch.inference_mode() +def compile_model(model: TransformerLM, max_seq_len: int): + logger = get_logger() + assert 10 < max_seq_len <= MAX_SEQ_LEN + + model.eval() + model.setup_cache( + batch_size=1, + max_seq_len=max_seq_len, + dtype=DTYPE, + ) + + global decode_one + decode_one = torch.compile( + decode_one, + mode="reduce-overhead", + fullgraph=True, + ) + + # Might need to pass in pad_idxs? + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): + start_compile_time_s = time.time() + logger.info(f"Compiling forward pass") + decode_one( + model, + idxs=torch.tensor([[0]]).cuda(), + input_pos=torch.tensor([0], device="cuda", dtype=torch.int), + ) + logger.info( + f"Finished compiling - took {time.time() - start_compile_time_s:.4f} seconds" + ) + + for _ in range(100): + decode_one( + model, + idxs=torch.tensor([[0]]).cuda(), + input_pos=torch.tensor([0], device="cuda", dtype=torch.int), + ) + + compiled_forward_start_s = time.time() + decode_one( + model, + idxs=torch.tensor([[0]]).cuda(), + input_pos=torch.tensor([0], device="cuda", dtype=torch.int), + ) + compiled_forward_ms = (time.time() - compiled_forward_start_s) * 1000 + compiled_forward_its = 1000 / compiled_forward_ms + logger.info( + f"Compiled forward pass benchmark: {compiled_forward_ms:.2f} ms/it ({compiled_forward_its:.2f} it/s)" + ) + + return model + + +def load_model( + checkpoint_path: str, +): + logger = get_logger() + if not cuda_is_available(): + raise Exception("CUDA device is not available.") + + init_start_time_s = time.time() + + tokenizer = InferenceAbsTokenizer() + model_config = ModelConfig(**load_model_config("medium")) + model_config.set_vocab_size(tokenizer.vocab_size) + model_config.grad_checkpoint = False + model = TransformerLM(model_config).cuda() + + logging.info(f"Loading model weights from {checkpoint_path}") + model_state = _load_weight(checkpoint_path, "cuda") + model_state = { + k.replace("_orig_mod.", ""): v for k, v in model_state.items() + } + model.load_state_dict(model_state) + + logger.info( + f"Finished initializing model - took {time.time() - init_start_time_s:.4f} seconds" + ) + + return model + + +# TODO: For testing purposes only - remove later +@torch.autocast("cuda", dtype=DTYPE) +@torch.inference_mode() +def prefill_model( + seq: list, tokenizer: InferenceAbsTokenizer, model: TransformerLM +): + logger = get_logger() + enc_seq = tokenizer.encode(seq) + idxs = torch.tensor([enc_seq, enc_seq], dtype=torch.int, device="cuda") + print(idxs.shape) + + start_s = time.time() + prefill( + model=model, + idxs=idxs, + input_pos=torch.arange( + 0, idxs.shape[1], device="cuda", dtype=torch.int + ), + ) + logger.info(f"Took {(time.time() - start_s) * 1000:.4f}") + + +# TODO: Support CFG +# TODO: Context length switching +# TODO: Fix issue with generating onsets before current time +# - As a workaround we can not send messages which are more than 0.2s before curr_time +# TODO: Handle introduction of prompt tags +@torch.autocast("cuda", dtype=DTYPE) +@torch.inference_mode() +def generate_tokens( + priming_seq: list, + tokenizer: InferenceAbsTokenizer, + model: TransformerLM, + control_seen_sentinel: threading.Event, + generated_tokens_queue: queue.Queue, + temperature: float = 0.95, + top_p: float = 0.95, +): + logger = get_logger("GENERATE") + + priming_seq_len = len(priming_seq) + enc_seq = torch.tensor( + [ + tokenizer.encode( + priming_seq + + [tokenizer.pad_tok] * (MAX_SEQ_LEN - len(priming_seq)) + ) + ], + device="cuda", + ) + logger.debug(priming_seq) + logger.info(f"Priming sequence length: {priming_seq_len}") + + prefill_start_s = time.time() + prefill( + model, + idxs=enc_seq[:, :priming_seq_len], + input_pos=torch.arange(0, priming_seq_len, device="cuda"), + ) + logger.info( + f"Prefill took {(time.time() - prefill_start_s) * 1000:.2f} milliseconds" + ) + + idx = priming_seq_len + while (not control_seen_sentinel.is_set()) and idx < MAX_SEQ_LEN: + decode_one_start_time_s = time.time() + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): + logits = decode_one( + model, + idxs=torch.tensor([[enc_seq[0, idx - 1]]]).cuda(), # Workaround + # idxs=enc_seq[:, idx - 1 : idx], + input_pos=torch.tensor( + [idx - 1], device="cuda", dtype=torch.int + ), + ) + + logits[:, tokenizer.tok_to_id[tokenizer.dim_tok]] = float("-inf") + logits[:, tokenizer.tok_to_id[tokenizer.eos_tok]] = float("-inf") + logits[:, tokenizer.tok_to_id[tokenizer.prompt_start_tok]] = float( + "-inf" + ) + + if temperature > 0.0: + probs = torch.softmax(logits / temperature, dim=-1) + next_token_ids = sample_top_p(probs, top_p).flatten() + else: + next_token_ids = torch.argmax(logits, dim=-1).flatten() + + enc_seq[:, idx] = next_token_ids + next_token = tokenizer.id_to_tok[next_token_ids[0].item()] + logger.info( + f"({(time.time() - decode_one_start_time_s)*1000:.2f}ms) {idx}: {next_token}" + ) + generated_tokens_queue.put(next_token) + idx += 1 + + while generated_tokens_queue.qsize() > 250: + logger.info(f"Sleeping for 0.1s") + time.sleep(0.1) + + +def decode_tokens_to_midi( + generated_tokens_queue: queue.Queue, + midi_messages_queue: queue.Queue, + tokenizer: InferenceAbsTokenizer, + first_on_msg_epoch_ms: float, + priming_seq_last_onset_ms: float, + control_seen_sentinel: threading.Event, +): + logger = get_logger("DECODE") + + assert ( + first_on_msg_epoch_ms + priming_seq_last_onset_ms < time.time() * 1000 + ) + + logger.info(f"first_on_msg_epoch_ms: {first_on_msg_epoch_ms}") + logger.info(f"priming_seq_last_onset_ms: {priming_seq_last_onset_ms}") + logger.info(f"curr_time_ms: {round(time.time() * 1000)}") + + note_buffer = [] + num_time_toks = priming_seq_last_onset_ms // 5000 + + while not control_seen_sentinel.is_set(): + while True: + tok = generated_tokens_queue.get() + logger.info(f"Seen token: {tok}") + note_buffer.append(tok) + if isinstance(tok, tuple) and tok[0] == "dur": + break + + while note_buffer and note_buffer[0] == tokenizer.time_tok: + logger.info("Popping time_tok") + num_time_toks += 1 + note_buffer.pop(0) + + assert len(note_buffer) == 3 + logger.info(f"Decoded note: {note_buffer}") + note_tok, onset_tok, dur_tok = note_buffer + _, pitch, vel = note_tok + _, onset = onset_tok + _, dur = dur_tok + + onset_epoch_ms = first_on_msg_epoch_ms + (num_time_toks * 5000) + onset + offset_epoch_ms = onset_epoch_ms + dur + on_msg = {"pitch": pitch, "vel": vel, "epoch_time_ms": onset_epoch_ms} + off_msg = {"pitch": pitch, "vel": 0, "epoch_time_ms": offset_epoch_ms} + + midi_messages_queue.put(on_msg) + midi_messages_queue.put(off_msg) + logger.info(f"Put message: {on_msg}") + logger.info(f"Put message: {off_msg}") + logger.info(f"Ahead by {onset_epoch_ms - round(time.time() * 1000)}ms") + + note_buffer = [] + + +# TODO: Add sent midi messages to msgs list (adjust time) +# TODO: There is a bug here im 99% sure +def stream_midi( + midi_messages_queue: queue.Queue, + msgs: list[mido.Message], + prev_msg_epoch_time_ms: float, + midi_output_port: str, + control_seen_sentinel: threading.Event, +): + logger = get_logger("STREAM") + logger.info( + f"Sending generated messages on MIDI port: '{midi_output_port}'" + ) + active_pitches = [] + midi_messages = [] + + with mido.open_output(midi_output_port) as midi_out: + while not control_seen_sentinel.is_set(): + if len(midi_messages) == 0: + while True: + try: + msg = midi_messages_queue.get(timeout=0.001) + except queue.Empty: + break + else: + logger.info(f"Got message: {msg}") + midi_messages.append(msg) + + midi_messages = sorted( + midi_messages, key=lambda msg: msg["epoch_time_ms"] + ) + + while midi_messages: + curr_epoch_time_ms = round(time.time() * 1000) + msg = midi_messages[0] + + if 0 < curr_epoch_time_ms - msg["epoch_time_ms"] < 200: + mido_msg = mido.Message( + "note_on", + note=msg["pitch"], + velocity=msg["vel"], + channel=0, + time=0, + ) + + if msg["vel"] > 0: + active_pitches.append(msg["pitch"]) + elif msg["vel"] == 0 and msg["pitch"] in active_pitches: + active_pitches.remove(msg["pitch"]) + + mido_msg_with_time = copy.deepcopy(mido_msg) + mido_msg_with_time.time = max( + 0, msg["epoch_time_ms"] - prev_msg_epoch_time_ms + ) + prev_msg_epoch_time_ms = curr_epoch_time_ms + + midi_out.send(mido_msg) + msgs.append(mido_msg_with_time) + logger.info(f"Sent message: {mido_msg}") + midi_messages.pop(0) + + elif curr_epoch_time_ms - msg["epoch_time_ms"] > 200: + # Too far in the past + # TODO: Potential BUG with notes not being turned off? + midi_messages.pop(0) + else: + break + + time.sleep(0.01) + + # Control sentinel seen + while True: + try: + msg = midi_messages_queue.get_nowait() + except queue.Empty: + break + else: + midi_messages.append(msg) + + midi_messages = sorted( + midi_messages, key=lambda msg: msg["epoch_time_ms"] + ) + + # Turn off active pitches straight away + for msg in midi_messages: + if msg["vel"] == 0 and msg["pitch"] in active_pitches: + mido_msg = mido.Message( + "note_on", + note=msg["pitch"], + velocity=0, + channel=0, + time=0, + ) + + curr_epoch_time_ms = round(time.time() * 1000) + mido_msg_with_time = copy.deepcopy(mido_msg) + mido_msg_with_time.time = max( + 0, curr_epoch_time_ms - prev_msg_epoch_time_ms + ) + + midi_out.send(mido_msg) + msgs.append(mido_msg_with_time) + logger.info(f"Sent message: {mido_msg}") + prev_msg_epoch_time_ms = curr_epoch_time_ms + active_pitches.remove(msg["pitch"]) + + return msgs + + +def stream_msgs( + model: TransformerLM, + tokenizer: InferenceAbsTokenizer, + msgs: list[mido.Message], + midi_output_port: str, + first_on_msg_epoch_ms: int, + control_seen_sentinel: threading.Event, +): + midi = convert_msgs_to_midi(msgs=msgs) + midi_dict = MidiDict(**midi_to_dict(midi)) + priming_seq = tokenizer.tokenize( + midi_dict=midi_dict, prompt_intervals_ms=[] + ) + priming_seq = priming_seq[: priming_seq.index(tokenizer.eos_tok)] + if tokenizer.dim_tok in priming_seq: + priming_seq.remove(tokenizer.dim_tok) + + generated_tokens_queue = queue.Queue() + midi_messages_queue = queue.Queue() + + generate_tokens_thread = threading.Thread( + target=generate_tokens, + kwargs={ + "priming_seq": priming_seq, + "tokenizer": tokenizer, + "model": model, + "control_seen_sentinel": control_seen_sentinel, + "generated_tokens_queue": generated_tokens_queue, + }, + ) + generate_tokens_thread.start() + + decode_tokens_to_midi_thread = threading.Thread( + target=decode_tokens_to_midi, + kwargs={ + "generated_tokens_queue": generated_tokens_queue, + "midi_messages_queue": midi_messages_queue, + "tokenizer": tokenizer, + "first_on_msg_epoch_ms": first_on_msg_epoch_ms, + "priming_seq_last_onset_ms": tokenizer.calc_length_ms( + priming_seq[priming_seq.index(tokenizer.bos_tok) :], + onset=True, + ), + "control_seen_sentinel": control_seen_sentinel, + }, + ) + decode_tokens_to_midi_thread.start() + + # stream_midi_thread = threading.Thread( + # target=stream_midi, + # kwargs={ + # "midi_messages_queue": midi_messages_queue, + # "msgs": msgs, + # "prev_msg_epoch_time_ms": first_on_msg_epoch_ms + # + tokenizer.calc_length_ms( + # priming_seq[priming_seq.index(tokenizer.bos_tok) :], + # onset=False, + # ), + # "midi_output_port": midi_output_port, + # "control_seen_sentinel": control_seen_sentinel, + # }, + # ) + # stream_midi_thread.start() + + msgs = stream_midi( + midi_messages_queue=midi_messages_queue, + msgs=msgs, + prev_msg_epoch_time_ms=first_on_msg_epoch_ms + + tokenizer.calc_length_ms( + priming_seq[priming_seq.index(tokenizer.bos_tok) :], + onset=False, + ), + midi_output_port=midi_output_port, + control_seen_sentinel=control_seen_sentinel, + ) + + generate_tokens_thread.join() + decode_tokens_to_midi_thread.join() + + # Generate next token + # Decode NoteMessage + # Add to message global NoteMessages + # Convert into on/off MIDI message and add to dequeue + # Check next message against current time and send message while msg_time <= curr_time + # Listen for control-signal + + # output_msg_queue = [] + # while True: + # pass + + +def convert_msgs_to_midi(msgs: list[mido.Message]): + track = mido.MidiTrack() + track.append(mido.MetaMessage("set_tempo", tempo=500000, time=0)) + track.append(mido.Message("program_change", program=0, channel=0, time=0)) + for msg in msgs: + track.append(msg) + + mid = mido.MidiFile(type=0) + mid.ticks_per_beat = 500 + mid.tracks.append(track) + + return mid + + +# TODO: Make sure that pedal is closing? +def capture_midi_input(midi_input_port: str): + logger = get_logger("CAPTURE") + logger.info(f"Listening on MIDI port: '{midi_input_port}'") + received_messages = [] + seen_control = False + pedal_on = False + active_pitches = set() + first_on_msg_epoch_ms = None + prev_msg_epoch_time_ms = None + + with mido.open_input(midi_input_port) as midi_input, mido.open_output( + "Midi Through:Midi Through Port-2" + ) as midi_output: + for _msg in midi_input: + + if prev_msg_epoch_time_ms is None: + msg_time_ms = 0 + else: + msg_time_ms = round( + round((time.time() * 1000) - prev_msg_epoch_time_ms) + ) + + prev_msg_epoch_time_ms = round(time.time() * 1000) + msg = copy.deepcopy(_msg) + msg.time = msg_time_ms + msg.channel = 0 + logger.info(f"({prev_msg_epoch_time_ms}) {msg}") + + if msg.is_meta is True or msg.type == "program_change": + continue + + if ( + msg.type == "note_on" and msg.velocity == 0 + ) or msg.type == "note_off": + active_pitches.discard(msg.note) + received_messages.append(msg) + midi_output.send(msg) + elif msg.type == "control_change" and msg.control == 64: + received_messages.append(msg) + midi_output.send(msg) + if msg.value < 64: + pedal_on = False + else: + pedal_on = True + elif ( + msg.type == "control_change" + and msg.control == 66 + and msg.value > 0 + ): + # TODO: Change this to control arg + logger.info("Control signal seen") + logger.info(f"Active pitches: {active_pitches}") + seen_control = True + + # TURN OFF ALL MSGS + if seen_control is True: + if active_pitches: + pitch = active_pitches.pop() + __msg = mido.Message( + type="note_on", + note=pitch, + velocity=0, + channel=0, + time=msg_time_ms, + ) + received_messages.append(__msg) + midi_output.send(msg) + while active_pitches: + pitch = active_pitches.pop() + __msg = mido.Message( + type="note_on", + note=pitch, + velocity=0, + channel=0, + time=0, + ) + received_messages.append(__msg) + midi_output.send(msg) + + __msg = mido.Message( + type="control_change", + control=66, + value=0, + channel=0, + time=0, + ) + received_messages.append(__msg) + midi_output.send(msg) + + return received_messages, first_on_msg_epoch_ms + + # LET MSGS END NATURALLY + # if seen_control == True: + # if pedal_on == False and len(active_pitches) == 0: + # return received_messages, first_on_msg_epoch_ms + # else: + # continue + elif msg.type == "note_on" and msg.velocity > 0: + if first_on_msg_epoch_ms is None: + first_on_msg_epoch_ms = round(time.time() * 1000) + + received_messages.append(msg) + midi_output.send(msg) + active_pitches.add(msg.note) + + +def play_midi_file(midi_port: str, midi_path: str): + logger = get_logger("FILE") + logger.info(f"Playing file at {midi_path} on MIDI port {midi_port}") + time.sleep(1) + with mido.open_output(midi_port) as output_port: + for msg in mido.MidiFile(midi_path).play(): + # logger.info(f"MIDI OUTPUT: {msg}") + output_port.send(msg) + + +def listen_for_control(midi_port: str, control_seen_sentinel: threading.Event): + logger = get_logger("LISTEN") + with mido.open_input(midi_port) as input_port: + for msg in input_port: + logger.info(f"MIDI INPUT: {msg}") + if ( + msg.is_meta is False + and msg.type == "control_change" + and msg.control == 66 + and msg.value > 0 + ): + logger.info("Control signal seen") + control_seen_sentinel.set() + return + + +def parse_args(): + argp = argparse.ArgumentParser() + argp.add_argument("-cp", help="path to model checkpoint") + argp.add_argument("-midi_in", required=False, help="MIDI input port") + argp.add_argument("-midi_out", required=True, help="MIDI output port") + argp.add_argument( + "-midi_path", + required=False, + help="Use MIDI file instead of MIDI input port", + ) + argp.add_argument( + "-control_signal", + default=66, + help="MIDI control change message for AI takeover", + ) + argp.add_argument( + "-temp", + help="sampling temperature value", + type=float, + required=False, + default=0.95, + ) + argp.add_argument( + "-top_p", + help="sampling top_p value", + type=float, + required=False, + default=0.95, + ) + argp.add_argument( + "-cfg", + help="sampling cfg gamma value", + type=float, + required=False, + ) + argp.add_argument( + "-metadata", + nargs=2, + metavar=("KEY", "VALUE"), + action="append", + help="manually add metadata key-value pair when sampling", + ) + argp.add_argument( + "-guidance_path", type=str, help="path to guidance MIDI", required=False + ) + argp.add_argument( + "-guidance_start_ms", + help="guidance interval start (ms)", + type=int, + required=False, + ) + argp.add_argument( + "-guidance_end_ms", + help="guidance interval end (ms)", + type=int, + required=False, + ) + + return argp.parse_args() + + +def main(): + args = parse_args() + logger = get_logger() + + tokenizer = InferenceAbsTokenizer() + model = load_model(checkpoint_path=args.cp) + model = compile_model(model=model, max_seq_len=MAX_SEQ_LEN) + + assert (args.midi_path and os.path.isfile(args.midi_path)) or args.midi_in + if args.midi_path: + midi_input_port = "Midi Through:Midi Through Port-0" + play_file_thread = threading.Thread( + target=play_midi_file, args=(midi_input_port, args.midi_path) + ) + play_file_thread.start() + else: + midi_input_port = args.midi_in + + # TODO: All of the below logic should be in a loop: + + msgs, first_on_msg_epoch_ms = capture_midi_input(midi_input_port) + + # midi = convert_msgs_to_midi(msgs=msgs) + # midi.save("/home/loubb/Dropbox/shared/res.mid") + # midi_dict = MidiDict(**midi_to_dict(midi)) + # prefill_model( + # seq=tokenizer.tokenize(midi_dict, prompt_intervals_ms=[]), + # tokenizer=tokenizer, + # model=model, + # ) + # raise Exception + + control_seen_sentinel = threading.Event() + + # listen_for_control_thread = threading.Thread( + # target=listen_for_control, args=(midi_input_port, control_seen_sentinel) + # ) + # listen_for_control_thread.start() + + stream_msgs( + model=model, + tokenizer=tokenizer, + msgs=msgs, + midi_output_port=args.midi_out, + first_on_msg_epoch_ms=first_on_msg_epoch_ms, + control_seen_sentinel=control_seen_sentinel, + ) + + +if __name__ == "__main__": + main() diff --git a/demo/midi-tunnel-client.py b/demo/midi-tunnel-client.py new file mode 100644 index 0000000..ae79385 --- /dev/null +++ b/demo/midi-tunnel-client.py @@ -0,0 +1,130 @@ +import socket +import rtmidi +import time +import subprocess +import signal +import sys +import os +import argparse + +def parse_arguments(): + parser = argparse.ArgumentParser(description='MIDI UDP bridge with SSH tunnel') + parser.add_argument('-p', '--port', type=int, default=5004, + help='UDP port number (default: 5004)') + return parser.parse_args() + +def kill_existing_process(port): + # Check and kill existing process on remote server + check_command = f"ssh home-4090 'lsof -ti :{port}'" + try: + pid = subprocess.check_output(check_command, shell=True).decode().strip() + if pid: + print(f"Found existing process {pid} on port {port}, killing it...") + kill_command = f"ssh home-4090 'kill -9 {pid}'" + subprocess.run(kill_command, shell=True) + # Wait a moment for the port to be freed + time.sleep(1) + except subprocess.CalledProcessError: + # No existing process found + pass + +def setup_ssh_tunnel(port): + # Kill any existing process first + kill_existing_process(port) + + # Start SSH tunnel using socat + ssh_command = f"ssh home-4090 'socat -u UDP4-RECV:{port} STDOUT'" + local_socat = f"socat -u STDIN UDP4-SEND:localhost:{port}" + + ssh_process = subprocess.Popen(ssh_command, shell=True, stdout=subprocess.PIPE) + socat_process = subprocess.Popen(local_socat, shell=True, stdin=ssh_process.stdout) + + # Give the tunnel a moment to establish + time.sleep(1) + return ssh_process, socat_process + +def create_virtual_port(port): + midi_out = rtmidi.MidiOut() + # Create a virtual MIDI port with port number in name + midi_out.open_virtual_port(f"UDP_{port}") + return midi_out + +def start_udp_listener(port): + # Create UDP socket + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.bind(('localhost', port)) + return sock + +def split_midi_messages(data): + """Split a byte array into individual MIDI messages.""" + messages = [] + data_list = list(data) + i = 0 + while i < len(data_list): + # Check if we have a status byte (most significant bit is 1) + if data_list[i] >= 0x80: + # Most MIDI messages are 3 bytes + if i + 2 < len(data_list): + messages.append(data_list[i:i+3]) + i += 3 + else: + # Handle incomplete message at end of buffer + break + else: + # Skip non-status bytes (shouldn't happen in properly formatted MIDI) + i += 1 + return messages + +def cleanup(ssh_process, socat_process, midi_out, sock): + print("\nCleaning up...") + # Kill the SSH and socat processes + if ssh_process: + os.killpg(os.getpgid(ssh_process.pid), signal.SIGTERM) + if socat_process: + socat_process.terminate() + # Close MIDI and socket + if midi_out: + midi_out.close_port() + if sock: + sock.close() + +def main(): + args = parse_arguments() + port = args.port + + ssh_process = None + socat_process = None + midi_out = None + sock = None + + try: + # Setup SSH tunnel first + print(f"Setting up SSH tunnel on port {port}...") + ssh_process, socat_process = setup_ssh_tunnel(port) + + # Setup MIDI and UDP + print(f"Creating virtual MIDI port UDP_{port}...") + midi_out = create_virtual_port(port) + print(f"Starting UDP listener on port {port}...") + sock = start_udp_listener(port) + + print(f"UDP MIDI Bridge started - listening on port {port}") + + while True: + data, addr = sock.recvfrom(1024) + if data: + # Split the data into individual MIDI messages + midi_messages = split_midi_messages(data) + for midi_message in midi_messages: + print(f"Sending MIDI message: {midi_message}") + midi_out.send_message(midi_message) + + except KeyboardInterrupt: + print("\nShutting down UDP MIDI Bridge...") + except Exception as e: + print(f"Error: {e}") + finally: + cleanup(ssh_process, socat_process, midi_out, sock) + +if __name__ == "__main__": + main() diff --git a/demo/midi-tunnel-server.py b/demo/midi-tunnel-server.py new file mode 100755 index 0000000..01eaf2b --- /dev/null +++ b/demo/midi-tunnel-server.py @@ -0,0 +1,61 @@ +import rtmidi +import socket +import time +import struct +import argparse + +class MIDIRouter: + def __init__(self, midi_port="14:0", udp_port=5004): + self.midi_in = rtmidi.MidiIn() + self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + self.udp_port = udp_port + + # Print available ports + ports = self.midi_in.get_ports() + print(f"Available MIDI ports: {ports}") + + # Find and open MIDI port + for i, port in enumerate(ports): + if midi_port in port: + print(f"Opening MIDI port {i}: {port}") + self.midi_in.open_port(i) + break + else: + print(f"Warning: Could not find port containing '{midi_port}'") + + self.midi_in.set_callback(self._midi_callback) + + def _midi_callback(self, message, timestamp): + try: + print(f"Received MIDI message: {message[0]}") + midi_data = struct.pack(f'B' * len(message[0]), *message[0]) + self.socket.sendto(midi_data, ('localhost', self.udp_port)) + print(f"Sent {len(midi_data)} bytes to localhost:{self.udp_port}") + except Exception as e: + print(f"Error in callback: {e}") + + def start(self): + print(f"Routing MIDI messages through SSH tunnel on port {self.udp_port}...") + try: + while True: + time.sleep(0.1) + except KeyboardInterrupt: + self.stop() + + def stop(self): + print("Shutting down...") + self.midi_in.close_port() + self.socket.close() + +def parse_args(): + parser = argparse.ArgumentParser(description='MIDI to UDP router') + parser.add_argument('-midi-p', type=str, default="14:0", + help='MIDI port identifier (default: 14:0)') + parser.add_argument('-udp-p', type=int, default=5004, + help='UDP port for forwarding (default: 5004)') + return parser.parse_args() + +if __name__ == "__main__": + args = parse_args() + router = MIDIRouter(midi_port=args.midi_p, udp_port=args.udp_p) + router.start() From cfae8eea8051a1653dd930294a1bb1d3211ce1e8 Mon Sep 17 00:00:00 2001 From: Louis Date: Sat, 4 Jan 2025 19:15:43 +0000 Subject: [PATCH 02/72] demo fix --- demo/demo.py | 531 +++++++++++++++++++++---------------- demo/demo.sh | 7 + demo/dump.py | 75 ++++++ demo/midi-tunnel-server.py | 4 +- 4 files changed, 383 insertions(+), 234 deletions(-) create mode 100644 demo/demo.sh create mode 100644 demo/dump.py diff --git a/demo/demo.py b/demo/demo.py index 1bcc753..1c31352 100644 --- a/demo/demo.py +++ b/demo/demo.py @@ -2,7 +2,10 @@ import argparse import os +import keyboard import time +import functools +import uuid import copy import logging import threading @@ -13,22 +16,30 @@ import torch._inductor.config from torch.cuda import is_available as cuda_is_available -from ariautils.midi import MidiDict, NoteMessage, midi_to_dict +from contextlib import ExitStack +from ariautils.midi import MidiDict, midi_to_dict from aria.tokenizer import InferenceAbsTokenizer from aria.utils import _load_weight from aria.inference import TransformerLM from aria.model import ModelConfig -from aria.config import load_model_config, load_config -from aria.sample import prefill, decode_one, sample_top_p, update_seq_ids_ +from aria.config import load_model_config +from aria.sample import prefill, decode_one, sample_top_p -# torch._inductor.config.coordinate_descent_tuning = True -# torch._inductor.config.triton.unique_kernel_names = True -# torch._inductor.config.fx_graph_cache = True +torch._inductor.config.coordinate_descent_tuning = True +torch._inductor.config.triton.unique_kernel_names = True +torch._inductor.config.fx_graph_cache = True DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 MAX_SEQ_LEN = 8192 + +# TODO: +# - Add CFG support +# - Add beam-search for first generated onset (watch out for ) +# - Add loop functionality + + # CONTROL FLOW: # 1. Loads model, compiles forward @@ -49,11 +60,6 @@ # 6. Pop all note-on msgs, pop all note-off messages that are not processed yet # 7. Go back to (3) - sending messages from extra list arg of off-msgs - -# TODO: -# - Implement with flex attention to speed up prefill/decoding -# - Possibly compile different kernels for decoding different shapes, and for prefill - file_handler = logging.FileHandler("./demo.log", mode="w") file_handler.setLevel(logging.DEBUG) @@ -63,13 +69,21 @@ def get_logger(name: str | None = None) -> logging.Logger: if not logger.handlers: logger.propagate = False logger.setLevel(logging.DEBUG) + + # Custom formatter class to handle millisecond timestamps + class MillisecondFormatter(logging.Formatter): + def formatTime(self, record, datefmt=None): + # Get milliseconds since epoch using int() to remove decimal places + created_ms = int(record.created * 1000) + return str(created_ms) + if name is not None: - formatter = logging.Formatter( - "[%(asctime)s]: [%(levelname)s] [%(name)s] %(message)s" + formatter = MillisecondFormatter( + "%(asctime)s: [%(levelname)s] [%(name)s] %(message)s" ) else: - formatter = logging.Formatter( - "[%(asctime)s]: [%(levelname)s] %(message)s" + formatter = MillisecondFormatter( + "%(asctime)s: [%(levelname)s] %(message)s" ) ch = logging.StreamHandler() @@ -190,24 +204,28 @@ def prefill_model( logger.info(f"Took {(time.time() - start_s) * 1000:.4f}") -# TODO: Support CFG +# TODO: Support CFG, guidance, and metadata tags # TODO: Context length switching -# TODO: Fix issue with generating onsets before current time -# - As a workaround we can not send messages which are more than 0.2s before curr_time -# TODO: Handle introduction of prompt tags +# TODO: Get the model to predict durations of notes trucated by @torch.autocast("cuda", dtype=DTYPE) @torch.inference_mode() def generate_tokens( priming_seq: list, tokenizer: InferenceAbsTokenizer, model: TransformerLM, - control_seen_sentinel: threading.Event, + control_sentinel: threading.Event, generated_tokens_queue: queue.Queue, + num_preceding_active_pitches: int, temperature: float = 0.95, top_p: float = 0.95, + # cfg_gamma: float | None = None, ): logger = get_logger("GENERATE") + logger.info( + f"Using sampling parameters: temperature={temperature}, top_p={top_p}" + ) + dur_tok_ids = {tokenizer.tok_to_id[tok] for tok in tokenizer.dur_tokens} priming_seq_len = len(priming_seq) enc_seq = torch.tensor( [ @@ -231,10 +249,12 @@ def generate_tokens( f"Prefill took {(time.time() - prefill_start_s) * 1000:.2f} milliseconds" ) - idx = priming_seq_len - while (not control_seen_sentinel.is_set()) and idx < MAX_SEQ_LEN: + idx = priming_seq_len - (3 * num_preceding_active_pitches) + logger.info(f"Starting from idx={idx}") + while (not control_sentinel.is_set()) and idx < MAX_SEQ_LEN: decode_one_start_time_s = time.time() with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): + # BUG: Slicing is causing re-compilation which is annoying logits = decode_one( model, idxs=torch.tensor([[enc_seq[0, idx - 1]]]).cuda(), # Workaround @@ -256,17 +276,30 @@ def generate_tokens( else: next_token_ids = torch.argmax(logits, dim=-1).flatten() + # NOTE: This logic controls re-sampling of potentially truncated notes + # (durations) due to control-signal interruption in capture_midi_input + if idx < priming_seq_len: + if enc_seq[0, idx].item() not in dur_tok_ids: + logger.info( + f"Override prediction {tokenizer.id_to_tok[next_token_ids[0].item()]} -> {tokenizer.id_to_tok[enc_seq[:, idx].item()]}" + ) + next_token_ids = enc_seq[:, idx] + else: + logger.info( + f"Resampled ground truth {tokenizer.id_to_tok[enc_seq[:, idx].item()]} -> {tokenizer.id_to_tok[next_token_ids[0].item()]}" + ) + enc_seq[:, idx] = next_token_ids next_token = tokenizer.id_to_tok[next_token_ids[0].item()] logger.info( f"({(time.time() - decode_one_start_time_s)*1000:.2f}ms) {idx}: {next_token}" ) - generated_tokens_queue.put(next_token) - idx += 1 - while generated_tokens_queue.qsize() > 250: - logger.info(f"Sleeping for 0.1s") - time.sleep(0.1) + # To account for re-sampling + if idx >= priming_seq_len: + generated_tokens_queue.put(next_token) + + idx += 1 def decode_tokens_to_midi( @@ -275,7 +308,7 @@ def decode_tokens_to_midi( tokenizer: InferenceAbsTokenizer, first_on_msg_epoch_ms: float, priming_seq_last_onset_ms: float, - control_seen_sentinel: threading.Event, + control_sentinel: threading.Event, ): logger = get_logger("DECODE") @@ -285,12 +318,11 @@ def decode_tokens_to_midi( logger.info(f"first_on_msg_epoch_ms: {first_on_msg_epoch_ms}") logger.info(f"priming_seq_last_onset_ms: {priming_seq_last_onset_ms}") - logger.info(f"curr_time_ms: {round(time.time() * 1000)}") note_buffer = [] num_time_toks = priming_seq_last_onset_ms // 5000 - while not control_seen_sentinel.is_set(): + while not control_sentinel.is_set(): while True: tok = generated_tokens_queue.get() logger.info(f"Seen token: {tok}") @@ -310,10 +342,21 @@ def decode_tokens_to_midi( _, onset = onset_tok _, dur = dur_tok + _uuid = uuid.uuid4() onset_epoch_ms = first_on_msg_epoch_ms + (num_time_toks * 5000) + onset offset_epoch_ms = onset_epoch_ms + dur - on_msg = {"pitch": pitch, "vel": vel, "epoch_time_ms": onset_epoch_ms} - off_msg = {"pitch": pitch, "vel": 0, "epoch_time_ms": offset_epoch_ms} + on_msg = { + "pitch": pitch, + "vel": vel, + "epoch_time_ms": onset_epoch_ms, + "uuid": _uuid, + } + off_msg = { + "pitch": pitch, + "vel": 0, + "epoch_time_ms": offset_epoch_ms, + "uuid": _uuid, + } midi_messages_queue.put(on_msg) midi_messages_queue.put(off_msg) @@ -324,43 +367,45 @@ def decode_tokens_to_midi( note_buffer = [] -# TODO: Add sent midi messages to msgs list (adjust time) -# TODO: There is a bug here im 99% sure def stream_midi( midi_messages_queue: queue.Queue, msgs: list[mido.Message], prev_msg_epoch_time_ms: float, midi_output_port: str, - control_seen_sentinel: threading.Event, + control_sentinel: threading.Event, ): logger = get_logger("STREAM") logger.info( f"Sending generated messages on MIDI port: '{midi_output_port}'" ) - active_pitches = [] + last_pitch_uuid = {} midi_messages = [] with mido.open_output(midi_output_port) as midi_out: - while not control_seen_sentinel.is_set(): - if len(midi_messages) == 0: - while True: - try: - msg = midi_messages_queue.get(timeout=0.001) - except queue.Empty: - break - else: - logger.info(f"Got message: {msg}") - midi_messages.append(msg) + while not control_sentinel.is_set(): + + while True: + try: + msg = midi_messages_queue.get_nowait() + except queue.Empty: + break + else: + # logger.info(f"Got message: {msg}") + midi_messages.append(msg) midi_messages = sorted( - midi_messages, key=lambda msg: msg["epoch_time_ms"] + midi_messages, + key=lambda msg: ( + msg["epoch_time_ms"], + msg["vel"], + ), ) while midi_messages: curr_epoch_time_ms = round(time.time() * 1000) msg = midi_messages[0] - if 0 < curr_epoch_time_ms - msg["epoch_time_ms"] < 200: + if 0 < curr_epoch_time_ms - msg["epoch_time_ms"] <= 50: mido_msg = mido.Message( "note_on", note=msg["pitch"], @@ -370,29 +415,43 @@ def stream_midi( ) if msg["vel"] > 0: - active_pitches.append(msg["pitch"]) - elif msg["vel"] == 0 and msg["pitch"] in active_pitches: - active_pitches.remove(msg["pitch"]) - - mido_msg_with_time = copy.deepcopy(mido_msg) - mido_msg_with_time.time = max( - 0, msg["epoch_time_ms"] - prev_msg_epoch_time_ms - ) - prev_msg_epoch_time_ms = curr_epoch_time_ms - - midi_out.send(mido_msg) - msgs.append(mido_msg_with_time) - logger.info(f"Sent message: {mido_msg}") + last_pitch_uuid[msg["pitch"]] = msg["uuid"] + should_send = True + else: + # Only send note_off if it matches the last note_on UUID + should_send = ( + last_pitch_uuid.get(msg["pitch"]) == msg["uuid"] + ) + + if should_send is True: + mido_msg_with_time = copy.deepcopy(mido_msg) + mido_msg_with_time.time = max( + 0, msg["epoch_time_ms"] - prev_msg_epoch_time_ms + ) + prev_msg_epoch_time_ms = curr_epoch_time_ms + + midi_out.send(mido_msg) + msgs.append(mido_msg_with_time) + logger.info( + f"(D={msg['epoch_time_ms'] - curr_epoch_time_ms}) Sent message: {msg}" + ) + else: + logger.info( + f"(D={msg['epoch_time_ms'] - curr_epoch_time_ms}) Skipping note_off message due to uuid mismatch: {msg}" + ) midi_messages.pop(0) - elif curr_epoch_time_ms - msg["epoch_time_ms"] > 200: - # Too far in the past - # TODO: Potential BUG with notes not being turned off? + elif curr_epoch_time_ms - msg["epoch_time_ms"] > 100: + # Message occurs too far in the past + logger.info( + f"(D={msg["epoch_time_ms"] - curr_epoch_time_ms}) Skipping message occurring too far in the past: {msg}" + ) midi_messages.pop(0) else: + # Message occurs in the future break - time.sleep(0.01) + time.sleep(0.005) # Control sentinel seen while True: @@ -404,49 +463,61 @@ def stream_midi( midi_messages.append(msg) midi_messages = sorted( - midi_messages, key=lambda msg: msg["epoch_time_ms"] + midi_messages, + key=lambda msg: (msg["epoch_time_ms"], msg["vel"]), ) - # Turn off active pitches straight away - for msg in midi_messages: - if msg["vel"] == 0 and msg["pitch"] in active_pitches: - mido_msg = mido.Message( - "note_on", - note=msg["pitch"], - velocity=0, - channel=0, - time=0, - ) - - curr_epoch_time_ms = round(time.time() * 1000) - mido_msg_with_time = copy.deepcopy(mido_msg) - mido_msg_with_time.time = max( - 0, curr_epoch_time_ms - prev_msg_epoch_time_ms - ) - - midi_out.send(mido_msg) - msgs.append(mido_msg_with_time) - logger.info(f"Sent message: {mido_msg}") - prev_msg_epoch_time_ms = curr_epoch_time_ms - active_pitches.remove(msg["pitch"]) + # # Turn off active pitches straight away + # for msg in midi_messages: + # if msg["vel"] == 0 and msg["pitch"] in active_pitches: + # mido_msg = mido.Message( + # "note_on", + # note=msg["pitch"], + # velocity=0, + # channel=0, + # time=0, + # ) + + # curr_epoch_time_ms = round(time.time() * 1000) + # mido_msg_with_time = copy.deepcopy(mido_msg) + # mido_msg_with_time.time = max( + # 0, curr_epoch_time_ms - prev_msg_epoch_time_ms + # ) + + # midi_out.send(mido_msg) + # msgs.append(mido_msg_with_time) + # logger.info(f"Sent message: {mido_msg}") + # prev_msg_epoch_time_ms = curr_epoch_time_ms + # active_pitches.remove(msg["pitch"]) return msgs +# TODO: Control sentinel needs to terminate generate and midi_msgs_queue +# It also needs to keep sending the note_off msgs, if and only if they are on time def stream_msgs( model: TransformerLM, tokenizer: InferenceAbsTokenizer, msgs: list[mido.Message], midi_output_port: str, first_on_msg_epoch_ms: int, - control_seen_sentinel: threading.Event, + control_sentinel: threading.Event, + temperature: float, + top_p: float, + num_preceding_active_pitches: int, ): midi = convert_msgs_to_midi(msgs=msgs) midi_dict = MidiDict(**midi_to_dict(midi)) priming_seq = tokenizer.tokenize( - midi_dict=midi_dict, prompt_intervals_ms=[] + midi_dict=midi_dict, + # prompt_intervals_ms=[ + # (0, round(time.time() * 1000) - first_on_msg_epoch_ms) + # ], + prompt_intervals_ms=[], ) priming_seq = priming_seq[: priming_seq.index(tokenizer.eos_tok)] + # priming_seq = priming_seq[: priming_seq.index(tokenizer.prompt_end_tok) + 1] + if tokenizer.dim_tok in priming_seq: priming_seq.remove(tokenizer.dim_tok) @@ -459,9 +530,13 @@ def stream_msgs( "priming_seq": priming_seq, "tokenizer": tokenizer, "model": model, - "control_seen_sentinel": control_seen_sentinel, + "control_sentinel": control_sentinel, "generated_tokens_queue": generated_tokens_queue, + "temperature": temperature, + "top_p": top_p, + "num_preceding_active_pitches": num_preceding_active_pitches, }, + daemon=True, ) generate_tokens_thread.start() @@ -476,27 +551,12 @@ def stream_msgs( priming_seq[priming_seq.index(tokenizer.bos_tok) :], onset=True, ), - "control_seen_sentinel": control_seen_sentinel, + "control_sentinel": control_sentinel, }, + daemon=True, ) decode_tokens_to_midi_thread.start() - # stream_midi_thread = threading.Thread( - # target=stream_midi, - # kwargs={ - # "midi_messages_queue": midi_messages_queue, - # "msgs": msgs, - # "prev_msg_epoch_time_ms": first_on_msg_epoch_ms - # + tokenizer.calc_length_ms( - # priming_seq[priming_seq.index(tokenizer.bos_tok) :], - # onset=False, - # ), - # "midi_output_port": midi_output_port, - # "control_seen_sentinel": control_seen_sentinel, - # }, - # ) - # stream_midi_thread.start() - msgs = stream_midi( midi_messages_queue=midi_messages_queue, msgs=msgs, @@ -506,23 +566,12 @@ def stream_msgs( onset=False, ), midi_output_port=midi_output_port, - control_seen_sentinel=control_seen_sentinel, + control_sentinel=control_sentinel, ) generate_tokens_thread.join() decode_tokens_to_midi_thread.join() - # Generate next token - # Decode NoteMessage - # Add to message global NoteMessages - # Convert into on/off MIDI message and add to dequeue - # Check next message against current time and send message while msg_time <= curr_time - # Listen for control-signal - - # output_msg_queue = [] - # while True: - # pass - def convert_msgs_to_midi(msgs: list[mido.Message]): track = mido.MidiTrack() @@ -538,21 +587,36 @@ def convert_msgs_to_midi(msgs: list[mido.Message]): return mid -# TODO: Make sure that pedal is closing? -def capture_midi_input(midi_input_port: str): +def capture_midi_input( + midi_input_port: str, + control_sentinel: threading.Event, + midi_control_signal: int | None = None, + midi_through_port: str | None = None, +): logger = get_logger("CAPTURE") - logger.info(f"Listening on MIDI port: '{midi_input_port}'") received_messages = [] - seen_control = False - pedal_on = False active_pitches = set() first_on_msg_epoch_ms = None prev_msg_epoch_time_ms = None - with mido.open_input(midi_input_port) as midi_input, mido.open_output( - "Midi Through:Midi Through Port-2" - ) as midi_output: - for _msg in midi_input: + logger.info(f"Listening on MIDI port: '{midi_input_port}'") + logger.info(f"Using MIDI control signal: {midi_control_signal}") + if midi_through_port is not None: + logger.info(f"Sending through on MIDI port: '{midi_through_port}'") + + with ExitStack() as stack: + midi_input = stack.enter_context(mido.open_input(midi_input_port)) + midi_through = ( + stack.enter_context(mido.open_output(midi_through_port)) + if midi_through_port + else None + ) + + while not control_sentinel.is_set(): + msg = midi_input.receive(block=False) + if msg is None: + time.sleep(0.001) + continue if prev_msg_epoch_time_ms is None: msg_time_ms = 0 @@ -562,10 +626,9 @@ def capture_midi_input(midi_input_port: str): ) prev_msg_epoch_time_ms = round(time.time() * 1000) - msg = copy.deepcopy(_msg) msg.time = msg_time_ms msg.channel = 0 - logger.info(f"({prev_msg_epoch_time_ms}) {msg}") + logger.info(f"{msg}") if msg.is_meta is True or msg.type == "program_change": continue @@ -575,100 +638,97 @@ def capture_midi_input(midi_input_port: str): ) or msg.type == "note_off": active_pitches.discard(msg.note) received_messages.append(msg) - midi_output.send(msg) + if midi_through is not None: + midi_through.send(msg) + elif msg.type == "note_on" and msg.velocity > 0: + if first_on_msg_epoch_ms is None: + first_on_msg_epoch_ms = round(time.time() * 1000) + + active_pitches.add(msg.note) + received_messages.append(msg) + if midi_through is not None: + midi_through.send(msg) elif msg.type == "control_change" and msg.control == 64: received_messages.append(msg) - midi_output.send(msg) - if msg.value < 64: - pedal_on = False - else: - pedal_on = True elif ( msg.type == "control_change" - and msg.control == 66 + and msg.control == midi_control_signal and msg.value > 0 ): - # TODO: Change this to control arg - logger.info("Control signal seen") - logger.info(f"Active pitches: {active_pitches}") - seen_control = True - - # TURN OFF ALL MSGS - if seen_control is True: - if active_pitches: - pitch = active_pitches.pop() - __msg = mido.Message( - type="note_on", - note=pitch, - velocity=0, - channel=0, - time=msg_time_ms, - ) - received_messages.append(__msg) - midi_output.send(msg) - while active_pitches: - pitch = active_pitches.pop() - __msg = mido.Message( - type="note_on", - note=pitch, - velocity=0, - channel=0, - time=0, - ) - received_messages.append(__msg) - midi_output.send(msg) - - __msg = mido.Message( - type="control_change", - control=66, - value=0, - channel=0, - time=0, - ) - received_messages.append(__msg) - midi_output.send(msg) - - return received_messages, first_on_msg_epoch_ms - - # LET MSGS END NATURALLY - # if seen_control == True: - # if pedal_on == False and len(active_pitches) == 0: - # return received_messages, first_on_msg_epoch_ms - # else: - # continue - elif msg.type == "note_on" and msg.velocity > 0: - if first_on_msg_epoch_ms is None: - first_on_msg_epoch_ms = round(time.time() * 1000) + control_sentinel.set() + + logger.info("Control signal seen") + logger.info(f"Active pitches: {active_pitches}") + num_active_pitches = len(active_pitches) + + if active_pitches: + pitch = active_pitches.pop() + msg = mido.Message( + type="note_on", + note=pitch, + velocity=0, + channel=0, + time=msg_time_ms, + ) + received_messages.append(msg) + if midi_through is not None: + midi_through.send(msg) + + while active_pitches: + pitch = active_pitches.pop() + msg = mido.Message( + type="note_on", + note=pitch, + velocity=0, + channel=0, + time=0, + ) + received_messages.append(msg) + if midi_through is not None: + midi_through.send(msg) + + msg = mido.Message( + type="control_change", + control=64, + value=0, + channel=0, + time=0, + ) + received_messages.append(msg) + if midi_through is not None: + midi_through.send(msg) + + # Workaround for the way that file-playback is implemented - delete + msg = mido.Message( + type="control_change", + control=66, + value=0, + channel=0, + time=0, + ) + if midi_through is not None: + midi_through.send(msg) - received_messages.append(msg) - midi_output.send(msg) - active_pitches.add(msg.note) + return received_messages, first_on_msg_epoch_ms, num_active_pitches def play_midi_file(midi_port: str, midi_path: str): logger = get_logger("FILE") - logger.info(f"Playing file at {midi_path} on MIDI port {midi_port}") + logger.info(f"Playing file at {midi_path} on MIDI port '{midi_port}'") time.sleep(1) with mido.open_output(midi_port) as output_port: for msg in mido.MidiFile(midi_path).play(): - # logger.info(f"MIDI OUTPUT: {msg}") + logger.debug(f"{msg}") output_port.send(msg) -def listen_for_control(midi_port: str, control_seen_sentinel: threading.Event): - logger = get_logger("LISTEN") - with mido.open_input(midi_port) as input_port: - for msg in input_port: - logger.info(f"MIDI INPUT: {msg}") - if ( - msg.is_meta is False - and msg.type == "control_change" - and msg.control == 66 - and msg.value > 0 - ): - logger.info("Control signal seen") - control_seen_sentinel.set() - return +def listen_for_control_signal_keypress(control_sentinel: threading.Event): + logger = get_logger("KEYBOARD") + for _ in range(2): + input() + logger.info("Keypress seen") + control_sentinel.set() + time.sleep(5) def parse_args(): @@ -676,14 +736,19 @@ def parse_args(): argp.add_argument("-cp", help="path to model checkpoint") argp.add_argument("-midi_in", required=False, help="MIDI input port") argp.add_argument("-midi_out", required=True, help="MIDI output port") + argp.add_argument( + "-midi_through", + required=False, + help="MIDI through port for received input", + ) argp.add_argument( "-midi_path", required=False, help="Use MIDI file instead of MIDI input port", ) argp.add_argument( - "-control_signal", - default=66, + "-midi_control_signal", + type=int, help="MIDI control change message for AI takeover", ) argp.add_argument( @@ -734,8 +799,6 @@ def parse_args(): def main(): args = parse_args() - logger = get_logger() - tokenizer = InferenceAbsTokenizer() model = load_model(checkpoint_path=args.cp) model = compile_model(model=model, max_seq_len=MAX_SEQ_LEN) @@ -744,41 +807,45 @@ def main(): if args.midi_path: midi_input_port = "Midi Through:Midi Through Port-0" play_file_thread = threading.Thread( - target=play_midi_file, args=(midi_input_port, args.midi_path) + target=play_midi_file, + args=(midi_input_port, args.midi_path), + daemon=True, ) play_file_thread.start() else: midi_input_port = args.midi_in - # TODO: All of the below logic should be in a loop: - - msgs, first_on_msg_epoch_ms = capture_midi_input(midi_input_port) - - # midi = convert_msgs_to_midi(msgs=msgs) - # midi.save("/home/loubb/Dropbox/shared/res.mid") - # midi_dict = MidiDict(**midi_to_dict(midi)) - # prefill_model( - # seq=tokenizer.tokenize(midi_dict, prompt_intervals_ms=[]), - # tokenizer=tokenizer, - # model=model, - # ) - # raise Exception + # TODO: All of the below logic should be in a loop with additional handling + # for the control sentinel - control_seen_sentinel = threading.Event() + control_sentinel = threading.Event() + keypress_thread = threading.Thread( + target=listen_for_control_signal_keypress, + args=[control_sentinel], + daemon=True, + ) + keypress_thread.start() - # listen_for_control_thread = threading.Thread( - # target=listen_for_control, args=(midi_input_port, control_seen_sentinel) - # ) - # listen_for_control_thread.start() + msgs, first_on_msg_epoch_ms, num_active_pitches = capture_midi_input( + midi_input_port=midi_input_port, + control_sentinel=control_sentinel, + midi_control_signal=args.midi_control_signal, + midi_through_port=args.midi_through, + ) + control_sentinel.clear() stream_msgs( model=model, tokenizer=tokenizer, msgs=msgs, midi_output_port=args.midi_out, first_on_msg_epoch_ms=first_on_msg_epoch_ms, - control_seen_sentinel=control_seen_sentinel, + control_sentinel=control_sentinel, + temperature=args.temp, + top_p=args.top_p, + num_preceding_active_pitches=num_active_pitches, ) + keypress_thread.join() if __name__ == "__main__": diff --git a/demo/demo.sh b/demo/demo.sh new file mode 100644 index 0000000..615328e --- /dev/null +++ b/demo/demo.sh @@ -0,0 +1,7 @@ +python /home/loubb/work/aria/demo/demo.py \ + -cp /mnt/ssd1/aria/v2/medium-75-ft.safetensors \ + -midi_path /home/loubb/Dropbox/shared/audio.mid \ + -midi_out "Midi Through:Midi Through Port-1" \ + -midi_through "Midi Through:Midi Through Port-2" \ + -midi_control_signal 66 \ + -temp 0.96 diff --git a/demo/dump.py b/demo/dump.py new file mode 100644 index 0000000..d1a47d2 --- /dev/null +++ b/demo/dump.py @@ -0,0 +1,75 @@ +import mido +import sys +from mido import tick2second, second2tick + +# Check if correct number of arguments is provided +if len(sys.argv) != 3: + print("Usage: python script.py ") + sys.exit(1) + +# Get command line arguments +input_file = sys.argv[1] +target_seconds = float(sys.argv[2]) + +try: + mid = mido.MidiFile(input_file) +except Exception as e: + print(f"Error loading MIDI file: {e}") + sys.exit(1) + +curr_tick = 0 +idx = 0 +tempo = None + +# First get the tempo +for msg in mid.tracks[0]: + if msg.type == "set_tempo": + tempo = msg.tempo + break + +print(f"Found tempo: {tempo}") + +# Then find the right index +curr_tick = 0 +for idx, msg in enumerate(mid.tracks[0]): + curr_tick += msg.time + seconds = tick2second( + tick=curr_tick, + ticks_per_beat=mid.ticks_per_beat, + tempo=tempo, + ) + print(f"At index {idx}, time: {seconds:.2f} seconds") + if seconds > target_seconds: + print(f"Breaking at index {idx}") + break + +print(f"Inserting at index {idx}") + +# Insert the messages at the found index +mid.tracks[0].insert( + idx, + mido.Message( + type="control_change", + control=66, + value=127, + time=0, + ), +) +mid.tracks[0].insert( + idx + 1, + mido.Message( + type="control_change", + control=66, + value=0, + time=second2tick( + second=0.01, + ticks_per_beat=mid.ticks_per_beat, + tempo=tempo, + ), + ), +) + +# Generate output filename based on input filename +output_path = "/home/loubb/Dropbox/shared/test.mid" +mid.save(output_path) +print(f"Saved modified MIDI file to: {output_path}") diff --git a/demo/midi-tunnel-server.py b/demo/midi-tunnel-server.py index 01eaf2b..988e200 100755 --- a/demo/midi-tunnel-server.py +++ b/demo/midi-tunnel-server.py @@ -49,9 +49,9 @@ def stop(self): def parse_args(): parser = argparse.ArgumentParser(description='MIDI to UDP router') - parser.add_argument('-midi-p', type=str, default="14:0", + parser.add_argument('-midi_p', type=str, default="14:0", help='MIDI port identifier (default: 14:0)') - parser.add_argument('-udp-p', type=int, default=5004, + parser.add_argument('-udp_p', type=int, default=5004, help='UDP port for forwarding (default: 5004)') return parser.parse_args() From ea68e75ad5b11b9744c6d28324b0fffc00be7852 Mon Sep 17 00:00:00 2001 From: Louis Date: Sat, 4 Jan 2025 23:25:21 +0000 Subject: [PATCH 03/72] mess it all up agian --- demo/demo.py | 282 ++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 246 insertions(+), 36 deletions(-) diff --git a/demo/demo.py b/demo/demo.py index 1c31352..2f17805 100644 --- a/demo/demo.py +++ b/demo/demo.py @@ -98,6 +98,10 @@ def formatTime(self, record, datefmt=None): return logger +def get_epoch_time_ms() -> int: + return round(time.time() * 1000) + + @torch.autocast("cuda", dtype=DTYPE) @torch.inference_mode() def compile_model(model: TransformerLM, max_seq_len: int): @@ -182,26 +186,211 @@ def load_model( return model -# TODO: For testing purposes only - remove later @torch.autocast("cuda", dtype=DTYPE) @torch.inference_mode() -def prefill_model( - seq: list, tokenizer: InferenceAbsTokenizer, model: TransformerLM +def recalculate_dur_tokens( + priming_seq: list, + enc_seq: torch.Tensor, + tokenizer: InferenceAbsTokenizer, + model: TransformerLM, + start_idx: int, ): - logger = get_logger() - enc_seq = tokenizer.encode(seq) - idxs = torch.tensor([enc_seq, enc_seq], dtype=torch.int, device="cuda") - print(idxs.shape) + logger = get_logger("GENERATE") + priming_seq_len = len(priming_seq) - start_s = time.time() - prefill( - model=model, - idxs=idxs, - input_pos=torch.arange( - 0, idxs.shape[1], device="cuda", dtype=torch.int - ), + for idx in range(priming_seq_len - start_idx, priming_seq_len): + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): + prev_tok_id = enc_seq[0, idx - 1] + logits = decode_one( + model, + idxs=torch.tensor([[prev_tok_id]]).cuda(), + input_pos=torch.tensor( + [idx - 1], device="cuda", dtype=torch.int + ), + ) + + logits[:, tokenizer.tok_to_id[tokenizer.dim_tok]] = float("-inf") + logits[:, tokenizer.tok_to_id[tokenizer.eos_tok]] = float("-inf") + logits[:, tokenizer.tok_to_id[tokenizer.prompt_start_tok]] = float( + "-inf" + ) + + next_token_ids = torch.argmax(logits, dim=-1).flatten() + priming_tok = tokenizer.id_to_tok[enc_seq[0, idx].item()] + predicted_tok = tokenizer.id_to_tok[next_token_ids[0].item()] + + resample = False + if isinstance(priming_tok, tuple) and priming_tok[0] == "dur": + priming_dur = priming_tok[1] + predicted_dur = predicted_tok[1] + + if predicted_dur > priming_dur: + resample = True + + if resample is True: + logger.info( + f"Resampled ground truth {tokenizer.id_to_tok[enc_seq[:, idx].item()]} -> {tokenizer.id_to_tok[next_token_ids[0].item()]}" + ) + enc_seq[:, idx] = next_token_ids + + return enc_seq + + +# TODO: Clean this up +# - Replace the log statements with log statements of the normal form in generate for generating tokens +# - clean up logging +# - Make sure there is no bugs + + +@torch.autocast("cuda", dtype=DTYPE) +@torch.inference_mode() +def decode_first_onset( + model: TransformerLM, + enc_seq: torch.Tensor, + priming_seq: list, + tokenizer: InferenceAbsTokenizer, + generated_tokens_queue: queue.Queue, + first_on_msg_epoch_ms: int, +): + logger = get_logger("GENERATE-FIRST") + BEAM_WIDTH = 5 + time_since_first_onset_ms = get_epoch_time_ms() - first_on_msg_epoch_ms + num_time_toks = priming_seq.count(tokenizer.time_tok) + time_tok_id = tokenizer.tok_to_id[tokenizer.time_tok] + idx = len(priming_seq) + + # DEBUG + logger.info(f"Priming seq if length {idx}") + logger.info(f"MS since start: {time_since_first_onset_ms}") + logger.info(f"Number of time_toks in priming seq: {num_time_toks}") + # END DEBUG + + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): + prev_tok_id = enc_seq[0, idx - 1] + logits = decode_one( + model, + idxs=torch.tensor([[prev_tok_id]]).cuda(), + input_pos=torch.tensor([idx - 1], device="cuda", dtype=torch.int), + ) + logger.info(f"Sampled logits for tok at pos {idx}") + _, top_ids = torch.topk(logits, k=BEAM_WIDTH, dim=-1) + idx += 1 + + num_time_toks_to_add = ( + (time_since_first_onset_ms + 200) // 5000 + ) - num_time_toks + append_time_toks = (num_time_toks_to_add > 0) or ( + tokenizer.tok_to_id[tokenizer.time_tok] in top_ids[0].tolist() + ) + + # DEBUG + logger.info( + f"top_toks = {[tokenizer.id_to_tok[id] for id in top_ids[0].tolist()]}" + ) + logger.info(f"Append time tok: {append_time_toks}") + logger.info(f"Num time toks to add: {num_time_toks_to_add}") + # END DEBUG + + if append_time_toks: + if num_time_toks_to_add == 0: + num_time_toks_to_add += 1 + + while num_time_toks_to_add > 0: + with torch.nn.attention.sdpa_kernel( + torch.nn.attention.SDPBackend.MATH + ): + enc_seq[:, idx] = torch.tensor([[time_tok_id]]).cuda() + generated_tokens_queue.put(tokenizer.time_tok) + logits = decode_one( + model, + idxs=torch.tensor([[time_tok_id]]).cuda(), + input_pos=torch.tensor( + [idx - 1], device="cuda", dtype=torch.int + ), + ) + logger.info( + f"Sampled logits for tok at pos {idx} by adding time_tok" + ) + num_time_toks_to_add -= 1 + idx += 1 + + # BEAM SEARCH + probs = torch.softmax(logits, dim=-1) + top_probs, top_ids = torch.topk(probs, k=BEAM_WIDTH, dim=-1) + + # DEBUG + logger.info( + f"top_toks = {[tokenizer.id_to_tok[id] for id in top_ids[0].tolist()]}" + ) + logger.info(f"top_probs = {top_probs}") + # END DEBUG + + if append_time_toks is False: + masked_onset_ids = [ + tokenizer.tok_to_id[tok] + for tok in tokenizer.onset_tokens + if tok[1] < (time_since_first_onset_ms % 5000) + ] + else: + masked_onset_ids = [] + + logger.info( + f"Masking onsets for {len(masked_onset_ids)} tokens ({time_since_first_onset_ms})" + ) + + best_score = 0 + for i in range(BEAM_WIDTH): + tok_id = top_ids[0, i].item() + tok_prob = top_probs[0, i] + assert tok_id != tokenizer.tok_to_id[tokenizer.time_tok] + + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): + next_logits = decode_one( + model, + idxs=torch.tensor([[tok_id]]).cuda(), + input_pos=torch.tensor( + [idx - 1], device="cuda", dtype=torch.int + ), + ) + logger.info( + f"Sampled logits for tok at pos {idx} by adding {tokenizer.id_to_tok[tok_id]}" + ) + + next_probs = torch.softmax(next_logits, dim=-1) + next_probs[:, masked_onset_ids] = 0 + next_tok_prob, next_tok_id = torch.max(next_probs, dim=-1) + + logger.info( + f"Sampled {tokenizer.id_to_tok[next_tok_id[0].item()]} with p={next_tok_prob}" + ) + + score = (tok_prob * next_tok_prob).item() + if score > best_score: + tok_id_1, tok_id_2 = tok_id, next_tok_id.item() + best_score = score + + logger.info(f"Score={score}") + + logger.info( + f"Filling in kv at position {idx-1} with {tokenizer.id_to_tok[tok_id_1]} " + ) + + decode_one( + model, + idxs=torch.tensor([[tok_id_1]]).cuda(), + input_pos=torch.tensor([idx - 1], device="cuda", dtype=torch.int), + ) + + logger.info( + f"Selecting {tokenizer.id_to_tok[tok_id_1], tokenizer.id_to_tok[tok_id_2]}" ) - logger.info(f"Took {(time.time() - start_s) * 1000:.4f}") + + enc_seq[:, idx - 1] = tok_id_1 + enc_seq[:, idx] = tok_id_2 + generated_tokens_queue.put(tokenizer.id_to_tok[tok_id_1]) + generated_tokens_queue.put(tokenizer.id_to_tok[tok_id_2]) + + return enc_seq, idx + 1 # TODO: Support CFG, guidance, and metadata tags @@ -216,6 +405,7 @@ def generate_tokens( control_sentinel: threading.Event, generated_tokens_queue: queue.Queue, num_preceding_active_pitches: int, + first_on_msg_epoch_ms: int, temperature: float = 0.95, top_p: float = 0.95, # cfg_gamma: float | None = None, @@ -225,7 +415,6 @@ def generate_tokens( f"Using sampling parameters: temperature={temperature}, top_p={top_p}" ) - dur_tok_ids = {tokenizer.tok_to_id[tok] for tok in tokenizer.dur_tokens} priming_seq_len = len(priming_seq) enc_seq = torch.tensor( [ @@ -249,16 +438,24 @@ def generate_tokens( f"Prefill took {(time.time() - prefill_start_s) * 1000:.2f} milliseconds" ) - idx = priming_seq_len - (3 * num_preceding_active_pitches) + # TODO: Still not 100% sure that decode_first_onset is completely correct + enc_seq, idx = decode_first_onset( + model=model, + enc_seq=enc_seq, + priming_seq=priming_seq, + tokenizer=tokenizer, + generated_tokens_queue=generated_tokens_queue, + first_on_msg_epoch_ms=first_on_msg_epoch_ms, + ) + logger.info(f"Starting from idx={idx}") while (not control_sentinel.is_set()) and idx < MAX_SEQ_LEN: decode_one_start_time_s = time.time() with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): - # BUG: Slicing is causing re-compilation which is annoying + prev_tok_id = enc_seq[0, idx - 1] logits = decode_one( model, - idxs=torch.tensor([[enc_seq[0, idx - 1]]]).cuda(), # Workaround - # idxs=enc_seq[:, idx - 1 : idx], + idxs=torch.tensor([[prev_tok_id]]).cuda(), input_pos=torch.tensor( [idx - 1], device="cuda", dtype=torch.int ), @@ -279,15 +476,23 @@ def generate_tokens( # NOTE: This logic controls re-sampling of potentially truncated notes # (durations) due to control-signal interruption in capture_midi_input if idx < priming_seq_len: - if enc_seq[0, idx].item() not in dur_tok_ids: - logger.info( - f"Override prediction {tokenizer.id_to_tok[next_token_ids[0].item()]} -> {tokenizer.id_to_tok[enc_seq[:, idx].item()]}" - ) - next_token_ids = enc_seq[:, idx] - else: + priming_tok = tokenizer.id_to_tok[enc_seq[0, idx].item()] + predicted_tok = tokenizer.id_to_tok[next_token_ids[0].item()] + + resample = False + if isinstance(priming_tok, tuple) and priming_tok[0] == "dur": + priming_dur = priming_tok[1] + predicted_dur = predicted_tok[1] + + if predicted_dur > priming_dur: + resample = True + + if resample is True: logger.info( f"Resampled ground truth {tokenizer.id_to_tok[enc_seq[:, idx].item()]} -> {tokenizer.id_to_tok[next_token_ids[0].item()]}" ) + else: + next_token_ids = enc_seq[:, idx] enc_seq[:, idx] = next_token_ids next_token = tokenizer.id_to_tok[next_token_ids[0].item()] @@ -306,14 +511,14 @@ def decode_tokens_to_midi( generated_tokens_queue: queue.Queue, midi_messages_queue: queue.Queue, tokenizer: InferenceAbsTokenizer, - first_on_msg_epoch_ms: float, - priming_seq_last_onset_ms: float, + first_on_msg_epoch_ms: int, + priming_seq_last_onset_ms: int, control_sentinel: threading.Event, ): logger = get_logger("DECODE") assert ( - first_on_msg_epoch_ms + priming_seq_last_onset_ms < time.time() * 1000 + first_on_msg_epoch_ms + priming_seq_last_onset_ms < get_epoch_time_ms() ) logger.info(f"first_on_msg_epoch_ms: {first_on_msg_epoch_ms}") @@ -362,7 +567,7 @@ def decode_tokens_to_midi( midi_messages_queue.put(off_msg) logger.info(f"Put message: {on_msg}") logger.info(f"Put message: {off_msg}") - logger.info(f"Ahead by {onset_epoch_ms - round(time.time() * 1000)}ms") + logger.info(f"Ahead by {onset_epoch_ms - get_epoch_time_ms()}ms") note_buffer = [] @@ -402,7 +607,7 @@ def stream_midi( ) while midi_messages: - curr_epoch_time_ms = round(time.time() * 1000) + curr_epoch_time_ms = get_epoch_time_ms() msg = midi_messages[0] if 0 < curr_epoch_time_ms - msg["epoch_time_ms"] <= 50: @@ -535,6 +740,7 @@ def stream_msgs( "temperature": temperature, "top_p": top_p, "num_preceding_active_pitches": num_preceding_active_pitches, + "first_on_msg_epoch_ms": first_on_msg_epoch_ms, }, daemon=True, ) @@ -614,6 +820,12 @@ def capture_midi_input( while not control_sentinel.is_set(): msg = midi_input.receive(block=False) + # DEBUG REMEMBER TO REMOVE + # if ( + # first_on_msg_epoch_ms is not None + # and get_epoch_time_ms() - first_on_msg_epoch_ms > 14100 + # ): + # control_sentinel.set() if msg is None: time.sleep(0.001) continue @@ -621,11 +833,9 @@ def capture_midi_input( if prev_msg_epoch_time_ms is None: msg_time_ms = 0 else: - msg_time_ms = round( - round((time.time() * 1000) - prev_msg_epoch_time_ms) - ) + msg_time_ms = get_epoch_time_ms() - prev_msg_epoch_time_ms - prev_msg_epoch_time_ms = round(time.time() * 1000) + prev_msg_epoch_time_ms = get_epoch_time_ms() msg.time = msg_time_ms msg.channel = 0 logger.info(f"{msg}") @@ -642,7 +852,7 @@ def capture_midi_input( midi_through.send(msg) elif msg.type == "note_on" and msg.velocity > 0: if first_on_msg_epoch_ms is None: - first_on_msg_epoch_ms = round(time.time() * 1000) + first_on_msg_epoch_ms = get_epoch_time_ms() active_pitches.add(msg.note) received_messages.append(msg) From d6a865b0786181f1c318e8533d34d186ead5cd6a Mon Sep 17 00:00:00 2001 From: Louis Date: Sun, 5 Jan 2025 22:51:48 +0000 Subject: [PATCH 04/72] demo finished --- demo/demo.py | 562 +++++++++++++++++++------------------ demo/demo.sh | 5 +- demo/midi-tunnel-server.py | 186 ++++++++---- 3 files changed, 425 insertions(+), 328 deletions(-) diff --git a/demo/demo.py b/demo/demo.py index 2f17805..65d4112 100644 --- a/demo/demo.py +++ b/demo/demo.py @@ -2,9 +2,7 @@ import argparse import os -import keyboard import time -import functools import uuid import copy import logging @@ -33,32 +31,9 @@ DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 MAX_SEQ_LEN = 8192 - # TODO: # - Add CFG support -# - Add beam-search for first generated onset (watch out for ) -# - Add loop functionality - - -# CONTROL FLOW: - -# 1. Loads model, compiles forward -# 2. Listen on MIDI port for first note -# 3. Start timer at first message seen -# 4. Wait for control-signal -# 3. Signal seen -> prefill all closed notes -# 4. Wait for all notes to close (ignore new notes) -> prefill the rest of the notes -# 5. Init main loop: - -# Generate next token -# Decode NoteMessage -# Add to message global NoteMessages -# Convert into on/off MIDI message and add to dequeue -# Check next message against current time and send message while msg_time <= curr_time -# Listen for control-signal - -# 6. Pop all note-on msgs, pop all note-off messages that are not processed yet -# 7. Go back to (3) - sending messages from extra list arg of off-msgs +# - Add looping functionality file_handler = logging.FileHandler("./demo.log", mode="w") file_handler.setLevel(logging.DEBUG) @@ -70,10 +45,8 @@ def get_logger(name: str | None = None) -> logging.Logger: logger.propagate = False logger.setLevel(logging.DEBUG) - # Custom formatter class to handle millisecond timestamps class MillisecondFormatter(logging.Formatter): def formatTime(self, record, datefmt=None): - # Get milliseconds since epoch using int() to remove decimal places created_ms = int(record.created * 1000) return str(created_ms) @@ -91,7 +64,6 @@ def formatTime(self, record, datefmt=None): ch.setFormatter(formatter) logger.addHandler(ch) - # Reuse shared file handler file_handler.setFormatter(formatter) logger.addHandler(file_handler) @@ -122,7 +94,6 @@ def compile_model(model: TransformerLM, max_seq_len: int): fullgraph=True, ) - # Might need to pass in pad_idxs? with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): start_compile_time_s = time.time() logger.info(f"Compiling forward pass") @@ -189,18 +160,25 @@ def load_model( @torch.autocast("cuda", dtype=DTYPE) @torch.inference_mode() def recalculate_dur_tokens( + model: TransformerLM, priming_seq: list, enc_seq: torch.Tensor, tokenizer: InferenceAbsTokenizer, - model: TransformerLM, start_idx: int, ): logger = get_logger("GENERATE") + assert start_idx > 0 + priming_seq_len = len(priming_seq) + num_time_toks_seen = priming_seq[:start_idx].count(tokenizer.time_tok) + curr_onset = num_time_toks_seen * 5000 + last_offset = tokenizer.calc_length_ms(priming_seq) + LAST_OFFSET_BUFFER_MS = 50 - for idx in range(priming_seq_len - start_idx, priming_seq_len): + for idx in range(start_idx, priming_seq_len): with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): prev_tok_id = enc_seq[0, idx - 1] + prev_tok = tokenizer.id_to_tok[prev_tok_id.item()] logits = decode_one( model, idxs=torch.tensor([[prev_tok_id]]).cuda(), @@ -208,141 +186,131 @@ def recalculate_dur_tokens( [idx - 1], device="cuda", dtype=torch.int ), ) - - logits[:, tokenizer.tok_to_id[tokenizer.dim_tok]] = float("-inf") - logits[:, tokenizer.tok_to_id[tokenizer.eos_tok]] = float("-inf") - logits[:, tokenizer.tok_to_id[tokenizer.prompt_start_tok]] = float( - "-inf" - ) + logger.debug( + f"Sampled logits for position {idx} by inserting {prev_tok} at position {idx-1}" + ) next_token_ids = torch.argmax(logits, dim=-1).flatten() priming_tok = tokenizer.id_to_tok[enc_seq[0, idx].item()] predicted_tok = tokenizer.id_to_tok[next_token_ids[0].item()] + logger.debug( + f"Ground truth token: {priming_tok}, resampled token: {predicted_tok}" + ) + resample = False - if isinstance(priming_tok, tuple) and priming_tok[0] == "dur": + if isinstance(priming_tok, tuple) and priming_tok[0] == "onset": + curr_onset = (num_time_toks_seen * 5000) + priming_tok[1] + elif priming_tok == tokenizer.time_tok: + num_time_toks_seen += 1 + curr_onset = num_time_toks_seen * 5000 + elif isinstance(priming_tok, tuple) and priming_tok[0] == "dur": + assert ( + isinstance(predicted_tok, tuple) and predicted_tok[0] == "dur" + ) + priming_dur = priming_tok[1] predicted_dur = predicted_tok[1] - if predicted_dur > priming_dur: + if (predicted_dur > priming_dur) and ( + curr_onset + priming_dur > last_offset - LAST_OFFSET_BUFFER_MS + ): resample = True if resample is True: logger.info( - f"Resampled ground truth {tokenizer.id_to_tok[enc_seq[:, idx].item()]} -> {tokenizer.id_to_tok[next_token_ids[0].item()]}" + f"Replaced ground truth for position {idx}: {tokenizer.id_to_tok[enc_seq[:, idx].item()]} -> {tokenizer.id_to_tok[next_token_ids[0].item()]}" ) enc_seq[:, idx] = next_token_ids + priming_seq[idx] = predicted_tok + + last_tok_id = enc_seq[0, idx] + last_tok = tokenizer.id_to_tok[last_tok_id.item()] - return enc_seq + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): + next_token_logits = decode_one( + model, + idxs=torch.tensor([[last_tok_id]]).cuda(), + input_pos=torch.tensor([idx], device="cuda", dtype=torch.int), + ) + logger.info(f"Updated KV-Cache by inserting {last_tok} at position {idx}") -# TODO: Clean this up -# - Replace the log statements with log statements of the normal form in generate for generating tokens -# - clean up logging -# - Make sure there is no bugs + return enc_seq, priming_seq, next_token_logits @torch.autocast("cuda", dtype=DTYPE) @torch.inference_mode() -def decode_first_onset( +def decode_first_tokens( model: TransformerLM, + first_token_logits: torch.Tensor, enc_seq: torch.Tensor, priming_seq: list, tokenizer: InferenceAbsTokenizer, generated_tokens_queue: queue.Queue, first_on_msg_epoch_ms: int, ): - logger = get_logger("GENERATE-FIRST") + logger = get_logger("GENERATE") + BEAM_WIDTH = 5 + BUFFER_MS = 50 + TIME_TOK_ID = tokenizer.tok_to_id[tokenizer.time_tok] + + logits = first_token_logits time_since_first_onset_ms = get_epoch_time_ms() - first_on_msg_epoch_ms - num_time_toks = priming_seq.count(tokenizer.time_tok) - time_tok_id = tokenizer.tok_to_id[tokenizer.time_tok] - idx = len(priming_seq) + idx = len(priming_seq) + 1 - # DEBUG - logger.info(f"Priming seq if length {idx}") - logger.info(f"MS since start: {time_since_first_onset_ms}") - logger.info(f"Number of time_toks in priming seq: {num_time_toks}") - # END DEBUG + num_time_toks_required = (time_since_first_onset_ms + BUFFER_MS) // 5000 + num_time_toks_in_priming_seq = priming_seq.count(tokenizer.time_tok) + num_time_toks_to_add = num_time_toks_required - num_time_toks_in_priming_seq - with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): - prev_tok_id = enc_seq[0, idx - 1] - logits = decode_one( - model, - idxs=torch.tensor([[prev_tok_id]]).cuda(), - input_pos=torch.tensor([idx - 1], device="cuda", dtype=torch.int), - ) - logger.info(f"Sampled logits for tok at pos {idx}") - _, top_ids = torch.topk(logits, k=BEAM_WIDTH, dim=-1) - idx += 1 + logger.info(f"Time since first onset: {time_since_first_onset_ms}ms") - num_time_toks_to_add = ( - (time_since_first_onset_ms + 200) // 5000 - ) - num_time_toks - append_time_toks = (num_time_toks_to_add > 0) or ( - tokenizer.tok_to_id[tokenizer.time_tok] in top_ids[0].tolist() - ) + while num_time_toks_to_add > 0: + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): + generated_tokens_queue.put(tokenizer.time_tok) + logits = decode_one( + model, + idxs=torch.tensor([[TIME_TOK_ID]]).cuda(), + input_pos=torch.tensor( + [idx - 1], device="cuda", dtype=torch.int + ), + ) - # DEBUG - logger.info( - f"top_toks = {[tokenizer.id_to_tok[id] for id in top_ids[0].tolist()]}" - ) - logger.info(f"Append time tok: {append_time_toks}") - logger.info(f"Num time toks to add: {num_time_toks_to_add}") - # END DEBUG + logger.info(f"Inserted time_tok at position {idx-1}") + num_time_toks_to_add -= 1 + enc_seq[:, idx - 1] = torch.tensor([[TIME_TOK_ID]]).cuda() + idx += 1 - if append_time_toks: - if num_time_toks_to_add == 0: - num_time_toks_to_add += 1 + log_probs = torch.log_softmax(logits, dim=-1) + top_log_probs, top_ids = torch.topk(log_probs, k=BEAM_WIDTH, dim=-1) - while num_time_toks_to_add > 0: - with torch.nn.attention.sdpa_kernel( - torch.nn.attention.SDPBackend.MATH - ): - enc_seq[:, idx] = torch.tensor([[time_tok_id]]).cuda() - generated_tokens_queue.put(tokenizer.time_tok) - logits = decode_one( - model, - idxs=torch.tensor([[time_tok_id]]).cuda(), - input_pos=torch.tensor( - [idx - 1], device="cuda", dtype=torch.int - ), - ) - logger.info( - f"Sampled logits for tok at pos {idx} by adding time_tok" - ) - num_time_toks_to_add -= 1 - idx += 1 + if TIME_TOK_ID not in top_ids[0].tolist(): + top_ids[0, -1] = TIME_TOK_ID + top_log_probs[0, -1] = log_probs[0, TIME_TOK_ID] - 1 - # BEAM SEARCH - probs = torch.softmax(logits, dim=-1) - top_probs, top_ids = torch.topk(probs, k=BEAM_WIDTH, dim=-1) + top_toks = [tokenizer.id_to_tok[id] for id in top_ids[0].tolist()] - # DEBUG - logger.info( - f"top_toks = {[tokenizer.id_to_tok[id] for id in top_ids[0].tolist()]}" + logger.debug(f"Calculated top {BEAM_WIDTH} tokens={top_toks}") + logger.debug( + f"Calculated top {BEAM_WIDTH} scores={top_log_probs[0].tolist()}" ) - logger.info(f"top_probs = {top_probs}") - # END DEBUG - - if append_time_toks is False: - masked_onset_ids = [ - tokenizer.tok_to_id[tok] - for tok in tokenizer.onset_tokens - if tok[1] < (time_since_first_onset_ms % 5000) - ] - else: - masked_onset_ids = [] - logger.info( - f"Masking onsets for {len(masked_onset_ids)} tokens ({time_since_first_onset_ms})" + masked_onset_ids = [ + tokenizer.tok_to_id[tok] + for tok in tokenizer.onset_tokens + if tok[1] < ((time_since_first_onset_ms + BUFFER_MS) % 5000) + ] + + logger.debug( + f"Masking onsets for {len(masked_onset_ids)} tokens ({time_since_first_onset_ms + BUFFER_MS})" ) - best_score = 0 + best_score = float("-inf") for i in range(BEAM_WIDTH): + tok = top_toks[i] tok_id = top_ids[0, i].item() - tok_prob = top_probs[0, i] - assert tok_id != tokenizer.tok_to_id[tokenizer.time_tok] + tok_log_prob = top_log_probs[0, i] with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): next_logits = decode_one( @@ -352,50 +320,117 @@ def decode_first_onset( [idx - 1], device="cuda", dtype=torch.int ), ) - logger.info( - f"Sampled logits for tok at pos {idx} by adding {tokenizer.id_to_tok[tok_id]}" + logger.debug( + f"Sampled logits for positions {idx} by inserting {tok} at position {idx-1}" ) - next_probs = torch.softmax(next_logits, dim=-1) - next_probs[:, masked_onset_ids] = 0 - next_tok_prob, next_tok_id = torch.max(next_probs, dim=-1) + next_log_probs = torch.log_softmax(next_logits, dim=-1) + next_log_probs[:, masked_onset_ids] = float("-inf") + next_tok_log_prob, next_tok_id = torch.max(next_log_probs, dim=-1) + next_tok = tokenizer.id_to_tok[next_tok_id.item()] + score = tok_log_prob + next_tok_log_prob logger.info( - f"Sampled {tokenizer.id_to_tok[next_tok_id[0].item()]} with p={next_tok_prob}" + f"Calculated tuple {(tok, next_tok)} with scores {(tok_log_prob.item(), next_tok_log_prob.item())} (combined={score.item()})" ) - score = (tok_prob * next_tok_prob).item() if score > best_score: - tok_id_1, tok_id_2 = tok_id, next_tok_id.item() + best_tok_id_1, best_tok_id_2 = tok_id, next_tok_id.item() + best_tok_1, best_tok_2 = ( + tokenizer.id_to_tok[best_tok_id_1], + tokenizer.id_to_tok[best_tok_id_2], + ) best_score = score - logger.info(f"Score={score}") - logger.info( - f"Filling in kv at position {idx-1} with {tokenizer.id_to_tok[tok_id_1]} " + f"Chose tuple {(best_tok_1, best_tok_2)} with score {best_score.item()}" ) - decode_one( - model, - idxs=torch.tensor([[tok_id_1]]).cuda(), - input_pos=torch.tensor([idx - 1], device="cuda", dtype=torch.int), + enc_seq[:, idx - 1] = best_tok_id_1 + enc_seq[:, idx] = best_tok_id_2 + generated_tokens_queue.put(tokenizer.id_to_tok[best_tok_id_1]) + generated_tokens_queue.put(tokenizer.id_to_tok[best_tok_id_2]) + + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): + decode_one( + model, + idxs=torch.tensor([[best_tok_id_1]]).cuda(), + input_pos=torch.tensor([idx - 1], device="cuda", dtype=torch.int), + ) + + logger.info( + f"Updated KV-Cache by re-inserting {best_tok_1} at position {idx-1}" ) + logger.info( + f"Inserted {best_tok_2} at position {idx} without updating KV-Cache" + ) + + return enc_seq, idx + 1 + +def decode_tokens( + model: TransformerLM, + enc_seq: torch.Tensor, + tokenizer: InferenceAbsTokenizer, + control_sentinel: threading.Event, + generated_tokens_queue: queue.Queue, + idx: int, + temperature: float, + top_p: float, +): + logger = get_logger("GENERATE") logger.info( - f"Selecting {tokenizer.id_to_tok[tok_id_1], tokenizer.id_to_tok[tok_id_2]}" + f"Using sampling parameters: temperature={temperature}, top_p={top_p}" ) - enc_seq[:, idx - 1] = tok_id_1 - enc_seq[:, idx] = tok_id_2 - generated_tokens_queue.put(tokenizer.id_to_tok[tok_id_1]) - generated_tokens_queue.put(tokenizer.id_to_tok[tok_id_2]) + while (not control_sentinel.is_set()) and idx < MAX_SEQ_LEN: + decode_one_start_time_s = time.time() - return enc_seq, idx + 1 + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): + prev_tok_id = enc_seq[0, idx - 1] + prev_tok = tokenizer.id_to_tok[prev_tok_id.item()] + + logits = decode_one( + model, + idxs=torch.tensor([[prev_tok_id]]).cuda(), + input_pos=torch.tensor( + [idx - 1], device="cuda", dtype=torch.int + ), + ) + + logger.debug( + f"Sampled logits for positions {idx} by inserting {prev_tok} at position {idx-1}" + ) + + logits[:, tokenizer.tok_to_id[tokenizer.dim_tok]] = float("-inf") + logits[:, tokenizer.tok_to_id[tokenizer.eos_tok]] = float("-inf") + logits[:, tokenizer.tok_to_id[tokenizer.prompt_start_tok]] = float( + "-inf" + ) + + if temperature > 0.0: + probs = torch.softmax(logits / temperature, dim=-1) + next_token_ids = sample_top_p(probs, top_p).flatten() + else: + next_token_ids = torch.argmax(logits, dim=-1).flatten() + + enc_seq[:, idx] = next_token_ids + next_token = tokenizer.id_to_tok[next_token_ids[0].item()] + logger.info( + f"({(time.time() - decode_one_start_time_s)*1000:.2f}ms) {idx}: {next_token}" + ) + + generated_tokens_queue.put(next_token) + idx += 1 + + logger.info("Seen exit signal") + generated_tokens_queue.put(None) # TODO: Support CFG, guidance, and metadata tags # TODO: Context length switching -# TODO: Get the model to predict durations of notes trucated by +# TODO: BUG: I'm still not 100% sure that the KV is being calculated correctly +# TODO: BUG: Potentially a bug with dim_toks ect... being removed during kv-preprocessing @torch.autocast("cuda", dtype=DTYPE) @torch.inference_mode() def generate_tokens( @@ -408,14 +443,12 @@ def generate_tokens( first_on_msg_epoch_ms: int, temperature: float = 0.95, top_p: float = 0.95, - # cfg_gamma: float | None = None, ): logger = get_logger("GENERATE") - logger.info( - f"Using sampling parameters: temperature={temperature}, top_p={top_p}" - ) + generate_start_s = time.time() priming_seq_len = len(priming_seq) + start_idx = max(2, priming_seq_len - 4 * num_preceding_active_pitches) enc_seq = torch.tensor( [ tokenizer.encode( @@ -425,22 +458,41 @@ def generate_tokens( ], device="cuda", ) - logger.debug(priming_seq) + + logger.debug(f"Priming sequence {priming_seq}") logger.info(f"Priming sequence length: {priming_seq_len}") + logger.info(f"Prefilling up to (and including) position: {start_idx-2}") + # In theory we could reuse the logits from prefill prefill_start_s = time.time() prefill( model, - idxs=enc_seq[:, :priming_seq_len], - input_pos=torch.arange(0, priming_seq_len, device="cuda"), + idxs=enc_seq[:, : start_idx - 1], + input_pos=torch.arange(0, start_idx - 1, device="cuda"), ) + logger.info( f"Prefill took {(time.time() - prefill_start_s) * 1000:.2f} milliseconds" ) + logger.info(f"Starting duration recalculation from: {start_idx}") + + recalculate_dur_start_s = time.time() + enc_seq, priming_seq, next_token_logits = recalculate_dur_tokens( + model=model, + priming_seq=priming_seq, + enc_seq=enc_seq, + tokenizer=tokenizer, + start_idx=start_idx, + ) - # TODO: Still not 100% sure that decode_first_onset is completely correct - enc_seq, idx = decode_first_onset( + logger.info( + f"Recalculating durations took {(time.time() - recalculate_dur_start_s) * 1000:.2f} milliseconds" + ) + + decode_first_s = time.time() + enc_seq, idx = decode_first_tokens( model=model, + first_token_logits=next_token_logits, enc_seq=enc_seq, priming_seq=priming_seq, tokenizer=tokenizer, @@ -448,63 +500,23 @@ def generate_tokens( first_on_msg_epoch_ms=first_on_msg_epoch_ms, ) - logger.info(f"Starting from idx={idx}") - while (not control_sentinel.is_set()) and idx < MAX_SEQ_LEN: - decode_one_start_time_s = time.time() - with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): - prev_tok_id = enc_seq[0, idx - 1] - logits = decode_one( - model, - idxs=torch.tensor([[prev_tok_id]]).cuda(), - input_pos=torch.tensor( - [idx - 1], device="cuda", dtype=torch.int - ), - ) - - logits[:, tokenizer.tok_to_id[tokenizer.dim_tok]] = float("-inf") - logits[:, tokenizer.tok_to_id[tokenizer.eos_tok]] = float("-inf") - logits[:, tokenizer.tok_to_id[tokenizer.prompt_start_tok]] = float( - "-inf" - ) - - if temperature > 0.0: - probs = torch.softmax(logits / temperature, dim=-1) - next_token_ids = sample_top_p(probs, top_p).flatten() - else: - next_token_ids = torch.argmax(logits, dim=-1).flatten() - - # NOTE: This logic controls re-sampling of potentially truncated notes - # (durations) due to control-signal interruption in capture_midi_input - if idx < priming_seq_len: - priming_tok = tokenizer.id_to_tok[enc_seq[0, idx].item()] - predicted_tok = tokenizer.id_to_tok[next_token_ids[0].item()] - - resample = False - if isinstance(priming_tok, tuple) and priming_tok[0] == "dur": - priming_dur = priming_tok[1] - predicted_dur = predicted_tok[1] - - if predicted_dur > priming_dur: - resample = True - - if resample is True: - logger.info( - f"Resampled ground truth {tokenizer.id_to_tok[enc_seq[:, idx].item()]} -> {tokenizer.id_to_tok[next_token_ids[0].item()]}" - ) - else: - next_token_ids = enc_seq[:, idx] - - enc_seq[:, idx] = next_token_ids - next_token = tokenizer.id_to_tok[next_token_ids[0].item()] - logger.info( - f"({(time.time() - decode_one_start_time_s)*1000:.2f}ms) {idx}: {next_token}" - ) - - # To account for re-sampling - if idx >= priming_seq_len: - generated_tokens_queue.put(next_token) + logger.info( + f"Decode first two tokens took {(time.time() - decode_first_s) * 1000:.2f} milliseconds" + ) + logger.info( + f"Time to first token took {(time.time() - generate_start_s) * 1000:.2f} milliseconds" + ) - idx += 1 + decode_tokens( + model=model, + enc_seq=enc_seq, + tokenizer=tokenizer, + control_sentinel=control_sentinel, + generated_tokens_queue=generated_tokens_queue, + idx=idx, + temperature=temperature, + top_p=top_p, + ) def decode_tokens_to_midi( @@ -521,22 +533,30 @@ def decode_tokens_to_midi( first_on_msg_epoch_ms + priming_seq_last_onset_ms < get_epoch_time_ms() ) - logger.info(f"first_on_msg_epoch_ms: {first_on_msg_epoch_ms}") - logger.info(f"priming_seq_last_onset_ms: {priming_seq_last_onset_ms}") + logger.info(f"Priming sequence last onset: {priming_seq_last_onset_ms}") + logger.info( + f"Total time elapsed since first onset: {get_epoch_time_ms() - first_on_msg_epoch_ms}" + ) note_buffer = [] num_time_toks = priming_seq_last_onset_ms // 5000 - while not control_sentinel.is_set(): + while True: while True: tok = generated_tokens_queue.get() - logger.info(f"Seen token: {tok}") + if tok is None: + # This is triggered iff control sentinel is set a second time + logger.info("Seen exit signal") + midi_messages_queue.put(None) + return + + logger.debug(f"Seen token: {tok}") note_buffer.append(tok) if isinstance(tok, tuple) and tok[0] == "dur": break while note_buffer and note_buffer[0] == tokenizer.time_tok: - logger.info("Popping time_tok") + logger.debug("Popping time_tok") num_time_toks += 1 note_buffer.pop(0) @@ -565,8 +585,8 @@ def decode_tokens_to_midi( midi_messages_queue.put(on_msg) midi_messages_queue.put(off_msg) - logger.info(f"Put message: {on_msg}") - logger.info(f"Put message: {off_msg}") + logger.debug(f"Put message: {on_msg}") + logger.debug(f"Put message: {off_msg}") logger.info(f"Ahead by {onset_epoch_ms - get_epoch_time_ms()}ms") note_buffer = [] @@ -588,14 +608,13 @@ def stream_midi( with mido.open_output(midi_output_port) as midi_out: while not control_sentinel.is_set(): - while True: try: msg = midi_messages_queue.get_nowait() except queue.Empty: break else: - # logger.info(f"Got message: {msg}") + logger.debug(f"Received message: {msg}") midi_messages.append(msg) midi_messages = sorted( @@ -633,23 +652,22 @@ def stream_midi( mido_msg_with_time.time = max( 0, msg["epoch_time_ms"] - prev_msg_epoch_time_ms ) - prev_msg_epoch_time_ms = curr_epoch_time_ms + prev_msg_epoch_time_ms = msg["epoch_time_ms"] midi_out.send(mido_msg) msgs.append(mido_msg_with_time) - logger.info( - f"(D={msg['epoch_time_ms'] - curr_epoch_time_ms}) Sent message: {msg}" - ) + logger.info(mido_msg_with_time) + logger.info(f"Sent message: {msg}") else: logger.info( - f"(D={msg['epoch_time_ms'] - curr_epoch_time_ms}) Skipping note_off message due to uuid mismatch: {msg}" + f"Skipping note_off message due to uuid mismatch: {msg}" ) midi_messages.pop(0) elif curr_epoch_time_ms - msg["epoch_time_ms"] > 100: # Message occurs too far in the past logger.info( - f"(D={msg["epoch_time_ms"] - curr_epoch_time_ms}) Skipping message occurring too far in the past: {msg}" + f"Skipping message occurring too far ({curr_epoch_time_ms - msg["epoch_time_ms"]}ms) in the past: {msg}" ) midi_messages.pop(0) else: @@ -658,48 +676,32 @@ def stream_midi( time.sleep(0.005) - # Control sentinel seen - while True: - try: - msg = midi_messages_queue.get_nowait() - except queue.Empty: - break - else: - midi_messages.append(msg) + logger.info("Processing remaining note_off messages") + + remaining_note_off_messages = [ + msg + for msg in midi_messages + if msg["vel"] == 0 + and last_pitch_uuid.get(msg["pitch"]) == msg["uuid"] + ] + + while remaining_note_off_messages: + msg = remaining_note_off_messages.pop(0) + mido_msg = mido.Message( + "note_on", + note=msg["pitch"], + velocity=0, + channel=0, + time=msg["epoch_time_ms"] - prev_msg_epoch_time_ms, + ) + prev_msg_epoch_time_ms = msg["epoch_time_ms"] + midi_out.send(mido_msg) + logger.info(f"Sent message: {msg}") + msgs.append(mido_msg) + + return msgs - midi_messages = sorted( - midi_messages, - key=lambda msg: (msg["epoch_time_ms"], msg["vel"]), - ) - # # Turn off active pitches straight away - # for msg in midi_messages: - # if msg["vel"] == 0 and msg["pitch"] in active_pitches: - # mido_msg = mido.Message( - # "note_on", - # note=msg["pitch"], - # velocity=0, - # channel=0, - # time=0, - # ) - - # curr_epoch_time_ms = round(time.time() * 1000) - # mido_msg_with_time = copy.deepcopy(mido_msg) - # mido_msg_with_time.time = max( - # 0, curr_epoch_time_ms - prev_msg_epoch_time_ms - # ) - - # midi_out.send(mido_msg) - # msgs.append(mido_msg_with_time) - # logger.info(f"Sent message: {mido_msg}") - # prev_msg_epoch_time_ms = curr_epoch_time_ms - # active_pitches.remove(msg["pitch"]) - - return msgs - - -# TODO: Control sentinel needs to terminate generate and midi_msgs_queue -# It also needs to keep sending the note_off msgs, if and only if they are on time def stream_msgs( model: TransformerLM, tokenizer: InferenceAbsTokenizer, @@ -778,6 +780,8 @@ def stream_msgs( generate_tokens_thread.join() decode_tokens_to_midi_thread.join() + return msgs + def convert_msgs_to_midi(msgs: list[mido.Message]): track = mido.MidiTrack() @@ -820,12 +824,13 @@ def capture_midi_input( while not control_sentinel.is_set(): msg = midi_input.receive(block=False) - # DEBUG REMEMBER TO REMOVE + # if ( # first_on_msg_epoch_ms is not None - # and get_epoch_time_ms() - first_on_msg_epoch_ms > 14100 + # and get_epoch_time_ms() - first_on_msg_epoch_ms > 14300 # ): # control_sentinel.set() + if msg is None: time.sleep(0.001) continue @@ -838,7 +843,7 @@ def capture_midi_input( prev_msg_epoch_time_ms = get_epoch_time_ms() msg.time = msg_time_ms msg.channel = 0 - logger.info(f"{msg}") + logger.info(f"Received message: [{msg}]") if msg.is_meta is True or msg.type == "program_change": continue @@ -878,7 +883,7 @@ def capture_midi_input( note=pitch, velocity=0, channel=0, - time=msg_time_ms, + time=get_epoch_time_ms() - prev_msg_epoch_time_ms, ) received_messages.append(msg) if midi_through is not None: @@ -911,7 +916,7 @@ def capture_midi_input( # Workaround for the way that file-playback is implemented - delete msg = mido.Message( type="control_change", - control=66, + control=midi_control_signal, value=0, channel=0, time=0, @@ -935,10 +940,10 @@ def play_midi_file(midi_port: str, midi_path: str): def listen_for_control_signal_keypress(control_sentinel: threading.Event): logger = get_logger("KEYBOARD") for _ in range(2): + time.sleep(1) input() logger.info("Keypress seen") control_sentinel.set() - time.sleep(5) def parse_args(): @@ -1003,12 +1008,19 @@ def parse_args(): type=int, required=False, ) + argp.add_argument( + "-save_path", + type=str, + required=False, + help="Path to save complete MIDI file", + ) return argp.parse_args() def main(): args = parse_args() + logger = get_logger() tokenizer = InferenceAbsTokenizer() model = load_model(checkpoint_path=args.cp) model = compile_model(model=model, max_seq_len=MAX_SEQ_LEN) @@ -1025,9 +1037,6 @@ def main(): else: midi_input_port = args.midi_in - # TODO: All of the below logic should be in a loop with additional handling - # for the control sentinel - control_sentinel = threading.Event() keypress_thread = threading.Thread( target=listen_for_control_signal_keypress, @@ -1044,7 +1053,7 @@ def main(): ) control_sentinel.clear() - stream_msgs( + msgs = stream_msgs( model=model, tokenizer=tokenizer, msgs=msgs, @@ -1057,6 +1066,11 @@ def main(): ) keypress_thread.join() + if args.save_path: + logger.info(f"Saving result to {args.save_path}") + midi = convert_msgs_to_midi(msgs=msgs) + midi.save(args.save_path) + if __name__ == "__main__": main() diff --git a/demo/demo.sh b/demo/demo.sh index 615328e..05a0328 100644 --- a/demo/demo.sh +++ b/demo/demo.sh @@ -1,7 +1,8 @@ python /home/loubb/work/aria/demo/demo.py \ -cp /mnt/ssd1/aria/v2/medium-75-ft.safetensors \ - -midi_path /home/loubb/Dropbox/shared/audio.mid \ + -midi_path /home/loubb/Dropbox/shared/prompt/nocturne.mid \ -midi_out "Midi Through:Midi Through Port-1" \ -midi_through "Midi Through:Midi Through Port-2" \ + -save_path /home/loubb/Dropbox/shared/output.mid \ -midi_control_signal 66 \ - -temp 0.96 + -temp 0.95 diff --git a/demo/midi-tunnel-server.py b/demo/midi-tunnel-server.py index 988e200..d4c6f83 100755 --- a/demo/midi-tunnel-server.py +++ b/demo/midi-tunnel-server.py @@ -1,61 +1,143 @@ -import rtmidi import socket +import rtmidi import time -import struct +import subprocess +import signal +import sys +import os import argparse -class MIDIRouter: - def __init__(self, midi_port="14:0", udp_port=5004): - self.midi_in = rtmidi.MidiIn() - self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - self.udp_port = udp_port - - # Print available ports - ports = self.midi_in.get_ports() - print(f"Available MIDI ports: {ports}") - - # Find and open MIDI port - for i, port in enumerate(ports): - if midi_port in port: - print(f"Opening MIDI port {i}: {port}") - self.midi_in.open_port(i) +SSH_SERVER = "home-4090.remote" +def parse_arguments(): + parser = argparse.ArgumentParser(description='MIDI UDP bridge with SSH tunnel') + parser.add_argument('-p', '--port', type=int, default=5004, + help='UDP port number (default: 5004)') + return parser.parse_args() + +def kill_existing_process(port): + # Check and kill existing process on remote server + check_command = f"ssh {SSH_SERVER} 'lsof -ti :{port}'" + try: + pid = subprocess.check_output(check_command, shell=True).decode().strip() + if pid: + print(f"Found existing process {pid} on port {port}, killing it...") + kill_command = f"ssh {SSH_SERVER} 'kill -9 {pid}'" + subprocess.run(kill_command, shell=True) + # Wait a moment for the port to be freed + time.sleep(1) + except subprocess.CalledProcessError: + # No existing process found + pass + +def setup_ssh_tunnel(port): + while True: + try: + # Kill any existing process first + kill_existing_process(port) + + # Start SSH tunnel using socat + print(f"Attempting to establish SSH tunnel on port {port}...") + ssh_command = f"ssh {SSH_SERVER} 'socat -u UDP4-RECV:{port} STDOUT'" + local_socat = f"socat -u STDIN UDP4-SEND:localhost:{port}" + + ssh_process = subprocess.Popen(ssh_command, shell=True, stdout=subprocess.PIPE) + socat_process = subprocess.Popen(local_socat, shell=True, stdin=ssh_process.stdout) + + # Check if the processes started successfully + time.sleep(1) + if ssh_process.poll() is not None: # Process terminated + raise subprocess.CalledProcessError(ssh_process.returncode, ssh_command) + + print("SSH tunnel established successfully!") + return ssh_process, socat_process + + except (subprocess.CalledProcessError, OSError) as e: + print(f"Failed to establish SSH tunnel: {str(e)}") + print("Retrying in 1 second...") + time.sleep(1) + +def create_virtual_port(port): + midi_out = rtmidi.MidiOut() + # Create a virtual MIDI port with port number in name + midi_out.open_virtual_port(f"UDP_{port}") + return midi_out + +def start_udp_listener(port): + # Create UDP socket + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.bind(('localhost', port)) + return sock + +def split_midi_messages(data): + """Split a byte array into individual MIDI messages.""" + messages = [] + data_list = list(data) + i = 0 + while i < len(data_list): + # Check if we have a status byte (most significant bit is 1) + if data_list[i] >= 0x80: + # Most MIDI messages are 3 bytes + if i + 2 < len(data_list): + messages.append(data_list[i:i+3]) + i += 3 + else: + # Handle incomplete message at end of buffer break else: - print(f"Warning: Could not find port containing '{midi_port}'") - - self.midi_in.set_callback(self._midi_callback) + # Skip non-status bytes (shouldn't happen in properly formatted MIDI) + i += 1 + return messages - def _midi_callback(self, message, timestamp): - try: - print(f"Received MIDI message: {message[0]}") - midi_data = struct.pack(f'B' * len(message[0]), *message[0]) - self.socket.sendto(midi_data, ('localhost', self.udp_port)) - print(f"Sent {len(midi_data)} bytes to localhost:{self.udp_port}") - except Exception as e: - print(f"Error in callback: {e}") - - def start(self): - print(f"Routing MIDI messages through SSH tunnel on port {self.udp_port}...") - try: - while True: - time.sleep(0.1) - except KeyboardInterrupt: - self.stop() - - def stop(self): - print("Shutting down...") - self.midi_in.close_port() - self.socket.close() - -def parse_args(): - parser = argparse.ArgumentParser(description='MIDI to UDP router') - parser.add_argument('-midi_p', type=str, default="14:0", - help='MIDI port identifier (default: 14:0)') - parser.add_argument('-udp_p', type=int, default=5004, - help='UDP port for forwarding (default: 5004)') - return parser.parse_args() +def cleanup(ssh_process, socat_process, midi_out, sock): + print("\nCleaning up...") + # Kill the SSH and socat processes + if ssh_process: + os.killpg(os.getpgid(ssh_process.pid), signal.SIGTERM) + if socat_process: + socat_process.terminate() + # Close MIDI and socket + if midi_out: + midi_out.close_port() + if sock: + sock.close() + +def main(): + args = parse_arguments() + port = args.port + + ssh_process = None + socat_process = None + midi_out = None + sock = None + + try: + # Setup SSH tunnel first + print(f"Setting up SSH tunnel on port {port}...") + ssh_process, socat_process = setup_ssh_tunnel(port) + + # Setup MIDI and UDP + print(f"Creating virtual MIDI port UDP_{port}...") + midi_out = create_virtual_port(port) + print(f"Starting UDP listener on port {port}...") + sock = start_udp_listener(port) + + print(f"UDP MIDI Bridge started - listening on port {port}") + + while True: + data, addr = sock.recvfrom(1024) + if data: + # Split the data into individual MIDI messages + midi_messages = split_midi_messages(data) + for midi_message in midi_messages: + print(f"Sending MIDI message: {midi_message}") + midi_out.send_message(midi_message) + + except KeyboardInterrupt: + print("\nShutting down UDP MIDI Bridge...") + except Exception as e: + print(f"Error: {e}") + finally: + cleanup(ssh_process, socat_process, midi_out, sock) if __name__ == "__main__": - args = parse_args() - router = MIDIRouter(midi_port=args.midi_p, udp_port=args.udp_p) - router.start() + main() From 977f54b00e2bd2c545e2d6bf1b74604ad8f9423b Mon Sep 17 00:00:00 2001 From: Louis Date: Mon, 6 Jan 2025 15:39:00 +0000 Subject: [PATCH 05/72] undo mistake --- demo/midi-tunnel-client.py | 55 ++++++----- demo/midi-tunnel-server.py | 186 +++++++++++-------------------------- 2 files changed, 86 insertions(+), 155 deletions(-) diff --git a/demo/midi-tunnel-client.py b/demo/midi-tunnel-client.py index ae79385..d4c6f83 100644 --- a/demo/midi-tunnel-client.py +++ b/demo/midi-tunnel-client.py @@ -7,6 +7,7 @@ import os import argparse +SSH_SERVER = "home-4090.remote" def parse_arguments(): parser = argparse.ArgumentParser(description='MIDI UDP bridge with SSH tunnel') parser.add_argument('-p', '--port', type=int, default=5004, @@ -15,12 +16,12 @@ def parse_arguments(): def kill_existing_process(port): # Check and kill existing process on remote server - check_command = f"ssh home-4090 'lsof -ti :{port}'" + check_command = f"ssh {SSH_SERVER} 'lsof -ti :{port}'" try: pid = subprocess.check_output(check_command, shell=True).decode().strip() if pid: print(f"Found existing process {pid} on port {port}, killing it...") - kill_command = f"ssh home-4090 'kill -9 {pid}'" + kill_command = f"ssh {SSH_SERVER} 'kill -9 {pid}'" subprocess.run(kill_command, shell=True) # Wait a moment for the port to be freed time.sleep(1) @@ -29,19 +30,31 @@ def kill_existing_process(port): pass def setup_ssh_tunnel(port): - # Kill any existing process first - kill_existing_process(port) - - # Start SSH tunnel using socat - ssh_command = f"ssh home-4090 'socat -u UDP4-RECV:{port} STDOUT'" - local_socat = f"socat -u STDIN UDP4-SEND:localhost:{port}" - - ssh_process = subprocess.Popen(ssh_command, shell=True, stdout=subprocess.PIPE) - socat_process = subprocess.Popen(local_socat, shell=True, stdin=ssh_process.stdout) - - # Give the tunnel a moment to establish - time.sleep(1) - return ssh_process, socat_process + while True: + try: + # Kill any existing process first + kill_existing_process(port) + + # Start SSH tunnel using socat + print(f"Attempting to establish SSH tunnel on port {port}...") + ssh_command = f"ssh {SSH_SERVER} 'socat -u UDP4-RECV:{port} STDOUT'" + local_socat = f"socat -u STDIN UDP4-SEND:localhost:{port}" + + ssh_process = subprocess.Popen(ssh_command, shell=True, stdout=subprocess.PIPE) + socat_process = subprocess.Popen(local_socat, shell=True, stdin=ssh_process.stdout) + + # Check if the processes started successfully + time.sleep(1) + if ssh_process.poll() is not None: # Process terminated + raise subprocess.CalledProcessError(ssh_process.returncode, ssh_command) + + print("SSH tunnel established successfully!") + return ssh_process, socat_process + + except (subprocess.CalledProcessError, OSError) as e: + print(f"Failed to establish SSH tunnel: {str(e)}") + print("Retrying in 1 second...") + time.sleep(1) def create_virtual_port(port): midi_out = rtmidi.MidiOut() @@ -91,25 +104,25 @@ def cleanup(ssh_process, socat_process, midi_out, sock): def main(): args = parse_arguments() port = args.port - + ssh_process = None socat_process = None midi_out = None sock = None - + try: # Setup SSH tunnel first print(f"Setting up SSH tunnel on port {port}...") ssh_process, socat_process = setup_ssh_tunnel(port) - + # Setup MIDI and UDP print(f"Creating virtual MIDI port UDP_{port}...") midi_out = create_virtual_port(port) print(f"Starting UDP listener on port {port}...") sock = start_udp_listener(port) - + print(f"UDP MIDI Bridge started - listening on port {port}") - + while True: data, addr = sock.recvfrom(1024) if data: @@ -118,7 +131,7 @@ def main(): for midi_message in midi_messages: print(f"Sending MIDI message: {midi_message}") midi_out.send_message(midi_message) - + except KeyboardInterrupt: print("\nShutting down UDP MIDI Bridge...") except Exception as e: diff --git a/demo/midi-tunnel-server.py b/demo/midi-tunnel-server.py index d4c6f83..988e200 100755 --- a/demo/midi-tunnel-server.py +++ b/demo/midi-tunnel-server.py @@ -1,143 +1,61 @@ -import socket import rtmidi +import socket import time -import subprocess -import signal -import sys -import os +import struct import argparse -SSH_SERVER = "home-4090.remote" -def parse_arguments(): - parser = argparse.ArgumentParser(description='MIDI UDP bridge with SSH tunnel') - parser.add_argument('-p', '--port', type=int, default=5004, - help='UDP port number (default: 5004)') - return parser.parse_args() - -def kill_existing_process(port): - # Check and kill existing process on remote server - check_command = f"ssh {SSH_SERVER} 'lsof -ti :{port}'" - try: - pid = subprocess.check_output(check_command, shell=True).decode().strip() - if pid: - print(f"Found existing process {pid} on port {port}, killing it...") - kill_command = f"ssh {SSH_SERVER} 'kill -9 {pid}'" - subprocess.run(kill_command, shell=True) - # Wait a moment for the port to be freed - time.sleep(1) - except subprocess.CalledProcessError: - # No existing process found - pass - -def setup_ssh_tunnel(port): - while True: - try: - # Kill any existing process first - kill_existing_process(port) - - # Start SSH tunnel using socat - print(f"Attempting to establish SSH tunnel on port {port}...") - ssh_command = f"ssh {SSH_SERVER} 'socat -u UDP4-RECV:{port} STDOUT'" - local_socat = f"socat -u STDIN UDP4-SEND:localhost:{port}" - - ssh_process = subprocess.Popen(ssh_command, shell=True, stdout=subprocess.PIPE) - socat_process = subprocess.Popen(local_socat, shell=True, stdin=ssh_process.stdout) - - # Check if the processes started successfully - time.sleep(1) - if ssh_process.poll() is not None: # Process terminated - raise subprocess.CalledProcessError(ssh_process.returncode, ssh_command) - - print("SSH tunnel established successfully!") - return ssh_process, socat_process - - except (subprocess.CalledProcessError, OSError) as e: - print(f"Failed to establish SSH tunnel: {str(e)}") - print("Retrying in 1 second...") - time.sleep(1) - -def create_virtual_port(port): - midi_out = rtmidi.MidiOut() - # Create a virtual MIDI port with port number in name - midi_out.open_virtual_port(f"UDP_{port}") - return midi_out - -def start_udp_listener(port): - # Create UDP socket - sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - sock.bind(('localhost', port)) - return sock - -def split_midi_messages(data): - """Split a byte array into individual MIDI messages.""" - messages = [] - data_list = list(data) - i = 0 - while i < len(data_list): - # Check if we have a status byte (most significant bit is 1) - if data_list[i] >= 0x80: - # Most MIDI messages are 3 bytes - if i + 2 < len(data_list): - messages.append(data_list[i:i+3]) - i += 3 - else: - # Handle incomplete message at end of buffer +class MIDIRouter: + def __init__(self, midi_port="14:0", udp_port=5004): + self.midi_in = rtmidi.MidiIn() + self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + self.udp_port = udp_port + + # Print available ports + ports = self.midi_in.get_ports() + print(f"Available MIDI ports: {ports}") + + # Find and open MIDI port + for i, port in enumerate(ports): + if midi_port in port: + print(f"Opening MIDI port {i}: {port}") + self.midi_in.open_port(i) break else: - # Skip non-status bytes (shouldn't happen in properly formatted MIDI) - i += 1 - return messages - -def cleanup(ssh_process, socat_process, midi_out, sock): - print("\nCleaning up...") - # Kill the SSH and socat processes - if ssh_process: - os.killpg(os.getpgid(ssh_process.pid), signal.SIGTERM) - if socat_process: - socat_process.terminate() - # Close MIDI and socket - if midi_out: - midi_out.close_port() - if sock: - sock.close() - -def main(): - args = parse_arguments() - port = args.port - - ssh_process = None - socat_process = None - midi_out = None - sock = None - - try: - # Setup SSH tunnel first - print(f"Setting up SSH tunnel on port {port}...") - ssh_process, socat_process = setup_ssh_tunnel(port) - - # Setup MIDI and UDP - print(f"Creating virtual MIDI port UDP_{port}...") - midi_out = create_virtual_port(port) - print(f"Starting UDP listener on port {port}...") - sock = start_udp_listener(port) + print(f"Warning: Could not find port containing '{midi_port}'") + + self.midi_in.set_callback(self._midi_callback) - print(f"UDP MIDI Bridge started - listening on port {port}") - - while True: - data, addr = sock.recvfrom(1024) - if data: - # Split the data into individual MIDI messages - midi_messages = split_midi_messages(data) - for midi_message in midi_messages: - print(f"Sending MIDI message: {midi_message}") - midi_out.send_message(midi_message) - - except KeyboardInterrupt: - print("\nShutting down UDP MIDI Bridge...") - except Exception as e: - print(f"Error: {e}") - finally: - cleanup(ssh_process, socat_process, midi_out, sock) + def _midi_callback(self, message, timestamp): + try: + print(f"Received MIDI message: {message[0]}") + midi_data = struct.pack(f'B' * len(message[0]), *message[0]) + self.socket.sendto(midi_data, ('localhost', self.udp_port)) + print(f"Sent {len(midi_data)} bytes to localhost:{self.udp_port}") + except Exception as e: + print(f"Error in callback: {e}") + + def start(self): + print(f"Routing MIDI messages through SSH tunnel on port {self.udp_port}...") + try: + while True: + time.sleep(0.1) + except KeyboardInterrupt: + self.stop() + + def stop(self): + print("Shutting down...") + self.midi_in.close_port() + self.socket.close() + +def parse_args(): + parser = argparse.ArgumentParser(description='MIDI to UDP router') + parser.add_argument('-midi_p', type=str, default="14:0", + help='MIDI port identifier (default: 14:0)') + parser.add_argument('-udp_p', type=int, default=5004, + help='UDP port for forwarding (default: 5004)') + return parser.parse_args() if __name__ == "__main__": - main() + args = parse_args() + router = MIDIRouter(midi_port=args.midi_p, udp_port=args.udp_p) + router.start() From 877d6e0bd0e582dec4690ec1ec2ed4b730cf07d3 Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 7 Jan 2025 17:11:29 +0000 Subject: [PATCH 06/72] update demo --- demo/demo.py | 55 +++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 44 insertions(+), 11 deletions(-) diff --git a/demo/demo.py b/demo/demo.py index 65d4112..a3d1e84 100644 --- a/demo/demo.py +++ b/demo/demo.py @@ -255,6 +255,7 @@ def decode_first_tokens( BEAM_WIDTH = 5 BUFFER_MS = 50 TIME_TOK_ID = tokenizer.tok_to_id[tokenizer.time_tok] + TIME_TOK_WEIGHTING = -3 logits = first_token_logits time_since_first_onset_ms = get_epoch_time_ms() - first_on_msg_epoch_ms @@ -282,12 +283,16 @@ def decode_first_tokens( enc_seq[:, idx - 1] = torch.tensor([[TIME_TOK_ID]]).cuda() idx += 1 + logits[:, tokenizer.tok_to_id[tokenizer.dim_tok]] = float("-inf") + logits[:, tokenizer.tok_to_id[tokenizer.eos_tok]] = float("-inf") + logits[:, tokenizer.tok_to_id[tokenizer.prompt_start_tok]] = float("-inf") + log_probs = torch.log_softmax(logits, dim=-1) top_log_probs, top_ids = torch.topk(log_probs, k=BEAM_WIDTH, dim=-1) if TIME_TOK_ID not in top_ids[0].tolist(): top_ids[0, -1] = TIME_TOK_ID - top_log_probs[0, -1] = log_probs[0, TIME_TOK_ID] - 1 + top_log_probs[0, -1] = log_probs[0, TIME_TOK_ID] + TIME_TOK_WEIGHTING top_toks = [tokenizer.id_to_tok[id] for id in top_ids[0].tolist()] @@ -326,6 +331,9 @@ def decode_first_tokens( next_log_probs = torch.log_softmax(next_logits, dim=-1) next_log_probs[:, masked_onset_ids] = float("-inf") + if tok_id == TIME_TOK_ID: + next_log_probs[:, TIME_TOK_ID] = float("-inf") + next_tok_log_prob, next_tok_id = torch.max(next_log_probs, dim=-1) next_tok = tokenizer.id_to_tok[next_tok_id.item()] score = tok_log_prob + next_tok_log_prob @@ -423,6 +431,9 @@ def decode_tokens( generated_tokens_queue.put(next_token) idx += 1 + while not control_sentinel.is_set(): + time.sleep(0.1) + logger.info("Seen exit signal") generated_tokens_queue.put(None) @@ -525,7 +536,6 @@ def decode_tokens_to_midi( tokenizer: InferenceAbsTokenizer, first_on_msg_epoch_ms: int, priming_seq_last_onset_ms: int, - control_sentinel: threading.Event, ): logger = get_logger("DECODE") @@ -617,6 +627,9 @@ def stream_midi( logger.debug(f"Received message: {msg}") midi_messages.append(msg) + if control_sentinel.is_set(): + break + midi_messages = sorted( midi_messages, key=lambda msg: ( @@ -712,18 +725,23 @@ def stream_msgs( temperature: float, top_p: float, num_preceding_active_pitches: int, + guidance_midi_dict: MidiDict | None = None, + guidance_start_ms: int | None = None, + guidance_end_ms: int | None = None, ): midi = convert_msgs_to_midi(msgs=msgs) midi_dict = MidiDict(**midi_to_dict(midi)) priming_seq = tokenizer.tokenize( midi_dict=midi_dict, # prompt_intervals_ms=[ - # (0, round(time.time() * 1000) - first_on_msg_epoch_ms) + # (0, (get_epoch_time_ms() - 5000) - first_on_msg_epoch_ms) # ], prompt_intervals_ms=[], + guidance_midi_dict=guidance_midi_dict, + guidance_start_ms=guidance_start_ms, + guidance_end_ms=guidance_end_ms, ) priming_seq = priming_seq[: priming_seq.index(tokenizer.eos_tok)] - # priming_seq = priming_seq[: priming_seq.index(tokenizer.prompt_end_tok) + 1] if tokenizer.dim_tok in priming_seq: priming_seq.remove(tokenizer.dim_tok) @@ -759,7 +777,6 @@ def stream_msgs( priming_seq[priming_seq.index(tokenizer.bos_tok) :], onset=True, ), - "control_sentinel": control_sentinel, }, daemon=True, ) @@ -825,12 +842,6 @@ def capture_midi_input( while not control_sentinel.is_set(): msg = midi_input.receive(block=False) - # if ( - # first_on_msg_epoch_ms is not None - # and get_epoch_time_ms() - first_on_msg_epoch_ms > 14300 - # ): - # control_sentinel.set() - if msg is None: time.sleep(0.001) continue @@ -1025,6 +1036,25 @@ def main(): model = load_model(checkpoint_path=args.cp) model = compile_model(model=model, max_seq_len=MAX_SEQ_LEN) + if args.guidance_path: + assert ( + args.guidance_start_ms is not None and args.guidance_start_ms >= 0 + ) + assert args.guidance_end_ms is not None and args.guidance_end_ms >= 0 + assert ( + tokenizer._config["guidance"]["min_ms"] + <= args.guidance_end_ms - args.guidance_start_ms + <= tokenizer._config["guidance"]["max_ms"] + ) + guidance_midi_dict = MidiDict.from_midi(args.guidance_path) + + logger.info( + f"Using guidance from {args.guidance_path} in interval {[args.guidance_start_ms, args.guidance_end_ms]}" + ) + + else: + guidance_midi_dict = None + assert (args.midi_path and os.path.isfile(args.midi_path)) or args.midi_in if args.midi_path: midi_input_port = "Midi Through:Midi Through Port-0" @@ -1063,6 +1093,9 @@ def main(): temperature=args.temp, top_p=args.top_p, num_preceding_active_pitches=num_active_pitches, + guidance_midi_dict=guidance_midi_dict, + guidance_start_ms=args.guidance_start_ms, + guidance_end_ms=args.guidance_end_ms, ) keypress_thread.join() From 9a5c0112b5473d4ac732e131d5d58f4a9dcb9daa Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 7 Jan 2025 20:53:06 +0000 Subject: [PATCH 07/72] add prefill compile --- demo/demo.py | 144 ++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 115 insertions(+), 29 deletions(-) diff --git a/demo/demo.py b/demo/demo.py index a3d1e84..9f7447f 100644 --- a/demo/demo.py +++ b/demo/demo.py @@ -27,9 +27,11 @@ torch._inductor.config.coordinate_descent_tuning = True torch._inductor.config.triton.unique_kernel_names = True torch._inductor.config.fx_graph_cache = True +# torch.set_float32_matmul_precision("high") DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 MAX_SEQ_LEN = 8192 +PREFILL_COMPILE_SEQ_LEN = 1024 # TODO: # - Add CFG support @@ -74,19 +76,59 @@ def get_epoch_time_ms() -> int: return round(time.time() * 1000) -@torch.autocast("cuda", dtype=DTYPE) -@torch.inference_mode() -def compile_model(model: TransformerLM, max_seq_len: int): - logger = get_logger() - assert 10 < max_seq_len <= MAX_SEQ_LEN +def compiled_prefill( + model: TransformerLM, + enc_seq: torch.Tensor, +): + return prefill( + model=model, + idxs=enc_seq[:, :PREFILL_COMPILE_SEQ_LEN], + input_pos=torch.arange(0, PREFILL_COMPILE_SEQ_LEN, device="cuda"), + ) - model.eval() - model.setup_cache( - batch_size=1, - max_seq_len=max_seq_len, - dtype=DTYPE, + +def _compile_prefill( + model: TransformerLM, + logger: logging.Logger, +): + global compiled_prefill + compiled_prefill = torch.compile( + compiled_prefill, + mode="reduce-overhead", + fullgraph=True, + ) + + start_compile_time_s = time.time() + logger.info(f"Compiling prefill") + compiled_prefill( + model, enc_seq=torch.ones(1, 8192, device="cuda", dtype=torch.int) + ) + print + logger.info( + f"Finished compiling - took {time.time() - start_compile_time_s:.4f} seconds" ) + for _ in range(5): + compiled_prefill( + model, + enc_seq=torch.ones(1, 8192, device="cuda", dtype=torch.int), + ) + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + compiled_prefill( + model, enc_seq=torch.ones(1, 8192, device="cuda", dtype=torch.int) + ) + end_event.record() + end_event.synchronize() + compiled_prefill_ms = start_event.elapsed_time(end_event) + logger.info(f"Compiled prefill benchmark: {compiled_prefill_ms:.2f}ms") + + return model + + +def _compile_decode_one(model: TransformerLM, logger: logging.Logger): global decode_one decode_one = torch.compile( decode_one, @@ -99,32 +141,57 @@ def compile_model(model: TransformerLM, max_seq_len: int): logger.info(f"Compiling forward pass") decode_one( model, - idxs=torch.tensor([[0]]).cuda(), + idxs=torch.tensor([[0]], device="cuda", dtype=torch.int), input_pos=torch.tensor([0], device="cuda", dtype=torch.int), ) logger.info( f"Finished compiling - took {time.time() - start_compile_time_s:.4f} seconds" ) - for _ in range(100): + for _ in range(5): decode_one( model, - idxs=torch.tensor([[0]]).cuda(), + idxs=torch.tensor([[0]], device="cuda", dtype=torch.int).cuda(), input_pos=torch.tensor([0], device="cuda", dtype=torch.int), ) - compiled_forward_start_s = time.time() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() decode_one( model, - idxs=torch.tensor([[0]]).cuda(), + idxs=torch.tensor([[0]], device="cuda", dtype=torch.int).cuda(), input_pos=torch.tensor([0], device="cuda", dtype=torch.int), ) - compiled_forward_ms = (time.time() - compiled_forward_start_s) * 1000 + end_event.record() + end_event.synchronize() + + compiled_forward_ms = start_event.elapsed_time(end_event) compiled_forward_its = 1000 / compiled_forward_ms logger.info( f"Compiled forward pass benchmark: {compiled_forward_ms:.2f} ms/it ({compiled_forward_its:.2f} it/s)" ) + return model + + +@torch.autocast("cuda", dtype=DTYPE) +@torch.inference_mode() +def compile_model(model: TransformerLM, max_seq_len: int): + logger = get_logger() + assert 10 < max_seq_len <= MAX_SEQ_LEN + + model.eval() + model.setup_cache( + batch_size=1, + max_seq_len=max_seq_len, + dtype=DTYPE, + ) + + model = _compile_decode_one(model=model, logger=logger) + model = _compile_prefill(model=model, logger=logger) + return model @@ -181,7 +248,9 @@ def recalculate_dur_tokens( prev_tok = tokenizer.id_to_tok[prev_tok_id.item()] logits = decode_one( model, - idxs=torch.tensor([[prev_tok_id]]).cuda(), + idxs=torch.tensor( + [[prev_tok_id]], device="cuda", dtype=torch.int + ), input_pos=torch.tensor( [idx - 1], device="cuda", dtype=torch.int ), @@ -230,7 +299,7 @@ def recalculate_dur_tokens( with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): next_token_logits = decode_one( model, - idxs=torch.tensor([[last_tok_id]]).cuda(), + idxs=torch.tensor([[last_tok_id]], device="cuda", dtype=torch.int), input_pos=torch.tensor([idx], device="cuda", dtype=torch.int), ) @@ -253,7 +322,7 @@ def decode_first_tokens( logger = get_logger("GENERATE") BEAM_WIDTH = 5 - BUFFER_MS = 50 + BUFFER_MS = 100 TIME_TOK_ID = tokenizer.tok_to_id[tokenizer.time_tok] TIME_TOK_WEIGHTING = -3 @@ -272,7 +341,9 @@ def decode_first_tokens( generated_tokens_queue.put(tokenizer.time_tok) logits = decode_one( model, - idxs=torch.tensor([[TIME_TOK_ID]]).cuda(), + idxs=torch.tensor( + [[TIME_TOK_ID]], device="cuda", dtype=torch.int + ), input_pos=torch.tensor( [idx - 1], device="cuda", dtype=torch.int ), @@ -320,7 +391,7 @@ def decode_first_tokens( with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): next_logits = decode_one( model, - idxs=torch.tensor([[tok_id]]).cuda(), + idxs=torch.tensor([[tok_id]], device="cuda", dtype=torch.int), input_pos=torch.tensor( [idx - 1], device="cuda", dtype=torch.int ), @@ -362,7 +433,9 @@ def decode_first_tokens( with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): decode_one( model, - idxs=torch.tensor([[best_tok_id_1]]).cuda(), + idxs=torch.tensor( + [[best_tok_id_1]], device="cuda", dtype=torch.int + ), input_pos=torch.tensor([idx - 1], device="cuda", dtype=torch.int), ) @@ -400,7 +473,9 @@ def decode_tokens( logits = decode_one( model, - idxs=torch.tensor([[prev_tok_id]]).cuda(), + idxs=torch.tensor( + [[prev_tok_id]], device="cuda", dtype=torch.int + ), input_pos=torch.tensor( [idx - 1], device="cuda", dtype=torch.int ), @@ -468,6 +543,7 @@ def generate_tokens( ) ], device="cuda", + dtype=torch.int, ) logger.debug(f"Priming sequence {priming_seq}") @@ -476,12 +552,22 @@ def generate_tokens( # In theory we could reuse the logits from prefill prefill_start_s = time.time() - prefill( - model, - idxs=enc_seq[:, : start_idx - 1], - input_pos=torch.arange(0, start_idx - 1, device="cuda"), - ) + if start_idx < PREFILL_COMPILE_SEQ_LEN: + logger.info( + f"Using compiled prefill for sequence length: {PREFILL_COMPILE_SEQ_LEN}" + ) + compiled_prefill( + model=model, + enc_seq=enc_seq, + ) + else: + prefill( + model, + idxs=enc_seq[:, : start_idx - 1], + input_pos=torch.arange(0, start_idx - 1, device="cuda"), + ) + torch.cuda.synchronize() logger.info( f"Prefill took {(time.time() - prefill_start_s) * 1000:.2f} milliseconds" ) @@ -734,7 +820,7 @@ def stream_msgs( priming_seq = tokenizer.tokenize( midi_dict=midi_dict, # prompt_intervals_ms=[ - # (0, (get_epoch_time_ms() - 5000) - first_on_msg_epoch_ms) + # (0, (get_epoch_time_ms() - 3000) - first_on_msg_epoch_ms) # ], prompt_intervals_ms=[], guidance_midi_dict=guidance_midi_dict, From 3dd44bf7a0db2149fc81247ac7793f8d309f50ec Mon Sep 17 00:00:00 2001 From: Louis Date: Thu, 20 Feb 2025 17:40:08 +0000 Subject: [PATCH 08/72] add class finetuning --- aria/embeddings/__init__.py | 25 + aria/embeddings/classifier_finetune.py | 662 +++++++++++++++++++++++++ aria/embeddings/eval.py | 464 +++++++++++++++++ aria/model.py | 44 ++ config/models/medium-genre.json | 15 + 5 files changed, 1210 insertions(+) create mode 100644 aria/embeddings/__init__.py create mode 100644 aria/embeddings/classifier_finetune.py create mode 100644 aria/embeddings/eval.py create mode 100644 config/models/medium-genre.json diff --git a/aria/embeddings/__init__.py b/aria/embeddings/__init__.py new file mode 100644 index 0000000..6b4e906 --- /dev/null +++ b/aria/embeddings/__init__.py @@ -0,0 +1,25 @@ +""" +Plan: + +Embeddings experiments: + ARIA-MIDI - Classification of genre classical vs jazz (vs other?) + ARIA-MIDI - Classification musical time period (classical, baroque, ect...) + ARIA-MIDI - Classification of top 5 composers + ARIA-MIDI - Classification of top 5 pianists + + Pianist8 - Classification -- should work + VGMIDI - Classification -- Probably won't work as it's multi-track + WikiMT - Classification -- No idea + +Ablation comparisons: + Frozen classical embeddings (both mean and last-token from pretrained model) + Aria finetuned on these specific classification tasks (define new token) - TODO + Aria trained from scratch on these specific classification tasks (define new token) - TODO + Aria finetuned with contrastive learning - TODO + Aria trained with contrastive learning without next-token pretraining - TODO (Maybe skip) + +Other model comparisons: + Clamp2 or Clamp3 + MusicBERT + +""" diff --git a/aria/embeddings/classifier_finetune.py b/aria/embeddings/classifier_finetune.py new file mode 100644 index 0000000..7ff7330 --- /dev/null +++ b/aria/embeddings/classifier_finetune.py @@ -0,0 +1,662 @@ +import torch +import os +import mmap +import argparse +import logging +import random +import copy +import functools +import accelerate +import multiprocessing +import json +import jsonlines + +from aria.config import load_model_config +from aria.utils import _load_weight +from ariautils.tokenizer import AbsTokenizer +from ariautils.midi import MidiDict +from aria.model import TransformerCL, ModelConfig + +from torch import nn +from torch.utils.data import DataLoader, Dataset + +from accelerate.logging import get_logger +from logging.handlers import RotatingFileHandler +from tqdm import tqdm + + +def setup_logger(project_dir: str): + # Get logger and reset all handlers + logger = logging.getLogger(__name__) + for h in logger.handlers[:]: + logger.removeHandler(h) + + logger.propagate = False + logger.setLevel(logging.DEBUG) + formatter = logging.Formatter( + "[%(asctime)s] %(name)s: [%(levelname)s] %(message)s", + ) + fh = RotatingFileHandler( + os.path.join(project_dir, "logs.txt"), backupCount=5, maxBytes=1024**3 + ) + fh.setLevel(logging.DEBUG) + fh.setFormatter(formatter) + logger.addHandler(fh) + + ch = logging.StreamHandler() + ch.setLevel(logging.INFO) + ch.setFormatter(formatter) + logger.addHandler(ch) + + return get_logger(__name__) + + +def setup_project_dir(project_dir: str | None): + if not project_dir: + # Create project directory + if not os.path.isdir("./experiments"): + os.mkdir("./experiments") + + project_dirs = [ + _dir + for _dir in os.listdir("./experiments") + if os.path.isdir(os.path.join("experiments", _dir)) + ] + + ind = 0 + while True: + if str(ind) not in project_dirs: + break + else: + ind += 1 + + project_dir_abs = os.path.abspath(os.path.join("experiments", str(ind))) + assert not os.path.isdir(project_dir_abs) + os.mkdir(project_dir_abs) + + elif project_dir: + if os.path.isdir(project_dir): + assert ( + len(os.listdir(project_dir)) == 0 + ), "Provided project directory is not empty" + project_dir_abs = os.path.abspath(project_dir) + elif os.path.isfile(project_dir): + raise FileExistsError( + "The provided path points toward an existing file" + ) + else: + try: + os.mkdir(project_dir) + except Exception as e: + raise Exception( + f"Failed to create project directory at {project_dir}" + ) from e + + project_dir_abs = os.path.abspath(project_dir) + + os.mkdir(os.path.join(project_dir_abs, "checkpoints")) + + return project_dir_abs + + +def process_entry( + entry, + metadata_category: str, + tag_ids: dict, + min_slice_notes: int, + max_slice_notes: int, + max_seq_len: int, + tokenizer: AbsTokenizer, +): + midi_dict = MidiDict.from_msg_dict(entry) + metadata_tag = midi_dict.metadata.get(metadata_category, None) + + # Skip if metadata tag is missing or not in tag_ids. + if metadata_tag is None: + return [] + elif metadata_tag not in tag_ids: + metadata_tag = "other" + + outputs = [] + note_msgs = midi_dict.note_msgs + idx = 0 + + while idx < len(note_msgs): + slice_length = random.randint(min_slice_notes, max_slice_notes) + chunk = note_msgs[idx : idx + slice_length] + + # If the chunk is too short, break out of the loop. + if len(chunk) < min_slice_notes: + break + + idx += slice_length + + # Create slice + slice_midi_dict = copy.deepcopy(midi_dict) + slice_midi_dict.note_msgs = chunk + slice_midi_dict.metadata = {} + + # Format + tokenized_slice = tokenizer.tokenize(slice_midi_dict) + if tokenizer.dim_tok in tokenized_slice: + tokenized_slice.remove(tokenizer.dim_tok) + + # Use EOS tok for classification head + tokenized_slice = tokenized_slice[:max_seq_len] + tokenized_slice += [tokenizer.pad_tok] * ( + max_seq_len - len(tokenized_slice) + ) + if tokenizer.eos_tok not in tokenized_slice: + tokenized_slice[-1] = tokenizer.eos_tok + + pos = tokenized_slice.index(tokenizer.eos_tok) + + outputs.append( + {"seq": tokenized_slice, "tag": metadata_tag, "pos": pos} + ) + + return outputs + + +class FinetuningDataset(Dataset): + def __init__(self, load_path: str, tag_ids: dict): + self.load_path = load_path + self.tag_ids = tag_ids + self.tokenizer = AbsTokenizer() + self.index = [] + + self.file_buff = open(self.load_path, "rb") + self.mmap_obj = mmap.mmap( + self.file_buff.fileno(), 0, access=mmap.ACCESS_READ + ) + + while True: + pos = self.mmap_obj.tell() + line = self.mmap_obj.readline() + if not line: + break + self.index.append(pos) + + def __getitem__(self, idx: int): + def _format(tok): + # Required because json formats tuples into lists + if isinstance(tok, list): + return tuple(tok) + return tok + + file_pos = self.index[idx] + self.mmap_obj.seek(file_pos) + + raw_data = self.mmap_obj.readline().decode("utf-8") + json_data = json.loads(raw_data) + + seq, tag, pos = json_data["seq"], json_data["tag"], json_data["pos"] + assert tag in self.tag_ids.keys() + assert pos < len(seq) + + seq = [_format(tok) for tok in seq] + seq_enc = torch.tensor(self.tokenizer.encode(seq)) + tag_enc = torch.tensor(self.tag_ids[tag]) + pos_enc = torch.tensor(pos) + + assert seq_enc[pos_enc.item()].item() == 1 # EOS ID + + return seq_enc, tag_enc, pos_enc + + def __len__(self): + return len(self.index) + + @classmethod + def export_worker_init_fn(cls): + def worker_init_fn(worker_id: int): + worker_info = torch.utils.data.get_worker_info() + dataset = worker_info.dataset + + if hasattr(dataset, "mmap_obj") and dataset.mmap_obj: + dataset.mmap_obj.close() + + f = open(dataset.load_path, "rb") + dataset.mmap_obj = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) + + return worker_init_fn + + @classmethod + def build( + cls, + midi_dataset_load_path: str, + save_path: str, + min_slice_notes: int, + max_slice_notes: int, + max_seq_len: int, + metadata_category: str, + tag_ids: dict, + ): + assert os.path.isfile(midi_dataset_load_path) + assert os.path.isfile(save_path) is False + + tokenizer = AbsTokenizer() + + with jsonlines.open( + midi_dataset_load_path, "r" + ) as midi_dataset, jsonlines.open(save_path, "w") as writer: + + cnt = 0 + with multiprocessing.Pool() as pool: + for result in pool.imap_unordered( + functools.partial( + process_entry, + metadata_category=metadata_category, + tag_ids=tag_ids, + min_slice_notes=min_slice_notes, + max_slice_notes=max_slice_notes, + max_seq_len=max_seq_len, + tokenizer=tokenizer, + ), + midi_dataset, + chunksize=10, + ): + cnt += 1 + if cnt % 500 == 0: + print(f"Completed {cnt}") + + for chunk in result: + writer.write(chunk) + + +def _get_optim( + lr: float, + model: nn.Module, + num_epochs: int, + steps_per_epoch: int, + warmup: int = 100, + end_ratio: int = 0.1, +): + optimizer = torch.optim.AdamW( + model.parameters(), + lr=lr, + weight_decay=0.1, + betas=(0.9, 0.95), + eps=1e-5, + ) + + warmup_lrs = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=0.000001, + end_factor=1, + total_iters=warmup, + ) + linear_decay_lrs = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=1, + end_factor=end_ratio, + total_iters=(num_epochs * steps_per_epoch) - warmup, + ) + + lr_scheduler = torch.optim.lr_scheduler.SequentialLR( + optimizer, + schedulers=[warmup_lrs, linear_decay_lrs], + milestones=[warmup], + ) + + return optimizer, lr_scheduler + + +def get_optim( + model: nn.Module, + num_epochs: int, + steps_per_epoch: int, +): + LR = 1e-5 + END_RATIO = 0.1 + WARMUP_STEPS = 1000 + + return _get_optim( + lr=LR, + model=model, + num_epochs=num_epochs, + steps_per_epoch=steps_per_epoch, + warmup=WARMUP_STEPS, + end_ratio=END_RATIO, + ) + + +def get_dataloaders( + train_data_path: str, + val_data_path: str, + batch_size: int, + num_workers: int, + apply_aug=True, +): + TAG_IDS = {"classical": 0, "jazz": 1, "other": 2} + train_dataset = FinetuningDataset( + load_path=train_data_path, + tag_ids=TAG_IDS, + ) + val_dataset = FinetuningDataset( + load_path=val_data_path, + tag_ids=TAG_IDS, + ) + + train_loader = DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + ) + val_loader = DataLoader( + val_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + ) + return train_loader, val_loader + + +def _train( + num_epochs: int, + accelerator: accelerate.Accelerator, + model: TransformerCL, + train_dataloader: DataLoader, + val_dataloader: DataLoader, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler = None, + project_dir: str | None = None, +): + def make_checkpoint( + _accelerator: accelerate.Accelerator, _epoch: int, _step: int + ): + if accelerator.is_main_process: + checkpoint_dir = os.path.join( + project_dir, + "checkpoints", + f"epoch{_epoch}_step{_step}", + ) + + logger.info( + f"EPOCH {_epoch}/{num_epochs}: Saving checkpoint - {checkpoint_dir}" + ) + _accelerator.save_state(checkpoint_dir) + + def train_loop( + dataloader: DataLoader, _epoch: int, steps_per_checkpoint: int + ): + loss = torch.tensor([0.0]) + avg_train_loss = 0 + trailing_loss = 0 + loss_buffer = [] + + try: + lr_for_print = "{:.2e}".format(scheduler.get_last_lr()[0]) + except Exception: + pass + else: + lr_for_print = "{:.2e}".format(optimizer.param_groups[-1]["lr"]) + + model.train() + for __step, batch in ( + pbar := tqdm( + enumerate(dataloader), + total=len(dataloader), + initial=0, + leave=False, + ) + ): + pbar.set_postfix_str( + f"lr={lr_for_print}, " + f"loss={round(loss.item(), 4)}, " + f"trailing={round(trailing_loss, 4)}" + ) + + with accelerator.accumulate(model): + step = __step + 1 + + seqs, labels, eos_pos = batch + logits = model(seqs) # (b_sz, s_len, class_size) + logits = logits[ + torch.arange(logits.shape[0], device=logits.device), eos_pos + ] + loss = loss_fn(logits, labels) + + # Calculate statistics + loss_buffer.append(accelerator.gather(loss).mean(dim=0).item()) + trailing_loss = sum(loss_buffer[-TRAILING_LOSS_STEPS:]) / len( + loss_buffer[-TRAILING_LOSS_STEPS:] + ) + avg_train_loss = sum(loss_buffer) / len(loss_buffer) + + # Logging + logger.debug( + f"EPOCH {_epoch} STEP {step}: " + f"lr={lr_for_print}, " + f"loss={round(loss.item(), 4)}, " + f"trailing_loss={round(trailing_loss, 4)}, " + f"average_loss={round(avg_train_loss, 4)}" + ) + + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + if scheduler: + scheduler.step() + lr_for_print = "{:.2e}".format(scheduler.get_last_lr()[0]) + + if steps_per_checkpoint: + if step % steps_per_checkpoint == 0: + make_checkpoint( + _accelerator=accelerator, + _epoch=_epoch, + _step=step, + ) + + return avg_train_loss + + def val_loop(dataloader: DataLoader, _epoch: int): + model.eval() + val_loss_buffer = [] + total_correct = 0 + total_samples = 0 + + with torch.no_grad(): + pbar = tqdm( + dataloader, desc=f"Validation Epoch {_epoch}", leave=False + ) + for batch in pbar: + seqs, labels, eos_pos = batch + logits = model(seqs) # (b_sz, s_len, class_size) + logits = logits[ + torch.arange(logits.shape[0], device=logits.device), eos_pos + ] + loss = loss_fn(logits, labels) + # Gather loss from all devices (if applicable) + val_loss_buffer.append( + accelerator.gather(loss).mean(dim=0).item() + ) + + # Compute predictions and update accuracy stats + preds = torch.argmax(logits, dim=-1) + total_correct += (preds == labels).sum().item() + total_samples += labels.size(0) + current_accuracy = ( + total_correct / total_samples if total_samples > 0 else 0.0 + ) + current_avg_loss = sum(val_loss_buffer) / len(val_loss_buffer) + + pbar.set_postfix_str( + f"loss={round(current_avg_loss,4)}, acc={round(current_accuracy,4)}" + ) + + avg_val_loss = sum(val_loss_buffer) / len(val_loss_buffer) + accuracy = total_correct / total_samples if total_samples > 0 else 0.0 + + logger.info( + f"Validation Epoch {_epoch}: average_loss={round(avg_val_loss, 4)}, accuracy={round(accuracy, 4)}" + ) + return avg_val_loss, accuracy + + logger = get_logger(__name__) + loss_fn = nn.CrossEntropyLoss() + TRAILING_LOSS_STEPS = 100 + + train_loop(dataloader=train_dataloader, _epoch=0, steps_per_checkpoint=2000) + make_checkpoint(_accelerator=accelerator, _epoch=1, _step=0) + val_loop(dataloader=val_dataloader, _epoch=0) + + +def train( + model_name: str, + train_data_path: str, + val_data_path: str, + num_workers: int, + num_epochs: int, + batch_size: int, + grad_acc_steps: int, + project_dir: str | None = None, + checkpoint_path: str | None = None, +): + accelerator = accelerate.Accelerator( + project_dir=project_dir, + gradient_accumulation_steps=grad_acc_steps, + ) + + if accelerator.is_main_process: + project_dir = setup_project_dir(project_dir) + logger = setup_logger(os.path.join(project_dir)) + else: + # In other processes, we won't create logs + project_dir = project_dir or "./experiments" + logger = get_logger(__name__) + + logger.info(f"Project directory: {project_dir}") + logger.info( + f"Training config: epochs={num_epochs}, batch_size={batch_size}, num_workers={num_workers}" + ) + + tokenizer = AbsTokenizer() + model_config = ModelConfig(**load_model_config(model_name)) + model_config.set_vocab_size(tokenizer.vocab_size) + model = TransformerCL(model_config) + + if checkpoint_path is not None: + logger.info(f"Loading checkpoint from {checkpoint_path}") + model_state = _load_weight(checkpoint_path) + model_state = { + k.replace("_orig_mod.", ""): v for k, v in model_state.items() + } + if "lm_head.weight" in model_state.keys(): + del model_state["lm_head.weight"] + + model_state = { + k.replace("model.", ""): v for k, v in model_state.items() + } + model.model.load_state_dict(model_state) + else: + logger.info("No checkpoint path provided") + + model.compile() + + train_dataloader, val_dataloader = get_dataloaders( + train_data_path=train_data_path, + val_data_path=val_data_path, + batch_size=batch_size, + num_workers=num_workers, + apply_aug=True, + ) + + optimizer, scheduler = get_optim( + model=model, + num_epochs=num_epochs, + steps_per_epoch=len(train_dataloader), + ) + + ( + model, + train_dataloader, + val_dataloader, + optimizer, + scheduler, + ) = accelerator.prepare( + model, + train_dataloader, + val_dataloader, + optimizer, + scheduler, + ) + + _train( + num_epochs=num_epochs, + accelerator=accelerator, + model=model, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + optimizer=optimizer, + scheduler=scheduler, + project_dir=project_dir, + ) + + +def test_build_dataset(): + FinetuningDataset.build( + midi_dataset_load_path="/mnt/ssd1/aria/data/mididict-ft_val.jsonl", + save_path="/mnt/ssd1/aria/data/train.jsonl", + min_slice_notes=100, + max_slice_notes=165, + max_seq_len=512, + metadata_category="genre", + tag_ids={"classical": 0, "jazz": 1, "other": 2}, + ) + + # FinetuningDataset.build( + # midi_dataset_load_path="/mnt/ssd1/aria/data/mididict-ft_val.jsonl", + # save_path="/mnt/ssd1/aria/data/val.jsonl", + # min_slice_notes=100, + # max_slice_notes=165, + # max_seq_len=512, + # metadata_category="genre", + # tag_ids={"classical": 0, "jazz": 1, "other": 2}, + # ) + + +def test_dataset(): + dataset = FinetuningDataset( + load_path="/mnt/ssd1/aria/data/test.jsonl", + tag_ids={"classical": 0, "jazz": 1, "other": 2}, + ) + + for idx, entry in enumerate(dataset): + print(idx) + # print(entry) + # input("") + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Finetune a model for classification." + ) + parser.add_argument("--model_name", type=str, required=True) + parser.add_argument("--checkpoint_path", type=str, default=None) + parser.add_argument("--train_data_path", type=str, required=True) + parser.add_argument("--val_data_path", type=str, required=True) + parser.add_argument("--batch_size", type=int) + parser.add_argument("--num_epochs", type=int) + parser.add_argument("--num_workers", type=int, default=1) + parser.add_argument("--grad_acc_steps", type=int, default=1) + parser.add_argument("--project_dir", type=str, default=None) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + train( + model_name=args.model_name, + checkpoint_path=args.checkpoint_path, + train_data_path=args.train_data_path, + val_data_path=args.val_data_path, + batch_size=args.batch_size, + num_epochs=args.num_epochs, + num_workers=args.num_workers, + grad_acc_steps=args.grad_acc_steps, + project_dir=args.project_dir, + ) + + # test_build_dataset() + # test_dataset() diff --git a/aria/embeddings/eval.py b/aria/embeddings/eval.py new file mode 100644 index 0000000..014a59f --- /dev/null +++ b/aria/embeddings/eval.py @@ -0,0 +1,464 @@ +import torch +import accelerate +import os +import mmap +import time +import json +import functools +import multiprocessing +import copy +import jsonlines +import torch.nn as nn +import torch.nn.functional as F + +from tqdm import tqdm +from collections import deque +from typing import Callable +from concurrent.futures import ThreadPoolExecutor + +from aria.model import ModelConfig, TransformerLM +from aria.config import load_model_config +from aria.utils import _load_weight +from ariautils.midi import MidiDict +from ariautils.tokenizer import AbsTokenizer + +MODEL_PATH = "/mnt/ssd1/aria/v2/medium-dedupe-pt-cont2/checkpoints/epoch18_step0/model.safetensors" +MAX_SEQ_LEN = 512 +TAG_IDS = {"classical": 0, "jazz": 1, "other": 2} +ID_TO_TAG = {v: k for k, v in TAG_IDS.items()} + + +def chunk_and_pad(lst: list, n: int): + return [lst[i : i + n] for i in range(0, len(lst), n)] + + +def init_worker(): + global tokenizer + tokenizer = AbsTokenizer() + + +def write_entries(writer, entries): + for entry in entries: + writer.write(entry) + + +# The worker function processes a single JSON-lines entry. +def process_entry( + entry, + metadata_category: str, + tag_ids: dict, + slice_len_notes: int, + max_seq_len: int, +): + midi_dict = MidiDict.from_msg_dict(entry) + metadata_tag = midi_dict.metadata.get(metadata_category, None) + + # Skip metadata tag + if metadata_tag is None: + return [] + elif metadata_tag not in tag_ids.keys(): + metadata_tag = "other" + + outputs = [] + for slice_note_msgs in chunk_and_pad( + lst=midi_dict.note_msgs, n=slice_len_notes + ): + if len(slice_note_msgs) < 20: + break + + slice_midi_dict = copy.deepcopy(midi_dict) + slice_midi_dict.note_msgs = slice_note_msgs + slice_midi_dict.metadata = {} + tokenized_slice = tokenizer.tokenize(slice_midi_dict) + if tokenizer.eos_tok in tokenized_slice: + tokenized_slice.remove(tokenizer.eos_tok) + if tokenizer.dim_tok in tokenized_slice: + tokenized_slice.remove(tokenizer.dim_tok) + + tokenized_slice = tokenized_slice[:max_seq_len] + + outputs.append({"seq": tokenized_slice, "tag": metadata_tag}) + + return outputs + + +@torch.autocast("cuda", dtype=torch.bfloat16) +@torch.inference_mode() +def get_baseline_embedding( + seqs: list, + hook_model: nn.Module, + hook_max_seq_len: int, + hook_tokenizer: AbsTokenizer, + pool_mode: str = "last", # "last" or "mean" +): + orig_lengths = [len(seq) for seq in seqs] + last_tok_positions = [length - 1 for length in orig_lengths] + seqs = [ + seq + ([hook_tokenizer.pad_tok] * (hook_max_seq_len - len(seq))) + for seq in seqs + ] + + enc_seqs = torch.tensor( + [hook_tokenizer.encode(seq) for seq in seqs], device="cuda" + ) + hidden_states = hook_model(enc_seqs) + + if pool_mode == "last": + idx = torch.arange(hidden_states.shape[0], device=hidden_states.device) + emb = hidden_states[idx, last_tok_positions].tolist() + elif pool_mode == "mean": + pad_id = tokenizer.pad_id + # Create a mask by comparing enc_seqs to pad_id. + mask = (enc_seqs != pad_id).unsqueeze(-1).to(hidden_states.dtype) + # Sum over valid tokens and average. + sum_hidden = (hidden_states * mask).sum(dim=1) + valid_counts = mask.sum(dim=1) + mean_hidden = sum_hidden / valid_counts + emb = mean_hidden.tolist() + else: + raise ValueError(f"Unsupported pool_mode: {pool_mode}") + + return emb + + +class EvaluationDataset(torch.utils.data.Dataset): + def __init__(self, load_path: str, tag_ids: dict): + self.load_path = load_path + self.tag_ids = tag_ids + self.tokenizer = AbsTokenizer() + self.index = [] + + self.file_buff = open(self.load_path, "rb") + self.mmap_obj = mmap.mmap( + self.file_buff.fileno(), 0, access=mmap.ACCESS_READ + ) + + while True: + pos = self.mmap_obj.tell() + line = self.mmap_obj.readline() + if not line: + break + self.index.append(pos) + + def __getitem__(self, idx: int): + pos = self.index[idx] + self.mmap_obj.seek(pos) + + raw_data = self.mmap_obj.readline().decode("utf-8") + json_data = json.loads(raw_data) + + emb = json_data["emb"] + tag = json_data["tag"] + + assert tag in self.tag_ids + tag_tensor = torch.tensor(self.tag_ids[tag]) + emb_tensor = torch.tensor(emb) + + return emb_tensor, tag_tensor + + def __len__(self): + return len(self.index) + + @classmethod + def export_worker_init_fn(cls): + def worker_init_fn(worker_id: int): + worker_info = torch.utils.data.get_worker_info() + dataset = worker_info.dataset + + if hasattr(dataset, "mmap_obj") and dataset.mmap_obj: + dataset.mmap_obj.close() + + f = open(dataset.load_path, "rb") + dataset.mmap_obj = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) + + return worker_init_fn + + @classmethod + def build( + cls, + midi_dataset_load_path: str, + save_path: str, + slice_len_notes: int, + max_seq_len: int, + metadata_category: str, + tag_ids: dict, + batch_size: int, + embedding_hook: Callable, + **embedding_hook_kwargs, + ): + assert os.path.isfile(midi_dataset_load_path) + assert os.path.isfile(save_path) is False + + with jsonlines.open( + midi_dataset_load_path, "r" + ) as midi_dataset, jsonlines.open(save_path, "w") as writer: + + cnt = 0 + buffer = deque() + write_executor = ThreadPoolExecutor(max_workers=1) + with multiprocessing.Pool( + processes=8, initializer=init_worker + ) as pool: + for result in pool.imap_unordered( + functools.partial( + process_entry, + metadata_category=metadata_category, + tag_ids=tag_ids, + slice_len_notes=slice_len_notes, + max_seq_len=max_seq_len, + ), + midi_dataset, + chunksize=10, + ): + + cnt += 1 + if cnt % 500 == 0: + print(f"Completed {cnt}") + + for entry in result: + buffer.append(entry) + + # Inside your processing loop: + if len(buffer) >= batch_size: + _buffer = [buffer.popleft() for _ in range(batch_size)] + _seqs = [entry["seq"] for entry in _buffer] + _tags = [entry["tag"] for entry in _buffer] + _embs = embedding_hook( + seqs=_seqs, **embedding_hook_kwargs + ) + + # Prepare the write objects + write_objs = [ + {"seq": _seq, "emb": _emb, "tag": _tag} + for _seq, _emb, _tag in zip(_seqs, _embs, _tags) + ] + + write_executor.submit(write_entries, writer, write_objs) + + if buffer: + _seqs = [entry["seq"] for entry in buffer] + _tags = [entry["tag"] for entry in buffer] + _embs = embedding_hook(seqs=_seqs, **embedding_hook_kwargs) + for _seq, _tag, _emb in zip(_seqs, _tags, _embs): + writer.write({"seq": _seq, "emb": _emb, "tag": _tag}) + + +def _get_optim( + lr: float, + model: nn.Module, + total_steps: int, + warmup: int = 100, + end_ratio: int = 0.1, +): + optimizer = torch.optim.AdamW( + model.parameters(), + lr=lr, + weight_decay=0.1, + betas=(0.9, 0.95), + eps=1e-5, + ) + + warmup_lrs = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=0.000001, + end_factor=1, + total_iters=warmup, + ) + linear_decay_lrs = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=1, + end_factor=end_ratio, + total_iters=total_steps - warmup, + ) + + lr_scheduler = torch.optim.lr_scheduler.SequentialLR( + optimizer, + schedulers=[warmup_lrs, linear_decay_lrs], + milestones=[warmup], + ) + + return optimizer, lr_scheduler + + +class ClassifierHead(nn.Module): + def __init__(self, d_emb: int, hidden_dim: int, num_class: int): + super().__init__() + self.fc1 = nn.Linear(d_emb, hidden_dim) + self.activation = nn.ReLU() + self.fc2 = nn.Linear(hidden_dim, num_class) + + def forward(self, x): + x = self.fc1(x) + x = self.activation(x) + logits = self.fc2(x) + return logits + + +def _train( + accelerator: accelerate.Accelerator, + model: nn.Module, + train_dataloader: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, +): + TRAILING_LOSS_STEPS = 100 + loss = torch.tensor([0.0]) + trailing_loss = 0 + lr_for_print = "{:.2e}".format(optimizer.param_groups[-1]["lr"]) + loss_buffer = [] + + model.train() + loss_fn = nn.CrossEntropyLoss() + + for __step, batch in ( + pbar := tqdm(enumerate(train_dataloader), leave=False) + ): + pbar.set_postfix_str( + f"lr={lr_for_print}, " + f"loss={round(loss.item(), 4)}, " + f"trailing={round(trailing_loss, 4)}" + ) + + emb, tag_ids = batch + tag_ids = tag_ids.view(-1) + + logits = model(emb) + loss = loss_fn(logits, tag_ids) + + loss_buffer.append(accelerator.gather(loss).mean(dim=0).item()) + trailing_loss = sum(loss_buffer[-TRAILING_LOSS_STEPS:]) / len( + loss_buffer[-TRAILING_LOSS_STEPS:] + ) + + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + if scheduler: + scheduler.step() + lr_for_print = "{:.2e}".format(scheduler.get_last_lr()[0]) + + if accelerator.is_main_process: + accelerator.save_state("/mnt/ssd1/aria/test") + + return model + + +def train_classifier( + emb_d: int, + train_dataset: EvaluationDataset, + tag_ids: dict, + batch_size: int, +): + train_dataloader = torch.utils.data.DataLoader( + dataset=train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=8, + worker_init_fn=EvaluationDataset.export_worker_init_fn(), + ) + + model = ClassifierHead( + d_emb=emb_d, + hidden_dim=emb_d, + num_class=len(tag_ids.keys()), + ) + optimizer, scheduler = _get_optim( + lr=3e-4, + model=model, + total_steps=len(train_dataloader), + ) + accelerator = accelerate.Accelerator() + + model, train_dataloader, optimizer, scheduler = accelerator.prepare( + model, + train_dataloader, + optimizer, + scheduler, + ) + + return _train( + accelerator=accelerator, + model=model, + train_dataloader=train_dataloader, + optimizer=optimizer, + scheduler=scheduler, + ) + + +def eval(model: nn.Module): + val_dataset = EvaluationDataset( + load_path="/mnt/ssd1/aria/data/val.jsonl", + tag_ids=TAG_IDS, + ) + model = model.cpu() + + correct = 0 + total = 0 + dist = { + "classical": 0, + "jazz": 0, + "other": 0, + } + for midi_emb, tag_id in val_dataset: + with torch.no_grad(): + logits = model(torch.tensor(midi_emb.view(1, -1))) + probs = F.softmax(logits) + pred_tag_id = probs.argmax(dim=-1).item() + dist[ID_TO_TAG[tag_id.item()]] += 1 + + if ID_TO_TAG[tag_id.item()] == "other": + continue + + if pred_tag_id == tag_id.item(): + correct += 1 + total += 1 + + print(ID_TO_TAG[tag_id.item()], ID_TO_TAG[pred_tag_id]) + input("...") + + print(f"Total accuracy: {correct/total}") + print(f"Label distribution: {dist}") + + +if __name__ == "__main__": + tokenizer = AbsTokenizer() + + dataset = EvaluationDataset( + load_path="/mnt/ssd1/aria/data/train.jsonl", + tag_ids=TAG_IDS, + ) + + model = train_classifier( + emb_d=1536, + train_dataset=dataset, + batch_size=32, + tag_ids=TAG_IDS, + ) + eval(model=model) + + # model_state = _load_weight(MODEL_PATH, "cuda") + # model_state = { + # k.replace("_orig_mod.", ""): v for k, v in model_state.items() + # } + # pretrained_model_config = ModelConfig(**load_model_config("medium")) + # pretrained_model_config.set_vocab_size(tokenizer.vocab_size) + # pretrained_model_config.grad_checkpoint = False + # pretrained_model = TransformerLM(pretrained_model_config) + # pretrained_model.load_state_dict(model_state) + # pretrained_model.eval() + + # EvaluationDataset.build( + # midi_dataset_load_path="/mnt/ssd1/aria/data/mididict-ft_train.jsonl", + # save_path="/mnt/ssd1/aria/data/train.jsonl", + # max_seq_len=MAX_SEQ_LEN, + # slice_len_notes=165, + # metadata_category="genre", + # tag_ids=TAG_IDS, + # batch_size=128, + # embedding_hook=functools.partial( + # get_baseline_embedding, pool_mode="mean" + # ), + # hook_model=pretrained_model.model.cuda(), + # hook_max_seq_len=512, + # hook_tokenizer=tokenizer, + # ) diff --git a/aria/model.py b/aria/model.py index b4753f7..5dc14b3 100644 --- a/aria/model.py +++ b/aria/model.py @@ -20,6 +20,8 @@ class ModelConfig: max_seq_len: int grad_checkpoint: bool vocab_size: Optional[int] = None + class_size: Optional[int] = None + tag_to_id: Optional[dict] = None def set_vocab_size(self, vocab_size: int): self.vocab_size = vocab_size @@ -197,6 +199,7 @@ class TransformerLM(nn.Module): def __init__(self, model_config: ModelConfig): super().__init__() + assert model_config.vocab_size is not None self.max_seq_len = model_config.max_seq_len self.model = Transformer(model_config) @@ -228,6 +231,47 @@ def forward( return logits +class TransformerCL(nn.Module): + """Transformer decoder with head for classification. + + Args: + model_config (ModelConfig): Model config settings. + """ + + def __init__(self, model_config: ModelConfig): + super().__init__() + assert model_config.class_size is not None + + self.max_seq_len = model_config.max_seq_len + self.model = Transformer(model_config) + self.class_head = nn.Linear( + model_config.d_model, model_config.class_size, bias=False + ) + + def forward( + self, + src: torch.Tensor, + ): + """Forward pass of Transformer decoder with CL head. + + Args: + src (torch.tensor): Input to encoder block, of shape (batch_size, + seq_len, d_model). + attn_mask (Optional[torch.tensor]): Attention mask of shape + (batch_size, seq_len). Defaults to None. + past_kv (Optional[list[KVCache]]): a list of kv caches. The list index + corresponds to the layer index. + + Returns: + torch.tensor: Forward pass of src through Transformer and CL head. + Has shape (batch_size, seq_len, vocab_size). + """ + hidden = self.model(src) + logits = self.class_head(hidden) + + return logits + + def precompute_freqs_cis( seq_len: int, n_elem: int, diff --git a/config/models/medium-genre.json b/config/models/medium-genre.json new file mode 100644 index 0000000..97a4b89 --- /dev/null +++ b/config/models/medium-genre.json @@ -0,0 +1,15 @@ +{ + "d_model": 1536, + "n_heads": 24, + "n_layers": 16, + "ff_mult": 4, + "drop_p": 0.0, + "max_seq_len": 8192, + "grad_checkpoint": true, + "class_size": 3, + "tag_to_id": { + "classical": 0, + "jazz": 1, + "other": 2 + } +} From 7d99bad727e8c8ef819729afa731a56c7610ff95 Mon Sep 17 00:00:00 2001 From: Louis Date: Fri, 21 Feb 2025 15:01:31 +0000 Subject: [PATCH 09/72] add seq sep option to PretrainingDataset --- aria/datasets.py | 53 ++++++++++++++++++++++++++++++++++++------------ aria/run.py | 6 ++++++ 2 files changed, 46 insertions(+), 13 deletions(-) diff --git a/aria/datasets.py b/aria/datasets.py index b1d3b51..01524ab 100644 --- a/aria/datasets.py +++ b/aria/datasets.py @@ -808,6 +808,7 @@ def build( num_epochs: int, midi_dataset: MidiDataset = None, midi_dataset_path: str = None, + separate_sequences: bool = False, ): """Builds and returns PretrainingDataset.""" @@ -822,21 +823,47 @@ def _build_epoch(_save_path, _midi_dataset): } ) - buffer = [] - _idx = 0 - for entry in reservoir(get_seqs(tokenizer, _midi_dataset), 10): - if entry is not None: - buffer += entry - while len(buffer) >= max_seq_len: + if separate_sequences is False: + buffer = [] + _idx = 0 + for entry in reservoir( + get_seqs(tokenizer, _midi_dataset), 10 + ): + if entry is not None: + buffer += entry + while len(buffer) >= max_seq_len: + writer.write(buffer[:max_seq_len]) + buffer = buffer[max_seq_len:] + + _idx += 1 + if _idx % 250 == 0: + logger.info(f"Finished processing {_idx}") + + buffer += [tokenizer.pad_tok] * (max_seq_len - len(buffer)) + writer.write(buffer[:max_seq_len]) + elif separate_sequences is True: + _idx = 0 + for entry in reservoir( + get_seqs(tokenizer, _midi_dataset), 10 + ): + if entry is None: + continue + + buffer = entry + while len(buffer) >= max_seq_len: + writer.write(buffer[:max_seq_len]) + buffer = buffer[max_seq_len:] + + buffer += [tokenizer.pad_tok] * ( + max_seq_len - len(buffer) + ) writer.write(buffer[:max_seq_len]) - buffer = buffer[max_seq_len:] - - _idx += 1 - if _idx % 250 == 0: - logger.info(f"Finished processing {_idx}") - buffer += [tokenizer.pad_tok] * (max_seq_len - len(buffer)) - writer.write(buffer[:max_seq_len]) + _idx += 1 + if _idx % 250 == 0: + logger.info(f"Finished processing {_idx}") + else: + raise ValueError logger = setup_logger() assert max_seq_len > 0, "max_seq_len must be greater than 0" diff --git a/aria/run.py b/aria/run.py index 182d1f5..a1302f9 100644 --- a/aria/run.py +++ b/aria/run.py @@ -252,6 +252,11 @@ def _parse_pretrain_dataset_args(): ) argp.add_argument("-l", help="max sequence length", type=int, default=4096) argp.add_argument("-e", help="num epochs", type=int, default=1) + argp.add_argument( + "-sep_sequences", + help="start each with a new entry", + action="store_true", + ) return argp.parse_args(sys.argv[2:]) @@ -271,6 +276,7 @@ def build_pretraining_dataset(args): max_seq_len=args.l, num_epochs=args.e, midi_dataset_path=args.load_path, + separate_sequences=args.sep_sequences, ) From 014f0b9ab713b7e55b63fdb315ddf2bb939387ff Mon Sep 17 00:00:00 2001 From: Louis Date: Fri, 21 Feb 2025 18:14:54 +0000 Subject: [PATCH 10/72] change from genre to composer --- aria/embeddings/classifier_finetune.py | 59 +++++++++++++++++--------- 1 file changed, 40 insertions(+), 19 deletions(-) diff --git a/aria/embeddings/classifier_finetune.py b/aria/embeddings/classifier_finetune.py index 7ff7330..ac4b68a 100644 --- a/aria/embeddings/classifier_finetune.py +++ b/aria/embeddings/classifier_finetune.py @@ -24,6 +24,28 @@ from logging.handlers import RotatingFileHandler from tqdm import tqdm +TAG_IDS = { + "chopin": 0, + "bach": 1, + "beethoven": 2, + "liszt": 3, + "mozart": 4, + "debussy": 5, + "schumann": 6, + "schubert": 7, + "rachmaninoff": 8, + "brahms": 9, + "tchaikovsky": 10, + "haydn": 11, + "scriabin": 12, + "mendelssohn": 13, + "czerny": 14, + "ravel": 15, + "scarlatti": 16, + "other": 17, +} +METADATA_CATEGORY = "composer" + def setup_logger(project_dir: str): # Get logger and reset all handlers @@ -327,7 +349,6 @@ def get_dataloaders( num_workers: int, apply_aug=True, ): - TAG_IDS = {"classical": 0, "jazz": 1, "other": 2} train_dataset = FinetuningDataset( load_path=train_data_path, tag_ids=TAG_IDS, @@ -601,8 +622,8 @@ def test_build_dataset(): min_slice_notes=100, max_slice_notes=165, max_seq_len=512, - metadata_category="genre", - tag_ids={"classical": 0, "jazz": 1, "other": 2}, + metadata_category=METADATA_CATEGORY, + tag_ids=TAG_IDS, ) # FinetuningDataset.build( @@ -611,15 +632,15 @@ def test_build_dataset(): # min_slice_notes=100, # max_slice_notes=165, # max_seq_len=512, - # metadata_category="genre", - # tag_ids={"classical": 0, "jazz": 1, "other": 2}, + # metadata_category=METADATA_CATEGORY, + # tag_ids=TAG_IDS, # ) def test_dataset(): dataset = FinetuningDataset( load_path="/mnt/ssd1/aria/data/test.jsonl", - tag_ids={"classical": 0, "jazz": 1, "other": 2}, + tag_ids=TAG_IDS, ) for idx, entry in enumerate(dataset): @@ -645,18 +666,18 @@ def parse_args(): if __name__ == "__main__": - args = parse_args() - train( - model_name=args.model_name, - checkpoint_path=args.checkpoint_path, - train_data_path=args.train_data_path, - val_data_path=args.val_data_path, - batch_size=args.batch_size, - num_epochs=args.num_epochs, - num_workers=args.num_workers, - grad_acc_steps=args.grad_acc_steps, - project_dir=args.project_dir, - ) + # args = parse_args() + # train( + # model_name=args.model_name, + # checkpoint_path=args.checkpoint_path, + # train_data_path=args.train_data_path, + # val_data_path=args.val_data_path, + # batch_size=args.batch_size, + # num_epochs=args.num_epochs, + # num_workers=args.num_workers, + # grad_acc_steps=args.grad_acc_steps, + # project_dir=args.project_dir, + # ) - # test_build_dataset() + test_build_dataset() # test_dataset() From 6b53ba4668308fe09186361555fc065825b566c8 Mon Sep 17 00:00:00 2001 From: Louis Date: Mon, 24 Feb 2025 14:37:40 +0000 Subject: [PATCH 11/72] update emb eval scripts --- aria/embeddings/eval.py | 110 ++++++++++++++++++++++++---------------- 1 file changed, 65 insertions(+), 45 deletions(-) diff --git a/aria/embeddings/eval.py b/aria/embeddings/eval.py index 014a59f..119c9c4 100644 --- a/aria/embeddings/eval.py +++ b/aria/embeddings/eval.py @@ -2,7 +2,6 @@ import accelerate import os import mmap -import time import json import functools import multiprocessing @@ -16,7 +15,7 @@ from typing import Callable from concurrent.futures import ThreadPoolExecutor -from aria.model import ModelConfig, TransformerLM +from aria.model import ModelConfig, TransformerLM, TransformerCL from aria.config import load_model_config from aria.utils import _load_weight from ariautils.midi import MidiDict @@ -24,8 +23,28 @@ MODEL_PATH = "/mnt/ssd1/aria/v2/medium-dedupe-pt-cont2/checkpoints/epoch18_step0/model.safetensors" MAX_SEQ_LEN = 512 -TAG_IDS = {"classical": 0, "jazz": 1, "other": 2} -ID_TO_TAG = {v: k for k, v in TAG_IDS.items()} +TAG_TO_ID = { + "chopin": 0, + "bach": 1, + "beethoven": 2, + "liszt": 3, + "mozart": 4, + "debussy": 5, + "schumann": 6, + "schubert": 7, + "rachmaninoff": 8, + "brahms": 9, + "tchaikovsky": 10, + "haydn": 11, + "scriabin": 12, + "mendelssohn": 13, + "czerny": 14, + "ravel": 15, + "scarlatti": 16, + "other": 17, +} +ID_TO_TAG = {v: k for k, v in TAG_TO_ID.items()} +METADATA_CATEGORY = "composer" def chunk_and_pad(lst: list, n: int): @@ -385,20 +404,17 @@ def train_classifier( ) -def eval(model: nn.Module): +def evaluate_model(model: nn.Module, val_dataset_path: str): val_dataset = EvaluationDataset( - load_path="/mnt/ssd1/aria/data/val.jsonl", - tag_ids=TAG_IDS, + load_path=val_dataset_path, + tag_ids=TAG_TO_ID, ) model = model.cpu() correct = 0 total = 0 - dist = { - "classical": 0, - "jazz": 0, - "other": 0, - } + dist = {k: 0 for k in TAG_TO_ID.keys()} + for midi_emb, tag_id in val_dataset: with torch.no_grad(): logits = model(torch.tensor(midi_emb.view(1, -1))) @@ -413,52 +429,56 @@ def eval(model: nn.Module): correct += 1 total += 1 - print(ID_TO_TAG[tag_id.item()], ID_TO_TAG[pred_tag_id]) - input("...") - print(f"Total accuracy: {correct/total}") print(f"Label distribution: {dist}") +def build_dataset(): + MODEL_PATH = "/mnt/ssd1/aria/v2/medium-dedupe-pt-cont2/checkpoints/epoch18_step0/model.safetensors" + + model_state = _load_weight(MODEL_PATH, "cuda") + model_state = { + k.replace("_orig_mod.", ""): v for k, v in model_state.items() + } + pretrained_model_config = ModelConfig(**load_model_config("medium")) + pretrained_model_config.set_vocab_size(tokenizer.vocab_size) + pretrained_model_config.grad_checkpoint = False + pretrained_model = TransformerLM(pretrained_model_config) + pretrained_model.load_state_dict(model_state) + pretrained_model.eval() + + EvaluationDataset.build( + midi_dataset_load_path="/mnt/ssd1/aria/data/mididict-ft_train.jsonl", + save_path="/mnt/ssd1/aria/data/train.jsonl", + max_seq_len=MAX_SEQ_LEN, + slice_len_notes=165, + metadata_category="genre", + tag_ids=TAG_TO_ID, + batch_size=128, + embedding_hook=functools.partial( + get_baseline_embedding, pool_mode="mean" + ), + hook_model=pretrained_model.model.cuda(), + hook_max_seq_len=512, + hook_tokenizer=tokenizer, + ) + + if __name__ == "__main__": tokenizer = AbsTokenizer() dataset = EvaluationDataset( load_path="/mnt/ssd1/aria/data/train.jsonl", - tag_ids=TAG_IDS, + tag_ids=TAG_TO_ID, ) model = train_classifier( emb_d=1536, train_dataset=dataset, batch_size=32, - tag_ids=TAG_IDS, + tag_ids=TAG_TO_ID, + ) + evaluate_model( + model=model, + val_dataset_path="/mnt/ssd1/aria/data/val.jsonl", ) - eval(model=model) - - # model_state = _load_weight(MODEL_PATH, "cuda") - # model_state = { - # k.replace("_orig_mod.", ""): v for k, v in model_state.items() - # } - # pretrained_model_config = ModelConfig(**load_model_config("medium")) - # pretrained_model_config.set_vocab_size(tokenizer.vocab_size) - # pretrained_model_config.grad_checkpoint = False - # pretrained_model = TransformerLM(pretrained_model_config) - # pretrained_model.load_state_dict(model_state) - # pretrained_model.eval() - - # EvaluationDataset.build( - # midi_dataset_load_path="/mnt/ssd1/aria/data/mididict-ft_train.jsonl", - # save_path="/mnt/ssd1/aria/data/train.jsonl", - # max_seq_len=MAX_SEQ_LEN, - # slice_len_notes=165, - # metadata_category="genre", - # tag_ids=TAG_IDS, - # batch_size=128, - # embedding_hook=functools.partial( - # get_baseline_embedding, pool_mode="mean" - # ), - # hook_model=pretrained_model.model.cuda(), - # hook_max_seq_len=512, - # hook_tokenizer=tokenizer, - # ) From e6c1e2a799d04995e6e42ed6343c5ffa97d843ba Mon Sep 17 00:00:00 2001 From: Louis Date: Mon, 24 Feb 2025 14:37:58 +0000 Subject: [PATCH 12/72] add explore script --- aria/embeddings/explore_midi.py | 112 ++++++++++++++++++++++++++++++++ 1 file changed, 112 insertions(+) create mode 100644 aria/embeddings/explore_midi.py diff --git a/aria/embeddings/explore_midi.py b/aria/embeddings/explore_midi.py new file mode 100644 index 0000000..13d38a0 --- /dev/null +++ b/aria/embeddings/explore_midi.py @@ -0,0 +1,112 @@ +import copy +import torch + +from aria.config import load_model_config +from aria.utils import _load_weight +from ariautils.midi import MidiDict +from ariautils.tokenizer import AbsTokenizer +from aria.model import TransformerCL, ModelConfig + +TAG_IDS = { + "chopin": 0, + "bach": 1, + "beethoven": 2, + "liszt": 3, + "mozart": 4, + "debussy": 5, + "schumann": 6, + "schubert": 7, + "rachmaninoff": 8, + "brahms": 9, + "tchaikovsky": 10, + "haydn": 11, + "scriabin": 12, + "mendelssohn": 13, + "czerny": 14, + "ravel": 15, + "scarlatti": 16, + "other": 17, +} +ID_TO_TAG = {v: k for k, v in TAG_IDS.items()} + + +def explore_midi( + midi_path: str, + checkpoint_path: str, + metadata_category: str, + slice_len_notes: int = 500, + max_seq_len: int = 2048, +): + midi_dict = MidiDict.from_midi(midi_path) + print(midi_dict.instrument_msgs) + + tag = midi_dict.metadata.get(metadata_category, None) + if tag is not None and tag not in TAG_IDS: + tag = "other" + + note_msgs = midi_dict.note_msgs + slices = [ + note_msgs[i : i + slice_len_notes] + for i in range(0, len(note_msgs), slice_len_notes) + ] + slices = [s for s in slices if len(s) >= 20] + + print(f"Found {len(slices)} slices in the MIDI file.") + + tokenizer = AbsTokenizer() + model_config = ModelConfig(**load_model_config("medium-composer")) + model_config.set_vocab_size(tokenizer.vocab_size) + model_config.grad_checkpoint = False + model_state = _load_weight(checkpoint_path, device="cuda") + model = TransformerCL(model_config) + model.load_state_dict(model_state) + model.eval() + model.cuda() + + for idx, note_slice in enumerate(slices): + slice_midi = copy.deepcopy(midi_dict) + slice_midi.note_msgs = note_slice + slice_midi.metadata = {} + + tokenized_seq = tokenizer.tokenize(slice_midi) + tokenizer.detokenize(tokenized_seq).to_midi().save( + "/home/loubb/Dropbox/shared/test.mid" + ) + if tokenizer.dim_tok in tokenized_seq: + tokenized_seq.remove(tokenizer.dim_tok) + tokenized_seq = tokenized_seq[:max_seq_len] + if tokenizer.eos_tok not in tokenized_seq: + tokenized_seq[-1] = tokenizer.eos_tok + + tokenizer + encoded_seq = tokenizer.encode(tokenized_seq) + input_tensor = torch.tensor([encoded_seq]).cuda() + + # Forward pass + with torch.inference_mode(): + logits = model(input_tensor)[0, -1, :] + probs = torch.softmax(logits, dim=-1) + # Get the top 5 probabilities and their corresponding indices + top_probs, top_indices = torch.topk(probs, k=5) + formatted_top_probs = [ + float(f"{p:.4f}") for p in top_probs.tolist() + ] + top_tags = [ + ID_TO_TAG.get(idx.item(), "unknown") for idx in top_indices + ] + + print("Top 5 Predictions:") + for tag, prob in zip(top_tags, formatted_top_probs): + print(f"{tag}: {prob}") + + input("\nPress Enter to continue to the next slice...") + + +if __name__ == "__main__": + explore_midi( + midi_path="/home/loubb/Dropbox/shared/audio.mid", + checkpoint_path="/home/loubb/work/aria/models/medium-composer.safetensors", + metadata_category="composer", + slice_len_notes=150, + max_seq_len=512, + ) From 1a94cb44770aea459d8c1b95985a2942d6564db9 Mon Sep 17 00:00:00 2001 From: Louis Date: Mon, 24 Feb 2025 18:33:57 +0000 Subject: [PATCH 13/72] add contrastive ft --- aria/embeddings/contrastive_finetune.py | 666 +++++++++++++++++++++++ aria/embeddings/finetune_contrastive.py | 683 ++++++++++++++++++++++++ 2 files changed, 1349 insertions(+) create mode 100644 aria/embeddings/contrastive_finetune.py create mode 100644 aria/embeddings/finetune_contrastive.py diff --git a/aria/embeddings/contrastive_finetune.py b/aria/embeddings/contrastive_finetune.py new file mode 100644 index 0000000..2df6552 --- /dev/null +++ b/aria/embeddings/contrastive_finetune.py @@ -0,0 +1,666 @@ +import torch +import os +import mmap +import argparse +import logging +import random +import copy +import accelerate +import json + +from aria.config import load_model_config +from aria.utils import _load_weight +from ariautils.tokenizer import AbsTokenizer +from ariautils.midi import MidiDict +from aria.model import TransformerEMB, ModelConfig + +from torch import nn +from torch.nn import functional as F +from torch.utils.data import DataLoader, Dataset + +from accelerate.logging import get_logger +from logging.handlers import RotatingFileHandler +from tqdm import tqdm + + +def setup_logger(project_dir: str): + # Get logger and reset all handlers + logger = logging.getLogger(__name__) + for h in logger.handlers[:]: + logger.removeHandler(h) + + logger.propagate = False + logger.setLevel(logging.DEBUG) + formatter = logging.Formatter( + "[%(asctime)s] %(name)s: [%(levelname)s] %(message)s", + ) + fh = RotatingFileHandler( + os.path.join(project_dir, "logs.txt"), backupCount=5, maxBytes=1024**3 + ) + fh.setLevel(logging.DEBUG) + fh.setFormatter(formatter) + logger.addHandler(fh) + + ch = logging.StreamHandler() + ch.setLevel(logging.INFO) + ch.setFormatter(formatter) + logger.addHandler(ch) + + return get_logger(__name__) + + +def setup_project_dir(project_dir: str | None): + if not project_dir: + # Create project directory + if not os.path.isdir("./experiments"): + os.mkdir("./experiments") + + project_dirs = [ + _dir + for _dir in os.listdir("./experiments") + if os.path.isdir(os.path.join("experiments", _dir)) + ] + + ind = 0 + while True: + if str(ind) not in project_dirs: + break + else: + ind += 1 + + project_dir_abs = os.path.abspath(os.path.join("experiments", str(ind))) + assert not os.path.isdir(project_dir_abs) + os.mkdir(project_dir_abs) + + elif project_dir: + if os.path.isdir(project_dir): + assert ( + len(os.listdir(project_dir)) == 0 + ), "Provided project directory is not empty" + project_dir_abs = os.path.abspath(project_dir) + elif os.path.isfile(project_dir): + raise FileExistsError( + "The provided path points toward an existing file" + ) + else: + try: + os.mkdir(project_dir) + except Exception as e: + raise Exception( + f"Failed to create project directory at {project_dir}" + ) from e + + project_dir_abs = os.path.abspath(project_dir) + + os.mkdir(os.path.join(project_dir_abs, "checkpoints")) + + return project_dir_abs + + +class ContrastiveDataset(Dataset): + def __init__( + self, + load_path: str, + min_number_slice_notes: int, + max_number_slice_notes: int, + max_seq_len: int, + ): + self.load_path = load_path + self.min_number_slice_notes = min_number_slice_notes + self.max_number_slice_notes = max_number_slice_notes + self.max_seq_len = max_seq_len + + self.tokenizer = AbsTokenizer() + self.index = [] + + self.file_buff = open(self.load_path, "rb") + self.mmap_obj = mmap.mmap( + self.file_buff.fileno(), 0, access=mmap.ACCESS_READ + ) + + while True: + pos = self.mmap_obj.tell() + line = self.mmap_obj.readline() + if not line: + break + self.index.append(pos) + + def get_slice( + self, + midi_dict: MidiDict, + min_num_notes: int, + max_num_notes: int, + max_seq_len: int, + ): + _midi_dict = copy.deepcopy(midi_dict) + slice_length = random.randint(min_num_notes, max_num_notes) + idx = random.randint(0, len(_midi_dict.note_msgs) - min_num_notes) + + _midi_dict.note_msgs = _midi_dict.note_msgs[idx : idx + slice_length] + _midi_dict.metadata = {} + + tokenized_slice = self.tokenizer.tokenize(_midi_dict) + if self.tokenizer.dim_tok in tokenized_slice: + tokenized_slice.remove(self.tokenizer.dim_tok) + + # Use EOS tok for classification head + tokenized_slice = tokenized_slice[:max_seq_len] + tokenized_slice += [self.tokenizer.pad_tok] * ( + max_seq_len - len(tokenized_slice) + ) + if self.tokenizer.eos_tok not in tokenized_slice: + tokenized_slice[-1] = self.tokenizer.eos_tok + + pos = tokenized_slice.index(self.tokenizer.eos_tok) + + return tokenized_slice, pos + + def __getitem__(self, idx: int): + def _format(tok): + # Required because json formats tuples into lists + if isinstance(tok, list): + return tuple(tok) + return tok + + file_pos = self.index[idx] + self.mmap_obj.seek(file_pos) + + raw_data = self.mmap_obj.readline().decode("utf-8") + json_data = json.loads(raw_data) + midi_dict = MidiDict.from_msg_dict(json_data) + + slice_seq_1, slice_pos_1 = self.get_slice( + midi_dict=midi_dict, + min_num_notes=self.min_number_slice_notes, + max_num_notes=self.max_number_slice_notes, + max_seq_len=self.max_seq_len, + ) + slice_seq_2, slice_pos_2 = self.get_slice( + midi_dict=midi_dict, + min_num_notes=self.min_number_slice_notes, + max_num_notes=self.max_number_slice_notes, + max_seq_len=self.max_seq_len, + ) + + slice_seq_1 = [_format(tok) for tok in slice_seq_1] + slice_seq_2 = [_format(tok) for tok in slice_seq_2] + + assert len(slice_seq_1) <= self.max_seq_len + assert len(slice_seq_2) <= self.max_seq_len + assert slice_pos_1 < self.max_seq_len + assert slice_pos_2 < self.max_seq_len + assert slice_seq_1[slice_pos_1] == self.tokenizer.eos_tok + assert slice_seq_2[slice_pos_2] == self.tokenizer.eos_tok + + slices_enc = torch.tensor( + [ + self.tokenizer.encode(slice_seq_1), + self.tokenizer.encode(slice_seq_2), + ] + ) + + slices_pos = torch.tensor([slice_pos_1, slice_pos_2]) + + return slices_enc, slices_pos + + def __len__(self): + return len(self.index) + + @classmethod + def export_worker_init_fn(cls): + def worker_init_fn(worker_id: int): + worker_info = torch.utils.data.get_worker_info() + dataset = worker_info.dataset + + if hasattr(dataset, "mmap_obj") and dataset.mmap_obj: + dataset.mmap_obj.close() + + dataset.file_buff = open(dataset.load_path, "rb") + dataset.mmap_obj = mmap.mmap( + dataset.file_buff.fileno(), 0, access=mmap.ACCESS_READ + ) + + return worker_init_fn + + +def _get_optim( + lr: float, + model: nn.Module, + num_epochs: int, + steps_per_epoch: int, + warmup: int = 100, + end_ratio: int = 0.1, +): + optimizer = torch.optim.AdamW( + model.parameters(), + lr=lr, + weight_decay=0.1, + betas=(0.9, 0.95), + eps=1e-5, + ) + + warmup_lrs = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=0.000001, + end_factor=1, + total_iters=warmup, + ) + linear_decay_lrs = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=1, + end_factor=end_ratio, + total_iters=(num_epochs * steps_per_epoch) - warmup, + ) + + lr_scheduler = torch.optim.lr_scheduler.SequentialLR( + optimizer, + schedulers=[warmup_lrs, linear_decay_lrs], + milestones=[warmup], + ) + + return optimizer, lr_scheduler + + +def get_optim( + model: nn.Module, + num_epochs: int, + steps_per_epoch: int, +): + LR = 1e-5 + END_RATIO = 0.1 + WARMUP_STEPS = 1000 + + return _get_optim( + lr=LR, + model=model, + num_epochs=num_epochs, + steps_per_epoch=steps_per_epoch, + warmup=WARMUP_STEPS, + end_ratio=END_RATIO, + ) + + +def get_dataloaders( + train_data_path: str, + val_data_path: str, + batch_size: int, + num_workers: int, + min_number_slice_notes: int = 100, + max_number_slice_notes: int = 300, + max_seq_len: int = 1024, +): + train_dataset = ContrastiveDataset( + load_path=train_data_path, + min_number_slice_notes=min_number_slice_notes, + max_number_slice_notes=max_number_slice_notes, + max_seq_len=max_seq_len, + ) + val_dataset = ContrastiveDataset( + load_path=val_data_path, + min_number_slice_notes=min_number_slice_notes, + max_number_slice_notes=max_number_slice_notes, + max_seq_len=max_seq_len, + ) + + train_loader = DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + worker_init_fn=ContrastiveDataset.export_worker_init_fn(), + ) + val_loader = DataLoader( + val_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + worker_init_fn=ContrastiveDataset.export_worker_init_fn(), + ) + + return train_loader, val_loader + + +# TODO: This might not be 100% correct (verify CEL calculation) +def symmetric_nt_xent_loss_cosine( + z1: torch.Tensor, z2: torch.Tensor, temperature=0.5 +): + bsz = z1.shape[0] + + z1 = F.normalize(z1, dim=1) # First view + z2 = F.normalize(z2, dim=1) # Second view + + sim_matrix = ( + F.cosine_similarity(z1.unsqueeze(1), z2.unsqueeze(0), dim=-1) + / temperature + ) + + labels = torch.arange(bsz, device=z1.device) + + loss1 = F.cross_entropy(sim_matrix, labels) + loss2 = F.cross_entropy(sim_matrix.T, labels) + + return (loss1 + loss2) / 2.0 + + +def _train( + num_epochs: int, + accelerator: accelerate.Accelerator, + model: TransformerEMB, + train_dataloader: DataLoader, + val_dataloader: DataLoader, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler = None, + project_dir: str | None = None, +): + def make_checkpoint( + _accelerator: accelerate.Accelerator, _epoch: int, _step: int + ): + if accelerator.is_main_process: + checkpoint_dir = os.path.join( + project_dir, + "checkpoints", + f"epoch{_epoch}_step{_step}", + ) + + logger.info( + f"EPOCH {_epoch}/{num_epochs}: Saving checkpoint - {checkpoint_dir}" + ) + _accelerator.save_state(checkpoint_dir) + + def train_loop( + dataloader: DataLoader, _epoch: int, steps_per_checkpoint: int + ): + loss = torch.tensor([0.0]) + avg_train_loss = 0 + trailing_loss = 0 + loss_buffer = [] + + try: + lr_for_print = "{:.2e}".format(scheduler.get_last_lr()[0]) + except Exception: + pass + else: + lr_for_print = "{:.2e}".format(optimizer.param_groups[-1]["lr"]) + + model.train() + for __step, batch in ( + pbar := tqdm( + enumerate(dataloader), + total=len(dataloader), + initial=0, + leave=False, + ) + ): + pbar.set_postfix_str( + f"lr={lr_for_print}, " + f"loss={round(loss.item(), 4)}, " + f"trailing={round(trailing_loss, 4)}" + ) + + with accelerator.accumulate(model): + step = __step + 1 + seqs, eos_pos = batch + + seqs = seqs.contiguous() + bsz = seqs.size(0) + seqs_flat = seqs.view(2 * bsz, seqs.size(-1)) + + outputs = model(seqs_flat) + z1_full = outputs[0::2] + z2_full = outputs[1::2] + + batch_indices = torch.arange(bsz, device=z1_full.device) + eos_pos_1 = eos_pos[:, 0] + eos_pos_2 = eos_pos[:, 1] + + z1 = z1_full[batch_indices, eos_pos_1] + z2 = z2_full[batch_indices, eos_pos_2] + + # seqs_1 = seqs[:, 0, :] + # eos_pos_1 = eos_pos[:, 0] + # z1_full = model(seqs_1) + # z1 = z1_full[ + # torch.arange(z1_full.shape[0], device=z1_full.device), + # eos_pos_1, + # ] + + # seqs_2 = seqs[:, 1, :] + # eos_pos_2 = eos_pos[:, 1] + # z2_full = model(seqs_2) + # z2 = z2_full[ + # torch.arange(z2_full.shape[0], device=z2_full.device), + # eos_pos_2, + # ] + + #### + + loss = symmetric_nt_xent_loss_cosine(z1, z2) + + # Calculate statistics + loss_buffer.append(accelerator.gather(loss).mean(dim=0).item()) + trailing_loss = sum(loss_buffer[-TRAILING_LOSS_STEPS:]) / len( + loss_buffer[-TRAILING_LOSS_STEPS:] + ) + avg_train_loss = sum(loss_buffer) / len(loss_buffer) + + # Logging + logger.debug( + f"EPOCH {_epoch} STEP {step}: " + f"lr={lr_for_print}, " + f"loss={round(loss.item(), 4)}, " + f"trailing_loss={round(trailing_loss, 4)}, " + f"average_loss={round(avg_train_loss, 4)}" + ) + + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + if scheduler: + scheduler.step() + lr_for_print = "{:.2e}".format(scheduler.get_last_lr()[0]) + + if steps_per_checkpoint: + if step % steps_per_checkpoint == 0: + make_checkpoint( + _accelerator=accelerator, + _epoch=_epoch, + _step=step, + ) + + return avg_train_loss + + def val_loop(dataloader: DataLoader, _epoch: int): + model.eval() + val_loss_buffer = [] + + with torch.no_grad(): + pbar = tqdm( + dataloader, desc=f"Validation Epoch {_epoch}", leave=False + ) + for batch in pbar: + seqs, eos_pos = batch + + seqs = seqs.contiguous() + bsz = seqs.size(0) + seqs_flat = seqs.view(2 * bsz, seqs.size(-1)) + + outputs = model(seqs_flat) + z1_full = outputs[0::2] + z2_full = outputs[1::2] + + batch_indices = torch.arange(bsz, device=z1_full.device) + eos_pos_1 = eos_pos[:, 0] + eos_pos_2 = eos_pos[:, 1] + + z1 = z1_full[batch_indices, eos_pos_1] + z2 = z2_full[batch_indices, eos_pos_2] + + loss = symmetric_nt_xent_loss_cosine(z1, z2) + # Gather loss from all devices (if applicable) + val_loss_buffer.append( + accelerator.gather(loss).mean(dim=0).item() + ) + + current_avg_loss = sum(val_loss_buffer) / len(val_loss_buffer) + + pbar.set_postfix_str(f"avg_loss={round(current_avg_loss,4)}") + + avg_val_loss = sum(val_loss_buffer) / len(val_loss_buffer) + + logger.info( + f"Validation Epoch {_epoch}: average_loss={round(avg_val_loss, 4)}" + ) + return avg_val_loss + + logger = get_logger(__name__) + TRAILING_LOSS_STEPS = 100 + + train_loop(dataloader=train_dataloader, _epoch=0, steps_per_checkpoint=2000) + make_checkpoint(_accelerator=accelerator, _epoch=1, _step=0) + val_loop(dataloader=val_dataloader, _epoch=0) + + +def train( + model_name: str, + train_data_path: str, + val_data_path: str, + num_workers: int, + num_epochs: int, + batch_size: int, + grad_acc_steps: int, + project_dir: str | None = None, + checkpoint_path: str | None = None, +): + accelerator = accelerate.Accelerator( + project_dir=project_dir, + gradient_accumulation_steps=grad_acc_steps, + ) + + if accelerator.is_main_process: + project_dir = setup_project_dir(project_dir) + logger = setup_logger(os.path.join(project_dir)) + else: + # In other processes, we won't create logs + project_dir = project_dir or "./experiments" + logger = get_logger(__name__) + + logger.info(f"Project directory: {project_dir}") + logger.info( + f"Training config: epochs={num_epochs}, batch_size={batch_size}, num_workers={num_workers}" + ) + + tokenizer = AbsTokenizer() + model_config = ModelConfig(**load_model_config(model_name)) + model_config.set_vocab_size(tokenizer.vocab_size) + model = TransformerEMB(model_config) + + if checkpoint_path is not None: + logger.info(f"Loading checkpoint from {checkpoint_path}") + model_state = _load_weight(checkpoint_path) + model_state = { + k.replace("_orig_mod.", ""): v for k, v in model_state.items() + } + if "lm_head.weight" in model_state.keys(): + del model_state["lm_head.weight"] + + model_state = { + k.replace("model.", ""): v for k, v in model_state.items() + } + model.model.load_state_dict(model_state) + else: + logger.info("No checkpoint path provided") + + train_dataloader, val_dataloader = get_dataloaders( + train_data_path=train_data_path, + val_data_path=val_data_path, + batch_size=batch_size, + num_workers=num_workers, + ) + + optimizer, scheduler = get_optim( + model=model, + num_epochs=num_epochs, + steps_per_epoch=len(train_dataloader), + ) + + ( + model, + train_dataloader, + val_dataloader, + optimizer, + scheduler, + ) = accelerator.prepare( + model, + train_dataloader, + val_dataloader, + optimizer, + scheduler, + ) + + _train( + num_epochs=num_epochs, + accelerator=accelerator, + model=model, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + optimizer=optimizer, + scheduler=scheduler, + project_dir=project_dir, + ) + + +def test_dataset(): + tokenizer = AbsTokenizer() + dataset = ContrastiveDataset( + load_path="/mnt/ssd1/aria/data/mididict-ft_val.jsonl", + min_number_slice_notes=150, + max_number_slice_notes=300, + max_seq_len=1024, + ) + + for idx, (enc, pos) in enumerate(dataset): + seq_1 = enc[0].tolist() + midi_dict_1 = tokenizer.detokenize(tokenizer.decode(seq_1)) + midi_dict_1.to_midi().save("/home/loubb/Dropbox/shared/test1.mid") + + seq_2 = enc[1].tolist() + midi_dict_2 = tokenizer.detokenize(tokenizer.decode(seq_2)) + midi_dict_2.to_midi().save("/home/loubb/Dropbox/shared/test2.mid") + + print(enc.shape) + print(pos.shape, pos) + input("") + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Finetune a model contrastive_embeddings" + ) + parser.add_argument("--model_name", type=str, required=True) + parser.add_argument("--checkpoint_path", type=str, default=None) + parser.add_argument("--train_data_path", type=str, required=True) + parser.add_argument("--val_data_path", type=str, required=True) + parser.add_argument("--batch_size", type=int) + parser.add_argument("--num_epochs", type=int) + parser.add_argument("--num_workers", type=int, default=1) + parser.add_argument("--grad_acc_steps", type=int, default=1) + parser.add_argument("--project_dir", type=str, default=None) + + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + train( + model_name=args.model_name, + checkpoint_path=args.checkpoint_path, + train_data_path=args.train_data_path, + val_data_path=args.val_data_path, + batch_size=args.batch_size, + num_epochs=args.num_epochs, + num_workers=args.num_workers, + grad_acc_steps=args.grad_acc_steps, + project_dir=args.project_dir, + ) + + # test_dataset() diff --git a/aria/embeddings/finetune_contrastive.py b/aria/embeddings/finetune_contrastive.py new file mode 100644 index 0000000..3a314a4 --- /dev/null +++ b/aria/embeddings/finetune_contrastive.py @@ -0,0 +1,683 @@ +import torch +import os +import mmap +import argparse +import logging +import random +import copy +import functools +import accelerate +import multiprocessing +import json +import jsonlines + +from aria.config import load_model_config +from aria.utils import _load_weight +from ariautils.tokenizer import AbsTokenizer +from ariautils.midi import MidiDict +from aria.model import TransformerCL, ModelConfig + +from torch import nn +from torch.utils.data import DataLoader, Dataset + +from accelerate.logging import get_logger +from logging.handlers import RotatingFileHandler +from tqdm import tqdm + +TAG_IDS = { + "chopin": 0, + "bach": 1, + "beethoven": 2, + "liszt": 3, + "mozart": 4, + "debussy": 5, + "schumann": 6, + "schubert": 7, + "rachmaninoff": 8, + "brahms": 9, + "tchaikovsky": 10, + "haydn": 11, + "scriabin": 12, + "mendelssohn": 13, + "czerny": 14, + "ravel": 15, + "scarlatti": 16, + "other": 17, +} +METADATA_CATEGORY = "composer" + + +def setup_logger(project_dir: str): + # Get logger and reset all handlers + logger = logging.getLogger(__name__) + for h in logger.handlers[:]: + logger.removeHandler(h) + + logger.propagate = False + logger.setLevel(logging.DEBUG) + formatter = logging.Formatter( + "[%(asctime)s] %(name)s: [%(levelname)s] %(message)s", + ) + fh = RotatingFileHandler( + os.path.join(project_dir, "logs.txt"), backupCount=5, maxBytes=1024**3 + ) + fh.setLevel(logging.DEBUG) + fh.setFormatter(formatter) + logger.addHandler(fh) + + ch = logging.StreamHandler() + ch.setLevel(logging.INFO) + ch.setFormatter(formatter) + logger.addHandler(ch) + + return get_logger(__name__) + + +def setup_project_dir(project_dir: str | None): + if not project_dir: + # Create project directory + if not os.path.isdir("./experiments"): + os.mkdir("./experiments") + + project_dirs = [ + _dir + for _dir in os.listdir("./experiments") + if os.path.isdir(os.path.join("experiments", _dir)) + ] + + ind = 0 + while True: + if str(ind) not in project_dirs: + break + else: + ind += 1 + + project_dir_abs = os.path.abspath(os.path.join("experiments", str(ind))) + assert not os.path.isdir(project_dir_abs) + os.mkdir(project_dir_abs) + + elif project_dir: + if os.path.isdir(project_dir): + assert ( + len(os.listdir(project_dir)) == 0 + ), "Provided project directory is not empty" + project_dir_abs = os.path.abspath(project_dir) + elif os.path.isfile(project_dir): + raise FileExistsError( + "The provided path points toward an existing file" + ) + else: + try: + os.mkdir(project_dir) + except Exception as e: + raise Exception( + f"Failed to create project directory at {project_dir}" + ) from e + + project_dir_abs = os.path.abspath(project_dir) + + os.mkdir(os.path.join(project_dir_abs, "checkpoints")) + + return project_dir_abs + + +def process_entry( + entry, + metadata_category: str, + tag_ids: dict, + min_slice_notes: int, + max_slice_notes: int, + max_seq_len: int, + tokenizer: AbsTokenizer, +): + midi_dict = MidiDict.from_msg_dict(entry) + metadata_tag = midi_dict.metadata.get(metadata_category, None) + + # Skip if metadata tag is missing or not in tag_ids. + if metadata_tag is None: + return [] + elif metadata_tag not in tag_ids: + metadata_tag = "other" + + outputs = [] + note_msgs = midi_dict.note_msgs + idx = 0 + + while idx < len(note_msgs): + slice_length = random.randint(min_slice_notes, max_slice_notes) + chunk = note_msgs[idx : idx + slice_length] + + # If the chunk is too short, break out of the loop. + if len(chunk) < min_slice_notes: + break + + idx += slice_length + + # Create slice + slice_midi_dict = copy.deepcopy(midi_dict) + slice_midi_dict.note_msgs = chunk + slice_midi_dict.metadata = {} + + # Format + tokenized_slice = tokenizer.tokenize(slice_midi_dict) + if tokenizer.dim_tok in tokenized_slice: + tokenized_slice.remove(tokenizer.dim_tok) + + # Use EOS tok for classification head + tokenized_slice = tokenized_slice[:max_seq_len] + tokenized_slice += [tokenizer.pad_tok] * ( + max_seq_len - len(tokenized_slice) + ) + if tokenizer.eos_tok not in tokenized_slice: + tokenized_slice[-1] = tokenizer.eos_tok + + pos = tokenized_slice.index(tokenizer.eos_tok) + + outputs.append( + {"seq": tokenized_slice, "tag": metadata_tag, "pos": pos} + ) + + return outputs + + +class ContrastiveDataset(Dataset): + def __init__(self, load_path: str, tag_ids: dict): + self.load_path = load_path + self.tag_ids = tag_ids + self.tokenizer = AbsTokenizer() + self.index = [] + + self.file_buff = open(self.load_path, "rb") + self.mmap_obj = mmap.mmap( + self.file_buff.fileno(), 0, access=mmap.ACCESS_READ + ) + + while True: + pos = self.mmap_obj.tell() + line = self.mmap_obj.readline() + if not line: + break + self.index.append(pos) + + def __getitem__(self, idx: int): + def _format(tok): + # Required because json formats tuples into lists + if isinstance(tok, list): + return tuple(tok) + return tok + + file_pos = self.index[idx] + self.mmap_obj.seek(file_pos) + + raw_data = self.mmap_obj.readline().decode("utf-8") + json_data = json.loads(raw_data) + + seq, tag, pos = json_data["seq"], json_data["tag"], json_data["pos"] + assert tag in self.tag_ids.keys() + assert pos < len(seq) + + seq = [_format(tok) for tok in seq] + seq_enc = torch.tensor(self.tokenizer.encode(seq)) + tag_enc = torch.tensor(self.tag_ids[tag]) + pos_enc = torch.tensor(pos) + + assert seq_enc[pos_enc.item()].item() == 1 # EOS ID + + return seq_enc, tag_enc, pos_enc + + def __len__(self): + return len(self.index) + + @classmethod + def export_worker_init_fn(cls): + def worker_init_fn(worker_id: int): + worker_info = torch.utils.data.get_worker_info() + dataset = worker_info.dataset + + if hasattr(dataset, "mmap_obj") and dataset.mmap_obj: + dataset.mmap_obj.close() + + f = open(dataset.load_path, "rb") + dataset.mmap_obj = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) + + return worker_init_fn + + @classmethod + def build( + cls, + midi_dataset_load_path: str, + save_path: str, + min_slice_notes: int, + max_slice_notes: int, + max_seq_len: int, + metadata_category: str, + tag_ids: dict, + ): + assert os.path.isfile(midi_dataset_load_path) + assert os.path.isfile(save_path) is False + + tokenizer = AbsTokenizer() + + with jsonlines.open( + midi_dataset_load_path, "r" + ) as midi_dataset, jsonlines.open(save_path, "w") as writer: + + cnt = 0 + with multiprocessing.Pool() as pool: + for result in pool.imap_unordered( + functools.partial( + process_entry, + metadata_category=metadata_category, + tag_ids=tag_ids, + min_slice_notes=min_slice_notes, + max_slice_notes=max_slice_notes, + max_seq_len=max_seq_len, + tokenizer=tokenizer, + ), + midi_dataset, + chunksize=10, + ): + cnt += 1 + if cnt % 500 == 0: + print(f"Completed {cnt}") + + for chunk in result: + writer.write(chunk) + + +def _get_optim( + lr: float, + model: nn.Module, + num_epochs: int, + steps_per_epoch: int, + warmup: int = 100, + end_ratio: int = 0.1, +): + optimizer = torch.optim.AdamW( + model.parameters(), + lr=lr, + weight_decay=0.1, + betas=(0.9, 0.95), + eps=1e-5, + ) + + warmup_lrs = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=0.000001, + end_factor=1, + total_iters=warmup, + ) + linear_decay_lrs = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=1, + end_factor=end_ratio, + total_iters=(num_epochs * steps_per_epoch) - warmup, + ) + + lr_scheduler = torch.optim.lr_scheduler.SequentialLR( + optimizer, + schedulers=[warmup_lrs, linear_decay_lrs], + milestones=[warmup], + ) + + return optimizer, lr_scheduler + + +def get_optim( + model: nn.Module, + num_epochs: int, + steps_per_epoch: int, +): + LR = 1e-5 + END_RATIO = 0.1 + WARMUP_STEPS = 1000 + + return _get_optim( + lr=LR, + model=model, + num_epochs=num_epochs, + steps_per_epoch=steps_per_epoch, + warmup=WARMUP_STEPS, + end_ratio=END_RATIO, + ) + + +def get_dataloaders( + train_data_path: str, + val_data_path: str, + batch_size: int, + num_workers: int, + apply_aug=True, +): + train_dataset = FinetuningDataset( + load_path=train_data_path, + tag_ids=TAG_IDS, + ) + val_dataset = FinetuningDataset( + load_path=val_data_path, + tag_ids=TAG_IDS, + ) + + train_loader = DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + ) + val_loader = DataLoader( + val_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + ) + return train_loader, val_loader + + +def _train( + num_epochs: int, + accelerator: accelerate.Accelerator, + model: TransformerCL, + train_dataloader: DataLoader, + val_dataloader: DataLoader, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler = None, + project_dir: str | None = None, +): + def make_checkpoint( + _accelerator: accelerate.Accelerator, _epoch: int, _step: int + ): + if accelerator.is_main_process: + checkpoint_dir = os.path.join( + project_dir, + "checkpoints", + f"epoch{_epoch}_step{_step}", + ) + + logger.info( + f"EPOCH {_epoch}/{num_epochs}: Saving checkpoint - {checkpoint_dir}" + ) + _accelerator.save_state(checkpoint_dir) + + def train_loop( + dataloader: DataLoader, _epoch: int, steps_per_checkpoint: int + ): + loss = torch.tensor([0.0]) + avg_train_loss = 0 + trailing_loss = 0 + loss_buffer = [] + + try: + lr_for_print = "{:.2e}".format(scheduler.get_last_lr()[0]) + except Exception: + pass + else: + lr_for_print = "{:.2e}".format(optimizer.param_groups[-1]["lr"]) + + model.train() + for __step, batch in ( + pbar := tqdm( + enumerate(dataloader), + total=len(dataloader), + initial=0, + leave=False, + ) + ): + pbar.set_postfix_str( + f"lr={lr_for_print}, " + f"loss={round(loss.item(), 4)}, " + f"trailing={round(trailing_loss, 4)}" + ) + + with accelerator.accumulate(model): + step = __step + 1 + + seqs, labels, eos_pos = batch + logits = model(seqs) # (b_sz, s_len, class_size) + logits = logits[ + torch.arange(logits.shape[0], device=logits.device), eos_pos + ] + loss = loss_fn(logits, labels) + + # Calculate statistics + loss_buffer.append(accelerator.gather(loss).mean(dim=0).item()) + trailing_loss = sum(loss_buffer[-TRAILING_LOSS_STEPS:]) / len( + loss_buffer[-TRAILING_LOSS_STEPS:] + ) + avg_train_loss = sum(loss_buffer) / len(loss_buffer) + + # Logging + logger.debug( + f"EPOCH {_epoch} STEP {step}: " + f"lr={lr_for_print}, " + f"loss={round(loss.item(), 4)}, " + f"trailing_loss={round(trailing_loss, 4)}, " + f"average_loss={round(avg_train_loss, 4)}" + ) + + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + if scheduler: + scheduler.step() + lr_for_print = "{:.2e}".format(scheduler.get_last_lr()[0]) + + if steps_per_checkpoint: + if step % steps_per_checkpoint == 0: + make_checkpoint( + _accelerator=accelerator, + _epoch=_epoch, + _step=step, + ) + + return avg_train_loss + + def val_loop(dataloader: DataLoader, _epoch: int): + model.eval() + val_loss_buffer = [] + total_correct = 0 + total_samples = 0 + + with torch.no_grad(): + pbar = tqdm( + dataloader, desc=f"Validation Epoch {_epoch}", leave=False + ) + for batch in pbar: + seqs, labels, eos_pos = batch + logits = model(seqs) # (b_sz, s_len, class_size) + logits = logits[ + torch.arange(logits.shape[0], device=logits.device), eos_pos + ] + loss = loss_fn(logits, labels) + # Gather loss from all devices (if applicable) + val_loss_buffer.append( + accelerator.gather(loss).mean(dim=0).item() + ) + + # Compute predictions and update accuracy stats + preds = torch.argmax(logits, dim=-1) + total_correct += (preds == labels).sum().item() + total_samples += labels.size(0) + current_accuracy = ( + total_correct / total_samples if total_samples > 0 else 0.0 + ) + current_avg_loss = sum(val_loss_buffer) / len(val_loss_buffer) + + pbar.set_postfix_str( + f"loss={round(current_avg_loss,4)}, acc={round(current_accuracy,4)}" + ) + + avg_val_loss = sum(val_loss_buffer) / len(val_loss_buffer) + accuracy = total_correct / total_samples if total_samples > 0 else 0.0 + + logger.info( + f"Validation Epoch {_epoch}: average_loss={round(avg_val_loss, 4)}, accuracy={round(accuracy, 4)}" + ) + return avg_val_loss, accuracy + + logger = get_logger(__name__) + loss_fn = nn.CrossEntropyLoss() + TRAILING_LOSS_STEPS = 100 + + train_loop(dataloader=train_dataloader, _epoch=0, steps_per_checkpoint=2000) + make_checkpoint(_accelerator=accelerator, _epoch=1, _step=0) + val_loop(dataloader=val_dataloader, _epoch=0) + + +def train( + model_name: str, + train_data_path: str, + val_data_path: str, + num_workers: int, + num_epochs: int, + batch_size: int, + grad_acc_steps: int, + project_dir: str | None = None, + checkpoint_path: str | None = None, +): + accelerator = accelerate.Accelerator( + project_dir=project_dir, + gradient_accumulation_steps=grad_acc_steps, + ) + + if accelerator.is_main_process: + project_dir = setup_project_dir(project_dir) + logger = setup_logger(os.path.join(project_dir)) + else: + # In other processes, we won't create logs + project_dir = project_dir or "./experiments" + logger = get_logger(__name__) + + logger.info(f"Project directory: {project_dir}") + logger.info( + f"Training config: epochs={num_epochs}, batch_size={batch_size}, num_workers={num_workers}" + ) + + tokenizer = AbsTokenizer() + model_config = ModelConfig(**load_model_config(model_name)) + model_config.set_vocab_size(tokenizer.vocab_size) + model = TransformerCL(model_config) + + if checkpoint_path is not None: + logger.info(f"Loading checkpoint from {checkpoint_path}") + model_state = _load_weight(checkpoint_path) + model_state = { + k.replace("_orig_mod.", ""): v for k, v in model_state.items() + } + if "lm_head.weight" in model_state.keys(): + del model_state["lm_head.weight"] + + model_state = { + k.replace("model.", ""): v for k, v in model_state.items() + } + model.model.load_state_dict(model_state) + else: + logger.info("No checkpoint path provided") + + model.compile() + + train_dataloader, val_dataloader = get_dataloaders( + train_data_path=train_data_path, + val_data_path=val_data_path, + batch_size=batch_size, + num_workers=num_workers, + apply_aug=True, + ) + + optimizer, scheduler = get_optim( + model=model, + num_epochs=num_epochs, + steps_per_epoch=len(train_dataloader), + ) + + ( + model, + train_dataloader, + val_dataloader, + optimizer, + scheduler, + ) = accelerator.prepare( + model, + train_dataloader, + val_dataloader, + optimizer, + scheduler, + ) + + _train( + num_epochs=num_epochs, + accelerator=accelerator, + model=model, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + optimizer=optimizer, + scheduler=scheduler, + project_dir=project_dir, + ) + + +def test_build_dataset(): + FinetuningDataset.build( + midi_dataset_load_path="/mnt/ssd1/aria/data/mididict-ft_val.jsonl", + save_path="/mnt/ssd1/aria/data/train.jsonl", + min_slice_notes=100, + max_slice_notes=165, + max_seq_len=512, + metadata_category=METADATA_CATEGORY, + tag_ids=TAG_IDS, + ) + + # FinetuningDataset.build( + # midi_dataset_load_path="/mnt/ssd1/aria/data/mididict-ft_val.jsonl", + # save_path="/mnt/ssd1/aria/data/val.jsonl", + # min_slice_notes=100, + # max_slice_notes=165, + # max_seq_len=512, + # metadata_category=METADATA_CATEGORY, + # tag_ids=TAG_IDS, + # ) + + +def test_dataset(): + dataset = FinetuningDataset( + load_path="/mnt/ssd1/aria/data/test.jsonl", + tag_ids=TAG_IDS, + ) + + for idx, entry in enumerate(dataset): + print(idx) + # print(entry) + # input("") + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Finetune a model for classification." + ) + parser.add_argument("--model_name", type=str, required=True) + parser.add_argument("--checkpoint_path", type=str, default=None) + parser.add_argument("--train_data_path", type=str, required=True) + parser.add_argument("--val_data_path", type=str, required=True) + parser.add_argument("--batch_size", type=int) + parser.add_argument("--num_epochs", type=int) + parser.add_argument("--num_workers", type=int, default=1) + parser.add_argument("--grad_acc_steps", type=int, default=1) + parser.add_argument("--project_dir", type=str, default=None) + return parser.parse_args() + + +if __name__ == "__main__": + # args = parse_args() + # train( + # model_name=args.model_name, + # checkpoint_path=args.checkpoint_path, + # train_data_path=args.train_data_path, + # val_data_path=args.val_data_path, + # batch_size=args.batch_size, + # num_epochs=args.num_epochs, + # num_workers=args.num_workers, + # grad_acc_steps=args.grad_acc_steps, + # project_dir=args.project_dir, + # ) + + test_build_dataset() + # test_dataset() From 3dc695367e72e33c9f364888a683245dce985cb4 Mon Sep 17 00:00:00 2001 From: Louis Date: Mon, 24 Feb 2025 18:39:08 +0000 Subject: [PATCH 14/72] add missing changes --- aria/model.py | 43 ++++++++++++++++++++++++++---- config/models/medium-composer.json | 30 +++++++++++++++++++++ config/models/medium-emb.json | 10 +++++++ 3 files changed, 78 insertions(+), 5 deletions(-) create mode 100644 config/models/medium-composer.json create mode 100644 config/models/medium-emb.json diff --git a/aria/model.py b/aria/model.py index 5dc14b3..13104cc 100644 --- a/aria/model.py +++ b/aria/model.py @@ -22,6 +22,7 @@ class ModelConfig: vocab_size: Optional[int] = None class_size: Optional[int] = None tag_to_id: Optional[dict] = None + emb_size: Optional[dict] = None def set_vocab_size(self, vocab_size: int): self.vocab_size = vocab_size @@ -257,14 +258,10 @@ def forward( Args: src (torch.tensor): Input to encoder block, of shape (batch_size, seq_len, d_model). - attn_mask (Optional[torch.tensor]): Attention mask of shape - (batch_size, seq_len). Defaults to None. - past_kv (Optional[list[KVCache]]): a list of kv caches. The list index - corresponds to the layer index. Returns: torch.tensor: Forward pass of src through Transformer and CL head. - Has shape (batch_size, seq_len, vocab_size). + Has shape (batch_size, seq_len, class_size). """ hidden = self.model(src) logits = self.class_head(hidden) @@ -272,6 +269,42 @@ def forward( return logits +class TransformerEMB(nn.Module): + """Transformer decoder with head for embedding. + + Args: + model_config (ModelConfig): Model config settings. + """ + + def __init__(self, model_config: ModelConfig): + super().__init__() + assert model_config.emb_size is not None + + self.max_seq_len = model_config.max_seq_len + self.model = Transformer(model_config) + self.emb_head = nn.Linear( + model_config.d_model, model_config.emb_size, bias=False + ) + + def forward( + self, + src: torch.Tensor, + ): + """Forward pass of Transformer decoder with EMB head. + + Args: + src (torch.tensor): Input to encoder block, of shape (batch_size, + seq_len, d_model). + Returns: + torch.tensor: Forward pass of src through Transformer and EMB head. + Has shape (batch_size, seq_len, emb_size). + """ + hidden = self.model(src) + emb = self.emb_head(hidden) + + return emb + + def precompute_freqs_cis( seq_len: int, n_elem: int, diff --git a/config/models/medium-composer.json b/config/models/medium-composer.json new file mode 100644 index 0000000..ece2209 --- /dev/null +++ b/config/models/medium-composer.json @@ -0,0 +1,30 @@ +{ + "d_model": 1536, + "n_heads": 24, + "n_layers": 16, + "ff_mult": 4, + "drop_p": 0.0, + "max_seq_len": 8192, + "grad_checkpoint": true, + "class_size": 18, + "tag_to_id": { + "chopin": 0, + "bach": 1, + "beethoven": 2, + "liszt": 3, + "mozart": 4, + "debussy": 5, + "schumann": 6, + "schubert": 7, + "rachmaninoff": 8, + "brahms": 9, + "tchaikovsky": 10, + "haydn": 11, + "scriabin": 12, + "mendelssohn": 13, + "czerny": 14, + "ravel": 15, + "scarlatti": 16, + "other": 17 + } +} \ No newline at end of file diff --git a/config/models/medium-emb.json b/config/models/medium-emb.json new file mode 100644 index 0000000..f4f6579 --- /dev/null +++ b/config/models/medium-emb.json @@ -0,0 +1,10 @@ +{ + "d_model": 1536, + "n_heads": 24, + "n_layers": 16, + "ff_mult": 4, + "drop_p": 0.0, + "max_seq_len": 8192, + "grad_checkpoint": true, + "emb_size": 512 +} From 32b832dfd8d772a04fab90032d316b04fc88ccd1 Mon Sep 17 00:00:00 2001 From: Louis Date: Mon, 24 Feb 2025 18:51:37 +0000 Subject: [PATCH 15/72] add loop --- aria/embeddings/contrastive_finetune.py | 9 +- aria/embeddings/finetune_contrastive.py | 683 ------------------------ 2 files changed, 6 insertions(+), 686 deletions(-) delete mode 100644 aria/embeddings/finetune_contrastive.py diff --git a/aria/embeddings/contrastive_finetune.py b/aria/embeddings/contrastive_finetune.py index 2df6552..d3ee942 100644 --- a/aria/embeddings/contrastive_finetune.py +++ b/aria/embeddings/contrastive_finetune.py @@ -515,9 +515,12 @@ def val_loop(dataloader: DataLoader, _epoch: int): logger = get_logger(__name__) TRAILING_LOSS_STEPS = 100 - train_loop(dataloader=train_dataloader, _epoch=0, steps_per_checkpoint=2000) - make_checkpoint(_accelerator=accelerator, _epoch=1, _step=0) - val_loop(dataloader=val_dataloader, _epoch=0) + for _epoch_num in range(num_epochs): + train_loop(dataloader=train_dataloader, _epoch=_epoch_num) + make_checkpoint( + _accelerator=accelerator, _epoch=_epoch_num + 1, _step=0 + ) + val_loop(dataloader=val_dataloader, _epoch=_epoch_num) def train( diff --git a/aria/embeddings/finetune_contrastive.py b/aria/embeddings/finetune_contrastive.py deleted file mode 100644 index 3a314a4..0000000 --- a/aria/embeddings/finetune_contrastive.py +++ /dev/null @@ -1,683 +0,0 @@ -import torch -import os -import mmap -import argparse -import logging -import random -import copy -import functools -import accelerate -import multiprocessing -import json -import jsonlines - -from aria.config import load_model_config -from aria.utils import _load_weight -from ariautils.tokenizer import AbsTokenizer -from ariautils.midi import MidiDict -from aria.model import TransformerCL, ModelConfig - -from torch import nn -from torch.utils.data import DataLoader, Dataset - -from accelerate.logging import get_logger -from logging.handlers import RotatingFileHandler -from tqdm import tqdm - -TAG_IDS = { - "chopin": 0, - "bach": 1, - "beethoven": 2, - "liszt": 3, - "mozart": 4, - "debussy": 5, - "schumann": 6, - "schubert": 7, - "rachmaninoff": 8, - "brahms": 9, - "tchaikovsky": 10, - "haydn": 11, - "scriabin": 12, - "mendelssohn": 13, - "czerny": 14, - "ravel": 15, - "scarlatti": 16, - "other": 17, -} -METADATA_CATEGORY = "composer" - - -def setup_logger(project_dir: str): - # Get logger and reset all handlers - logger = logging.getLogger(__name__) - for h in logger.handlers[:]: - logger.removeHandler(h) - - logger.propagate = False - logger.setLevel(logging.DEBUG) - formatter = logging.Formatter( - "[%(asctime)s] %(name)s: [%(levelname)s] %(message)s", - ) - fh = RotatingFileHandler( - os.path.join(project_dir, "logs.txt"), backupCount=5, maxBytes=1024**3 - ) - fh.setLevel(logging.DEBUG) - fh.setFormatter(formatter) - logger.addHandler(fh) - - ch = logging.StreamHandler() - ch.setLevel(logging.INFO) - ch.setFormatter(formatter) - logger.addHandler(ch) - - return get_logger(__name__) - - -def setup_project_dir(project_dir: str | None): - if not project_dir: - # Create project directory - if not os.path.isdir("./experiments"): - os.mkdir("./experiments") - - project_dirs = [ - _dir - for _dir in os.listdir("./experiments") - if os.path.isdir(os.path.join("experiments", _dir)) - ] - - ind = 0 - while True: - if str(ind) not in project_dirs: - break - else: - ind += 1 - - project_dir_abs = os.path.abspath(os.path.join("experiments", str(ind))) - assert not os.path.isdir(project_dir_abs) - os.mkdir(project_dir_abs) - - elif project_dir: - if os.path.isdir(project_dir): - assert ( - len(os.listdir(project_dir)) == 0 - ), "Provided project directory is not empty" - project_dir_abs = os.path.abspath(project_dir) - elif os.path.isfile(project_dir): - raise FileExistsError( - "The provided path points toward an existing file" - ) - else: - try: - os.mkdir(project_dir) - except Exception as e: - raise Exception( - f"Failed to create project directory at {project_dir}" - ) from e - - project_dir_abs = os.path.abspath(project_dir) - - os.mkdir(os.path.join(project_dir_abs, "checkpoints")) - - return project_dir_abs - - -def process_entry( - entry, - metadata_category: str, - tag_ids: dict, - min_slice_notes: int, - max_slice_notes: int, - max_seq_len: int, - tokenizer: AbsTokenizer, -): - midi_dict = MidiDict.from_msg_dict(entry) - metadata_tag = midi_dict.metadata.get(metadata_category, None) - - # Skip if metadata tag is missing or not in tag_ids. - if metadata_tag is None: - return [] - elif metadata_tag not in tag_ids: - metadata_tag = "other" - - outputs = [] - note_msgs = midi_dict.note_msgs - idx = 0 - - while idx < len(note_msgs): - slice_length = random.randint(min_slice_notes, max_slice_notes) - chunk = note_msgs[idx : idx + slice_length] - - # If the chunk is too short, break out of the loop. - if len(chunk) < min_slice_notes: - break - - idx += slice_length - - # Create slice - slice_midi_dict = copy.deepcopy(midi_dict) - slice_midi_dict.note_msgs = chunk - slice_midi_dict.metadata = {} - - # Format - tokenized_slice = tokenizer.tokenize(slice_midi_dict) - if tokenizer.dim_tok in tokenized_slice: - tokenized_slice.remove(tokenizer.dim_tok) - - # Use EOS tok for classification head - tokenized_slice = tokenized_slice[:max_seq_len] - tokenized_slice += [tokenizer.pad_tok] * ( - max_seq_len - len(tokenized_slice) - ) - if tokenizer.eos_tok not in tokenized_slice: - tokenized_slice[-1] = tokenizer.eos_tok - - pos = tokenized_slice.index(tokenizer.eos_tok) - - outputs.append( - {"seq": tokenized_slice, "tag": metadata_tag, "pos": pos} - ) - - return outputs - - -class ContrastiveDataset(Dataset): - def __init__(self, load_path: str, tag_ids: dict): - self.load_path = load_path - self.tag_ids = tag_ids - self.tokenizer = AbsTokenizer() - self.index = [] - - self.file_buff = open(self.load_path, "rb") - self.mmap_obj = mmap.mmap( - self.file_buff.fileno(), 0, access=mmap.ACCESS_READ - ) - - while True: - pos = self.mmap_obj.tell() - line = self.mmap_obj.readline() - if not line: - break - self.index.append(pos) - - def __getitem__(self, idx: int): - def _format(tok): - # Required because json formats tuples into lists - if isinstance(tok, list): - return tuple(tok) - return tok - - file_pos = self.index[idx] - self.mmap_obj.seek(file_pos) - - raw_data = self.mmap_obj.readline().decode("utf-8") - json_data = json.loads(raw_data) - - seq, tag, pos = json_data["seq"], json_data["tag"], json_data["pos"] - assert tag in self.tag_ids.keys() - assert pos < len(seq) - - seq = [_format(tok) for tok in seq] - seq_enc = torch.tensor(self.tokenizer.encode(seq)) - tag_enc = torch.tensor(self.tag_ids[tag]) - pos_enc = torch.tensor(pos) - - assert seq_enc[pos_enc.item()].item() == 1 # EOS ID - - return seq_enc, tag_enc, pos_enc - - def __len__(self): - return len(self.index) - - @classmethod - def export_worker_init_fn(cls): - def worker_init_fn(worker_id: int): - worker_info = torch.utils.data.get_worker_info() - dataset = worker_info.dataset - - if hasattr(dataset, "mmap_obj") and dataset.mmap_obj: - dataset.mmap_obj.close() - - f = open(dataset.load_path, "rb") - dataset.mmap_obj = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) - - return worker_init_fn - - @classmethod - def build( - cls, - midi_dataset_load_path: str, - save_path: str, - min_slice_notes: int, - max_slice_notes: int, - max_seq_len: int, - metadata_category: str, - tag_ids: dict, - ): - assert os.path.isfile(midi_dataset_load_path) - assert os.path.isfile(save_path) is False - - tokenizer = AbsTokenizer() - - with jsonlines.open( - midi_dataset_load_path, "r" - ) as midi_dataset, jsonlines.open(save_path, "w") as writer: - - cnt = 0 - with multiprocessing.Pool() as pool: - for result in pool.imap_unordered( - functools.partial( - process_entry, - metadata_category=metadata_category, - tag_ids=tag_ids, - min_slice_notes=min_slice_notes, - max_slice_notes=max_slice_notes, - max_seq_len=max_seq_len, - tokenizer=tokenizer, - ), - midi_dataset, - chunksize=10, - ): - cnt += 1 - if cnt % 500 == 0: - print(f"Completed {cnt}") - - for chunk in result: - writer.write(chunk) - - -def _get_optim( - lr: float, - model: nn.Module, - num_epochs: int, - steps_per_epoch: int, - warmup: int = 100, - end_ratio: int = 0.1, -): - optimizer = torch.optim.AdamW( - model.parameters(), - lr=lr, - weight_decay=0.1, - betas=(0.9, 0.95), - eps=1e-5, - ) - - warmup_lrs = torch.optim.lr_scheduler.LinearLR( - optimizer, - start_factor=0.000001, - end_factor=1, - total_iters=warmup, - ) - linear_decay_lrs = torch.optim.lr_scheduler.LinearLR( - optimizer, - start_factor=1, - end_factor=end_ratio, - total_iters=(num_epochs * steps_per_epoch) - warmup, - ) - - lr_scheduler = torch.optim.lr_scheduler.SequentialLR( - optimizer, - schedulers=[warmup_lrs, linear_decay_lrs], - milestones=[warmup], - ) - - return optimizer, lr_scheduler - - -def get_optim( - model: nn.Module, - num_epochs: int, - steps_per_epoch: int, -): - LR = 1e-5 - END_RATIO = 0.1 - WARMUP_STEPS = 1000 - - return _get_optim( - lr=LR, - model=model, - num_epochs=num_epochs, - steps_per_epoch=steps_per_epoch, - warmup=WARMUP_STEPS, - end_ratio=END_RATIO, - ) - - -def get_dataloaders( - train_data_path: str, - val_data_path: str, - batch_size: int, - num_workers: int, - apply_aug=True, -): - train_dataset = FinetuningDataset( - load_path=train_data_path, - tag_ids=TAG_IDS, - ) - val_dataset = FinetuningDataset( - load_path=val_data_path, - tag_ids=TAG_IDS, - ) - - train_loader = DataLoader( - train_dataset, - batch_size=batch_size, - shuffle=True, - num_workers=num_workers, - ) - val_loader = DataLoader( - val_dataset, - batch_size=batch_size, - shuffle=False, - num_workers=num_workers, - ) - return train_loader, val_loader - - -def _train( - num_epochs: int, - accelerator: accelerate.Accelerator, - model: TransformerCL, - train_dataloader: DataLoader, - val_dataloader: DataLoader, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler.LRScheduler = None, - project_dir: str | None = None, -): - def make_checkpoint( - _accelerator: accelerate.Accelerator, _epoch: int, _step: int - ): - if accelerator.is_main_process: - checkpoint_dir = os.path.join( - project_dir, - "checkpoints", - f"epoch{_epoch}_step{_step}", - ) - - logger.info( - f"EPOCH {_epoch}/{num_epochs}: Saving checkpoint - {checkpoint_dir}" - ) - _accelerator.save_state(checkpoint_dir) - - def train_loop( - dataloader: DataLoader, _epoch: int, steps_per_checkpoint: int - ): - loss = torch.tensor([0.0]) - avg_train_loss = 0 - trailing_loss = 0 - loss_buffer = [] - - try: - lr_for_print = "{:.2e}".format(scheduler.get_last_lr()[0]) - except Exception: - pass - else: - lr_for_print = "{:.2e}".format(optimizer.param_groups[-1]["lr"]) - - model.train() - for __step, batch in ( - pbar := tqdm( - enumerate(dataloader), - total=len(dataloader), - initial=0, - leave=False, - ) - ): - pbar.set_postfix_str( - f"lr={lr_for_print}, " - f"loss={round(loss.item(), 4)}, " - f"trailing={round(trailing_loss, 4)}" - ) - - with accelerator.accumulate(model): - step = __step + 1 - - seqs, labels, eos_pos = batch - logits = model(seqs) # (b_sz, s_len, class_size) - logits = logits[ - torch.arange(logits.shape[0], device=logits.device), eos_pos - ] - loss = loss_fn(logits, labels) - - # Calculate statistics - loss_buffer.append(accelerator.gather(loss).mean(dim=0).item()) - trailing_loss = sum(loss_buffer[-TRAILING_LOSS_STEPS:]) / len( - loss_buffer[-TRAILING_LOSS_STEPS:] - ) - avg_train_loss = sum(loss_buffer) / len(loss_buffer) - - # Logging - logger.debug( - f"EPOCH {_epoch} STEP {step}: " - f"lr={lr_for_print}, " - f"loss={round(loss.item(), 4)}, " - f"trailing_loss={round(trailing_loss, 4)}, " - f"average_loss={round(avg_train_loss, 4)}" - ) - - accelerator.backward(loss) - optimizer.step() - optimizer.zero_grad() - if scheduler: - scheduler.step() - lr_for_print = "{:.2e}".format(scheduler.get_last_lr()[0]) - - if steps_per_checkpoint: - if step % steps_per_checkpoint == 0: - make_checkpoint( - _accelerator=accelerator, - _epoch=_epoch, - _step=step, - ) - - return avg_train_loss - - def val_loop(dataloader: DataLoader, _epoch: int): - model.eval() - val_loss_buffer = [] - total_correct = 0 - total_samples = 0 - - with torch.no_grad(): - pbar = tqdm( - dataloader, desc=f"Validation Epoch {_epoch}", leave=False - ) - for batch in pbar: - seqs, labels, eos_pos = batch - logits = model(seqs) # (b_sz, s_len, class_size) - logits = logits[ - torch.arange(logits.shape[0], device=logits.device), eos_pos - ] - loss = loss_fn(logits, labels) - # Gather loss from all devices (if applicable) - val_loss_buffer.append( - accelerator.gather(loss).mean(dim=0).item() - ) - - # Compute predictions and update accuracy stats - preds = torch.argmax(logits, dim=-1) - total_correct += (preds == labels).sum().item() - total_samples += labels.size(0) - current_accuracy = ( - total_correct / total_samples if total_samples > 0 else 0.0 - ) - current_avg_loss = sum(val_loss_buffer) / len(val_loss_buffer) - - pbar.set_postfix_str( - f"loss={round(current_avg_loss,4)}, acc={round(current_accuracy,4)}" - ) - - avg_val_loss = sum(val_loss_buffer) / len(val_loss_buffer) - accuracy = total_correct / total_samples if total_samples > 0 else 0.0 - - logger.info( - f"Validation Epoch {_epoch}: average_loss={round(avg_val_loss, 4)}, accuracy={round(accuracy, 4)}" - ) - return avg_val_loss, accuracy - - logger = get_logger(__name__) - loss_fn = nn.CrossEntropyLoss() - TRAILING_LOSS_STEPS = 100 - - train_loop(dataloader=train_dataloader, _epoch=0, steps_per_checkpoint=2000) - make_checkpoint(_accelerator=accelerator, _epoch=1, _step=0) - val_loop(dataloader=val_dataloader, _epoch=0) - - -def train( - model_name: str, - train_data_path: str, - val_data_path: str, - num_workers: int, - num_epochs: int, - batch_size: int, - grad_acc_steps: int, - project_dir: str | None = None, - checkpoint_path: str | None = None, -): - accelerator = accelerate.Accelerator( - project_dir=project_dir, - gradient_accumulation_steps=grad_acc_steps, - ) - - if accelerator.is_main_process: - project_dir = setup_project_dir(project_dir) - logger = setup_logger(os.path.join(project_dir)) - else: - # In other processes, we won't create logs - project_dir = project_dir or "./experiments" - logger = get_logger(__name__) - - logger.info(f"Project directory: {project_dir}") - logger.info( - f"Training config: epochs={num_epochs}, batch_size={batch_size}, num_workers={num_workers}" - ) - - tokenizer = AbsTokenizer() - model_config = ModelConfig(**load_model_config(model_name)) - model_config.set_vocab_size(tokenizer.vocab_size) - model = TransformerCL(model_config) - - if checkpoint_path is not None: - logger.info(f"Loading checkpoint from {checkpoint_path}") - model_state = _load_weight(checkpoint_path) - model_state = { - k.replace("_orig_mod.", ""): v for k, v in model_state.items() - } - if "lm_head.weight" in model_state.keys(): - del model_state["lm_head.weight"] - - model_state = { - k.replace("model.", ""): v for k, v in model_state.items() - } - model.model.load_state_dict(model_state) - else: - logger.info("No checkpoint path provided") - - model.compile() - - train_dataloader, val_dataloader = get_dataloaders( - train_data_path=train_data_path, - val_data_path=val_data_path, - batch_size=batch_size, - num_workers=num_workers, - apply_aug=True, - ) - - optimizer, scheduler = get_optim( - model=model, - num_epochs=num_epochs, - steps_per_epoch=len(train_dataloader), - ) - - ( - model, - train_dataloader, - val_dataloader, - optimizer, - scheduler, - ) = accelerator.prepare( - model, - train_dataloader, - val_dataloader, - optimizer, - scheduler, - ) - - _train( - num_epochs=num_epochs, - accelerator=accelerator, - model=model, - train_dataloader=train_dataloader, - val_dataloader=val_dataloader, - optimizer=optimizer, - scheduler=scheduler, - project_dir=project_dir, - ) - - -def test_build_dataset(): - FinetuningDataset.build( - midi_dataset_load_path="/mnt/ssd1/aria/data/mididict-ft_val.jsonl", - save_path="/mnt/ssd1/aria/data/train.jsonl", - min_slice_notes=100, - max_slice_notes=165, - max_seq_len=512, - metadata_category=METADATA_CATEGORY, - tag_ids=TAG_IDS, - ) - - # FinetuningDataset.build( - # midi_dataset_load_path="/mnt/ssd1/aria/data/mididict-ft_val.jsonl", - # save_path="/mnt/ssd1/aria/data/val.jsonl", - # min_slice_notes=100, - # max_slice_notes=165, - # max_seq_len=512, - # metadata_category=METADATA_CATEGORY, - # tag_ids=TAG_IDS, - # ) - - -def test_dataset(): - dataset = FinetuningDataset( - load_path="/mnt/ssd1/aria/data/test.jsonl", - tag_ids=TAG_IDS, - ) - - for idx, entry in enumerate(dataset): - print(idx) - # print(entry) - # input("") - - -def parse_args(): - parser = argparse.ArgumentParser( - description="Finetune a model for classification." - ) - parser.add_argument("--model_name", type=str, required=True) - parser.add_argument("--checkpoint_path", type=str, default=None) - parser.add_argument("--train_data_path", type=str, required=True) - parser.add_argument("--val_data_path", type=str, required=True) - parser.add_argument("--batch_size", type=int) - parser.add_argument("--num_epochs", type=int) - parser.add_argument("--num_workers", type=int, default=1) - parser.add_argument("--grad_acc_steps", type=int, default=1) - parser.add_argument("--project_dir", type=str, default=None) - return parser.parse_args() - - -if __name__ == "__main__": - # args = parse_args() - # train( - # model_name=args.model_name, - # checkpoint_path=args.checkpoint_path, - # train_data_path=args.train_data_path, - # val_data_path=args.val_data_path, - # batch_size=args.batch_size, - # num_epochs=args.num_epochs, - # num_workers=args.num_workers, - # grad_acc_steps=args.grad_acc_steps, - # project_dir=args.project_dir, - # ) - - test_build_dataset() - # test_dataset() From 68fe37872118d8a5aa932e3fbb49888d104a4c5b Mon Sep 17 00:00:00 2001 From: Louis Date: Mon, 24 Feb 2025 18:55:54 +0000 Subject: [PATCH 16/72] fix arg bug --- aria/embeddings/contrastive_finetune.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/aria/embeddings/contrastive_finetune.py b/aria/embeddings/contrastive_finetune.py index d3ee942..5920a9f 100644 --- a/aria/embeddings/contrastive_finetune.py +++ b/aria/embeddings/contrastive_finetune.py @@ -368,7 +368,9 @@ def make_checkpoint( _accelerator.save_state(checkpoint_dir) def train_loop( - dataloader: DataLoader, _epoch: int, steps_per_checkpoint: int + dataloader: DataLoader, + _epoch: int, + steps_per_checkpoint: int | None = None, ): loss = torch.tensor([0.0]) avg_train_loss = 0 From dae0b03e1dac44c152976e1dd65152a480ac2ba6 Mon Sep 17 00:00:00 2001 From: Louis Date: Thu, 27 Feb 2025 17:11:30 +0000 Subject: [PATCH 17/72] update eval --- aria/embeddings/eval.py | 451 +++++++++++++++++++++++++++------------- 1 file changed, 305 insertions(+), 146 deletions(-) diff --git a/aria/embeddings/eval.py b/aria/embeddings/eval.py index 119c9c4..db3e9aa 100644 --- a/aria/embeddings/eval.py +++ b/aria/embeddings/eval.py @@ -3,26 +3,26 @@ import os import mmap import json +import time import functools import multiprocessing +import queue import copy import jsonlines import torch.nn as nn import torch.nn.functional as F from tqdm import tqdm -from collections import deque from typing import Callable from concurrent.futures import ThreadPoolExecutor -from aria.model import ModelConfig, TransformerLM, TransformerCL +from aria.model import ModelConfig, TransformerLM, TransformerEMB from aria.config import load_model_config from aria.utils import _load_weight from ariautils.midi import MidiDict from ariautils.tokenizer import AbsTokenizer -MODEL_PATH = "/mnt/ssd1/aria/v2/medium-dedupe-pt-cont2/checkpoints/epoch18_step0/model.safetensors" -MAX_SEQ_LEN = 512 +METADATA_CATEGORY = "composer" TAG_TO_ID = { "chopin": 0, "bach": 1, @@ -44,16 +44,17 @@ "other": 17, } ID_TO_TAG = {v: k for k, v in TAG_TO_ID.items()} -METADATA_CATEGORY = "composer" -def chunk_and_pad(lst: list, n: int): - return [lst[i : i + n] for i in range(0, len(lst), n)] +def model_forward( + model: nn.Module, + idxs: torch.Tensor, +): + return model(idxs) -def init_worker(): - global tokenizer - tokenizer = AbsTokenizer() +def chunk_and_pad(lst: list, n: int): + return [lst[i : i + n] for i in range(0, len(lst), n)] def write_entries(writer, entries): @@ -61,22 +62,13 @@ def write_entries(writer, entries): writer.write(entry) -# The worker function processes a single JSON-lines entry. def process_entry( entry, - metadata_category: str, - tag_ids: dict, slice_len_notes: int, max_seq_len: int, + tokenizer: AbsTokenizer, ): midi_dict = MidiDict.from_msg_dict(entry) - metadata_tag = midi_dict.metadata.get(metadata_category, None) - - # Skip metadata tag - if metadata_tag is None: - return [] - elif metadata_tag not in tag_ids.keys(): - metadata_tag = "other" outputs = [] for slice_note_msgs in chunk_and_pad( @@ -89,18 +81,53 @@ def process_entry( slice_midi_dict.note_msgs = slice_note_msgs slice_midi_dict.metadata = {} tokenized_slice = tokenizer.tokenize(slice_midi_dict) - if tokenizer.eos_tok in tokenized_slice: - tokenized_slice.remove(tokenizer.eos_tok) if tokenizer.dim_tok in tokenized_slice: tokenized_slice.remove(tokenizer.dim_tok) tokenized_slice = tokenized_slice[:max_seq_len] - outputs.append({"seq": tokenized_slice, "tag": metadata_tag}) + outputs.append({"seq": tokenized_slice, "metadata": midi_dict.metadata}) return outputs +def _pad_seq(seq: list, tokenizer: AbsTokenizer, max_seq_len: int): + seq = seq[:max_seq_len] + seq += [tokenizer.pad_tok] * (max_seq_len - len(seq)) + + if tokenizer.eos_tok not in seq: + seq[-1] = tokenizer.eos_tok + + return seq + + +@torch.autocast("cuda", dtype=torch.bfloat16) +@torch.inference_mode() +def get_contrastive_embedding( + seqs: list, + hook_model: nn.Module, + hook_max_seq_len: int, + hook_tokenizer: AbsTokenizer, + hook_model_forward: Callable, +): + seqs = [ + _pad_seq( + seq=seq, tokenizer=hook_tokenizer, max_seq_len=hook_max_seq_len + ) + for seq in seqs + ] + + eos_positions = [seq.index(hook_tokenizer.eos_tok) for seq in seqs] + enc_seqs = torch.tensor( + [hook_tokenizer.encode(seq) for seq in seqs], device="cuda" + ) + hidden_states = hook_model_forward(model=hook_model, idxs=enc_seqs) + idx = torch.arange(hidden_states.shape[0], device=hidden_states.device) + emb = hidden_states[idx, eos_positions].tolist() + + return emb + + @torch.autocast("cuda", dtype=torch.bfloat16) @torch.inference_mode() def get_baseline_embedding( @@ -110,6 +137,10 @@ def get_baseline_embedding( hook_tokenizer: AbsTokenizer, pool_mode: str = "last", # "last" or "mean" ): + for seq in seqs: + if hook_tokenizer.eos_tok in seq: + seq.remove(hook_tokenizer.eos_tok) + orig_lengths = [len(seq) for seq in seqs] last_tok_positions = [length - 1 for length in orig_lengths] seqs = [ @@ -126,7 +157,7 @@ def get_baseline_embedding( idx = torch.arange(hidden_states.shape[0], device=hidden_states.device) emb = hidden_states[idx, last_tok_positions].tolist() elif pool_mode == "mean": - pad_id = tokenizer.pad_id + pad_id = hook_tokenizer.pad_id # Create a mask by comparing enc_seqs to pad_id. mask = (enc_seqs != pad_id).unsqueeze(-1).to(hidden_states.dtype) # Sum over valid tokens and average. @@ -141,9 +172,10 @@ def get_baseline_embedding( class EvaluationDataset(torch.utils.data.Dataset): - def __init__(self, load_path: str, tag_ids: dict): + def __init__(self, load_path: str, tag_ids: dict, metadata_category: str): self.load_path = load_path self.tag_ids = tag_ids + self.metadata_category = metadata_category self.tokenizer = AbsTokenizer() self.index = [] @@ -167,7 +199,9 @@ def __getitem__(self, idx: int): json_data = json.loads(raw_data) emb = json_data["emb"] - tag = json_data["tag"] + metadata = json_data["metadata"] + tag = metadata.get(self.metadata_category, "other") + tag = tag if tag in self.tag_ids.keys() else "other" assert tag in self.tag_ids tag_tensor = torch.tensor(self.tag_ids[tag]) @@ -199,67 +233,128 @@ def build( save_path: str, slice_len_notes: int, max_seq_len: int, - metadata_category: str, - tag_ids: dict, batch_size: int, embedding_hook: Callable, **embedding_hook_kwargs, ): - assert os.path.isfile(midi_dataset_load_path) - assert os.path.isfile(save_path) is False - - with jsonlines.open( - midi_dataset_load_path, "r" - ) as midi_dataset, jsonlines.open(save_path, "w") as writer: - + def batch_producer( + results_queue: queue.Queue, + batch_queue: queue.Queue, + batch_size: int, + ): + buffer = [] + while True: + if batch_queue.qsize() >= 5: + time.sleep(1) + + try: + result = results_queue.get(timeout=0.01) + if result is None: + if len(buffer) > 0: + batch_queue.put(buffer) + break + + buffer.append(result) + if len(buffer) == batch_size: + batch_queue.put(buffer) + buffer = [] + except queue.Empty: + pass + + def producer( + midi_dataset_load_path: str, + midi_dict_queue: queue.Queue, + ): cnt = 0 - buffer = deque() - write_executor = ThreadPoolExecutor(max_workers=1) - with multiprocessing.Pool( - processes=8, initializer=init_worker - ) as pool: - for result in pool.imap_unordered( - functools.partial( - process_entry, - metadata_category=metadata_category, - tag_ids=tag_ids, - slice_len_notes=slice_len_notes, - max_seq_len=max_seq_len, - ), - midi_dataset, - chunksize=10, - ): - + with jsonlines.open(midi_dataset_load_path, "r") as midi_dataset: + for midi_dict in midi_dataset: + while midi_dict_queue.qsize() >= 250: + time.sleep(0.1) + midi_dict_queue.put(midi_dict) cnt += 1 - if cnt % 500 == 0: - print(f"Completed {cnt}") - - for entry in result: - buffer.append(entry) - - # Inside your processing loop: - if len(buffer) >= batch_size: - _buffer = [buffer.popleft() for _ in range(batch_size)] - _seqs = [entry["seq"] for entry in _buffer] - _tags = [entry["tag"] for entry in _buffer] - _embs = embedding_hook( - seqs=_seqs, **embedding_hook_kwargs - ) - # Prepare the write objects - write_objs = [ - {"seq": _seq, "emb": _emb, "tag": _tag} - for _seq, _emb, _tag in zip(_seqs, _embs, _tags) - ] + if cnt % 500 == 0: + print(f"Finished {cnt}") + + for _ in range(16): + midi_dict_queue.put(None) + + def worker( + midi_dict_queue: queue.Queue, + results_queue: queue.Queue, + slice_len_notes: int, + max_seq_len: int, + ): + tokenizer = AbsTokenizer() + + while True: + midi_dict = midi_dict_queue.get() + if midi_dict is None: + results_queue.put(None) + break + + while results_queue.qsize() > 500: + time.sleep(0.5) + + _result = process_entry( + entry=midi_dict, + slice_len_notes=slice_len_notes, + max_seq_len=max_seq_len, + tokenizer=tokenizer, + ) + for _sub_result in _result: + results_queue.put(_sub_result) - write_executor.submit(write_entries, writer, write_objs) + assert os.path.isfile(midi_dataset_load_path) + assert os.path.isfile(save_path) is False - if buffer: - _seqs = [entry["seq"] for entry in buffer] - _tags = [entry["tag"] for entry in buffer] - _embs = embedding_hook(seqs=_seqs, **embedding_hook_kwargs) - for _seq, _tag, _emb in zip(_seqs, _tags, _embs): - writer.write({"seq": _seq, "emb": _emb, "tag": _tag}) + write_executor = ThreadPoolExecutor(max_workers=1) + results_queue = multiprocessing.Queue() + midi_dict_queue = multiprocessing.Queue() + batch_queue = multiprocessing.Queue() + producer_process = multiprocessing.Process( + target=producer, args=(midi_dataset_load_path, midi_dict_queue) + ) + batch_producer_process = multiprocessing.Process( + target=batch_producer, args=(results_queue, batch_queue, batch_size) + ) + worker_processes = [ + multiprocessing.Process( + target=worker, + args=( + midi_dict_queue, + results_queue, + slice_len_notes, + max_seq_len, + ), + ) + for _ in range(4) + ] + + producer_process.start() + batch_producer_process.start() + for p in worker_processes: + p.start() + + with jsonlines.open(save_path, "w") as writer: + while batch_producer_process.is_alive() or not batch_queue.empty(): + try: + batch = batch_queue.get(timeout=0.01) + + _seqs = [item["seq"] for item in batch] + _metadata = [item["metadata"] for item in batch] + _embs = embedding_hook(seqs=_seqs, **embedding_hook_kwargs) + + write_objs = [ + {"seq": s, "emb": e, "metadata": m} + for s, e, m in zip(_seqs, _embs, _metadata) + ] + write_executor.submit(write_entries, writer, write_objs) + + except queue.Empty: + continue + + write_executor.shutdown(wait=True) def _get_optim( @@ -300,17 +395,12 @@ def _get_optim( class ClassifierHead(nn.Module): - def __init__(self, d_emb: int, hidden_dim: int, num_class: int): + def __init__(self, d_emb: int, num_class: int): super().__init__() - self.fc1 = nn.Linear(d_emb, hidden_dim) - self.activation = nn.ReLU() - self.fc2 = nn.Linear(hidden_dim, num_class) + self.linear = nn.Linear(d_emb, num_class) - def forward(self, x): - x = self.fc1(x) - x = self.activation(x) - logits = self.fc2(x) - return logits + def forward(self, x: torch.Tensor): + return self.linear(x) def _train( @@ -319,6 +409,7 @@ def _train( train_dataloader: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, + num_epochs: int = 1, ): TRAILING_LOSS_STEPS = 100 loss = torch.tensor([0.0]) @@ -329,32 +420,33 @@ def _train( model.train() loss_fn = nn.CrossEntropyLoss() - for __step, batch in ( - pbar := tqdm(enumerate(train_dataloader), leave=False) - ): - pbar.set_postfix_str( - f"lr={lr_for_print}, " - f"loss={round(loss.item(), 4)}, " - f"trailing={round(trailing_loss, 4)}" - ) - - emb, tag_ids = batch - tag_ids = tag_ids.view(-1) - - logits = model(emb) - loss = loss_fn(logits, tag_ids) - - loss_buffer.append(accelerator.gather(loss).mean(dim=0).item()) - trailing_loss = sum(loss_buffer[-TRAILING_LOSS_STEPS:]) / len( - loss_buffer[-TRAILING_LOSS_STEPS:] - ) - - accelerator.backward(loss) - optimizer.step() - optimizer.zero_grad() - if scheduler: - scheduler.step() - lr_for_print = "{:.2e}".format(scheduler.get_last_lr()[0]) + for _epoch in range(num_epochs): + for __step, batch in ( + pbar := tqdm(enumerate(train_dataloader), leave=False) + ): + pbar.set_postfix_str( + f"lr={lr_for_print}, " + f"loss={round(loss.item(), 4)}, " + f"trailing={round(trailing_loss, 4)}" + ) + + emb, tag_ids = batch + tag_ids = tag_ids.view(-1) + + logits = model(emb) + loss = loss_fn(logits, tag_ids) + + loss_buffer.append(accelerator.gather(loss).mean(dim=0).item()) + trailing_loss = sum(loss_buffer[-TRAILING_LOSS_STEPS:]) / len( + loss_buffer[-TRAILING_LOSS_STEPS:] + ) + + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + if scheduler: + scheduler.step() + lr_for_print = "{:.2e}".format(scheduler.get_last_lr()[0]) if accelerator.is_main_process: accelerator.save_state("/mnt/ssd1/aria/test") @@ -368,23 +460,23 @@ def train_classifier( tag_ids: dict, batch_size: int, ): + num_epochs = 1 train_dataloader = torch.utils.data.DataLoader( dataset=train_dataset, batch_size=batch_size, shuffle=True, - num_workers=8, + num_workers=24, worker_init_fn=EvaluationDataset.export_worker_init_fn(), ) model = ClassifierHead( d_emb=emb_d, - hidden_dim=emb_d, num_class=len(tag_ids.keys()), ) optimizer, scheduler = _get_optim( lr=3e-4, model=model, - total_steps=len(train_dataloader), + total_steps=num_epochs * len(train_dataloader), ) accelerator = accelerate.Accelerator() @@ -401,6 +493,7 @@ def train_classifier( train_dataloader=train_dataloader, optimizer=optimizer, scheduler=scheduler, + num_epochs=num_epochs, ) @@ -408,34 +501,55 @@ def evaluate_model(model: nn.Module, val_dataset_path: str): val_dataset = EvaluationDataset( load_path=val_dataset_path, tag_ids=TAG_TO_ID, + metadata_category=METADATA_CATEGORY, ) - model = model.cpu() + model = model.cpu().eval() - correct = 0 - total = 0 - dist = {k: 0 for k in TAG_TO_ID.keys()} + # Count true values and correct predictions per tag. + dist = {k: {"correct": 0, "total": 0} for k in TAG_TO_ID.keys()} + # New dictionary to count predictions per tag. + pred_dist = {k: 0 for k in TAG_TO_ID.keys()} for midi_emb, tag_id in val_dataset: with torch.no_grad(): logits = model(torch.tensor(midi_emb.view(1, -1))) - probs = F.softmax(logits) + probs = F.softmax(logits, dim=-1) pred_tag_id = probs.argmax(dim=-1).item() - dist[ID_TO_TAG[tag_id.item()]] += 1 - if ID_TO_TAG[tag_id.item()] == "other": - continue + true_tag = ID_TO_TAG[tag_id.item()] + pred_tag = ID_TO_TAG[pred_tag_id] - if pred_tag_id == tag_id.item(): - correct += 1 - total += 1 + dist[true_tag]["total"] += 1 + pred_dist[pred_tag] += 1 - print(f"Total accuracy: {correct/total}") - print(f"Label distribution: {dist}") + if pred_tag_id == tag_id.item(): + dist[true_tag]["correct"] += 1 + + total_correct = sum(v["correct"] for v in dist.values()) + total_samples = sum(v["total"] for v in dist.values()) + print(f"Total accuracy: {total_correct/total_samples}") + + for tag in TAG_TO_ID.keys(): + TP = dist[tag]["correct"] + FN = dist[tag]["total"] - TP + FP = pred_dist[tag] - TP + precision = TP / (TP + FP) if (TP + FP) > 0 else 0 + recall = TP / (TP + FN) if (TP + FN) > 0 else 0 + f1 = ( + 2 * precision * recall / (precision + recall) + if (precision + recall) > 0 + else 0 + ) + print( + f"{tag} -- Accuracy: {TP/dist[tag]['total']}, Precision: {precision}, Recall: {recall}, F1: {f1}" + ) -def build_dataset(): +def build_baseline_dataset(): + MAX_SEQ_LEN = 512 MODEL_PATH = "/mnt/ssd1/aria/v2/medium-dedupe-pt-cont2/checkpoints/epoch18_step0/model.safetensors" + tokenizer = AbsTokenizer() model_state = _load_weight(MODEL_PATH, "cuda") model_state = { k.replace("_orig_mod.", ""): v for k, v in model_state.items() @@ -447,38 +561,83 @@ def build_dataset(): pretrained_model.load_state_dict(model_state) pretrained_model.eval() + global model_forward + model_forward = torch.compile( + model_forward, + mode="reduce-overhead", + fullgraph=True, + ) + EvaluationDataset.build( midi_dataset_load_path="/mnt/ssd1/aria/data/mididict-ft_train.jsonl", save_path="/mnt/ssd1/aria/data/train.jsonl", max_seq_len=MAX_SEQ_LEN, slice_len_notes=165, - metadata_category="genre", - tag_ids=TAG_TO_ID, batch_size=128, embedding_hook=functools.partial( get_baseline_embedding, pool_mode="mean" ), hook_model=pretrained_model.model.cuda(), - hook_max_seq_len=512, + hook_max_seq_len=MAX_SEQ_LEN, hook_tokenizer=tokenizer, ) -if __name__ == "__main__": +def build_contrastive_dataset(): + MAX_SEQ_LEN = 1024 + MODEL_PATH = ( + "/home/loubb/work/aria/models/medium-emb-t0.5-s1024-e20.safetensors" + ) + tokenizer = AbsTokenizer() + model_state = _load_weight(MODEL_PATH, "cuda") + model_state = { + k.replace("_orig_mod.", ""): v for k, v in model_state.items() + } + pretrained_model_config = ModelConfig(**load_model_config("medium-emb")) + pretrained_model_config.set_vocab_size(tokenizer.vocab_size) + pretrained_model_config.grad_checkpoint = False + pretrained_model = TransformerEMB(pretrained_model_config) + pretrained_model.load_state_dict(model_state) + pretrained_model.eval() - dataset = EvaluationDataset( - load_path="/mnt/ssd1/aria/data/train.jsonl", - tag_ids=TAG_TO_ID, + hook_model_forward = torch.compile( + model_forward, + mode="reduce-overhead", + fullgraph=True, ) - model = train_classifier( - emb_d=1536, - train_dataset=dataset, - batch_size=32, - tag_ids=TAG_TO_ID, - ) - evaluate_model( - model=model, - val_dataset_path="/mnt/ssd1/aria/data/val.jsonl", + EvaluationDataset.build( + midi_dataset_load_path="/mnt/ssd1/aria/data/mididict-all_train.jsonl", + save_path="/mnt/ssd1/aria/data/eval/test.jsonl", + max_seq_len=MAX_SEQ_LEN, + slice_len_notes=300, + batch_size=128, + embedding_hook=get_contrastive_embedding, + hook_model=pretrained_model.cuda(), + hook_max_seq_len=MAX_SEQ_LEN, + hook_tokenizer=tokenizer, + hook_model_forward=hook_model_forward, ) + + +if __name__ == "__main__": + # tokenizer = AbsTokenizer() + # dataset = EvaluationDataset( + # load_path="/mnt/ssd1/aria/data/eval/temp-train.jsonl", + # tag_ids=TAG_TO_ID, + # metadata_category=METADATA_CATEGORY, + # ) + + # model = train_classifier( + # emb_d=512, + # train_dataset=dataset, + # batch_size=32, + # tag_ids=TAG_TO_ID, + # ) + # evaluate_model( + # model=model, + # val_dataset_path="/mnt/ssd1/aria/data/eval/temp-val.jsonl", + # ) + + build_contrastive_dataset() From 06ef3381f4f6b66fd27176ceb61139fec231bb16 Mon Sep 17 00:00:00 2001 From: Louis Date: Fri, 28 Feb 2025 14:20:35 +0000 Subject: [PATCH 18/72] fix eval hang --- aria/embeddings/eval.py | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/aria/embeddings/eval.py b/aria/embeddings/eval.py index db3e9aa..635a329 100644 --- a/aria/embeddings/eval.py +++ b/aria/embeddings/eval.py @@ -241,34 +241,37 @@ def batch_producer( results_queue: queue.Queue, batch_queue: queue.Queue, batch_size: int, + total_workers: int, ): buffer = [] - while True: + termination_signals = 0 + while termination_signals < total_workers: if batch_queue.qsize() >= 5: time.sleep(1) - try: result = results_queue.get(timeout=0.01) if result is None: - if len(buffer) > 0: - batch_queue.put(buffer) - break - + termination_signals += 1 + continue buffer.append(result) if len(buffer) == batch_size: batch_queue.put(buffer) buffer = [] except queue.Empty: - pass + continue + + if buffer: + batch_queue.put(buffer) def producer( midi_dataset_load_path: str, midi_dict_queue: queue.Queue, + num_workers: int, ): cnt = 0 with jsonlines.open(midi_dataset_load_path, "r") as midi_dataset: for midi_dict in midi_dataset: - while midi_dict_queue.qsize() >= 250: + while midi_dict_queue.qsize() >= 1000: time.sleep(0.1) midi_dict_queue.put(midi_dict) cnt += 1 @@ -276,7 +279,7 @@ def producer( if cnt % 500 == 0: print(f"Finished {cnt}") - for _ in range(16): + for _ in range(num_workers): midi_dict_queue.put(None) def worker( @@ -293,7 +296,7 @@ def worker( results_queue.put(None) break - while results_queue.qsize() > 500: + while results_queue.qsize() > 1000: time.sleep(0.5) _result = process_entry( @@ -308,15 +311,18 @@ def worker( assert os.path.isfile(midi_dataset_load_path) assert os.path.isfile(save_path) is False + TOTAL_WORKERS = 8 write_executor = ThreadPoolExecutor(max_workers=1) results_queue = multiprocessing.Queue() midi_dict_queue = multiprocessing.Queue() batch_queue = multiprocessing.Queue() producer_process = multiprocessing.Process( - target=producer, args=(midi_dataset_load_path, midi_dict_queue) + target=producer, + args=(midi_dataset_load_path, midi_dict_queue, TOTAL_WORKERS), ) batch_producer_process = multiprocessing.Process( - target=batch_producer, args=(results_queue, batch_queue, batch_size) + target=batch_producer, + args=(results_queue, batch_queue, batch_size, TOTAL_WORKERS), ) worker_processes = [ multiprocessing.Process( @@ -328,7 +334,7 @@ def worker( max_seq_len, ), ) - for _ in range(4) + for _ in range(TOTAL_WORKERS) ] producer_process.start() @@ -612,7 +618,7 @@ def build_contrastive_dataset(): save_path="/mnt/ssd1/aria/data/eval/test.jsonl", max_seq_len=MAX_SEQ_LEN, slice_len_notes=300, - batch_size=128, + batch_size=256, embedding_hook=get_contrastive_embedding, hook_model=pretrained_model.cuda(), hook_max_seq_len=MAX_SEQ_LEN, From 19f979508d50c8f8b43a8bf5f13bb6ada25ee909 Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 5 Mar 2025 13:21:28 +0000 Subject: [PATCH 19/72] add data aug --- aria/embeddings/contrastive_finetune.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/aria/embeddings/contrastive_finetune.py b/aria/embeddings/contrastive_finetune.py index 5920a9f..5fb1129 100644 --- a/aria/embeddings/contrastive_finetune.py +++ b/aria/embeddings/contrastive_finetune.py @@ -104,15 +104,22 @@ def __init__( min_number_slice_notes: int, max_number_slice_notes: int, max_seq_len: int, + apply_aug: bool = False, ): self.load_path = load_path self.min_number_slice_notes = min_number_slice_notes self.max_number_slice_notes = max_number_slice_notes self.max_seq_len = max_seq_len + self.apply_aug = apply_aug self.tokenizer = AbsTokenizer() - self.index = [] + if apply_aug is True: + self.aug_fns = self.tokenizer.export_data_aug() + else: + self.aug_fns = None + + self.index = [] self.file_buff = open(self.load_path, "rb") self.mmap_obj = mmap.mmap( self.file_buff.fileno(), 0, access=mmap.ACCESS_READ @@ -185,6 +192,12 @@ def _format(tok): slice_seq_1 = [_format(tok) for tok in slice_seq_1] slice_seq_2 = [_format(tok) for tok in slice_seq_2] + if self.apply_aug: + assert self.aug_fns + for fn in self.aug_fns: + slice_seq_1 = fn(slice_seq_1) + slice_seq_2 = fn(slice_seq_2) + assert len(slice_seq_1) <= self.max_seq_len assert len(slice_seq_2) <= self.max_seq_len assert slice_pos_1 < self.max_seq_len From 2c086e145aa56a5d9acd62765870377655d5efa5 Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 5 Mar 2025 13:37:25 +0000 Subject: [PATCH 20/72] fix data aug --- aria/embeddings/contrastive_finetune.py | 46 +++++++------------------ 1 file changed, 13 insertions(+), 33 deletions(-) diff --git a/aria/embeddings/contrastive_finetune.py b/aria/embeddings/contrastive_finetune.py index 5fb1129..394af89 100644 --- a/aria/embeddings/contrastive_finetune.py +++ b/aria/embeddings/contrastive_finetune.py @@ -138,6 +138,7 @@ def get_slice( min_num_notes: int, max_num_notes: int, max_seq_len: int, + apply_aug: bool = False, ): _midi_dict = copy.deepcopy(midi_dict) slice_length = random.randint(min_num_notes, max_num_notes) @@ -147,6 +148,15 @@ def get_slice( _midi_dict.metadata = {} tokenized_slice = self.tokenizer.tokenize(_midi_dict) + + if apply_aug: + assert self.aug_fns + for fn in self.aug_fns: + tokenized_slice = fn(tokenized_slice) + + while self.tokenizer.pad_tok in tokenized_slice: + tokenized_slice.remove(self.tokenizer.pad_tok) + if self.tokenizer.dim_tok in tokenized_slice: tokenized_slice.remove(self.tokenizer.dim_tok) @@ -163,12 +173,6 @@ def get_slice( return tokenized_slice, pos def __getitem__(self, idx: int): - def _format(tok): - # Required because json formats tuples into lists - if isinstance(tok, list): - return tuple(tok) - return tok - file_pos = self.index[idx] self.mmap_obj.seek(file_pos) @@ -181,23 +185,16 @@ def _format(tok): min_num_notes=self.min_number_slice_notes, max_num_notes=self.max_number_slice_notes, max_seq_len=self.max_seq_len, + apply_aug=self.apply_aug, ) slice_seq_2, slice_pos_2 = self.get_slice( midi_dict=midi_dict, min_num_notes=self.min_number_slice_notes, max_num_notes=self.max_number_slice_notes, max_seq_len=self.max_seq_len, + apply_aug=self.apply_aug, ) - slice_seq_1 = [_format(tok) for tok in slice_seq_1] - slice_seq_2 = [_format(tok) for tok in slice_seq_2] - - if self.apply_aug: - assert self.aug_fns - for fn in self.aug_fns: - slice_seq_1 = fn(slice_seq_1) - slice_seq_2 = fn(slice_seq_2) - assert len(slice_seq_1) <= self.max_seq_len assert len(slice_seq_2) <= self.max_seq_len assert slice_pos_1 < self.max_seq_len @@ -431,24 +428,6 @@ def train_loop( z1 = z1_full[batch_indices, eos_pos_1] z2 = z2_full[batch_indices, eos_pos_2] - # seqs_1 = seqs[:, 0, :] - # eos_pos_1 = eos_pos[:, 0] - # z1_full = model(seqs_1) - # z1 = z1_full[ - # torch.arange(z1_full.shape[0], device=z1_full.device), - # eos_pos_1, - # ] - - # seqs_2 = seqs[:, 1, :] - # eos_pos_2 = eos_pos[:, 1] - # z2_full = model(seqs_2) - # z2 = z2_full[ - # torch.arange(z2_full.shape[0], device=z2_full.device), - # eos_pos_2, - # ] - - #### - loss = symmetric_nt_xent_loss_cosine(z1, z2) # Calculate statistics @@ -634,6 +613,7 @@ def test_dataset(): min_number_slice_notes=150, max_number_slice_notes=300, max_seq_len=1024, + apply_aug=True, ) for idx, (enc, pos) in enumerate(dataset): From b524ab31b237d2db22522fedc1a7df8aaaeb712b Mon Sep 17 00:00:00 2001 From: Louis Date: Sat, 8 Mar 2025 20:34:56 +0000 Subject: [PATCH 21/72] formalize eval --- aria/embeddings/evaluate.py | 721 ++++++++++++++++++ aria/embeddings/m3/__init__.py | 0 aria/embeddings/m3/config.py | 79 ++ aria/embeddings/m3/emb.py | 216 ++++++ aria/embeddings/m3/utils.py | 702 +++++++++++++++++ aria/embeddings/mert/__init__.py | 0 aria/embeddings/mert/emb.py | 150 ++++ .../scripts/build_embedding_eval_datasets.py | 208 +++++ .../scripts/evaluate_embedding_with_probe.py | 0 paper/scripts/get_unique_tags.py | 31 + paper/scripts/make_cl_split.py | 113 +++ 11 files changed, 2220 insertions(+) create mode 100644 aria/embeddings/evaluate.py create mode 100644 aria/embeddings/m3/__init__.py create mode 100644 aria/embeddings/m3/config.py create mode 100644 aria/embeddings/m3/emb.py create mode 100644 aria/embeddings/m3/utils.py create mode 100644 aria/embeddings/mert/__init__.py create mode 100644 aria/embeddings/mert/emb.py create mode 100644 paper/scripts/build_embedding_eval_datasets.py create mode 100644 paper/scripts/evaluate_embedding_with_probe.py create mode 100644 paper/scripts/get_unique_tags.py create mode 100644 paper/scripts/make_cl_split.py diff --git a/aria/embeddings/evaluate.py b/aria/embeddings/evaluate.py new file mode 100644 index 0000000..d7a0957 --- /dev/null +++ b/aria/embeddings/evaluate.py @@ -0,0 +1,721 @@ +import torch +import accelerate +import os +import mmap +import json +import time +import functools +import multiprocessing +import queue +import copy +import jsonlines +import torch.nn as nn +import torch.nn.functional as F + +from tqdm import tqdm +from typing import Callable +from concurrent.futures import ThreadPoolExecutor + +from aria.model import ModelConfig, TransformerLM +from aria.config import load_model_config +from aria.utils import _load_weight +from ariautils.midi import MidiDict +from ariautils.tokenizer import AbsTokenizer + +CATEGORY_TAGS = { + "genre": { + "classical": 0, + "jazz": 1, + }, + "music_period": { + "baroque": 0, + "classical": 1, + "romantic": 2, + "impressionist": 3, + }, + "composer": { + "beethoven": 0, + "debussy": 1, + "brahms": 2, + "rachmaninoff": 3, + "schumann": 4, + "mozart": 5, + "liszt": 6, + "bach": 7, + "chopin": 8, + "schubert": 9, + }, + "form": { + "nocturne": 0, + "sonata": 1, + "improvisation": 2, + "etude": 3, + "fugue": 4, + "waltz": 5, + }, +} + + +def model_forward( + model: nn.Module, + idxs: torch.Tensor, +): + return model(idxs) + + +def chunk_and_pad(lst: list, n: int): + return [lst[i : i + n] for i in range(0, len(lst), n)] + + +def write_entries(writer, entries): + for entry in entries: + writer.write(entry) + + +def process_entry( + entry, + slice_len_notes: int, + max_seq_len: int, + tokenizer: AbsTokenizer, +): + midi_dict = MidiDict.from_msg_dict(entry) + + outputs = [] + for slice_note_msgs in chunk_and_pad( + lst=midi_dict.note_msgs, n=slice_len_notes + ): + if len(slice_note_msgs) < 20: + break + + slice_midi_dict = copy.deepcopy(midi_dict) + slice_midi_dict.note_msgs = slice_note_msgs + slice_midi_dict.metadata = {} + tokenized_slice = tokenizer.tokenize(slice_midi_dict) + if tokenizer.dim_tok in tokenized_slice: + tokenized_slice.remove(tokenizer.dim_tok) + + tokenized_slice = tokenized_slice[:max_seq_len] + + outputs.append({"seq": tokenized_slice, "metadata": midi_dict.metadata}) + + return outputs + + +def _pad_seq(seq: list, tokenizer: AbsTokenizer, max_seq_len: int): + seq = seq[:max_seq_len] + seq += [tokenizer.pad_tok] * (max_seq_len - len(seq)) + + if tokenizer.eos_tok not in seq: + seq[-1] = tokenizer.eos_tok + + return seq + + +@torch.autocast( + "cuda", + dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16, +) +@torch.inference_mode() +def get_aria_contrastive_embedding( + seqs: list, + hook_model: nn.Module, + hook_max_seq_len: int, + hook_tokenizer: AbsTokenizer, + hook_model_forward: Callable, + hook_max_batch_size: int = 64, +): + all_emb = [] + + for i in range(0, len(seqs), hook_max_batch_size): + batch_seqs = seqs[i : i + hook_max_batch_size] + padded_seqs = [ + _pad_seq( + seq=seq, tokenizer=hook_tokenizer, max_seq_len=hook_max_seq_len + ) + for seq in batch_seqs + ] + eos_positions = [ + seq.index(hook_tokenizer.eos_tok) for seq in padded_seqs + ] + enc_seqs = torch.tensor( + [hook_tokenizer.encode(seq) for seq in padded_seqs], device="cuda" + ) + hidden_states = hook_model_forward(model=hook_model, idxs=enc_seqs) + idx = torch.arange(hidden_states.shape[0], device=hidden_states.device) + batch_emb = hidden_states[idx, eos_positions].tolist() + all_emb.extend(batch_emb) + + return all_emb + + +def get_mert_embedding( + seqs: list, + hook_model: nn.Module, + hook_processor, + hook_tokenizer: AbsTokenizer, + hook_pianoteq_exec_path: str, + hook_pianoteq_num_procs: int, +): + from aria.embeddings.mert.emb import ( + seq_to_audio_path, + compute_audio_embedding, + ) + + with multiprocessing.Pool(hook_pianoteq_num_procs) as pool: + audio_paths = pool.imap( + functools.partial( + seq_to_audio_path, + tokenizer=hook_tokenizer, + pianoteq_exec_path=hook_pianoteq_exec_path, + ), + seqs, + ) + + emb = [ + compute_audio_embedding( + audio_path=path, + model=hook_model, + processor=hook_processor, + delete_audio=True, + ).tolist() + for path in audio_paths + ] + + return emb + + +def get_clamp3_embedding( + seqs: list, + hook_model: nn.Module, + hook_patchilizer, + hook_tokenizer: AbsTokenizer, +): + from aria.embeddings.m3.emb import get_midi_embedding + + emb = [ + get_midi_embedding( + mid=hook_tokenizer.detokenize(seq).to_midi(), + model=hook_model, + patchilizer=hook_patchilizer, + get_global=True, + ).tolist() + for seq in seqs + ] + + return emb + + +@torch.autocast("cuda", dtype=torch.bfloat16) +@torch.inference_mode() +def get_baseline_embedding( + seqs: list, + hook_model: nn.Module, + hook_max_seq_len: int, + hook_tokenizer: AbsTokenizer, + pool_mode: str = "last", # "last" or "mean" +): + for seq in seqs: + if hook_tokenizer.eos_tok in seq: + seq.remove(hook_tokenizer.eos_tok) + + orig_lengths = [len(seq) for seq in seqs] + last_tok_positions = [length - 1 for length in orig_lengths] + seqs = [ + seq + ([hook_tokenizer.pad_tok] * (hook_max_seq_len - len(seq))) + for seq in seqs + ] + + enc_seqs = torch.tensor( + [hook_tokenizer.encode(seq) for seq in seqs], device="cuda" + ) + hidden_states = hook_model(enc_seqs) + + if pool_mode == "last": + idx = torch.arange(hidden_states.shape[0], device=hidden_states.device) + emb = hidden_states[idx, last_tok_positions].tolist() + elif pool_mode == "mean": + pad_id = hook_tokenizer.pad_id + # Create a mask by comparing enc_seqs to pad_id. + mask = (enc_seqs != pad_id).unsqueeze(-1).to(hidden_states.dtype) + # Sum over valid tokens and average. + sum_hidden = (hidden_states * mask).sum(dim=1) + valid_counts = mask.sum(dim=1) + mean_hidden = sum_hidden / valid_counts + emb = mean_hidden.tolist() + else: + raise ValueError(f"Unsupported pool_mode: {pool_mode}") + + return emb + + +class EvaluationDataset(torch.utils.data.Dataset): + def __init__(self, load_path: str, tag_to_id: dict, metadata_category: str): + self.load_path = load_path + self.tag_to_id = tag_to_id + self.metadata_category = metadata_category + self.tokenizer = AbsTokenizer() + self.index = [] + + self.file_buff = open(self.load_path, "rb") + self.mmap_obj = mmap.mmap( + self.file_buff.fileno(), 0, access=mmap.ACCESS_READ + ) + + while True: + pos = self.mmap_obj.tell() + line = self.mmap_obj.readline() + if not line: + break + self.index.append(pos) + + def __getitem__(self, idx: int): + pos = self.index[idx] + self.mmap_obj.seek(pos) + + raw_data = self.mmap_obj.readline().decode("utf-8") + json_data = json.loads(raw_data) + + emb = json_data["emb"] + metadata = json_data["metadata"] + tag = metadata.get(self.metadata_category, "other") + tag = tag if tag in self.tag_to_id.keys() else "other" + + assert tag in self.tag_to_id + tag_tensor = torch.tensor(self.tag_to_id[tag]) + emb_tensor = torch.tensor(emb) + + return emb_tensor, tag_tensor + + def __len__(self): + return len(self.index) + + @classmethod + def export_worker_init_fn(cls): + def worker_init_fn(worker_id: int): + worker_info = torch.utils.data.get_worker_info() + dataset = worker_info.dataset + + if hasattr(dataset, "mmap_obj") and dataset.mmap_obj: + dataset.mmap_obj.close() + + f = open(dataset.load_path, "rb") + dataset.mmap_obj = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) + + return worker_init_fn + + @classmethod + def build( + cls, + midi_dataset_load_path: str, + save_path: str, + slice_len_notes: int, + max_seq_len: int, + batch_size: int, + embedding_hook: Callable, + **embedding_hook_kwargs, + ): + def batch_producer( + results_queue: queue.Queue, + batch_queue: queue.Queue, + batch_size: int, + total_workers: int, + ): + buffer = [] + termination_signals = 0 + while termination_signals < total_workers: + if batch_queue.qsize() >= 5: + time.sleep(1) + try: + result = results_queue.get(timeout=0.01) + if result is None: + termination_signals += 1 + continue + buffer.append(result) + if len(buffer) == batch_size: + batch_queue.put(buffer) + buffer = [] + except queue.Empty: + continue + + if buffer: + batch_queue.put(buffer) + + def producer( + midi_dataset_load_path: str, + midi_dict_queue: queue.Queue, + num_workers: int, + ): + cnt = 0 + with jsonlines.open(midi_dataset_load_path, "r") as midi_dataset: + for midi_dict in midi_dataset: + while midi_dict_queue.qsize() >= 1000: + time.sleep(0.1) + midi_dict_queue.put(midi_dict) + cnt += 1 + + if cnt % 500 == 0: + print(f"Finished {cnt}") + + for _ in range(num_workers): + midi_dict_queue.put(None) + + def worker( + midi_dict_queue: queue.Queue, + results_queue: queue.Queue, + slice_len_notes: int, + max_seq_len: int, + ): + tokenizer = AbsTokenizer() + + while True: + midi_dict = midi_dict_queue.get() + if midi_dict is None: + results_queue.put(None) + break + + while results_queue.qsize() > 1000: + time.sleep(0.5) + + _result = process_entry( + entry=midi_dict, + slice_len_notes=slice_len_notes, + max_seq_len=max_seq_len, + tokenizer=tokenizer, + ) + for _sub_result in _result: + results_queue.put(_sub_result) + + assert os.path.isfile(midi_dataset_load_path) + assert os.path.isfile(save_path) is False + + TOTAL_WORKERS = 8 + write_executor = ThreadPoolExecutor(max_workers=1) + results_queue = multiprocessing.Queue() + midi_dict_queue = multiprocessing.Queue() + batch_queue = multiprocessing.Queue() + producer_process = multiprocessing.Process( + target=producer, + args=(midi_dataset_load_path, midi_dict_queue, TOTAL_WORKERS), + ) + batch_producer_process = multiprocessing.Process( + target=batch_producer, + args=(results_queue, batch_queue, batch_size, TOTAL_WORKERS), + ) + worker_processes = [ + multiprocessing.Process( + target=worker, + args=( + midi_dict_queue, + results_queue, + slice_len_notes, + max_seq_len, + ), + ) + for _ in range(TOTAL_WORKERS) + ] + + producer_process.start() + batch_producer_process.start() + for p in worker_processes: + p.start() + + with jsonlines.open(save_path, "w") as writer: + while batch_producer_process.is_alive() or not batch_queue.empty(): + try: + batch = batch_queue.get(timeout=0.01) + + _seqs = [item["seq"] for item in batch] + _metadata = [item["metadata"] for item in batch] + _embs = embedding_hook(seqs=_seqs, **embedding_hook_kwargs) + + write_objs = [ + {"seq": s, "emb": e, "metadata": m} + for s, e, m in zip(_seqs, _embs, _metadata) + ] + write_executor.submit(write_entries, writer, write_objs) + + except queue.Empty: + continue + + write_executor.shutdown(wait=True) + + +def _get_optim( + lr: float, + model: nn.Module, + total_steps: int, + warmup: int = 100, + end_ratio: int = 0.1, +): + optimizer = torch.optim.AdamW( + model.parameters(), + lr=lr, + weight_decay=0.1, + betas=(0.9, 0.95), + eps=1e-5, + ) + + warmup_lrs = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=0.000001, + end_factor=1, + total_iters=warmup, + ) + linear_decay_lrs = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=1, + end_factor=end_ratio, + total_iters=total_steps - warmup, + ) + + lr_scheduler = torch.optim.lr_scheduler.SequentialLR( + optimizer, + schedulers=[warmup_lrs, linear_decay_lrs], + milestones=[warmup], + ) + + return optimizer, lr_scheduler + + +class ClassifierHead(nn.Module): + def __init__(self, d_emb: int, num_class: int): + super().__init__() + self.linear = nn.Linear(d_emb, num_class) + + def forward(self, x: torch.Tensor): + return self.linear(x) + + +def _train( + accelerator: accelerate.Accelerator, + model: nn.Module, + train_dataloader: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + num_epochs: int = 1, +): + TRAILING_LOSS_STEPS = 100 + loss = torch.tensor([0.0]) + trailing_loss = 0 + lr_for_print = "{:.2e}".format(optimizer.param_groups[-1]["lr"]) + loss_buffer = [] + + model.train() + loss_fn = nn.CrossEntropyLoss() + + for _epoch in range(num_epochs): + for __step, batch in ( + pbar := tqdm(enumerate(train_dataloader), leave=False) + ): + pbar.set_postfix_str( + f"lr={lr_for_print}, " + f"loss={round(loss.item(), 4)}, " + f"trailing={round(trailing_loss, 4)}" + ) + + emb, tag_ids = batch + tag_ids = tag_ids.view(-1) + + logits = model(emb) + loss = loss_fn(logits, tag_ids) + + loss_buffer.append(accelerator.gather(loss).mean(dim=0).item()) + trailing_loss = sum(loss_buffer[-TRAILING_LOSS_STEPS:]) / len( + loss_buffer[-TRAILING_LOSS_STEPS:] + ) + + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + if scheduler: + scheduler.step() + lr_for_print = "{:.2e}".format(scheduler.get_last_lr()[0]) + + if accelerator.is_main_process: + accelerator.save_state("/mnt/ssd1/aria/test") + + return model + + +def train_classifier( + emb_d: int, + train_dataset: EvaluationDataset, + tag_to_id: dict, + batch_size: int, +): + num_epochs = 1 + train_dataloader = torch.utils.data.DataLoader( + dataset=train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=24, + worker_init_fn=EvaluationDataset.export_worker_init_fn(), + ) + + model = ClassifierHead( + d_emb=emb_d, + num_class=len(tag_to_id.keys()), + ) + optimizer, scheduler = _get_optim( + lr=3e-4, + model=model, + total_steps=num_epochs * len(train_dataloader), + ) + accelerator = accelerate.Accelerator() + + model, train_dataloader, optimizer, scheduler = accelerator.prepare( + model, + train_dataloader, + optimizer, + scheduler, + ) + + return _train( + accelerator=accelerator, + model=model, + train_dataloader=train_dataloader, + optimizer=optimizer, + scheduler=scheduler, + num_epochs=num_epochs, + ) + + +def evaluate_model( + model: nn.Module, + val_dataset_path: str, + metadata_category: str, + tag_to_id: dict, +): + id_to_tag = {v: k for k, v in tag_to_id.items()} + val_dataset = EvaluationDataset( + load_path=val_dataset_path, + tag_to_id=tag_to_id, + metadata_category=metadata_category, + ) + model = model.cpu().eval() + + dist = {k: {"correct": 0, "total": 0} for k in tag_to_id.keys()} + pred_dist = {k: 0 for k in tag_to_id.keys()} + + for midi_emb, tag_id in val_dataset: + with torch.no_grad(): + logits = model(torch.tensor(midi_emb.view(1, -1))) + probs = F.softmax(logits, dim=-1) + pred_tag_id = probs.argmax(dim=-1).item() + + true_tag = id_to_tag[tag_id.item()] + pred_tag = id_to_tag[pred_tag_id] + + dist[true_tag]["total"] += 1 + pred_dist[pred_tag] += 1 + + if pred_tag_id == tag_id.item(): + dist[true_tag]["correct"] += 1 + + total_correct = sum(v["correct"] for v in dist.values()) + total_samples = sum(v["total"] for v in dist.values()) + print(f"Total accuracy: {total_correct/total_samples}") + + for tag in tag_to_id.keys(): + TP = dist[tag]["correct"] + FN = dist[tag]["total"] - TP + FP = pred_dist[tag] - TP + precision = TP / (TP + FP) if (TP + FP) > 0 else 0 + recall = TP / (TP + FN) if (TP + FN) > 0 else 0 + f1 = ( + 2 * precision * recall / (precision + recall) + if (precision + recall) > 0 + else 0 + ) + print( + f"{tag} -- Accuracy: {TP/dist[tag]['total']}, Precision: {precision}, Recall: {recall}, F1: {f1}" + ) + + +# TODO: Move this to the build_embedding_eval_datasets.py script +def build_baseline_dataset(): + MAX_SEQ_LEN = 512 + MODEL_PATH = "/mnt/ssd1/aria/v2/medium-dedupe-pt-cont2/checkpoints/epoch18_step0/model.safetensors" + + tokenizer = AbsTokenizer() + model_state = _load_weight(MODEL_PATH, "cuda") + model_state = { + k.replace("_orig_mod.", ""): v for k, v in model_state.items() + } + pretrained_model_config = ModelConfig(**load_model_config("medium")) + pretrained_model_config.set_vocab_size(tokenizer.vocab_size) + pretrained_model_config.grad_checkpoint = False + pretrained_model = TransformerLM(pretrained_model_config) + pretrained_model.load_state_dict(model_state) + pretrained_model.eval() + + global model_forward + model_forward = torch.compile( + model_forward, + mode="reduce-overhead", + fullgraph=True, + ) + + EvaluationDataset.build( + midi_dataset_load_path="/mnt/ssd1/aria/data/mididict-ft_train.jsonl", + save_path="/mnt/ssd1/aria/data/train.jsonl", + max_seq_len=MAX_SEQ_LEN, + slice_len_notes=165, + batch_size=128, + embedding_hook=functools.partial( + get_baseline_embedding, pool_mode="mean" + ), + hook_model=pretrained_model.model.cuda(), + hook_max_seq_len=MAX_SEQ_LEN, + hook_tokenizer=tokenizer, + ) + + +def eval_all(): + metadata_category = "music_period" + tag_to_id = CATEGORY_TAGS[metadata_category] + + dataset = EvaluationDataset( + load_path="/mnt/ssd1/aria/data/paper/classification/period-aria/train-aria.jsonl", + metadata_category=metadata_category, + tag_to_id=tag_to_id, + ) + model = train_classifier( + emb_d=512, + train_dataset=dataset, + batch_size=8, + tag_to_id=tag_to_id, + ) + print("ARIA aria_midi-test:") + evaluate_model( + model=model, + val_dataset_path="/mnt/ssd1/aria/data/paper/classification/period-aria/test-aria.jsonl", + metadata_category=metadata_category, + tag_to_id=tag_to_id, + ) + + ### + + dataset = EvaluationDataset( + load_path="/mnt/ssd1/aria/data/paper/classification/period-aria/train-m3.jsonl", + metadata_category=metadata_category, + tag_to_id=tag_to_id, + ) + model = train_classifier( + emb_d=768, + train_dataset=dataset, + batch_size=8, + tag_to_id=tag_to_id, + ) + print("M3 aria_midi-test:") + evaluate_model( + model=model, + val_dataset_path="/mnt/ssd1/aria/data/paper/classification/period-aria/test-m3.jsonl", + metadata_category=metadata_category, + tag_to_id=tag_to_id, + ) + + +if __name__ == "__main__": + # TODO: Move this + eval_all() diff --git a/aria/embeddings/m3/__init__.py b/aria/embeddings/m3/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/aria/embeddings/m3/config.py b/aria/embeddings/m3/config.py new file mode 100644 index 0000000..893a1c0 --- /dev/null +++ b/aria/embeddings/m3/config.py @@ -0,0 +1,79 @@ +EVAL_SPLIT = 0.01 # Fraction of training data used for evaluation +WANDB_KEY = "" # Weights and Biases API key + +# -------------------- Configuration for M3 Training -------------------- +M3_TRAIN_FOLDERS = [ + "" # Directory containing training data for M3 +] + +M3_EVAL_FOLDERS = [ + "" # Directory containing evaluation data for M3 (optional) +] + +PATCH_SIZE = 64 # Size of each patch +PATCH_LENGTH = 512 # Length of the patches +PATCH_NUM_LAYERS = 12 # Number of layers in the encoder +TOKEN_NUM_LAYERS = 3 # Number of layers in the decoder +M3_HIDDEN_SIZE = 768 # Size of the hidden layer + +M3_NUM_EPOCH = 100 # Maximum number of epochs for training +M3_LEARNING_RATE = 1e-4 # Learning rate for the optimizer +M3_BATCH_SIZE = 16 # Batch size per GPU (single card) during training +M3_MASK_RATIO = 0.45 # Ratio of masked elements during training +M3_DETERMINISTIC = True # Ensures deterministic results with random seeds +M3_WANDB_LOG = True # Enable logging to Weights and Biases +M3_LOAD_CKPT = True # Load model weights from a checkpoint if available + +M3_WEIGHTS_PATH = ( + "weights_m3"+ + "_h_size_" + str(M3_HIDDEN_SIZE) + + "_t_layers_" + str(TOKEN_NUM_LAYERS) + + "_p_layers_" + str(PATCH_NUM_LAYERS) + + "_p_size_" + str(PATCH_SIZE) + + "_p_length_" + str(PATCH_LENGTH) + + "_lr_" + str(M3_LEARNING_RATE) + + "_batch_" + str(M3_BATCH_SIZE) + + "_mask_" + str(M3_MASK_RATIO) + ".pth" +) # Path to store the model weights +M3_LOGS_PATH = M3_WEIGHTS_PATH.replace("weights", "logs").replace("pth", "txt") # Path to save training logs + +# -------------------- Configuration for CLaMP3 Training ---------------- +CLAMP3_TRAIN_JSONL = "" # Path to the JSONL file with training data for CLaMP3 +CLAMP3_EVAL_JSONL = "" # Path to the JSONL file with evaluation data for CLaMP3 (optional) + +CLAMP3_HIDDEN_SIZE = 768 # Size of the hidden layer +TEXT_MODEL_NAME = "FacebookAI/xlm-roberta-base" # Name of the pre-trained text model +MAX_TEXT_LENGTH = 128 # Maximum allowed length for text input + +AUDIO_HIDDEN_SIZE = 768 # Size of the hidden layer for audio features +AUDIO_NUM_LAYERS = 12 # Number of layers in the audio encoder +MAX_AUDIO_LENGTH = 128 # Maximum allowed length for audio input + +CLAMP3_NUM_EPOCH = 100 # Maximum number of epochs for training +CLAMP3_LEARNING_RATE = 1e-5 # Learning rate for the optimizer +CLAMP3_BATCH_SIZE = 256 # Batch size per GPU (single card) during training +LOGIT_SCALE = 1 # Scaling factor for contrastive loss + +FREEZE_TEXT = False # Freeze the weights of the text model and text projection layer +TEXT_DROPOUT = True # Whether to apply dropout during text processing +CLAMP3_DETERMINISTIC = True # Ensures deterministic results with random seeds +CLAMP3_LOAD_M3 = True # Load weights from the M3 model +CLAMP3_WANDB_LOG = True # Enable logging to Weights and Biases +CLAMP3_LOAD_CKPT = True # Load weights from a checkpoint if available +SAVE_EVERY = 5 # Save model weights every SAVE_EVERY epochs + +CLAMP3_WEIGHTS_PATH = ( + "weights_clamp3_saas" + + "_h_size_" + str(CLAMP3_HIDDEN_SIZE) + + "_t_model_" + TEXT_MODEL_NAME.replace("/", "_") + + "_t_length_" + str(MAX_TEXT_LENGTH) + + "_a_size_" + str(AUDIO_HIDDEN_SIZE) + + "_a_layers_" + str(AUDIO_NUM_LAYERS) + + "_a_length_" + str(MAX_AUDIO_LENGTH) + + "_s_size_" + str(M3_HIDDEN_SIZE) + + "_s_layers_" + str(PATCH_NUM_LAYERS) + + "_p_size_" + str(PATCH_SIZE) + + "_p_length_" + str(PATCH_LENGTH) + ".pth" + +) # Path to store CLaMP3 model weights +CLAMP3_LOGS_PATH = CLAMP3_WEIGHTS_PATH.replace("weights", "logs").replace("pth", "txt") # Path to save training logs diff --git a/aria/embeddings/m3/emb.py b/aria/embeddings/m3/emb.py new file mode 100644 index 0000000..9b67ef2 --- /dev/null +++ b/aria/embeddings/m3/emb.py @@ -0,0 +1,216 @@ +import os +import torch +import mido +from transformers import BertConfig + +from aria.embeddings.m3.config import ( + AUDIO_HIDDEN_SIZE, + AUDIO_NUM_LAYERS, + MAX_AUDIO_LENGTH, + M3_HIDDEN_SIZE, + PATCH_NUM_LAYERS, + PATCH_LENGTH, + PATCH_SIZE, + CLAMP3_HIDDEN_SIZE, + CLAMP3_LOAD_M3, + TEXT_MODEL_NAME, +) + +from aria.embeddings.m3.utils import CLaMP3Model, M3Patchilizer + + +def msg_to_str(msg): + str_msg = "" + for key, value in msg.dict().items(): + str_msg += " " + str(value) + return str_msg.strip().encode("unicode_escape").decode("utf-8") + + +def load_midi( + filename: str | None = None, + mid: mido.MidiFile | None = None, + m3_compatible: bool = True, +): + """ + Load a MIDI file and convert it to MTF format. + """ + + if mid is None: + assert os.path.isfile(filename) + mid = mido.MidiFile(filename) + + msg_list = ["ticks_per_beat " + str(mid.ticks_per_beat)] + + # Merge tracks manually using mido.merge_tracks() + merged = mido.merge_tracks(mid.tracks) + + for msg in merged: + if m3_compatible and msg.is_meta: + if msg.type in [ + "text", + "copyright", + "track_name", + "instrument_name", + "lyrics", + "marker", + "cue_marker", + "device_name", + ]: + continue + str_msg = msg_to_str(msg) + msg_list.append(str_msg) + + return "\n".join(msg_list) + + +def load_clamp3_model(checkpoint_path: str): + """ + Loads the CLaMP3 model along with its configuration and M3Patchilizer. + The model weights are loaded from the checkpoint specified in CLAMP3_WEIGHTS_PATH. + """ + # Build the configurations for audio and symbolic (M3) parts. + audio_config = BertConfig( + vocab_size=1, + hidden_size=AUDIO_HIDDEN_SIZE, + num_hidden_layers=AUDIO_NUM_LAYERS, + num_attention_heads=AUDIO_HIDDEN_SIZE // 64, + intermediate_size=AUDIO_HIDDEN_SIZE * 4, + max_position_embeddings=MAX_AUDIO_LENGTH, + ) + symbolic_config = BertConfig( + vocab_size=1, + hidden_size=M3_HIDDEN_SIZE, + num_hidden_layers=PATCH_NUM_LAYERS, + num_attention_heads=M3_HIDDEN_SIZE // 64, + intermediate_size=M3_HIDDEN_SIZE * 4, + max_position_embeddings=PATCH_LENGTH, + ) + + # Instantiate the CLaMP3 model. + model = CLaMP3Model( + audio_config=audio_config, + symbolic_config=symbolic_config, + text_model_name=TEXT_MODEL_NAME, + hidden_size=CLAMP3_HIDDEN_SIZE, + load_m3=CLAMP3_LOAD_M3, + ) + model = model.to("cuda") + model.eval() + + # Determine checkpoint path. + if not os.path.exists(checkpoint_path): + raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + + # Load checkpoint weights. + checkpoint = torch.load( + checkpoint_path, map_location="cuda", weights_only=True + ) + model.load_state_dict(checkpoint["model"]) + + # Instantiate the patchilizer from utils. + patchilizer = M3Patchilizer() + + return model, patchilizer + + +def get_midi_embedding( + mid: mido.MidiFile, + model: torch.nn.Module, + patchilizer: M3Patchilizer, + get_global=True, +): + device = "cuda" + # Step 1: Convert MIDI to an MTF-format string. + mtf_str = load_midi(mid=mid, m3_compatible=True) + + # Step 3: Encode the MTF string to patches. + # The patchilizer returns a list of patches (each patch is a list of token IDs). + patches = patchilizer.encode(mtf_str, add_special_patches=True) + + # Convert the list of patches to a tensor. + # Each patch is of length PATCH_SIZE; we assume tokens are integers. + # The resulting tensor shape will be [num_patches, PATCH_SIZE]. + token_tensor = torch.tensor(patches, dtype=torch.long).to(device) + + # Step 4: Create segments of fixed length PATCH_LENGTH. + num_tokens = token_tensor.size(0) + segments = [] + seg_weights = [] + for i in range(0, num_tokens, PATCH_LENGTH): + seg = token_tensor[i : i + PATCH_LENGTH] + cur_len = seg.size(0) + segments.append(seg) + seg_weights.append(cur_len) + # For global feature extraction, ensure the last segment is exactly PATCH_LENGTH tokens + if num_tokens > PATCH_LENGTH: + segments[-1] = token_tensor[-PATCH_LENGTH:] + seg_weights[-1] = segments[-1].size(0) + + # Step 5: Process each segment to obtain features. + processed_feats = [] + for seg in segments: + cur_len = seg.size(0) + # Pad the segment if it's shorter than PATCH_LENGTH. + if cur_len < PATCH_LENGTH: + pad = torch.full( + ( + PATCH_LENGTH - cur_len, + token_tensor.size(1), + ), # include PATCH_SIZE dimension + patchilizer.pad_token_id, + dtype=torch.long, + device=device, + ) + seg = torch.cat([seg, pad], dim=0) + seg = seg.unsqueeze(0) # Add batch dimension. + # Create a mask: 1 for valid tokens, 0 for padding. + mask = torch.cat( + [ + torch.ones(cur_len, device=device), + torch.zeros(PATCH_LENGTH - cur_len, device=device), + ], + dim=0, + ).unsqueeze(0) + with torch.no_grad(): + feat = model.get_symbolic_features( + symbolic_inputs=seg, symbolic_masks=mask, get_global=get_global + ) + # When not getting a global feature, you may wish to trim the output to the valid length. + if not get_global: + feat = feat[:, : int(mask.sum().item()), :] + processed_feats.append(feat) + + # Step 6: Combine segment features into a single embedding. + if not get_global: + # Concatenate features along the time dimension. + embedding = torch.cat( + [feat.squeeze(0) for feat in processed_feats], dim=0 + ) + else: + # For a global embedding, compute a weighted average of segment features. + feats = torch.stack( + [feat.squeeze(0) for feat in processed_feats], dim=0 + ) + weights = torch.tensor( + seg_weights, dtype=torch.float, device=device + ).view(-1, 1) + embedding = (feats * weights).sum(dim=0) / weights.sum() + + return embedding.view(-1) + + +# Example usage: +if __name__ == "__main__": + midi_file_path = "/home/loubb/Dropbox/shared/test/test.mid" + model, patchilizer = load_clamp3_model( + "/home/loubb/work/clamp3/weights_clamp3_saas_h_size_768_t_model_FacebookAI_xlm-roberta-base_t_length_128_a_size_768_a_layers_12_a_length_128_s_size_768_s_layers_12_p_size_64_p_length_512.pth" + ) + mid = mido.MidiFile(midi_file_path) + embedding = get_midi_embedding( + mid, + model=model, + patchilizer=patchilizer, + get_global=True, + ) + print(embedding) + print(embedding.shape) diff --git a/aria/embeddings/m3/utils.py b/aria/embeddings/m3/utils.py new file mode 100644 index 0000000..3952ac7 --- /dev/null +++ b/aria/embeddings/m3/utils.py @@ -0,0 +1,702 @@ +import re +import os +import math +import torch +import random +from aria.embeddings.m3.config import * +from unidecode import unidecode +from torch.nn import functional as F +from transformers import ( + AutoModel, + BertModel, + GPT2LMHeadModel, + PreTrainedModel, + GPT2Config, +) + +try: + import torch.distributed.nn + from torch import distributed as dist + + has_distributed = True +except ImportError: + has_distributed = False + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + + +class ClipLoss(torch.nn.Module): + + def __init__( + self, + local_loss=False, + gather_with_grad=False, + cache_labels=False, + rank=0, + world_size=1, + use_horovod=False, + ): + super().__init__() + self.local_loss = local_loss + self.gather_with_grad = gather_with_grad + self.cache_labels = cache_labels + self.rank = rank + self.world_size = world_size + self.use_horovod = use_horovod + + # cache state + self.prev_num_logits = 0 + self.labels = {} + + def gather_features( + self, + image_features, + text_features, + local_loss=False, + gather_with_grad=False, + rank=0, + world_size=1, + use_horovod=False, + ): + assert ( + has_distributed + ), "torch.distributed did not import correctly, please use a PyTorch version with support." + if use_horovod: + assert hvd is not None, "Please install horovod" + if gather_with_grad: + all_image_features = hvd.allgather(image_features) + all_text_features = hvd.allgather(text_features) + else: + with torch.no_grad(): + all_image_features = hvd.allgather(image_features) + all_text_features = hvd.allgather(text_features) + if not local_loss: + # ensure grads for local rank when all_* features don't have a gradient + gathered_image_features = list( + all_image_features.chunk(world_size, dim=0) + ) + gathered_text_features = list( + all_text_features.chunk(world_size, dim=0) + ) + gathered_image_features[rank] = image_features + gathered_text_features[rank] = text_features + all_image_features = torch.cat( + gathered_image_features, dim=0 + ) + all_text_features = torch.cat(gathered_text_features, dim=0) + else: + # We gather tensors from all gpus + if gather_with_grad: + all_image_features = torch.cat( + torch.distributed.nn.all_gather(image_features), dim=0 + ) + all_text_features = torch.cat( + torch.distributed.nn.all_gather(text_features), dim=0 + ) + else: + gathered_image_features = [ + torch.zeros_like(image_features) for _ in range(world_size) + ] + gathered_text_features = [ + torch.zeros_like(text_features) for _ in range(world_size) + ] + dist.all_gather(gathered_image_features, image_features) + dist.all_gather(gathered_text_features, text_features) + if not local_loss: + # ensure grads for local rank when all_* features don't have a gradient + gathered_image_features[rank] = image_features + gathered_text_features[rank] = text_features + all_image_features = torch.cat(gathered_image_features, dim=0) + all_text_features = torch.cat(gathered_text_features, dim=0) + + return all_image_features, all_text_features + + def get_ground_truth(self, device, num_logits) -> torch.Tensor: + # calculated ground-truth and cache if enabled + if self.prev_num_logits != num_logits or device not in self.labels: + labels = torch.arange(num_logits, device=device, dtype=torch.long) + if self.world_size > 1 and self.local_loss: + labels = labels + num_logits * self.rank + if self.cache_labels: + self.labels[device] = labels + self.prev_num_logits = num_logits + else: + labels = self.labels[device] + return labels + + def get_logits(self, image_features, text_features, logit_scale): + if self.world_size > 1: + all_image_features, all_text_features = self.gather_features( + image_features, + text_features, + self.local_loss, + self.gather_with_grad, + self.rank, + self.world_size, + self.use_horovod, + ) + + if self.local_loss: + logits_per_image = ( + logit_scale * image_features @ all_text_features.T + ) + logits_per_text = ( + logit_scale * text_features @ all_image_features.T + ) + else: + logits_per_image = ( + logit_scale * all_image_features @ all_text_features.T + ) + logits_per_text = logits_per_image.T + else: + logits_per_image = logit_scale * image_features @ text_features.T + logits_per_text = logit_scale * text_features @ image_features.T + + return logits_per_image, logits_per_text + + def forward( + self, image_features, text_features, logit_scale, output_dict=False + ): + device = image_features.device + logits_per_image, logits_per_text = self.get_logits( + image_features, text_features, logit_scale + ) + + labels = self.get_ground_truth(device, logits_per_image.shape[0]) + + total_loss = ( + F.cross_entropy(logits_per_image, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + + return {"contrastive_loss": total_loss} if output_dict else total_loss + + +class M3Patchilizer: + def __init__(self): + self.delimiters = ["|:", "::", ":|", "[|", "||", "|]", "|"] + self.regexPattern = ( + "(" + "|".join(map(re.escape, self.delimiters)) + ")" + ) + self.pad_token_id = 0 + self.bos_token_id = 1 + self.eos_token_id = 2 + self.mask_token_id = 3 + + def split_bars(self, body): + bars = re.split(self.regexPattern, "".join(body)) + bars = list(filter(None, bars)) # remove empty strings + if bars[0] in self.delimiters: + bars[1] = bars[0] + bars[1] + bars = bars[1:] + bars = [bars[i * 2] + bars[i * 2 + 1] for i in range(len(bars) // 2)] + return bars + + def bar2patch(self, bar, patch_size=PATCH_SIZE): + patch = ( + [self.bos_token_id] + [ord(c) for c in bar] + [self.eos_token_id] + ) + patch = patch[:patch_size] + patch += [self.pad_token_id] * (patch_size - len(patch)) + return patch + + def patch2bar(self, patch): + return "".join( + chr(idx) if idx > self.mask_token_id else "" for idx in patch + ) + + def encode( + self, + item, + patch_size=PATCH_SIZE, + add_special_patches=False, + truncate=False, + random_truncate=False, + ): + item = item.replace("L:1/8\n", "") + item = unidecode(item) + lines = re.findall(r".*?\n|.*$", item) + lines = list(filter(None, lines)) # remove empty lines + + patches = [] + + if lines[0].split(" ")[0] == "ticks_per_beat": + patch = "" + for line in lines: + if patch.startswith(line.split(" ")[0]) and ( + len(patch) + len(" ".join(line.split(" ")[1:])) + <= patch_size - 2 + ): + patch = patch[:-1] + "\t" + " ".join(line.split(" ")[1:]) + else: + if patch: + patches.append(patch) + patch = line + if patch != "": + patches.append(patch) + else: + for line in lines: + if len(line) > 1 and ( + (line[0].isalpha() and line[1] == ":") + or line.startswith("%%") + ): + patches.append(line) + else: + bars = self.split_bars(line) + if bars: + bars[-1] += "\n" + patches.extend(bars) + + if add_special_patches: + bos_patch = chr(self.bos_token_id) * patch_size + eos_patch = chr(self.eos_token_id) * patch_size + patches = [bos_patch] + patches + [eos_patch] + + if len(patches) > PATCH_LENGTH and truncate: + choices = ["head", "tail", "middle"] + choice = random.choice(choices) + if choice == "head" or random_truncate == False: + patches = patches[:PATCH_LENGTH] + elif choice == "tail": + patches = patches[-PATCH_LENGTH:] + else: + start = random.randint(1, len(patches) - PATCH_LENGTH) + patches = patches[start : start + PATCH_LENGTH] + + patches = [self.bar2patch(patch) for patch in patches] + + return patches + + def decode(self, patches): + return "".join(self.patch2bar(patch) for patch in patches) + + +class M3PatchEncoder(PreTrainedModel): + def __init__(self, config): + super(M3PatchEncoder, self).__init__(config) + self.patch_embedding = torch.nn.Linear(PATCH_SIZE * 128, M3_HIDDEN_SIZE) + torch.nn.init.normal_(self.patch_embedding.weight, std=0.02) + self.base = BertModel(config=config) + self.pad_token_id = 0 + self.bos_token_id = 1 + self.eos_token_id = 2 + self.mask_token_id = 3 + + def forward( + self, + input_patches, # [batch_size, seq_length, hidden_size] + input_masks, + ): # [batch_size, seq_length] + # Transform input_patches into embeddings + input_patches = torch.nn.functional.one_hot( + input_patches, num_classes=128 + ) + input_patches = input_patches.reshape( + len(input_patches), -1, PATCH_SIZE * 128 + ).type(torch.FloatTensor) + input_patches = self.patch_embedding(input_patches.to(self.device)) + + # Apply BERT model to input_patches and input_masks + return self.base( + inputs_embeds=input_patches, attention_mask=input_masks + ) + + +class M3TokenDecoder(PreTrainedModel): + def __init__(self, config): + super(M3TokenDecoder, self).__init__(config) + self.base = GPT2LMHeadModel(config=config) + self.pad_token_id = 0 + self.bos_token_id = 1 + self.eos_token_id = 2 + self.mask_token_id = 3 + + def forward( + self, patch_features, target_patches # [batch_size, hidden_size] + ): # [batch_size, seq_length] + # get input embeddings + inputs_embeds = torch.nn.functional.embedding( + target_patches, self.base.transformer.wte.weight + ) + + # concatenate the encoded patches with the input embeddings + inputs_embeds = torch.cat( + (patch_features.unsqueeze(1), inputs_embeds[:, 1:, :]), dim=1 + ) + + # preparing the labels for model training + target_masks = target_patches == self.pad_token_id + target_patches = target_patches.clone().masked_fill_(target_masks, -100) + + # get the attention mask + target_masks = ~target_masks + target_masks = target_masks.type(torch.int) + + return self.base( + inputs_embeds=inputs_embeds, + attention_mask=target_masks, + labels=target_patches, + ) + + def generate(self, patch_feature, tokens): + # reshape the patch_feature and tokens + patch_feature = patch_feature.reshape(1, 1, -1) + tokens = tokens.reshape(1, -1) + + # get input embeddings + tokens = torch.nn.functional.embedding( + tokens, self.base.transformer.wte.weight + ) + + # concatenate the encoded patches with the input embeddings + tokens = torch.cat((patch_feature, tokens[:, 1:, :]), dim=1) + + # get the outputs from the model + outputs = self.base(inputs_embeds=tokens) + + # get the probabilities of the next token + probs = torch.nn.functional.softmax( + outputs.logits.squeeze(0)[-1], dim=-1 + ) + + return probs.detach().cpu().numpy() + + +class M3Model(PreTrainedModel): + def __init__(self, encoder_config, decoder_config): + super(M3Model, self).__init__(encoder_config) + self.encoder = M3PatchEncoder(encoder_config) + self.decoder = M3TokenDecoder(decoder_config) + self.pad_token_id = 0 + self.bos_token_id = 1 + self.eos_token_id = 2 + self.mask_token_id = 3 + + def forward( + self, + input_patches, # [batch_size, seq_length, hidden_size] + input_masks, # [batch_size, seq_length] + selected_indices, # [batch_size, seq_length] + target_patches, + ): # [batch_size, seq_length, hidden_size] + input_patches = input_patches.reshape( + len(input_patches), -1, PATCH_SIZE + ).to(self.device) + input_masks = input_masks.to(self.device) + selected_indices = selected_indices.to(self.device) + target_patches = target_patches.reshape( + len(target_patches), -1, PATCH_SIZE + ).to(self.device) + + # Pass the input_patches and input_masks through the encoder + outputs = self.encoder(input_patches, input_masks)["last_hidden_state"] + + # Use selected_indices to form target_patches + target_patches = target_patches[selected_indices.bool()] + patch_features = outputs[selected_indices.bool()] + + # Pass patch_features and target_patches through the decoder + return self.decoder(patch_features, target_patches) + + +class CLaMP3Model(PreTrainedModel): + def __init__( + self, + audio_config, + symbolic_config, + global_rank=None, + world_size=None, + text_model_name=TEXT_MODEL_NAME, + hidden_size=CLAMP3_HIDDEN_SIZE, + load_m3=CLAMP3_LOAD_M3, + ): + super(CLaMP3Model, self).__init__(symbolic_config) + + self.text_model = AutoModel.from_pretrained( + text_model_name + ) # Load the text model + self.text_proj = torch.nn.Linear( + self.text_model.config.hidden_size, hidden_size + ) # Linear layer for text projections + torch.nn.init.normal_( + self.text_proj.weight, std=0.02 + ) # Initialize weights with normal distribution + + self.symbolic_model = M3PatchEncoder( + symbolic_config + ) # Initialize the symbolic model + self.symbolic_proj = torch.nn.Linear( + M3_HIDDEN_SIZE, hidden_size + ) # Linear layer for symbolic projections + torch.nn.init.normal_( + self.symbolic_proj.weight, std=0.02 + ) # Initialize weights with normal distribution + + self.audio_model = BertModel(audio_config) # Initialize the audio model + self.audio_proj = torch.nn.Linear( + audio_config.hidden_size, hidden_size + ) # Linear layer for audio projections + torch.nn.init.normal_( + self.audio_proj.weight, std=0.02 + ) # Initialize weights with normal distribution + + if global_rank == None or world_size == None: + global_rank = 0 + world_size = 1 + + self.loss_fn = ClipLoss( + local_loss=False, + gather_with_grad=True, + cache_labels=False, + rank=global_rank, + world_size=world_size, + use_horovod=False, + ) + + if load_m3 and os.path.exists(M3_WEIGHTS_PATH): + checkpoint = torch.load( + M3_WEIGHTS_PATH, map_location="cpu", weights_only=True + ) + decoder_config = GPT2Config( + vocab_size=128, + n_positions=PATCH_SIZE, + n_embd=M3_HIDDEN_SIZE, + n_layer=TOKEN_NUM_LAYERS, + n_head=M3_HIDDEN_SIZE // 64, + n_inner=M3_HIDDEN_SIZE * 4, + ) + model = M3Model(symbolic_config, decoder_config) + model.load_state_dict(checkpoint["model"]) + self.symbolic_model = model.encoder + model = None + print( + f"Successfully Loaded M3 Checkpoint from Epoch {checkpoint['epoch']} with loss {checkpoint['min_eval_loss']}" + ) + + def set_trainable(self, freeze_list): + if "text_model" in freeze_list: + self.text_model.eval() + for param in self.text_model.parameters(): + param.requires_grad = False + print("Text Model Frozen") + else: + self.text_model.train() + for param in self.text_model.parameters(): + param.requires_grad = True + print("Text Model Training") + + if "text_proj" in freeze_list: + self.text_proj.eval() + for param in self.text_proj.parameters(): + param.requires_grad = False + print("Text Projection Layer Frozen") + else: + self.text_proj.train() + for param in self.text_proj.parameters(): + param.requires_grad = True + print("Text Projection Layer Training") + + if "symbolic_model" in freeze_list: + self.symbolic_model.eval() + for param in self.symbolic_model.parameters(): + param.requires_grad = False + print("Symbolic Model Frozen") + else: + self.symbolic_model.train() + for param in self.symbolic_model.parameters(): + param.requires_grad = True + print("Symbolic Model Training") + + if "symbolic_proj" in freeze_list: + self.symbolic_proj.eval() + for param in self.symbolic_proj.parameters(): + param.requires_grad = False + print("Symbolic Projection Layer Frozen") + else: + self.symbolic_proj.train() + for param in self.symbolic_proj.parameters(): + param.requires_grad = True + print("Symbolic Projection Layer Training") + + if "audio_model" in freeze_list: + self.audio_model.eval() + for param in self.audio_model.parameters(): + param.requires_grad = False + print("Audio Model Frozen") + else: + self.audio_model.train() + for param in self.audio_model.parameters(): + param.requires_grad = True + print("Audio Model Training") + + if "audio_proj" in freeze_list: + self.audio_proj.eval() + for param in self.audio_proj.parameters(): + param.requires_grad = False + print("Audio Projection Layer Frozen") + else: + self.audio_proj.train() + for param in self.audio_proj.parameters(): + param.requires_grad = True + print("Audio Projection Layer Training") + + def avg_pooling(self, input_features, input_masks): + input_masks = input_masks.unsqueeze(-1).to( + self.device + ) # add a dimension to match the feature dimension + input_features = ( + input_features * input_masks + ) # apply mask to input_features + avg_pool = input_features.sum(dim=1) / input_masks.sum( + dim=1 + ) # calculate average pooling + + return avg_pool + + def get_text_features(self, text_inputs, text_masks, get_global=False): + text_features = self.text_model( + text_inputs.to(self.device), + attention_mask=text_masks.to(self.device), + )["last_hidden_state"] + + if get_global: + text_features = self.avg_pooling(text_features, text_masks) + text_features = self.text_proj(text_features) + + return text_features + + def get_symbolic_features( + self, symbolic_inputs, symbolic_masks, get_global=False + ): + symbolic_features = self.symbolic_model( + symbolic_inputs.to(self.device), symbolic_masks.to(self.device) + )["last_hidden_state"] + + if get_global: + symbolic_features = self.avg_pooling( + symbolic_features, symbolic_masks + ) + symbolic_features = self.symbolic_proj(symbolic_features) + + return symbolic_features + + def get_audio_features(self, audio_inputs, audio_masks, get_global=False): + audio_features = self.audio_model( + inputs_embeds=audio_inputs.to(self.device), + attention_mask=audio_masks.to(self.device), + )["last_hidden_state"] + + if get_global: + audio_features = self.avg_pooling(audio_features, audio_masks) + audio_features = self.audio_proj(audio_features) + + return audio_features + + def forward( + self, + text_inputs, # [batch_size, seq_length] + text_masks, # [batch_size, seq_length] + music_inputs, # [batch_size, seq_length, hidden_size] + music_masks, # [batch_size, seq_length] + music_modality, + ): # "symbolic" or "audio" + # Compute the text features + text_features = self.get_text_features( + text_inputs, text_masks, get_global=True + ) + + # Compute the music features + if music_modality == "symbolic": + music_features = self.get_symbolic_features( + music_inputs, music_masks, get_global=True + ) + elif music_modality == "audio": + music_features = self.get_audio_features( + music_inputs, music_masks, get_global=True + ) + else: + raise ValueError( + "music_modality must be either 'symbolic' or 'audio'" + ) + + return self.loss_fn( + text_features, music_features, LOGIT_SCALE, output_dict=False + ) + + +def split_data(data, eval_ratio=EVAL_SPLIT): + random.shuffle(data) + split_idx = int(len(data) * eval_ratio) + eval_set = data[:split_idx] + train_set = data[split_idx:] + return train_set, eval_set + + +def mask_patches(target_patches, patchilizer, mode): + indices = list(range(len(target_patches))) + random.shuffle(indices) + selected_indices = indices[: math.ceil(M3_MASK_RATIO * len(indices))] + sorted_indices = sorted(selected_indices) + input_patches = torch.tensor(target_patches) + + if mode == "eval": + choice = "original" + else: + choice = random.choices( + ["mask", "shuffle", "original"], weights=[0.8, 0.1, 0.1] + )[0] + + if choice == "mask": + input_patches[sorted_indices] = torch.tensor( + [patchilizer.mask_token_id] * PATCH_SIZE + ) + elif choice == "shuffle": + for idx in sorted_indices: + patch = input_patches[idx] + try: + index_eos = (patch == patchilizer.eos_token_id).nonzero().item() + except: + index_eos = len(patch) + + indices = list(range(1, index_eos)) + random.shuffle(indices) + indices = [0] + indices + list(range(index_eos, len(patch))) + input_patches[idx] = patch[indices] + + selected_indices = torch.zeros(len(target_patches)) + selected_indices[sorted_indices] = 1.0 + + return input_patches, selected_indices + + +def remove_instrument_info(item): + # remove instrument information from symbolic music + lines = re.findall(r".*?\n|.*$", item) + lines = list(filter(None, lines)) + if lines[0].split(" ")[0] == "ticks_per_beat": + type = "mtf" + else: + type = "abc" + + cleaned_lines = [] + for line in lines: + if type == "abc" and line.startswith("V:"): + # find the position of " nm=" or " snm=" + nm_pos = line.find(" nm=") + snm_pos = line.find(" snm=") + # keep the part before " nm=" or " snm=" + if nm_pos != -1: + line = line[:nm_pos] + elif snm_pos != -1: + line = line[:snm_pos] + if nm_pos != -1 or snm_pos != -1: + line += "\n" + elif type == "mtf" and line.startswith("program_change"): + line = " ".join(line.split(" ")[:-1]) + " 0\n" + + cleaned_lines.append(line) + + return "".join(cleaned_lines) diff --git a/aria/embeddings/mert/__init__.py b/aria/embeddings/mert/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/aria/embeddings/mert/emb.py b/aria/embeddings/mert/emb.py new file mode 100644 index 0000000..9f7d184 --- /dev/null +++ b/aria/embeddings/mert/emb.py @@ -0,0 +1,150 @@ +import torch +import tempfile +import shlex +import os +import torchaudio + +import torchaudio.transforms as T +import torch.nn.functional as F +import torch.nn as nn + +from ariautils.midi import MidiDict +from ariautils.tokenizer import AbsTokenizer + +from transformers import Wav2Vec2FeatureExtractor, AutoModel + + +def seq_to_audio_path( + seq: list, tokenizer: AbsTokenizer, pianoteq_exec_path: str +): + mid_temp = tempfile.NamedTemporaryFile(suffix=".mid", delete=False) + mid_path = mid_temp.name + mid_temp.close() + + mid = tokenizer.detokenize(seq) + mid.to_midi().save(mid_path) + + # Step 3: Create a temporary WAV file for output + audio_temp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) + audio_path = audio_temp.name + audio_temp.close() # Close so CLI can write to it + + # Step 4: Run CLI command to generate audio using Pianoteq + # EXEC_PATH = "/home/loubb/pianoteq/x86-64bit/Pianoteq 8 STAGE" + preset = "NY Steinway D Classical Recording" + + pianoteq_cmd = f"{shlex.quote(pianoteq_exec_path)} --preset {shlex.quote(preset)} --rate 24000 --midi {mid_path} --wav {audio_path}" + os.system(pianoteq_cmd) + + os.remove(mid_path) + + return audio_path + + +def compute_audio_embedding( + audio_path: str, model: nn.Module, processor, delete_audio: bool = False +) -> torch.Tensor: + """ + Loads the MERT-v1-330M model and processor, reads an mp3 file, + segments the audio into 5-second chunks, computes a segment embedding by averaging + over the time dimension (for each layer) and across layers, and then aggregates + the segment embeddings using average pooling to produce a final embedding. + + Parameters: + file_path (str): Path to the mp3 audio file. + + Returns: + torch.Tensor: The final audio embedding. + """ + # Load the mp3 file and convert to mono if necessary (waveform shape: [channels, time]) + waveform, sr = torchaudio.load(audio_path) + + if waveform.size(0) > 1: + waveform = waveform.mean(dim=0, keepdim=True) + + # Resample if needed (target_sr for MERT-v1-330M is typically 24000 Hz) + target_sr = processor.sampling_rate + if sr != target_sr: + resampler = T.Resample(orig_freq=sr, new_freq=target_sr) + waveform = resampler(waveform) + + # Remove channel dimension to get [n_samples] + waveform = waveform.squeeze(0) + + # Define the segment length for 5 seconds + segment_length = target_sr * 5 + total_samples = waveform.size(0) + segments = [] + + # Split the waveform into segments; pad the final segment if needed + for start in range(0, total_samples, segment_length): + segment = waveform[start : start + segment_length] + if segment.size(0) < segment_length: + padding = segment_length - segment.size(0) + segment = F.pad(segment, (0, padding)) + segments.append(segment.numpy()) + + # Process all segments in one batch. The processor accepts a list of numpy arrays. + inputs = processor(segments, sampling_rate=target_sr, return_tensors="pt") + inputs = {k: v.cuda() for k, v in inputs.items()} + + # Forward pass through the model in batch mode + with torch.no_grad(): + outputs = model(**inputs, output_hidden_states=True) + + # outputs.hidden_states is a tuple of tensors (one per layer) of shape: + # [batch_size, time_steps, feature_dim] for each layer. + # Stack them to get shape: [num_layers, batch_size, time_steps, feature_dim] + hidden_states = torch.stack(outputs.hidden_states) + + # Average over the time dimension for each segment in each layer: + # result shape: [num_layers, batch_size, feature_dim] + layer_time_avg = hidden_states.mean(dim=2) + + # Average over layers to obtain one embedding per segment: + # result shape: [batch_size, feature_dim] + segment_embeddings = layer_time_avg.mean(dim=0) + + # Finally, average the segment embeddings to get a final representation: + # shape: [feature_dim] + final_embedding = segment_embeddings.mean(dim=0) + + if delete_audio is True: + os.remove(audio_path) + + return final_embedding + + +def load_mert_model(): + + return AutoModel.from_pretrained( + "m-a-p/MERT-v1-330M", trust_remote_code=True + ).cuda(), Wav2Vec2FeatureExtractor.from_pretrained( + "m-a-p/MERT-v1-330M", trust_remote_code=True + ) + + +def main(): + model = AutoModel.from_pretrained( + "m-a-p/MERT-v1-330M", trust_remote_code=True + ).cuda() + processor = Wav2Vec2FeatureExtractor.from_pretrained( + "m-a-p/MERT-v1-330M", trust_remote_code=True + ) + + tokenizer = AbsTokenizer() + mid_dict = MidiDict.from_midi("/home/loubb/Dropbox/shared/test.mid") + seq = tokenizer.tokenize(mid_dict) + + audio_path = seq_to_audio_path(seq, tokenizer) + emb = compute_audio_embedding( + audio_path=audio_path, + model=model, + processor=processor, + delete_audio=True, + ) + print(emb.shape) + + +if __name__ == "__main__": + main() diff --git a/paper/scripts/build_embedding_eval_datasets.py b/paper/scripts/build_embedding_eval_datasets.py new file mode 100644 index 0000000..8c09b8e --- /dev/null +++ b/paper/scripts/build_embedding_eval_datasets.py @@ -0,0 +1,208 @@ +import os +import argparse +import torch + +import torch.nn as nn + +from ariautils.tokenizer import AbsTokenizer +from aria.embeddings.evaluate import ( + EvaluationDataset, + get_aria_contrastive_embedding, + get_clamp3_embedding, + get_mert_embedding, +) + +MAX_SEQ_LEN = 1024 +NUM_SLICE_NOTES = 300 +SEQS_BATCH_SIZE = 128 + + +def aria_model_forward( + model: nn.Module, + idxs: torch.Tensor, +): + return model(idxs) + + +def build_aria_dataset( + midi_dataset_load_path: str, + embedding_dataset_save_path: str, + checkpoint_path: str, + max_batch_size, +): + from aria.config import load_model_config + from aria.utils import _load_weight + from aria.model import ModelConfig, TransformerEMB + + assert os.path.isfile(midi_dataset_load_path) + assert os.path.isfile(checkpoint_path) + assert not os.path.isfile(embedding_dataset_save_path) + + tokenizer = AbsTokenizer() + model_state = _load_weight(checkpoint_path, "cuda") + model_state = { + k.replace("_orig_mod.", ""): v for k, v in model_state.items() + } + pretrained_model_config = ModelConfig(**load_model_config("medium-emb")) + pretrained_model_config.set_vocab_size(tokenizer.vocab_size) + pretrained_model_config.grad_checkpoint = False + pretrained_model = TransformerEMB(pretrained_model_config) + pretrained_model.load_state_dict(model_state) + pretrained_model.eval() + + hook_model_forward = torch.compile( + aria_model_forward, + mode="reduce-overhead", + fullgraph=True, + ) + + EvaluationDataset.build( + midi_dataset_load_path=midi_dataset_load_path, + save_path=embedding_dataset_save_path, + max_seq_len=MAX_SEQ_LEN, + slice_len_notes=NUM_SLICE_NOTES, + batch_size=SEQS_BATCH_SIZE, + embedding_hook=get_aria_contrastive_embedding, + hook_model=pretrained_model.cuda(), + hook_max_seq_len=MAX_SEQ_LEN, + hook_tokenizer=tokenizer, + hook_model_forward=hook_model_forward, + hook_max_batch_size=max_batch_size, + ) + + +def build_m3_dataset( + midi_dataset_load_path: str, + embedding_dataset_save_path: str, + checkpoint_path: str, +): + from aria.embeddings.m3.emb import load_clamp3_model + + assert os.path.isfile(midi_dataset_load_path) + assert os.path.isfile(checkpoint_path) + assert not os.path.isfile(embedding_dataset_save_path) + + tokenizer = AbsTokenizer() + model, patchilizer = load_clamp3_model(checkpoint_path=checkpoint_path) + + EvaluationDataset.build( + midi_dataset_load_path=midi_dataset_load_path, + save_path=embedding_dataset_save_path, + max_seq_len=MAX_SEQ_LEN, + slice_len_notes=NUM_SLICE_NOTES, + batch_size=SEQS_BATCH_SIZE, + embedding_hook=get_clamp3_embedding, + hook_model=model, + hook_patchilizer=patchilizer, + hook_tokenizer=tokenizer, + ) + + +def build_mert_dataset( + midi_dataset_load_path: str, + embedding_dataset_save_path: str, + pianoteq_exec_path: str, + pianoteq_num_procs: int, +): + from aria.embeddings.mert.emb import load_mert_model + + assert pianoteq_num_procs > 0 + assert os.path.isfile(midi_dataset_load_path) + assert not os.path.isfile(embedding_dataset_save_path) + + tokenizer = AbsTokenizer() + model, processor = load_mert_model() + + EvaluationDataset.build( + midi_dataset_load_path=midi_dataset_load_path, + save_path=embedding_dataset_save_path, + max_seq_len=MAX_SEQ_LEN, + slice_len_notes=NUM_SLICE_NOTES, + batch_size=SEQS_BATCH_SIZE, + embedding_hook=get_mert_embedding, + hook_model=model, + hook_processor=processor, + hook_tokenizer=tokenizer, + hook_pianoteq_exec_path=pianoteq_exec_path, + hook_pianoteq_num_procs=pianoteq_num_procs, + ) + + +def main(): + parser = argparse.ArgumentParser( + description="Process model and dataset paths." + ) + parser.add_argument( + "--model", + type=str, + choices=["aria", "mert", "m3"], + required=True, + ) + parser.add_argument( + "--model_cp_path", + type=str, + required=False, + help="Path from which to load the model.", + ) + parser.add_argument( + "--dataset_load_path", + type=str, + required=True, + help="Path from which to load the dataset.", + ) + parser.add_argument( + "--dataset_save_path", + type=str, + required=True, + help="Path where the dataset will be saved.", + ) + parser.add_argument( + "--aria_max_batch_size", + type=int, + default=128, + help="Max batch size for aria embedding forward pass", + ) + parser.add_argument( + "--mert_pianoteq_exec_path", + type=str, + required=False, + help="Path to pianoteq executable", + ) + parser.add_argument( + "--mert_pianoteq_num_procs", + type=int, + default=16, + help="Num of procs to use for audio synthesis", + ) + + args = parser.parse_args() + + if args.model == "aria": + assert args.aria_max_batch_size > 0 + + # TODO: TEST + build_aria_dataset( + midi_dataset_load_path=args.dataset_load_path, + embedding_dataset_save_path=args.dataset_save_path, + checkpoint_path=args.model_cp_path, + max_batch_size=args.aria_max_batch_size, + ) + elif args.model == "m3": + # TODO: TEST + build_m3_dataset( + midi_dataset_load_path=args.dataset_load_path, + embedding_dataset_save_path=args.dataset_save_path, + checkpoint_path=args.model_cp_path, + ) + elif args.model == "mert": + # TODO: TEST + build_mert_dataset( + midi_dataset_load_path=args.dataset_load_path, + embedding_dataset_save_path=args.dataset_save_path, + pianoteq_exec_path=args.mert_pianoteq_exec_path, + pianoteq_num_procs=args.mert_pianoteq_num_procs, + ) + + +if __name__ == "__main__": + main() diff --git a/paper/scripts/evaluate_embedding_with_probe.py b/paper/scripts/evaluate_embedding_with_probe.py new file mode 100644 index 0000000..e69de29 diff --git a/paper/scripts/get_unique_tags.py b/paper/scripts/get_unique_tags.py new file mode 100644 index 0000000..437da85 --- /dev/null +++ b/paper/scripts/get_unique_tags.py @@ -0,0 +1,31 @@ +import json + + +# TODO: Make this into argeparse script + + +JSONL_FILE_PATH = ( + "/mnt/ssd1/aria/data/paper/classification/form-aria/train-mididict.jsonl" +) +MUSIC_PERIOD_KEY = "form" + + +def main(): + unique_periods = set() + + with open(JSONL_FILE_PATH, "r") as f: + for line in f: + # Parse each line as a JSON object. + record = json.loads(line) + # Get the 'metadata' dictionary, defaulting to an empty dict if missing. + metadata = record.get("metadata", {}) + # Check if the metadata contains the music period key. + if MUSIC_PERIOD_KEY in metadata: + unique_periods.add(metadata[MUSIC_PERIOD_KEY]) + + # Print the list of unique music periods. + print(list(unique_periods)) + + +if __name__ == "__main__": + main() diff --git a/paper/scripts/make_cl_split.py b/paper/scripts/make_cl_split.py new file mode 100644 index 0000000..45bd511 --- /dev/null +++ b/paper/scripts/make_cl_split.py @@ -0,0 +1,113 @@ +import json +import random +from collections import Counter +from pathlib import Path + +from aria.datasets import build_mididict_dataset + +random.seed(42) + + +# TODO: Make this into argparse script + +DATASET_DIR = "/mnt/ssd1/aria-midi/final/v1/aria-midi-v1-emb-int" +METADATA_PATH = f"{DATASET_DIR}/metadata.json" +METADATA_CATEGORY = "form" +METADATA_TAGS = [ + "fugue", + "sonata", + "etude", + "nocturne", + "waltz", + "improvisation", +] + +MIDI_DATASET_TRAIN_SIZE = 10000 +MIDI_DATASET_TEST_SIZE = 1000 + +TRAIN_SAVE_PATH = ( + "/mnt/ssd1/aria/data/paper/classification/form-aria/train-mididict.jsonl" +) +TEST_SAVE_PATH = ( + "/mnt/ssd1/aria/data/paper/classification/form-aria/test-mididict.jsonl" +) + + +def get_midi_paths(): + with open(METADATA_PATH, "r") as f: + metadata_dict = json.load(f) + metadata_dict = {k: v["metadata"] for k, v in metadata_dict.items()} + + midi_paths = list(Path(DATASET_DIR).rglob("*.mid")) + buckets = {tag: [] for tag in METADATA_TAGS} + + for midi_file in midi_paths: + # Extract metadata key from file name (e.g., "000001_0" -> "1") + key = str(int(midi_file.stem.split("_")[0])) + metadata = metadata_dict.get(key) + if not metadata: + continue + tag = metadata.get(METADATA_CATEGORY) + if tag in METADATA_TAGS: + buckets[tag].append(midi_file) + + # Calculate the desired count per tag for both splits + num_tags = len(METADATA_TAGS) + desired_train_per_tag = MIDI_DATASET_TRAIN_SIZE // num_tags + desired_test_per_tag = MIDI_DATASET_TEST_SIZE // num_tags + + train_paths = [] + test_paths = [] + for tag, files in buckets.items(): + random.shuffle(files) + total_files = len(files) + total_desired = desired_train_per_tag + desired_test_per_tag + + if total_files >= total_desired: + # Enough files: use fixed numbers. + train_paths.extend(files[:desired_train_per_tag]) + test_paths.extend( + files[ + desired_train_per_tag : desired_train_per_tag + + desired_test_per_tag + ] + ) + else: + # Not enough files: split based on the desired ratio. + train_ratio = desired_train_per_tag / total_desired + train_count = round(total_files * train_ratio) + test_count = total_files - train_count # all remaining go to test + train_paths.extend(files[:train_count]) + test_paths.extend(files[train_count : train_count + test_count]) + + def _extract_tag(midi_file): + key = str(int(midi_file.stem.split("_")[0])) + return metadata_dict.get(key, {}).get(METADATA_CATEGORY, "unknown") + + train_distribution = Counter(_extract_tag(mp) for mp in train_paths) + test_distribution = Counter(_extract_tag(mp) for mp in test_paths) + + print( + f"Finished with splits: train={len(train_paths)}, test={len(test_paths)}" + ) + print("Train distribution:", dict(train_distribution)) + print("Test distribution:", dict(test_distribution)) + + return train_paths, test_paths + + +def main(): + train_paths, test_paths = get_midi_paths() + + build_mididict_dataset( + mid_paths=train_paths, + stream_save_path=TRAIN_SAVE_PATH, + ) + build_mididict_dataset( + mid_paths=test_paths, + stream_save_path=TEST_SAVE_PATH, + ) + + +if __name__ == "__main__": + main() From 87c82ac58ef7a5f2f4206f786a629cf58b674492 Mon Sep 17 00:00:00 2001 From: Louis Date: Mon, 10 Mar 2025 16:45:56 +0000 Subject: [PATCH 22/72] eval scripts --- aria/datasets.py | 37 +++-- aria/embeddings/evaluate.py | 142 ++++++++---------- aria/embeddings/m3/emb.py | 1 + aria/model.py | 4 - .../build_dataset/build_aria_dataset.sh | 7 + .../scripts/build_dataset/build_m3_dataset.sh | 5 + .../build_dataset/build_mert_dataset.sh | 6 + .../scripts/build_embedding_eval_datasets.py | 42 ++++-- .../scripts/evaluate_embedding_with_probe.py | 94 ++++++++++++ paper/scripts/evaluate_embeddings.sh | 6 + paper/scripts/get_unique_tags.py | 31 ---- .../{make_cl_split.py => make_eval_split.py} | 86 +++++++---- 12 files changed, 289 insertions(+), 172 deletions(-) create mode 100644 paper/scripts/build_dataset/build_aria_dataset.sh create mode 100644 paper/scripts/build_dataset/build_m3_dataset.sh create mode 100644 paper/scripts/build_dataset/build_mert_dataset.sh create mode 100644 paper/scripts/evaluate_embeddings.sh delete mode 100644 paper/scripts/get_unique_tags.py rename paper/scripts/{make_cl_split.py => make_eval_split.py} (61%) diff --git a/aria/datasets.py b/aria/datasets.py index 01524ab..dc2d6ae 100644 --- a/aria/datasets.py +++ b/aria/datasets.py @@ -330,7 +330,8 @@ def _preprocess_mididict(_mid_dict: MidiDict): def build_mididict_dataset( - dir: str, + dir: str | None = None, + mid_paths: list[str] = [], recur: bool = False, stream_save_path: str = None, overwrite: bool = False, @@ -398,13 +399,16 @@ def _get_mididicts_mp(_paths): "will slow down dataset building" ) - paths = [] - if recur is True: - paths += Path(dir).rglob(f"*.mid") - paths += Path(dir).rglob(f"*.midi") - else: - paths += Path(dir).glob(f"*.mid") - paths += Path(dir).glob(f"*.midi") + assert mid_paths or dir, "Must provider paths or a directory to glob files" + + paths = mid_paths if mid_paths else [] + if dir is not None: + if recur is True: + paths += Path(dir).rglob(f"*.mid") + paths += Path(dir).rglob(f"*.midi") + else: + paths += Path(dir).glob(f"*.mid") + paths += Path(dir).glob(f"*.midi") num_paths = len(paths) if num_paths == 0: @@ -839,8 +843,12 @@ def _build_epoch(_save_path, _midi_dataset): if _idx % 250 == 0: logger.info(f"Finished processing {_idx}") - buffer += [tokenizer.pad_tok] * (max_seq_len - len(buffer)) - writer.write(buffer[:max_seq_len]) + if buffer: + buffer += [tokenizer.pad_tok] * ( + max_seq_len - len(buffer) + ) + writer.write(buffer[:max_seq_len]) + elif separate_sequences is True: _idx = 0 for entry in reservoir( @@ -854,10 +862,11 @@ def _build_epoch(_save_path, _midi_dataset): writer.write(buffer[:max_seq_len]) buffer = buffer[max_seq_len:] - buffer += [tokenizer.pad_tok] * ( - max_seq_len - len(buffer) - ) - writer.write(buffer[:max_seq_len]) + if buffer: + buffer += [tokenizer.pad_tok] * ( + max_seq_len - len(buffer) + ) + writer.write(buffer[:max_seq_len]) _idx += 1 if _idx % 250 == 0: diff --git a/aria/embeddings/evaluate.py b/aria/embeddings/evaluate.py index d7a0957..dc22c2d 100644 --- a/aria/embeddings/evaluate.py +++ b/aria/embeddings/evaluate.py @@ -54,6 +54,7 @@ "waltz": 5, }, } +LEARNING_RATE = 3e-4 def model_forward( @@ -235,9 +236,9 @@ def get_baseline_embedding( emb = hidden_states[idx, last_tok_positions].tolist() elif pool_mode == "mean": pad_id = hook_tokenizer.pad_id - # Create a mask by comparing enc_seqs to pad_id. + # Create a mask by comparing enc_seqs to pad_id mask = (enc_seqs != pad_id).unsqueeze(-1).to(hidden_states.dtype) - # Sum over valid tokens and average. + # Sum over valid tokens and average sum_hidden = (hidden_states * mask).sum(dim=1) valid_counts = mask.sum(dim=1) mean_hidden = sum_hidden / valid_counts @@ -312,6 +313,7 @@ def build( max_seq_len: int, batch_size: int, embedding_hook: Callable, + per_file_embeddings: bool = False, **embedding_hook_kwargs, ): def batch_producer( @@ -319,25 +321,36 @@ def batch_producer( batch_queue: queue.Queue, batch_size: int, total_workers: int, + per_file: bool = False, ): buffer = [] termination_signals = 0 + while termination_signals < total_workers: - if batch_queue.qsize() >= 5: - time.sleep(1) + if batch_queue.qsize() > 10: + time.sleep(0.25) + try: result = results_queue.get(timeout=0.01) - if result is None: - termination_signals += 1 - continue - buffer.append(result) - if len(buffer) == batch_size: - batch_queue.put(buffer) - buffer = [] except queue.Empty: continue + if result is None: + termination_signals += 1 + continue - if buffer: + if per_file: + batch_queue.put(result) + if len(result) > batch_size: + print( + f"WARNING: Generated batch of size {len(result)} (batch_size={batch_size})" + ) + else: + buffer.extend(result) + while len(buffer) >= batch_size: + batch_queue.put(buffer[:batch_size]) + buffer = buffer[batch_size:] + + if not per_file and buffer: batch_queue.put(buffer) def producer( @@ -373,7 +386,7 @@ def worker( results_queue.put(None) break - while results_queue.qsize() > 1000: + while results_queue.qsize() > 250: time.sleep(0.5) _result = process_entry( @@ -382,8 +395,7 @@ def worker( max_seq_len=max_seq_len, tokenizer=tokenizer, ) - for _sub_result in _result: - results_queue.put(_sub_result) + results_queue.put(_result) assert os.path.isfile(midi_dataset_load_path) assert os.path.isfile(save_path) is False @@ -399,7 +411,13 @@ def worker( ) batch_producer_process = multiprocessing.Process( target=batch_producer, - args=(results_queue, batch_queue, batch_size, TOTAL_WORKERS), + args=( + results_queue, + batch_queue, + batch_size, + TOTAL_WORKERS, + per_file_embeddings, + ), ) worker_processes = [ multiprocessing.Process( @@ -428,10 +446,20 @@ def worker( _metadata = [item["metadata"] for item in batch] _embs = embedding_hook(seqs=_seqs, **embedding_hook_kwargs) - write_objs = [ - {"seq": s, "emb": e, "metadata": m} - for s, e, m in zip(_seqs, _embs, _metadata) - ] + if not per_file_embeddings: + write_objs = [ + {"emb": e, "metadata": m} + for e, m in zip(_embs, _metadata) + ] + else: + avg_emb = torch.tensor(_embs).mean(dim=0).tolist() + write_objs = [ + { + "emb": avg_emb, + "metadata": _metadata[0], + } + ] + write_executor.submit(write_entries, writer, write_objs) except queue.Empty: @@ -441,7 +469,6 @@ def worker( def _get_optim( - lr: float, model: nn.Module, total_steps: int, warmup: int = 100, @@ -449,7 +476,7 @@ def _get_optim( ): optimizer = torch.optim.AdamW( model.parameters(), - lr=lr, + lr=LEARNING_RATE, weight_decay=0.1, betas=(0.9, 0.95), eps=1e-5, @@ -538,12 +565,18 @@ def _train( def train_classifier( - emb_d: int, - train_dataset: EvaluationDataset, + embedding_dimension: int, + train_dataset_path: str, + metadata_category: str, tag_to_id: dict, batch_size: int, + num_epochs: int = 1, ): - num_epochs = 1 + train_dataset = EvaluationDataset( + load_path=train_dataset_path, + tag_to_id=tag_to_id, + metadata_category=metadata_category, + ) train_dataloader = torch.utils.data.DataLoader( dataset=train_dataset, batch_size=batch_size, @@ -553,11 +586,10 @@ def train_classifier( ) model = ClassifierHead( - d_emb=emb_d, + d_emb=embedding_dimension, num_class=len(tag_to_id.keys()), ) optimizer, scheduler = _get_optim( - lr=3e-4, model=model, total_steps=num_epochs * len(train_dataloader), ) @@ -580,15 +612,15 @@ def train_classifier( ) -def evaluate_model( +def evaluate_classifier( model: nn.Module, - val_dataset_path: str, + evaluation_dataset_path: str, metadata_category: str, tag_to_id: dict, ): id_to_tag = {v: k for k, v in tag_to_id.items()} val_dataset = EvaluationDataset( - load_path=val_dataset_path, + load_path=evaluation_dataset_path, tag_to_id=tag_to_id, metadata_category=metadata_category, ) @@ -669,53 +701,3 @@ def build_baseline_dataset(): hook_max_seq_len=MAX_SEQ_LEN, hook_tokenizer=tokenizer, ) - - -def eval_all(): - metadata_category = "music_period" - tag_to_id = CATEGORY_TAGS[metadata_category] - - dataset = EvaluationDataset( - load_path="/mnt/ssd1/aria/data/paper/classification/period-aria/train-aria.jsonl", - metadata_category=metadata_category, - tag_to_id=tag_to_id, - ) - model = train_classifier( - emb_d=512, - train_dataset=dataset, - batch_size=8, - tag_to_id=tag_to_id, - ) - print("ARIA aria_midi-test:") - evaluate_model( - model=model, - val_dataset_path="/mnt/ssd1/aria/data/paper/classification/period-aria/test-aria.jsonl", - metadata_category=metadata_category, - tag_to_id=tag_to_id, - ) - - ### - - dataset = EvaluationDataset( - load_path="/mnt/ssd1/aria/data/paper/classification/period-aria/train-m3.jsonl", - metadata_category=metadata_category, - tag_to_id=tag_to_id, - ) - model = train_classifier( - emb_d=768, - train_dataset=dataset, - batch_size=8, - tag_to_id=tag_to_id, - ) - print("M3 aria_midi-test:") - evaluate_model( - model=model, - val_dataset_path="/mnt/ssd1/aria/data/paper/classification/period-aria/test-m3.jsonl", - metadata_category=metadata_category, - tag_to_id=tag_to_id, - ) - - -if __name__ == "__main__": - # TODO: Move this - eval_all() diff --git a/aria/embeddings/m3/emb.py b/aria/embeddings/m3/emb.py index 9b67ef2..6b0dc9d 100644 --- a/aria/embeddings/m3/emb.py +++ b/aria/embeddings/m3/emb.py @@ -201,6 +201,7 @@ def get_midi_embedding( # Example usage: if __name__ == "__main__": + checkpoint_url = "https://huggingface.co/sander-wood/clamp3/resolve/main/weights_clamp3_saas_h_size_768_t_model_FacebookAI_xlm-roberta-base_t_length_128_a_size_768_a_layers_12_a_length_128_s_size_768_s_layers_12_p_size_64_p_length_512.pth" midi_file_path = "/home/loubb/Dropbox/shared/test/test.mid" model, patchilizer = load_clamp3_model( "/home/loubb/work/clamp3/weights_clamp3_saas_h_size_768_t_model_FacebookAI_xlm-roberta-base_t_length_128_a_size_768_a_layers_12_a_length_128_s_size_768_s_layers_12_p_size_64_p_length_512.pth" diff --git a/aria/model.py b/aria/model.py index 13104cc..6c85f86 100644 --- a/aria/model.py +++ b/aria/model.py @@ -148,10 +148,6 @@ def forward( Args: src (torch.tensor): Input to encoder block, of shape (batch_size, seq_len, d_model). - attn_mask (Optional[torch.tensor]): Attention mask of shape - (batch_size, seq_len). Defaults to None. - past_kv (Optional[list[KVCache]]): a list of kv caches. The list index - corresponds to the layer index. Returns: torch.tensor: Model outputs with shape (batch_size, seq_len, diff --git a/paper/scripts/build_dataset/build_aria_dataset.sh b/paper/scripts/build_dataset/build_aria_dataset.sh new file mode 100644 index 0000000..3489fab --- /dev/null +++ b/paper/scripts/build_dataset/build_aria_dataset.sh @@ -0,0 +1,7 @@ +python /home/loubb/work/aria/paper/scripts/build_embedding_eval_datasets.py \ + --model aria \ + --model_cp_path /home/loubb/work/aria/models/emb-t0.1-s2048-e25.safetensors \ + --dataset_load_path /mnt/ssd1/aria/data/paper/clas/genre-test-maestro_pijama/combined-mididict.jsonl \ + --dataset_save_path /mnt/ssd1/aria/data/paper/clas/genre-test-maestro_pijama/combined-aria-perfile.jsonl \ + --compute_per_file_embeddings \ + --aria_max_batch_size 128 \ No newline at end of file diff --git a/paper/scripts/build_dataset/build_m3_dataset.sh b/paper/scripts/build_dataset/build_m3_dataset.sh new file mode 100644 index 0000000..a285608 --- /dev/null +++ b/paper/scripts/build_dataset/build_m3_dataset.sh @@ -0,0 +1,5 @@ +python /home/loubb/work/aria/paper/scripts/build_embedding_eval_datasets.py \ + --model m3 \ + --model_cp_path /home/loubb/work/clamp3/weights_clamp3_saas_h_size_768_t_model_FacebookAI_xlm-roberta-base_t_length_128_a_size_768_a_layers_12_a_length_128_s_size_768_s_layers_12_p_size_64_p_length_512.pth \ + --dataset_load_path /mnt/ssd1/aria/data/paper/clas/genre-aria/train-mididict.jsonl \ + --dataset_save_path /mnt/ssd1/aria/data/paper/clas/genre-aria/train-m3.jsonl \ No newline at end of file diff --git a/paper/scripts/build_dataset/build_mert_dataset.sh b/paper/scripts/build_dataset/build_mert_dataset.sh new file mode 100644 index 0000000..f23179f --- /dev/null +++ b/paper/scripts/build_dataset/build_mert_dataset.sh @@ -0,0 +1,6 @@ +python /home/loubb/work/aria/paper/scripts/build_embedding_eval_datasets.py \ + --model mert \ + --dataset_load_path /mnt/ssd1/aria/data/paper/clas/genre-aria/train-mididict.jsonl \ + --dataset_save_path /mnt/ssd1/aria/data/paper/clas/genre-aria/train-mert.jsonl \ + --mert_pianoteq_exec_path "/home/loubb/pianoteq/x86-64bit/Pianoteq 8 STAGE" \ + --mert_pianoteq_num_procs 8 \ No newline at end of file diff --git a/paper/scripts/build_embedding_eval_datasets.py b/paper/scripts/build_embedding_eval_datasets.py index 8c09b8e..35444a4 100644 --- a/paper/scripts/build_embedding_eval_datasets.py +++ b/paper/scripts/build_embedding_eval_datasets.py @@ -28,7 +28,9 @@ def build_aria_dataset( midi_dataset_load_path: str, embedding_dataset_save_path: str, checkpoint_path: str, - max_batch_size, + per_file_embeddings: bool, + max_batch_size: int, + compile: bool, ): from aria.config import load_model_config from aria.utils import _load_weight @@ -50,11 +52,14 @@ def build_aria_dataset( pretrained_model.load_state_dict(model_state) pretrained_model.eval() - hook_model_forward = torch.compile( - aria_model_forward, - mode="reduce-overhead", - fullgraph=True, - ) + if compile is True: + hook_model_forward = torch.compile( + aria_model_forward, + mode="reduce-overhead", + fullgraph=True, + ) + else: + hook_model_forward = aria_model_forward EvaluationDataset.build( midi_dataset_load_path=midi_dataset_load_path, @@ -62,6 +67,7 @@ def build_aria_dataset( max_seq_len=MAX_SEQ_LEN, slice_len_notes=NUM_SLICE_NOTES, batch_size=SEQS_BATCH_SIZE, + per_file_embeddings=per_file_embeddings, embedding_hook=get_aria_contrastive_embedding, hook_model=pretrained_model.cuda(), hook_max_seq_len=MAX_SEQ_LEN, @@ -75,6 +81,7 @@ def build_m3_dataset( midi_dataset_load_path: str, embedding_dataset_save_path: str, checkpoint_path: str, + per_file_embeddings: bool, ): from aria.embeddings.m3.emb import load_clamp3_model @@ -91,6 +98,7 @@ def build_m3_dataset( max_seq_len=MAX_SEQ_LEN, slice_len_notes=NUM_SLICE_NOTES, batch_size=SEQS_BATCH_SIZE, + per_file_embeddings=per_file_embeddings, embedding_hook=get_clamp3_embedding, hook_model=model, hook_patchilizer=patchilizer, @@ -101,6 +109,7 @@ def build_m3_dataset( def build_mert_dataset( midi_dataset_load_path: str, embedding_dataset_save_path: str, + per_file_embeddings: bool, pianoteq_exec_path: str, pianoteq_num_procs: int, ): @@ -119,6 +128,7 @@ def build_mert_dataset( max_seq_len=MAX_SEQ_LEN, slice_len_notes=NUM_SLICE_NOTES, batch_size=SEQS_BATCH_SIZE, + per_file_embeddings=per_file_embeddings, embedding_hook=get_mert_embedding, hook_model=model, hook_processor=processor, @@ -156,12 +166,22 @@ def main(): required=True, help="Path where the dataset will be saved.", ) + parser.add_argument( + "--compute_per_file_embeddings", + action="store_true", + help="Compute embeddings on a per-file basis", + ) parser.add_argument( "--aria_max_batch_size", type=int, default=128, help="Max batch size for aria embedding forward pass", ) + parser.add_argument( + "--aria_compile", + action="store_true", + help="Compile forward pass", + ) parser.add_argument( "--mert_pianoteq_exec_path", type=str, @@ -179,26 +199,28 @@ def main(): if args.model == "aria": assert args.aria_max_batch_size > 0 - - # TODO: TEST build_aria_dataset( midi_dataset_load_path=args.dataset_load_path, embedding_dataset_save_path=args.dataset_save_path, checkpoint_path=args.model_cp_path, + per_file_embeddings=args.compute_per_file_embeddings, max_batch_size=args.aria_max_batch_size, + compile=args.aria_compile, ) elif args.model == "m3": - # TODO: TEST build_m3_dataset( midi_dataset_load_path=args.dataset_load_path, embedding_dataset_save_path=args.dataset_save_path, checkpoint_path=args.model_cp_path, + per_file_embeddings=args.compute_per_file_embeddings, ) elif args.model == "mert": - # TODO: TEST + assert args.mert_pianoteq_exec_path + assert args.mert_pianoteq_num_procs > 0 build_mert_dataset( midi_dataset_load_path=args.dataset_load_path, embedding_dataset_save_path=args.dataset_save_path, + per_file_embeddings=args.compute_per_file_embeddings, pianoteq_exec_path=args.mert_pianoteq_exec_path, pianoteq_num_procs=args.mert_pianoteq_num_procs, ) diff --git a/paper/scripts/evaluate_embedding_with_probe.py b/paper/scripts/evaluate_embedding_with_probe.py index e69de29..c6c9285 100644 --- a/paper/scripts/evaluate_embedding_with_probe.py +++ b/paper/scripts/evaluate_embedding_with_probe.py @@ -0,0 +1,94 @@ +import argparse + + +from aria.embeddings.evaluate import ( + train_classifier, + evaluate_classifier, + CATEGORY_TAGS, +) + +EMBEDDING_SIZE = { + "aria": 512, + "m3": 768, + "mert": 1024, +} + + +def evaluate_embeddings( + model_name: str, + metadata_category: str, + train_dataset_path: str, + test_dataset_path: str, + num_epochs: str, + batch_size: str, +): + embedding_size = EMBEDDING_SIZE[model_name] + tag_to_id = CATEGORY_TAGS[metadata_category] + + model = train_classifier( + embedding_dimension=embedding_size, + train_dataset_path=train_dataset_path, + metadata_category=metadata_category, + tag_to_id=tag_to_id, + batch_size=batch_size, + num_epochs=num_epochs, + ) + evaluate_classifier( + model=model, + evaluation_dataset_path=test_dataset_path, + metadata_category=metadata_category, + tag_to_id=tag_to_id, + ) + + +def main(): + parser = argparse.ArgumentParser( + description="Train and evaluate embeddings with linear prob" + ) + parser.add_argument( + "--model", + type=str, + choices=["aria", "mert", "m3"], + required=True, + ) + parser.add_argument( + "--metadata_category", + type=str, + choices=["genre", "music_period", "composer", "form"], + required=True, + ) + parser.add_argument( + "--train_dataset_path", + type=str, + required=True, + ) + parser.add_argument( + "--test_dataset_path", + type=str, + required=True, + ) + parser.add_argument( + "--num_epochs", + type=int, + default=1, + ) + parser.add_argument( + "--batch_size", + type=int, + default=8, + help="batch_size for training classifier", + ) + args = parser.parse_args() + + evaluate_embeddings( + model_name=args.model, + metadata_category=args.metadata_category, + train_dataset_path=args.train_dataset_path, + test_dataset_path=args.test_dataset_path, + num_epochs=args.num_epochs, + batch_size=args.batch_size, + ) + + +if __name__ == "__main__": + main() diff --git a/paper/scripts/evaluate_embeddings.sh b/paper/scripts/evaluate_embeddings.sh new file mode 100644 index 0000000..5211cdb --- /dev/null +++ b/paper/scripts/evaluate_embeddings.sh @@ -0,0 +1,6 @@ +python /home/loubb/work/aria/paper/scripts/evaluate_embedding_with_probe.py \ + --model aria \ + --metadata_category genre \ + --train_dataset_path /mnt/ssd1/aria/data/paper/clas/genre-aria/train-aria.jsonl \ + --test_dataset_path /mnt/ssd1/aria/data/paper/clas/genre-test-maestro_pijama/combined-aria-perfile.jsonl \ + --num_epochs 1 \ No newline at end of file diff --git a/paper/scripts/get_unique_tags.py b/paper/scripts/get_unique_tags.py deleted file mode 100644 index 437da85..0000000 --- a/paper/scripts/get_unique_tags.py +++ /dev/null @@ -1,31 +0,0 @@ -import json - - -# TODO: Make this into argeparse script - - -JSONL_FILE_PATH = ( - "/mnt/ssd1/aria/data/paper/classification/form-aria/train-mididict.jsonl" -) -MUSIC_PERIOD_KEY = "form" - - -def main(): - unique_periods = set() - - with open(JSONL_FILE_PATH, "r") as f: - for line in f: - # Parse each line as a JSON object. - record = json.loads(line) - # Get the 'metadata' dictionary, defaulting to an empty dict if missing. - metadata = record.get("metadata", {}) - # Check if the metadata contains the music period key. - if MUSIC_PERIOD_KEY in metadata: - unique_periods.add(metadata[MUSIC_PERIOD_KEY]) - - # Print the list of unique music periods. - print(list(unique_periods)) - - -if __name__ == "__main__": - main() diff --git a/paper/scripts/make_cl_split.py b/paper/scripts/make_eval_split.py similarity index 61% rename from paper/scripts/make_cl_split.py rename to paper/scripts/make_eval_split.py index 45bd511..6d60ee0 100644 --- a/paper/scripts/make_cl_split.py +++ b/paper/scripts/make_eval_split.py @@ -1,45 +1,30 @@ import json import random +import argparse from collections import Counter from pathlib import Path from aria.datasets import build_mididict_dataset +from aria.embeddings.evaluate import CATEGORY_TAGS random.seed(42) - -# TODO: Make this into argparse script - -DATASET_DIR = "/mnt/ssd1/aria-midi/final/v1/aria-midi-v1-emb-int" -METADATA_PATH = f"{DATASET_DIR}/metadata.json" -METADATA_CATEGORY = "form" -METADATA_TAGS = [ - "fugue", - "sonata", - "etude", - "nocturne", - "waltz", - "improvisation", -] - MIDI_DATASET_TRAIN_SIZE = 10000 MIDI_DATASET_TEST_SIZE = 1000 -TRAIN_SAVE_PATH = ( - "/mnt/ssd1/aria/data/paper/classification/form-aria/train-mididict.jsonl" -) -TEST_SAVE_PATH = ( - "/mnt/ssd1/aria/data/paper/classification/form-aria/test-mididict.jsonl" -) - -def get_midi_paths(): - with open(METADATA_PATH, "r") as f: +def get_midi_paths( + dataset_dir: str, + metadata_path: str, + metadata_category: str, +): + metadata_tags = list(CATEGORY_TAGS[metadata_category].keys()) + with open(metadata_path, "r") as f: metadata_dict = json.load(f) metadata_dict = {k: v["metadata"] for k, v in metadata_dict.items()} - midi_paths = list(Path(DATASET_DIR).rglob("*.mid")) - buckets = {tag: [] for tag in METADATA_TAGS} + midi_paths = list(Path(dataset_dir).rglob("*.mid")) + buckets = {tag: [] for tag in metadata_tags} for midi_file in midi_paths: # Extract metadata key from file name (e.g., "000001_0" -> "1") @@ -47,12 +32,12 @@ def get_midi_paths(): metadata = metadata_dict.get(key) if not metadata: continue - tag = metadata.get(METADATA_CATEGORY) - if tag in METADATA_TAGS: + tag = metadata.get(metadata_category) + if tag in metadata_tags: buckets[tag].append(midi_file) # Calculate the desired count per tag for both splits - num_tags = len(METADATA_TAGS) + num_tags = len(metadata_tags) desired_train_per_tag = MIDI_DATASET_TRAIN_SIZE // num_tags desired_test_per_tag = MIDI_DATASET_TEST_SIZE // num_tags @@ -82,7 +67,7 @@ def get_midi_paths(): def _extract_tag(midi_file): key = str(int(midi_file.stem.split("_")[0])) - return metadata_dict.get(key, {}).get(METADATA_CATEGORY, "unknown") + return metadata_dict.get(key, {}).get(metadata_category, "unknown") train_distribution = Counter(_extract_tag(mp) for mp in train_paths) test_distribution = Counter(_extract_tag(mp) for mp in test_paths) @@ -97,15 +82,50 @@ def _extract_tag(midi_file): def main(): - train_paths, test_paths = get_midi_paths() + parser = argparse.ArgumentParser( + description="Train and evaluate embeddings with linear prob" + ) + parser.add_argument( + "--dataset_dir", + type=str, + required=True, + ) + parser.add_argument( + "--metadata_path", + type=str, + required=True, + ) + parser.add_argument( + "--metadata_category", + type=str, + choices=["genre", "music_period", "composer", "form"], + required=True, + ) + parser.add_argument( + "--train_save_path", + type=str, + required=True, + ) + parser.add_argument( + "--test_save_path", + type=str, + required=True, + ) + args = parser.parse_args() + + train_paths, test_paths = get_midi_paths( + dataset_dir=args.dataset_dir, + metadata_path=args.metadata_path, + metadata_category=args.metadata_category, + ) build_mididict_dataset( mid_paths=train_paths, - stream_save_path=TRAIN_SAVE_PATH, + stream_save_path=args.train_save_path, ) build_mididict_dataset( mid_paths=test_paths, - stream_save_path=TEST_SAVE_PATH, + stream_save_path=args.test_save_path, ) From 58d439f6cab86b6b2f05b1be61d1b8f815b67f42 Mon Sep 17 00:00:00 2001 From: Louis Date: Mon, 10 Mar 2025 21:17:39 +0000 Subject: [PATCH 23/72] fix range bug --- aria/embeddings/contrastive_finetune.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/aria/embeddings/contrastive_finetune.py b/aria/embeddings/contrastive_finetune.py index 394af89..9bf8e2a 100644 --- a/aria/embeddings/contrastive_finetune.py +++ b/aria/embeddings/contrastive_finetune.py @@ -142,7 +142,10 @@ def get_slice( ): _midi_dict = copy.deepcopy(midi_dict) slice_length = random.randint(min_num_notes, max_num_notes) - idx = random.randint(0, len(_midi_dict.note_msgs) - min_num_notes) + if len(_midi_dict.note_msgs) <= min_num_notes: + idx = 0 + else: + idx = random.randint(0, len(_midi_dict.note_msgs) - min_num_notes) _midi_dict.note_msgs = _midi_dict.note_msgs[idx : idx + slice_length] _midi_dict.metadata = {} From b238edcb6c6051d1547f30e9023d562d0ee7cf59 Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 11 Mar 2025 16:50:47 +0000 Subject: [PATCH 24/72] add m3 only embeddings --- aria/embeddings/m3/config.py | 77 ++++++++++++++++++++++++------------ aria/embeddings/m3/emb.py | 73 ++++++++++++---------------------- aria/embeddings/m3/utils.py | 2 +- 3 files changed, 78 insertions(+), 74 deletions(-) diff --git a/aria/embeddings/m3/config.py b/aria/embeddings/m3/config.py index 893a1c0..ff53938 100644 --- a/aria/embeddings/m3/config.py +++ b/aria/embeddings/m3/config.py @@ -25,24 +25,37 @@ M3_LOAD_CKPT = True # Load model weights from a checkpoint if available M3_WEIGHTS_PATH = ( - "weights_m3"+ - "_h_size_" + str(M3_HIDDEN_SIZE) + - "_t_layers_" + str(TOKEN_NUM_LAYERS) + - "_p_layers_" + str(PATCH_NUM_LAYERS) + - "_p_size_" + str(PATCH_SIZE) + - "_p_length_" + str(PATCH_LENGTH) + - "_lr_" + str(M3_LEARNING_RATE) + - "_batch_" + str(M3_BATCH_SIZE) + - "_mask_" + str(M3_MASK_RATIO) + ".pth" + "weights_m3" + + "_h_size_" + + str(M3_HIDDEN_SIZE) + + "_t_layers_" + + str(TOKEN_NUM_LAYERS) + + "_p_layers_" + + str(PATCH_NUM_LAYERS) + + "_p_size_" + + str(PATCH_SIZE) + + "_p_length_" + + str(PATCH_LENGTH) + + "_lr_" + + str(M3_LEARNING_RATE) + + "_batch_" + + str(M3_BATCH_SIZE) + + "_mask_" + + str(M3_MASK_RATIO) + + ".pth" ) # Path to store the model weights -M3_LOGS_PATH = M3_WEIGHTS_PATH.replace("weights", "logs").replace("pth", "txt") # Path to save training logs +M3_LOGS_PATH = M3_WEIGHTS_PATH.replace("weights", "logs").replace( + "pth", "txt" +) # Path to save training logs # -------------------- Configuration for CLaMP3 Training ---------------- CLAMP3_TRAIN_JSONL = "" # Path to the JSONL file with training data for CLaMP3 CLAMP3_EVAL_JSONL = "" # Path to the JSONL file with evaluation data for CLaMP3 (optional) CLAMP3_HIDDEN_SIZE = 768 # Size of the hidden layer -TEXT_MODEL_NAME = "FacebookAI/xlm-roberta-base" # Name of the pre-trained text model +TEXT_MODEL_NAME = ( + "FacebookAI/xlm-roberta-base" # Name of the pre-trained text model +) MAX_TEXT_LENGTH = 128 # Maximum allowed length for text input AUDIO_HIDDEN_SIZE = 768 # Size of the hidden layer for audio features @@ -54,7 +67,9 @@ CLAMP3_BATCH_SIZE = 256 # Batch size per GPU (single card) during training LOGIT_SCALE = 1 # Scaling factor for contrastive loss -FREEZE_TEXT = False # Freeze the weights of the text model and text projection layer +FREEZE_TEXT = ( + False # Freeze the weights of the text model and text projection layer +) TEXT_DROPOUT = True # Whether to apply dropout during text processing CLAMP3_DETERMINISTIC = True # Ensures deterministic results with random seeds CLAMP3_LOAD_M3 = True # Load weights from the M3 model @@ -63,17 +78,29 @@ SAVE_EVERY = 5 # Save model weights every SAVE_EVERY epochs CLAMP3_WEIGHTS_PATH = ( - "weights_clamp3_saas" + - "_h_size_" + str(CLAMP3_HIDDEN_SIZE) + - "_t_model_" + TEXT_MODEL_NAME.replace("/", "_") + - "_t_length_" + str(MAX_TEXT_LENGTH) + - "_a_size_" + str(AUDIO_HIDDEN_SIZE) + - "_a_layers_" + str(AUDIO_NUM_LAYERS) + - "_a_length_" + str(MAX_AUDIO_LENGTH) + - "_s_size_" + str(M3_HIDDEN_SIZE) + - "_s_layers_" + str(PATCH_NUM_LAYERS) + - "_p_size_" + str(PATCH_SIZE) + - "_p_length_" + str(PATCH_LENGTH) + ".pth" - + "weights_clamp3_saas" + + "_h_size_" + + str(CLAMP3_HIDDEN_SIZE) + + "_t_model_" + + TEXT_MODEL_NAME.replace("/", "_") + + "_t_length_" + + str(MAX_TEXT_LENGTH) + + "_a_size_" + + str(AUDIO_HIDDEN_SIZE) + + "_a_layers_" + + str(AUDIO_NUM_LAYERS) + + "_a_length_" + + str(MAX_AUDIO_LENGTH) + + "_s_size_" + + str(M3_HIDDEN_SIZE) + + "_s_layers_" + + str(PATCH_NUM_LAYERS) + + "_p_size_" + + str(PATCH_SIZE) + + "_p_length_" + + str(PATCH_LENGTH) + + ".pth" ) # Path to store CLaMP3 model weights -CLAMP3_LOGS_PATH = CLAMP3_WEIGHTS_PATH.replace("weights", "logs").replace("pth", "txt") # Path to save training logs +CLAMP3_LOGS_PATH = CLAMP3_WEIGHTS_PATH.replace("weights", "logs").replace( + "pth", "txt" +) # Path to save training logs diff --git a/aria/embeddings/m3/emb.py b/aria/embeddings/m3/emb.py index 6b0dc9d..cd68aea 100644 --- a/aria/embeddings/m3/emb.py +++ b/aria/embeddings/m3/emb.py @@ -1,7 +1,7 @@ import os import torch import mido -from transformers import BertConfig +from transformers import BertConfig, GPT2Config from aria.embeddings.m3.config import ( AUDIO_HIDDEN_SIZE, @@ -12,11 +12,11 @@ PATCH_LENGTH, PATCH_SIZE, CLAMP3_HIDDEN_SIZE, - CLAMP3_LOAD_M3, TEXT_MODEL_NAME, + TOKEN_NUM_LAYERS, ) -from aria.embeddings.m3.utils import CLaMP3Model, M3Patchilizer +from aria.embeddings.m3.utils import CLaMP3Model, M3Patchilizer, M3Model def msg_to_str(msg): @@ -63,12 +63,8 @@ def load_midi( return "\n".join(msg_list) -def load_clamp3_model(checkpoint_path: str): - """ - Loads the CLaMP3 model along with its configuration and M3Patchilizer. - The model weights are loaded from the checkpoint specified in CLAMP3_WEIGHTS_PATH. - """ - # Build the configurations for audio and symbolic (M3) parts. +def load_clamp3_model(checkpoint_path: str, m3_only: bool = False): + # Create audio and symbolic configurations. audio_config = BertConfig( vocab_size=1, hidden_size=AUDIO_HIDDEN_SIZE, @@ -85,29 +81,39 @@ def load_clamp3_model(checkpoint_path: str): intermediate_size=M3_HIDDEN_SIZE * 4, max_position_embeddings=PATCH_LENGTH, ) + decoder_config = GPT2Config( + vocab_size=128, + n_positions=PATCH_SIZE, + n_embd=M3_HIDDEN_SIZE, + n_layer=TOKEN_NUM_LAYERS, + n_head=M3_HIDDEN_SIZE // 64, + n_inner=M3_HIDDEN_SIZE * 4, + ) - # Instantiate the CLaMP3 model. model = CLaMP3Model( audio_config=audio_config, symbolic_config=symbolic_config, text_model_name=TEXT_MODEL_NAME, hidden_size=CLAMP3_HIDDEN_SIZE, - load_m3=CLAMP3_LOAD_M3, + load_m3=True, ) model = model.to("cuda") model.eval() - # Determine checkpoint path. if not os.path.exists(checkpoint_path): raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") - # Load checkpoint weights. checkpoint = torch.load( checkpoint_path, map_location="cuda", weights_only=True ) - model.load_state_dict(checkpoint["model"]) - # Instantiate the patchilizer from utils. + if m3_only is False: + model.load_state_dict(checkpoint["model"]) + else: + temp_m3_model = M3Model(symbolic_config, decoder_config) + temp_m3_model.load_state_dict(checkpoint["model"]) + model.symbolic_model.load_state_dict(temp_m3_model.encoder.state_dict()) + patchilizer = M3Patchilizer() return model, patchilizer @@ -115,24 +121,16 @@ def load_clamp3_model(checkpoint_path: str): def get_midi_embedding( mid: mido.MidiFile, - model: torch.nn.Module, + model: CLaMP3Model, patchilizer: M3Patchilizer, get_global=True, ): device = "cuda" - # Step 1: Convert MIDI to an MTF-format string. mtf_str = load_midi(mid=mid, m3_compatible=True) - - # Step 3: Encode the MTF string to patches. - # The patchilizer returns a list of patches (each patch is a list of token IDs). patches = patchilizer.encode(mtf_str, add_special_patches=True) - # Convert the list of patches to a tensor. - # Each patch is of length PATCH_SIZE; we assume tokens are integers. - # The resulting tensor shape will be [num_patches, PATCH_SIZE]. token_tensor = torch.tensor(patches, dtype=torch.long).to(device) - # Step 4: Create segments of fixed length PATCH_LENGTH. num_tokens = token_tensor.size(0) segments = [] seg_weights = [] @@ -141,12 +139,11 @@ def get_midi_embedding( cur_len = seg.size(0) segments.append(seg) seg_weights.append(cur_len) - # For global feature extraction, ensure the last segment is exactly PATCH_LENGTH tokens + if num_tokens > PATCH_LENGTH: segments[-1] = token_tensor[-PATCH_LENGTH:] seg_weights[-1] = segments[-1].size(0) - # Step 5: Process each segment to obtain features. processed_feats = [] for seg in segments: cur_len = seg.size(0) @@ -163,7 +160,7 @@ def get_midi_embedding( ) seg = torch.cat([seg, pad], dim=0) seg = seg.unsqueeze(0) # Add batch dimension. - # Create a mask: 1 for valid tokens, 0 for padding. + mask = torch.cat( [ torch.ones(cur_len, device=device), @@ -175,14 +172,12 @@ def get_midi_embedding( feat = model.get_symbolic_features( symbolic_inputs=seg, symbolic_masks=mask, get_global=get_global ) - # When not getting a global feature, you may wish to trim the output to the valid length. + if not get_global: feat = feat[:, : int(mask.sum().item()), :] processed_feats.append(feat) - # Step 6: Combine segment features into a single embedding. if not get_global: - # Concatenate features along the time dimension. embedding = torch.cat( [feat.squeeze(0) for feat in processed_feats], dim=0 ) @@ -197,21 +192,3 @@ def get_midi_embedding( embedding = (feats * weights).sum(dim=0) / weights.sum() return embedding.view(-1) - - -# Example usage: -if __name__ == "__main__": - checkpoint_url = "https://huggingface.co/sander-wood/clamp3/resolve/main/weights_clamp3_saas_h_size_768_t_model_FacebookAI_xlm-roberta-base_t_length_128_a_size_768_a_layers_12_a_length_128_s_size_768_s_layers_12_p_size_64_p_length_512.pth" - midi_file_path = "/home/loubb/Dropbox/shared/test/test.mid" - model, patchilizer = load_clamp3_model( - "/home/loubb/work/clamp3/weights_clamp3_saas_h_size_768_t_model_FacebookAI_xlm-roberta-base_t_length_128_a_size_768_a_layers_12_a_length_128_s_size_768_s_layers_12_p_size_64_p_length_512.pth" - ) - mid = mido.MidiFile(midi_file_path) - embedding = get_midi_embedding( - mid, - model=model, - patchilizer=patchilizer, - get_global=True, - ) - print(embedding) - print(embedding.shape) diff --git a/aria/embeddings/m3/utils.py b/aria/embeddings/m3/utils.py index 3952ac7..2713aaf 100644 --- a/aria/embeddings/m3/utils.py +++ b/aria/embeddings/m3/utils.py @@ -579,7 +579,7 @@ def get_symbolic_features( symbolic_features = self.avg_pooling( symbolic_features, symbolic_masks ) - symbolic_features = self.symbolic_proj(symbolic_features) + # symbolic_features = self.symbolic_proj(symbolic_features) return symbolic_features From b485c042c70d158fc47e8744417fb34d24ad92ed Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 11 Mar 2025 16:58:06 +0000 Subject: [PATCH 25/72] update script for m3 embeddings --- aria/embeddings/evaluate.py | 25 +++++++++++++------ .../scripts/build_embedding_eval_datasets.py | 20 ++++++++++++--- 2 files changed, 34 insertions(+), 11 deletions(-) diff --git a/aria/embeddings/evaluate.py b/aria/embeddings/evaluate.py index dc22c2d..7ca4f5d 100644 --- a/aria/embeddings/evaluate.py +++ b/aria/embeddings/evaluate.py @@ -53,6 +53,16 @@ "fugue": 4, "waltz": 5, }, + "pianist": { + "hisaishi": 0, + "hancock": 1, + "bethel": 2, + "einaudi": 3, + "clayderman": 4, + "ryuichi": 5, + "yiruma": 6, + "hillsong": 7, + }, } LEARNING_RATE = 3e-4 @@ -64,15 +74,15 @@ def model_forward( return model(idxs) -def chunk_and_pad(lst: list, n: int): - return [lst[i : i + n] for i in range(0, len(lst), n)] - - def write_entries(writer, entries): for entry in entries: writer.write(entry) +def chunk_and_pad(lst: list, n: int): + return [lst[i : i + n] for i in range(0, len(lst), n)] + + def process_entry( entry, slice_len_notes: int, @@ -281,7 +291,7 @@ def __getitem__(self, idx: int): tag = metadata.get(self.metadata_category, "other") tag = tag if tag in self.tag_to_id.keys() else "other" - assert tag in self.tag_to_id + assert tag in self.tag_to_id, metadata tag_tensor = torch.tensor(self.tag_to_id[tag]) emb_tensor = torch.tensor(emb) @@ -448,13 +458,14 @@ def worker( if not per_file_embeddings: write_objs = [ - {"emb": e, "metadata": m} - for e, m in zip(_embs, _metadata) + {"seq": s, "emb": e, "metadata": m} + for s, e, m in zip(_seqs, _embs, _metadata) ] else: avg_emb = torch.tensor(_embs).mean(dim=0).tolist() write_objs = [ { + "seqs": _seqs, "emb": avg_emb, "metadata": _metadata[0], } diff --git a/paper/scripts/build_embedding_eval_datasets.py b/paper/scripts/build_embedding_eval_datasets.py index 35444a4..3f2dad6 100644 --- a/paper/scripts/build_embedding_eval_datasets.py +++ b/paper/scripts/build_embedding_eval_datasets.py @@ -1,7 +1,6 @@ import os import argparse import torch - import torch.nn as nn from ariautils.tokenizer import AbsTokenizer @@ -81,6 +80,7 @@ def build_m3_dataset( midi_dataset_load_path: str, embedding_dataset_save_path: str, checkpoint_path: str, + is_encoder_checkpoint: bool, per_file_embeddings: bool, ): from aria.embeddings.m3.emb import load_clamp3_model @@ -90,13 +90,19 @@ def build_m3_dataset( assert not os.path.isfile(embedding_dataset_save_path) tokenizer = AbsTokenizer() - model, patchilizer = load_clamp3_model(checkpoint_path=checkpoint_path) + model, patchilizer = load_clamp3_model( + checkpoint_path=checkpoint_path, m3_only=is_encoder_checkpoint + ) + + # Workaround to outsource global_emb calculation to model + slice_len_notes = NUM_SLICE_NOTES if per_file_embeddings is False else 10000 + max_seq_len = MAX_SEQ_LEN if per_file_embeddings is False else 100000 EvaluationDataset.build( midi_dataset_load_path=midi_dataset_load_path, save_path=embedding_dataset_save_path, - max_seq_len=MAX_SEQ_LEN, - slice_len_notes=NUM_SLICE_NOTES, + max_seq_len=max_seq_len, + slice_len_notes=slice_len_notes, batch_size=SEQS_BATCH_SIZE, per_file_embeddings=per_file_embeddings, embedding_hook=get_clamp3_embedding, @@ -182,6 +188,11 @@ def main(): action="store_true", help="Compile forward pass", ) + parser.add_argument( + "--m3_is_encoder_checkpoint", + action="store_true", + help="Checkpoint is for entire clamp model.", + ) parser.add_argument( "--mert_pianoteq_exec_path", type=str, @@ -212,6 +223,7 @@ def main(): midi_dataset_load_path=args.dataset_load_path, embedding_dataset_save_path=args.dataset_save_path, checkpoint_path=args.model_cp_path, + is_encoder_checkpoint=args.m3_is_encoder_checkpoint, per_file_embeddings=args.compute_per_file_embeddings, ) elif args.model == "mert": From bda70ac0e500e91dec741307bf22d84460ea71bc Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 11 Mar 2025 16:58:28 +0000 Subject: [PATCH 26/72] update for pianist eval --- paper/scripts/evaluate_embedding_with_probe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paper/scripts/evaluate_embedding_with_probe.py b/paper/scripts/evaluate_embedding_with_probe.py index c6c9285..9753b26 100644 --- a/paper/scripts/evaluate_embedding_with_probe.py +++ b/paper/scripts/evaluate_embedding_with_probe.py @@ -54,7 +54,7 @@ def main(): parser.add_argument( "--metadata_category", type=str, - choices=["genre", "music_period", "composer", "form"], + choices=["genre", "music_period", "composer", "form", "pianist"], required=True, ) parser.add_argument( From 3fa2b2971a96e050986aff87bae6473d66f199ef Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 11 Mar 2025 16:58:51 +0000 Subject: [PATCH 27/72] add pianist8 dataset script --- paper/scripts/make_pianist8_dataset.py | 118 +++++++++++++++++++++++++ 1 file changed, 118 insertions(+) create mode 100644 paper/scripts/make_pianist8_dataset.py diff --git a/paper/scripts/make_pianist8_dataset.py b/paper/scripts/make_pianist8_dataset.py new file mode 100644 index 0000000..79e4163 --- /dev/null +++ b/paper/scripts/make_pianist8_dataset.py @@ -0,0 +1,118 @@ +import os +import re +import json +import random +import argparse +from pathlib import Path + +from ariautils.midi import MidiDict +from aria.datasets import MidiDataset + +random.seed(43) + +SPLIT_RATIO = 0.9 + + +def get_midi_paths(dataset_dir: str, test_split_file: str = None): + train_paths = [] + test_paths = [] + + test_pairs = set() + if test_split_file: + with open(test_split_file, "r") as f: + test_files = json.load(f) + for entry in test_files: + parts = re.split(r"[\\/]", entry) + assert len(parts) == 3 + + pianist = parts[1].lower() + file_name = parts[2].replace(".npy", ".mid") + test_pairs.add((pianist, file_name)) + + pianist_categories = os.listdir(dataset_dir) + for pianist in pianist_categories: + pianist_dir = os.path.join(dataset_dir, pianist) + mid_paths = list(Path(pianist_dir).glob("*.mid")) + random.shuffle(mid_paths) + print(f"Found {len(mid_paths)} for {pianist}") + + if test_pairs: + for path in mid_paths: + if (pianist.lower(), path.name) in test_pairs: + test_paths.append( + {"path": path, "pianist": pianist.lower()} + ) + else: + train_paths.append( + {"path": path, "pianist": pianist.lower()} + ) + else: + split_idx = int(len(mid_paths) * SPLIT_RATIO) + train_paths += [ + {"path": path, "pianist": pianist.lower()} + for path in mid_paths[:split_idx] + ] + test_paths += [ + {"path": path, "pianist": pianist.lower()} + for path in mid_paths[split_idx:] + ] + + train_mididicts = [] + for path_entry in train_paths: + _mid_dict = MidiDict.from_midi(mid_path=path_entry["path"]) + _mid_dict.metadata["pianist"] = path_entry["pianist"] + train_mididicts.append(_mid_dict) + + test_mididicts = [] + for path_entry in test_paths: + _mid_dict = MidiDict.from_midi(mid_path=path_entry["path"]) + _mid_dict.metadata["pianist"] = path_entry["pianist"] + test_mididicts.append(_mid_dict) + + return train_mididicts, test_mididicts + + +def main(): + parser = argparse.ArgumentParser( + description="Create pianist8 dataset train-test split" + ) + parser.add_argument( + "--dataset_dir", + type=str, + required=True, + ) + parser.add_argument( + "--train_save_path", + type=str, + required=True, + ) + parser.add_argument( + "--test_save_path", + type=str, + required=True, + ) + parser.add_argument( + "--test_split", + type=str, + default=None, + help="Path to JSON file listing test split files (paths like 'pianist8//.npy')", + ) + args = parser.parse_args() + + assert os.path.isdir(args.dataset_dir) + assert not os.path.isfile(args.train_save_path) + assert not os.path.isfile(args.test_save_path) + + train_mididicts, test_mididicts = get_midi_paths( + dataset_dir=args.dataset_dir, + test_split_file=args.test_split, + ) + + TrainDataset = MidiDataset(entries=train_mididicts).save( + args.train_save_path + ) + TestDataset = MidiDataset(entries=test_mididicts).save(args.test_save_path) + + +if __name__ == "__main__": + main() From 4a7427eb1e477640001bf946f1d3a197c260f5e4 Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 12 Mar 2025 14:50:06 +0000 Subject: [PATCH 28/72] adjust per file emb logic and update scripts --- aria/embeddings/evaluate.py | 47 ++++++++++++++----- aria/embeddings/mert/emb.py | 3 -- .../build_dataset/build_aria_dataset.sh | 14 ++++-- .../build_dataset/build_clamp_dataset.sh | 13 +++++ .../scripts/build_dataset/build_m3_dataset.sh | 16 +++++-- .../build_dataset/build_mert_dataset.sh | 15 ++++-- paper/scripts/evaluate_embeddings.sh | 14 ++++-- 7 files changed, 93 insertions(+), 29 deletions(-) create mode 100644 paper/scripts/build_dataset/build_clamp_dataset.sh diff --git a/aria/embeddings/evaluate.py b/aria/embeddings/evaluate.py index 7ca4f5d..97e949b 100644 --- a/aria/embeddings/evaluate.py +++ b/aria/embeddings/evaluate.py @@ -349,18 +349,24 @@ def batch_producer( continue if per_file: - batch_queue.put(result) - if len(result) > batch_size: + assert all( + "abs_load_path" in r["metadata"].keys() for r in result + ) + buffer.extend(result) + if len(buffer) > 2 * batch_size: print( - f"WARNING: Generated batch of size {len(result)} (batch_size={batch_size})" + f"WARNING: Generated batch of size {len(buffer)} (batch_size={batch_size})" ) + if len(buffer) >= batch_size: + batch_queue.put(buffer) + buffer = [] else: buffer.extend(result) while len(buffer) >= batch_size: batch_queue.put(buffer[:batch_size]) buffer = buffer[batch_size:] - if not per_file and buffer: + if buffer: batch_queue.put(buffer) def producer( @@ -462,14 +468,31 @@ def worker( for s, e, m in zip(_seqs, _embs, _metadata) ] else: - avg_emb = torch.tensor(_embs).mean(dim=0).tolist() - write_objs = [ - { - "seqs": _seqs, - "emb": avg_emb, - "metadata": _metadata[0], - } - ] + # Calculate per-file emb by averaging over abs_load_path embs + groups = {} + for seq, emb, meta in zip(_seqs, _embs, _metadata): + file_path = meta["abs_load_path"] + if file_path not in groups: + groups[file_path] = { + "seqs": [], + "embs": [], + "metadata": meta, + } + groups[file_path]["seqs"].append(seq) + groups[file_path]["embs"].append(emb) + + write_objs = [] + for file_path, data in groups.items(): + avg_emb = ( + torch.tensor(data["embs"]).mean(dim=0).tolist() + ) + write_objs.append( + { + "seqs": data["seqs"], + "emb": avg_emb, + "metadata": data["metadata"], + } + ) write_executor.submit(write_entries, writer, write_objs) diff --git a/aria/embeddings/mert/emb.py b/aria/embeddings/mert/emb.py index 9f7d184..bcee333 100644 --- a/aria/embeddings/mert/emb.py +++ b/aria/embeddings/mert/emb.py @@ -24,13 +24,10 @@ def seq_to_audio_path( mid = tokenizer.detokenize(seq) mid.to_midi().save(mid_path) - # Step 3: Create a temporary WAV file for output audio_temp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) audio_path = audio_temp.name audio_temp.close() # Close so CLI can write to it - # Step 4: Run CLI command to generate audio using Pianoteq - # EXEC_PATH = "/home/loubb/pianoteq/x86-64bit/Pianoteq 8 STAGE" preset = "NY Steinway D Classical Recording" pianoteq_cmd = f"{shlex.quote(pianoteq_exec_path)} --preset {shlex.quote(preset)} --rate 24000 --midi {mid_path} --wav {audio_path}" diff --git a/paper/scripts/build_dataset/build_aria_dataset.sh b/paper/scripts/build_dataset/build_aria_dataset.sh index 3489fab..a46a273 100644 --- a/paper/scripts/build_dataset/build_aria_dataset.sh +++ b/paper/scripts/build_dataset/build_aria_dataset.sh @@ -1,7 +1,15 @@ python /home/loubb/work/aria/paper/scripts/build_embedding_eval_datasets.py \ --model aria \ --model_cp_path /home/loubb/work/aria/models/emb-t0.1-s2048-e25.safetensors \ - --dataset_load_path /mnt/ssd1/aria/data/paper/clas/genre-test-maestro_pijama/combined-mididict.jsonl \ - --dataset_save_path /mnt/ssd1/aria/data/paper/clas/genre-test-maestro_pijama/combined-aria-perfile.jsonl \ + --dataset_load_path /mnt/ssd1/aria/data/paper/clas/pianist/train-mididict.jsonl \ + --dataset_save_path /mnt/ssd1/aria/data/paper/clas/pianist/train-aria.jsonl \ --compute_per_file_embeddings \ - --aria_max_batch_size 128 \ No newline at end of file + --aria_max_batch_size 128 + +python /home/loubb/work/aria/paper/scripts/build_embedding_eval_datasets.py \ + --model aria \ + --model_cp_path /home/loubb/work/aria/models/emb-t0.1-s2048-e25.safetensors \ + --dataset_load_path /mnt/ssd1/aria/data/paper/clas/pianist/test-mididict.jsonl \ + --dataset_save_path /mnt/ssd1/aria/data/paper/clas/pianist/test-aria.jsonl \ + --compute_per_file_embeddings \ + --aria_max_batch_size 128 diff --git a/paper/scripts/build_dataset/build_clamp_dataset.sh b/paper/scripts/build_dataset/build_clamp_dataset.sh new file mode 100644 index 0000000..8c06845 --- /dev/null +++ b/paper/scripts/build_dataset/build_clamp_dataset.sh @@ -0,0 +1,13 @@ +python /home/loubb/work/aria/paper/scripts/build_embedding_eval_datasets.py \ + --model m3 \ + --model_cp_path /home/loubb/work/clamp3/weights_clamp3_saas_h_size_768_t_model_FacebookAI_xlm-roberta-base_t_length_128_a_size_768_a_layers_12_a_length_128_s_size_768_s_layers_12_p_size_64_p_length_512.pth \ + --dataset_load_path /mnt/ssd1/aria/data/paper/clas/pianist/train-mididict.jsonl \ + --dataset_save_path /mnt/ssd1/aria/data/paper/clas/pianist/train-clamp.jsonl \ + --compute_per_file_embeddings + +python /home/loubb/work/aria/paper/scripts/build_embedding_eval_datasets.py \ + --model m3 \ + --model_cp_path /home/loubb/work/clamp3/weights_clamp3_saas_h_size_768_t_model_FacebookAI_xlm-roberta-base_t_length_128_a_size_768_a_layers_12_a_length_128_s_size_768_s_layers_12_p_size_64_p_length_512.pth \ + --dataset_load_path /mnt/ssd1/aria/data/paper/clas/pianist/test-mididict.jsonl \ + --dataset_save_path /mnt/ssd1/aria/data/paper/clas/pianist/test-clamp.jsonl \ + --compute_per_file_embeddings diff --git a/paper/scripts/build_dataset/build_m3_dataset.sh b/paper/scripts/build_dataset/build_m3_dataset.sh index a285608..ee65960 100644 --- a/paper/scripts/build_dataset/build_m3_dataset.sh +++ b/paper/scripts/build_dataset/build_m3_dataset.sh @@ -1,5 +1,15 @@ python /home/loubb/work/aria/paper/scripts/build_embedding_eval_datasets.py \ --model m3 \ - --model_cp_path /home/loubb/work/clamp3/weights_clamp3_saas_h_size_768_t_model_FacebookAI_xlm-roberta-base_t_length_128_a_size_768_a_layers_12_a_length_128_s_size_768_s_layers_12_p_size_64_p_length_512.pth \ - --dataset_load_path /mnt/ssd1/aria/data/paper/clas/genre-aria/train-mididict.jsonl \ - --dataset_save_path /mnt/ssd1/aria/data/paper/clas/genre-aria/train-m3.jsonl \ No newline at end of file + --model_cp_path /home/loubb/work/clamp3/weights_m3_p_size_64_p_length_512_t_layers_3_p_layers_12_h_size_768_lr_0.0001_batch_16_mask_0.45.pth \ + --dataset_load_path /mnt/ssd1/aria/data/paper/clas/pianist/train-mididict.jsonl \ + --dataset_save_path /mnt/ssd1/aria/data/paper/clas/pianist/train-m3.jsonl \ + --compute_per_file_embeddings \ + --m3_is_encoder_checkpoint + +python /home/loubb/work/aria/paper/scripts/build_embedding_eval_datasets.py \ + --model m3 \ + --model_cp_path /home/loubb/work/clamp3/weights_m3_p_size_64_p_length_512_t_layers_3_p_layers_12_h_size_768_lr_0.0001_batch_16_mask_0.45.pth \ + --dataset_load_path /mnt/ssd1/aria/data/paper/clas/pianist/test-mididict.jsonl \ + --dataset_save_path /mnt/ssd1/aria/data/paper/clas/pianist/test-m3.jsonl \ + --compute_per_file_embeddings \ + --m3_is_encoder_checkpoint diff --git a/paper/scripts/build_dataset/build_mert_dataset.sh b/paper/scripts/build_dataset/build_mert_dataset.sh index f23179f..45f1215 100644 --- a/paper/scripts/build_dataset/build_mert_dataset.sh +++ b/paper/scripts/build_dataset/build_mert_dataset.sh @@ -1,6 +1,15 @@ python /home/loubb/work/aria/paper/scripts/build_embedding_eval_datasets.py \ --model mert \ - --dataset_load_path /mnt/ssd1/aria/data/paper/clas/genre-aria/train-mididict.jsonl \ - --dataset_save_path /mnt/ssd1/aria/data/paper/clas/genre-aria/train-mert.jsonl \ + --dataset_load_path /mnt/ssd1/aria/data/paper/clas/pianist/train-mididict.jsonl \ + --dataset_save_path /mnt/ssd1/aria/data/paper/clas/pianist/train-mert.jsonl \ --mert_pianoteq_exec_path "/home/loubb/pianoteq/x86-64bit/Pianoteq 8 STAGE" \ - --mert_pianoteq_num_procs 8 \ No newline at end of file + --mert_pianoteq_num_procs 16 \ + --compute_per_file_embeddings + +python /home/loubb/work/aria/paper/scripts/build_embedding_eval_datasets.py \ + --model mert \ + --dataset_load_path /mnt/ssd1/aria/data/paper/clas/pianist/test-mididict.jsonl \ + --dataset_save_path /mnt/ssd1/aria/data/paper/clas/pianist/test-mert.jsonl \ + --mert_pianoteq_exec_path "/home/loubb/pianoteq/x86-64bit/Pianoteq 8 STAGE" \ + --mert_pianoteq_num_procs 16 \ + --compute_per_file_embeddings \ No newline at end of file diff --git a/paper/scripts/evaluate_embeddings.sh b/paper/scripts/evaluate_embeddings.sh index 5211cdb..d992cf1 100644 --- a/paper/scripts/evaluate_embeddings.sh +++ b/paper/scripts/evaluate_embeddings.sh @@ -1,6 +1,10 @@ +MODEL="m3" +CATEGORY="pianist" +echo "Evaluating model ${MODEL} on category: ${CATEGORY}" + python /home/loubb/work/aria/paper/scripts/evaluate_embedding_with_probe.py \ - --model aria \ - --metadata_category genre \ - --train_dataset_path /mnt/ssd1/aria/data/paper/clas/genre-aria/train-aria.jsonl \ - --test_dataset_path /mnt/ssd1/aria/data/paper/clas/genre-test-maestro_pijama/combined-aria-perfile.jsonl \ - --num_epochs 1 \ No newline at end of file + --model $MODEL \ + --metadata_category $CATEGORY \ + --train_dataset_path "/mnt/ssd1/aria/data/paper/clas/${CATEGORY}/train-${MODEL}.jsonl" \ + --test_dataset_path "/mnt/ssd1/aria/data/paper/clas/${CATEGORY}/test-${MODEL}.jsonl" \ + --num_epochs 50 From c8cc7b8e20c81072ea94864f84c0bdbbd92e998b Mon Sep 17 00:00:00 2001 From: Louis Date: Fri, 14 Mar 2025 14:55:09 +0000 Subject: [PATCH 29/72] update datasets/training/model scripts to support embedding conditioning --- aria/datasets.py | 418 ++++++----------------- aria/embeddings/eval.py | 649 ------------------------------------ aria/embeddings/evaluate.py | 11 +- aria/model.py | 117 +++++-- aria/run.py | 50 +-- aria/tokenizer.py | 6 + aria/train.py | 138 ++++---- config/config.json | 36 +- 8 files changed, 283 insertions(+), 1142 deletions(-) delete mode 100644 aria/embeddings/eval.py diff --git a/aria/datasets.py b/aria/datasets.py index dc2d6ae..1f79c6c 100644 --- a/aria/datasets.py +++ b/aria/datasets.py @@ -16,17 +16,14 @@ from mido.midifiles.units import second2tick from pathlib import Path from typing import List -from copy import deepcopy from typing import Callable, Iterable from collections import defaultdict from aria.config import load_config -from aria.tokenizer import InferenceAbsTokenizer from ariautils.tokenizer import Tokenizer from ariautils.midi import ( MidiDict, get_test_fn, - get_duration_ms, get_metadata_fn, ) @@ -477,7 +474,7 @@ def __init__(self, tokenizer: Tokenizer): def build(**kwargs): raise NotImplementedError - def get_loss_mask(self, src_seq: list, tgt_seq: list): + def get_loss_mask(self, src_seq: list, tgt_seq: list, offset: int = 0): # Should returns a bool Tensor with False indicating a masked loss raise NotImplementedError @@ -598,8 +595,10 @@ def _format(tok): mmap_obj = self.file_mmaps[file_idx] mmap_obj.seek(pos) - _debug = mmap_obj.readline() - seq = json.loads(_debug) # Load raw seq + entry_raw = mmap_obj.readline() + entry_dict = json.loads(entry_raw) + + seq = entry_dict["seq"] # Load raw seq seq = [_format(tok) for tok in seq] # Format into hashable if self._transform: seq = self._transform(seq) # Data augmentation @@ -607,11 +606,13 @@ def _format(tok): src = seq tgt = seq[1:] + [self.tokenizer.pad_tok] mask = self.get_loss_mask(src_seq=src, tgt_seq=tgt) + emb = entry_dict.get("emb", None) return ( torch.tensor(self.tokenizer.encode(src)), torch.tensor(self.tokenizer.encode(tgt)), mask, + torch.tensor(emb) if emb is not None else torch.empty(0), ) def check_config(self, epoch_load_path: str): @@ -712,6 +713,8 @@ def _get_seqs( else: raise Exception + _file_path = _midi_dict.metadata["abs_load_path"] + try: if _tokenize_fn is not None: _tokenized_seq = _tokenize_fn(_midi_dict) @@ -720,11 +723,13 @@ def _get_seqs( except Exception as e: print(e) logger.info(f"Skipping midi_dict: {e}") + return else: if _tokenizer.unk_tok in _tokenized_seq: logger.warning("Unknown token seen while tokenizing midi_dict") - return _tokenized_seq + + return _tokenized_seq, _file_path def get_seqs( @@ -732,7 +737,7 @@ def get_seqs( midi_dict_iter: Iterable, tokenize_fn: Callable | None = None, ): - # Can't pickle geneator object when start method is spawn + # Can't pickle generator object when start method is spawn if multiprocessing.get_start_method() == "spawn": logging.info( "Converting generator to list due to multiprocessing start method" @@ -781,6 +786,7 @@ def random_selection_itt(iterables: list[Iterable]): pass +# GOAL: Modify this and then rename it, removing ft-dataset from codebase class PretrainingDataset(TrainingDataset): """Torch dataset object yielding sequences formatted for pre-training""" @@ -813,12 +819,13 @@ def build( midi_dataset: MidiDataset = None, midi_dataset_path: str = None, separate_sequences: bool = False, + file_embeddings: dict | None = None, ): """Builds and returns PretrainingDataset.""" - def _build_epoch(_save_path, _midi_dataset): + def _build_concat_epoch(_save_path: str, _midi_dataset: Iterable): + # Sequences are concatenated and sliced with jsonlines.open(_save_path, mode="w") as writer: - # Write tokenizer info into json on first line writer.write( { "tokenizer_config": tokenizer.config, @@ -826,53 +833,75 @@ def _build_epoch(_save_path, _midi_dataset): "max_seq_len": max_seq_len, } ) + seq_buffer = [] + _idx = 0 + for entry, file_path in reservoir( + get_seqs(tokenizer, _midi_dataset), 10 + ): + seq_buffer += entry + + while len(seq_buffer) >= max_seq_len: + writer.write({"seq": seq_buffer[:max_seq_len]}) + seq_buffer = seq_buffer[max_seq_len:] + + _idx += 1 + if _idx % 250 == 0: + logger.info(f"Finished processing {_idx}") - if separate_sequences is False: - buffer = [] - _idx = 0 - for entry in reservoir( - get_seqs(tokenizer, _midi_dataset), 10 - ): - if entry is not None: - buffer += entry - while len(buffer) >= max_seq_len: - writer.write(buffer[:max_seq_len]) - buffer = buffer[max_seq_len:] - - _idx += 1 - if _idx % 250 == 0: - logger.info(f"Finished processing {_idx}") - - if buffer: - buffer += [tokenizer.pad_tok] * ( - max_seq_len - len(buffer) + if seq_buffer: + seq_buffer += [tokenizer.pad_tok] * ( + max_seq_len - len(seq_buffer) + ) + writer.write({"seq": seq_buffer[:max_seq_len]}) + + def _build_epoch_separated( + _save_path: str, + _midi_dataset: Iterable, + _file_embeddings: dict | None, + ): + # Sequences always start with a new entry (requires padding) + with jsonlines.open(_save_path, mode="w") as writer: + writer.write( + { + "tokenizer_config": tokenizer.config, + "tokenizer_name": tokenizer.name, + "max_seq_len": max_seq_len, + } + ) + _idx = 0 + for entry, file_path in reservoir( + get_seqs(tokenizer, _midi_dataset), 10 + ): + seq_buffer = entry + embedding_data = ( + {"emb": _file_embeddings[file_path]} + if _file_embeddings + else {} + ) + + while len(seq_buffer) >= max_seq_len: + writer.write( + { + "seq": seq_buffer[:max_seq_len], + **embedding_data, + } ) - writer.write(buffer[:max_seq_len]) - - elif separate_sequences is True: - _idx = 0 - for entry in reservoir( - get_seqs(tokenizer, _midi_dataset), 10 - ): - if entry is None: - continue - - buffer = entry - while len(buffer) >= max_seq_len: - writer.write(buffer[:max_seq_len]) - buffer = buffer[max_seq_len:] - - if buffer: - buffer += [tokenizer.pad_tok] * ( - max_seq_len - len(buffer) - ) - writer.write(buffer[:max_seq_len]) - - _idx += 1 - if _idx % 250 == 0: - logger.info(f"Finished processing {_idx}") - else: - raise ValueError + seq_buffer = seq_buffer[max_seq_len:] + + if seq_buffer: + seq_buffer += [tokenizer.pad_tok] * ( + max_seq_len - len(seq_buffer) + ) + writer.write( + { + "seq": seq_buffer[:max_seq_len], + **embedding_data, + } + ) + + _idx += 1 + if _idx % 250 == 0: + logger.info(f"Finished processing {_idx}") logger = setup_logger() assert max_seq_len > 0, "max_seq_len must be greater than 0" @@ -914,114 +943,24 @@ def _build_epoch(_save_path, _midi_dataset): if midi_dataset_path: midi_dataset = MidiDataset.get_generator(midi_dataset_path) - _build_epoch( - _save_path=os.path.join(save_dir, f"epoch{idx}.jsonl"), - _midi_dataset=midi_dataset, - ) + if separate_sequences is True: + _build_epoch_separated( + _save_path=os.path.join(save_dir, f"epoch{idx}.jsonl"), + _midi_dataset=midi_dataset, + _file_embeddings=file_embeddings, + ) + else: + _build_concat_epoch( + _save_path=os.path.join(save_dir, f"epoch{idx}.jsonl"), + _midi_dataset=midi_dataset, + ) logger.info( f"Finished building, saved PretrainingDataset to {save_dir}" ) -# TODO: Refactor for readability -def _get_combined_mididict( - clean_midi_dict: MidiDict, - noisy_midi_dict: MidiDict, - min_noisy_ms: int, - max_noisy_ms: int, - min_clean_ms: int, - max_clean_ms: int, -) -> MidiDict: - # NOTE: We adopt the tempo/ticks_per_beat of the clean_midi_dict, and - # adjust the noisy note messages accordingly. - assert len(clean_midi_dict.tempo_msgs) == 1, "Unsupported tempo msgs" - assert len(noisy_midi_dict.tempo_msgs) == 1, "Unsupported tempo msgs" - - total_length_ms = get_duration_ms( - start_tick=0, - end_tick=clean_midi_dict.note_msgs[-1]["data"]["start"], - tempo_msgs=clean_midi_dict.tempo_msgs, - ticks_per_beat=clean_midi_dict.ticks_per_beat, - ) - - # Create intervals - noisy_intervals = [] - clean_intervals = [] - prev_ms = -1 - add_noisy_next = random.choice([True, False]) - while True: - if add_noisy_next is True: - # Add noisy interval - noisy_end_ms = random.randint( - prev_ms + min_noisy_ms, prev_ms + max_noisy_ms - ) - noisy_intervals.append([prev_ms + 1, noisy_end_ms]) - prev_ms = noisy_end_ms - if prev_ms > total_length_ms: - break - else: - add_noisy_next = False - else: - # Add clean interval - clean_end_ms = random.randint( - prev_ms + min_clean_ms, prev_ms + max_clean_ms - ) - clean_intervals.append([prev_ms + 1, clean_end_ms]) - prev_ms = clean_end_ms - if prev_ms > total_length_ms: - break - else: - add_noisy_next = True - - # Merge note_msgs - clean_ms_to_tick = (clean_midi_dict.ticks_per_beat * 1e3) / ( - clean_midi_dict.tempo_msgs[0]["data"] - ) - - comb_note_msgs = [] - for _note_msg in noisy_midi_dict.note_msgs: - onset_time_ms = noisy_midi_dict.tick_to_ms(_note_msg["data"]["start"]) - - for _interval_start_ms, _interval_end_ms in noisy_intervals: - if _interval_start_ms < onset_time_ms < _interval_end_ms: - offset_time_ms = noisy_midi_dict.tick_to_ms( - _note_msg["data"]["end"] - ) - _adj_note_msg = copy.deepcopy(_note_msg) - _adj_onset_tick = int(onset_time_ms * clean_ms_to_tick) - _adj_offset_tick = int(offset_time_ms * clean_ms_to_tick) - _adj_note_msg["tick"] = _adj_onset_tick - _adj_note_msg["data"]["start"] = _adj_onset_tick - _adj_note_msg["data"]["end"] = _adj_offset_tick - - comb_note_msgs.append(_adj_note_msg) - break - - for _note_msg in clean_midi_dict.note_msgs: - onset_time_ms = clean_midi_dict.tick_to_ms(_note_msg["data"]["start"]) - - for _interval_start_ms, _interval_end_ms in clean_intervals: - if _interval_start_ms < onset_time_ms < _interval_end_ms: - comb_note_msgs.append(_note_msg) - break - - comb_metadata = deepcopy(clean_midi_dict.metadata) - comb_metadata["noisy_intervals"] = noisy_intervals - - # Maybe using clean pedal msgs here is bad? - return MidiDict( - meta_msgs=clean_midi_dict.meta_msgs, - tempo_msgs=clean_midi_dict.tempo_msgs, - pedal_msgs=clean_midi_dict.pedal_msgs, - instrument_msgs=clean_midi_dict.instrument_msgs, - note_msgs=comb_note_msgs, - ticks_per_beat=clean_midi_dict.ticks_per_beat, - metadata=comb_metadata, - ) - - -# TODO: Refactor this function for readability +# Unused but potentially useful in the future def _noise_midi_dict(midi_dict: MidiDict, config: dict): def _get_velocity_adjusted_msg( __note_msg: dict, @@ -1182,168 +1121,3 @@ def _get_onset_adjusted_msg( ticks_per_beat=midi_dict.ticks_per_beat, metadata=midi_dict.metadata, ) - - -def export_inference_abs_build_tokenize_fn( - midi_dict: MidiDict, tokenizer: InferenceAbsTokenizer -): - finetuning_config = load_config()["data"]["finetuning"] - GUIDANCE_PROB = finetuning_config["guidance_prob"] - NOISING_PROB = finetuning_config["noising"]["activation_prob"] - MIN_NOISY_MS = finetuning_config["min_noisy_interval_ms"] - MAX_NOISY_MS = finetuning_config["max_noisy_interval_ms"] - MIN_CLEAN_MS = finetuning_config["min_clean_interval_ms"] - MAX_CLEAN_MS = finetuning_config["max_clean_interval_ms"] - - if random.random() <= NOISING_PROB: - noisy_midi_dict = _noise_midi_dict( - midi_dict, config=finetuning_config["noising"] - ) - midi_dict_for_tokenization = _get_combined_mididict( - clean_midi_dict=midi_dict, - noisy_midi_dict=noisy_midi_dict, - min_noisy_ms=MIN_NOISY_MS, - max_noisy_ms=MAX_NOISY_MS, - min_clean_ms=MIN_CLEAN_MS, - max_clean_ms=MAX_CLEAN_MS, - ) - else: - midi_dict_for_tokenization = midi_dict - - if random.random() <= GUIDANCE_PROB: - return tokenizer.tokenize( - midi_dict=midi_dict_for_tokenization, - prompt_intervals_ms=midi_dict_for_tokenization.metadata.get( - "noisy_intervals", [] - ), - guidance_midi_dict=midi_dict, - ) - else: - return tokenizer.tokenize( - midi_dict=midi_dict_for_tokenization, - prompt_intervals_ms=midi_dict_for_tokenization.metadata.get( - "noisy_intervals", [] - ), - ) - - -class FinetuningDataset(TrainingDataset): - """Torch dataset object yielding sequences formatted for fine-tuning.""" - - def __init__( - self, dir_paths: List[str] | str, tokenizer: InferenceAbsTokenizer - ): - super().__init__(tokenizer=tokenizer) - - assert tokenizer.name == "inference_abs", "invalid tokenizer" - - if isinstance(dir_paths, str): - dir_paths = [dir_paths] - - self.dir_paths = dir_paths - self.get_epoch_files_by_dir(dir_paths) - self.init_epoch(0) - - def __len__(self): - return len(self.index) - - def get_loss_mask(self, src_seq: list, tgt_seq: list): - mask = [False] * len(tgt_seq) - inside_target = True - - for idx, (src_tok, tgt_tok) in enumerate(zip(src_seq, tgt_seq)): - if src_tok == self.tokenizer.guidance_start_tok: - inside_target = False - elif src_tok == self.tokenizer.guidance_end_tok: - inside_target = True - elif tgt_tok == self.tokenizer.prompt_start_tok: - inside_target = False - elif src_tok == self.tokenizer.prompt_end_tok: - inside_target = True - - if inside_target is True and tgt_tok != self.tokenizer.pad_tok: - mask[idx] = True - - return torch.tensor(mask, dtype=torch.bool) - - @classmethod - def build( - cls, - tokenizer: InferenceAbsTokenizer, - save_dir: str, - max_seq_len: int, - num_epochs: int, - midi_dataset_path: str, - ): - - def _build_epoch(_save_path, _midi_dataset): - with jsonlines.open(_save_path, mode="w") as writer: - # Write tokenizer info into json on first line - writer.write( - { - "tokenizer_config": tokenizer.config, - "tokenizer_name": tokenizer.name, - "max_seq_len": max_seq_len, - } - ) - - _idx = 0 - for entry in reservoir( - get_seqs( - tokenizer, - _midi_dataset, - tokenize_fn=functools.partial( - export_inference_abs_build_tokenize_fn, - tokenizer=tokenizer, - ), - ), - 10, - ): - for _entry in tokenizer.split(entry, max_seq_len): - writer.write(_entry) - - _idx += 1 - if _idx % 250 == 0: - logger.info(f"Finished processing {_idx}") - - logger = setup_logger() - assert max_seq_len > 0, "max_seq_len must be greater than 0" - assert num_epochs > 0, "num_epochs must be greater than 0" - assert os.path.isfile(midi_dataset_path), "file not found" - if multiprocessing.get_start_method() == "spawn": - logger.warning( - 'The current multiprocessing start method is "spawn", this ' - "will slow down dataset building" - ) - - if os.path.isdir(save_dir) and os.listdir(save_dir): - print( - f"The directory at {save_dir} in non-empty, type [Y/y] to " - "remove and continue:" - ) - if input() not in {"Y", "y"}: - print("Aborting") - return - else: - shutil.rmtree(save_dir) - - if not os.path.exists(save_dir): - os.mkdir(save_dir) - - logger.info( - f"Building FinetuningDataset with config: " - f"max_seq_len={max_seq_len}, " - f"tokenizer_name={tokenizer.name}" - ) - - for idx in range(num_epochs): - logger.info(f"Building epoch {idx}/{num_epochs - 1}...") - - # Reload the combined dataset for each epoch - midi_dataset = MidiDataset.get_generator(midi_dataset_path) - _build_epoch( - _save_path=os.path.join(save_dir, f"epoch{idx}.jsonl"), - _midi_dataset=midi_dataset, - ) - - logger.info(f"Finished building, saved FinetuningDataset to {save_dir}") diff --git a/aria/embeddings/eval.py b/aria/embeddings/eval.py deleted file mode 100644 index 635a329..0000000 --- a/aria/embeddings/eval.py +++ /dev/null @@ -1,649 +0,0 @@ -import torch -import accelerate -import os -import mmap -import json -import time -import functools -import multiprocessing -import queue -import copy -import jsonlines -import torch.nn as nn -import torch.nn.functional as F - -from tqdm import tqdm -from typing import Callable -from concurrent.futures import ThreadPoolExecutor - -from aria.model import ModelConfig, TransformerLM, TransformerEMB -from aria.config import load_model_config -from aria.utils import _load_weight -from ariautils.midi import MidiDict -from ariautils.tokenizer import AbsTokenizer - -METADATA_CATEGORY = "composer" -TAG_TO_ID = { - "chopin": 0, - "bach": 1, - "beethoven": 2, - "liszt": 3, - "mozart": 4, - "debussy": 5, - "schumann": 6, - "schubert": 7, - "rachmaninoff": 8, - "brahms": 9, - "tchaikovsky": 10, - "haydn": 11, - "scriabin": 12, - "mendelssohn": 13, - "czerny": 14, - "ravel": 15, - "scarlatti": 16, - "other": 17, -} -ID_TO_TAG = {v: k for k, v in TAG_TO_ID.items()} - - -def model_forward( - model: nn.Module, - idxs: torch.Tensor, -): - return model(idxs) - - -def chunk_and_pad(lst: list, n: int): - return [lst[i : i + n] for i in range(0, len(lst), n)] - - -def write_entries(writer, entries): - for entry in entries: - writer.write(entry) - - -def process_entry( - entry, - slice_len_notes: int, - max_seq_len: int, - tokenizer: AbsTokenizer, -): - midi_dict = MidiDict.from_msg_dict(entry) - - outputs = [] - for slice_note_msgs in chunk_and_pad( - lst=midi_dict.note_msgs, n=slice_len_notes - ): - if len(slice_note_msgs) < 20: - break - - slice_midi_dict = copy.deepcopy(midi_dict) - slice_midi_dict.note_msgs = slice_note_msgs - slice_midi_dict.metadata = {} - tokenized_slice = tokenizer.tokenize(slice_midi_dict) - if tokenizer.dim_tok in tokenized_slice: - tokenized_slice.remove(tokenizer.dim_tok) - - tokenized_slice = tokenized_slice[:max_seq_len] - - outputs.append({"seq": tokenized_slice, "metadata": midi_dict.metadata}) - - return outputs - - -def _pad_seq(seq: list, tokenizer: AbsTokenizer, max_seq_len: int): - seq = seq[:max_seq_len] - seq += [tokenizer.pad_tok] * (max_seq_len - len(seq)) - - if tokenizer.eos_tok not in seq: - seq[-1] = tokenizer.eos_tok - - return seq - - -@torch.autocast("cuda", dtype=torch.bfloat16) -@torch.inference_mode() -def get_contrastive_embedding( - seqs: list, - hook_model: nn.Module, - hook_max_seq_len: int, - hook_tokenizer: AbsTokenizer, - hook_model_forward: Callable, -): - seqs = [ - _pad_seq( - seq=seq, tokenizer=hook_tokenizer, max_seq_len=hook_max_seq_len - ) - for seq in seqs - ] - - eos_positions = [seq.index(hook_tokenizer.eos_tok) for seq in seqs] - enc_seqs = torch.tensor( - [hook_tokenizer.encode(seq) for seq in seqs], device="cuda" - ) - hidden_states = hook_model_forward(model=hook_model, idxs=enc_seqs) - idx = torch.arange(hidden_states.shape[0], device=hidden_states.device) - emb = hidden_states[idx, eos_positions].tolist() - - return emb - - -@torch.autocast("cuda", dtype=torch.bfloat16) -@torch.inference_mode() -def get_baseline_embedding( - seqs: list, - hook_model: nn.Module, - hook_max_seq_len: int, - hook_tokenizer: AbsTokenizer, - pool_mode: str = "last", # "last" or "mean" -): - for seq in seqs: - if hook_tokenizer.eos_tok in seq: - seq.remove(hook_tokenizer.eos_tok) - - orig_lengths = [len(seq) for seq in seqs] - last_tok_positions = [length - 1 for length in orig_lengths] - seqs = [ - seq + ([hook_tokenizer.pad_tok] * (hook_max_seq_len - len(seq))) - for seq in seqs - ] - - enc_seqs = torch.tensor( - [hook_tokenizer.encode(seq) for seq in seqs], device="cuda" - ) - hidden_states = hook_model(enc_seqs) - - if pool_mode == "last": - idx = torch.arange(hidden_states.shape[0], device=hidden_states.device) - emb = hidden_states[idx, last_tok_positions].tolist() - elif pool_mode == "mean": - pad_id = hook_tokenizer.pad_id - # Create a mask by comparing enc_seqs to pad_id. - mask = (enc_seqs != pad_id).unsqueeze(-1).to(hidden_states.dtype) - # Sum over valid tokens and average. - sum_hidden = (hidden_states * mask).sum(dim=1) - valid_counts = mask.sum(dim=1) - mean_hidden = sum_hidden / valid_counts - emb = mean_hidden.tolist() - else: - raise ValueError(f"Unsupported pool_mode: {pool_mode}") - - return emb - - -class EvaluationDataset(torch.utils.data.Dataset): - def __init__(self, load_path: str, tag_ids: dict, metadata_category: str): - self.load_path = load_path - self.tag_ids = tag_ids - self.metadata_category = metadata_category - self.tokenizer = AbsTokenizer() - self.index = [] - - self.file_buff = open(self.load_path, "rb") - self.mmap_obj = mmap.mmap( - self.file_buff.fileno(), 0, access=mmap.ACCESS_READ - ) - - while True: - pos = self.mmap_obj.tell() - line = self.mmap_obj.readline() - if not line: - break - self.index.append(pos) - - def __getitem__(self, idx: int): - pos = self.index[idx] - self.mmap_obj.seek(pos) - - raw_data = self.mmap_obj.readline().decode("utf-8") - json_data = json.loads(raw_data) - - emb = json_data["emb"] - metadata = json_data["metadata"] - tag = metadata.get(self.metadata_category, "other") - tag = tag if tag in self.tag_ids.keys() else "other" - - assert tag in self.tag_ids - tag_tensor = torch.tensor(self.tag_ids[tag]) - emb_tensor = torch.tensor(emb) - - return emb_tensor, tag_tensor - - def __len__(self): - return len(self.index) - - @classmethod - def export_worker_init_fn(cls): - def worker_init_fn(worker_id: int): - worker_info = torch.utils.data.get_worker_info() - dataset = worker_info.dataset - - if hasattr(dataset, "mmap_obj") and dataset.mmap_obj: - dataset.mmap_obj.close() - - f = open(dataset.load_path, "rb") - dataset.mmap_obj = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) - - return worker_init_fn - - @classmethod - def build( - cls, - midi_dataset_load_path: str, - save_path: str, - slice_len_notes: int, - max_seq_len: int, - batch_size: int, - embedding_hook: Callable, - **embedding_hook_kwargs, - ): - def batch_producer( - results_queue: queue.Queue, - batch_queue: queue.Queue, - batch_size: int, - total_workers: int, - ): - buffer = [] - termination_signals = 0 - while termination_signals < total_workers: - if batch_queue.qsize() >= 5: - time.sleep(1) - try: - result = results_queue.get(timeout=0.01) - if result is None: - termination_signals += 1 - continue - buffer.append(result) - if len(buffer) == batch_size: - batch_queue.put(buffer) - buffer = [] - except queue.Empty: - continue - - if buffer: - batch_queue.put(buffer) - - def producer( - midi_dataset_load_path: str, - midi_dict_queue: queue.Queue, - num_workers: int, - ): - cnt = 0 - with jsonlines.open(midi_dataset_load_path, "r") as midi_dataset: - for midi_dict in midi_dataset: - while midi_dict_queue.qsize() >= 1000: - time.sleep(0.1) - midi_dict_queue.put(midi_dict) - cnt += 1 - - if cnt % 500 == 0: - print(f"Finished {cnt}") - - for _ in range(num_workers): - midi_dict_queue.put(None) - - def worker( - midi_dict_queue: queue.Queue, - results_queue: queue.Queue, - slice_len_notes: int, - max_seq_len: int, - ): - tokenizer = AbsTokenizer() - - while True: - midi_dict = midi_dict_queue.get() - if midi_dict is None: - results_queue.put(None) - break - - while results_queue.qsize() > 1000: - time.sleep(0.5) - - _result = process_entry( - entry=midi_dict, - slice_len_notes=slice_len_notes, - max_seq_len=max_seq_len, - tokenizer=tokenizer, - ) - for _sub_result in _result: - results_queue.put(_sub_result) - - assert os.path.isfile(midi_dataset_load_path) - assert os.path.isfile(save_path) is False - - TOTAL_WORKERS = 8 - write_executor = ThreadPoolExecutor(max_workers=1) - results_queue = multiprocessing.Queue() - midi_dict_queue = multiprocessing.Queue() - batch_queue = multiprocessing.Queue() - producer_process = multiprocessing.Process( - target=producer, - args=(midi_dataset_load_path, midi_dict_queue, TOTAL_WORKERS), - ) - batch_producer_process = multiprocessing.Process( - target=batch_producer, - args=(results_queue, batch_queue, batch_size, TOTAL_WORKERS), - ) - worker_processes = [ - multiprocessing.Process( - target=worker, - args=( - midi_dict_queue, - results_queue, - slice_len_notes, - max_seq_len, - ), - ) - for _ in range(TOTAL_WORKERS) - ] - - producer_process.start() - batch_producer_process.start() - for p in worker_processes: - p.start() - - with jsonlines.open(save_path, "w") as writer: - while batch_producer_process.is_alive() or not batch_queue.empty(): - try: - batch = batch_queue.get(timeout=0.01) - - _seqs = [item["seq"] for item in batch] - _metadata = [item["metadata"] for item in batch] - _embs = embedding_hook(seqs=_seqs, **embedding_hook_kwargs) - - write_objs = [ - {"seq": s, "emb": e, "metadata": m} - for s, e, m in zip(_seqs, _embs, _metadata) - ] - write_executor.submit(write_entries, writer, write_objs) - - except queue.Empty: - continue - - write_executor.shutdown(wait=True) - - -def _get_optim( - lr: float, - model: nn.Module, - total_steps: int, - warmup: int = 100, - end_ratio: int = 0.1, -): - optimizer = torch.optim.AdamW( - model.parameters(), - lr=lr, - weight_decay=0.1, - betas=(0.9, 0.95), - eps=1e-5, - ) - - warmup_lrs = torch.optim.lr_scheduler.LinearLR( - optimizer, - start_factor=0.000001, - end_factor=1, - total_iters=warmup, - ) - linear_decay_lrs = torch.optim.lr_scheduler.LinearLR( - optimizer, - start_factor=1, - end_factor=end_ratio, - total_iters=total_steps - warmup, - ) - - lr_scheduler = torch.optim.lr_scheduler.SequentialLR( - optimizer, - schedulers=[warmup_lrs, linear_decay_lrs], - milestones=[warmup], - ) - - return optimizer, lr_scheduler - - -class ClassifierHead(nn.Module): - def __init__(self, d_emb: int, num_class: int): - super().__init__() - self.linear = nn.Linear(d_emb, num_class) - - def forward(self, x: torch.Tensor): - return self.linear(x) - - -def _train( - accelerator: accelerate.Accelerator, - model: nn.Module, - train_dataloader: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler.LRScheduler, - num_epochs: int = 1, -): - TRAILING_LOSS_STEPS = 100 - loss = torch.tensor([0.0]) - trailing_loss = 0 - lr_for_print = "{:.2e}".format(optimizer.param_groups[-1]["lr"]) - loss_buffer = [] - - model.train() - loss_fn = nn.CrossEntropyLoss() - - for _epoch in range(num_epochs): - for __step, batch in ( - pbar := tqdm(enumerate(train_dataloader), leave=False) - ): - pbar.set_postfix_str( - f"lr={lr_for_print}, " - f"loss={round(loss.item(), 4)}, " - f"trailing={round(trailing_loss, 4)}" - ) - - emb, tag_ids = batch - tag_ids = tag_ids.view(-1) - - logits = model(emb) - loss = loss_fn(logits, tag_ids) - - loss_buffer.append(accelerator.gather(loss).mean(dim=0).item()) - trailing_loss = sum(loss_buffer[-TRAILING_LOSS_STEPS:]) / len( - loss_buffer[-TRAILING_LOSS_STEPS:] - ) - - accelerator.backward(loss) - optimizer.step() - optimizer.zero_grad() - if scheduler: - scheduler.step() - lr_for_print = "{:.2e}".format(scheduler.get_last_lr()[0]) - - if accelerator.is_main_process: - accelerator.save_state("/mnt/ssd1/aria/test") - - return model - - -def train_classifier( - emb_d: int, - train_dataset: EvaluationDataset, - tag_ids: dict, - batch_size: int, -): - num_epochs = 1 - train_dataloader = torch.utils.data.DataLoader( - dataset=train_dataset, - batch_size=batch_size, - shuffle=True, - num_workers=24, - worker_init_fn=EvaluationDataset.export_worker_init_fn(), - ) - - model = ClassifierHead( - d_emb=emb_d, - num_class=len(tag_ids.keys()), - ) - optimizer, scheduler = _get_optim( - lr=3e-4, - model=model, - total_steps=num_epochs * len(train_dataloader), - ) - accelerator = accelerate.Accelerator() - - model, train_dataloader, optimizer, scheduler = accelerator.prepare( - model, - train_dataloader, - optimizer, - scheduler, - ) - - return _train( - accelerator=accelerator, - model=model, - train_dataloader=train_dataloader, - optimizer=optimizer, - scheduler=scheduler, - num_epochs=num_epochs, - ) - - -def evaluate_model(model: nn.Module, val_dataset_path: str): - val_dataset = EvaluationDataset( - load_path=val_dataset_path, - tag_ids=TAG_TO_ID, - metadata_category=METADATA_CATEGORY, - ) - model = model.cpu().eval() - - # Count true values and correct predictions per tag. - dist = {k: {"correct": 0, "total": 0} for k in TAG_TO_ID.keys()} - # New dictionary to count predictions per tag. - pred_dist = {k: 0 for k in TAG_TO_ID.keys()} - - for midi_emb, tag_id in val_dataset: - with torch.no_grad(): - logits = model(torch.tensor(midi_emb.view(1, -1))) - probs = F.softmax(logits, dim=-1) - pred_tag_id = probs.argmax(dim=-1).item() - - true_tag = ID_TO_TAG[tag_id.item()] - pred_tag = ID_TO_TAG[pred_tag_id] - - dist[true_tag]["total"] += 1 - pred_dist[pred_tag] += 1 - - if pred_tag_id == tag_id.item(): - dist[true_tag]["correct"] += 1 - - total_correct = sum(v["correct"] for v in dist.values()) - total_samples = sum(v["total"] for v in dist.values()) - print(f"Total accuracy: {total_correct/total_samples}") - - for tag in TAG_TO_ID.keys(): - TP = dist[tag]["correct"] - FN = dist[tag]["total"] - TP - FP = pred_dist[tag] - TP - precision = TP / (TP + FP) if (TP + FP) > 0 else 0 - recall = TP / (TP + FN) if (TP + FN) > 0 else 0 - f1 = ( - 2 * precision * recall / (precision + recall) - if (precision + recall) > 0 - else 0 - ) - print( - f"{tag} -- Accuracy: {TP/dist[tag]['total']}, Precision: {precision}, Recall: {recall}, F1: {f1}" - ) - - -def build_baseline_dataset(): - MAX_SEQ_LEN = 512 - MODEL_PATH = "/mnt/ssd1/aria/v2/medium-dedupe-pt-cont2/checkpoints/epoch18_step0/model.safetensors" - - tokenizer = AbsTokenizer() - model_state = _load_weight(MODEL_PATH, "cuda") - model_state = { - k.replace("_orig_mod.", ""): v for k, v in model_state.items() - } - pretrained_model_config = ModelConfig(**load_model_config("medium")) - pretrained_model_config.set_vocab_size(tokenizer.vocab_size) - pretrained_model_config.grad_checkpoint = False - pretrained_model = TransformerLM(pretrained_model_config) - pretrained_model.load_state_dict(model_state) - pretrained_model.eval() - - global model_forward - model_forward = torch.compile( - model_forward, - mode="reduce-overhead", - fullgraph=True, - ) - - EvaluationDataset.build( - midi_dataset_load_path="/mnt/ssd1/aria/data/mididict-ft_train.jsonl", - save_path="/mnt/ssd1/aria/data/train.jsonl", - max_seq_len=MAX_SEQ_LEN, - slice_len_notes=165, - batch_size=128, - embedding_hook=functools.partial( - get_baseline_embedding, pool_mode="mean" - ), - hook_model=pretrained_model.model.cuda(), - hook_max_seq_len=MAX_SEQ_LEN, - hook_tokenizer=tokenizer, - ) - - -def build_contrastive_dataset(): - MAX_SEQ_LEN = 1024 - MODEL_PATH = ( - "/home/loubb/work/aria/models/medium-emb-t0.5-s1024-e20.safetensors" - ) - - tokenizer = AbsTokenizer() - model_state = _load_weight(MODEL_PATH, "cuda") - model_state = { - k.replace("_orig_mod.", ""): v for k, v in model_state.items() - } - pretrained_model_config = ModelConfig(**load_model_config("medium-emb")) - pretrained_model_config.set_vocab_size(tokenizer.vocab_size) - pretrained_model_config.grad_checkpoint = False - pretrained_model = TransformerEMB(pretrained_model_config) - pretrained_model.load_state_dict(model_state) - pretrained_model.eval() - - hook_model_forward = torch.compile( - model_forward, - mode="reduce-overhead", - fullgraph=True, - ) - - EvaluationDataset.build( - midi_dataset_load_path="/mnt/ssd1/aria/data/mididict-all_train.jsonl", - save_path="/mnt/ssd1/aria/data/eval/test.jsonl", - max_seq_len=MAX_SEQ_LEN, - slice_len_notes=300, - batch_size=256, - embedding_hook=get_contrastive_embedding, - hook_model=pretrained_model.cuda(), - hook_max_seq_len=MAX_SEQ_LEN, - hook_tokenizer=tokenizer, - hook_model_forward=hook_model_forward, - ) - - -if __name__ == "__main__": - # tokenizer = AbsTokenizer() - # dataset = EvaluationDataset( - # load_path="/mnt/ssd1/aria/data/eval/temp-train.jsonl", - # tag_ids=TAG_TO_ID, - # metadata_category=METADATA_CATEGORY, - # ) - - # model = train_classifier( - # emb_d=512, - # train_dataset=dataset, - # batch_size=32, - # tag_ids=TAG_TO_ID, - # ) - # evaluate_model( - # model=model, - # val_dataset_path="/mnt/ssd1/aria/data/eval/temp-val.jsonl", - # ) - - build_contrastive_dataset() diff --git a/aria/embeddings/evaluate.py b/aria/embeddings/evaluate.py index 97e949b..df2042d 100644 --- a/aria/embeddings/evaluate.py +++ b/aria/embeddings/evaluate.py @@ -79,8 +79,11 @@ def write_entries(writer, entries): writer.write(entry) -def chunk_and_pad(lst: list, n: int): - return [lst[i : i + n] for i in range(0, len(lst), n)] +def get_chunks(note_msgs: list, chunk_len: int): + return [ + note_msgs[i : i + chunk_len] + for i in range(0, len(note_msgs), chunk_len) + ] def process_entry( @@ -92,8 +95,8 @@ def process_entry( midi_dict = MidiDict.from_msg_dict(entry) outputs = [] - for slice_note_msgs in chunk_and_pad( - lst=midi_dict.note_msgs, n=slice_len_notes + for slice_note_msgs in get_chunks( + note_msgs=midi_dict.note_msgs, chunk_len=slice_len_notes ): if len(slice_note_msgs) < 20: break diff --git a/aria/model.py b/aria/model.py index 6c85f86..75981ed 100644 --- a/aria/model.py +++ b/aria/model.py @@ -1,4 +1,4 @@ -"""Training implementation.""" +"""Training model implementation.""" from dataclasses import dataclass from typing import Optional @@ -118,10 +118,10 @@ def _ff_block(self, x: torch.Tensor): class Transformer(nn.Module): - """Transformer decoder with no language model head. + """Transformer decoder without a language model head. Args: - model_config (ModelConfig): Model config settings. + model_config (ModelConfig): Model configuration settings. """ def __init__(self, model_config: ModelConfig): @@ -142,19 +142,24 @@ def __init__(self, model_config: ModelConfig): def forward( self, src: torch.Tensor, + emb: torch.Tensor | None = None, ): - """Forward pass of Transformer. + """Perform a forward pass through the transformer. Args: - src (torch.tensor): Input to encoder block, of shape (batch_size, - seq_len, d_model). + src (torch.Tensor): Input tensor of token indices with shape (batch_size, seq_len). + emb (Optional[torch.Tensor]): Optional extra embedding with shape (batch_size, emb_dim). Returns: - torch.tensor: Model outputs with shape (batch_size, seq_len, - d_model). + torch.Tensor: Output tensor with shape (batch_size, seq_len, d_model). """ + hidden_states = self.tok_embeddings(src) + if emb is not None: + emb = emb[:, None, :] + hidden_states = torch.cat([emb, hidden_states[:, :-1, :]], dim=1) + if self.freqs_cis is None: self.freqs_cis = precompute_freqs_cis( seq_len=self.model_config.max_seq_len, @@ -188,10 +193,10 @@ def custom_forward(*args): class TransformerLM(nn.Module): - """Transformer decoder with head for language modelling. + """Transformer decoder with a language modeling head. Args: - model_config (ModelConfig): Model config settings. + model_config (ModelConfig): Model configuration settings (vocab_size must be defined). """ def __init__(self, model_config: ModelConfig): @@ -208,20 +213,15 @@ def forward( self, src: torch.Tensor, ): - """Forward pass of Transformer decoder with LM head. + """Compute language modeling logits. Args: - src (torch.tensor): Input to encoder block, of shape (batch_size, - seq_len, d_model). - attn_mask (Optional[torch.tensor]): Attention mask of shape - (batch_size, seq_len). Defaults to None. - past_kv (Optional[list[KVCache]]): a list of kv caches. The list index - corresponds to the layer index. + src (torch.Tensor): Input tensor of token indices with shape (batch_size, seq_len). Returns: - torch.tensor: Forward pass of src through Transformer and LM head. - Has shape (batch_size, seq_len, vocab_size). + torch.Tensor: Logits with shape (batch_size, seq_len, vocab_size). """ + hidden = self.model(src) logits = self.lm_head(hidden) @@ -229,10 +229,10 @@ def forward( class TransformerCL(nn.Module): - """Transformer decoder with head for classification. + """Transformer decoder with a classification head. Args: - model_config (ModelConfig): Model config settings. + model_config (ModelConfig): Model configuration settings (class_size must be defined). """ def __init__(self, model_config: ModelConfig): @@ -249,27 +249,78 @@ def forward( self, src: torch.Tensor, ): - """Forward pass of Transformer decoder with CL head. + """Compute classification logits. Args: - src (torch.tensor): Input to encoder block, of shape (batch_size, - seq_len, d_model). + src (torch.Tensor): Input tensor of token indices with shape (batch_size, seq_len). Returns: - torch.tensor: Forward pass of src through Transformer and CL head. - Has shape (batch_size, seq_len, class_size). + torch.Tensor: Classification logits with shape (batch_size, seq_len, class_size). """ + hidden = self.model(src) logits = self.class_head(hidden) return logits +class TransformerLM_CND(nn.Module): + """Transformer decoder with a language modeling head and optional conditioning. + + Args: + model_config (ModelConfig): Model configuration settings (vocab_size and emb_size must be defined). + """ + + def __init__(self, model_config: ModelConfig): + super().__init__() + assert model_config.vocab_size is not None + + self.max_seq_len = model_config.max_seq_len + self.model = Transformer(model_config) + self.lm_head = nn.Linear( + model_config.d_model, model_config.vocab_size, bias=False + ) + self.embedding_adapter = nn.Linear( + model_config.emb_size, model_config.d_model, bias=False + ) + + def forward( + self, + src: torch.Tensor, + emb: torch.Tensor | None = None, + ): + """Compute language modeling logits with optional conditioning. + + Args: + src (torch.Tensor): Input tensor of token indices with shape (batch_size, seq_len). + emb (Optional[torch.Tensor]): Optional conditioning embedding with shape (batch_size, emb_size). + + Returns: + torch.Tensor: Logits with shape (batch_size, seq_len, vocab_size). + Note that if the emb is provided, the seq_len will be seq_len -1. + + """ + + if emb is not None: + # Embedding is prepended to sequence via the adapter. We slice the + # logits so that the logits format still matches src. + emb = self.embedding_adapter(emb) + hidden = self.model(src, emb) + logits = self.lm_head(hidden) + + return logits[:, 1:, :] + else: + hidden = self.model(src, emb) + logits = self.lm_head(hidden) + + return logits + + class TransformerEMB(nn.Module): - """Transformer decoder with head for embedding. + """Transformer decoder with an embedding head. Args: - model_config (ModelConfig): Model config settings. + model_config (ModelConfig): Model configuration settings (emb_size must be defined). """ def __init__(self, model_config: ModelConfig): @@ -286,15 +337,15 @@ def forward( self, src: torch.Tensor, ): - """Forward pass of Transformer decoder with EMB head. + """Compute output embeddings from the transformer. Args: - src (torch.tensor): Input to encoder block, of shape (batch_size, - seq_len, d_model). + src (torch.Tensor): Input tensor of token indices with shape (batch_size, seq_len). + Returns: - torch.tensor: Forward pass of src through Transformer and EMB head. - Has shape (batch_size, seq_len, emb_size). + torch.Tensor: Output embeddings with shape (batch_size, seq_len, emb_size). """ + hidden = self.model(src) emb = self.emb_head(hidden) diff --git a/aria/run.py b/aria/run.py index a1302f9..861193b 100644 --- a/aria/run.py +++ b/aria/run.py @@ -2,7 +2,7 @@ import argparse import os -import re +import json import sys @@ -72,6 +72,7 @@ def _parse_sample_args(): return argp.parse_args(sys.argv[2:]) +# TODO: This is all broken due to tokenizer / embedding change -- need to fix def sample(args): """Entrypoint for sampling""" @@ -151,6 +152,8 @@ def sample(args): guidance_end_ms=args.guidance_end_ms, guidance_midi_dict=guidance_midi_dict, ) + # prompt_seq = prompt_seq[:-1] + print(prompt_seq) if len(prompt_seq) + args.l > model_config.max_seq_len: print( @@ -158,13 +161,14 @@ def sample(args): ) prompts = [prompt_seq for _ in range(num_variations)] - samples_dir = os.path.join(os.path.dirname(__file__), "..", "samples") + samples_dir = "/home/loubb/Dropbox/shared" if os.path.isdir(samples_dir) is False: os.mkdir(samples_dir) if guidance_seq: tokenizer.detokenize(guidance_seq).to_midi().save( os.path.join(samples_dir, f"guidance.mid") ) + if args.cfg is not None and guidance_seq is not None: results = sample_batch_cfg( model=model, @@ -257,6 +261,7 @@ def _parse_pretrain_dataset_args(): help="start each with a new entry", action="store_true", ) + argp.add_argument("-embedding_dataset_path", required=False) return argp.parse_args(sys.argv[2:]) @@ -270,6 +275,16 @@ def build_pretraining_dataset(args): elif args.tokenizer_name == "rel": tokenizer = RelTokenizer() + if args.embedding_dataset_path is not None: + with open(args.embedding_dataset_path, "r") as f: + file_embeddings = { + data["metadata"]["abs_load_path"]: data["emb"] + for data in map(json.loads, f) + } + + else: + file_embeddings = None + PretrainingDataset.build( tokenizer=tokenizer, save_dir=args.save_dir, @@ -277,33 +292,7 @@ def build_pretraining_dataset(args): num_epochs=args.e, midi_dataset_path=args.load_path, separate_sequences=args.sep_sequences, - ) - - -def _parse_finetune_dataset_args(): - argp = argparse.ArgumentParser(prog="aria finetune-dataset") - argp.add_argument( - "-midi_dataset_path", - help="path to midi_dict dataset", - ) - argp.add_argument("-save_dir", help="path to save dataset") - argp.add_argument("-l", help="max sequence length", type=int, default=4096) - argp.add_argument("-e", help="num epochs", type=int, default=1) - - return argp.parse_args(sys.argv[2:]) - - -def build_finetune_dataset(args): - from aria.tokenizer import InferenceAbsTokenizer - from aria.datasets import FinetuningDataset - - tokenizer = InferenceAbsTokenizer() - FinetuningDataset.build( - tokenizer=tokenizer, - save_dir=args.save_dir, - max_seq_len=args.l, - num_epochs=args.e, - midi_dataset_path=args.midi_dataset_path, + file_embeddings=file_embeddings, ) @@ -317,7 +306,6 @@ def main(): "sample", "midi-dataset", "pretrain-dataset", - "finetune-dataset", ), ) @@ -335,8 +323,6 @@ def main(): build_midi_dataset(args=_parse_midi_dataset_args()) elif args.command == "pretrain-dataset": build_pretraining_dataset(args=_parse_pretrain_dataset_args()) - elif args.command == "finetune-dataset": - build_finetune_dataset(args=_parse_finetune_dataset_args()) else: print("Unrecognized command") parser.print_help() diff --git a/aria/tokenizer.py b/aria/tokenizer.py index c142405..55ca846 100644 --- a/aria/tokenizer.py +++ b/aria/tokenizer.py @@ -125,9 +125,15 @@ def _add_prompt_tokens( elif ( curr_time_ms > prompt_end_ms and prompt_tok_inserted == True ): + # Res has already been shifted +1 when inserting prompt_tok res.insert(idx + 1, self.prompt_end_tok) break + if prompt_tok_inserted and self.prompt_end_tok not in res: + res.insert(-1, self.prompt_end_tok) + assert res[-1] == self.eos_tok + assert res[-2] == self.prompt_end_tok + return res def tokenize( diff --git a/aria/train.py b/aria/train.py index 5733e49..202cfe4 100644 --- a/aria/train.py +++ b/aria/train.py @@ -9,8 +9,6 @@ from torch import nn as nn from torch.utils.data import DataLoader -from torch.utils.flop_counter import FlopCounterMode -from triton.testing import do_bench from accelerate.logging import get_logger from safetensors.torch import load_file from logging.handlers import RotatingFileHandler @@ -18,13 +16,11 @@ from typing import List from aria.config import load_model_config -from aria.model import ModelConfig, TransformerLM +from aria.model import ModelConfig, TransformerLM, TransformerLM_CND from ariautils.tokenizer import Tokenizer, AbsTokenizer, RelTokenizer -from aria.tokenizer import InferenceAbsTokenizer from aria.datasets import ( TrainingDataset, PretrainingDataset, - FinetuningDataset, ) from aria.utils import _load_weight @@ -213,31 +209,18 @@ def get_dataloaders( tokenizer: Tokenizer, batch_size: int, num_workers: int, + use_embeddings: bool, init_epoch: int | None = None, apply_aug: bool = True, - finetune: bool = False, ): - logger = logging.getLogger(__name__) - if finetune == False: - train_dataset = PretrainingDataset( - dir_paths=train_data_dirs, - tokenizer=tokenizer, - ) - val_dataset = PretrainingDataset( - dir_paths=val_data_dir, - tokenizer=tokenizer, - ) - elif finetune == True: - train_dataset = FinetuningDataset( - dir_paths=train_data_dirs, - tokenizer=tokenizer, - ) - val_dataset = FinetuningDataset( - dir_paths=val_data_dir, - tokenizer=tokenizer, - ) - else: - raise ValueError + train_dataset = PretrainingDataset( + dir_paths=train_data_dirs, + tokenizer=tokenizer, + ) + val_dataset = PretrainingDataset( + dir_paths=val_data_dir, + tokenizer=tokenizer, + ) if init_epoch: train_dataset.init_epoch(idx=init_epoch) @@ -262,6 +245,12 @@ def get_dataloaders( shuffle=False, ) + if use_embeddings is True: + _src, _tgt, _mask, _emb = train_dataset[0] + _src, _tgt, _mask, __emb = val_dataset[0] + assert _emb.numel() != 0, "Embeddings not present in train dataset" + assert __emb.numel() != 0, "Embeddings not present in val dataset" + return train_dataloader, val_dataloader @@ -271,6 +260,7 @@ def _train( model: TransformerLM, train_dataloader: DataLoader, val_dataloader: DataLoader, + use_embeddings: bool, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler = None, steps_per_checkpoint: int | None = None, @@ -278,34 +268,6 @@ def _train( resume_epoch: int | None = None, project_dir: str | None = None, ): - def profile_flops(dataloader: DataLoader): - def _bench(): - for batch in dataloader: - src, tgt = batch # (b_sz, s_len), (b_sz, s_len, v_sz) - logits = model(src) # (b_sz, s_len, v_sz) - logits = logits.transpose(1, 2) - loss = loss_fn(logits, tgt) - - # Backwards step - omit optimizer.step() - accelerator.backward(loss) - optimizer.zero_grad() - break - - logger.info( - f"Model has " - f"{'{:,}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad))} " - "parameters" - ) - - # logger.info("Profiling FLOP") - # flop_counter = FlopCounterMode(display=False) - # _bench() - - # with flop_counter: - # _bench() - # total_flop = sum(flop_counter.get_flop_counts()["Global"].values()) - # logger.info(f"Forwards & backwards FLOP: {total_flop / 1e12} TF") - def make_checkpoint( _accelerator: accelerate.Accelerator, _epoch: int, _step: int ): @@ -353,8 +315,16 @@ def train_loop(dataloader: DataLoader, _epoch: int, _resume_step: int = 0): with accelerator.accumulate(model): step = __step + _resume_step + 1 - src, tgt, mask = batch # (b_sz, s_len), (b_sz, s_len, v_sz) - logits = model(src) # (b_sz, s_len, v_sz) + src, tgt, mask, emb = ( + batch # (b_sz, s_len), (b_sz, s_len), (b_sz, s_len), (b_sz, d_emb) + ) + if use_embeddings is True: + logits = model(src=src, emb=emb) # (b_sz, s_len - 1, v_sz) + tgt = tgt[:, :-1] # (b_sz, s_len - 1) + mask = mask[:, :-1] # (b_sz, s_len - 1) + else: + logits = model(src) # (b_sz, s_len, v_sz) + logits = logits.transpose( 1, 2 ) # Transpose for CrossEntropyLoss @@ -418,8 +388,16 @@ def val_loop(dataloader, _epoch: int): leave=False, ) ): - src, tgt, mask = batch # (b_sz, s_len), (b_sz, s_len, v_sz) - logits = model(src) # (b_sz, s_len, v_sz) + src, tgt, mask, emb = ( + batch # (b_sz, s_len), (b_sz, s_len), (b_sz, s_len), (b_sz, d_emb) + ) + if use_embeddings is True: + logits = model(src=src, emb=emb) # (b_sz, s_len - 1, v_sz) + tgt = tgt[:, :-1] # (b_sz, s_len - 1) + mask = mask[:, :-1] # (b_sz, s_len - 1) + else: + logits = model(src) # (b_sz, s_len, v_sz) + logits = logits.transpose(1, 2) # Transpose for CrossEntropyLoss loss = loss_fn(logits, tgt) @@ -451,7 +429,12 @@ def val_loop(dataloader, _epoch: int): PAD_ID = train_dataloader.dataset.tokenizer.pad_id logger = get_logger(__name__) # Accelerate logger loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_ID, reduction="none") - profile_flops(dataloader=train_dataloader) + + logger.info( + f"Model has " + f"{'{:,}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad))} " + "parameters" + ) if accelerator.is_main_process: loss_csv = open(os.path.join(project_dir, "loss.csv"), "w") @@ -504,10 +487,12 @@ def val_loop(dataloader, _epoch: int): epoch_csv.close() +# TODO: Add use_embeddings logic to this code path def resume_train( model_name: str, train_data_paths: str, val_data_path: str, + use_embeddings: bool, num_workers: int, batch_size: int, grad_acc_steps: int, @@ -533,8 +518,6 @@ def resume_train( tokenizer_name = get_tokenizer_name(train_data_paths, val_data_path) if tokenizer_name == "abs": tokenizer = AbsTokenizer() - elif tokenizer_name == "inference_abs": - tokenizer = InferenceAbsTokenizer() elif tokenizer_name == "rel": tokenizer = RelTokenizer() else: @@ -560,6 +543,7 @@ def resume_train( logger.info( f"Using training config: " f"model_name={model_name}, " + f"use_embeddings={use_embeddings}, " f"epochs={epochs}, " f"batch_size={batch_size}, " f"grad_acc_steps={grad_acc_steps}, " @@ -575,7 +559,12 @@ def resume_train( # Init model model_config = ModelConfig(**load_model_config(model_name)) model_config.set_vocab_size(tokenizer.vocab_size) - model = TransformerLM(model_config) + + if use_embeddings: + model = TransformerLM_CND(model_config) + else: + model = TransformerLM(model_config) + model.compile() train_dataloader, val_dataloader = get_dataloaders( @@ -586,6 +575,7 @@ def resume_train( batch_size=batch_size, num_workers=num_workers, apply_aug=True, + use_embeddings=use_embeddings, ) optimizer, scheduler = get_optim( model, @@ -624,6 +614,7 @@ def resume_train( model=model, train_dataloader=train_dataloader, val_dataloader=val_dataloader, + use_embeddings=use_embeddings, optimizer=optimizer, scheduler=scheduler, steps_per_checkpoint=steps_per_checkpoint, @@ -637,6 +628,7 @@ def train( model_name: str, train_data_paths: List[str], val_data_path: str, + use_embeddings: bool, num_workers: int, batch_size: int, grad_acc_steps: int, @@ -659,8 +651,6 @@ def train( tokenizer_name = get_tokenizer_name(train_data_paths, val_data_path) if tokenizer_name == "abs": tokenizer = AbsTokenizer() - elif tokenizer_name == "inference_abs": - tokenizer = InferenceAbsTokenizer() elif tokenizer_name == "rel": tokenizer = RelTokenizer() else: @@ -678,6 +668,7 @@ def train( logger.info( f"Using training config: " f"model_name={model_name}, " + f"use_embeddings={use_embeddings}, " f"checkpoint_path={checkpoint_path}, " if checkpoint_path else "" @@ -693,7 +684,12 @@ def train( # Init model model_config = ModelConfig(**load_model_config(model_name)) model_config.set_vocab_size(tokenizer.vocab_size) - model = TransformerLM(model_config) + + if use_embeddings is True: + model = TransformerLM_CND(model_config) + else: + model = TransformerLM(model_config) + model.compile() logger.info(f"Loaded model with config: {load_model_config(model_name)}") if checkpoint_path: @@ -714,7 +710,7 @@ def train( batch_size=batch_size, num_workers=num_workers, apply_aug=True, - finetune=True if checkpoint_path is not None else False, + use_embeddings=use_embeddings, ) assert ( @@ -798,6 +794,9 @@ def parse_resume_args(): argp.add_argument("-train_data", nargs="+", help="path to train dir") argp.add_argument("-val_data", help="path to val dir") argp.add_argument("-cp_dir", help="checkpoint dir", type=str, required=True) + argp.add_argument( + "-use_embeddings", help="prepend embeddings", action="store_true" + ) argp.add_argument("-r_step", help="resume step", type=int, required=True) argp.add_argument("-r_epoch", help="resume epoch", type=int, required=True) argp.add_argument("-epochs", help="train epochs", type=int, required=True) @@ -825,6 +824,9 @@ def parse_train_args(): argp.add_argument( "-cp_path", help="path to checkpoint", required=False, default=None ) + argp.add_argument( + "-use_embeddings", help="prepend embeddings", action="store_true" + ) argp.add_argument("-epochs", help="train epochs", type=int, required=True) argp.add_argument("-bs", help="batch size", type=int, default=32) argp.add_argument( @@ -860,6 +862,7 @@ def parse_train_args(): train( model_name=train_args.model, train_data_paths=train_args.train_data, + use_embeddings=train_args.use_embeddings, val_data_path=train_args.val_data, num_workers=train_args.workers, batch_size=train_args.bs, @@ -875,6 +878,7 @@ def parse_train_args(): model_name=resume_args.model, train_data_paths=resume_args.train_data, val_data_path=resume_args.val_data, + use_embeddings=resume_args.use_embeddings, num_workers=resume_args.workers, batch_size=resume_args.bs, grad_acc_steps=resume_args.grad_acc_steps, diff --git a/config/config.json b/config/config.json index 0d32671..dff8372 100644 --- a/config/config.json +++ b/config/config.json @@ -140,7 +140,7 @@ "metadata": { "functions": { "aria_midi_json": { - "run": true, + "run": false, "args": {} }, "composer_filename": { @@ -174,39 +174,6 @@ "form": ["sonata", "prelude", "nocturne", "étude", "waltz", "mazurka", "impromptu", "fugue"], "composer": ["bach", "beethoven", "mozart", "chopin", "rachmaninoff", "liszt", "debussy", "schubert", "brahms", "ravel", "satie", "scarlatti"] } - }, - "finetuning": { - "guidance_prob": 0.5, - "min_noisy_interval_ms": 5000, - "max_noisy_interval_ms": 60000, - "min_clean_interval_ms": 60000, - "max_clean_interval_ms": 200000, - "noising": { - "activation_prob": 0.5, - "remove_notes": { - "activation_prob": 0.25, - "min_ratio": 0.0, - "max_ratio": 0.15 - }, - "adjust_velocity": { - "activation_prob": 0.25, - "min_adjust": 1, - "max_adjust": 20 - }, - "adjust_onsets": { - "activation_prob": 0.25, - "min_adjust_s": 0.005, - "max_adjust_s": 0.05, - "max_ratio": 0.0, - "min_ratio": 0.2 - }, - "quantize_onsets": { - "activation_prob": 0.05, - "min_quant_s": 0.05, - "max_quant_s": 0.1, - "max_vel_delta": 30 - } - } } }, "tokenizer": { @@ -219,5 +186,4 @@ } } - } From 626b5b4e770a2fddaaaa2aead55bfe1032499d22 Mon Sep 17 00:00:00 2001 From: Louis Date: Fri, 14 Mar 2025 14:57:29 +0000 Subject: [PATCH 30/72] add ft-dataset script --- paper/scripts/build_aria_ft_emb_dataset.sh | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 paper/scripts/build_aria_ft_emb_dataset.sh diff --git a/paper/scripts/build_aria_ft_emb_dataset.sh b/paper/scripts/build_aria_ft_emb_dataset.sh new file mode 100644 index 0000000..7444685 --- /dev/null +++ b/paper/scripts/build_aria_ft_emb_dataset.sh @@ -0,0 +1,16 @@ +python /home/loubb/work/aria/paper/scripts/build_embedding_eval_datasets.py \ + --model aria \ + --model_cp_path /home/loubb/work/aria/models/emb-t0.1-s2048-e25.safetensors \ + --dataset_load_path /mnt/ssd1/aria/data/mididict-ft_val.jsonl \ + --dataset_save_path /mnt/ssd1/aria/data/finetune/ft-val_emb.jsonl \ + --compute_per_file_embeddings \ + --aria_max_batch_size 128 + +aria pretrain-dataset \ + -tokenizer_name abs \ + -load_path /mnt/ssd1/aria/data/mididict-ft_val.jsonl \ + -embedding_dataset_path /mnt/ssd1/aria/data/finetune/ft-val_emb.jsonl \ + -save_dir /mnt/ssd1/aria/data/finetune/val \ + -l 8192 \ + -e 1 \ + -sep_sequences \ No newline at end of file From 93bcb22318794fa763edd6dfe38f0032cfb26ab8 Mon Sep 17 00:00:00 2001 From: Louis Date: Sat, 15 Mar 2025 14:27:37 +0000 Subject: [PATCH 31/72] change use embeddings train logic --- aria/train.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/aria/train.py b/aria/train.py index 202cfe4..83b8d42 100644 --- a/aria/train.py +++ b/aria/train.py @@ -3,6 +3,7 @@ import csv import argparse import logging +import random import torch import accelerate @@ -318,7 +319,10 @@ def train_loop(dataloader: DataLoader, _epoch: int, _resume_step: int = 0): src, tgt, mask, emb = ( batch # (b_sz, s_len), (b_sz, s_len), (b_sz, s_len), (b_sz, d_emb) ) - if use_embeddings is True: + + use_embeddings_cond = use_embeddings and (random.random() > 0.5) + + if use_embeddings_cond is True: logits = model(src=src, emb=emb) # (b_sz, s_len - 1, v_sz) tgt = tgt[:, :-1] # (b_sz, s_len - 1) mask = mask[:, :-1] # (b_sz, s_len - 1) @@ -391,7 +395,9 @@ def val_loop(dataloader, _epoch: int): src, tgt, mask, emb = ( batch # (b_sz, s_len), (b_sz, s_len), (b_sz, s_len), (b_sz, d_emb) ) - if use_embeddings is True: + use_embeddings_cond = use_embeddings and (random.random() > 0.5) + + if use_embeddings_cond is True: logits = model(src=src, emb=emb) # (b_sz, s_len - 1, v_sz) tgt = tgt[:, :-1] # (b_sz, s_len - 1) mask = mask[:, :-1] # (b_sz, s_len - 1) From 72160f728653c35982defa78b80789c7c0648460 Mon Sep 17 00:00:00 2001 From: Louis Date: Sat, 15 Mar 2025 15:06:36 +0000 Subject: [PATCH 32/72] fix model ft loading --- aria/train.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/aria/train.py b/aria/train.py index 83b8d42..e3dec2e 100644 --- a/aria/train.py +++ b/aria/train.py @@ -701,12 +701,13 @@ def train( if checkpoint_path: try: model.load_state_dict(_load_weight(checkpoint_path)) - except Exception as e: - raise Exception( - f"Failed to load checkpoint: {e}\n" - "This could be due to a mismatch between the tokenizer used " - "to build the pre-training and fine-tuning datasets" + except RuntimeError as e: + print(e) + logger.info( + f"Failed to load {model_name} into {model_name}, attempting with strict=False" ) + model.load_state_dict(_load_weight(checkpoint_path), strict=False) + logger.info(f"Loaded finetune checkpoint located at: {checkpoint_path}") train_dataloader, val_dataloader = get_dataloaders( From 50b27d38534cb139a3710bd6b75cd3b1bc75fa5f Mon Sep 17 00:00:00 2001 From: Louis Date: Sat, 15 Mar 2025 15:16:03 +0000 Subject: [PATCH 33/72] fix arg --- aria/train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/aria/train.py b/aria/train.py index e3dec2e..222eab1 100644 --- a/aria/train.py +++ b/aria/train.py @@ -716,7 +716,7 @@ def train( tokenizer=tokenizer, batch_size=batch_size, num_workers=num_workers, - apply_aug=True, + apply_aug=False, use_embeddings=use_embeddings, ) @@ -755,6 +755,7 @@ def train( model=model, train_dataloader=train_dataloader, val_dataloader=val_dataloader, + use_embeddings=use_embeddings, optimizer=optimizer, scheduler=scheduler, steps_per_checkpoint=steps_per_checkpoint, From f9e15dedff46c8c166ab5325ab9ceb79b9a01831 Mon Sep 17 00:00:00 2001 From: Louis Date: Sat, 15 Mar 2025 15:24:28 +0000 Subject: [PATCH 34/72] fix ddp model error --- aria/model.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/aria/model.py b/aria/model.py index 75981ed..5ed277f 100644 --- a/aria/model.py +++ b/aria/model.py @@ -310,8 +310,18 @@ def forward( return logits[:, 1:, :] else: - hidden = self.model(src, emb) + # Needed for torch dpp error + dummy_input = torch.zeros( + src.size(0), + self.embedding_adapter.in_features, + device=src.device, + ) + dummy_output = self.embedding_adapter(dummy_input) + dummy_loss = dummy_output.sum() * 0.0 + + hidden = self.model(src, None) logits = self.lm_head(hidden) + logits = logits + dummy_loss return logits From 43689b4c5923b0689349deede709f92c8e7d1321 Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 18 Mar 2025 14:56:04 +0000 Subject: [PATCH 35/72] add pca --- aria/embeddings/pca.py | 120 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 120 insertions(+) create mode 100644 aria/embeddings/pca.py diff --git a/aria/embeddings/pca.py b/aria/embeddings/pca.py new file mode 100644 index 0000000..2d1c309 --- /dev/null +++ b/aria/embeddings/pca.py @@ -0,0 +1,120 @@ +import matplotlib.pyplot as plt +import json +import numpy as np +import pandas as pd +import seaborn as sns +from sklearn.decomposition import PCA +from sklearn.manifold import TSNE + + +# Flag to choose between t-SNE and PCA +use_tsne = True + +# Load data from the JSON file +with open("aria_embeddings.json", "r") as f: + data = json.load(f) + +# Define the set of top composers +top_composers = { + "chopin", + "bach", + "handel", + "haydn", + "tchaikovsky", + "scriabin", + "beethoven", + "liszt", + "mozart", + "debussy", + "schumann", + "schubert", + "satie", + "rachmaninoff", + "brahms", + "ravel", +} + + +# Filter the data to include only entries for the top composers +filtered_data = [entry for entry in data if entry["composer"] in top_composers] + +# Extract embeddings and composers from the filtered data +embeddings = np.array([entry["emb"] for entry in filtered_data]) +embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True) +composers = [entry["composer"].capitalize() for entry in filtered_data] + +# Perform dimensionality reduction based on the flag +if use_tsne: + reducer = TSNE( + n_components=2, perplexity=50, max_iter=2500, random_state=43 + ) + title = "t-SNE Visualization of Composer Embeddings" + filename = "/home/loubb/work/aria/tsne_plot.png" +else: + reducer = PCA(n_components=2) + filename = "/home/loubb/work/aria/pca_plot.png" + +embeddings_2d = reducer.fit_transform(embeddings) + +# Create a DataFrame for plotting +df = pd.DataFrame( + { + "Dimension 1": embeddings_2d[:, 0], + "Dimension 2": embeddings_2d[:, 1], + "Composer": composers, + } +) + +# Set the aesthetic style of the plots +sns.set_theme(style="whitegrid", font="Helvetica") + +# Create the scatter plot +plt.figure(figsize=(12, 8)) +scatter_plot = sns.scatterplot( + data=df, + x="Dimension 1", + y="Dimension 2", + hue="Composer", + palette="tab20", + s=50, # Marker size + edgecolor="w", + linewidth=0.5, +) + +plt.xlabel(None) # Remove x-axis label +plt.ylabel(None) # Remove y-axis label + +plt.xticks([]) # Remove numerical x-axis ticks +plt.yticks([]) # Remove numerical y-axis ticks + +plt.grid(True, linestyle="--", linewidth=0.5) # Keep the grid visible + +# Ensure grid is properly aligned +plt.gca().set_aspect("auto") # Prevent distortion +plt.gca().set_frame_on(True) # Keep figure frame +# plt.gca().set_xticks( +# np.linspace(df["Dimension 1"].min(), df["Dimension 1"].max(), num=6) +# ) +# plt.gca().set_yticks( +# np.linspace(df["Dimension 2"].min(), df["Dimension 2"].max(), num=6) +# ) + +# Move the legend outside the plot +plt.legend( + bbox_to_anchor=(0, -0.38, 1, 0), + loc="lower center", + ncol=4, # Arrange in multiple columns + fontsize=20, # Increase font size + columnspacing=1.05, # Increase the space between columns + title_fontsize=20, + title="Composer", +) + +# Set plot title and labels +# plt.title(title) + +# Save the plot as a high-resolution PNG file +plt.savefig(filename, dpi=300, bbox_inches="tight") + +# Display the plot +plt.show() From 46b9daf2771a6c3adba77d02e45ed64c55a6684e Mon Sep 17 00:00:00 2001 From: Louis Date: Thu, 20 Mar 2025 16:35:51 +0000 Subject: [PATCH 36/72] keshav --- aria/embeddings/evaluate.py | 7 +- aria/run.py | 179 ++++++++++++++++++++---------------- 2 files changed, 104 insertions(+), 82 deletions(-) diff --git a/aria/embeddings/evaluate.py b/aria/embeddings/evaluate.py index df2042d..bd1ab2b 100644 --- a/aria/embeddings/evaluate.py +++ b/aria/embeddings/evaluate.py @@ -87,12 +87,15 @@ def get_chunks(note_msgs: list, chunk_len: int): def process_entry( - entry, + entry: MidiDict | dict, slice_len_notes: int, max_seq_len: int, tokenizer: AbsTokenizer, ): - midi_dict = MidiDict.from_msg_dict(entry) + if isinstance(entry, dict): + midi_dict = MidiDict.from_msg_dict(entry) + else: + midi_dict = entry outputs = [] for slice_note_msgs in get_chunks( diff --git a/aria/run.py b/aria/run.py index 861193b..c9f5c4e 100644 --- a/aria/run.py +++ b/aria/run.py @@ -8,9 +8,18 @@ def _parse_sample_args(): argp = argparse.ArgumentParser(prog="aria sample") - argp.add_argument("-m", help="name of model config file") - argp.add_argument("-c", help="path to model checkpoint") - argp.add_argument("-p", help="path to midi file") + argp.add_argument( + "-checkpoint_path", help="path to model used for decoding" + ) + argp.add_argument("-prompt_midi_path", help="path to midi file") + argp.add_argument( + "-embedding_checkpoint_path", + required=False, + help="path to model checkpoint used for embeddings", + ) + argp.add_argument( + "-embedding_midi_path", required=False, help="path to midi file" + ) argp.add_argument( "-temp", help="sampling temperature value", @@ -31,13 +40,6 @@ def _parse_sample_args(): type=float, required=False, ) - argp.add_argument( - "-metadata", - nargs=2, - metavar=("KEY", "VALUE"), - action="append", - help="manually add metadata key-value pair when sampling", - ) argp.add_argument( "-var", help="number of variations", @@ -52,67 +54,104 @@ def _parse_sample_args(): ) argp.add_argument("-e", action="store_true", help="enable force end") argp.add_argument("-l", type=int, help="generation length", default=1024) - argp.add_argument( - "-guidance_path", type=str, help="path to guidance MIDI", required=False + argp.add_argument("-compile", action="store_true", help="compile cudagraph") + + return argp.parse_args(sys.argv[2:]) + + +def _get_embedding( + embedding_checkpoint_path: str, + midi_path: str, +): + import torch + + from aria.model import TransformerEMB + from aria.model import ModelConfig + from aria.config import load_model_config + from aria.utils import _load_weight + from aria.embeddings.evaluate import ( + get_aria_contrastive_embedding, + process_entry, ) - argp.add_argument( - "-guidance_start_ms", - help="guidance interval start (ms)", - type=int, - required=False, + + from ariautils.midi import MidiDict + from ariautils.tokenizer import AbsTokenizer + + SLICE_NUM_NOTES = 300 + SLICE_MAX_SEQ_LEN = 1024 + + tokenizer = AbsTokenizer() + + model_state = _load_weight(embedding_checkpoint_path, "cuda") + model_state = { + k.replace("_orig_mod.", ""): v for k, v in model_state.items() + } + + model_config = ModelConfig(**load_model_config("medium-emb")) + model_config.set_vocab_size(tokenizer.vocab_size) + model_config.grad_checkpoint = False + model = TransformerEMB(model_config).cuda().eval() + model.load_state_dict(model_state) + + seqs = process_entry( + entry=MidiDict.from_midi(midi_path), + slice_len_notes=SLICE_NUM_NOTES, + max_seq_len=SLICE_MAX_SEQ_LEN, + tokenizer=tokenizer, ) - argp.add_argument( - "-guidance_end_ms", - help="guidance interval end (ms)", - type=int, - required=False, + + def model_forward(model, idxs): + return model(idxs) + + embeddings = get_aria_contrastive_embedding( + seqs=[s["seq"] for s in seqs], + hook_model=model, + hook_max_seq_len=SLICE_MAX_SEQ_LEN, + hook_tokenizer=tokenizer, + hook_model_forward=model_forward, ) - argp.add_argument("-compile", action="store_true", help="compile cudagraph") + embedding = torch.tensor(embeddings, device="cuda").mean(0).tolist() - return argp.parse_args(sys.argv[2:]) + return embedding -# TODO: This is all broken due to tokenizer / embedding change -- need to fix def sample(args): """Entrypoint for sampling""" from torch.cuda import is_available as cuda_is_available from aria.inference import TransformerLM from aria.model import ModelConfig - from aria.config import load_model_config, load_config - from aria.tokenizer import InferenceAbsTokenizer - from aria.sample import ( - sample_batch_cfg, - sample_batch, - get_inference_prompt, - ) - from ariautils.midi import MidiDict + from aria.config import load_model_config + from aria.sample import sample_batch, sample_batch_cfg, get_inference_prompt from aria.utils import _load_weight + from ariautils.midi import MidiDict + from ariautils.tokenizer import AbsTokenizer + if not cuda_is_available(): raise Exception("CUDA device is not available.") - model_state = _load_weight(args.c, "cuda") - model_state = { - k.replace("_orig_mod.", ""): v for k, v in model_state.items() - } - - manual_metadata = {k: v for k, v in args.metadata} if args.metadata else {} - valid_metadata = load_config()["data"]["metadata"]["manual"] - for k, v in manual_metadata.copy().items(): - assert k in valid_metadata.keys(), f"{manual_metadata} is invalid" - if v not in valid_metadata[k]: - print(f"Ignoring invalid manual metadata: {k}") - print(f"Please choose from {valid_metadata[k]}") - del manual_metadata[k] - num_variations = args.var truncate_len = args.trunc force_end = args.e - model_name = args.m - tokenizer = InferenceAbsTokenizer() - model_config = ModelConfig(**load_model_config(model_name)) + tokenizer = AbsTokenizer() + + if args.embedding_checkpoint_path and args.embedding_midi_path: + print(f"Using embedding from {args.embedding_midi_path}") + embedding = _get_embedding( + embedding_checkpoint_path=args.embedding_checkpoint_path, + midi_path=args.embedding_midi_path, + ) + else: + embedding = None + + model_state = _load_weight(args.checkpoint_path, "cuda") + model_state = { + k.replace("_orig_mod.", ""): v for k, v in model_state.items() + } + + model_config = ModelConfig(**load_model_config("medium-emb")) model_config.set_vocab_size(tokenizer.vocab_size) model_config.grad_checkpoint = False model = TransformerLM(model_config).cuda() @@ -120,39 +159,21 @@ def sample(args): try: model.load_state_dict(model_state) except Exception as e: - print( - "Failed to load model_state. This is likely due to an incompatibility " - "between the checkpoint file (-c) and model name/config (-m)." - ) - raise e + print("Failed to load model_state - loading with strict=False") + model.load_state_dict(model_state, strict=False) assert args.l > 0, "Generation length must be positive." max_new_tokens = args.l # Load and format prompts and metadata - midi_dict = MidiDict.from_midi(mid_path=args.p) - if args.guidance_path: - guidance_midi_dict = MidiDict.from_midi(mid_path=args.guidance_path) - else: - guidance_midi_dict = None - - for k, v in manual_metadata.items(): - midi_dict.metadata[k] = v - - print(f"Extracted metadata: {midi_dict.metadata}") - print( - f"Instruments: {set([MidiDict.get_program_to_instrument()[msg['data']] for msg in midi_dict.instrument_msgs])}" - ) + midi_dict = MidiDict.from_midi(mid_path=args.prompt_midi_path) - prompt_seq, guidance_seq = get_inference_prompt( + prompt_seq = get_inference_prompt( tokenizer=tokenizer, midi_dict=midi_dict, - truncate_len=truncate_len, - guidance_start_ms=args.guidance_start_ms, - guidance_end_ms=args.guidance_end_ms, - guidance_midi_dict=guidance_midi_dict, + prompt_len_ms=truncate_len * 1e3, ) - # prompt_seq = prompt_seq[:-1] + print(prompt_seq) if len(prompt_seq) + args.l > model_config.max_seq_len: @@ -164,22 +185,19 @@ def sample(args): samples_dir = "/home/loubb/Dropbox/shared" if os.path.isdir(samples_dir) is False: os.mkdir(samples_dir) - if guidance_seq: - tokenizer.detokenize(guidance_seq).to_midi().save( - os.path.join(samples_dir, f"guidance.mid") - ) - if args.cfg is not None and guidance_seq is not None: + if args.cfg and embedding is not None: results = sample_batch_cfg( model=model, tokenizer=tokenizer, prompts=prompts, max_new_tokens=max_new_tokens, - cfg_gamma=args.cfg, force_end=force_end, temperature=args.temp, top_p=args.top_p, + cfg_gamma=args.cfg, compile=args.compile, + embedding=embedding, ) else: results = sample_batch( @@ -191,6 +209,7 @@ def sample(args): temperature=args.temp, top_p=args.top_p, compile=args.compile, + embedding=embedding, ) for idx, tokenized_seq in enumerate(results): From 9aeafd2de8e3fab3938b8ec98c5b3b7618bf703d Mon Sep 17 00:00:00 2001 From: Louis Date: Thu, 20 Mar 2025 17:18:36 +0000 Subject: [PATCH 37/72] keshav add args --- aria/run.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/aria/run.py b/aria/run.py index c9f5c4e..95dc378 100644 --- a/aria/run.py +++ b/aria/run.py @@ -62,6 +62,8 @@ def _parse_sample_args(): def _get_embedding( embedding_checkpoint_path: str, midi_path: str, + start_ms: int | None = None, + end_ms: int | None = None, ): import torch @@ -93,6 +95,22 @@ def _get_embedding( model = TransformerEMB(model_config).cuda().eval() model.load_state_dict(model_state) + midi_dict = MidiDict.from_midi(midi_path) + midi_dict.note_msgs = [ + msg + for msg in midi_dict.note_msgs + if ( + midi_dict.tick_to_ms(msg["tick"]) >= start_ms + if start_ms is not None + else True + ) + and ( + midi_dict.tick_to_ms(msg["tick"]) <= end_ms + if end_ms is not None + else True + ) + ] + seqs = process_entry( entry=MidiDict.from_midi(midi_path), slice_len_notes=SLICE_NUM_NOTES, From 9ef4db9712ee860949e689203aa0256be80896cb Mon Sep 17 00:00:00 2001 From: Louis Date: Fri, 21 Mar 2025 13:28:16 +0000 Subject: [PATCH 38/72] fix keshav --- aria/run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aria/run.py b/aria/run.py index 95dc378..b5b1207 100644 --- a/aria/run.py +++ b/aria/run.py @@ -112,7 +112,7 @@ def _get_embedding( ] seqs = process_entry( - entry=MidiDict.from_midi(midi_path), + entry=midi_dict, slice_len_notes=SLICE_NUM_NOTES, max_seq_len=SLICE_MAX_SEQ_LEN, tokenizer=tokenizer, From 67d86bbb8efc7e5ac05be4ca39d3776c367a2dcf Mon Sep 17 00:00:00 2001 From: Louis Date: Thu, 22 May 2025 16:33:51 +0000 Subject: [PATCH 39/72] update sampling and demo --- aria/run.py | 74 +++++++++------ aria/sample.py | 251 +++++++++++++++++++++++++++---------------------- demo/demo.py | 208 ++++++++++++++++++++-------------------- 3 files changed, 288 insertions(+), 245 deletions(-) diff --git a/aria/run.py b/aria/run.py index b5b1207..079b1e8 100644 --- a/aria/run.py +++ b/aria/run.py @@ -18,7 +18,10 @@ def _parse_sample_args(): help="path to model checkpoint used for embeddings", ) argp.add_argument( - "-embedding_midi_path", required=False, help="path to midi file" + "-embedding_midi_paths", + nargs="+", + required=False, + help="path(s) to midi file(s) used for embeddings", ) argp.add_argument( "-temp", @@ -32,7 +35,12 @@ def _parse_sample_args(): help="sampling top_p value", type=float, required=False, - default=0.95, + ) + argp.add_argument( + "-min_p", + help="sampling min_p value", + type=float, + required=False, ) argp.add_argument( "-cfg", @@ -61,7 +69,7 @@ def _parse_sample_args(): def _get_embedding( embedding_checkpoint_path: str, - midi_path: str, + midi_paths: list[str], start_ms: int | None = None, end_ms: int | None = None, ): @@ -95,28 +103,32 @@ def _get_embedding( model = TransformerEMB(model_config).cuda().eval() model.load_state_dict(model_state) - midi_dict = MidiDict.from_midi(midi_path) - midi_dict.note_msgs = [ - msg - for msg in midi_dict.note_msgs - if ( - midi_dict.tick_to_ms(msg["tick"]) >= start_ms - if start_ms is not None - else True - ) - and ( - midi_dict.tick_to_ms(msg["tick"]) <= end_ms - if end_ms is not None - else True + seqs = [] + for midi_path in midi_paths: + midi_dict = MidiDict.from_midi(midi_path) + midi_dict.note_msgs = [ + msg + for msg in midi_dict.note_msgs + if ( + midi_dict.tick_to_ms(msg["tick"]) >= start_ms + if start_ms is not None + else True + ) + and ( + midi_dict.tick_to_ms(msg["tick"]) <= end_ms + if end_ms is not None + else True + ) + ] + + seqs.extend( + process_entry( + entry=midi_dict, + slice_len_notes=SLICE_NUM_NOTES, + max_seq_len=SLICE_MAX_SEQ_LEN, + tokenizer=tokenizer, + ) ) - ] - - seqs = process_entry( - entry=midi_dict, - slice_len_notes=SLICE_NUM_NOTES, - max_seq_len=SLICE_MAX_SEQ_LEN, - tokenizer=tokenizer, - ) def model_forward(model, idxs): return model(idxs) @@ -155,11 +167,13 @@ def sample(args): tokenizer = AbsTokenizer() - if args.embedding_checkpoint_path and args.embedding_midi_path: - print(f"Using embedding from {args.embedding_midi_path}") + if args.embedding_checkpoint_path and args.embedding_midi_paths: + print(f"Using embedding from {args.embedding_midi_paths}") embedding = _get_embedding( embedding_checkpoint_path=args.embedding_checkpoint_path, - midi_path=args.embedding_midi_path, + midi_paths=args.embedding_midi_paths, + start_ms=args.trunc * 1e3, + end_ms=None, ) else: embedding = None @@ -211,8 +225,9 @@ def sample(args): prompts=prompts, max_new_tokens=max_new_tokens, force_end=force_end, - temperature=args.temp, + temp=args.temp, top_p=args.top_p, + min_p=args.min_p, cfg_gamma=args.cfg, compile=args.compile, embedding=embedding, @@ -224,8 +239,9 @@ def sample(args): prompts=prompts, max_new_tokens=max_new_tokens, force_end=force_end, - temperature=args.temp, + temp=args.temp, top_p=args.top_p, + min_p=args.min_p, compile=args.compile, embedding=embedding, ) diff --git a/aria/sample.py b/aria/sample.py index 01f9abe..653505a 100644 --- a/aria/sample.py +++ b/aria/sample.py @@ -1,6 +1,5 @@ """Contains generation/sampling code""" -import copy import torch import torch._dynamo.config import torch._inductor.config @@ -9,7 +8,6 @@ from tqdm import tqdm from aria.inference import TransformerLM -from aria.tokenizer import InferenceAbsTokenizer from ariautils.tokenizer import Tokenizer, AbsTokenizer from ariautils.midi import MidiDict @@ -18,15 +16,11 @@ torch._inductor.config.fx_graph_cache = True -def get_cfg_prompt(prompts: list, pad_tok: str, guidance_end_tok: str): +def get_cfg_prompt(prompts: list): cfg_prompts = [] for prompt in prompts: - prompt_no_guidance = prompt[prompt.index(guidance_end_tok) + 1 :] - prompt_no_guidance = [pad_tok] * ( - len(prompt) - len(prompt_no_guidance) - ) + prompt_no_guidance cfg_prompts.append(prompt) - cfg_prompts.append(prompt_no_guidance) + cfg_prompts.append(prompt) return cfg_prompts @@ -38,6 +32,8 @@ def decode_one( input_pos: torch.Tensor, pad_idxs: torch.Tensor | None = None, ): + assert input_pos.shape[-1] == 1 + logits = model.forward( idxs=idxs, input_pos=input_pos, @@ -93,7 +89,6 @@ def update_seq_ids_( seq[:, idx] = next_token_ids -# TODO: Not working @torch.autocast( "cuda", dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16, @@ -105,21 +100,30 @@ def sample_batch( prompts: List[list], max_new_tokens: int, force_end=False, - temperature: float = 0.95, - top_p: float = 0.95, + temp: float = 0.95, + embedding: list[float] | None = None, + top_p: float | None = None, + min_p: float | None = None, compile: bool = False, ): + assert top_p is not None or min_p is not None + if top_p is not None: + assert 0.5 <= top_p <= 1.0 + if min_p is not None: + assert 0.0 <= min_p <= 1.0 + if temp is not None: + assert 0.0 <= temp <= 2.0 if force_end: assert max_new_tokens > 130, "prompt too long to use force_end=True" - _prompt_len = len(prompts[0]) - _num_prompts = len(prompts) - assert all([len(p) == _prompt_len for p in prompts]) + prompt_len = len(prompts[0]) + num_prompts = len(prompts) + assert all([len(p) == prompt_len for p in prompts]) model.eval() - dim_tok_inserted = [False for _ in range(_num_prompts)] - eos_tok_seen = [False for _ in range(_num_prompts)] - total_len = _prompt_len + max_new_tokens + dim_tok_inserted = [False for _ in range(num_prompts)] + eos_tok_seen = [False for _ in range(num_prompts)] + total_len = prompt_len + max_new_tokens seq = torch.stack( [ torch.tensor( @@ -138,48 +142,59 @@ def sample_batch( ) model.setup_cache( - batch_size=_num_prompts, + batch_size=num_prompts, max_seq_len=total_len, dtype=( torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 ), ) + if embedding: + condition_embedding = torch.tensor( + [embedding for _ in range(num_prompts)], device=seq.device + ) + model.fill_condition_kv(cond_emb=condition_embedding) + emb_offset = 1 + else: + emb_offset = 0 + print( - f"Using hyperparams: temp={temperature}, top_p={top_p}, gen_len={max_new_tokens}" + f"Using hyperparams: temp={temp}, top_p={top_p}, min_p={min_p}, gen_len={max_new_tokens}" ) for idx in ( pbar := tqdm( - range(_prompt_len, total_len), - total=total_len - _prompt_len, + range(prompt_len, total_len), + total=total_len - prompt_len, leave=False, ) ): with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): - if idx == _prompt_len: + if idx == prompt_len: logits = prefill( model, idxs=seq[:, :idx], - input_pos=torch.arange(0, idx, device=seq.device), + input_pos=torch.arange( + emb_offset, idx + emb_offset, device=seq.device + ), ) else: logits = decode_one( model, idxs=seq[:, idx - 1 : idx], input_pos=torch.tensor( - [idx - 1], device=seq.device, dtype=torch.int + [(idx + emb_offset) - 1], + device=seq.device, + dtype=torch.int, ), ) - if tokenizer.name == "inference_abs": - logits[:, tokenizer.tok_to_id[tokenizer.prompt_start_tok]] = float( - "-inf" - ) - - if temperature > 0.0: - probs = torch.softmax(logits / temperature, dim=-1) - next_token_ids = sample_top_p(probs, top_p).flatten() + if temp > 0.0: + probs = torch.softmax(logits / temp, dim=-1) + if min_p is not None: + next_token_ids = sample_min_p(probs, min_p).flatten() + else: + next_token_ids = sample_top_p(probs, top_p).flatten() else: next_token_ids = torch.argmax(logits, dim=-1).flatten() @@ -210,6 +225,7 @@ def sample_batch( return decoded_results +# Not tested but I think this works @torch.autocast( "cuda", dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16, @@ -217,42 +233,46 @@ def sample_batch( @torch.inference_mode() def sample_batch_cfg( model: TransformerLM, - tokenizer: InferenceAbsTokenizer, + tokenizer: AbsTokenizer, prompts: List[list], max_new_tokens: int, cfg_gamma: float, + embedding: list[float], force_end=False, - temperature: float = 0.95, - top_p: float = 0.95, + temp: float = 0.95, + top_p: float | None = None, + min_p: float | None = None, compile: bool = False, ): - assert 0.0 <= cfg_gamma <= 2.0 - assert 0.0 <= temperature <= 2.0 - assert 0.5 <= top_p <= 1.0 - assert tokenizer.name == "inference_abs" + assert 0.0 <= cfg_gamma <= 15.0 + assert top_p is not None or min_p is not None + if top_p is not None: + assert 0.5 <= top_p <= 1.0 + if temp is not None: + assert 0.0 <= temp <= 2.0 if force_end: assert max_new_tokens > 130, "prompt too long to use force_end=True" - prompts = get_cfg_prompt( - prompts, tokenizer.pad_tok, tokenizer.guidance_end_tok - ) + prompts = get_cfg_prompt(prompts) - _prompt_len = len(prompts[0]) - _num_prompts = len(prompts) - assert all([len(p) == _prompt_len for p in prompts]) + prompt_len = len(prompts[0]) + num_prompts = len(prompts) + assert all([len(p) == prompt_len for p in prompts]) model.eval() - total_len = _prompt_len + max_new_tokens + total_context_len = prompt_len + max_new_tokens seq = torch.stack( [ torch.tensor( - tokenizer.encode(p + [tokenizer.pad_tok] * (total_len - len(p))) + tokenizer.encode( + p + [tokenizer.pad_tok] * (total_context_len - len(p)) + ) ) for p in prompts ] ).cuda() - dim_tok_inserted = [False for _ in range(_num_prompts)] - eos_tok_seen = [False for _ in range(_num_prompts)] + dim_tok_inserted = [False for _ in range(num_prompts)] + eos_tok_seen = [False for _ in range(num_prompts)] if compile is True: global decode_one @@ -263,51 +283,70 @@ def sample_batch_cfg( ) model.setup_cache( - batch_size=_num_prompts, - max_seq_len=total_len, + batch_size=num_prompts, + max_seq_len=total_context_len, dtype=( torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 ), ) + condition_embedding = torch.tensor( + [embedding for _ in range(num_prompts)], device=seq.device + ) + model.fill_condition_kv(cond_emb=condition_embedding) + embedding_offset = 1 + pad_idxs = torch.zeros_like(seq, dtype=torch.bool) + pad_idxs[1::2, 0] = True + print( - f"Using hyperparams: temp={temperature}, top_p={top_p}, gamma={cfg_gamma}, gen_len={max_new_tokens}" + f"Using hyperparams: temp={temp}, top_p={top_p}, min_p={min_p}, gamma={cfg_gamma}, gen_len={max_new_tokens}" ) + CFG_WARM_UP_STEPS = 250 + curr_step = 0 for idx in ( pbar := tqdm( - range(_prompt_len, total_len), - total=total_len - _prompt_len, + range(prompt_len, total_context_len), + total=total_context_len - prompt_len, leave=False, ) ): with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): - if idx == _prompt_len: + if idx == prompt_len: logits = prefill( model, idxs=seq[:, :idx], - input_pos=torch.arange(0, idx, device=seq.device), - pad_idxs=(seq == tokenizer.pad_id), + input_pos=torch.arange( + embedding_offset, + idx + embedding_offset, + device=seq.device, + ), + pad_idxs=pad_idxs, ) else: logits = decode_one( model, idxs=seq[:, idx - 1 : idx], input_pos=torch.tensor( - [idx - 1], device=seq.device, dtype=torch.int + [(idx + embedding_offset) - 1], + device=seq.device, + dtype=torch.int, ), - pad_idxs=(seq == tokenizer.pad_id), + pad_idxs=pad_idxs, ) - logits_cfg = cfg_gamma * logits[::2] + (1 - cfg_gamma) * logits[1::2] - logits_cfg[:, tokenizer.tok_to_id[tokenizer.prompt_start_tok]] = float( - "-inf" - ) + curr_step += 1 + _cfg_gamma = min(cfg_gamma, (curr_step / CFG_WARM_UP_STEPS) * cfg_gamma) + + logits_cfg = _cfg_gamma * logits[::2] + (1 - _cfg_gamma) * logits[1::2] logits_cfg[:, tokenizer.tok_to_id[tokenizer.dim_tok]] = float("-inf") - if temperature > 0.0: - probs = torch.softmax(logits_cfg / temperature, dim=-1) - next_token_ids = sample_top_p(probs, top_p).flatten() + if temp > 0.0: + probs = torch.softmax(logits_cfg / temp, dim=-1) + if min_p is not None: + next_token_ids = sample_min_p(probs, min_p).flatten() + else: + next_token_ids = sample_top_p(probs, top_p).flatten() else: next_token_ids = torch.argmax(logits_cfg, dim=-1).flatten() @@ -318,7 +357,7 @@ def sample_batch_cfg( next_token_ids=next_token_ids, dim_tok_inserted=dim_tok_inserted, eos_tok_seen=eos_tok_seen, - max_len=total_len, + max_len=total_context_len, force_end=force_end, tokenizer=tokenizer, ) @@ -339,67 +378,49 @@ def sample_batch_cfg( return decoded_results +def sample_min_p(probs, p_base): + """See - https://arxiv.org/pdf/2407.01082""" + p_max, _ = torch.max(probs, dim=-1, keepdim=True) + p_scaled = p_base * p_max + mask = probs >= p_scaled + + masked_probs = probs.clone() + masked_probs[~mask] = 0.0 + masked_probs.div_(masked_probs.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(masked_probs, num_samples=1) + + return next_token + + def sample_top_p(probs, p): probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) probs_sum = torch.cumsum(probs_sort, dim=-1) mask = probs_sum - probs_sort > p probs_sort[mask] = 0.0 + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) next_token = torch.multinomial(probs_sort, num_samples=1) next_token = torch.gather(probs_idx, -1, next_token) + return next_token def get_inference_prompt( - tokenizer: InferenceAbsTokenizer, - midi_dict: MidiDict, - truncate_len: int, - guidance_start_ms: int, - guidance_end_ms: int, - guidance_midi_dict: MidiDict | None = None, + midi_dict: MidiDict, tokenizer: AbsTokenizer, prompt_len_ms: int ): - assert tokenizer.name == "inference_abs" - - if guidance_midi_dict is not None: - assert guidance_start_ms is not None and guidance_start_ms >= 0 - assert guidance_end_ms is not None and guidance_end_ms >= 0 - assert ( - tokenizer._config["guidance"]["min_ms"] - <= guidance_end_ms - guidance_start_ms - <= tokenizer._config["guidance"]["max_ms"] - ) - - prompt_seq = tokenizer.tokenize( - midi_dict=midi_dict, - prompt_intervals_ms=( - [[0, truncate_len * 1e3]] if truncate_len > 0 else [] - ), - guidance_midi_dict=guidance_midi_dict, - guidance_start_ms=guidance_start_ms, - guidance_end_ms=guidance_end_ms, - ) - - if tokenizer.prompt_end_tok in prompt_seq: - prompt_seq = prompt_seq[ - : prompt_seq.index(tokenizer.prompt_end_tok) + 1 - ] - else: - print("No notes found in prompt region") - prompt_seq = prompt_seq[: prompt_seq.index(tokenizer.bos_tok) + 1] + midi_dict.note_msgs = [ + msg + for msg in midi_dict.note_msgs + if midi_dict.tick_to_ms(msg["data"]["start"]) <= prompt_len_ms + ] - if tokenizer.dim_tok in prompt_seq: - prompt_seq.remove(tokenizer.dim_tok) + if len(midi_dict.note_msgs) == 0: + return [("prefix", "instrument", "piano"), tokenizer.bos_tok] - if ( - guidance_midi_dict is not None - and tokenizer.guidance_start_tok in prompt_seq - ): - guidance_seq = copy.deepcopy(prompt_seq) - guidance_seq = guidance_seq[ - : guidance_seq.index(tokenizer.guidance_end_tok) - ] - guidance_seq[0] = ("prefix", "instrument", "piano") - else: - guidance_seq = None + seq = tokenizer.tokenize(midi_dict=midi_dict) + if tokenizer.dim_tok in seq: + seq.remove(tokenizer.dim_tok) + if tokenizer.eos_tok in seq: + seq.remove(tokenizer.eos_tok) - return prompt_seq, guidance_seq + return seq diff --git a/demo/demo.py b/demo/demo.py index 9f7447f..c3083a5 100644 --- a/demo/demo.py +++ b/demo/demo.py @@ -10,32 +10,39 @@ import queue import torch import mido -import torch._dynamo.config import torch._inductor.config from torch.cuda import is_available as cuda_is_available from contextlib import ExitStack from ariautils.midi import MidiDict, midi_to_dict -from aria.tokenizer import InferenceAbsTokenizer +from ariautils.tokenizer import AbsTokenizer from aria.utils import _load_weight from aria.inference import TransformerLM from aria.model import ModelConfig from aria.config import load_model_config -from aria.sample import prefill, decode_one, sample_top_p +from aria.sample import prefill, decode_one, sample_min_p torch._inductor.config.coordinate_descent_tuning = True torch._inductor.config.triton.unique_kernel_names = True torch._inductor.config.fx_graph_cache = True -# torch.set_float32_matmul_precision("high") DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 MAX_SEQ_LEN = 8192 PREFILL_COMPILE_SEQ_LEN = 1024 +# HARDWARE: Decoded logits are masked for durations < MIN_NOTE_LEN_MS +# HARDWARE: Sends early off-msg if pitch is on MIN_NOTE_DELTA_MS before on-msg +# HARDWARE: All messages are sent HARDWARE_LATENCY_MS early +MIN_NOTE_DELTA_MS = 50 +MIN_NOTE_LEN_MS = 50 +HARDWARE_LATENCY_MS = 0 + # TODO: -# - Add CFG support +# - Add CFG support (eek) # - Add looping functionality +# - Add ending functionality + file_handler = logging.FileHandler("./demo.log", mode="w") file_handler.setLevel(logging.DEBUG) @@ -204,8 +211,8 @@ def load_model( init_start_time_s = time.time() - tokenizer = InferenceAbsTokenizer() - model_config = ModelConfig(**load_model_config("medium")) + tokenizer = AbsTokenizer() + model_config = ModelConfig(**load_model_config("medium-emb")) model_config.set_vocab_size(tokenizer.vocab_size) model_config.grad_checkpoint = False model = TransformerLM(model_config).cuda() @@ -215,7 +222,11 @@ def load_model( model_state = { k.replace("_orig_mod.", ""): v for k, v in model_state.items() } - model.load_state_dict(model_state) + try: + model.load_state_dict(model_state) + except Exception: + logger.info("Failed to load model, attempting with strict=False...") + model.load_state_dict(model_state, strict=False) logger.info( f"Finished initializing model - took {time.time() - init_start_time_s:.4f} seconds" @@ -230,7 +241,7 @@ def recalculate_dur_tokens( model: TransformerLM, priming_seq: list, enc_seq: torch.Tensor, - tokenizer: InferenceAbsTokenizer, + tokenizer: AbsTokenizer, start_idx: int, ): logger = get_logger("GENERATE") @@ -293,6 +304,7 @@ def recalculate_dur_tokens( enc_seq[:, idx] = next_token_ids priming_seq[idx] = predicted_tok + # TODO: There is a bug here if no-notes are active when the signal is pressed last_tok_id = enc_seq[0, idx] last_tok = tokenizer.id_to_tok[last_tok_id.item()] @@ -315,16 +327,16 @@ def decode_first_tokens( first_token_logits: torch.Tensor, enc_seq: torch.Tensor, priming_seq: list, - tokenizer: InferenceAbsTokenizer, + tokenizer: AbsTokenizer, generated_tokens_queue: queue.Queue, first_on_msg_epoch_ms: int, ): logger = get_logger("GENERATE") BEAM_WIDTH = 5 - BUFFER_MS = 100 + BUFFER_MS = 100 + HARDWARE_LATENCY_MS TIME_TOK_ID = tokenizer.tok_to_id[tokenizer.time_tok] - TIME_TOK_WEIGHTING = -3 + TIME_TOK_WEIGHTING = -6 logits = first_token_logits time_since_first_onset_ms = get_epoch_time_ms() - first_on_msg_epoch_ms @@ -356,7 +368,6 @@ def decode_first_tokens( logits[:, tokenizer.tok_to_id[tokenizer.dim_tok]] = float("-inf") logits[:, tokenizer.tok_to_id[tokenizer.eos_tok]] = float("-inf") - logits[:, tokenizer.tok_to_id[tokenizer.prompt_start_tok]] = float("-inf") log_probs = torch.log_softmax(logits, dim=-1) top_log_probs, top_ids = torch.topk(log_probs, k=BEAM_WIDTH, dim=-1) @@ -452,16 +463,16 @@ def decode_first_tokens( def decode_tokens( model: TransformerLM, enc_seq: torch.Tensor, - tokenizer: InferenceAbsTokenizer, + tokenizer: AbsTokenizer, control_sentinel: threading.Event, generated_tokens_queue: queue.Queue, idx: int, temperature: float, - top_p: float, + min_p: float, ): logger = get_logger("GENERATE") logger.info( - f"Using sampling parameters: temperature={temperature}, top_p={top_p}" + f"Using sampling parameters: temperature={temperature}, min_p={min_p}" ) while (not control_sentinel.is_set()) and idx < MAX_SEQ_LEN: @@ -487,19 +498,18 @@ def decode_tokens( logits[:, tokenizer.tok_to_id[tokenizer.dim_tok]] = float("-inf") logits[:, tokenizer.tok_to_id[tokenizer.eos_tok]] = float("-inf") - logits[:, tokenizer.tok_to_id[tokenizer.prompt_start_tok]] = float( - "-inf" - ) + for dur_ms in range(0, MIN_NOTE_LEN_MS, 10): + logits[:, tokenizer.tok_to_id[("dur", dur_ms)]] = float("-inf") if temperature > 0.0: probs = torch.softmax(logits / temperature, dim=-1) - next_token_ids = sample_top_p(probs, top_p).flatten() + next_token_ids = sample_min_p(probs, min_p).flatten() else: next_token_ids = torch.argmax(logits, dim=-1).flatten() enc_seq[:, idx] = next_token_ids next_token = tokenizer.id_to_tok[next_token_ids[0].item()] - logger.info( + logger.debug( f"({(time.time() - decode_one_start_time_s)*1000:.2f}ms) {idx}: {next_token}" ) @@ -513,22 +523,19 @@ def decode_tokens( generated_tokens_queue.put(None) -# TODO: Support CFG, guidance, and metadata tags -# TODO: Context length switching -# TODO: BUG: I'm still not 100% sure that the KV is being calculated correctly # TODO: BUG: Potentially a bug with dim_toks ect... being removed during kv-preprocessing @torch.autocast("cuda", dtype=DTYPE) @torch.inference_mode() def generate_tokens( priming_seq: list, - tokenizer: InferenceAbsTokenizer, + tokenizer: AbsTokenizer, model: TransformerLM, control_sentinel: threading.Event, generated_tokens_queue: queue.Queue, num_preceding_active_pitches: int, first_on_msg_epoch_ms: int, - temperature: float = 0.95, - top_p: float = 0.95, + temperature: float = 0.97, + min_p: float = 0.03, ): logger = get_logger("GENERATE") @@ -612,14 +619,14 @@ def generate_tokens( generated_tokens_queue=generated_tokens_queue, idx=idx, temperature=temperature, - top_p=top_p, + min_p=min_p, ) def decode_tokens_to_midi( generated_tokens_queue: queue.Queue, midi_messages_queue: queue.Queue, - tokenizer: InferenceAbsTokenizer, + tokenizer: AbsTokenizer, first_on_msg_epoch_ms: int, priming_seq_last_onset_ms: int, ): @@ -657,7 +664,7 @@ def decode_tokens_to_midi( note_buffer.pop(0) assert len(note_buffer) == 3 - logger.info(f"Decoded note: {note_buffer}") + logger.debug(f"Decoded note: {note_buffer}") note_tok, onset_tok, dur_tok = note_buffer _, pitch, vel = note_tok _, onset = onset_tok @@ -683,7 +690,7 @@ def decode_tokens_to_midi( midi_messages_queue.put(off_msg) logger.debug(f"Put message: {on_msg}") logger.debug(f"Put message: {off_msg}") - logger.info(f"Ahead by {onset_epoch_ms - get_epoch_time_ms()}ms") + logger.debug(f"Ahead by {onset_epoch_ms - get_epoch_time_ms()}ms") note_buffer = [] @@ -699,7 +706,12 @@ def stream_midi( logger.info( f"Sending generated messages on MIDI port: '{midi_output_port}'" ) + logger.info( + f"Applying hardware latency adjustment: {HARDWARE_LATENCY_MS}ms" + ) + last_pitch_uuid = {} + pitch_active = {} midi_messages = [] with mido.open_output(midi_output_port) as midi_out: @@ -725,10 +737,37 @@ def stream_midi( ) while midi_messages: - curr_epoch_time_ms = get_epoch_time_ms() + latency_adjusted_epoch_time_ms = ( + get_epoch_time_ms() + HARDWARE_LATENCY_MS + ) msg = midi_messages[0] - if 0 < curr_epoch_time_ms - msg["epoch_time_ms"] <= 50: + if ( + (msg["vel"] != 0) + and ( + msg["epoch_time_ms"] - latency_adjusted_epoch_time_ms + <= MIN_NOTE_DELTA_MS + ) + and pitch_active.get(msg["pitch"], False) + ): + midi_out.send( + mido.Message( + "note_on", + note=msg["pitch"], + velocity=0, + channel=0, + time=0, + ) + ) + pitch_active[msg["pitch"]] = False + logger.info(f"Sent early off for {msg}") + + if ( + 0 + < latency_adjusted_epoch_time_ms - msg["epoch_time_ms"] + <= 50 + ): + mido_msg = mido.Message( "note_on", note=msg["pitch"], @@ -755,18 +794,23 @@ def stream_midi( midi_out.send(mido_msg) msgs.append(mido_msg_with_time) - logger.info(mido_msg_with_time) - logger.info(f"Sent message: {msg}") + pitch_active[msg["pitch"]] = msg["vel"] != 0 + + logger.info( + f"[{get_epoch_time_ms() - msg["epoch_time_ms"]}ms] Sent message: {msg}" + ) else: logger.info( f"Skipping note_off message due to uuid mismatch: {msg}" ) midi_messages.pop(0) - elif curr_epoch_time_ms - msg["epoch_time_ms"] > 100: + elif ( + latency_adjusted_epoch_time_ms - msg["epoch_time_ms"] > 100 + ): # Message occurs too far in the past logger.info( - f"Skipping message occurring too far ({curr_epoch_time_ms - msg["epoch_time_ms"]}ms) in the past: {msg}" + f"Skipping message occurring too far ({latency_adjusted_epoch_time_ms - msg["epoch_time_ms"]}ms) in the past: {msg}" ) midi_messages.pop(0) else: @@ -803,30 +847,18 @@ def stream_midi( def stream_msgs( model: TransformerLM, - tokenizer: InferenceAbsTokenizer, + tokenizer: AbsTokenizer, msgs: list[mido.Message], midi_output_port: str, first_on_msg_epoch_ms: int, control_sentinel: threading.Event, temperature: float, - top_p: float, + min_p: float, num_preceding_active_pitches: int, - guidance_midi_dict: MidiDict | None = None, - guidance_start_ms: int | None = None, - guidance_end_ms: int | None = None, ): midi = convert_msgs_to_midi(msgs=msgs) midi_dict = MidiDict(**midi_to_dict(midi)) - priming_seq = tokenizer.tokenize( - midi_dict=midi_dict, - # prompt_intervals_ms=[ - # (0, (get_epoch_time_ms() - 3000) - first_on_msg_epoch_ms) - # ], - prompt_intervals_ms=[], - guidance_midi_dict=guidance_midi_dict, - guidance_start_ms=guidance_start_ms, - guidance_end_ms=guidance_end_ms, - ) + priming_seq = tokenizer.tokenize(midi_dict=midi_dict) priming_seq = priming_seq[: priming_seq.index(tokenizer.eos_tok)] if tokenizer.dim_tok in priming_seq: @@ -844,7 +876,7 @@ def stream_msgs( "control_sentinel": control_sentinel, "generated_tokens_queue": generated_tokens_queue, "temperature": temperature, - "top_p": top_p, + "min_p": min_p, "num_preceding_active_pitches": num_preceding_active_pitches, "first_on_msg_epoch_ms": first_on_msg_epoch_ms, }, @@ -860,8 +892,7 @@ def stream_msgs( "tokenizer": tokenizer, "first_on_msg_epoch_ms": first_on_msg_epoch_ms, "priming_seq_last_onset_ms": tokenizer.calc_length_ms( - priming_seq[priming_seq.index(tokenizer.bos_tok) :], - onset=True, + priming_seq, onset=True ), }, daemon=True, @@ -872,10 +903,7 @@ def stream_msgs( midi_messages_queue=midi_messages_queue, msgs=msgs, prev_msg_epoch_time_ms=first_on_msg_epoch_ms - + tokenizer.calc_length_ms( - priming_seq[priming_seq.index(tokenizer.bos_tok) :], - onset=False, - ), + + tokenizer.calc_length_ms(priming_seq, onset=False), midi_output_port=midi_output_port, control_sentinel=control_sentinel, ) @@ -1028,8 +1056,22 @@ def play_midi_file(midi_port: str, midi_path: str): logger = get_logger("FILE") logger.info(f"Playing file at {midi_path} on MIDI port '{midi_port}'") time.sleep(1) + active_pitches = [] with mido.open_output(midi_port) as output_port: for msg in mido.MidiFile(midi_path).play(): + if msg.type == "note_on" and msg.velocity > 0: + if msg.note in active_pitches: + _off_msg = copy.deepcopy(msg) + _off_msg.velocity = 0 + output_port.send(_off_msg) + else: + active_pitches.append(msg.note) + elif msg.type == "note_off" or ( + msg.type == "note_on" and msg.velocity == 0 + ): + if msg.note in active_pitches: + active_pitches.remove(msg.note) + logger.debug(f"{msg}") output_port.send(msg) @@ -1071,11 +1113,11 @@ def parse_args(): default=0.95, ) argp.add_argument( - "-top_p", - help="sampling top_p value", + "-min_p", + help="sampling min_p value", type=float, required=False, - default=0.95, + default=0.03, ) argp.add_argument( "-cfg", @@ -1090,21 +1132,6 @@ def parse_args(): action="append", help="manually add metadata key-value pair when sampling", ) - argp.add_argument( - "-guidance_path", type=str, help="path to guidance MIDI", required=False - ) - argp.add_argument( - "-guidance_start_ms", - help="guidance interval start (ms)", - type=int, - required=False, - ) - argp.add_argument( - "-guidance_end_ms", - help="guidance interval end (ms)", - type=int, - required=False, - ) argp.add_argument( "-save_path", type=str, @@ -1118,29 +1145,10 @@ def parse_args(): def main(): args = parse_args() logger = get_logger() - tokenizer = InferenceAbsTokenizer() + tokenizer = AbsTokenizer() model = load_model(checkpoint_path=args.cp) model = compile_model(model=model, max_seq_len=MAX_SEQ_LEN) - if args.guidance_path: - assert ( - args.guidance_start_ms is not None and args.guidance_start_ms >= 0 - ) - assert args.guidance_end_ms is not None and args.guidance_end_ms >= 0 - assert ( - tokenizer._config["guidance"]["min_ms"] - <= args.guidance_end_ms - args.guidance_start_ms - <= tokenizer._config["guidance"]["max_ms"] - ) - guidance_midi_dict = MidiDict.from_midi(args.guidance_path) - - logger.info( - f"Using guidance from {args.guidance_path} in interval {[args.guidance_start_ms, args.guidance_end_ms]}" - ) - - else: - guidance_midi_dict = None - assert (args.midi_path and os.path.isfile(args.midi_path)) or args.midi_in if args.midi_path: midi_input_port = "Midi Through:Midi Through Port-0" @@ -1177,11 +1185,8 @@ def main(): first_on_msg_epoch_ms=first_on_msg_epoch_ms, control_sentinel=control_sentinel, temperature=args.temp, - top_p=args.top_p, + min_p=args.min_p, num_preceding_active_pitches=num_active_pitches, - guidance_midi_dict=guidance_midi_dict, - guidance_start_ms=args.guidance_start_ms, - guidance_end_ms=args.guidance_end_ms, ) keypress_thread.join() @@ -1191,5 +1196,6 @@ def main(): midi.save(args.save_path) +# TODO: Note this is all broken due to tokenizer changes -- fix if __name__ == "__main__": main() From 654d9dee8985b71424c169ab62b8bc74d22bcb9c Mon Sep 17 00:00:00 2001 From: Louis Date: Fri, 23 May 2025 18:53:27 +0000 Subject: [PATCH 40/72] add looping and ending to demo --- demo/demo.py | 215 +++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 173 insertions(+), 42 deletions(-) diff --git a/demo/demo.py b/demo/demo.py index c3083a5..3774525 100644 --- a/demo/demo.py +++ b/demo/demo.py @@ -40,8 +40,6 @@ # TODO: # - Add CFG support (eek) -# - Add looping functionality -# - Add ending functionality file_handler = logging.FileHandler("./demo.log", mode="w") @@ -110,7 +108,6 @@ def _compile_prefill( compiled_prefill( model, enc_seq=torch.ones(1, 8192, device="cuda", dtype=torch.int) ) - print logger.info( f"Finished compiling - took {time.time() - start_compile_time_s:.4f} seconds" ) @@ -497,7 +494,7 @@ def decode_tokens( ) logits[:, tokenizer.tok_to_id[tokenizer.dim_tok]] = float("-inf") - logits[:, tokenizer.tok_to_id[tokenizer.eos_tok]] = float("-inf") + # logits[:, tokenizer.tok_to_id[tokenizer.eos_tok]] = float("-inf") for dur_ms in range(0, MIN_NOTE_LEN_MS, 10): logits[:, tokenizer.tok_to_id[("dur", dur_ms)]] = float("-inf") @@ -513,8 +510,13 @@ def decode_tokens( f"({(time.time() - decode_one_start_time_s)*1000:.2f}ms) {idx}: {next_token}" ) - generated_tokens_queue.put(next_token) - idx += 1 + if next_token == tokenizer.eos_tok: + logger.info("EOS token produced, exiting...") + generated_tokens_queue.put(next_token) + return + else: + generated_tokens_queue.put(next_token) + idx += 1 while not control_sentinel.is_set(): time.sleep(0.1) @@ -523,7 +525,6 @@ def decode_tokens( generated_tokens_queue.put(None) -# TODO: BUG: Potentially a bug with dim_toks ect... being removed during kv-preprocessing @torch.autocast("cuda", dtype=DTYPE) @torch.inference_mode() def generate_tokens( @@ -541,7 +542,7 @@ def generate_tokens( generate_start_s = time.time() priming_seq_len = len(priming_seq) - start_idx = max(2, priming_seq_len - 4 * num_preceding_active_pitches) + start_idx = max(2, priming_seq_len - 4 * num_preceding_active_pitches - 1) enc_seq = torch.tensor( [ tokenizer.encode( @@ -647,14 +648,26 @@ def decode_tokens_to_midi( while True: while True: tok = generated_tokens_queue.get() - if tok is None: - # This is triggered iff control sentinel is set a second time - logger.info("Seen exit signal") - midi_messages_queue.put(None) + if tok is tokenizer.eos_tok: + _uuid = uuid.uuid4() + end_msg = { + "pitch": -1, + "vel": -1, + "epoch_time_ms": offset_epoch_ms + 250, # Last note offset + "uuid": _uuid, + } # pitch=-1 denotes end_msg + midi_messages_queue.put(end_msg) + logger.info(f"Seen exit signal: EOS token") + logger.debug(f"Put message: {end_msg}") + return + + elif tok is None: + logger.info(f"Seen exit signal") return logger.debug(f"Seen token: {tok}") note_buffer.append(tok) + if isinstance(tok, tuple) and tok[0] == "dur": break @@ -701,6 +714,7 @@ def stream_midi( prev_msg_epoch_time_ms: float, midi_output_port: str, control_sentinel: threading.Event, + midi_stream_channel: int, ): logger = get_logger("STREAM") logger.info( @@ -743,7 +757,7 @@ def stream_midi( msg = midi_messages[0] if ( - (msg["vel"] != 0) + (msg["vel"] > 0) and ( msg["epoch_time_ms"] - latency_adjusted_epoch_time_ms <= MIN_NOTE_DELTA_MS @@ -767,6 +781,9 @@ def stream_midi( < latency_adjusted_epoch_time_ms - msg["epoch_time_ms"] <= 50 ): + if msg["pitch"] == -1: # End msg + control_sentinel.set() + break mido_msg = mido.Message( "note_on", @@ -787,6 +804,7 @@ def stream_midi( if should_send is True: mido_msg_with_time = copy.deepcopy(mido_msg) + mido_msg_with_time.channel = midi_stream_channel mido_msg_with_time.time = max( 0, msg["epoch_time_ms"] - prev_msg_epoch_time_ms ) @@ -821,6 +839,8 @@ def stream_midi( logger.info("Processing remaining note_off messages") + logger.debug(midi_messages) + remaining_note_off_messages = [ msg for msg in midi_messages @@ -834,7 +854,7 @@ def stream_midi( "note_on", note=msg["pitch"], velocity=0, - channel=0, + channel=midi_stream_channel, time=msg["epoch_time_ms"] - prev_msg_epoch_time_ms, ) prev_msg_epoch_time_ms = msg["epoch_time_ms"] @@ -855,6 +875,8 @@ def stream_msgs( temperature: float, min_p: float, num_preceding_active_pitches: int, + midi_stream_channel: int, + is_ending: bool = False, ): midi = convert_msgs_to_midi(msgs=msgs) midi_dict = MidiDict(**midi_to_dict(midi)) @@ -863,6 +885,8 @@ def stream_msgs( if tokenizer.dim_tok in priming_seq: priming_seq.remove(tokenizer.dim_tok) + if is_ending is True: + priming_seq.append(tokenizer.dim_tok) generated_tokens_queue = queue.Queue() midi_messages_queue = queue.Queue() @@ -906,6 +930,7 @@ def stream_msgs( + tokenizer.calc_length_ms(priming_seq, onset=False), midi_output_port=midi_output_port, control_sentinel=control_sentinel, + midi_stream_channel=midi_stream_channel, ) generate_tokens_thread.join() @@ -915,15 +940,26 @@ def stream_msgs( def convert_msgs_to_midi(msgs: list[mido.Message]): - track = mido.MidiTrack() - track.append(mido.MetaMessage("set_tempo", tempo=500000, time=0)) - track.append(mido.Message("program_change", program=0, channel=0, time=0)) + logger = get_logger("convert") + + channel_to_track = { + chan: mido.MidiTrack() + for chan in list(set([msg.channel for msg in msgs])) + } + for msg in msgs: - track.append(msg) + channel_to_track[msg.channel].append(msg) - mid = mido.MidiFile(type=0) + mid = mido.MidiFile(type=1) mid.ticks_per_beat = 500 - mid.tracks.append(track) + + for channel, track in channel_to_track.items(): + track.insert(0, mido.MetaMessage("set_tempo", tempo=500000, time=0)) + track.insert( + 0, + mido.Message("program_change", program=0, channel=channel, time=0), + ) + mid.tracks.append(track) return mid @@ -931,14 +967,16 @@ def convert_msgs_to_midi(msgs: list[mido.Message]): def capture_midi_input( midi_input_port: str, control_sentinel: threading.Event, + midi_capture_channel: int, midi_control_signal: int | None = None, midi_through_port: str | None = None, + first_msg_epoch_time_ms: int | None = None, ): logger = get_logger("CAPTURE") received_messages = [] active_pitches = set() first_on_msg_epoch_ms = None - prev_msg_epoch_time_ms = None + prev_msg_epoch_time_ms = first_msg_epoch_time_ms # logger.info(f"Listening on MIDI port: '{midi_input_port}'") logger.info(f"Using MIDI control signal: {midi_control_signal}") @@ -967,7 +1005,7 @@ def capture_midi_input( prev_msg_epoch_time_ms = get_epoch_time_ms() msg.time = msg_time_ms - msg.channel = 0 + msg.channel = midi_capture_channel logger.info(f"Received message: [{msg}]") if msg.is_meta is True or msg.type == "program_change": @@ -1007,7 +1045,7 @@ def capture_midi_input( type="note_on", note=pitch, velocity=0, - channel=0, + channel=midi_capture_channel, time=get_epoch_time_ms() - prev_msg_epoch_time_ms, ) received_messages.append(msg) @@ -1020,7 +1058,7 @@ def capture_midi_input( type="note_on", note=pitch, velocity=0, - channel=0, + channel=midi_capture_channel, time=0, ) received_messages.append(msg) @@ -1031,19 +1069,20 @@ def capture_midi_input( type="control_change", control=64, value=0, - channel=0, + channel=midi_capture_channel, time=0, ) received_messages.append(msg) if midi_through is not None: midi_through.send(msg) + # TODO: Need to figure out what the hell this does - is it needed? # Workaround for the way that file-playback is implemented - delete msg = mido.Message( type="control_change", control=midi_control_signal, value=0, - channel=0, + channel=midi_capture_channel, time=0, ) if midi_through is not None: @@ -1076,14 +1115,48 @@ def play_midi_file(midi_port: str, midi_path: str): output_port.send(msg) -def listen_for_control_signal_keypress(control_sentinel: threading.Event): +def listen_for_keypress_control_signal( + control_sentinel: threading.Event, + end_sentinel: threading.Event, +): logger = get_logger("KEYBOARD") - for _ in range(2): + while True: time.sleep(1) - input() - logger.info("Keypress seen") + _input = input() + logger.info(f'Keypress seen "{_input}"') control_sentinel.set() + if _input == "e": + end_sentinel.set() + + +# TODO: Not tested +def listen_for_midi_control_signal( + midi_input_port: str, + control_sentinel: threading.Event, + end_sentinel: threading.Event, + midi_control_signal: int | None = None, + midi_end_signal: int | None = None, +): + with mido.open_input(midi_input_port) as midi_input: + while True: + msg = midi_input.receive(block=False) + if msg is None: + time.sleep(0.01) + elif ( + msg.type == "control_change" + and msg.control == midi_control_signal + and msg.value > 0 + ): + control_sentinel.set() + elif ( + msg.type == "control_change" + and msg.control == midi_end_signal + and msg.value > 0 + ): + control_sentinel.set() + end_sentinel.set() + def parse_args(): argp = argparse.ArgumentParser() @@ -1105,6 +1178,11 @@ def parse_args(): type=int, help="MIDI control change message for AI takeover", ) + argp.add_argument( + "-midi_end_signal", + type=int, + help="MIDI control change message to generate ending", + ) argp.add_argument( "-temp", help="sampling temperature value", @@ -1142,6 +1220,11 @@ def parse_args(): return argp.parse_args() +# TODO: Test demo on real instrument with real MIDI interface +# TODO: Possibly a problem with endings being truncated +# TODO: Need functionality for handing case where we run out of model context + + def main(): args = parse_args() logger = get_logger() @@ -1162,21 +1245,69 @@ def main(): midi_input_port = args.midi_in control_sentinel = threading.Event() + end_sentinel = threading.Event() keypress_thread = threading.Thread( - target=listen_for_control_signal_keypress, - args=[control_sentinel], + target=listen_for_keypress_control_signal, + args=[control_sentinel, end_sentinel], + daemon=True, + ) + midi_control_thread = threading.Thread( + target=listen_for_midi_control_signal, + kwargs={ + "midi_input_port": midi_input_port, + "control_sentinel": control_sentinel, + "end_sentinel": end_sentinel, + "midi_control_signal": args.midi_control_signal, + "midi_end_signal": args.midi_end_signal, + }, daemon=True, ) keypress_thread.start() - - msgs, first_on_msg_epoch_ms, num_active_pitches = capture_midi_input( - midi_input_port=midi_input_port, - control_sentinel=control_sentinel, - midi_control_signal=args.midi_control_signal, - midi_through_port=args.midi_through, + midi_control_thread.start() + + msgs = [] + captured_msgs, first_on_msg_epoch_ms, num_active_pitches = ( + capture_midi_input( + midi_input_port=midi_input_port, + control_sentinel=control_sentinel, + midi_control_signal=args.midi_control_signal, + midi_through_port=args.midi_through, + midi_capture_channel=0, + ) ) - control_sentinel.clear() + itt = 0 + while True: + control_sentinel.clear() + msgs = stream_msgs( + model=model, + tokenizer=tokenizer, + msgs=msgs + captured_msgs, + midi_output_port=args.midi_out, + first_on_msg_epoch_ms=first_on_msg_epoch_ms, + control_sentinel=control_sentinel, + temperature=args.temp, + min_p=args.min_p, + num_preceding_active_pitches=num_active_pitches, + midi_stream_channel=itt, + is_ending=False, + ) + + control_sentinel.clear() + if end_sentinel.is_set(): + break + else: + itt += 1 + + captured_msgs, _, num_active_pitches = capture_midi_input( + midi_input_port=midi_input_port, + control_sentinel=control_sentinel, + midi_control_signal=args.midi_control_signal, + midi_through_port=args.midi_through, + midi_capture_channel=itt, + first_msg_epoch_time_ms=first_on_msg_epoch_ms, + ) + msgs = stream_msgs( model=model, tokenizer=tokenizer, @@ -1184,11 +1315,12 @@ def main(): midi_output_port=args.midi_out, first_on_msg_epoch_ms=first_on_msg_epoch_ms, control_sentinel=control_sentinel, - temperature=args.temp, + temperature=args.temp / 2, min_p=args.min_p, num_preceding_active_pitches=num_active_pitches, + midi_stream_channel=itt, + is_ending=True, ) - keypress_thread.join() if args.save_path: logger.info(f"Saving result to {args.save_path}") @@ -1196,6 +1328,5 @@ def main(): midi.save(args.save_path) -# TODO: Note this is all broken due to tokenizer changes -- fix if __name__ == "__main__": main() From 77d27b592a787eddbfd66efba4235fb7bce44a05 Mon Sep 17 00:00:00 2001 From: Louis Date: Sun, 25 May 2025 15:41:13 +0000 Subject: [PATCH 41/72] push mlx imp for test --- aria/inference/mlx.py | 350 +++++++++++++++++++++++++++++++++++ aria/inference/model_mlx.py | 264 ++++++++++++++++++++++++++ aria/inference/sample_mlx.py | 281 ++++++++++++++++++++++++++++ 3 files changed, 895 insertions(+) create mode 100644 aria/inference/mlx.py create mode 100644 aria/inference/model_mlx.py create mode 100644 aria/inference/sample_mlx.py diff --git a/aria/inference/mlx.py b/aria/inference/mlx.py new file mode 100644 index 0000000..7ec1b03 --- /dev/null +++ b/aria/inference/mlx.py @@ -0,0 +1,350 @@ +"""Inference implementation for mlx backend""" + +from dataclasses import dataclass +from typing import Any, Dict, Optional, Union + +import mlx.core as mx +import mlx.nn as nn + +from .base import ( + BaseModelArgs, + create_attention_mask, + scaled_dot_product_attention, +) +from .cache import ChunkedKVCache, KVCache +from .rope_utils import initialize_rope +from .switch_layers import SwitchGLU + + +@dataclass +class TextArgs(BaseModelArgs): + attention_bias: bool + attention_chunk_size: int + head_dim: int + hidden_act: str + hidden_size: int + interleave_moe_layer_step: int + intermediate_size: int + intermediate_size_mlp: int + max_position_embeddings: int + model_type: str + num_attention_heads: int + num_experts_per_tok: int + num_hidden_layers: int + num_key_value_heads: int + num_local_experts: int + rms_norm_eps: float + rope_scaling: Any + rope_theta: float + use_qk_norm: bool + vocab_size: int + attn_temperature_tuning: int = 4 + floor_scale: int = 8192 + attn_scale: float = 0.1 + + +@dataclass +class ModelArgs(BaseModelArgs): + text_config: Union[TextArgs, dict] + model_type: str + + def __post_init__(self): + self.text_config = TextArgs.from_dict(self.text_config) + + +class Attention(nn.Module): + def __init__(self, args: TextArgs, layer_idx: int): + super().__init__() + + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + + self.use_rope = int( + (layer_idx + 1) % 4 != 0 + ) # rope unused for dense layers + self.attn_temperature_tuning = args.attn_temperature_tuning + self.floor_scale = args.floor_scale + self.attn_scale = args.attn_scale + + self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads + + self.scale = head_dim**-0.5 + if hasattr(args, "attention_bias"): + attention_bias = args.attention_bias + else: + attention_bias = False + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias) + + self.use_qk_norm = args.use_qk_norm and self.use_rope + + if self.use_rope: + self.rope = initialize_rope( + head_dim, + args.rope_theta, + traditional=True, + scaling_config=args.rope_scaling, + max_position_embeddings=args.max_position_embeddings, + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + offset = cache.offset + else: + offset = 0 + + if self.use_rope: + queries = self.rope(queries, offset=offset) + keys = self.rope(keys, offset=offset) + + if self.use_qk_norm: + queries = mx.fast.rms_norm(queries, weight=None, eps=1e-6) + keys = mx.fast.rms_norm(keys, weight=None, eps=1e-6) + + if self.attn_temperature_tuning and not self.use_rope: + attn_scales = ( + mx.log( + mx.floor( + mx.arange(offset + 1, offset + L + 1) / self.floor_scale + ) + + 1.0 + ) + * self.attn_scale + + 1.0 + ) + attn_scales = attn_scales[:, None] + queries = (queries * attn_scales).astype(queries.dtype) + + if cache is not None: + keys, values = cache.update_and_fetch(keys, values) + + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output) + + +class MLP(nn.Module): + def __init__(self, args: ModelArgs, intermediate_size: int = None): + super().__init__() + + dim = args.hidden_size + hidden_dim = intermediate_size or args.intermediate_size + + self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) + self.down_proj = nn.Linear(hidden_dim, dim, bias=False) + self.up_proj = nn.Linear(dim, hidden_dim, bias=False) + + def __call__(self, x) -> mx.array: + return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class MoE(nn.Module): + def __init__(self, args): + super().__init__() + self.top_k = args.num_experts_per_tok + self.num_experts = args.num_local_experts + self.experts = SwitchGLU( + args.hidden_size, args.intermediate_size, self.num_experts + ) + self.router = nn.Linear( + args.hidden_size, args.num_local_experts, bias=False + ) + self.shared_expert = MLP(args) + + def __call__(self, x) -> mx.array: + logits = self.router(x) + k = self.top_k + indices = mx.argpartition(-logits, kth=k - 1, axis=-1)[..., :k] + scores = mx.take_along_axis(logits, indices, axis=-1) + scores = mx.sigmoid(scores.astype(mx.float32)).astype(x.dtype) + + out = self.experts(x * scores, indices).squeeze(2) + return out + self.shared_expert(x) + + +class TransformerBlock(nn.Module): + def __init__(self, args: TextArgs, layer_idx: int): + super().__init__() + self.num_attention_heads = args.num_attention_heads + self.hidden_size = args.hidden_size + self.self_attn = Attention(args, layer_idx) + self.is_moe_layer = (layer_idx % args.interleave_moe_layer_step) == ( + args.interleave_moe_layer_step - 1 + ) + if self.is_moe_layer: + self.feed_forward = MoE(args) + else: + self.feed_forward = MLP(args, args.intermediate_size_mlp) + + self.input_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + self.args = args + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + r = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r + r = self.feed_forward(self.post_attention_layernorm(h)) + out = h + r + return out + + +class LlamaModel(nn.Module): + def __init__(self, args: TextArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.num_hidden_layers = args.num_hidden_layers + assert self.vocab_size > 0 + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + TransformerBlock(args, i) for i in range(args.num_hidden_layers) + ] + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.attention_chunk_size = args.attention_chunk_size + + def __call__( + self, + inputs: mx.array, + mask: mx.array = None, + cache=None, + ): + h = self.embed_tokens(inputs) + + if cache is not None: + for idx, c in enumerate(cache): + if (idx + 1) % 4 != 0: + c.maybe_trim_front() + start = cache[0].start_position + offset = cache[0].offset + else: + start = 0 + offset = 0 + end = offset + h.shape[1] + linds = mx.arange(start, end) + rinds = mx.arange(offset, end)[:, None] + block_pos = mx.abs( + (linds // self.attention_chunk_size) + - (rinds // self.attention_chunk_size) + ) + token_pos = linds <= rinds + chunk_mask = (block_pos == 0) & token_pos + + if mask is None: + mask = create_attention_mask(h, cache) + else: + chunk_mask &= mask + + if cache is None: + cache = [None] * len(self.layers) + + for idx, (layer, c) in enumerate(zip(self.layers, cache)): + use_chunked_attention = (idx + 1) % 4 != 0 + if use_chunked_attention: + local_mask = chunk_mask + else: + local_mask = mask + h = layer(h, local_mask, cache=c) + + return self.norm(h) + + +class LanguageModel(nn.Module): + def __init__(self, args: TextArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.model = LlamaModel(self.args) + self.lm_head = nn.Linear( + self.args.hidden_size, self.args.vocab_size, bias=False + ) + + def __call__( + self, + inputs: mx.array, + mask: mx.array = None, + cache=None, + ): + out = self.model(inputs, mask, cache) + return self.lm_head(out) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.language_model = LanguageModel(args.text_config) + + def __call__( + self, + inputs: mx.array, + mask: mx.array = None, + cache=None, + ): + return self.language_model(inputs, mask, cache) + + def sanitize(self, weights): + def to_remove(k): + return "vision_model" in k or "multi_modal_projector" in k + + # Remove vision weights + weights = {k: v for k, v in weights.items() if not to_remove(k)} + + # Rename expert weights for SwitchGLU + for l in range(self.args.text_config.num_hidden_layers): + prefix = f"language_model.model.layers.{l}.feed_forward.experts" + if f"{prefix}.gate_up_proj" in weights: + v = weights.pop(f"{prefix}.gate_up_proj") + gate_k = f"{prefix}.gate_proj.weight" + up_k = f"{prefix}.up_proj.weight" + gate_proj, up_proj = mx.split(v, 2, axis=-1) + weights[gate_k] = mx.swapaxes(gate_proj, 1, 2) + weights[up_k] = mx.swapaxes(up_proj, 1, 2) + if f"{prefix}.down_proj" in weights: + down_proj = weights.pop(f"{prefix}.down_proj") + weights[f"{prefix}.down_proj.weight"] = mx.swapaxes( + down_proj, 1, 2 + ) + return weights + + @property + def layers(self): + return self.language_model.model.layers + + def make_cache(self): + chunk_size = self.args.text_config.attention_chunk_size + caches = [] + for i in range(len(self.layers)): + if (i + 1) % 4 != 0: + caches.append(ChunkedKVCache(chunk_size)) + else: + caches.append(KVCache()) + return caches diff --git a/aria/inference/model_mlx.py b/aria/inference/model_mlx.py new file mode 100644 index 0000000..5895351 --- /dev/null +++ b/aria/inference/model_mlx.py @@ -0,0 +1,264 @@ +"""Inference implementation for mlx backend""" + +from dataclasses import dataclass +from aria.model import ModelConfig + +import mlx.core as mx +import mlx.nn as nn + + +# TODO: Implement this with dynamic kv-size +class KVCache(nn.Module): + def __init__( + self, + max_batch_size: int, + max_seq_length: int, + n_heads: int, + head_dim: int, + dtype: mx.Dtype = mx.bfloat16, + ): + super().__init__() + self.dtype = dtype + cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) + self.k_cache = mx.zeros(cache_shape, dtype=dtype) + self.v_cache = mx.zeros(cache_shape, dtype=dtype) + + def update(self, input_pos: mx.array, k_val: mx.array, v_val: mx.array): + # input_pos: [S], k_val: [B, H, S, D] + assert input_pos.shape[0] == k_val.shape[2] + + k_out = self.k_cache + v_out = self.v_cache + k_out[:, :, input_pos] = k_val + v_out[:, :, input_pos] = v_val + + return k_out, v_out + + +class TransformerBlock(nn.Module): + def __init__( + self, + model_config: ModelConfig, + ): + super().__init__() + self.d_model = model_config.d_model + self.n_heads = model_config.n_heads + self.d_head = self.d_model // self.n_heads + self.max_seq_len = model_config.max_seq_len + self.scale = self.d_head**-0.5 + + # Att + self.mixed_qkv = nn.Linear( + input_dims=model_config.d_model, + output_dims=3 * model_config.d_model, + bias=False, + ) + self.att_proj_linear = nn.Linear( + input_dims=model_config.d_model, + output_dims=model_config.d_model, + bias=False, + ) + + # FF + self.ff_gate_proj = nn.Linear( + input_dims=model_config.d_model, + output_dims=model_config.d_model * model_config.ff_mult, + bias=False, + ) + self.ff_up_proj = nn.Linear( + input_dims=model_config.d_model, + output_dims=model_config.d_model * model_config.ff_mult, + bias=False, + ) + self.ff_down_proj = nn.Linear( + input_dims=model_config.d_model * model_config.ff_mult, + output_dims=model_config.d_model, + bias=False, + ) + + # Pre layer norms + self.norm1 = nn.LayerNorm(model_config.d_model) + self.norm2 = nn.LayerNorm(model_config.d_model) + + self.kv_cache = None + + def __call__( + self, + x: mx.array, + input_pos: mx.array, + offset: int, + mask: mx.array, + ): + assert self.kv_cache is not None, "Cache not initialized" + + x += self._att_block( + x=self.norm1(x), + input_pos=input_pos, + offset=offset, + mask=mask, + ) + x = x + self._ff_block(self.norm2(x)) + + return x + + def get_kv(self, k: mx.array, v: mx.array, input_pos: mx.array): + k, v = self.kv_cache.update(k_val=k, v_val=v, input_pos=input_pos) + + return k, v + + def _att_block( + self, + x: mx.array, + input_pos: mx.array, + offset: int, + mask: mx.array, + ): + + qkv_splits = self.mixed_qkv(x).split(3, axis=2) + q, k, v = qkv_splits[0], qkv_splits[1], qkv_splits[2] + + batch_size, seq_len, _ = q.shape + q = q.reshape(batch_size, seq_len, self.n_heads, self.d_head) + k = k.reshape(batch_size, seq_len, self.n_heads, self.d_head) + v = v.reshape(batch_size, seq_len, self.n_heads, self.d_head) + + q = apply_rotary_emb_mlx(q, offset=offset) + k = apply_rotary_emb_mlx(k, offset=offset) + q, k, v = map(lambda x: x.transpose(0, 2, 1, 3), (q, k, v)) + + k, v = self.get_kv(k, v, input_pos=input_pos) + wv = mx.fast.scaled_dot_product_attention( + q=q, + k=k, + v=v, + scale=self.scale, + mask=mask, + ) + + # (bz, nh, L, dh) -> (bz, L, nh, dh) -> (bz, L, d) + wv = wv.transpose(0, 2, 1, 3).reshape( + batch_size, seq_len, self.n_heads * self.d_head + ) + + return self.att_proj_linear(wv) + + def _ff_block(self, x: mx.array): + return self.ff_down_proj( + nn.silu(self.ff_gate_proj(x)) * self.ff_up_proj(x) + ) + + +class Transformer(nn.Module): + def __init__(self, model_config: ModelConfig): + super().__init__() + self.model_config = model_config + + self.tok_embeddings = nn.Embedding( + num_embeddings=model_config.vocab_size, + dims=model_config.d_model, + ) + self.encode_layers = [ + TransformerBlock(model_config) for _ in range(model_config.n_layers) + ] + self.out_layer_norm = nn.LayerNorm(model_config.d_model) + self.causal_mask = None + + def __call__( + self, + idxs: mx.array, + input_pos: mx.array, + offset: int, + pad_idxs: mx.array | None = None, + ): + assert self.causal_mask is not None, "Caches must be initialized first" + + mask = self.causal_mask[None, None, input_pos] + + if pad_idxs is not None: + pad_mask = mx.expand_dims(mx.expand_dims(pad_idxs, axis=1), axis=1) + mask = mask & ~pad_mask + + x = self.tok_embeddings(idxs) + for layer in self.encode_layers: + x = layer(x, input_pos, offset, mask) + + x = self.out_layer_norm(x) + + return x + + +class TransformerLM(nn.Module): + def __init__(self, model_config: ModelConfig): + super().__init__() + self.model_config = model_config + self.max_seq_len = model_config.max_seq_len + self.model = Transformer(model_config) # Implement + self.lm_head = nn.Linear( + model_config.d_model, model_config.vocab_size, bias=False + ) + + def __call__( + self, + idxs: mx.array, + input_pos: mx.array, + offset: int, + pad_idxs: mx.array | None = None, + ): + hidden_states = self.model( + idxs=idxs, + input_pos=input_pos, + offset=offset, + pad_idxs=pad_idxs, + ) + logits = self.lm_head(hidden_states) + + return logits + + def setup_cache( + self, + batch_size, + max_seq_len=4096, + dtype=mx.bfloat16, + ): + # Init cache + for b in self.model.encode_layers: + b.kv_cache = KVCache( + max_batch_size=batch_size, + max_seq_length=max_seq_len, + n_heads=self.model_config.n_heads, + head_dim=self.model_config.d_model // self.model_config.n_heads, + dtype=dtype, + ) + + # mx.bool isn't a thing? How do I do this in mlx, is it mx.bool_ ? + self.model.causal_mask = mx.tril( + mx.ones((max_seq_len, max_seq_len), dtype=mx.bool_) + ) + + +def apply_rotary_emb_mlx( + x: mx.array, + offset: int = 0, +) -> mx.array: + # Original x shape: (b_sz, s_len, n_head, d_head) + original_shape = x.shape + b_sz, s_len, n_head, d_head = original_shape + + # Transpose to (b_sz, n_head, s_len, d_head) + x_permuted = x.transpose(0, 2, 1, 3) + # Reshape for mx.fast.rope: (b_sz * n_head, s_len, d_head) + x_reshaped = x_permuted.reshape(-1, s_len, d_head) + + rotated_x_reshaped = mx.fast.rope( + x_reshaped, + dims=d_head, + traditional=False, + base=500000, + scale=1.0, + offset=offset, + ) + + rotated_x_permuted = rotated_x_reshaped.reshape(b_sz, n_head, s_len, d_head) + rotated_x = rotated_x_permuted.transpose(0, 2, 1, 3) + + return rotated_x diff --git a/aria/inference/sample_mlx.py b/aria/inference/sample_mlx.py new file mode 100644 index 0000000..578c804 --- /dev/null +++ b/aria/inference/sample_mlx.py @@ -0,0 +1,281 @@ +"""Contains generation/sampling code (mlx)""" + +import torch +import numpy as np +import mlx.core as mx + +from typing import List +from tqdm import tqdm + +from aria.inference.model_mlx import TransformerLM +from ariautils.tokenizer import Tokenizer + + +def decode_one( + model: TransformerLM, + idxs: mx.array, + input_pos: mx.array, + pad_idxs: mx.array | None = None, +): + assert input_pos.shape[-1] == 1 + + compiled_forward = mx.compile(model.__call__) + logits = compiled_forward( + idxs=idxs, + input_pos=input_pos, + offset=input_pos[0], + # pad_idxs=pad_idxs, + )[:, -1] + # logits = model( + # idxs=idxs, + # input_pos=input_pos, + # pad_idxs=pad_idxs, + # )[:, -1] + + return logits + + +def prefill( + model: TransformerLM, + idxs: mx.array, + input_pos: mx.array, + pad_idxs: mx.array | None = None, +): + logits = model( + idxs=idxs, + input_pos=input_pos, + offset=input_pos[0], + pad_idxs=pad_idxs, + )[:, -1] + + return logits + + +def update_seq_ids_( + seq: mx.array, + idx: int, + next_token_ids: mx.array, + dim_tok_inserted: list, + eos_tok_seen: list, + max_len: int, + force_end: bool, + tokenizer: Tokenizer, +): + # Insert dim and pad toks + for _idx in range(seq.shape[0]): + if eos_tok_seen[_idx] == True: + next_token_ids[_idx] = tokenizer.tok_to_id[tokenizer.pad_tok] + elif ( + force_end + and idx >= max_len - 130 + and dim_tok_inserted[_idx] is False + and tokenizer.id_to_tok[next_token_ids[_idx].item()][0] + not in ("dur", "onset") + ): + next_token_ids[_idx] = tokenizer.tok_to_id[tokenizer.dim_tok] + + # Update dim_tok_inserted and eos_tok_seen + if next_token_ids[_idx] == tokenizer.tok_to_id[tokenizer.dim_tok]: + dim_tok_inserted[_idx] = True + elif next_token_ids[_idx] == tokenizer.tok_to_id[tokenizer.eos_tok]: + eos_tok_seen[_idx] = True + + seq[:, idx] = next_token_ids + + +def sample_batch( + model: TransformerLM, + tokenizer: Tokenizer, + prompts: List[list], + max_new_tokens: int, + force_end=False, + temp: float = 0.95, + min_p: float | None = None, + # compile: bool = False, +): + if min_p is not None: + assert 0.0 <= min_p <= 1.0 + if temp is not None: + assert 0.0 <= temp <= 2.0 + if force_end: + assert max_new_tokens > 130, "prompt too long to use force_end=True" + + prompt_len = len(prompts[0]) + num_prompts = len(prompts) + assert all([len(p) == prompt_len for p in prompts]) + + model.eval() + dim_tok_inserted = [False for _ in range(num_prompts)] + eos_tok_seen = [False for _ in range(num_prompts)] + total_len = prompt_len + max_new_tokens + seq = mx.stack( + [ + mx.array( + tokenizer.encode(p + [tokenizer.pad_tok] * (total_len - len(p))) + ) + for p in prompts + ] + ) + model.setup_cache(batch_size=num_prompts, max_seq_len=total_len) + print( + f"Using hyperparams: temp={temp}, min_p={min_p}, gen_len={max_new_tokens}" + ) + + for idx in ( + pbar := tqdm( + range(prompt_len, total_len), + total=total_len - prompt_len, + leave=False, + ) + ): + if idx == prompt_len: + logits = prefill( + model, + idxs=seq[:, :idx], + input_pos=mx.arange(0, idx), + ) + else: + logits = decode_one( + model, + idxs=seq[:, idx - 1 : idx], + input_pos=mx.array( + [idx - 1], + dtype=mx.int32, + ), + ) + + if temp > 0.0: + probs = mx.softmax(logits / temp, axis=-1) + next_token_ids = sample_min_p(probs, min_p).flatten() + else: + next_token_ids = mx.argmax(logits, axis=-1).flatten() + + print(tokenizer.id_to_tok[next_token_ids[0].item()]) + + update_seq_ids_( + seq=seq, + idx=idx, + next_token_ids=next_token_ids, + dim_tok_inserted=dim_tok_inserted, + eos_tok_seen=eos_tok_seen, + max_len=total_len, + force_end=force_end, + tokenizer=tokenizer, + ) + + if all(seen_eos is True for seen_eos in eos_tok_seen): + break + + decoded_results = [tokenizer.decode(s) for s in seq.tolist()] + decoded_results = [ + ( + res[: res.index(tokenizer.eos_tok) + 1] + if tokenizer.eos_tok in res + else res + ) + for res in decoded_results + ] + + return decoded_results + + +# TODO: Broken +# def sample_min_p(probs: mx.array, p_base: float): # Added type hint +# """See - https://arxiv.org/pdf/2407.01082""" +# p_max = mx.max(probs, axis=-1, keepdims=True) +# p_scaled = p_base * p_max +# mask = probs >= p_scaled + +# masked_probs = mx.where(~mask, mx.zeros_like(probs), probs) +# sum_masked_probs = mx.sum(masked_probs, axis=-1, keepdims=True) +# sum_masked_probs = mx.where(sum_masked_probs == 0, 1e-9, sum_masked_probs) +# masked_probs_normalized = masked_probs / sum_masked_probs + +# next_token = mx.random.categorical(masked_probs_normalized, num_samples=1) + +# return next_token + + +def sample_min_p(probs: mx.array, p_base: float): # Added type hint + """See - https://arxiv.org/pdf/2407.01082""" + p_max = mx.max(probs, axis=-1, keepdims=True) + p_scaled = p_base * p_max + mask = probs >= p_scaled + + masked_probs = mx.where(~mask, mx.zeros_like(probs), probs) + sum_masked_probs = mx.sum(masked_probs, axis=-1, keepdims=True) + masked_probs_normalized = masked_probs / sum_masked_probs + + # Dumb workaround for mlx not having categorical probs sampler + next_token = mx.array( + torch.multinomial( + torch.from_numpy(np.array(masked_probs_normalized)), num_samples=1 + ), + dtype=mx.int32, + ) + + return next_token + + +def sample(): + import os + import torch + + from aria.model import ModelConfig + from aria.config import load_model_config + + from ariautils.midi import MidiDict + from ariautils.tokenizer import AbsTokenizer + from aria.sample import get_inference_prompt + + CHECKPOINT_PATH = "/mnt/ssd1/aria/v2/medium-75-annealed.safetensors" # Or ".pt" if you're loading a converted PyTorch model + PROMPT_MIDI_PATH = ( + "/home/loubb/Dropbox/shared/demo.mid" # Example: "my_melody_prompt.mid" + ) + + NUM_VARIATIONS = 2 # Number of samples (e.g., 2 variations) + TRUNCATE_LEN_MS = 1000 # Prompt length in milliseconds (e.g., 10 seconds) + GEN_LENGTH = 256 # Number of new tokens to generate (args.l) + FORCE_END = False # Whether to force sequence end (args.e) + TEMPERATURE = 0.98 # Sampling temperature (args.temp) + MIN_P = 0.04 # Min-p sampling (args.min_p) + + SAMPLES_DIR = os.path.join(os.getcwd(), "/home/loubb/Dropbox/shared") + + tokenizer = AbsTokenizer() + model_config = ModelConfig(**load_model_config("medium-emb")) + model_config.set_vocab_size(tokenizer.vocab_size) + model = TransformerLM(model_config) + model.load_weights(CHECKPOINT_PATH) + + midi_dict = MidiDict.from_midi(mid_path=PROMPT_MIDI_PATH) + prompt_seq = get_inference_prompt( + tokenizer=tokenizer, + midi_dict=midi_dict, + prompt_len_ms=TRUNCATE_LEN_MS, + ) + + print(prompt_seq) + print(f"Prompt sequence length: {len(prompt_seq)} tokens") + prompts = [prompt_seq for _ in range(NUM_VARIATIONS)] + + results = sample_batch( + model=model, + tokenizer=tokenizer, + prompts=prompts, + max_new_tokens=GEN_LENGTH, + force_end=FORCE_END, + temp=TEMPERATURE, + min_p=MIN_P, + ) + + for idx, tokenized_seq in enumerate(results): + res_midi_dict = tokenizer.detokenize(tokenized_seq) + res_midi = res_midi_dict.to_midi() + output_file_path = os.path.join(SAMPLES_DIR, f"res_{idx + 1}.mid") + res_midi.save(output_file_path) + print(f"Saved result {idx + 1} to {output_file_path}") + + +if __name__ == "__main__": + sample() From 445d48483cdffbfba1da7c8bdce8134d465d6ae5 Mon Sep 17 00:00:00 2001 From: Louis Date: Mon, 26 May 2025 15:13:04 +0100 Subject: [PATCH 42/72] fix sample script --- aria/inference/sample_mlx.py | 54 ++++++++++++------------------------ 1 file changed, 18 insertions(+), 36 deletions(-) diff --git a/aria/inference/sample_mlx.py b/aria/inference/sample_mlx.py index 578c804..65a33ae 100644 --- a/aria/inference/sample_mlx.py +++ b/aria/inference/sample_mlx.py @@ -3,6 +3,7 @@ import torch import numpy as np import mlx.core as mx +import mlx.nn as nn from typing import List from tqdm import tqdm @@ -19,18 +20,12 @@ def decode_one( ): assert input_pos.shape[-1] == 1 - compiled_forward = mx.compile(model.__call__) - logits = compiled_forward( + logits = model( idxs=idxs, input_pos=input_pos, offset=input_pos[0], - # pad_idxs=pad_idxs, + pad_idxs=pad_idxs, )[:, -1] - # logits = model( - # idxs=idxs, - # input_pos=input_pos, - # pad_idxs=pad_idxs, - # )[:, -1] return logits @@ -116,7 +111,9 @@ def sample_batch( for p in prompts ] ) - model.setup_cache(batch_size=num_prompts, max_seq_len=total_len) + model.setup_cache( + batch_size=num_prompts, max_seq_len=total_len, dtype=mx.float32 + ) print( f"Using hyperparams: temp={temp}, min_p={min_p}, gen_len={max_new_tokens}" ) @@ -179,29 +176,14 @@ def sample_batch( return decoded_results -# TODO: Broken -# def sample_min_p(probs: mx.array, p_base: float): # Added type hint -# """See - https://arxiv.org/pdf/2407.01082""" -# p_max = mx.max(probs, axis=-1, keepdims=True) -# p_scaled = p_base * p_max -# mask = probs >= p_scaled - -# masked_probs = mx.where(~mask, mx.zeros_like(probs), probs) -# sum_masked_probs = mx.sum(masked_probs, axis=-1, keepdims=True) -# sum_masked_probs = mx.where(sum_masked_probs == 0, 1e-9, sum_masked_probs) -# masked_probs_normalized = masked_probs / sum_masked_probs - -# next_token = mx.random.categorical(masked_probs_normalized, num_samples=1) - -# return next_token - - def sample_min_p(probs: mx.array, p_base: float): # Added type hint """See - https://arxiv.org/pdf/2407.01082""" p_max = mx.max(probs, axis=-1, keepdims=True) p_scaled = p_base * p_max mask = probs >= p_scaled + print(mx.sum(mask).item()) + masked_probs = mx.where(~mask, mx.zeros_like(probs), probs) sum_masked_probs = mx.sum(masked_probs, axis=-1, keepdims=True) masked_probs_normalized = masked_probs / sum_masked_probs @@ -219,7 +201,6 @@ def sample_min_p(probs: mx.array, p_base: float): # Added type hint def sample(): import os - import torch from aria.model import ModelConfig from aria.config import load_model_config @@ -228,25 +209,26 @@ def sample(): from ariautils.tokenizer import AbsTokenizer from aria.sample import get_inference_prompt - CHECKPOINT_PATH = "/mnt/ssd1/aria/v2/medium-75-annealed.safetensors" # Or ".pt" if you're loading a converted PyTorch model - PROMPT_MIDI_PATH = ( - "/home/loubb/Dropbox/shared/demo.mid" # Example: "my_melody_prompt.mid" + CHECKPOINT_PATH = ( + "/Users/louis/work/aria/models/medium-75-annealed.safetensors" ) + PROMPT_MIDI_PATH = "/Users/louis/Dropbox/shared/audio.mid" - NUM_VARIATIONS = 2 # Number of samples (e.g., 2 variations) - TRUNCATE_LEN_MS = 1000 # Prompt length in milliseconds (e.g., 10 seconds) - GEN_LENGTH = 256 # Number of new tokens to generate (args.l) + NUM_VARIATIONS = 1 # Number of samples (e.g., 2 variations) + TRUNCATE_LEN_MS = 15000 # Prompt length in milliseconds (e.g., 10 seconds) + GEN_LENGTH = 1024 # Number of new tokens to generate (args.l) FORCE_END = False # Whether to force sequence end (args.e) - TEMPERATURE = 0.98 # Sampling temperature (args.temp) - MIN_P = 0.04 # Min-p sampling (args.min_p) + TEMPERATURE = 0.95 # Sampling temperature (args.temp) + MIN_P = 0.05 # Min-p sampling (args.min_p) - SAMPLES_DIR = os.path.join(os.getcwd(), "/home/loubb/Dropbox/shared") + SAMPLES_DIR = os.path.join(os.getcwd(), "/Users/louis/Dropbox/shared") tokenizer = AbsTokenizer() model_config = ModelConfig(**load_model_config("medium-emb")) model_config.set_vocab_size(tokenizer.vocab_size) model = TransformerLM(model_config) model.load_weights(CHECKPOINT_PATH) + nn.quantize(model.model, group_size=128, bits=8) midi_dict = MidiDict.from_midi(mid_path=PROMPT_MIDI_PATH) prompt_seq = get_inference_prompt( From dc9fdcb9f32bb88f9d32e86bf5d653fd105853ca Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 27 May 2025 20:58:21 +0000 Subject: [PATCH 43/72] add continuous prefill and speculative duration calculation --- demo/demo.py | 538 +++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 392 insertions(+), 146 deletions(-) diff --git a/demo/demo.py b/demo/demo.py index 3774525..e6aaa87 100644 --- a/demo/demo.py +++ b/demo/demo.py @@ -21,15 +21,21 @@ from aria.inference import TransformerLM from aria.model import ModelConfig from aria.config import load_model_config -from aria.sample import prefill, decode_one, sample_min_p +from aria.sample import sample_min_p torch._inductor.config.coordinate_descent_tuning = True torch._inductor.config.triton.unique_kernel_names = True torch._inductor.config.fx_graph_cache = True DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 -MAX_SEQ_LEN = 8192 -PREFILL_COMPILE_SEQ_LEN = 1024 +MAX_SEQ_LEN = 2048 +PREFILL_CHUNK_SIZE = 16 +RECALC_DUR_PREFILL_CHUNK_SIZE = 8 +RECALC_DUR_BUFFER_MS = 50 + +# Decode first +BEAM_WIDTH = 5 +TIME_TOK_WEIGHTING = -6 # HARDWARE: Decoded logits are masked for durations < MIN_NOTE_LEN_MS # HARDWARE: Sends early off-msg if pitch is on MIN_NOTE_DELTA_MS before on-msg @@ -81,53 +87,90 @@ def get_epoch_time_ms() -> int: return round(time.time() * 1000) -def compiled_prefill( +@torch.autocast("cuda", dtype=DTYPE) +@torch.inference_mode() +def prefill( model: TransformerLM, - enc_seq: torch.Tensor, -): - return prefill( - model=model, - idxs=enc_seq[:, :PREFILL_COMPILE_SEQ_LEN], - input_pos=torch.arange(0, PREFILL_COMPILE_SEQ_LEN, device="cuda"), + idxs: torch.Tensor, + input_pos: torch.Tensor, + pad_idxs: torch.Tensor | None = None, +) -> torch.Tensor: + logits = model.forward( + idxs=idxs, + input_pos=input_pos, + pad_idxs=pad_idxs, ) + return logits + + +@torch.autocast("cuda", dtype=DTYPE) +@torch.inference_mode() +def decode_one( + model: TransformerLM, + idxs: torch.Tensor, + input_pos: torch.Tensor, + pad_idxs: torch.Tensor | None = None, +) -> torch.Tensor: + assert input_pos.shape[-1] == 1 + + logits = model.forward( + idxs=idxs, + input_pos=input_pos, + pad_idxs=pad_idxs, + )[:, -1] + + return logits + def _compile_prefill( model: TransformerLM, logger: logging.Logger, + chunk_size: int, ): - global compiled_prefill - compiled_prefill = torch.compile( - compiled_prefill, + assert chunk_size > 1 + + global prefill + prefill = torch.compile( + prefill, mode="reduce-overhead", fullgraph=True, ) - start_compile_time_s = time.time() - logger.info(f"Compiling prefill") - compiled_prefill( - model, enc_seq=torch.ones(1, 8192, device="cuda", dtype=torch.int) + logger.info(f"Compiling prefill (chunk_size={chunk_size})") + prefill( + model, + idxs=torch.ones(1, chunk_size, device="cuda", dtype=torch.int), + input_pos=torch.arange(0, chunk_size, device="cuda", dtype=torch.int), ) logger.info( f"Finished compiling - took {time.time() - start_compile_time_s:.4f} seconds" ) for _ in range(5): - compiled_prefill( + prefill( model, - enc_seq=torch.ones(1, 8192, device="cuda", dtype=torch.int), + idxs=torch.ones(1, chunk_size, device="cuda", dtype=torch.int), + input_pos=torch.arange( + 0, chunk_size, device="cuda", dtype=torch.int + ), ) start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() - compiled_prefill( - model, enc_seq=torch.ones(1, 8192, device="cuda", dtype=torch.int) + prefill( + model, + idxs=torch.ones(1, chunk_size, device="cuda", dtype=torch.int), + input_pos=torch.arange(0, chunk_size, device="cuda", dtype=torch.int), ) end_event.record() end_event.synchronize() compiled_prefill_ms = start_event.elapsed_time(end_event) - logger.info(f"Compiled prefill benchmark: {compiled_prefill_ms:.2f}ms") + compiled_prefill_its = 1000 / compiled_prefill_ms + logger.info( + f"Compiled prefill benchmark: {compiled_prefill_ms:.2f} ms/it ({compiled_prefill_its:.2f} it/s)" + ) return model @@ -142,7 +185,7 @@ def _compile_decode_one(model: TransformerLM, logger: logging.Logger): with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): start_compile_time_s = time.time() - logger.info(f"Compiling forward pass") + logger.info(f"Compiling decode_one") decode_one( model, idxs=torch.tensor([[0]], device="cuda", dtype=torch.int), @@ -174,13 +217,12 @@ def _compile_decode_one(model: TransformerLM, logger: logging.Logger): compiled_forward_ms = start_event.elapsed_time(end_event) compiled_forward_its = 1000 / compiled_forward_ms logger.info( - f"Compiled forward pass benchmark: {compiled_forward_ms:.2f} ms/it ({compiled_forward_its:.2f} it/s)" + f"Compiled decode_one benchmark: {compiled_forward_ms:.2f} ms/it ({compiled_forward_its:.2f} it/s)" ) return model -@torch.autocast("cuda", dtype=DTYPE) @torch.inference_mode() def compile_model(model: TransformerLM, max_seq_len: int): logger = get_logger() @@ -194,7 +236,10 @@ def compile_model(model: TransformerLM, max_seq_len: int): ) model = _compile_decode_one(model=model, logger=logger) - model = _compile_prefill(model=model, logger=logger) + for chunk_size in list({PREFILL_CHUNK_SIZE, RECALC_DUR_PREFILL_CHUNK_SIZE}): + model = _compile_prefill( + model=model, logger=logger, chunk_size=chunk_size + ) return model @@ -232,92 +277,113 @@ def load_model( return model -@torch.autocast("cuda", dtype=DTYPE) +def _first_bad_dur_index( + tokenizer: AbsTokenizer, + priming_seq: list, + pred_ids: list, + chunk_start: int, + last_offset_ms: int, + logger: logging.Logger, +): + num_time_toks = priming_seq[:chunk_start].count(tokenizer.time_tok) + local_onset_ms = tokenizer.calc_length_ms( + priming_seq[:chunk_start], onset=True + ) + logger.debug(f"Starting from local onset {local_onset_ms}") + + for pos, tok_id in enumerate( + pred_ids[: len(priming_seq) - chunk_start], start=chunk_start + ): + prim_tok = priming_seq[pos] # Should never error? + pred_tok = tokenizer.id_to_tok[tok_id] + logger.debug(f"prim={prim_tok}, pred={pred_tok}") + + if isinstance(prim_tok, tuple) and prim_tok[0] == "onset": + local_onset_ms = num_time_toks * 5000 + prim_tok[1] + elif prim_tok == tokenizer.time_tok: + num_time_toks += 1 + elif isinstance(prim_tok, tuple) and prim_tok[0] == "dur": + dur_true = prim_tok[1] + dur_pred = pred_tok[1] + if dur_pred > dur_true and ( + local_onset_ms + dur_true + > last_offset_ms - RECALC_DUR_BUFFER_MS + ): + logger.info( + f"Found token to resample at {pos}: {prim_tok} -> {pred_tok}" + ) + return pos + + return None + + +# TODO: I'm still not 100% sure this is bug free. +# A good debugging strat would be to run it over and over again until we +# cover all of the edge cases @torch.inference_mode() -def recalculate_dur_tokens( +def recalc_dur_tokens_chunked( model: TransformerLM, priming_seq: list, enc_seq: torch.Tensor, tokenizer: AbsTokenizer, start_idx: int, ): - logger = get_logger("GENERATE") + """Speculative-decoding inspired duration re-calculation""" assert start_idx > 0 + logger = get_logger("GENERATE") - priming_seq_len = len(priming_seq) - num_time_toks_seen = priming_seq[:start_idx].count(tokenizer.time_tok) - curr_onset = num_time_toks_seen * 5000 + priming_len = len(priming_seq) last_offset = tokenizer.calc_length_ms(priming_seq) - LAST_OFFSET_BUFFER_MS = 50 - for idx in range(start_idx, priming_seq_len): - with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): - prev_tok_id = enc_seq[0, idx - 1] - prev_tok = tokenizer.id_to_tok[prev_tok_id.item()] - logits = decode_one( - model, - idxs=torch.tensor( - [[prev_tok_id]], device="cuda", dtype=torch.int - ), - input_pos=torch.tensor( - [idx - 1], device="cuda", dtype=torch.int - ), - ) - logger.debug( - f"Sampled logits for position {idx} by inserting {prev_tok} at position {idx-1}" - ) - - next_token_ids = torch.argmax(logits, dim=-1).flatten() - priming_tok = tokenizer.id_to_tok[enc_seq[0, idx].item()] - predicted_tok = tokenizer.id_to_tok[next_token_ids[0].item()] + idx = start_idx + while idx <= priming_len: + end_idx = idx + RECALC_DUR_PREFILL_CHUNK_SIZE - logger.debug( - f"Ground truth token: {priming_tok}, resampled token: {predicted_tok}" + window_ids = torch.tensor( + enc_seq[:, idx - 1 : end_idx - 1].tolist(), + device="cuda", + dtype=torch.int, + ) + window_pos = torch.arange( + idx - 1, end_idx - 1, device="cuda", dtype=torch.int ) - resample = False - if isinstance(priming_tok, tuple) and priming_tok[0] == "onset": - curr_onset = (num_time_toks_seen * 5000) + priming_tok[1] - elif priming_tok == tokenizer.time_tok: - num_time_toks_seen += 1 - curr_onset = num_time_toks_seen * 5000 - elif isinstance(priming_tok, tuple) and priming_tok[0] == "dur": - assert ( - isinstance(predicted_tok, tuple) and predicted_tok[0] == "dur" - ) - - priming_dur = priming_tok[1] - predicted_dur = predicted_tok[1] - - if (predicted_dur > priming_dur) and ( - curr_onset + priming_dur > last_offset - LAST_OFFSET_BUFFER_MS - ): - resample = True - - if resample is True: - logger.info( - f"Replaced ground truth for position {idx}: {tokenizer.id_to_tok[enc_seq[:, idx].item()]} -> {tokenizer.id_to_tok[next_token_ids[0].item()]}" - ) - enc_seq[:, idx] = next_token_ids - priming_seq[idx] = predicted_tok + logger.info( + f"Recalculating chunked durations for positions: {idx-1} - {end_idx-2}" + ) + logger.debug(f"Inserted: {tokenizer.decode(window_ids[0].tolist())}") + logger.debug(f"Positions: {window_pos.tolist()}") - # TODO: There is a bug here if no-notes are active when the signal is pressed - last_tok_id = enc_seq[0, idx] - last_tok = tokenizer.id_to_tok[last_tok_id.item()] + logits = prefill(model, idxs=window_ids, input_pos=window_pos) + pred_ids = logits.argmax(dim=-1).flatten().tolist() - with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): - next_token_logits = decode_one( - model, - idxs=torch.tensor([[last_tok_id]], device="cuda", dtype=torch.int), - input_pos=torch.tensor([idx], device="cuda", dtype=torch.int), + bad_pos = _first_bad_dur_index( + tokenizer=tokenizer, + priming_seq=priming_seq, + pred_ids=pred_ids, + chunk_start=idx, + last_offset_ms=last_offset, + logger=logger, ) - logger.info(f"Updated KV-Cache by inserting {last_tok} at position {idx}") + if bad_pos is None: + idx = end_idx + else: + new_id = pred_ids[bad_pos - idx] + enc_seq[0, bad_pos] = new_id + priming_seq[bad_pos] = tokenizer.id_to_tok[new_id] + idx = bad_pos - return enc_seq, priming_seq, next_token_logits + next_logits = logits[:, priming_len - idx] + return enc_seq, priming_seq, next_logits -@torch.autocast("cuda", dtype=DTYPE) + +# TODO: This is now the latency bottleneck. +# Ideas for reducing it: +# - Get rid of the manual time_tok insert stuff, instead just mask logits +# for all invalid tokens, this should force the model to sample a time tok +# if there aren't any other valid options @torch.inference_mode() def decode_first_tokens( model: TransformerLM, @@ -330,10 +396,8 @@ def decode_first_tokens( ): logger = get_logger("GENERATE") - BEAM_WIDTH = 5 - BUFFER_MS = 100 + HARDWARE_LATENCY_MS + BUFFER_MS = 50 + HARDWARE_LATENCY_MS TIME_TOK_ID = tokenizer.tok_to_id[tokenizer.time_tok] - TIME_TOK_WEIGHTING = -6 logits = first_token_logits time_since_first_onset_ms = get_epoch_time_ms() - first_on_msg_epoch_ms @@ -525,12 +589,12 @@ def decode_tokens( generated_tokens_queue.put(None) -@torch.autocast("cuda", dtype=DTYPE) @torch.inference_mode() def generate_tokens( priming_seq: list, tokenizer: AbsTokenizer, model: TransformerLM, + prev_context: list[int], control_sentinel: threading.Event, generated_tokens_queue: queue.Queue, num_preceding_active_pitches: int, @@ -542,7 +606,7 @@ def generate_tokens( generate_start_s = time.time() priming_seq_len = len(priming_seq) - start_idx = max(2, priming_seq_len - 4 * num_preceding_active_pitches - 1) + start_idx = max(2, priming_seq_len - 4 * num_preceding_active_pitches) enc_seq = torch.tensor( [ tokenizer.encode( @@ -556,33 +620,26 @@ def generate_tokens( logger.debug(f"Priming sequence {priming_seq}") logger.info(f"Priming sequence length: {priming_seq_len}") - logger.info(f"Prefilling up to (and including) position: {start_idx-2}") + logger.info(f"Prefilling up to (and including) position: {start_idx-1}") # In theory we could reuse the logits from prefill prefill_start_s = time.time() - if start_idx < PREFILL_COMPILE_SEQ_LEN: - logger.info( - f"Using compiled prefill for sequence length: {PREFILL_COMPILE_SEQ_LEN}" - ) - compiled_prefill( - model=model, - enc_seq=enc_seq, - ) - else: - prefill( - model, - idxs=enc_seq[:, : start_idx - 1], - input_pos=torch.arange(0, start_idx - 1, device="cuda"), - ) + chunked_prefill( + model=model, + tokenizer=tokenizer, + prev_context=prev_context, + curr_context=enc_seq[0, :start_idx].tolist(), + full=True, + ) torch.cuda.synchronize() logger.info( f"Prefill took {(time.time() - prefill_start_s) * 1000:.2f} milliseconds" ) - logger.info(f"Starting duration recalculation from: {start_idx}") + logger.info(f"Starting duration recalculation from position: {start_idx-1}") recalculate_dur_start_s = time.time() - enc_seq, priming_seq, next_token_logits = recalculate_dur_tokens( + enc_seq, priming_seq, next_token_logits = recalc_dur_tokens_chunked( model=model, priming_seq=priming_seq, enc_seq=enc_seq, @@ -756,6 +813,8 @@ def stream_midi( ) msg = midi_messages[0] + # TODO: Fix this (tomorrow) so it works for off-messages too + # (e.g., an off message which happens just before an on message) if ( (msg["vel"] > 0) and ( @@ -869,6 +928,7 @@ def stream_msgs( model: TransformerLM, tokenizer: AbsTokenizer, msgs: list[mido.Message], + prev_context: list[int], midi_output_port: str, first_on_msg_epoch_ms: int, control_sentinel: threading.Event, @@ -880,11 +940,11 @@ def stream_msgs( ): midi = convert_msgs_to_midi(msgs=msgs) midi_dict = MidiDict(**midi_to_dict(midi)) - priming_seq = tokenizer.tokenize(midi_dict=midi_dict) + priming_seq = tokenizer.tokenize(midi_dict=midi_dict, add_dim_tok=False) priming_seq = priming_seq[: priming_seq.index(tokenizer.eos_tok)] - if tokenizer.dim_tok in priming_seq: - priming_seq.remove(tokenizer.dim_tok) + # if tokenizer.dim_tok in priming_seq: + # priming_seq.remove(tokenizer.dim_tok) if is_ending is True: priming_seq.append(tokenizer.dim_tok) @@ -897,6 +957,7 @@ def stream_msgs( "priming_seq": priming_seq, "tokenizer": tokenizer, "model": model, + "prev_context": prev_context, "control_sentinel": control_sentinel, "generated_tokens_queue": generated_tokens_queue, "temperature": temperature, @@ -964,16 +1025,202 @@ def convert_msgs_to_midi(msgs: list[mido.Message]): return mid +def _find_divergence( + prev_context: list, + curr_context: list, + logger: logging.Logger, +): + agreement_index = 0 + for prev_val, curr_val in zip(prev_context, curr_context): + if prev_val == curr_val: + agreement_index += 1 + else: + logger.info( + f"Found divergence at position {agreement_index + 1}: {curr_val}, {prev_val}" + ) + break + + return agreement_index, curr_context[agreement_index:] + + +# There is an error here if curr_context < prev_context +@torch.inference_mode() +def chunked_prefill( + model: TransformerLM, + tokenizer: AbsTokenizer, + prev_context: list, + curr_context: list, + full: bool = False, +): + + # prev_context = 124 (last thing that was prefilled) + # curr_context = 100 (what we are trying to make sure of now) + + assert isinstance(curr_context[0], int) + assert tokenizer.pad_id not in prev_context + assert tokenizer.pad_id not in curr_context + + logger = get_logger("PREFILL") + while True: + prefill_idx, prefill_toks = _find_divergence( + prev_context, curr_context, logger=logger + ) + num_prefill_toks = len(prefill_toks) + logger.info(f"Tokens to prefill: {len(prefill_toks)}") + + if num_prefill_toks > PREFILL_CHUNK_SIZE: + logger.info( + f"Prefilling {PREFILL_CHUNK_SIZE} tokens from idx={prefill_idx}" + ) + + prefill( + model, + idxs=torch.tensor( + [prefill_toks[:PREFILL_CHUNK_SIZE]], + device="cuda", + dtype=torch.int, + ), + input_pos=torch.arange( + prefill_idx, + prefill_idx + PREFILL_CHUNK_SIZE, + device="cuda", + dtype=torch.int, + ), + ) + prev_context = curr_context[: prefill_idx + PREFILL_CHUNK_SIZE] + + elif num_prefill_toks > 0 and full is True: + logger.info( + f"Prefilling (force) {num_prefill_toks} tokens from idx={prefill_idx}" + ) + prefill_toks += (PREFILL_CHUNK_SIZE - len(prefill_toks)) * [ + tokenizer.pad_id + ] + prefill( + model, + idxs=torch.tensor( + [prefill_toks], device="cuda", dtype=torch.int + ), + input_pos=torch.arange( + prefill_idx, + prefill_idx + PREFILL_CHUNK_SIZE, + device="cuda", + dtype=torch.int, + ), + ) + prev_context = curr_context + break + else: + break + + # TODO: This appears as -1 sometimes?? + logger.info( + f"KV stored up to idx={len(prev_context)- 1} (curr_context_len={len(curr_context)})" + ) + + return prev_context + + +def continuous_prefill( + model: TransformerLM, + msgs: list, + received_messages_queue: queue.Queue, + prev_context: list[int], +): + tokenizer = AbsTokenizer() + logger = get_logger("PREFILL") + msg_cnt = 0 + seen_sentinel = False + + while seen_sentinel is False: + while seen_sentinel is False: + try: + msg = received_messages_queue.get_nowait() + except queue.Empty: + break + else: + if msg is None: + logger.info("Seen sentinel in message received messages") + seen_sentinel = True + else: + msgs.append(msg) + msg_cnt += 1 + + if (msg_cnt >= 5 or seen_sentinel) and len(msgs) > 10: + midi = convert_msgs_to_midi(msgs=msgs) + midi_dict = MidiDict(**midi_to_dict(midi)) + curr_context = tokenizer.encode( + tokenizer.tokenize(midi_dict, add_dim_tok=False) + ) + prev_context = chunked_prefill( + model=model, + tokenizer=tokenizer, + prev_context=prev_context, + curr_context=curr_context, + full=False, + ) + msg_cnt = 0 + else: + time.sleep(0.01) + + return msgs, prev_context + + +def capture_and_update_kv( + model: TransformerLM, + msgs: list, + prev_context: list, + control_sentinel: threading.Event, + midi_input_port: str, + midi_capture_channel: int, + midi_control_signal: int | None = None, + midi_through_port: str | None = None, + first_msg_epoch_time_ms: int | None = None, +): + # Start capture_midi_input in threadpool + # Run continiously update kv in main thread + # When control sentinel is seen, return msgs and prev_context continiously update kv (check for None sentinel here) + + received_messages_queue = queue.Queue() + results_queue = queue.Queue() + capture_midi_thread = threading.Thread( + target=capture_midi_input, + kwargs={ + "midi_input_port": midi_input_port, + "control_sentinel": control_sentinel, + "received_messages_queue": received_messages_queue, + "midi_capture_channel": midi_capture_channel, + "midi_control_signal": midi_control_signal, + "midi_through_port": midi_through_port, + "first_msg_epoch_time_ms": first_msg_epoch_time_ms, + "results_queue": results_queue, + }, + ) + capture_midi_thread.start() + + msgs, prev_context = continuous_prefill( + model=model, + msgs=msgs, + received_messages_queue=received_messages_queue, + prev_context=prev_context, + ) + capture_midi_thread.join() + first_on_msg_epoch_ms, num_active_pitches = results_queue.get() + + return msgs, prev_context, first_on_msg_epoch_ms, num_active_pitches + + def capture_midi_input( midi_input_port: str, control_sentinel: threading.Event, + received_messages_queue: queue.Queue, midi_capture_channel: int, midi_control_signal: int | None = None, midi_through_port: str | None = None, first_msg_epoch_time_ms: int | None = None, + results_queue: queue.Queue | None = None, ): logger = get_logger("CAPTURE") - received_messages = [] active_pitches = set() first_on_msg_epoch_ms = None prev_msg_epoch_time_ms = first_msg_epoch_time_ms # @@ -1015,7 +1262,7 @@ def capture_midi_input( msg.type == "note_on" and msg.velocity == 0 ) or msg.type == "note_off": active_pitches.discard(msg.note) - received_messages.append(msg) + received_messages_queue.put(msg) if midi_through is not None: midi_through.send(msg) elif msg.type == "note_on" and msg.velocity > 0: @@ -1023,11 +1270,11 @@ def capture_midi_input( first_on_msg_epoch_ms = get_epoch_time_ms() active_pitches.add(msg.note) - received_messages.append(msg) + received_messages_queue.put(msg) if midi_through is not None: midi_through.send(msg) elif msg.type == "control_change" and msg.control == 64: - received_messages.append(msg) + received_messages_queue.put(msg) elif ( msg.type == "control_change" and msg.control == midi_control_signal @@ -1048,7 +1295,7 @@ def capture_midi_input( channel=midi_capture_channel, time=get_epoch_time_ms() - prev_msg_epoch_time_ms, ) - received_messages.append(msg) + received_messages_queue.put(msg) if midi_through is not None: midi_through.send(msg) @@ -1061,10 +1308,11 @@ def capture_midi_input( channel=midi_capture_channel, time=0, ) - received_messages.append(msg) + received_messages_queue.put(msg) if midi_through is not None: midi_through.send(msg) + # Turn off pedal msg = mido.Message( type="control_change", control=64, @@ -1072,23 +1320,16 @@ def capture_midi_input( channel=midi_capture_channel, time=0, ) - received_messages.append(msg) + received_messages_queue.put(msg) if midi_through is not None: midi_through.send(msg) - # TODO: Need to figure out what the hell this does - is it needed? - # Workaround for the way that file-playback is implemented - delete - msg = mido.Message( - type="control_change", - control=midi_control_signal, - value=0, - channel=midi_capture_channel, - time=0, - ) - if midi_through is not None: - midi_through.send(msg) + received_messages_queue.put(None) # Sentinel + + if results_queue is not None: + results_queue.put((first_on_msg_epoch_ms, num_active_pitches)) - return received_messages, first_on_msg_epoch_ms, num_active_pitches + return first_on_msg_epoch_ms, num_active_pitches def play_midi_file(midi_port: str, midi_path: str): @@ -1220,11 +1461,8 @@ def parse_args(): return argp.parse_args() -# TODO: Test demo on real instrument with real MIDI interface -# TODO: Possibly a problem with endings being truncated # TODO: Need functionality for handing case where we run out of model context - - +# TODO: Make sure channel=9 (drum) case is covered def main(): args = parse_args() logger = get_logger() @@ -1265,11 +1503,13 @@ def main(): keypress_thread.start() midi_control_thread.start() - msgs = [] - captured_msgs, first_on_msg_epoch_ms, num_active_pitches = ( - capture_midi_input( - midi_input_port=midi_input_port, + msgs, prev_context, first_on_msg_epoch_ms, num_active_pitches = ( + capture_and_update_kv( + model=model, + msgs=[], + prev_context=[], control_sentinel=control_sentinel, + midi_input_port=midi_input_port, midi_control_signal=args.midi_control_signal, midi_through_port=args.midi_through, midi_capture_channel=0, @@ -1282,7 +1522,8 @@ def main(): msgs = stream_msgs( model=model, tokenizer=tokenizer, - msgs=msgs + captured_msgs, + msgs=msgs, + prev_context=prev_context, midi_output_port=args.midi_out, first_on_msg_epoch_ms=first_on_msg_epoch_ms, control_sentinel=control_sentinel, @@ -1299,19 +1540,24 @@ def main(): else: itt += 1 - captured_msgs, _, num_active_pitches = capture_midi_input( - midi_input_port=midi_input_port, + msgs, prev_context, _, num_active_pitches = capture_and_update_kv( + model=model, + msgs=msgs, + prev_context=prev_context, control_sentinel=control_sentinel, + midi_input_port=midi_input_port, midi_control_signal=args.midi_control_signal, midi_through_port=args.midi_through, midi_capture_channel=itt, first_msg_epoch_time_ms=first_on_msg_epoch_ms, ) + # TODO: There is a bug with the token somewhere? msgs = stream_msgs( model=model, tokenizer=tokenizer, msgs=msgs, + prev_context=prev_context, midi_output_port=args.midi_out, first_on_msg_epoch_ms=first_on_msg_epoch_ms, control_sentinel=control_sentinel, From 4f54e417d505b8f3e444293020abfd7cf8b87976 Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 28 May 2025 14:57:39 +0000 Subject: [PATCH 44/72] add off-msg streaming and fix timing alignment --- demo/demo.py | 117 ++++++++++++++++++++++++++++++--------------------- 1 file changed, 69 insertions(+), 48 deletions(-) diff --git a/demo/demo.py b/demo/demo.py index e6aaa87..dbe4518 100644 --- a/demo/demo.py +++ b/demo/demo.py @@ -8,6 +8,7 @@ import logging import threading import queue +import copy import torch import mido import torch._inductor.config @@ -28,14 +29,14 @@ torch._inductor.config.fx_graph_cache = True DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 -MAX_SEQ_LEN = 2048 -PREFILL_CHUNK_SIZE = 16 +MAX_SEQ_LEN = 4096 +PREFILL_CHUNK_SIZE = 32 RECALC_DUR_PREFILL_CHUNK_SIZE = 8 RECALC_DUR_BUFFER_MS = 50 # Decode first -BEAM_WIDTH = 5 -TIME_TOK_WEIGHTING = -6 +BEAM_WIDTH = 3 +TIME_TOK_WEIGHTING = -5 # HARDWARE: Decoded logits are masked for durations < MIN_NOTE_LEN_MS # HARDWARE: Sends early off-msg if pitch is on MIN_NOTE_DELTA_MS before on-msg @@ -558,7 +559,6 @@ def decode_tokens( ) logits[:, tokenizer.tok_to_id[tokenizer.dim_tok]] = float("-inf") - # logits[:, tokenizer.tok_to_id[tokenizer.eos_tok]] = float("-inf") for dur_ms in range(0, MIN_NOTE_LEN_MS, 10): logits[:, tokenizer.tok_to_id[("dur", dur_ms)]] = float("-inf") @@ -606,7 +606,7 @@ def generate_tokens( generate_start_s = time.time() priming_seq_len = len(priming_seq) - start_idx = max(2, priming_seq_len - 4 * num_preceding_active_pitches) + start_idx = max(2, priming_seq_len - 4 * num_preceding_active_pitches - 1) enc_seq = torch.tensor( [ tokenizer.encode( @@ -772,6 +772,7 @@ def stream_midi( midi_output_port: str, control_sentinel: threading.Event, midi_stream_channel: int, + results_queue: queue.Queue, ): logger = get_logger("STREAM") logger.info( @@ -796,9 +797,6 @@ def stream_midi( logger.debug(f"Received message: {msg}") midi_messages.append(msg) - if control_sentinel.is_set(): - break - midi_messages = sorted( midi_messages, key=lambda msg: ( @@ -807,6 +805,9 @@ def stream_midi( ), ) + if control_sentinel.is_set(): + break + while midi_messages: latency_adjusted_epoch_time_ms = ( get_epoch_time_ms() + HARDWARE_LATENCY_MS @@ -833,7 +834,7 @@ def stream_midi( ) ) pitch_active[msg["pitch"]] = False - logger.info(f"Sent early off for {msg}") + logger.debug(f"Sent early off for {msg}") if ( 0 @@ -873,11 +874,9 @@ def stream_midi( msgs.append(mido_msg_with_time) pitch_active[msg["pitch"]] = msg["vel"] != 0 - logger.info( - f"[{get_epoch_time_ms() - msg["epoch_time_ms"]}ms] Sent message: {msg}" - ) + logger.info(f"Sent message: {mido_msg}") else: - logger.info( + logger.debug( f"Skipping note_off message due to uuid mismatch: {msg}" ) midi_messages.pop(0) @@ -886,7 +885,7 @@ def stream_midi( latency_adjusted_epoch_time_ms - msg["epoch_time_ms"] > 100 ): # Message occurs too far in the past - logger.info( + logger.debug( f"Skipping message occurring too far ({latency_adjusted_epoch_time_ms - msg["epoch_time_ms"]}ms) in the past: {msg}" ) midi_messages.pop(0) @@ -897,7 +896,6 @@ def stream_midi( time.sleep(0.005) logger.info("Processing remaining note_off messages") - logger.debug(midi_messages) remaining_note_off_messages = [ @@ -907,8 +905,7 @@ def stream_midi( and last_pitch_uuid.get(msg["pitch"]) == msg["uuid"] ] - while remaining_note_off_messages: - msg = remaining_note_off_messages.pop(0) + for msg in remaining_note_off_messages: mido_msg = mido.Message( "note_on", note=msg["pitch"], @@ -917,11 +914,34 @@ def stream_midi( time=msg["epoch_time_ms"] - prev_msg_epoch_time_ms, ) prev_msg_epoch_time_ms = msg["epoch_time_ms"] - midi_out.send(mido_msg) - logger.info(f"Sent message: {msg}") msgs.append(mido_msg) - return msgs + results_queue.put(msgs) + + while remaining_note_off_messages: + msg = remaining_note_off_messages.pop(0) + while True: + latency_adjusted_epoch_time_ms = ( + get_epoch_time_ms() + HARDWARE_LATENCY_MS + ) + + if ( + 0 + < latency_adjusted_epoch_time_ms - msg["epoch_time_ms"] + <= 50 + ): + mido_msg = mido.Message( + "note_on", + note=msg["pitch"], + velocity=0, + channel=midi_stream_channel, + time=0, # Does not matter as only used for streaming + ) + midi_out.send(mido_msg) + logger.info(f"Sent message: {mido_msg}") + break + else: + time.sleep(0.01) def stream_msgs( @@ -965,7 +985,6 @@ def stream_msgs( "num_preceding_active_pitches": num_preceding_active_pitches, "first_on_msg_epoch_ms": first_on_msg_epoch_ms, }, - daemon=True, ) generate_tokens_thread.start() @@ -980,29 +999,35 @@ def stream_msgs( priming_seq, onset=True ), }, - daemon=True, ) decode_tokens_to_midi_thread.start() - msgs = stream_midi( - midi_messages_queue=midi_messages_queue, - msgs=msgs, - prev_msg_epoch_time_ms=first_on_msg_epoch_ms - + tokenizer.calc_length_ms(priming_seq, onset=False), - midi_output_port=midi_output_port, - control_sentinel=control_sentinel, - midi_stream_channel=midi_stream_channel, + stream_midi_results_queue = queue.Queue() + stream_midi_thread = threading.Thread( + target=stream_midi, + kwargs={ + "midi_messages_queue": midi_messages_queue, + "msgs": msgs, + "prev_msg_epoch_time_ms": first_on_msg_epoch_ms + + tokenizer.calc_length_ms(priming_seq, onset=False), + "midi_output_port": midi_output_port, + "control_sentinel": control_sentinel, + "midi_stream_channel": midi_stream_channel, + "results_queue": stream_midi_results_queue, + }, + daemon=True, ) + stream_midi_thread.start() generate_tokens_thread.join() decode_tokens_to_midi_thread.join() + msgs = stream_midi_results_queue.get() return msgs +# TODO: Channel 9 issues here? def convert_msgs_to_midi(msgs: list[mido.Message]): - logger = get_logger("convert") - channel_to_track = { chan: mido.MidiTrack() for chan in list(set([msg.channel for msg in msgs])) @@ -1011,6 +1036,14 @@ def convert_msgs_to_midi(msgs: list[mido.Message]): for msg in msgs: channel_to_track[msg.channel].append(msg) + # Workaround for possibility that track_0 start time != first_on_msg_epoch_ms + for msg in channel_to_track[0]: + if msg.type == "note_on" and msg.velocity > 0: + msg.time = 0 + break + else: + msg.time = 0 + mid = mido.MidiFile(type=1) mid.ticks_per_beat = 500 @@ -1053,9 +1086,6 @@ def chunked_prefill( full: bool = False, ): - # prev_context = 124 (last thing that was prefilled) - # curr_context = 100 (what we are trying to make sure of now) - assert isinstance(curr_context[0], int) assert tokenizer.pad_id not in prev_context assert tokenizer.pad_id not in curr_context @@ -1113,9 +1143,8 @@ def chunked_prefill( else: break - # TODO: This appears as -1 sometimes?? logger.info( - f"KV stored up to idx={len(prev_context)- 1} (curr_context_len={len(curr_context)})" + f"KV stored up to idx={max(0, len(prev_context)- 1)} (curr_context_len={len(curr_context)})" ) return prev_context @@ -1177,10 +1206,6 @@ def capture_and_update_kv( midi_through_port: str | None = None, first_msg_epoch_time_ms: int | None = None, ): - # Start capture_midi_input in threadpool - # Run continiously update kv in main thread - # When control sentinel is seen, return msgs and prev_context continiously update kv (check for None sentinel here) - received_messages_queue = queue.Queue() results_queue = queue.Queue() capture_midi_thread = threading.Thread( @@ -1215,10 +1240,10 @@ def capture_midi_input( control_sentinel: threading.Event, received_messages_queue: queue.Queue, midi_capture_channel: int, + results_queue: queue.Queue, midi_control_signal: int | None = None, midi_through_port: str | None = None, first_msg_epoch_time_ms: int | None = None, - results_queue: queue.Queue | None = None, ): logger = get_logger("CAPTURE") active_pitches = set() @@ -1325,11 +1350,7 @@ def capture_midi_input( midi_through.send(msg) received_messages_queue.put(None) # Sentinel - - if results_queue is not None: - results_queue.put((first_on_msg_epoch_ms, num_active_pitches)) - - return first_on_msg_epoch_ms, num_active_pitches + results_queue.put((first_on_msg_epoch_ms, num_active_pitches)) def play_midi_file(midi_port: str, midi_path: str): From b16394cc57e5ffe34055bc4ab03a6cc41abaa69d Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 28 May 2025 17:42:19 +0000 Subject: [PATCH 45/72] fix early-off logic with dumb hack --- demo/demo.py | 46 ++++++++++++++++++++++------------------------ 1 file changed, 22 insertions(+), 24 deletions(-) diff --git a/demo/demo.py b/demo/demo.py index dbe4518..0b75112 100644 --- a/demo/demo.py +++ b/demo/demo.py @@ -41,8 +41,8 @@ # HARDWARE: Decoded logits are masked for durations < MIN_NOTE_LEN_MS # HARDWARE: Sends early off-msg if pitch is on MIN_NOTE_DELTA_MS before on-msg # HARDWARE: All messages are sent HARDWARE_LATENCY_MS early -MIN_NOTE_DELTA_MS = 50 -MIN_NOTE_LEN_MS = 50 +MIN_NOTE_DELTA_MS = 100 +MIN_NOTE_LEN_MS = 200 HARDWARE_LATENCY_MS = 0 # TODO: @@ -699,6 +699,7 @@ def decode_tokens_to_midi( f"Total time elapsed since first onset: {get_epoch_time_ms() - first_on_msg_epoch_ms}" ) + pitch_to_prev_msg = {} note_buffer = [] num_time_toks = priming_seq_last_onset_ms // 5000 @@ -756,6 +757,24 @@ def decode_tokens_to_midi( "uuid": _uuid, } + # Not thread safe but in theory should be ok? + if pitch_to_prev_msg.get(pitch) is not None and MIN_NOTE_DELTA_MS > 0: + prev_on, prev_off = pitch_to_prev_msg.get(pitch) + new_prev_off = max( + min( + prev_off["epoch_time_ms"], + onset_epoch_ms - MIN_NOTE_DELTA_MS, + ), + prev_on["epoch_time_ms"], + ) + if new_prev_off != prev_off["epoch_time_ms"]: + logger.info( + f"Adjusting prev_off['epoch_time_ms'] -> new_prev_off" + ) + prev_off["epoch_time_ms"] = new_prev_off + + pitch_to_prev_msg[pitch] = {"on": on_msg, "off": off_msg} + midi_messages_queue.put(on_msg) midi_messages_queue.put(off_msg) logger.debug(f"Put message: {on_msg}") @@ -765,6 +784,7 @@ def decode_tokens_to_midi( note_buffer = [] +# TODO: Test the new changes in decode_tokens_to_midi and clean this fn up. def stream_midi( midi_messages_queue: queue.Queue, msgs: list[mido.Message], @@ -814,28 +834,6 @@ def stream_midi( ) msg = midi_messages[0] - # TODO: Fix this (tomorrow) so it works for off-messages too - # (e.g., an off message which happens just before an on message) - if ( - (msg["vel"] > 0) - and ( - msg["epoch_time_ms"] - latency_adjusted_epoch_time_ms - <= MIN_NOTE_DELTA_MS - ) - and pitch_active.get(msg["pitch"], False) - ): - midi_out.send( - mido.Message( - "note_on", - note=msg["pitch"], - velocity=0, - channel=0, - time=0, - ) - ) - pitch_active[msg["pitch"]] = False - logger.debug(f"Sent early off for {msg}") - if ( 0 < latency_adjusted_epoch_time_ms - msg["epoch_time_ms"] From fc43f709f946d4bef8046b1f1d538ed52ea18420 Mon Sep 17 00:00:00 2001 From: Louis Date: Thu, 29 May 2025 14:26:09 +0000 Subject: [PATCH 46/72] fix stream_midi logic --- demo/demo.py | 153 ++++++++++++++++++++++++++------------------------- 1 file changed, 77 insertions(+), 76 deletions(-) diff --git a/demo/demo.py b/demo/demo.py index 0b75112..46a3a0b 100644 --- a/demo/demo.py +++ b/demo/demo.py @@ -37,6 +37,7 @@ # Decode first BEAM_WIDTH = 3 TIME_TOK_WEIGHTING = -5 +FIRST_ONSET_BUFFER_MS = 25 # HARDWARE: Decoded logits are masked for durations < MIN_NOTE_LEN_MS # HARDWARE: Sends early off-msg if pitch is on MIN_NOTE_DELTA_MS before on-msg @@ -45,10 +46,6 @@ MIN_NOTE_LEN_MS = 200 HARDWARE_LATENCY_MS = 0 -# TODO: -# - Add CFG support (eek) - - file_handler = logging.FileHandler("./demo.log", mode="w") file_handler.setLevel(logging.DEBUG) @@ -397,14 +394,14 @@ def decode_first_tokens( ): logger = get_logger("GENERATE") - BUFFER_MS = 50 + HARDWARE_LATENCY_MS - TIME_TOK_ID = tokenizer.tok_to_id[tokenizer.time_tok] + buffer_ms = FIRST_ONSET_BUFFER_MS + HARDWARE_LATENCY_MS + time_tok_id = tokenizer.tok_to_id[tokenizer.time_tok] logits = first_token_logits time_since_first_onset_ms = get_epoch_time_ms() - first_on_msg_epoch_ms idx = len(priming_seq) + 1 - num_time_toks_required = (time_since_first_onset_ms + BUFFER_MS) // 5000 + num_time_toks_required = (time_since_first_onset_ms + buffer_ms) // 5000 num_time_toks_in_priming_seq = priming_seq.count(tokenizer.time_tok) num_time_toks_to_add = num_time_toks_required - num_time_toks_in_priming_seq @@ -416,7 +413,7 @@ def decode_first_tokens( logits = decode_one( model, idxs=torch.tensor( - [[TIME_TOK_ID]], device="cuda", dtype=torch.int + [[time_tok_id]], device="cuda", dtype=torch.int ), input_pos=torch.tensor( [idx - 1], device="cuda", dtype=torch.int @@ -425,7 +422,7 @@ def decode_first_tokens( logger.info(f"Inserted time_tok at position {idx-1}") num_time_toks_to_add -= 1 - enc_seq[:, idx - 1] = torch.tensor([[TIME_TOK_ID]]).cuda() + enc_seq[:, idx - 1] = torch.tensor([[time_tok_id]]).cuda() idx += 1 logits[:, tokenizer.tok_to_id[tokenizer.dim_tok]] = float("-inf") @@ -434,9 +431,9 @@ def decode_first_tokens( log_probs = torch.log_softmax(logits, dim=-1) top_log_probs, top_ids = torch.topk(log_probs, k=BEAM_WIDTH, dim=-1) - if TIME_TOK_ID not in top_ids[0].tolist(): - top_ids[0, -1] = TIME_TOK_ID - top_log_probs[0, -1] = log_probs[0, TIME_TOK_ID] + TIME_TOK_WEIGHTING + if time_tok_id not in top_ids[0].tolist(): + top_ids[0, -1] = time_tok_id + top_log_probs[0, -1] = log_probs[0, time_tok_id] + TIME_TOK_WEIGHTING top_toks = [tokenizer.id_to_tok[id] for id in top_ids[0].tolist()] @@ -448,11 +445,11 @@ def decode_first_tokens( masked_onset_ids = [ tokenizer.tok_to_id[tok] for tok in tokenizer.onset_tokens - if tok[1] < ((time_since_first_onset_ms + BUFFER_MS) % 5000) + if tok[1] < ((time_since_first_onset_ms + buffer_ms) % 5000) ] logger.debug( - f"Masking onsets for {len(masked_onset_ids)} tokens ({time_since_first_onset_ms + BUFFER_MS})" + f"Masking onsets for {len(masked_onset_ids)} tokens ({time_since_first_onset_ms + buffer_ms})" ) best_score = float("-inf") @@ -475,8 +472,8 @@ def decode_first_tokens( next_log_probs = torch.log_softmax(next_logits, dim=-1) next_log_probs[:, masked_onset_ids] = float("-inf") - if tok_id == TIME_TOK_ID: - next_log_probs[:, TIME_TOK_ID] = float("-inf") + if tok_id == time_tok_id: + next_log_probs[:, time_tok_id] = float("-inf") next_tok_log_prob, next_tok_id = torch.max(next_log_probs, dim=-1) next_tok = tokenizer.id_to_tok[next_tok_id.item()] @@ -683,7 +680,7 @@ def generate_tokens( def decode_tokens_to_midi( generated_tokens_queue: queue.Queue, - midi_messages_queue: queue.Queue, + outbound_midi_msg_queue: queue.Queue, tokenizer: AbsTokenizer, first_on_msg_epoch_ms: int, priming_seq_last_onset_ms: int, @@ -714,7 +711,7 @@ def decode_tokens_to_midi( "epoch_time_ms": offset_epoch_ms + 250, # Last note offset "uuid": _uuid, } # pitch=-1 denotes end_msg - midi_messages_queue.put(end_msg) + outbound_midi_msg_queue.put(end_msg) logger.info(f"Seen exit signal: EOS token") logger.debug(f"Put message: {end_msg}") return @@ -760,23 +757,22 @@ def decode_tokens_to_midi( # Not thread safe but in theory should be ok? if pitch_to_prev_msg.get(pitch) is not None and MIN_NOTE_DELTA_MS > 0: prev_on, prev_off = pitch_to_prev_msg.get(pitch) - new_prev_off = max( + adj_off_time = max( min( prev_off["epoch_time_ms"], onset_epoch_ms - MIN_NOTE_DELTA_MS, ), prev_on["epoch_time_ms"], ) - if new_prev_off != prev_off["epoch_time_ms"]: - logger.info( - f"Adjusting prev_off['epoch_time_ms'] -> new_prev_off" - ) - prev_off["epoch_time_ms"] = new_prev_off + if adj_off_time != prev_off["epoch_time_ms"]: + logger.debug(f"Adjusting {prev_off}: t={adj_off_time}") + prev_off["epoch_time_ms"] = adj_off_time + prev_off["adjusted"] = True - pitch_to_prev_msg[pitch] = {"on": on_msg, "off": off_msg} + pitch_to_prev_msg[pitch] = [on_msg, off_msg] - midi_messages_queue.put(on_msg) - midi_messages_queue.put(off_msg) + outbound_midi_msg_queue.put(on_msg) + outbound_midi_msg_queue.put(off_msg) logger.debug(f"Put message: {on_msg}") logger.debug(f"Put message: {off_msg}") logger.debug(f"Ahead by {onset_epoch_ms - get_epoch_time_ms()}ms") @@ -786,7 +782,7 @@ def decode_tokens_to_midi( # TODO: Test the new changes in decode_tokens_to_midi and clean this fn up. def stream_midi( - midi_messages_queue: queue.Queue, + inbound_midi_msg_queue: queue.Queue, msgs: list[mido.Message], prev_msg_epoch_time_ms: float, midi_output_port: str, @@ -801,24 +797,25 @@ def stream_midi( logger.info( f"Applying hardware latency adjustment: {HARDWARE_LATENCY_MS}ms" ) + MAX_DELAY_MS = 50 - last_pitch_uuid = {} - pitch_active = {} - midi_messages = [] + active_pitch_uuid = {} + is_pitch_active = {} + midi_msgs = [] with mido.open_output(midi_output_port) as midi_out: while not control_sentinel.is_set(): while True: try: - msg = midi_messages_queue.get_nowait() + msg = inbound_midi_msg_queue.get_nowait() except queue.Empty: break else: logger.debug(f"Received message: {msg}") - midi_messages.append(msg) + midi_msgs.append(msg) - midi_messages = sorted( - midi_messages, + midi_msgs = sorted( + midi_msgs, key=lambda msg: ( msg["epoch_time_ms"], msg["vel"], @@ -828,16 +825,16 @@ def stream_midi( if control_sentinel.is_set(): break - while midi_messages: + while midi_msgs: latency_adjusted_epoch_time_ms = ( get_epoch_time_ms() + HARDWARE_LATENCY_MS ) - msg = midi_messages[0] + msg = midi_msgs[0] if ( 0 < latency_adjusted_epoch_time_ms - msg["epoch_time_ms"] - <= 50 + <= MAX_DELAY_MS ): if msg["pitch"] == -1: # End msg control_sentinel.set() @@ -852,57 +849,59 @@ def stream_midi( ) if msg["vel"] > 0: - last_pitch_uuid[msg["pitch"]] = msg["uuid"] - should_send = True + active_pitch_uuid[msg["pitch"]] = msg["uuid"] + should_send_midi_out = True + should_append_to_msgs = True + elif msg.get("adjusted", False) is True: + should_send_midi_out = True + should_append_to_msgs = False else: - # Only send note_off if it matches the last note_on UUID - should_send = ( - last_pitch_uuid.get(msg["pitch"]) == msg["uuid"] + should_send_midi_out = ( + active_pitch_uuid.get(msg["pitch"]) == msg["uuid"] ) + should_append_to_msgs = should_send_midi_out - if should_send is True: + if should_send_midi_out is True: + midi_out.send(mido_msg) + is_pitch_active[msg["pitch"]] = msg["vel"] != 0 + logger.info(f"Sent message: {mido_msg}") + if should_append_to_msgs is True: mido_msg_with_time = copy.deepcopy(mido_msg) mido_msg_with_time.channel = midi_stream_channel mido_msg_with_time.time = max( 0, msg["epoch_time_ms"] - prev_msg_epoch_time_ms ) prev_msg_epoch_time_ms = msg["epoch_time_ms"] - - midi_out.send(mido_msg) msgs.append(mido_msg_with_time) - pitch_active[msg["pitch"]] = msg["vel"] != 0 - logger.info(f"Sent message: {mido_msg}") - else: - logger.debug( - f"Skipping note_off message due to uuid mismatch: {msg}" - ) - midi_messages.pop(0) + midi_msgs.pop(0) elif ( - latency_adjusted_epoch_time_ms - msg["epoch_time_ms"] > 100 + latency_adjusted_epoch_time_ms - msg["epoch_time_ms"] + > MAX_DELAY_MS ): # Message occurs too far in the past logger.debug( f"Skipping message occurring too far ({latency_adjusted_epoch_time_ms - msg["epoch_time_ms"]}ms) in the past: {msg}" ) - midi_messages.pop(0) + midi_msgs.pop(0) else: # Message occurs in the future break time.sleep(0.005) - logger.info("Processing remaining note_off messages") - logger.debug(midi_messages) - remaining_note_off_messages = [ msg - for msg in midi_messages + for msg in midi_msgs if msg["vel"] == 0 - and last_pitch_uuid.get(msg["pitch"]) == msg["uuid"] + and active_pitch_uuid.get(msg["pitch"]) == msg["uuid"] ] + logger.info("Processing remaining note_off messages") + for __msg in remaining_note_off_messages: + logger.debug(remaining_note_off_messages) + for msg in remaining_note_off_messages: mido_msg = mido.Message( "note_on", @@ -923,11 +922,7 @@ def stream_midi( get_epoch_time_ms() + HARDWARE_LATENCY_MS ) - if ( - 0 - < latency_adjusted_epoch_time_ms - msg["epoch_time_ms"] - <= 50 - ): + if 0 < latency_adjusted_epoch_time_ms - msg["epoch_time_ms"]: mido_msg = mido.Message( "note_on", note=msg["pitch"], @@ -961,8 +956,6 @@ def stream_msgs( priming_seq = tokenizer.tokenize(midi_dict=midi_dict, add_dim_tok=False) priming_seq = priming_seq[: priming_seq.index(tokenizer.eos_tok)] - # if tokenizer.dim_tok in priming_seq: - # priming_seq.remove(tokenizer.dim_tok) if is_ending is True: priming_seq.append(tokenizer.dim_tok) @@ -990,7 +983,7 @@ def stream_msgs( target=decode_tokens_to_midi, kwargs={ "generated_tokens_queue": generated_tokens_queue, - "midi_messages_queue": midi_messages_queue, + "outbound_midi_msg_queue": midi_messages_queue, "tokenizer": tokenizer, "first_on_msg_epoch_ms": first_on_msg_epoch_ms, "priming_seq_last_onset_ms": tokenizer.calc_length_ms( @@ -1000,14 +993,20 @@ def stream_msgs( ) decode_tokens_to_midi_thread.start() + prev_ms_epoch_time_ms = ( + first_on_msg_epoch_ms + + tokenizer.calc_length_ms(priming_seq, onset=False) + if is_ending is False + else first_on_msg_epoch_ms + ) + stream_midi_results_queue = queue.Queue() stream_midi_thread = threading.Thread( target=stream_midi, kwargs={ - "midi_messages_queue": midi_messages_queue, + "inbound_midi_msg_queue": midi_messages_queue, "msgs": msgs, - "prev_msg_epoch_time_ms": first_on_msg_epoch_ms - + tokenizer.calc_length_ms(priming_seq, onset=False), + "prev_msg_epoch_time_ms": prev_ms_epoch_time_ms, "midi_output_port": midi_output_port, "control_sentinel": control_sentinel, "midi_stream_channel": midi_stream_channel, @@ -1021,6 +1020,9 @@ def stream_msgs( decode_tokens_to_midi_thread.join() msgs = stream_midi_results_queue.get() + if is_ending is True: + stream_midi_thread.join() + return msgs @@ -1094,10 +1096,10 @@ def chunked_prefill( prev_context, curr_context, logger=logger ) num_prefill_toks = len(prefill_toks) - logger.info(f"Tokens to prefill: {len(prefill_toks)}") + logger.debug(f"Tokens to prefill: {len(prefill_toks)}") if num_prefill_toks > PREFILL_CHUNK_SIZE: - logger.info( + logger.debug( f"Prefilling {PREFILL_CHUNK_SIZE} tokens from idx={prefill_idx}" ) @@ -1118,7 +1120,7 @@ def chunked_prefill( prev_context = curr_context[: prefill_idx + PREFILL_CHUNK_SIZE] elif num_prefill_toks > 0 and full is True: - logger.info( + logger.debug( f"Prefilling (force) {num_prefill_toks} tokens from idx={prefill_idx}" ) prefill_toks += (PREFILL_CHUNK_SIZE - len(prefill_toks)) * [ @@ -1553,11 +1555,10 @@ def main(): is_ending=False, ) + itt += 1 control_sentinel.clear() if end_sentinel.is_set(): break - else: - itt += 1 msgs, prev_context, _, num_active_pitches = capture_and_update_kv( model=model, From ba835ff7291b6c88b8afd53bded6ff97da056f6d Mon Sep 17 00:00:00 2001 From: Louis Date: Thu, 29 May 2025 15:29:45 +0000 Subject: [PATCH 47/72] port demo to mlx --- aria/inference/mlx.py | 350 -------- aria/inference/model_mlx.py | 1 - demo/demo_mlx.py | 1554 +++++++++++++++++++++++++++++++++++ 3 files changed, 1554 insertions(+), 351 deletions(-) delete mode 100644 aria/inference/mlx.py create mode 100644 demo/demo_mlx.py diff --git a/aria/inference/mlx.py b/aria/inference/mlx.py deleted file mode 100644 index 7ec1b03..0000000 --- a/aria/inference/mlx.py +++ /dev/null @@ -1,350 +0,0 @@ -"""Inference implementation for mlx backend""" - -from dataclasses import dataclass -from typing import Any, Dict, Optional, Union - -import mlx.core as mx -import mlx.nn as nn - -from .base import ( - BaseModelArgs, - create_attention_mask, - scaled_dot_product_attention, -) -from .cache import ChunkedKVCache, KVCache -from .rope_utils import initialize_rope -from .switch_layers import SwitchGLU - - -@dataclass -class TextArgs(BaseModelArgs): - attention_bias: bool - attention_chunk_size: int - head_dim: int - hidden_act: str - hidden_size: int - interleave_moe_layer_step: int - intermediate_size: int - intermediate_size_mlp: int - max_position_embeddings: int - model_type: str - num_attention_heads: int - num_experts_per_tok: int - num_hidden_layers: int - num_key_value_heads: int - num_local_experts: int - rms_norm_eps: float - rope_scaling: Any - rope_theta: float - use_qk_norm: bool - vocab_size: int - attn_temperature_tuning: int = 4 - floor_scale: int = 8192 - attn_scale: float = 0.1 - - -@dataclass -class ModelArgs(BaseModelArgs): - text_config: Union[TextArgs, dict] - model_type: str - - def __post_init__(self): - self.text_config = TextArgs.from_dict(self.text_config) - - -class Attention(nn.Module): - def __init__(self, args: TextArgs, layer_idx: int): - super().__init__() - - dim = args.hidden_size - self.n_heads = n_heads = args.num_attention_heads - self.n_kv_heads = n_kv_heads = args.num_key_value_heads - - self.use_rope = int( - (layer_idx + 1) % 4 != 0 - ) # rope unused for dense layers - self.attn_temperature_tuning = args.attn_temperature_tuning - self.floor_scale = args.floor_scale - self.attn_scale = args.attn_scale - - self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads - - self.scale = head_dim**-0.5 - if hasattr(args, "attention_bias"): - attention_bias = args.attention_bias - else: - attention_bias = False - - self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias) - self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) - self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) - self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias) - - self.use_qk_norm = args.use_qk_norm and self.use_rope - - if self.use_rope: - self.rope = initialize_rope( - head_dim, - args.rope_theta, - traditional=True, - scaling_config=args.rope_scaling, - max_position_embeddings=args.max_position_embeddings, - ) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - B, L, D = x.shape - - queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) - - queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - - if cache is not None: - offset = cache.offset - else: - offset = 0 - - if self.use_rope: - queries = self.rope(queries, offset=offset) - keys = self.rope(keys, offset=offset) - - if self.use_qk_norm: - queries = mx.fast.rms_norm(queries, weight=None, eps=1e-6) - keys = mx.fast.rms_norm(keys, weight=None, eps=1e-6) - - if self.attn_temperature_tuning and not self.use_rope: - attn_scales = ( - mx.log( - mx.floor( - mx.arange(offset + 1, offset + L + 1) / self.floor_scale - ) - + 1.0 - ) - * self.attn_scale - + 1.0 - ) - attn_scales = attn_scales[:, None] - queries = (queries * attn_scales).astype(queries.dtype) - - if cache is not None: - keys, values = cache.update_and_fetch(keys, values) - - output = scaled_dot_product_attention( - queries, keys, values, cache=cache, scale=self.scale, mask=mask - ) - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output) - - -class MLP(nn.Module): - def __init__(self, args: ModelArgs, intermediate_size: int = None): - super().__init__() - - dim = args.hidden_size - hidden_dim = intermediate_size or args.intermediate_size - - self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) - self.down_proj = nn.Linear(hidden_dim, dim, bias=False) - self.up_proj = nn.Linear(dim, hidden_dim, bias=False) - - def __call__(self, x) -> mx.array: - return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) - - -class MoE(nn.Module): - def __init__(self, args): - super().__init__() - self.top_k = args.num_experts_per_tok - self.num_experts = args.num_local_experts - self.experts = SwitchGLU( - args.hidden_size, args.intermediate_size, self.num_experts - ) - self.router = nn.Linear( - args.hidden_size, args.num_local_experts, bias=False - ) - self.shared_expert = MLP(args) - - def __call__(self, x) -> mx.array: - logits = self.router(x) - k = self.top_k - indices = mx.argpartition(-logits, kth=k - 1, axis=-1)[..., :k] - scores = mx.take_along_axis(logits, indices, axis=-1) - scores = mx.sigmoid(scores.astype(mx.float32)).astype(x.dtype) - - out = self.experts(x * scores, indices).squeeze(2) - return out + self.shared_expert(x) - - -class TransformerBlock(nn.Module): - def __init__(self, args: TextArgs, layer_idx: int): - super().__init__() - self.num_attention_heads = args.num_attention_heads - self.hidden_size = args.hidden_size - self.self_attn = Attention(args, layer_idx) - self.is_moe_layer = (layer_idx % args.interleave_moe_layer_step) == ( - args.interleave_moe_layer_step - 1 - ) - if self.is_moe_layer: - self.feed_forward = MoE(args) - else: - self.feed_forward = MLP(args, args.intermediate_size_mlp) - - self.input_layernorm = nn.RMSNorm( - args.hidden_size, eps=args.rms_norm_eps - ) - self.post_attention_layernorm = nn.RMSNorm( - args.hidden_size, eps=args.rms_norm_eps - ) - self.args = args - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - r = self.self_attn(self.input_layernorm(x), mask, cache) - h = x + r - r = self.feed_forward(self.post_attention_layernorm(h)) - out = h + r - return out - - -class LlamaModel(nn.Module): - def __init__(self, args: TextArgs): - super().__init__() - self.args = args - self.vocab_size = args.vocab_size - self.num_hidden_layers = args.num_hidden_layers - assert self.vocab_size > 0 - self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ - TransformerBlock(args, i) for i in range(args.num_hidden_layers) - ] - self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.attention_chunk_size = args.attention_chunk_size - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - h = self.embed_tokens(inputs) - - if cache is not None: - for idx, c in enumerate(cache): - if (idx + 1) % 4 != 0: - c.maybe_trim_front() - start = cache[0].start_position - offset = cache[0].offset - else: - start = 0 - offset = 0 - end = offset + h.shape[1] - linds = mx.arange(start, end) - rinds = mx.arange(offset, end)[:, None] - block_pos = mx.abs( - (linds // self.attention_chunk_size) - - (rinds // self.attention_chunk_size) - ) - token_pos = linds <= rinds - chunk_mask = (block_pos == 0) & token_pos - - if mask is None: - mask = create_attention_mask(h, cache) - else: - chunk_mask &= mask - - if cache is None: - cache = [None] * len(self.layers) - - for idx, (layer, c) in enumerate(zip(self.layers, cache)): - use_chunked_attention = (idx + 1) % 4 != 0 - if use_chunked_attention: - local_mask = chunk_mask - else: - local_mask = mask - h = layer(h, local_mask, cache=c) - - return self.norm(h) - - -class LanguageModel(nn.Module): - def __init__(self, args: TextArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - self.model = LlamaModel(self.args) - self.lm_head = nn.Linear( - self.args.hidden_size, self.args.vocab_size, bias=False - ) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - out = self.model(inputs, mask, cache) - return self.lm_head(out) - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - self.language_model = LanguageModel(args.text_config) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - return self.language_model(inputs, mask, cache) - - def sanitize(self, weights): - def to_remove(k): - return "vision_model" in k or "multi_modal_projector" in k - - # Remove vision weights - weights = {k: v for k, v in weights.items() if not to_remove(k)} - - # Rename expert weights for SwitchGLU - for l in range(self.args.text_config.num_hidden_layers): - prefix = f"language_model.model.layers.{l}.feed_forward.experts" - if f"{prefix}.gate_up_proj" in weights: - v = weights.pop(f"{prefix}.gate_up_proj") - gate_k = f"{prefix}.gate_proj.weight" - up_k = f"{prefix}.up_proj.weight" - gate_proj, up_proj = mx.split(v, 2, axis=-1) - weights[gate_k] = mx.swapaxes(gate_proj, 1, 2) - weights[up_k] = mx.swapaxes(up_proj, 1, 2) - if f"{prefix}.down_proj" in weights: - down_proj = weights.pop(f"{prefix}.down_proj") - weights[f"{prefix}.down_proj.weight"] = mx.swapaxes( - down_proj, 1, 2 - ) - return weights - - @property - def layers(self): - return self.language_model.model.layers - - def make_cache(self): - chunk_size = self.args.text_config.attention_chunk_size - caches = [] - for i in range(len(self.layers)): - if (i + 1) % 4 != 0: - caches.append(ChunkedKVCache(chunk_size)) - else: - caches.append(KVCache()) - return caches diff --git a/aria/inference/model_mlx.py b/aria/inference/model_mlx.py index 5895351..22a211e 100644 --- a/aria/inference/model_mlx.py +++ b/aria/inference/model_mlx.py @@ -1,6 +1,5 @@ """Inference implementation for mlx backend""" -from dataclasses import dataclass from aria.model import ModelConfig import mlx.core as mx diff --git a/demo/demo_mlx.py b/demo/demo_mlx.py new file mode 100644 index 0000000..0ac5d1f --- /dev/null +++ b/demo/demo_mlx.py @@ -0,0 +1,1554 @@ +#!/usr/bin/env python3 + +import argparse +import os +import time +import uuid +import copy +import logging +import threading +import queue +import copy +import mido +import torch + +import mlx.core as mx +import mlx.nn as nn +import numpy as np + +from contextlib import ExitStack + +from ariautils.midi import MidiDict, midi_to_dict +from ariautils.tokenizer import AbsTokenizer +from aria.inference.model_mlx import TransformerLM +from aria.model import ModelConfig +from aria.config import load_model_config + +DTYPE = mx.float32 +MAX_SEQ_LEN = 2048 +PREFILL_CHUNK_SIZE = 32 +RECALC_DUR_PREFILL_CHUNK_SIZE = 8 +RECALC_DUR_BUFFER_MS = 50 + +# Decode first +BEAM_WIDTH = 3 +TIME_TOK_WEIGHTING = -5 +FIRST_ONSET_BUFFER_MS = 25 + +# HARDWARE: Decoded logits are masked for durations < MIN_NOTE_LEN_MS +# HARDWARE: Sends early off-msg if pitch is on MIN_NOTE_DELTA_MS before on-msg +# HARDWARE: All messages are sent HARDWARE_LATENCY_MS early +MIN_NOTE_DELTA_MS = 100 +MIN_NOTE_LEN_MS = 200 +HARDWARE_LATENCY_MS = 0 + +file_handler = logging.FileHandler("./demo.log", mode="w") +file_handler.setLevel(logging.DEBUG) + + +def get_logger(name: str | None = None) -> logging.Logger: + logger = logging.getLogger(name) + if not logger.handlers: + logger.propagate = False + logger.setLevel(logging.DEBUG) + + class MillisecondFormatter(logging.Formatter): + def formatTime(self, record, datefmt=None): + created_ms = int(record.created * 1000) + return str(created_ms) + + if name is not None: + formatter = MillisecondFormatter( + "%(asctime)s: [%(levelname)s] [%(name)s] %(message)s" + ) + else: + formatter = MillisecondFormatter( + "%(asctime)s: [%(levelname)s] %(message)s" + ) + + ch = logging.StreamHandler() + ch.setLevel(logging.INFO) + ch.setFormatter(formatter) + logger.addHandler(ch) + + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + return logger + + +def get_epoch_time_ms() -> int: + return round(time.time() * 1000) + + +def prefill( + model: TransformerLM, + idxs: mx.array, + input_pos: mx.array, + pad_idxs: mx.array | None = None, +) -> mx.array: + logits = model.forward( + idxs=idxs, + input_pos=input_pos, + pad_idxs=pad_idxs, + ) + + return logits + + +def decode_one( + model: TransformerLM, + idxs: mx.array, + input_pos: mx.array, + pad_idxs: mx.array | None = None, +) -> mx.array: + assert input_pos.shape[-1] == 1 + + logits = model.forward( + idxs=idxs, + input_pos=input_pos, + pad_idxs=pad_idxs, + )[:, -1] + + return logits + + +def sample_min_p(probs: mx.array, p_base: float): + """See - https://arxiv.org/pdf/2407.01082""" + + p_max = mx.max(probs, axis=-1, keepdims=True) + p_scaled = p_base * p_max + mask = probs >= p_scaled + + masked_probs = mx.where(~mask, mx.zeros_like(probs), probs) + sum_masked_probs = mx.sum(masked_probs, axis=-1, keepdims=True) + masked_probs_normalized = masked_probs / sum_masked_probs + + # Dumb workaround for mlx not having categorical probs sampler + next_token = mx.array( + torch.multinomial( + torch.from_numpy(np.array(masked_probs_normalized)), num_samples=1 + ), + dtype=mx.int32, + ) + + return next_token + + +def _compile_prefill( + model: TransformerLM, + logger: logging.Logger, + chunk_size: int, +): + assert chunk_size > 1 + + compile_start_time_s = time.time() + logger.info(f"Compiling prefill (chunk_size={chunk_size})") + res = prefill( + model, + idxs=mx.ones([1, chunk_size], dtype=mx.int32), + input_pos=mx.arange(0, chunk_size, dtype=mx.int32), + ) + mx.eval(res) + logger.info( + f"Finished compiling - took {time.time() - compile_start_time_s:.4f} seconds" + ) + + for _ in range(5): + res = prefill( + model, + idxs=mx.ones([1, chunk_size], dtype=mx.int32), + input_pos=mx.arange(0, chunk_size, dtype=mx.int32), + ) + mx.eval(res) + + bench_start_time_s = time.time() + prefill( + model, + idxs=mx.ones([1, chunk_size], dtype=mx.int32), + input_pos=mx.arange(0, chunk_size, dtype=mx.int32), + ) + mx.eval(res) + bench_end_time_s = time.time() + bench_ms = 1e3 * (bench_end_time_s - bench_start_time_s) + bench_its = 1000 / bench_ms + logger.info( + f"Compiled prefill benchmark: {bench_ms:.2f} ms/it ({bench_its:.2f} it/s)" + ) + + return model + + +def _compile_decode_one( + model: TransformerLM, + logger: logging.Logger, +): + # Don't need to explicitly compile with mlx, instead we are just precalculating + # the computation graphs for different shapes + compile_start_time_s = time.time() + res = decode_one( + model, + idxs=mx.array([[0]], dtype=mx.int32), + input_pos=mx.array([0], dtype=mx.int32), + ) + mx.eval(res) + logger.info( + f"Finished compiling - took {time.time() - compile_start_time_s:.4f} seconds" + ) + + for _ in range(5): + res = decode_one( + model, + idxs=mx.array([[0]], dtype=mx.int32), + input_pos=mx.array([0], dtype=mx.int32), + ) + mx.eval(res) + + bench_start_time_s = time.time() + decode_one( + model, + idxs=mx.array([[0]], dtype=mx.int32), + input_pos=mx.array([0], dtype=mx.int32), + ) + mx.eval(res) + bench_end_time_s = time.time() + bench_ms = 1e3 * (bench_end_time_s - bench_start_time_s) + bench_its = 1000 / bench_ms + logger.info( + f"Compiled decode_one benchmark: {bench_ms:.2f} ms/it ({bench_its:.2f} it/s)" + ) + + return model + + +def compile_model(model: TransformerLM): + logger = get_logger() + + model.eval() + model.setup_cache( + batch_size=1, + max_seq_len=MAX_SEQ_LEN, + dtype=DTYPE, + ) + + model = _compile_decode_one(model=model, logger=logger) + for chunk_size in list({PREFILL_CHUNK_SIZE, RECALC_DUR_PREFILL_CHUNK_SIZE}): + model = _compile_prefill( + model=model, logger=logger, chunk_size=chunk_size + ) + + return model + + +def load_model( + checkpoint_path: str, +): + logger = get_logger() + + tokenizer = AbsTokenizer() + model_config = ModelConfig(**load_model_config("medium-emb")) + model_config.set_vocab_size(tokenizer.vocab_size) + + logging.info(f"Loading model weights from {checkpoint_path}") + + init_start_time_s = time.time() + model = TransformerLM(model_config) + model.load_weights(checkpoint_path) + nn.quantize(model.model, group_size=128, bits=8) + model.eval() + + logger.info( + f"Finished initializing model - took {time.time() - init_start_time_s:.4f} seconds" + ) + + return model + + +def _first_bad_dur_index( + tokenizer: AbsTokenizer, + priming_seq: list, + pred_ids: list, + chunk_start: int, + last_offset_ms: int, + logger: logging.Logger, +): + num_time_toks = priming_seq[:chunk_start].count(tokenizer.time_tok) + local_onset_ms = tokenizer.calc_length_ms( + priming_seq[:chunk_start], onset=True + ) + logger.debug(f"Starting from local onset {local_onset_ms}") + + for pos, tok_id in enumerate( + pred_ids[: len(priming_seq) - chunk_start], start=chunk_start + ): + prim_tok = priming_seq[pos] # Should never error? + pred_tok = tokenizer.id_to_tok[tok_id] + logger.debug(f"prim={prim_tok}, pred={pred_tok}") + + if isinstance(prim_tok, tuple) and prim_tok[0] == "onset": + local_onset_ms = num_time_toks * 5000 + prim_tok[1] + elif prim_tok == tokenizer.time_tok: + num_time_toks += 1 + elif isinstance(prim_tok, tuple) and prim_tok[0] == "dur": + dur_true = prim_tok[1] + dur_pred = pred_tok[1] + if dur_pred > dur_true and ( + local_onset_ms + dur_true + > last_offset_ms - RECALC_DUR_BUFFER_MS + ): + logger.info( + f"Found token to resample at {pos}: {prim_tok} -> {pred_tok}" + ) + return pos + + return None + + +def recalc_dur_tokens_chunked( + model: TransformerLM, + priming_seq: list, + enc_seq: mx.array, + tokenizer: AbsTokenizer, + start_idx: int, +): + """Speculative-decoding inspired duration re-calculation""" + assert start_idx > 0 + logger = get_logger("GENERATE") + + priming_len = len(priming_seq) + last_offset = tokenizer.calc_length_ms(priming_seq) + + idx = start_idx + while idx <= priming_len: + end_idx = idx + RECALC_DUR_PREFILL_CHUNK_SIZE + + window_ids = mx.array( + enc_seq[:, idx - 1 : end_idx - 1].tolist(), + dtype=mx.int32, + ) + window_pos = mx.arange(idx - 1, end_idx - 1, dtype=mx.int32) + + logger.info( + f"Recalculating chunked durations for positions: {idx-1} - {end_idx-2}" + ) + logger.debug(f"Inserted: {tokenizer.decode(window_ids[0].tolist())}") + logger.debug(f"Positions: {window_pos.tolist()}") + + logits = prefill(model, idxs=window_ids, input_pos=window_pos) + pred_ids = mx.argmax(logits, axis=-1).flatten().tolist() + + bad_pos = _first_bad_dur_index( + tokenizer=tokenizer, + priming_seq=priming_seq, + pred_ids=pred_ids, + chunk_start=idx, + last_offset_ms=last_offset, + logger=logger, + ) + + if bad_pos is None: + idx = end_idx + else: + new_id = pred_ids[bad_pos - idx] + enc_seq[0, bad_pos] = new_id + priming_seq[bad_pos] = tokenizer.id_to_tok[new_id] + idx = bad_pos + + next_logits = logits[:, priming_len - idx] + + return enc_seq, priming_seq, next_logits + + +# TODO: This is now the latency bottleneck. +# Ideas for reducing it: +# - Get rid of the manual time_tok insert stuff, instead just mask logits +# for all invalid tokens, this should force the model to sample a time tok +# if there aren't any other valid options +def decode_first_tokens( + model: TransformerLM, + first_token_logits: mx.array, + enc_seq: mx.array, + priming_seq: list, + tokenizer: AbsTokenizer, + generated_tokens_queue: queue.Queue, + first_on_msg_epoch_ms: int, +): + logger = get_logger("GENERATE") + + buffer_ms = FIRST_ONSET_BUFFER_MS + HARDWARE_LATENCY_MS + time_tok_id = tokenizer.tok_to_id[tokenizer.time_tok] + + logits = first_token_logits + time_since_first_onset_ms = get_epoch_time_ms() - first_on_msg_epoch_ms + idx = len(priming_seq) + 1 + + num_time_toks_required = (time_since_first_onset_ms + buffer_ms) // 5000 + num_time_toks_in_priming_seq = priming_seq.count(tokenizer.time_tok) + num_time_toks_to_add = num_time_toks_required - num_time_toks_in_priming_seq + + logger.info(f"Time since first onset: {time_since_first_onset_ms}ms") + + while num_time_toks_to_add > 0: + generated_tokens_queue.put(tokenizer.time_tok) + logits = decode_one( + model, + idxs=mx.array([[time_tok_id]], dtype=mx.int32), + input_pos=mx.array([idx - 1], dtype=mx.int32), + ) + + logger.info(f"Inserted time_tok at position {idx-1}") + num_time_toks_to_add -= 1 + enc_seq[:, idx - 1] = mx.array([[time_tok_id]], dtype=mx.int32) + idx += 1 + + logits[:, tokenizer.tok_to_id[tokenizer.dim_tok]] = float("-inf") + logits[:, tokenizer.tok_to_id[tokenizer.eos_tok]] = float("-inf") + + log_probs = nn.log_softmax(logits, axis=-1) + top_log_probs = mx.topk(log_probs, k=BEAM_WIDTH, axis=-1) + top_ids = mx.argsort(log_probs, axis=-1)[..., -BEAM_WIDTH:] + + if time_tok_id not in top_ids[0].tolist(): + top_ids[0, -1] = time_tok_id + top_log_probs[0, -1] = log_probs[0, time_tok_id] + TIME_TOK_WEIGHTING + + top_toks = [tokenizer.id_to_tok[id] for id in top_ids[0].tolist()] + + logger.debug(f"Calculated top {BEAM_WIDTH} tokens={top_toks}") + logger.debug( + f"Calculated top {BEAM_WIDTH} scores={top_log_probs[0].tolist()}" + ) + + masked_onset_ids = [ + tokenizer.tok_to_id[tok] + for tok in tokenizer.onset_tokens + if tok[1] < ((time_since_first_onset_ms + buffer_ms) % 5000) + ] + + logger.debug( + f"Masking onsets for {len(masked_onset_ids)} tokens ({time_since_first_onset_ms + buffer_ms})" + ) + + best_score = float("-inf") + for i in range(BEAM_WIDTH): + tok = top_toks[i] + tok_id = top_ids[0, i].item() + tok_log_prob = top_log_probs[0, i] + + next_logits = decode_one( + model, + idxs=mx.array([[tok_id]], dtype=mx.int32), + input_pos=mx.array([idx - 1], dtype=mx.int32), + ) + logger.debug( + f"Sampled logits for positions {idx} by inserting {tok} at position {idx-1}" + ) + + # Is float("-inf") masking ok in mlx? + next_log_probs = nn.log_softmax(next_logits, axis=-1) + next_log_probs[:, masked_onset_ids] = float("-inf") + if tok_id == time_tok_id: + next_log_probs[:, time_tok_id] = float("-inf") + + next_tok_log_prob = mx.max(next_log_probs, axis=-1) + next_tok_id = mx.argmax(next_log_probs, axis=-1) + next_tok = tokenizer.id_to_tok[next_tok_id.item()] + score = tok_log_prob + next_tok_log_prob + + logger.info( + f"Calculated tuple {(tok, next_tok)} with scores {(tok_log_prob.item(), next_tok_log_prob.item())} (combined={score.item()})" + ) + + if score > best_score: + best_tok_id_1, best_tok_id_2 = tok_id, next_tok_id.item() + best_tok_1, best_tok_2 = ( + tokenizer.id_to_tok[best_tok_id_1], + tokenizer.id_to_tok[best_tok_id_2], + ) + best_score = score + + logger.info( + f"Chose tuple {(best_tok_1, best_tok_2)} with score {best_score.item()}" + ) + + enc_seq[:, idx - 1] = best_tok_id_1 + enc_seq[:, idx] = best_tok_id_2 + generated_tokens_queue.put(tokenizer.id_to_tok[best_tok_id_1]) + generated_tokens_queue.put(tokenizer.id_to_tok[best_tok_id_2]) + + decode_one( + model, + idxs=mx.array([[best_tok_id_1]], dtype=mx.int32), + input_pos=mx.array([idx - 1], dtype=mx.int32), + ) + + logger.info( + f"Updated KV-Cache by re-inserting {best_tok_1} at position {idx-1}" + ) + logger.info( + f"Inserted {best_tok_2} at position {idx} without updating KV-Cache" + ) + + return enc_seq, idx + 1 + + +def decode_tokens( + model: TransformerLM, + enc_seq: mx.array, + tokenizer: AbsTokenizer, + control_sentinel: threading.Event, + generated_tokens_queue: queue.Queue, + idx: int, + temperature: float, + min_p: float, +): + logger = get_logger("GENERATE") + logger.info( + f"Using sampling parameters: temperature={temperature}, min_p={min_p}" + ) + + while (not control_sentinel.is_set()) and idx < MAX_SEQ_LEN: + decode_one_start_time_s = time.time() + prev_tok_id = enc_seq[0, idx - 1] + prev_tok = tokenizer.id_to_tok[prev_tok_id.item()] + + logits = decode_one( + model, + idxs=mx.array([[prev_tok_id]], dtype=mx.int32), + input_pos=mx.array([idx - 1], dtype=mx.int32), + ) + + logger.debug( + f"Sampled logits for positions {idx} by inserting {prev_tok} at position {idx-1}" + ) + + logits[:, tokenizer.tok_to_id[tokenizer.dim_tok]] = float("-inf") + for dur_ms in range(0, MIN_NOTE_LEN_MS, 10): + logits[:, tokenizer.tok_to_id[("dur", dur_ms)]] = float("-inf") + + if temperature > 0.0: + probs = mx.softmax(logits / temperature, axis=-1) + next_token_ids = sample_min_p(probs, min_p).flatten() + else: + next_token_ids = mx.argmax(logits, axis=-1).flatten() + + enc_seq[:, idx] = next_token_ids + next_token = tokenizer.id_to_tok[next_token_ids[0].item()] + logger.debug( + f"({(time.time() - decode_one_start_time_s)*1000:.2f}ms) {idx}: {next_token}" + ) + + if next_token == tokenizer.eos_tok: + logger.info("EOS token produced, exiting...") + generated_tokens_queue.put(next_token) + return + else: + generated_tokens_queue.put(next_token) + idx += 1 + + while not control_sentinel.is_set(): + time.sleep(0.1) + + logger.info("Seen exit signal") + generated_tokens_queue.put(None) + + +def generate_tokens( + priming_seq: list, + tokenizer: AbsTokenizer, + model: TransformerLM, + prev_context: list[int], + control_sentinel: threading.Event, + generated_tokens_queue: queue.Queue, + num_preceding_active_pitches: int, + first_on_msg_epoch_ms: int, + temperature: float = 0.97, + min_p: float = 0.03, +): + logger = get_logger("GENERATE") + + generate_start_s = time.time() + priming_seq_len = len(priming_seq) + start_idx = max(2, priming_seq_len - 4 * num_preceding_active_pitches - 1) + enc_seq = mx.array( + [ + tokenizer.encode( + priming_seq + + [tokenizer.pad_tok] * (MAX_SEQ_LEN - len(priming_seq)) + ) + ], + dtype=mx.int32, + ) + + logger.debug(f"Priming sequence {priming_seq}") + logger.info(f"Priming sequence length: {priming_seq_len}") + logger.info(f"Prefilling up to (and including) position: {start_idx-1}") + + prefill_start_s = time.time() + chunked_prefill( + model=model, + tokenizer=tokenizer, + prev_context=prev_context, + curr_context=enc_seq[0, :start_idx].tolist(), + full=True, + ) + + logger.info( + f"Prefill took {(time.time() - prefill_start_s) * 1000:.2f} milliseconds" + ) + logger.info(f"Starting duration recalculation from position: {start_idx-1}") + + recalculate_dur_start_s = time.time() + enc_seq, priming_seq, next_token_logits = recalc_dur_tokens_chunked( + model=model, + priming_seq=priming_seq, + enc_seq=enc_seq, + tokenizer=tokenizer, + start_idx=start_idx, + ) + + logger.info( + f"Recalculating durations took {(time.time() - recalculate_dur_start_s) * 1000:.2f} milliseconds" + ) + + decode_first_s = time.time() + enc_seq, idx = decode_first_tokens( + model=model, + first_token_logits=next_token_logits, + enc_seq=enc_seq, + priming_seq=priming_seq, + tokenizer=tokenizer, + generated_tokens_queue=generated_tokens_queue, + first_on_msg_epoch_ms=first_on_msg_epoch_ms, + ) + + logger.info( + f"Decode first two tokens took {(time.time() - decode_first_s) * 1000:.2f} milliseconds" + ) + logger.info( + f"Time to first token took {(time.time() - generate_start_s) * 1000:.2f} milliseconds" + ) + + decode_tokens( + model=model, + enc_seq=enc_seq, + tokenizer=tokenizer, + control_sentinel=control_sentinel, + generated_tokens_queue=generated_tokens_queue, + idx=idx, + temperature=temperature, + min_p=min_p, + ) + + +def decode_tokens_to_midi( + generated_tokens_queue: queue.Queue, + outbound_midi_msg_queue: queue.Queue, + tokenizer: AbsTokenizer, + first_on_msg_epoch_ms: int, + priming_seq_last_onset_ms: int, +): + logger = get_logger("DECODE") + + assert ( + first_on_msg_epoch_ms + priming_seq_last_onset_ms < get_epoch_time_ms() + ) + + logger.info(f"Priming sequence last onset: {priming_seq_last_onset_ms}") + logger.info( + f"Total time elapsed since first onset: {get_epoch_time_ms() - first_on_msg_epoch_ms}" + ) + + pitch_to_prev_msg = {} + note_buffer = [] + num_time_toks = priming_seq_last_onset_ms // 5000 + + while True: + while True: + tok = generated_tokens_queue.get() + if tok is tokenizer.eos_tok: + _uuid = uuid.uuid4() + end_msg = { + "pitch": -1, + "vel": -1, + "epoch_time_ms": offset_epoch_ms + 250, # Last note offset + "uuid": _uuid, + } # pitch=-1 denotes end_msg + outbound_midi_msg_queue.put(end_msg) + logger.info(f"Seen exit signal: EOS token") + logger.debug(f"Put message: {end_msg}") + return + + elif tok is None: + logger.info(f"Seen exit signal") + return + + logger.debug(f"Seen token: {tok}") + note_buffer.append(tok) + + if isinstance(tok, tuple) and tok[0] == "dur": + break + + while note_buffer and note_buffer[0] == tokenizer.time_tok: + logger.debug("Popping time_tok") + num_time_toks += 1 + note_buffer.pop(0) + + assert len(note_buffer) == 3 + logger.debug(f"Decoded note: {note_buffer}") + note_tok, onset_tok, dur_tok = note_buffer + _, pitch, vel = note_tok + _, onset = onset_tok + _, dur = dur_tok + + _uuid = uuid.uuid4() + onset_epoch_ms = first_on_msg_epoch_ms + (num_time_toks * 5000) + onset + offset_epoch_ms = onset_epoch_ms + dur + on_msg = { + "pitch": pitch, + "vel": vel, + "epoch_time_ms": onset_epoch_ms, + "uuid": _uuid, + } + off_msg = { + "pitch": pitch, + "vel": 0, + "epoch_time_ms": offset_epoch_ms, + "uuid": _uuid, + } + + # Not thread safe but in theory should be ok? + if pitch_to_prev_msg.get(pitch) is not None and MIN_NOTE_DELTA_MS > 0: + prev_on, prev_off = pitch_to_prev_msg.get(pitch) + adj_off_time = max( + min( + prev_off["epoch_time_ms"], + onset_epoch_ms - MIN_NOTE_DELTA_MS, + ), + prev_on["epoch_time_ms"], + ) + if adj_off_time != prev_off["epoch_time_ms"]: + logger.debug(f"Adjusting {prev_off}: t={adj_off_time}") + prev_off["epoch_time_ms"] = adj_off_time + prev_off["adjusted"] = True + + pitch_to_prev_msg[pitch] = [on_msg, off_msg] + + outbound_midi_msg_queue.put(on_msg) + outbound_midi_msg_queue.put(off_msg) + logger.debug(f"Put message: {on_msg}") + logger.debug(f"Put message: {off_msg}") + logger.debug(f"Ahead by {onset_epoch_ms - get_epoch_time_ms()}ms") + + note_buffer = [] + + +# TODO: Test the new changes in decode_tokens_to_midi and clean this fn up. +def stream_midi( + inbound_midi_msg_queue: queue.Queue, + msgs: list[mido.Message], + prev_msg_epoch_time_ms: float, + midi_output_port: str, + control_sentinel: threading.Event, + midi_stream_channel: int, + results_queue: queue.Queue, +): + logger = get_logger("STREAM") + logger.info( + f"Sending generated messages on MIDI port: '{midi_output_port}'" + ) + logger.info( + f"Applying hardware latency adjustment: {HARDWARE_LATENCY_MS}ms" + ) + MAX_DELAY_MS = 50 + + active_pitch_uuid = {} + is_pitch_active = {} + midi_msgs = [] + + with mido.open_output(midi_output_port) as midi_out: + while not control_sentinel.is_set(): + while True: + try: + msg = inbound_midi_msg_queue.get_nowait() + except queue.Empty: + break + else: + logger.debug(f"Received message: {msg}") + midi_msgs.append(msg) + + midi_msgs = sorted( + midi_msgs, + key=lambda msg: ( + msg["epoch_time_ms"], + msg["vel"], + ), + ) + + if control_sentinel.is_set(): + break + + while midi_msgs: + latency_adjusted_epoch_time_ms = ( + get_epoch_time_ms() + HARDWARE_LATENCY_MS + ) + msg = midi_msgs[0] + + if ( + 0 + < latency_adjusted_epoch_time_ms - msg["epoch_time_ms"] + <= MAX_DELAY_MS + ): + if msg["pitch"] == -1: # End msg + control_sentinel.set() + break + + mido_msg = mido.Message( + "note_on", + note=msg["pitch"], + velocity=msg["vel"], + channel=0, + time=0, + ) + + if msg["vel"] > 0: + active_pitch_uuid[msg["pitch"]] = msg["uuid"] + should_send_midi_out = True + should_append_to_msgs = True + elif msg.get("adjusted", False) is True: + should_send_midi_out = True + should_append_to_msgs = False + else: + should_send_midi_out = ( + active_pitch_uuid.get(msg["pitch"]) == msg["uuid"] + ) + should_append_to_msgs = should_send_midi_out + + if should_send_midi_out is True: + midi_out.send(mido_msg) + is_pitch_active[msg["pitch"]] = msg["vel"] != 0 + logger.info(f"Sent message: {mido_msg}") + if should_append_to_msgs is True: + mido_msg_with_time = copy.deepcopy(mido_msg) + mido_msg_with_time.channel = midi_stream_channel + mido_msg_with_time.time = max( + 0, msg["epoch_time_ms"] - prev_msg_epoch_time_ms + ) + prev_msg_epoch_time_ms = msg["epoch_time_ms"] + msgs.append(mido_msg_with_time) + + midi_msgs.pop(0) + + elif ( + latency_adjusted_epoch_time_ms - msg["epoch_time_ms"] + > MAX_DELAY_MS + ): + # Message occurs too far in the past + logger.debug( + f"Skipping message occurring too far ({latency_adjusted_epoch_time_ms - msg["epoch_time_ms"]}ms) in the past: {msg}" + ) + midi_msgs.pop(0) + else: + # Message occurs in the future + break + + time.sleep(0.005) + + remaining_note_off_messages = [ + msg + for msg in midi_msgs + if msg["vel"] == 0 + and active_pitch_uuid.get(msg["pitch"]) == msg["uuid"] + ] + + logger.info("Processing remaining note_off messages") + for __msg in remaining_note_off_messages: + logger.debug(remaining_note_off_messages) + + for msg in remaining_note_off_messages: + mido_msg = mido.Message( + "note_on", + note=msg["pitch"], + velocity=0, + channel=midi_stream_channel, + time=msg["epoch_time_ms"] - prev_msg_epoch_time_ms, + ) + prev_msg_epoch_time_ms = msg["epoch_time_ms"] + msgs.append(mido_msg) + + results_queue.put(msgs) + + while remaining_note_off_messages: + msg = remaining_note_off_messages.pop(0) + while True: + latency_adjusted_epoch_time_ms = ( + get_epoch_time_ms() + HARDWARE_LATENCY_MS + ) + + if 0 < latency_adjusted_epoch_time_ms - msg["epoch_time_ms"]: + mido_msg = mido.Message( + "note_on", + note=msg["pitch"], + velocity=0, + channel=midi_stream_channel, + time=0, # Does not matter as only used for streaming + ) + midi_out.send(mido_msg) + logger.info(f"Sent message: {mido_msg}") + break + else: + time.sleep(0.01) + + +def stream_msgs( + model: TransformerLM, + tokenizer: AbsTokenizer, + msgs: list[mido.Message], + prev_context: list[int], + midi_output_port: str, + first_on_msg_epoch_ms: int, + control_sentinel: threading.Event, + temperature: float, + min_p: float, + num_preceding_active_pitches: int, + midi_stream_channel: int, + is_ending: bool = False, +): + midi = convert_msgs_to_midi(msgs=msgs) + midi_dict = MidiDict(**midi_to_dict(midi)) + priming_seq = tokenizer.tokenize(midi_dict=midi_dict, add_dim_tok=False) + priming_seq = priming_seq[: priming_seq.index(tokenizer.eos_tok)] + + if is_ending is True: + priming_seq.append(tokenizer.dim_tok) + + generated_tokens_queue = queue.Queue() + midi_messages_queue = queue.Queue() + + generate_tokens_thread = threading.Thread( + target=generate_tokens, + kwargs={ + "priming_seq": priming_seq, + "tokenizer": tokenizer, + "model": model, + "prev_context": prev_context, + "control_sentinel": control_sentinel, + "generated_tokens_queue": generated_tokens_queue, + "temperature": temperature, + "min_p": min_p, + "num_preceding_active_pitches": num_preceding_active_pitches, + "first_on_msg_epoch_ms": first_on_msg_epoch_ms, + }, + ) + generate_tokens_thread.start() + + decode_tokens_to_midi_thread = threading.Thread( + target=decode_tokens_to_midi, + kwargs={ + "generated_tokens_queue": generated_tokens_queue, + "outbound_midi_msg_queue": midi_messages_queue, + "tokenizer": tokenizer, + "first_on_msg_epoch_ms": first_on_msg_epoch_ms, + "priming_seq_last_onset_ms": tokenizer.calc_length_ms( + priming_seq, onset=True + ), + }, + ) + decode_tokens_to_midi_thread.start() + + prev_ms_epoch_time_ms = ( + first_on_msg_epoch_ms + + tokenizer.calc_length_ms(priming_seq, onset=False) + if is_ending is False + else first_on_msg_epoch_ms + ) + + stream_midi_results_queue = queue.Queue() + stream_midi_thread = threading.Thread( + target=stream_midi, + kwargs={ + "inbound_midi_msg_queue": midi_messages_queue, + "msgs": msgs, + "prev_msg_epoch_time_ms": prev_ms_epoch_time_ms, + "midi_output_port": midi_output_port, + "control_sentinel": control_sentinel, + "midi_stream_channel": midi_stream_channel, + "results_queue": stream_midi_results_queue, + }, + daemon=True, + ) + stream_midi_thread.start() + + generate_tokens_thread.join() + decode_tokens_to_midi_thread.join() + msgs = stream_midi_results_queue.get() + + if is_ending is True: + stream_midi_thread.join() + + return msgs + + +# TODO: Channel 9 issues here? +def convert_msgs_to_midi(msgs: list[mido.Message]): + channel_to_track = { + chan: mido.MidiTrack() + for chan in list(set([msg.channel for msg in msgs])) + } + + for msg in msgs: + channel_to_track[msg.channel].append(msg) + + # Workaround for possibility that track_0 start time != first_on_msg_epoch_ms + for msg in channel_to_track[0]: + if msg.type == "note_on" and msg.velocity > 0: + msg.time = 0 + break + else: + msg.time = 0 + + mid = mido.MidiFile(type=1) + mid.ticks_per_beat = 500 + + for channel, track in channel_to_track.items(): + track.insert(0, mido.MetaMessage("set_tempo", tempo=500000, time=0)) + track.insert( + 0, + mido.Message("program_change", program=0, channel=channel, time=0), + ) + mid.tracks.append(track) + + return mid + + +def _find_divergence( + prev_context: list, + curr_context: list, + logger: logging.Logger, +): + agreement_index = 0 + for prev_val, curr_val in zip(prev_context, curr_context): + if prev_val == curr_val: + agreement_index += 1 + else: + logger.info( + f"Found divergence at position {agreement_index + 1}: {curr_val}, {prev_val}" + ) + break + + return agreement_index, curr_context[agreement_index:] + + +def chunked_prefill( + model: TransformerLM, + tokenizer: AbsTokenizer, + prev_context: list, + curr_context: list, + full: bool = False, +): + + assert isinstance(curr_context[0], int) + assert tokenizer.pad_id not in prev_context + assert tokenizer.pad_id not in curr_context + + logger = get_logger("PREFILL") + while True: + prefill_idx, prefill_toks = _find_divergence( + prev_context, curr_context, logger=logger + ) + num_prefill_toks = len(prefill_toks) + logger.debug(f"Tokens to prefill: {len(prefill_toks)}") + + if num_prefill_toks > PREFILL_CHUNK_SIZE: + logger.debug( + f"Prefilling {PREFILL_CHUNK_SIZE} tokens from idx={prefill_idx}" + ) + + prefill( + model, + idxs=mx.array( + [prefill_toks[:PREFILL_CHUNK_SIZE]], + dtype=mx.int32, + ), + input_pos=mx.arange( + prefill_idx, + prefill_idx + PREFILL_CHUNK_SIZE, + dtype=mx.int32, + ), + ) + prev_context = curr_context[: prefill_idx + PREFILL_CHUNK_SIZE] + + elif num_prefill_toks > 0 and full is True: + logger.debug( + f"Prefilling (force) {num_prefill_toks} tokens from idx={prefill_idx}" + ) + prefill_toks += (PREFILL_CHUNK_SIZE - len(prefill_toks)) * [ + tokenizer.pad_id + ] + prefill( + model, + idxs=mx.array([prefill_toks], dtype=mx.int32), + input_pos=mx.array( + prefill_idx, + prefill_idx + PREFILL_CHUNK_SIZE, + dtype=mx.int32, + ), + ) + prev_context = curr_context + break + else: + break + + logger.info( + f"KV stored up to idx={max(0, len(prev_context)- 1)} (curr_context_len={len(curr_context)})" + ) + + return prev_context + + +def continuous_prefill( + model: TransformerLM, + msgs: list, + received_messages_queue: queue.Queue, + prev_context: list[int], +): + tokenizer = AbsTokenizer() + logger = get_logger("PREFILL") + msg_cnt = 0 + seen_sentinel = False + + while seen_sentinel is False: + while seen_sentinel is False: + try: + msg = received_messages_queue.get_nowait() + except queue.Empty: + break + else: + if msg is None: + logger.info("Seen sentinel in message received messages") + seen_sentinel = True + else: + msgs.append(msg) + msg_cnt += 1 + + if (msg_cnt >= 5 or seen_sentinel) and len(msgs) > 10: + midi = convert_msgs_to_midi(msgs=msgs) + midi_dict = MidiDict(**midi_to_dict(midi)) + curr_context = tokenizer.encode( + tokenizer.tokenize(midi_dict, add_dim_tok=False) + ) + prev_context = chunked_prefill( + model=model, + tokenizer=tokenizer, + prev_context=prev_context, + curr_context=curr_context, + full=False, + ) + msg_cnt = 0 + else: + time.sleep(0.01) + + return msgs, prev_context + + +def capture_and_update_kv( + model: TransformerLM, + msgs: list, + prev_context: list, + control_sentinel: threading.Event, + midi_input_port: str, + midi_capture_channel: int, + midi_control_signal: int | None = None, + midi_through_port: str | None = None, + first_msg_epoch_time_ms: int | None = None, +): + received_messages_queue = queue.Queue() + results_queue = queue.Queue() + capture_midi_thread = threading.Thread( + target=capture_midi_input, + kwargs={ + "midi_input_port": midi_input_port, + "control_sentinel": control_sentinel, + "received_messages_queue": received_messages_queue, + "midi_capture_channel": midi_capture_channel, + "midi_control_signal": midi_control_signal, + "midi_through_port": midi_through_port, + "first_msg_epoch_time_ms": first_msg_epoch_time_ms, + "results_queue": results_queue, + }, + ) + capture_midi_thread.start() + + msgs, prev_context = continuous_prefill( + model=model, + msgs=msgs, + received_messages_queue=received_messages_queue, + prev_context=prev_context, + ) + capture_midi_thread.join() + first_on_msg_epoch_ms, num_active_pitches = results_queue.get() + + return msgs, prev_context, first_on_msg_epoch_ms, num_active_pitches + + +def capture_midi_input( + midi_input_port: str, + control_sentinel: threading.Event, + received_messages_queue: queue.Queue, + midi_capture_channel: int, + results_queue: queue.Queue, + midi_control_signal: int | None = None, + midi_through_port: str | None = None, + first_msg_epoch_time_ms: int | None = None, +): + logger = get_logger("CAPTURE") + active_pitches = set() + first_on_msg_epoch_ms = None + prev_msg_epoch_time_ms = first_msg_epoch_time_ms # + + logger.info(f"Listening on MIDI port: '{midi_input_port}'") + logger.info(f"Using MIDI control signal: {midi_control_signal}") + if midi_through_port is not None: + logger.info(f"Sending through on MIDI port: '{midi_through_port}'") + + with ExitStack() as stack: + midi_input = stack.enter_context(mido.open_input(midi_input_port)) + midi_through = ( + stack.enter_context(mido.open_output(midi_through_port)) + if midi_through_port + else None + ) + + while not control_sentinel.is_set(): + msg = midi_input.receive(block=False) + + if msg is None: + time.sleep(0.001) + continue + + if prev_msg_epoch_time_ms is None: + msg_time_ms = 0 + else: + msg_time_ms = get_epoch_time_ms() - prev_msg_epoch_time_ms + + prev_msg_epoch_time_ms = get_epoch_time_ms() + msg.time = msg_time_ms + msg.channel = midi_capture_channel + logger.info(f"Received message: [{msg}]") + + if msg.is_meta is True or msg.type == "program_change": + continue + + if ( + msg.type == "note_on" and msg.velocity == 0 + ) or msg.type == "note_off": + active_pitches.discard(msg.note) + received_messages_queue.put(msg) + if midi_through is not None: + midi_through.send(msg) + elif msg.type == "note_on" and msg.velocity > 0: + if first_on_msg_epoch_ms is None: + first_on_msg_epoch_ms = get_epoch_time_ms() + + active_pitches.add(msg.note) + received_messages_queue.put(msg) + if midi_through is not None: + midi_through.send(msg) + elif msg.type == "control_change" and msg.control == 64: + received_messages_queue.put(msg) + elif ( + msg.type == "control_change" + and msg.control == midi_control_signal + and msg.value > 0 + ): + control_sentinel.set() + + logger.info("Control signal seen") + logger.info(f"Active pitches: {active_pitches}") + num_active_pitches = len(active_pitches) + + if active_pitches: + pitch = active_pitches.pop() + msg = mido.Message( + type="note_on", + note=pitch, + velocity=0, + channel=midi_capture_channel, + time=get_epoch_time_ms() - prev_msg_epoch_time_ms, + ) + received_messages_queue.put(msg) + if midi_through is not None: + midi_through.send(msg) + + while active_pitches: + pitch = active_pitches.pop() + msg = mido.Message( + type="note_on", + note=pitch, + velocity=0, + channel=midi_capture_channel, + time=0, + ) + received_messages_queue.put(msg) + if midi_through is not None: + midi_through.send(msg) + + # Turn off pedal + msg = mido.Message( + type="control_change", + control=64, + value=0, + channel=midi_capture_channel, + time=0, + ) + received_messages_queue.put(msg) + if midi_through is not None: + midi_through.send(msg) + + received_messages_queue.put(None) # Sentinel + results_queue.put((first_on_msg_epoch_ms, num_active_pitches)) + + +def play_midi_file(midi_port: str, midi_path: str): + logger = get_logger("FILE") + logger.info(f"Playing file at {midi_path} on MIDI port '{midi_port}'") + time.sleep(1) + active_pitches = [] + with mido.open_output(midi_port) as output_port: + for msg in mido.MidiFile(midi_path).play(): + if msg.type == "note_on" and msg.velocity > 0: + if msg.note in active_pitches: + _off_msg = copy.deepcopy(msg) + _off_msg.velocity = 0 + output_port.send(_off_msg) + else: + active_pitches.append(msg.note) + elif msg.type == "note_off" or ( + msg.type == "note_on" and msg.velocity == 0 + ): + if msg.note in active_pitches: + active_pitches.remove(msg.note) + + logger.debug(f"{msg}") + output_port.send(msg) + + +def listen_for_keypress_control_signal( + control_sentinel: threading.Event, + end_sentinel: threading.Event, +): + logger = get_logger("KEYBOARD") + while True: + time.sleep(1) + _input = input() + logger.info(f'Keypress seen "{_input}"') + control_sentinel.set() + + if _input == "e": + end_sentinel.set() + + +# TODO: Not tested +def listen_for_midi_control_signal( + midi_input_port: str, + control_sentinel: threading.Event, + end_sentinel: threading.Event, + midi_control_signal: int | None = None, + midi_end_signal: int | None = None, +): + with mido.open_input(midi_input_port) as midi_input: + while True: + msg = midi_input.receive(block=False) + if msg is None: + time.sleep(0.01) + elif ( + msg.type == "control_change" + and msg.control == midi_control_signal + and msg.value > 0 + ): + control_sentinel.set() + elif ( + msg.type == "control_change" + and msg.control == midi_end_signal + and msg.value > 0 + ): + control_sentinel.set() + end_sentinel.set() + + +def parse_args(): + argp = argparse.ArgumentParser() + argp.add_argument("-cp", help="path to model checkpoint") + argp.add_argument("-midi_in", required=False, help="MIDI input port") + argp.add_argument("-midi_out", required=True, help="MIDI output port") + argp.add_argument( + "-midi_through", + required=False, + help="MIDI through port for received input", + ) + argp.add_argument( + "-midi_path", + required=False, + help="Use MIDI file instead of MIDI input port", + ) + argp.add_argument( + "-midi_control_signal", + type=int, + help="MIDI control change message for AI takeover", + ) + argp.add_argument( + "-midi_end_signal", + type=int, + help="MIDI control change message to generate ending", + ) + argp.add_argument( + "-temp", + help="sampling temperature value", + type=float, + required=False, + default=0.95, + ) + argp.add_argument( + "-min_p", + help="sampling min_p value", + type=float, + required=False, + default=0.03, + ) + argp.add_argument( + "-cfg", + help="sampling cfg gamma value", + type=float, + required=False, + ) + argp.add_argument( + "-metadata", + nargs=2, + metavar=("KEY", "VALUE"), + action="append", + help="manually add metadata key-value pair when sampling", + ) + argp.add_argument( + "-save_path", + type=str, + required=False, + help="Path to save complete MIDI file", + ) + + return argp.parse_args() + + +# TODO: Need functionality for handing case where we run out of model context +# TODO: Make sure channel=9 (drum) case is covered +def main(): + args = parse_args() + logger = get_logger() + tokenizer = AbsTokenizer() + model = load_model(checkpoint_path=args.cp) + model = compile_model(model=model) + + assert (args.midi_path and os.path.isfile(args.midi_path)) or args.midi_in + if args.midi_path: + midi_input_port = "Midi Through:Midi Through Port-0" + play_file_thread = threading.Thread( + target=play_midi_file, + args=(midi_input_port, args.midi_path), + daemon=True, + ) + play_file_thread.start() + else: + midi_input_port = args.midi_in + + control_sentinel = threading.Event() + end_sentinel = threading.Event() + keypress_thread = threading.Thread( + target=listen_for_keypress_control_signal, + args=[control_sentinel, end_sentinel], + daemon=True, + ) + midi_control_thread = threading.Thread( + target=listen_for_midi_control_signal, + kwargs={ + "midi_input_port": midi_input_port, + "control_sentinel": control_sentinel, + "end_sentinel": end_sentinel, + "midi_control_signal": args.midi_control_signal, + "midi_end_signal": args.midi_end_signal, + }, + daemon=True, + ) + keypress_thread.start() + midi_control_thread.start() + + msgs, prev_context, first_on_msg_epoch_ms, num_active_pitches = ( + capture_and_update_kv( + model=model, + msgs=[], + prev_context=[], + control_sentinel=control_sentinel, + midi_input_port=midi_input_port, + midi_control_signal=args.midi_control_signal, + midi_through_port=args.midi_through, + midi_capture_channel=0, + ) + ) + + itt = 0 + while True: + control_sentinel.clear() + msgs = stream_msgs( + model=model, + tokenizer=tokenizer, + msgs=msgs, + prev_context=prev_context, + midi_output_port=args.midi_out, + first_on_msg_epoch_ms=first_on_msg_epoch_ms, + control_sentinel=control_sentinel, + temperature=args.temp, + min_p=args.min_p, + num_preceding_active_pitches=num_active_pitches, + midi_stream_channel=itt, + is_ending=False, + ) + + itt += 1 + control_sentinel.clear() + if end_sentinel.is_set(): + break + + msgs, prev_context, _, num_active_pitches = capture_and_update_kv( + model=model, + msgs=msgs, + prev_context=prev_context, + control_sentinel=control_sentinel, + midi_input_port=midi_input_port, + midi_control_signal=args.midi_control_signal, + midi_through_port=args.midi_through, + midi_capture_channel=itt, + first_msg_epoch_time_ms=first_on_msg_epoch_ms, + ) + + # TODO: There is a bug with the token somewhere? + msgs = stream_msgs( + model=model, + tokenizer=tokenizer, + msgs=msgs, + prev_context=prev_context, + midi_output_port=args.midi_out, + first_on_msg_epoch_ms=first_on_msg_epoch_ms, + control_sentinel=control_sentinel, + temperature=args.temp / 2, + min_p=args.min_p, + num_preceding_active_pitches=num_active_pitches, + midi_stream_channel=itt, + is_ending=True, + ) + + if args.save_path: + logger.info(f"Saving result to {args.save_path}") + midi = convert_msgs_to_midi(msgs=msgs) + midi.save(args.save_path) + + +if __name__ == "__main__": + main() From 571c0a63c942645997af264d3edfefe61f2becd3 Mon Sep 17 00:00:00 2001 From: Louis Date: Thu, 29 May 2025 15:32:34 +0000 Subject: [PATCH 48/72] add script --- demo/demo.sh | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/demo/demo.sh b/demo/demo.sh index 05a0328..03f2d89 100644 --- a/demo/demo.sh +++ b/demo/demo.sh @@ -1,8 +1,12 @@ +MID_PATH="/home/loubb/Dropbox/shared/demo.mid" + python /home/loubb/work/aria/demo/demo.py \ - -cp /mnt/ssd1/aria/v2/medium-75-ft.safetensors \ - -midi_path /home/loubb/Dropbox/shared/prompt/nocturne.mid \ + -cp /mnt/ssd1/aria/v2/medium-dedupe-pt-cont2/checkpoints/epoch18_step0/model.safetensors \ + -midi_path ${MID_PATH} \ -midi_out "Midi Through:Midi Through Port-1" \ -midi_through "Midi Through:Midi Through Port-2" \ -save_path /home/loubb/Dropbox/shared/output.mid \ -midi_control_signal 66 \ - -temp 0.95 + -midi_end_signal 67 \ + -temp 0.98 \ + -min_p 0.02 \ No newline at end of file From a0944a4e55f994564229b6ee82465cfd1bf389b1 Mon Sep 17 00:00:00 2001 From: Louis Date: Mon, 2 Jun 2025 17:39:11 +0100 Subject: [PATCH 49/72] update mlx demo --- demo/demo_mlx.py | 427 ++++++++++++++++++++++++----------------------- demo/dump.py | 75 --------- requirements.txt | 2 +- 3 files changed, 220 insertions(+), 284 deletions(-) delete mode 100644 demo/dump.py diff --git a/demo/demo_mlx.py b/demo/demo_mlx.py index 0ac5d1f..615dfc2 100644 --- a/demo/demo_mlx.py +++ b/demo/demo_mlx.py @@ -5,6 +5,7 @@ import time import uuid import copy +import random import logging import threading import queue @@ -24,23 +25,24 @@ from aria.model import ModelConfig from aria.config import load_model_config +# TODO: Investigate DTYPE=mx.float16 (speedup?) DTYPE = mx.float32 MAX_SEQ_LEN = 2048 PREFILL_CHUNK_SIZE = 32 RECALC_DUR_PREFILL_CHUNK_SIZE = 8 RECALC_DUR_BUFFER_MS = 50 -# Decode first BEAM_WIDTH = 3 TIME_TOK_WEIGHTING = -5 -FIRST_ONSET_BUFFER_MS = 25 +FIRST_ONSET_BUFFER_MS = -150 # Controls onset timing for first generated note # HARDWARE: Decoded logits are masked for durations < MIN_NOTE_LEN_MS # HARDWARE: Sends early off-msg if pitch is on MIN_NOTE_DELTA_MS before on-msg # HARDWARE: All messages are sent HARDWARE_LATENCY_MS early -MIN_NOTE_DELTA_MS = 100 -MIN_NOTE_LEN_MS = 200 -HARDWARE_LATENCY_MS = 0 +MIN_NOTE_DELTA_MS = 50 +MIN_NOTE_LEN_MS = 100 +HARDWARE_LATENCY_MS = 100 +MAX_STREAM_DELAY_MS = 50 file_handler = logging.FileHandler("./demo.log", mode="w") file_handler.setLevel(logging.DEBUG) @@ -87,9 +89,11 @@ def prefill( input_pos: mx.array, pad_idxs: mx.array | None = None, ) -> mx.array: - logits = model.forward( + # pad_idxs is only needed for prepended pad tokens + logits = model( idxs=idxs, input_pos=input_pos, + offset=input_pos[0], pad_idxs=pad_idxs, ) @@ -102,11 +106,13 @@ def decode_one( input_pos: mx.array, pad_idxs: mx.array | None = None, ) -> mx.array: + # pad_idxs is only needed for prepended pad tokens assert input_pos.shape[-1] == 1 - logits = model.forward( + logits = model( idxs=idxs, input_pos=input_pos, + offset=input_pos[0], pad_idxs=pad_idxs, )[:, -1] @@ -144,31 +150,29 @@ def _compile_prefill( compile_start_time_s = time.time() logger.info(f"Compiling prefill (chunk_size={chunk_size})") - res = prefill( - model, - idxs=mx.ones([1, chunk_size], dtype=mx.int32), - input_pos=mx.arange(0, chunk_size, dtype=mx.int32), - ) - mx.eval(res) + for _start_idx in range(0, MAX_SEQ_LEN, chunk_size * 4): + mx.eval( + prefill( + model, + idxs=mx.ones([1, chunk_size], dtype=mx.int32), + input_pos=mx.arange( + _start_idx, _start_idx + chunk_size, dtype=mx.int32 + ), + ) + ) + logger.info( f"Finished compiling - took {time.time() - compile_start_time_s:.4f} seconds" ) - for _ in range(5): - res = prefill( + bench_start_time_s = time.time() + mx.eval( + prefill( model, idxs=mx.ones([1, chunk_size], dtype=mx.int32), input_pos=mx.arange(0, chunk_size, dtype=mx.int32), ) - mx.eval(res) - - bench_start_time_s = time.time() - prefill( - model, - idxs=mx.ones([1, chunk_size], dtype=mx.int32), - input_pos=mx.arange(0, chunk_size, dtype=mx.int32), ) - mx.eval(res) bench_end_time_s = time.time() bench_ms = 1e3 * (bench_end_time_s - bench_start_time_s) bench_its = 1000 / bench_ms @@ -186,31 +190,26 @@ def _compile_decode_one( # Don't need to explicitly compile with mlx, instead we are just precalculating # the computation graphs for different shapes compile_start_time_s = time.time() - res = decode_one( - model, - idxs=mx.array([[0]], dtype=mx.int32), - input_pos=mx.array([0], dtype=mx.int32), - ) - mx.eval(res) + for _start_idx in range(0, MAX_SEQ_LEN, 4): + mx.eval( + decode_one( + model, + idxs=mx.array([[random.randint(0, 20)]], dtype=mx.int32), + input_pos=mx.array([_start_idx], dtype=mx.int32), + ), + ) logger.info( f"Finished compiling - took {time.time() - compile_start_time_s:.4f} seconds" ) - for _ in range(5): - res = decode_one( + bench_start_time_s = time.time() + mx.eval( + decode_one( model, idxs=mx.array([[0]], dtype=mx.int32), input_pos=mx.array([0], dtype=mx.int32), ) - mx.eval(res) - - bench_start_time_s = time.time() - decode_one( - model, - idxs=mx.array([[0]], dtype=mx.int32), - input_pos=mx.array([0], dtype=mx.int32), ) - mx.eval(res) bench_end_time_s = time.time() bench_ms = 1e3 * (bench_end_time_s - bench_start_time_s) bench_its = 1000 / bench_ms @@ -253,8 +252,8 @@ def load_model( init_start_time_s = time.time() model = TransformerLM(model_config) - model.load_weights(checkpoint_path) - nn.quantize(model.model, group_size=128, bits=8) + model.load_weights(checkpoint_path, strict=False) + nn.quantize(model.model, group_size=64, bits=8) model.eval() logger.info( @@ -274,8 +273,8 @@ def _first_bad_dur_index( ): num_time_toks = priming_seq[:chunk_start].count(tokenizer.time_tok) local_onset_ms = tokenizer.calc_length_ms( - priming_seq[:chunk_start], onset=True - ) + priming_seq[: chunk_start + 1], onset=True + ) # chunk_start + 1 to account for possibly truncated dur token logger.debug(f"Starting from local onset {local_onset_ms}") for pos, tok_id in enumerate( @@ -294,7 +293,7 @@ def _first_bad_dur_index( dur_pred = pred_tok[1] if dur_pred > dur_true and ( local_onset_ms + dur_true - > last_offset_ms - RECALC_DUR_BUFFER_MS + >= last_offset_ms - RECALC_DUR_BUFFER_MS ): logger.info( f"Found token to resample at {pos}: {prim_tok} -> {pred_tok}" @@ -311,12 +310,15 @@ def recalc_dur_tokens_chunked( tokenizer: AbsTokenizer, start_idx: int, ): - """Speculative-decoding inspired duration re-calculation""" + # Speculative-decoding inspired duration re-calculation assert start_idx > 0 logger = get_logger("GENERATE") priming_len = len(priming_seq) - last_offset = tokenizer.calc_length_ms(priming_seq) + last_offset = tokenizer.calc_length_ms(priming_seq, onset=False) + logger.debug( + f"Using threshold for duration recalculation: {last_offset - RECALC_DUR_BUFFER_MS}" + ) idx = start_idx while idx <= priming_len: @@ -331,12 +333,14 @@ def recalc_dur_tokens_chunked( logger.info( f"Recalculating chunked durations for positions: {idx-1} - {end_idx-2}" ) - logger.debug(f"Inserted: {tokenizer.decode(window_ids[0].tolist())}") - logger.debug(f"Positions: {window_pos.tolist()}") logits = prefill(model, idxs=window_ids, input_pos=window_pos) pred_ids = mx.argmax(logits, axis=-1).flatten().tolist() + logger.debug(f"Inserted: {tokenizer.decode(window_ids[0].tolist())}") + logger.debug(f"Positions: {window_pos.tolist()}") + logger.debug(f"Predictions: {tokenizer.decode(pred_ids)}") + bad_pos = _first_bad_dur_index( tokenizer=tokenizer, priming_seq=priming_seq, @@ -352,18 +356,13 @@ def recalc_dur_tokens_chunked( new_id = pred_ids[bad_pos - idx] enc_seq[0, bad_pos] = new_id priming_seq[bad_pos] = tokenizer.id_to_tok[new_id] - idx = bad_pos + idx = bad_pos + 1 next_logits = logits[:, priming_len - idx] return enc_seq, priming_seq, next_logits -# TODO: This is now the latency bottleneck. -# Ideas for reducing it: -# - Get rid of the manual time_tok insert stuff, instead just mask logits -# for all invalid tokens, this should force the model to sample a time tok -# if there aren't any other valid options def decode_first_tokens( model: TransformerLM, first_token_logits: mx.array, @@ -375,8 +374,10 @@ def decode_first_tokens( ): logger = get_logger("GENERATE") - buffer_ms = FIRST_ONSET_BUFFER_MS + HARDWARE_LATENCY_MS + buffer_ms = FIRST_ONSET_BUFFER_MS time_tok_id = tokenizer.tok_to_id[tokenizer.time_tok] + eos_tok_id = tokenizer.tok_to_id[tokenizer.eos_tok] + dim_tok_id = tokenizer.tok_to_id[tokenizer.dim_tok] logits = first_token_logits time_since_first_onset_ms = get_epoch_time_ms() - first_on_msg_epoch_ms @@ -398,26 +399,29 @@ def decode_first_tokens( logger.info(f"Inserted time_tok at position {idx-1}") num_time_toks_to_add -= 1 - enc_seq[:, idx - 1] = mx.array([[time_tok_id]], dtype=mx.int32) + enc_seq[:, idx - 1] = time_tok_id idx += 1 logits[:, tokenizer.tok_to_id[tokenizer.dim_tok]] = float("-inf") logits[:, tokenizer.tok_to_id[tokenizer.eos_tok]] = float("-inf") + # MLX doesn't have a equivalent of torch topk log_probs = nn.log_softmax(logits, axis=-1) - top_log_probs = mx.topk(log_probs, k=BEAM_WIDTH, axis=-1) - top_ids = mx.argsort(log_probs, axis=-1)[..., -BEAM_WIDTH:] + top_ids = mx.argsort(log_probs, axis=-1)[0, -BEAM_WIDTH:] + top_log_probs = log_probs[0, top_ids] - if time_tok_id not in top_ids[0].tolist(): - top_ids[0, -1] = time_tok_id - top_log_probs[0, -1] = log_probs[0, time_tok_id] + TIME_TOK_WEIGHTING + # top_log_probs are sorted in ascending order + if time_tok_id not in top_ids.tolist(): + top_ids[0] = time_tok_id + top_log_probs[0] = log_probs[0, time_tok_id] - top_toks = [tokenizer.id_to_tok[id] for id in top_ids[0].tolist()] + _time_tok_idx = top_ids.tolist().index(time_tok_id) + top_log_probs[_time_tok_idx] += TIME_TOK_WEIGHTING + + top_toks = [tokenizer.id_to_tok[id] for id in top_ids.tolist()] logger.debug(f"Calculated top {BEAM_WIDTH} tokens={top_toks}") - logger.debug( - f"Calculated top {BEAM_WIDTH} scores={top_log_probs[0].tolist()}" - ) + logger.debug(f"Calculated top {BEAM_WIDTH} scores={top_log_probs.tolist()}") masked_onset_ids = [ tokenizer.tok_to_id[tok] @@ -432,8 +436,8 @@ def decode_first_tokens( best_score = float("-inf") for i in range(BEAM_WIDTH): tok = top_toks[i] - tok_id = top_ids[0, i].item() - tok_log_prob = top_log_probs[0, i] + tok_id = top_ids[i].item() + tok_log_prob = top_log_probs[i] next_logits = decode_one( model, @@ -444,9 +448,10 @@ def decode_first_tokens( f"Sampled logits for positions {idx} by inserting {tok} at position {idx-1}" ) - # Is float("-inf") masking ok in mlx? next_log_probs = nn.log_softmax(next_logits, axis=-1) next_log_probs[:, masked_onset_ids] = float("-inf") + next_log_probs[:, eos_tok_id] = float("-inf") + next_log_probs[:, dim_tok_id] = float("-inf") if tok_id == time_tok_id: next_log_probs[:, time_tok_id] = float("-inf") @@ -476,10 +481,12 @@ def decode_first_tokens( generated_tokens_queue.put(tokenizer.id_to_tok[best_tok_id_1]) generated_tokens_queue.put(tokenizer.id_to_tok[best_tok_id_2]) - decode_one( - model, - idxs=mx.array([[best_tok_id_1]], dtype=mx.int32), - input_pos=mx.array([idx - 1], dtype=mx.int32), + mx.eval( + decode_one( + model, + idxs=mx.array([[best_tok_id_1]], dtype=mx.int32), + input_pos=mx.array([idx - 1], dtype=mx.int32), + ) ) logger.info( @@ -501,6 +508,7 @@ def decode_tokens( idx: int, temperature: float, min_p: float, + is_ending: bool, ): logger = get_logger("GENERATE") logger.info( @@ -523,6 +531,9 @@ def decode_tokens( ) logits[:, tokenizer.tok_to_id[tokenizer.dim_tok]] = float("-inf") + if is_ending is False: + logits[:, tokenizer.tok_to_id[tokenizer.eos_tok]] = float("-inf") + for dur_ms in range(0, MIN_NOTE_LEN_MS, 10): logits[:, tokenizer.tok_to_id[("dur", dur_ms)]] = float("-inf") @@ -546,9 +557,6 @@ def decode_tokens( generated_tokens_queue.put(next_token) idx += 1 - while not control_sentinel.is_set(): - time.sleep(0.1) - logger.info("Seen exit signal") generated_tokens_queue.put(None) @@ -562,13 +570,15 @@ def generate_tokens( generated_tokens_queue: queue.Queue, num_preceding_active_pitches: int, first_on_msg_epoch_ms: int, - temperature: float = 0.97, + temperature: float = 0.98, min_p: float = 0.03, + is_ending: bool = False, ): logger = get_logger("GENERATE") generate_start_s = time.time() priming_seq_len = len(priming_seq) + start_idx = max(2, priming_seq_len - 4 * num_preceding_active_pitches - 1) enc_seq = mx.array( [ @@ -638,6 +648,7 @@ def generate_tokens( idx=idx, temperature=temperature, min_p=min_p, + is_ending=is_ending, ) @@ -671,7 +682,7 @@ def decode_tokens_to_midi( end_msg = { "pitch": -1, "vel": -1, - "epoch_time_ms": offset_epoch_ms + 250, # Last note offset + "epoch_time_ms": offset_epoch_ms + 100, # Last note offset "uuid": _uuid, } # pitch=-1 denotes end_msg outbound_midi_msg_queue.put(end_msg) @@ -743,11 +754,11 @@ def decode_tokens_to_midi( note_buffer = [] -# TODO: Test the new changes in decode_tokens_to_midi and clean this fn up. +# TODO: Refactor for readability def stream_midi( inbound_midi_msg_queue: queue.Queue, msgs: list[mido.Message], - prev_msg_epoch_time_ms: float, + last_channel_msg_epoch_time_ms: float, midi_output_port: str, control_sentinel: threading.Event, midi_stream_channel: int, @@ -760,7 +771,6 @@ def stream_midi( logger.info( f"Applying hardware latency adjustment: {HARDWARE_LATENCY_MS}ms" ) - MAX_DELAY_MS = 50 active_pitch_uuid = {} is_pitch_active = {} @@ -797,7 +807,7 @@ def stream_midi( if ( 0 < latency_adjusted_epoch_time_ms - msg["epoch_time_ms"] - <= MAX_DELAY_MS + <= MAX_STREAM_DELAY_MS ): if msg["pitch"] == -1: # End msg control_sentinel.set() @@ -832,20 +842,22 @@ def stream_midi( mido_msg_with_time = copy.deepcopy(mido_msg) mido_msg_with_time.channel = midi_stream_channel mido_msg_with_time.time = max( - 0, msg["epoch_time_ms"] - prev_msg_epoch_time_ms + 0, + msg["epoch_time_ms"] + - last_channel_msg_epoch_time_ms, ) - prev_msg_epoch_time_ms = msg["epoch_time_ms"] + last_channel_msg_epoch_time_ms = msg["epoch_time_ms"] msgs.append(mido_msg_with_time) midi_msgs.pop(0) elif ( latency_adjusted_epoch_time_ms - msg["epoch_time_ms"] - > MAX_DELAY_MS + > MAX_STREAM_DELAY_MS ): # Message occurs too far in the past logger.debug( - f"Skipping message occurring too far ({latency_adjusted_epoch_time_ms - msg["epoch_time_ms"]}ms) in the past: {msg}" + f"Skipping message occurring too far ({latency_adjusted_epoch_time_ms - msg['epoch_time_ms']}ms) in the past: {msg}" ) midi_msgs.pop(0) else: @@ -862,43 +874,20 @@ def stream_midi( ] logger.info("Processing remaining note_off messages") - for __msg in remaining_note_off_messages: - logger.debug(remaining_note_off_messages) - for msg in remaining_note_off_messages: mido_msg = mido.Message( "note_on", note=msg["pitch"], velocity=0, channel=midi_stream_channel, - time=msg["epoch_time_ms"] - prev_msg_epoch_time_ms, + time=msg["epoch_time_ms"] - last_channel_msg_epoch_time_ms, ) - prev_msg_epoch_time_ms = msg["epoch_time_ms"] + midi_out.send(mido_msg) + last_channel_msg_epoch_time_ms = msg["epoch_time_ms"] msgs.append(mido_msg) results_queue.put(msgs) - while remaining_note_off_messages: - msg = remaining_note_off_messages.pop(0) - while True: - latency_adjusted_epoch_time_ms = ( - get_epoch_time_ms() + HARDWARE_LATENCY_MS - ) - - if 0 < latency_adjusted_epoch_time_ms - msg["epoch_time_ms"]: - mido_msg = mido.Message( - "note_on", - note=msg["pitch"], - velocity=0, - channel=midi_stream_channel, - time=0, # Does not matter as only used for streaming - ) - midi_out.send(mido_msg) - logger.info(f"Sent message: {mido_msg}") - break - else: - time.sleep(0.01) - def stream_msgs( model: TransformerLM, @@ -938,6 +927,7 @@ def stream_msgs( "min_p": min_p, "num_preceding_active_pitches": num_preceding_active_pitches, "first_on_msg_epoch_ms": first_on_msg_epoch_ms, + "is_ending": is_ending, }, ) generate_tokens_thread.start() @@ -956,7 +946,9 @@ def stream_msgs( ) decode_tokens_to_midi_thread.start() - prev_ms_epoch_time_ms = ( + # If ending==True then previous MIDI message on midi_stream_channel occurs + # at first_on_msg_epoch_ms. + prev_channel_msg_epoch_time_ms = ( first_on_msg_epoch_ms + tokenizer.calc_length_ms(priming_seq, onset=False) if is_ending is False @@ -969,7 +961,7 @@ def stream_msgs( kwargs={ "inbound_midi_msg_queue": midi_messages_queue, "msgs": msgs, - "prev_msg_epoch_time_ms": prev_ms_epoch_time_ms, + "last_channel_msg_epoch_time_ms": prev_channel_msg_epoch_time_ms, "midi_output_port": midi_output_port, "control_sentinel": control_sentinel, "midi_stream_channel": midi_stream_channel, @@ -989,7 +981,6 @@ def stream_msgs( return msgs -# TODO: Channel 9 issues here? def convert_msgs_to_midi(msgs: list[mido.Message]): channel_to_track = { chan: mido.MidiTrack() @@ -1025,6 +1016,7 @@ def _find_divergence( prev_context: list, curr_context: list, logger: logging.Logger, + tokenizer: AbsTokenizer, ): agreement_index = 0 for prev_val, curr_val in zip(prev_context, curr_context): @@ -1032,7 +1024,7 @@ def _find_divergence( agreement_index += 1 else: logger.info( - f"Found divergence at position {agreement_index + 1}: {curr_val}, {prev_val}" + f"Found divergence at idx {agreement_index}: {tokenizer.id_to_tok[curr_val]}, {tokenizer.id_to_tok[prev_val]}" ) break @@ -1052,9 +1044,13 @@ def chunked_prefill( assert tokenizer.pad_id not in curr_context logger = get_logger("PREFILL") + while True: prefill_idx, prefill_toks = _find_divergence( - prev_context, curr_context, logger=logger + prev_context, + curr_context, + logger=logger, + tokenizer=tokenizer, ) num_prefill_toks = len(prefill_toks) logger.debug(f"Tokens to prefill: {len(prefill_toks)}") @@ -1064,17 +1060,19 @@ def chunked_prefill( f"Prefilling {PREFILL_CHUNK_SIZE} tokens from idx={prefill_idx}" ) - prefill( - model, - idxs=mx.array( - [prefill_toks[:PREFILL_CHUNK_SIZE]], - dtype=mx.int32, - ), - input_pos=mx.arange( - prefill_idx, - prefill_idx + PREFILL_CHUNK_SIZE, - dtype=mx.int32, - ), + mx.eval( + prefill( + model, + idxs=mx.array( + [prefill_toks[:PREFILL_CHUNK_SIZE]], + dtype=mx.int32, + ), + input_pos=mx.arange( + prefill_idx, + prefill_idx + PREFILL_CHUNK_SIZE, + dtype=mx.int32, + ), + ) ) prev_context = curr_context[: prefill_idx + PREFILL_CHUNK_SIZE] @@ -1085,14 +1083,16 @@ def chunked_prefill( prefill_toks += (PREFILL_CHUNK_SIZE - len(prefill_toks)) * [ tokenizer.pad_id ] - prefill( - model, - idxs=mx.array([prefill_toks], dtype=mx.int32), - input_pos=mx.array( - prefill_idx, - prefill_idx + PREFILL_CHUNK_SIZE, - dtype=mx.int32, - ), + mx.eval( + prefill( + model, + idxs=mx.array([prefill_toks], dtype=mx.int32), + input_pos=mx.arange( + prefill_idx, + prefill_idx + PREFILL_CHUNK_SIZE, + dtype=mx.int32, + ), + ) ) prev_context = curr_context break @@ -1131,7 +1131,7 @@ def continuous_prefill( msgs.append(msg) msg_cnt += 1 - if (msg_cnt >= 5 or seen_sentinel) and len(msgs) > 10: + if (msg_cnt >= 10 or seen_sentinel) and len(msgs) > 30: midi = convert_msgs_to_midi(msgs=msgs) midi_dict = MidiDict(**midi_to_dict(midi)) curr_context = tokenizer.encode( @@ -1156,6 +1156,7 @@ def capture_and_update_kv( msgs: list, prev_context: list, control_sentinel: threading.Event, + wait_for_close: bool, midi_input_port: str, midi_capture_channel: int, midi_control_signal: int | None = None, @@ -1175,6 +1176,7 @@ def capture_and_update_kv( "midi_through_port": midi_through_port, "first_msg_epoch_time_ms": first_msg_epoch_time_ms, "results_queue": results_queue, + "wait_for_close": wait_for_close, }, ) capture_midi_thread.start() @@ -1200,6 +1202,7 @@ def capture_midi_input( midi_control_signal: int | None = None, midi_through_port: str | None = None, first_msg_epoch_time_ms: int | None = None, + wait_for_close: bool = False, ): logger = get_logger("CAPTURE") active_pitches = set() @@ -1219,7 +1222,9 @@ def capture_midi_input( else None ) - while not control_sentinel.is_set(): + while not control_sentinel.is_set() or ( + wait_for_close and active_pitches + ): msg = midi_input.receive(block=False) if msg is None: @@ -1262,8 +1267,8 @@ def capture_midi_input( and msg.value > 0 ): control_sentinel.set() + logger.info("Control signal seen") - logger.info("Control signal seen") logger.info(f"Active pitches: {active_pitches}") num_active_pitches = len(active_pitches) @@ -1312,10 +1317,18 @@ def capture_midi_input( def play_midi_file(midi_port: str, midi_path: str): logger = get_logger("FILE") logger.info(f"Playing file at {midi_path} on MIDI port '{midi_port}'") + + midi_dict = MidiDict.from_midi(midi_path) + + if MIN_NOTE_DELTA_MS: + midi_dict.enforce_gaps(min_gap_ms=MIN_NOTE_DELTA_MS) + + mid = midi_dict.to_midi() + time.sleep(1) active_pitches = [] with mido.open_output(midi_port) as output_port: - for msg in mido.MidiFile(midi_path).play(): + for msg in mid.play(): if msg.type == "note_on" and msg.velocity > 0: if msg.note in active_pitches: _off_msg = copy.deepcopy(msg) @@ -1335,26 +1348,26 @@ def play_midi_file(midi_port: str, midi_path: str): def listen_for_keypress_control_signal( control_sentinel: threading.Event, - end_sentinel: threading.Event, + generate_ending_sentinel: threading.Event, ): logger = get_logger("KEYBOARD") while True: - time.sleep(1) + time.sleep(3) _input = input() logger.info(f'Keypress seen "{_input}"') - control_sentinel.set() - - if _input == "e": - end_sentinel.set() + if _input == "": + control_sentinel.set() + else: + control_sentinel.set() + generate_ending_sentinel.set() + return -# TODO: Not tested +# TODO: Get rid of logic for end sentinel def listen_for_midi_control_signal( midi_input_port: str, control_sentinel: threading.Event, - end_sentinel: threading.Event, midi_control_signal: int | None = None, - midi_end_signal: int | None = None, ): with mido.open_input(midi_input_port) as midi_input: while True: @@ -1364,72 +1377,52 @@ def listen_for_midi_control_signal( elif ( msg.type == "control_change" and msg.control == midi_control_signal - and msg.value > 0 + and msg.value >= 64 ): control_sentinel.set() - elif ( - msg.type == "control_change" - and msg.control == midi_end_signal - and msg.value > 0 - ): - control_sentinel.set() - end_sentinel.set() def parse_args(): argp = argparse.ArgumentParser() - argp.add_argument("-cp", help="path to model checkpoint") - argp.add_argument("-midi_in", required=False, help="MIDI input port") - argp.add_argument("-midi_out", required=True, help="MIDI output port") + argp.add_argument("--checkpoint", help="path to model checkpoint") + argp.add_argument("--midi_in", required=False, help="MIDI input port") + argp.add_argument("--midi_out", required=True, help="MIDI output port") argp.add_argument( - "-midi_through", + "--midi_through", required=False, help="MIDI through port for received input", ) argp.add_argument( - "-midi_path", + "--midi_path", required=False, help="Use MIDI file instead of MIDI input port", ) argp.add_argument( - "-midi_control_signal", + "--midi_control_signal", type=int, help="MIDI control change message for AI takeover", ) argp.add_argument( - "-midi_end_signal", - type=int, - help="MIDI control change message to generate ending", - ) - argp.add_argument( - "-temp", + "--temp", help="sampling temperature value", type=float, required=False, default=0.95, ) argp.add_argument( - "-min_p", + "--min_p", help="sampling min_p value", type=float, required=False, default=0.03, ) argp.add_argument( - "-cfg", - help="sampling cfg gamma value", - type=float, - required=False, + "--wait_for_close", + help="wait for note-offs before generating", + action="store_true", ) argp.add_argument( - "-metadata", - nargs=2, - metavar=("KEY", "VALUE"), - action="append", - help="manually add metadata key-value pair when sampling", - ) - argp.add_argument( - "-save_path", + "--save_path", type=str, required=False, help="Path to save complete MIDI file", @@ -1439,17 +1432,19 @@ def parse_args(): # TODO: Need functionality for handing case where we run out of model context -# TODO: Make sure channel=9 (drum) case is covered -def main(): + + +def main(args): args = parse_args() logger = get_logger() tokenizer = AbsTokenizer() - model = load_model(checkpoint_path=args.cp) + model = load_model(checkpoint_path=args.checkpoint) model = compile_model(model=model) assert (args.midi_path and os.path.isfile(args.midi_path)) or args.midi_in if args.midi_path: - midi_input_port = "Midi Through:Midi Through Port-0" + # TODO: Don't hardcode this + midi_input_port = "IAC Driver Bus 1" play_file_thread = threading.Thread( target=play_midi_file, args=(midi_input_port, args.midi_path), @@ -1460,10 +1455,10 @@ def main(): midi_input_port = args.midi_in control_sentinel = threading.Event() - end_sentinel = threading.Event() + generate_ending_sentinel = threading.Event() keypress_thread = threading.Thread( target=listen_for_keypress_control_signal, - args=[control_sentinel, end_sentinel], + args=[control_sentinel, generate_ending_sentinel], daemon=True, ) midi_control_thread = threading.Thread( @@ -1471,9 +1466,7 @@ def main(): kwargs={ "midi_input_port": midi_input_port, "control_sentinel": control_sentinel, - "end_sentinel": end_sentinel, "midi_control_signal": args.midi_control_signal, - "midi_end_signal": args.midi_end_signal, }, daemon=True, ) @@ -1486,6 +1479,7 @@ def main(): msgs=[], prev_context=[], control_sentinel=control_sentinel, + wait_for_close=args.wait_for_close, midi_input_port=midi_input_port, midi_control_signal=args.midi_control_signal, midi_through_port=args.midi_through, @@ -1493,7 +1487,7 @@ def main(): ) ) - itt = 0 + curr_midi_channel = 0 while True: control_sentinel.clear() msgs = stream_msgs( @@ -1507,28 +1501,32 @@ def main(): temperature=args.temp, min_p=args.min_p, num_preceding_active_pitches=num_active_pitches, - midi_stream_channel=itt, + midi_stream_channel=curr_midi_channel, is_ending=False, ) - itt += 1 + curr_midi_channel += 1 + if curr_midi_channel == 9: + curr_midi_channel += 1 + control_sentinel.clear() - if end_sentinel.is_set(): + if generate_ending_sentinel.is_set(): break + else: + msgs, prev_context, _, num_active_pitches = capture_and_update_kv( + model=model, + msgs=msgs, + prev_context=prev_context, + control_sentinel=control_sentinel, + wait_for_close=args.wait_for_close, + midi_input_port=midi_input_port, + midi_control_signal=args.midi_control_signal, + midi_through_port=args.midi_through, + midi_capture_channel=curr_midi_channel, + first_msg_epoch_time_ms=first_on_msg_epoch_ms, + ) - msgs, prev_context, _, num_active_pitches = capture_and_update_kv( - model=model, - msgs=msgs, - prev_context=prev_context, - control_sentinel=control_sentinel, - midi_input_port=midi_input_port, - midi_control_signal=args.midi_control_signal, - midi_through_port=args.midi_through, - midi_capture_channel=itt, - first_msg_epoch_time_ms=first_on_msg_epoch_ms, - ) - - # TODO: There is a bug with the token somewhere? + # Generate ending msgs = stream_msgs( model=model, tokenizer=tokenizer, @@ -1540,7 +1538,7 @@ def main(): temperature=args.temp / 2, min_p=args.min_p, num_preceding_active_pitches=num_active_pitches, - midi_stream_channel=itt, + midi_stream_channel=curr_midi_channel, is_ending=True, ) @@ -1550,5 +1548,18 @@ def main(): midi.save(args.save_path) +def exit(midi_out_port: str): + with mido.open_output(midi_out_port) as out: + for note in range(128): + out.send(mido.Message("note_off", note=note, velocity=0)) + + if __name__ == "__main__": - main() + args = parse_args() + + try: + main(args) + except KeyboardInterrupt: + if args.midi_out: + exit(args.midi_out) + raise diff --git a/demo/dump.py b/demo/dump.py deleted file mode 100644 index d1a47d2..0000000 --- a/demo/dump.py +++ /dev/null @@ -1,75 +0,0 @@ -import mido -import sys -from mido import tick2second, second2tick - -# Check if correct number of arguments is provided -if len(sys.argv) != 3: - print("Usage: python script.py ") - sys.exit(1) - -# Get command line arguments -input_file = sys.argv[1] -target_seconds = float(sys.argv[2]) - -try: - mid = mido.MidiFile(input_file) -except Exception as e: - print(f"Error loading MIDI file: {e}") - sys.exit(1) - -curr_tick = 0 -idx = 0 -tempo = None - -# First get the tempo -for msg in mid.tracks[0]: - if msg.type == "set_tempo": - tempo = msg.tempo - break - -print(f"Found tempo: {tempo}") - -# Then find the right index -curr_tick = 0 -for idx, msg in enumerate(mid.tracks[0]): - curr_tick += msg.time - seconds = tick2second( - tick=curr_tick, - ticks_per_beat=mid.ticks_per_beat, - tempo=tempo, - ) - print(f"At index {idx}, time: {seconds:.2f} seconds") - if seconds > target_seconds: - print(f"Breaking at index {idx}") - break - -print(f"Inserting at index {idx}") - -# Insert the messages at the found index -mid.tracks[0].insert( - idx, - mido.Message( - type="control_change", - control=66, - value=127, - time=0, - ), -) -mid.tracks[0].insert( - idx + 1, - mido.Message( - type="control_change", - control=66, - value=0, - time=second2tick( - second=0.01, - ticks_per_beat=mid.ticks_per_beat, - tempo=tempo, - ), - ), -) - -# Generate output filename based on input filename -output_path = "/home/loubb/Dropbox/shared/test.mid" -mid.save(output_path) -print(f"Saved modified MIDI file to: {output_path}") diff --git a/requirements.txt b/requirements.txt index 5d6ef6d..6c397c7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,4 @@ torch >= 2.3 accelerate jsonlines tqdm -safetensors \ No newline at end of file +safetensors From 6e8aeab309df087b2dacfc3448ed54eb1b875b9c Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 3 Jun 2025 11:48:27 +0000 Subject: [PATCH 50/72] partial tree refactor for release --- aria/embedding.py | 78 ++ aria/embeddings/explore_midi.py | 112 --- aria/embeddings/pca.py | 120 --- aria/{embeddings => eval}/__init__.py | 0 .../evaluate.py => eval/linear_probe.py} | 76 +- aria/{embeddings => eval}/m3/__init__.py | 0 aria/{embeddings => eval}/m3/config.py | 0 aria/{embeddings => eval}/m3/emb.py | 4 +- aria/{embeddings => eval}/m3/utils.py | 0 aria/{embeddings => eval}/mert/__init__.py | 0 aria/{embeddings => eval}/mert/emb.py | 0 aria/inference/model.py | 29 +- aria/tokenizer.py | 309 ------ {tests => aria/training}/__init__.py | 0 .../classifier_finetune.py | 579 ++++++----- .../contrastive_finetune.py | 13 +- aria/training/train.py | 903 ++++++++++++++++++ config/accelerate.yaml | 16 - config/models/large.json | 9 - config/models/medium-composer.json | 25 +- config/models/medium-emotion.json | 11 + config/models/medium-form.json | 11 + config/models/medium-genre.json | 8 +- config/models/medium-music_period.json | 11 + config/models/medium-pianist.json | 11 + paper/scripts/build_aria_ft_emb_dataset.sh | 16 - .../build_dataset/build_aria_dataset.sh | 15 - .../build_dataset/build_clamp_dataset.sh | 13 - .../scripts/build_dataset/build_m3_dataset.sh | 15 - .../build_dataset/build_mert_dataset.sh | 15 - .../scripts/build_embedding_eval_datasets.py | 242 ----- .../scripts/evaluate_embedding_with_probe.py | 94 -- paper/scripts/evaluate_embeddings.sh | 10 - paper/scripts/make_eval_split.py | 133 --- paper/scripts/make_pianist8_dataset.py | 118 --- tests/test_data.py | 305 ------ tests/test_data/arabesque.mid | Bin 16975 -> 0 bytes tests/test_data/bach.mid | Bin 2790 -> 0 bytes tests/test_data/basic.mid | Bin 762 -> 0 bytes tests/test_data/beethoven_moonlight.mid | Bin 15362 -> 0 bytes tests/test_data/beethoven_sonata.mid | Bin 91825 -> 0 bytes tests/test_data/clean/1.mid | Bin 49388 -> 0 bytes tests/test_data/clean/2.mid | Bin 8447 -> 0 bytes tests/test_data/expressive.mid | Bin 4786 -> 0 bytes tests/test_data/noisy/1.mid | Bin 75600 -> 0 bytes tests/test_data/noisy/2.mid | Bin 25503 -> 0 bytes tests/test_data/pop.mid | Bin 12527 -> 0 bytes tests/test_data/pop_copy.mid | Bin 12527 -> 0 bytes tests/test_tokenizers.py | 535 ----------- 49 files changed, 1408 insertions(+), 2428 deletions(-) create mode 100644 aria/embedding.py delete mode 100644 aria/embeddings/explore_midi.py delete mode 100644 aria/embeddings/pca.py rename aria/{embeddings => eval}/__init__.py (100%) rename aria/{embeddings/evaluate.py => eval/linear_probe.py} (91%) rename aria/{embeddings => eval}/m3/__init__.py (100%) rename aria/{embeddings => eval}/m3/config.py (100%) rename aria/{embeddings => eval}/m3/emb.py (98%) rename aria/{embeddings => eval}/m3/utils.py (100%) rename aria/{embeddings => eval}/mert/__init__.py (100%) rename aria/{embeddings => eval}/mert/emb.py (100%) delete mode 100644 aria/tokenizer.py rename {tests => aria/training}/__init__.py (100%) rename aria/{embeddings => training}/classifier_finetune.py (52%) rename aria/{embeddings => training}/contrastive_finetune.py (98%) create mode 100644 aria/training/train.py delete mode 100644 config/accelerate.yaml delete mode 100644 config/models/large.json create mode 100644 config/models/medium-emotion.json create mode 100644 config/models/medium-form.json create mode 100644 config/models/medium-music_period.json create mode 100644 config/models/medium-pianist.json delete mode 100644 paper/scripts/build_aria_ft_emb_dataset.sh delete mode 100644 paper/scripts/build_dataset/build_aria_dataset.sh delete mode 100644 paper/scripts/build_dataset/build_clamp_dataset.sh delete mode 100644 paper/scripts/build_dataset/build_m3_dataset.sh delete mode 100644 paper/scripts/build_dataset/build_mert_dataset.sh delete mode 100644 paper/scripts/build_embedding_eval_datasets.py delete mode 100644 paper/scripts/evaluate_embedding_with_probe.py delete mode 100644 paper/scripts/evaluate_embeddings.sh delete mode 100644 paper/scripts/make_eval_split.py delete mode 100644 paper/scripts/make_pianist8_dataset.py delete mode 100644 tests/test_data.py delete mode 100644 tests/test_data/arabesque.mid delete mode 100644 tests/test_data/bach.mid delete mode 100644 tests/test_data/basic.mid delete mode 100644 tests/test_data/beethoven_moonlight.mid delete mode 100644 tests/test_data/beethoven_sonata.mid delete mode 100644 tests/test_data/clean/1.mid delete mode 100644 tests/test_data/clean/2.mid delete mode 100644 tests/test_data/expressive.mid delete mode 100644 tests/test_data/noisy/1.mid delete mode 100644 tests/test_data/noisy/2.mid delete mode 100644 tests/test_data/pop.mid delete mode 100644 tests/test_data/pop_copy.mid delete mode 100644 tests/test_tokenizers.py diff --git a/aria/embedding.py b/aria/embedding.py new file mode 100644 index 0000000..81a1cdd --- /dev/null +++ b/aria/embedding.py @@ -0,0 +1,78 @@ +import torch +import copy + +from ariautils.midi import MidiDict +from ariautils.tokenizer import AbsTokenizer +from ariautils.tokenizer._base import Token + +from aria.model import TransformerEMB + + +def _validate_midi_for_emb(midi_dict: MidiDict): + present_instruments = { + midi_dict.program_to_instrument[msg["data"]] + for msg in midi_dict.instrument_msgs + } + assert present_instruments == {"piano"}, "Only piano MIDIs supported" + assert len(midi_dict.note_msgs) > 0 + + +def _get_chunks(midi_dict: MidiDict, notes_per_chunk: int): + res = [] + + for note_msg_chunk in [ + midi_dict.note_msgs[idx : idx + notes_per_chunk] + for idx in range(0, len(midi_dict.note_msgs), notes_per_chunk) + ]: + if len(note_msg_chunk) == 0: + break + + chunked_midi_dict = copy.deepcopy(midi_dict) + chunked_midi_dict.note_msgs = note_msg_chunk + chunked_midi_dict.metadata = {} + res.append(chunked_midi_dict) + + return res + + +@torch.inference_mode() +def get_embedding_from_seq( + model: TransformerEMB, seq: list[Token], device="cuda" +): + tokenizer = AbsTokenizer() + + assert len(seq) <= 2048, "Sequence lengths above 2048 not supported" + _validate_midi_for_emb(tokenizer.detokenize(seq)) + + eos_pos = seq.index(tokenizer.eos_tok) + seq_enc = torch.tensor(tokenizer.encode(seq), device=device) + emb = model.forward(seq_enc.view(1, -1))[0, eos_pos] + + return emb + + +def get_global_embedding_from_midi( + model: TransformerEMB, + midi_dict: MidiDict | None = None, + midi_path: str | None = None, + notes_per_chunk: int = 300, + device="cuda", +): + """Calculates global contrastive embedding by calculating an unweighted + average of chunk embeddings of notes_per_chunk notes.""" + + assert midi_dict or midi_path + + if midi_path: + midi_dict = MidiDict.from_midi(mid_path=midi_path) + + tokenizer = AbsTokenizer() + _validate_midi_for_emb(midi_dict) + + chunks = _get_chunks(midi_dict=midi_dict, notes_per_chunk=notes_per_chunk) + seqs = [tokenizer.tokenize(c, add_dim_tok=False)[:2048] for c in chunks] + embs = [ + get_embedding_from_seq(model=model, seq=s, device=device) for s in seqs + ] + + return torch.mean(torch.stack(embs), dim=0) diff --git a/aria/embeddings/explore_midi.py b/aria/embeddings/explore_midi.py deleted file mode 100644 index 13d38a0..0000000 --- a/aria/embeddings/explore_midi.py +++ /dev/null @@ -1,112 +0,0 @@ -import copy -import torch - -from aria.config import load_model_config -from aria.utils import _load_weight -from ariautils.midi import MidiDict -from ariautils.tokenizer import AbsTokenizer -from aria.model import TransformerCL, ModelConfig - -TAG_IDS = { - "chopin": 0, - "bach": 1, - "beethoven": 2, - "liszt": 3, - "mozart": 4, - "debussy": 5, - "schumann": 6, - "schubert": 7, - "rachmaninoff": 8, - "brahms": 9, - "tchaikovsky": 10, - "haydn": 11, - "scriabin": 12, - "mendelssohn": 13, - "czerny": 14, - "ravel": 15, - "scarlatti": 16, - "other": 17, -} -ID_TO_TAG = {v: k for k, v in TAG_IDS.items()} - - -def explore_midi( - midi_path: str, - checkpoint_path: str, - metadata_category: str, - slice_len_notes: int = 500, - max_seq_len: int = 2048, -): - midi_dict = MidiDict.from_midi(midi_path) - print(midi_dict.instrument_msgs) - - tag = midi_dict.metadata.get(metadata_category, None) - if tag is not None and tag not in TAG_IDS: - tag = "other" - - note_msgs = midi_dict.note_msgs - slices = [ - note_msgs[i : i + slice_len_notes] - for i in range(0, len(note_msgs), slice_len_notes) - ] - slices = [s for s in slices if len(s) >= 20] - - print(f"Found {len(slices)} slices in the MIDI file.") - - tokenizer = AbsTokenizer() - model_config = ModelConfig(**load_model_config("medium-composer")) - model_config.set_vocab_size(tokenizer.vocab_size) - model_config.grad_checkpoint = False - model_state = _load_weight(checkpoint_path, device="cuda") - model = TransformerCL(model_config) - model.load_state_dict(model_state) - model.eval() - model.cuda() - - for idx, note_slice in enumerate(slices): - slice_midi = copy.deepcopy(midi_dict) - slice_midi.note_msgs = note_slice - slice_midi.metadata = {} - - tokenized_seq = tokenizer.tokenize(slice_midi) - tokenizer.detokenize(tokenized_seq).to_midi().save( - "/home/loubb/Dropbox/shared/test.mid" - ) - if tokenizer.dim_tok in tokenized_seq: - tokenized_seq.remove(tokenizer.dim_tok) - tokenized_seq = tokenized_seq[:max_seq_len] - if tokenizer.eos_tok not in tokenized_seq: - tokenized_seq[-1] = tokenizer.eos_tok - - tokenizer - encoded_seq = tokenizer.encode(tokenized_seq) - input_tensor = torch.tensor([encoded_seq]).cuda() - - # Forward pass - with torch.inference_mode(): - logits = model(input_tensor)[0, -1, :] - probs = torch.softmax(logits, dim=-1) - # Get the top 5 probabilities and their corresponding indices - top_probs, top_indices = torch.topk(probs, k=5) - formatted_top_probs = [ - float(f"{p:.4f}") for p in top_probs.tolist() - ] - top_tags = [ - ID_TO_TAG.get(idx.item(), "unknown") for idx in top_indices - ] - - print("Top 5 Predictions:") - for tag, prob in zip(top_tags, formatted_top_probs): - print(f"{tag}: {prob}") - - input("\nPress Enter to continue to the next slice...") - - -if __name__ == "__main__": - explore_midi( - midi_path="/home/loubb/Dropbox/shared/audio.mid", - checkpoint_path="/home/loubb/work/aria/models/medium-composer.safetensors", - metadata_category="composer", - slice_len_notes=150, - max_seq_len=512, - ) diff --git a/aria/embeddings/pca.py b/aria/embeddings/pca.py deleted file mode 100644 index 2d1c309..0000000 --- a/aria/embeddings/pca.py +++ /dev/null @@ -1,120 +0,0 @@ -import matplotlib.pyplot as plt -import json -import numpy as np -import pandas as pd -import seaborn as sns -from sklearn.decomposition import PCA -from sklearn.manifold import TSNE - - -# Flag to choose between t-SNE and PCA -use_tsne = True - -# Load data from the JSON file -with open("aria_embeddings.json", "r") as f: - data = json.load(f) - -# Define the set of top composers -top_composers = { - "chopin", - "bach", - "handel", - "haydn", - "tchaikovsky", - "scriabin", - "beethoven", - "liszt", - "mozart", - "debussy", - "schumann", - "schubert", - "satie", - "rachmaninoff", - "brahms", - "ravel", -} - - -# Filter the data to include only entries for the top composers -filtered_data = [entry for entry in data if entry["composer"] in top_composers] - -# Extract embeddings and composers from the filtered data -embeddings = np.array([entry["emb"] for entry in filtered_data]) -embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True) -composers = [entry["composer"].capitalize() for entry in filtered_data] - -# Perform dimensionality reduction based on the flag -if use_tsne: - reducer = TSNE( - n_components=2, perplexity=50, max_iter=2500, random_state=43 - ) - title = "t-SNE Visualization of Composer Embeddings" - filename = "/home/loubb/work/aria/tsne_plot.png" -else: - reducer = PCA(n_components=2) - filename = "/home/loubb/work/aria/pca_plot.png" - -embeddings_2d = reducer.fit_transform(embeddings) - -# Create a DataFrame for plotting -df = pd.DataFrame( - { - "Dimension 1": embeddings_2d[:, 0], - "Dimension 2": embeddings_2d[:, 1], - "Composer": composers, - } -) - -# Set the aesthetic style of the plots -sns.set_theme(style="whitegrid", font="Helvetica") - -# Create the scatter plot -plt.figure(figsize=(12, 8)) -scatter_plot = sns.scatterplot( - data=df, - x="Dimension 1", - y="Dimension 2", - hue="Composer", - palette="tab20", - s=50, # Marker size - edgecolor="w", - linewidth=0.5, -) - -plt.xlabel(None) # Remove x-axis label -plt.ylabel(None) # Remove y-axis label - -plt.xticks([]) # Remove numerical x-axis ticks -plt.yticks([]) # Remove numerical y-axis ticks - -plt.grid(True, linestyle="--", linewidth=0.5) # Keep the grid visible - -# Ensure grid is properly aligned -plt.gca().set_aspect("auto") # Prevent distortion -plt.gca().set_frame_on(True) # Keep figure frame -# plt.gca().set_xticks( -# np.linspace(df["Dimension 1"].min(), df["Dimension 1"].max(), num=6) -# ) -# plt.gca().set_yticks( -# np.linspace(df["Dimension 2"].min(), df["Dimension 2"].max(), num=6) -# ) - -# Move the legend outside the plot -plt.legend( - bbox_to_anchor=(0, -0.38, 1, 0), - loc="lower center", - ncol=4, # Arrange in multiple columns - fontsize=20, # Increase font size - columnspacing=1.05, # Increase the space between columns - title_fontsize=20, - title="Composer", -) - -# Set plot title and labels -# plt.title(title) - -# Save the plot as a high-resolution PNG file -plt.savefig(filename, dpi=300, bbox_inches="tight") - -# Display the plot -plt.show() diff --git a/aria/embeddings/__init__.py b/aria/eval/__init__.py similarity index 100% rename from aria/embeddings/__init__.py rename to aria/eval/__init__.py diff --git a/aria/embeddings/evaluate.py b/aria/eval/linear_probe.py similarity index 91% rename from aria/embeddings/evaluate.py rename to aria/eval/linear_probe.py index bd1ab2b..2db8c25 100644 --- a/aria/embeddings/evaluate.py +++ b/aria/eval/linear_probe.py @@ -16,9 +16,6 @@ from typing import Callable from concurrent.futures import ThreadPoolExecutor -from aria.model import ModelConfig, TransformerLM -from aria.config import load_model_config -from aria.utils import _load_weight from ariautils.midi import MidiDict from ariautils.tokenizer import AbsTokenizer @@ -63,6 +60,12 @@ "yiruma": 6, "hillsong": 7, }, + "emotion": { + "happy": 0, + "sad": 1, + "calm": 2, + "tense": 3, + }, } LEARNING_RATE = 3e-4 @@ -101,7 +104,7 @@ def process_entry( for slice_note_msgs in get_chunks( note_msgs=midi_dict.note_msgs, chunk_len=slice_len_notes ): - if len(slice_note_msgs) < 20: + if len(slice_note_msgs) == 0: break slice_midi_dict = copy.deepcopy(midi_dict) @@ -598,9 +601,6 @@ def _train( scheduler.step() lr_for_print = "{:.2e}".format(scheduler.get_last_lr()[0]) - if accelerator.is_main_process: - accelerator.save_state("/mnt/ssd1/aria/test") - return model @@ -633,7 +633,7 @@ def train_classifier( model=model, total_steps=num_epochs * len(train_dataloader), ) - accelerator = accelerate.Accelerator() + accelerator = accelerate.Accelerator(cpu=True) model, train_dataloader, optimizer, scheduler = accelerator.prepare( model, @@ -686,8 +686,10 @@ def evaluate_classifier( total_correct = sum(v["correct"] for v in dist.values()) total_samples = sum(v["total"] for v in dist.values()) - print(f"Total accuracy: {total_correct/total_samples}") + overall_accuracy = total_correct / total_samples + class_metrics = {} + f1_scores = [] for tag in tag_to_id.keys(): TP = dist[tag]["correct"] FN = dist[tag]["total"] - TP @@ -699,45 +701,21 @@ def evaluate_classifier( if (precision + recall) > 0 else 0 ) - print( - f"{tag} -- Accuracy: {TP/dist[tag]['total']}, Precision: {precision}, Recall: {recall}, F1: {f1}" - ) - - -# TODO: Move this to the build_embedding_eval_datasets.py script -def build_baseline_dataset(): - MAX_SEQ_LEN = 512 - MODEL_PATH = "/mnt/ssd1/aria/v2/medium-dedupe-pt-cont2/checkpoints/epoch18_step0/model.safetensors" - - tokenizer = AbsTokenizer() - model_state = _load_weight(MODEL_PATH, "cuda") - model_state = { - k.replace("_orig_mod.", ""): v for k, v in model_state.items() + tag_accuracy = TP / dist[tag]["total"] if dist[tag]["total"] > 0 else 0 + class_metrics[tag] = { + "accuracy": tag_accuracy, + "precision": precision, + "recall": recall, + "F1": f1, + } + f1_scores.append(f1) + + macro_f1 = sum(f1_scores) / len(f1_scores) if f1_scores else 0 + + results = { + "accuracy": overall_accuracy, + "F1-macro": macro_f1, + "class_wise": class_metrics, } - pretrained_model_config = ModelConfig(**load_model_config("medium")) - pretrained_model_config.set_vocab_size(tokenizer.vocab_size) - pretrained_model_config.grad_checkpoint = False - pretrained_model = TransformerLM(pretrained_model_config) - pretrained_model.load_state_dict(model_state) - pretrained_model.eval() - - global model_forward - model_forward = torch.compile( - model_forward, - mode="reduce-overhead", - fullgraph=True, - ) - EvaluationDataset.build( - midi_dataset_load_path="/mnt/ssd1/aria/data/mididict-ft_train.jsonl", - save_path="/mnt/ssd1/aria/data/train.jsonl", - max_seq_len=MAX_SEQ_LEN, - slice_len_notes=165, - batch_size=128, - embedding_hook=functools.partial( - get_baseline_embedding, pool_mode="mean" - ), - hook_model=pretrained_model.model.cuda(), - hook_max_seq_len=MAX_SEQ_LEN, - hook_tokenizer=tokenizer, - ) + return results diff --git a/aria/embeddings/m3/__init__.py b/aria/eval/m3/__init__.py similarity index 100% rename from aria/embeddings/m3/__init__.py rename to aria/eval/m3/__init__.py diff --git a/aria/embeddings/m3/config.py b/aria/eval/m3/config.py similarity index 100% rename from aria/embeddings/m3/config.py rename to aria/eval/m3/config.py diff --git a/aria/embeddings/m3/emb.py b/aria/eval/m3/emb.py similarity index 98% rename from aria/embeddings/m3/emb.py rename to aria/eval/m3/emb.py index cd68aea..e3af968 100644 --- a/aria/embeddings/m3/emb.py +++ b/aria/eval/m3/emb.py @@ -3,7 +3,7 @@ import mido from transformers import BertConfig, GPT2Config -from aria.embeddings.m3.config import ( +from aria.eval.m3.config import ( AUDIO_HIDDEN_SIZE, AUDIO_NUM_LAYERS, MAX_AUDIO_LENGTH, @@ -16,7 +16,7 @@ TOKEN_NUM_LAYERS, ) -from aria.embeddings.m3.utils import CLaMP3Model, M3Patchilizer, M3Model +from aria.eval.m3.utils import CLaMP3Model, M3Patchilizer, M3Model def msg_to_str(msg): diff --git a/aria/embeddings/m3/utils.py b/aria/eval/m3/utils.py similarity index 100% rename from aria/embeddings/m3/utils.py rename to aria/eval/m3/utils.py diff --git a/aria/embeddings/mert/__init__.py b/aria/eval/mert/__init__.py similarity index 100% rename from aria/embeddings/mert/__init__.py rename to aria/eval/mert/__init__.py diff --git a/aria/embeddings/mert/emb.py b/aria/eval/mert/emb.py similarity index 100% rename from aria/embeddings/mert/emb.py rename to aria/eval/mert/emb.py diff --git a/aria/inference/model.py b/aria/inference/model.py index d4bf439..707e200 100644 --- a/aria/inference/model.py +++ b/aria/inference/model.py @@ -43,6 +43,9 @@ def __init__(self, model_config: ModelConfig): self.lm_head = nn.Linear( model_config.d_model, model_config.vocab_size, bias=False ) + self.embedding_adapter = nn.Linear( + model_config.emb_size, model_config.d_model, bias=False + ) def forward( self, @@ -59,13 +62,17 @@ def forward( return logits + def fill_condition_kv(self, cond_emb: torch.Tensor): + adapted_emb = self.embedding_adapter(cond_emb) + self.model.fill_condition_kv(emb=adapted_emb) + def setup_cache( self, - batch_size, + batch_size: int, max_seq_len=4096, dtype=torch.bfloat16, ): - # Init cache + assert batch_size >= 1 for b in self.model.encode_layers: b.kv_cache = KVCache( max_batch_size=batch_size, @@ -101,7 +108,19 @@ def __init__(self, model_config: ModelConfig) -> None: self.out_layer_norm = nn.LayerNorm(model_config.d_model) self.freqs_cis = None - self.casual_mask = None + self.causal_mask = None + + def fill_condition_kv(self, emb: torch.Tensor): + assert self.freqs_cis is not None, "Caches must be initialized first" + + input_pos = torch.tensor([0], device=emb.device) + mask = self.causal_mask[None, None, input_pos] + freqs_cis = self.freqs_cis[input_pos] + + x = emb.unsqueeze(dim=1) + + for layer in self.encode_layers: + x = layer(x, input_pos, freqs_cis, mask) def forward( self, @@ -169,7 +188,6 @@ def __init__(self, model_config: ModelConfig) -> None: self.norm1 = nn.LayerNorm(model_config.d_model) self.norm2 = nn.LayerNorm(model_config.d_model) - # TODO: Fill in args self.kv_cache = None def forward( @@ -255,7 +273,8 @@ def precompute_freqs_cis( return cache.to(dtype=dtype) -@torch.jit.script +# TODO: Fix +# @torch.jit.script def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: """ In-place RoPE. Credits to Katherine Crowson: diff --git a/aria/tokenizer.py b/aria/tokenizer.py deleted file mode 100644 index 55ca846..0000000 --- a/aria/tokenizer.py +++ /dev/null @@ -1,309 +0,0 @@ -"""Tokenizer for MIDI conditioned completions""" - -import copy -import random -import functools - -from typing import Callable - -from aria.config import load_config -from ariautils.midi import MidiDict -from ariautils.tokenizer import AbsTokenizer as _AbsTokenizer - - -class InferenceAbsTokenizer(_AbsTokenizer): - def __init__(self): - super().__init__() - - self.name = "inference_abs" - self._config = load_config()["tokenizer"]["inference_abs"] - - self.prompt_start_tok = "" - self.prompt_end_tok = "" - self.guidance_start_tok = "" - self.guidance_end_tok = "" - - self.add_tokens_to_vocab( - [ - self.prompt_start_tok, - self.prompt_end_tok, - self.guidance_start_tok, - self.guidance_end_tok, - ] - ) - self.special_tokens.append(self.prompt_start_tok) - self.special_tokens.append(self.prompt_end_tok) - self.special_tokens.append(self.guidance_start_tok) - self.special_tokens.append(self.guidance_end_tok) - - def _get_guidance_interval_ms(self, guidance_midi_dict: MidiDict): - first_note_onset_ms = guidance_midi_dict.tick_to_ms( - guidance_midi_dict.note_msgs[0]["tick"] - ) - last_note_onset_ms = guidance_midi_dict.tick_to_ms( - guidance_midi_dict.note_msgs[-1]["tick"] - ) - guidance_segment_length_ms = random.randint( - self._config["guidance"]["min_ms"], - min(self._config["guidance"]["max_ms"], last_note_onset_ms), - ) - guidance_start_ms = random.randint( - first_note_onset_ms, - last_note_onset_ms - guidance_segment_length_ms, - ) - guidance_end_ms = guidance_start_ms + guidance_segment_length_ms - - return guidance_start_ms, guidance_end_ms - - def _get_guidance_seq( - self, - guidance_midi_dict: MidiDict, - guidance_start_ms: int | None = None, - guidance_end_ms: int | None = None, - ): - assert guidance_midi_dict.note_msgs is not None - - # Need to validate these numbers - if guidance_start_ms is None: - assert guidance_end_ms is None - guidance_start_ms, guidance_end_ms = self._get_guidance_interval_ms( - guidance_midi_dict=guidance_midi_dict - ) - - slice_note_msgs = [] - for note_msg in guidance_midi_dict.note_msgs: - start_ms = guidance_midi_dict.tick_to_ms(note_msg["data"]["start"]) - if guidance_start_ms <= start_ms <= guidance_end_ms: - slice_note_msgs.append(note_msg) - - slice_midi_dict = copy.deepcopy(guidance_midi_dict) - slice_midi_dict.note_msgs = slice_note_msgs - - if len(slice_midi_dict.note_msgs) == 0: - # Catches not note in interval - return [] - - guidance_seq = self._tokenize_midi_dict( - midi_dict=slice_midi_dict, - remove_preceding_silence=True, - ) - - if self.dim_tok in guidance_seq: - guidance_seq.remove(self.dim_tok) - - guidance_seq = guidance_seq[ - guidance_seq.index(self.bos_tok) - + 1 : guidance_seq.index(self.eos_tok) - ] - - return ( - [self.guidance_start_tok] + guidance_seq + [self.guidance_end_tok] - ) - - def _add_prompt_tokens( - self, seq: list, prompt_start_ms: int, prompt_end_ms: int - ): - res = copy.deepcopy(seq) - prompt_tok_inserted = False - time_tok_cnt = 0 - curr_time_ms = 0 - for idx, (tok_1, tok_2) in enumerate(zip(seq, seq[1:])): - if tok_1 == self.time_tok: - time_tok_cnt += 1 - elif isinstance(tok_1, tuple) and tok_1[0] in self.instruments_wd: - assert isinstance(tok_2, tuple) and tok_2[0] == "onset" - - # Adjust time - curr_time_ms = (self.abs_time_step_ms * time_tok_cnt) + tok_2[1] - - if ( - curr_time_ms >= prompt_start_ms - and prompt_tok_inserted == False - ): - res.insert(idx, self.prompt_start_tok) - prompt_tok_inserted = True - elif ( - curr_time_ms > prompt_end_ms and prompt_tok_inserted == True - ): - # Res has already been shifted +1 when inserting prompt_tok - res.insert(idx + 1, self.prompt_end_tok) - break - - if prompt_tok_inserted and self.prompt_end_tok not in res: - res.insert(-1, self.prompt_end_tok) - assert res[-1] == self.eos_tok - assert res[-2] == self.prompt_end_tok - - return res - - def tokenize( - self, - midi_dict: MidiDict, - prompt_intervals_ms: list[tuple[int, int]], - guidance_midi_dict: MidiDict | None = None, - guidance_start_ms: int | None = None, - guidance_end_ms: int | None = None, - ): - seq = self._tokenize_midi_dict( - midi_dict=midi_dict, remove_preceding_silence=True - ) - first_note_ms = midi_dict.tick_to_ms( - midi_dict.note_msgs[0]["data"]["start"] - ) - - for prompt_start_ms, prompt_end_ms in prompt_intervals_ms: - if prompt_end_ms > first_note_ms: - seq = self._add_prompt_tokens( - seq, - prompt_start_ms=prompt_start_ms - first_note_ms, - prompt_end_ms=prompt_end_ms - first_note_ms, - ) - - if guidance_midi_dict is not None: - guidance_seq = self._get_guidance_seq( - guidance_midi_dict=guidance_midi_dict, - guidance_start_ms=guidance_start_ms, - guidance_end_ms=guidance_end_ms, - ) - else: - guidance_seq = [] - - return guidance_seq + seq - - def detokenize(self, tokenized_seq: list, **kwargs): - if self.guidance_end_tok in tokenized_seq: - seq = tokenized_seq[tokenized_seq.index(self.guidance_end_tok) :] - else: - seq = tokenized_seq - - return super()._detokenize_midi_dict(seq, **kwargs) - - def export_data_aug(self): - return [ - self.export_guidance_tempo_aug(max_tempo_aug=0.2, mixup=True), - self.export_guidance_pitch_aug(3), - self.export_guidance_velocity_aug(2), - ] - - def export_guidance_aug_fn(self, aug_fn): - """Transforms augmentation function to only apply to guidance seq""" - - def _guidance_seq_aug_fn( - src: list, - _aug_fn: Callable, - pad_tok: str, - **kwargs, - ) -> list: - - initial_seq_len = len(src) - if self.guidance_start_tok in src and self.guidance_end_tok in src: - guidance_seq = src[ - src.index(self.guidance_start_tok) - + 1 : src.index(self.guidance_end_tok) - ] - seq = src[src.index(self.guidance_end_tok) + 1 :] - - if len(guidance_seq) == 0: - return src - else: - return src - - augmented_guidance_seq = _aug_fn(guidance_seq) - res = ( - [self.guidance_start_tok] - + augmented_guidance_seq - + [self.guidance_end_tok] - + seq - ) - - # Pad or truncate to original sequence length as necessary - res = res[:initial_seq_len] - res += [pad_tok] * (initial_seq_len - len(res)) - - return res - - return functools.partial( - _guidance_seq_aug_fn, - _aug_fn=aug_fn, - pad_tok=self.pad_tok, - ) - - def export_guidance_pitch_aug(self, max_pitch_aug: int): - """Apply pitch augmentation to the guidance sequence""" - - return self.export_guidance_aug_fn( - self.export_pitch_aug(max_pitch_aug=max_pitch_aug) - ) - - def export_guidance_velocity_aug(self, max_num_aug_steps: int): - """Apply velocity augmentation to the guidance sequence""" - - return self.export_guidance_aug_fn( - self.export_velocity_aug(max_num_aug_steps=max_num_aug_steps) - ) - - def export_guidance_tempo_aug(self, max_tempo_aug: int, mixup: bool): - """Apply tempo augmentation to the guidance sequence""" - - return self.export_guidance_aug_fn( - self.export_tempo_aug(max_tempo_aug=max_tempo_aug, mixup=mixup) - ) - - def split(self, seq: list, seq_len: int): - def _process_chunk(_chunk: list): - # Ensure first token is note token - while True: - if _chunk[0] == self.bos_tok: - break - elif ( - isinstance(_chunk[0], tuple) - and _chunk[0][0] in self.instruments_wd - ): - break - else: - _chunk.pop(0) - - # Insert prompt_start_tok if it is missing (but required) - for idx in range(len(_chunk)): - tok = _chunk[idx] - - if tok == self.prompt_start_tok: - break - elif tok == self.prompt_end_tok: - if _chunk[0] == self.bos_tok: - _chunk.insert(1, self.prompt_start_tok) - else: - _chunk.insert(0, self.prompt_start_tok) - break - - return _chunk - - guidance = [] - if self.guidance_start_tok in seq: - guidance_start = seq.index(self.guidance_start_tok) - guidance_end = seq.index(self.guidance_end_tok) - guidance = seq[guidance_start : guidance_end + 1] - seq = seq[guidance_end + 1 :] - - prefix = [] - while seq: - tok = seq[0] - if tok != self.bos_tok and tok[0] == "prefix": - prefix.append(seq.pop(0)) - else: - break - - chunks = [ - _process_chunk(seq[idx : idx + seq_len]) - for idx in range(0, len(seq) - 100, seq_len) - ] - - res = [] - for chunk in chunks: - sub_seq = guidance + prefix + chunk - sub_seq = sub_seq[:seq_len] - sub_seq += [self.pad_tok] * (seq_len - len(sub_seq)) - - res.append(sub_seq) - - return res diff --git a/tests/__init__.py b/aria/training/__init__.py similarity index 100% rename from tests/__init__.py rename to aria/training/__init__.py diff --git a/aria/embeddings/classifier_finetune.py b/aria/training/classifier_finetune.py similarity index 52% rename from aria/embeddings/classifier_finetune.py rename to aria/training/classifier_finetune.py index ac4b68a..ccc7712 100644 --- a/aria/embeddings/classifier_finetune.py +++ b/aria/training/classifier_finetune.py @@ -3,48 +3,70 @@ import mmap import argparse import logging -import random -import copy -import functools import accelerate -import multiprocessing import json -import jsonlines from aria.config import load_model_config from aria.utils import _load_weight from ariautils.tokenizer import AbsTokenizer -from ariautils.midi import MidiDict from aria.model import TransformerCL, ModelConfig from torch import nn from torch.utils.data import DataLoader, Dataset from accelerate.logging import get_logger +from typing import Callable from logging.handlers import RotatingFileHandler from tqdm import tqdm -TAG_IDS = { - "chopin": 0, - "bach": 1, - "beethoven": 2, - "liszt": 3, - "mozart": 4, - "debussy": 5, - "schumann": 6, - "schubert": 7, - "rachmaninoff": 8, - "brahms": 9, - "tchaikovsky": 10, - "haydn": 11, - "scriabin": 12, - "mendelssohn": 13, - "czerny": 14, - "ravel": 15, - "scarlatti": 16, - "other": 17, +CATEGORY_TAGS = { + "genre": { + "classical": 0, + "jazz": 1, + }, + "music_period": { + "baroque": 0, + "classical": 1, + "romantic": 2, + "impressionist": 3, + }, + "composer": { + "beethoven": 0, + "debussy": 1, + "brahms": 2, + "rachmaninoff": 3, + "schumann": 4, + "mozart": 5, + "liszt": 6, + "bach": 7, + "chopin": 8, + "schubert": 9, + }, + "form": { + "nocturne": 0, + "sonata": 1, + "improvisation": 2, + "etude": 3, + "fugue": 4, + "waltz": 5, + }, + "pianist": { + "hisaishi": 0, + "hancock": 1, + "bethel": 2, + "einaudi": 3, + "clayderman": 4, + "ryuichi": 5, + "yiruma": 6, + "hillsong": 7, + }, + "emotion": { + "happy": 0, + "sad": 1, + "calm": 2, + "tense": 3, + }, } -METADATA_CATEGORY = "composer" def setup_logger(project_dir: str): @@ -121,72 +143,30 @@ def setup_project_dir(project_dir: str | None): return project_dir_abs -def process_entry( - entry, - metadata_category: str, - tag_ids: dict, - min_slice_notes: int, - max_slice_notes: int, - max_seq_len: int, - tokenizer: AbsTokenizer, -): - midi_dict = MidiDict.from_msg_dict(entry) - metadata_tag = midi_dict.metadata.get(metadata_category, None) - - # Skip if metadata tag is missing or not in tag_ids. - if metadata_tag is None: - return [] - elif metadata_tag not in tag_ids: - metadata_tag = "other" - - outputs = [] - note_msgs = midi_dict.note_msgs - idx = 0 - - while idx < len(note_msgs): - slice_length = random.randint(min_slice_notes, max_slice_notes) - chunk = note_msgs[idx : idx + slice_length] - - # If the chunk is too short, break out of the loop. - if len(chunk) < min_slice_notes: - break - - idx += slice_length - - # Create slice - slice_midi_dict = copy.deepcopy(midi_dict) - slice_midi_dict.note_msgs = chunk - slice_midi_dict.metadata = {} - - # Format - tokenized_slice = tokenizer.tokenize(slice_midi_dict) - if tokenizer.dim_tok in tokenized_slice: - tokenized_slice.remove(tokenizer.dim_tok) - - # Use EOS tok for classification head - tokenized_slice = tokenized_slice[:max_seq_len] - tokenized_slice += [tokenizer.pad_tok] * ( - max_seq_len - len(tokenized_slice) - ) - if tokenizer.eos_tok not in tokenized_slice: - tokenized_slice[-1] = tokenizer.eos_tok - - pos = tokenized_slice.index(tokenizer.eos_tok) - - outputs.append( - {"seq": tokenized_slice, "tag": metadata_tag, "pos": pos} - ) - - return outputs - - class FinetuningDataset(Dataset): - def __init__(self, load_path: str, tag_ids: dict): + def __init__( + self, + load_path: str, + tag_to_id: dict, + metadata_category: str, + max_seq_len: int, + per_file: bool = False, + ): self.load_path = load_path - self.tag_ids = tag_ids + self.tag_to_id = tag_to_id + self.metadata_category = metadata_category + self.max_seq_len = max_seq_len + self.per_file = per_file + self._transform = None self.tokenizer = AbsTokenizer() self.index = [] + assert metadata_category in CATEGORY_TAGS.keys() + assert all( + tag_to_id[_t] == _id + for _t, _id in CATEGORY_TAGS[metadata_category].items() + ) + self.file_buff = open(self.load_path, "rb") self.mmap_obj = mmap.mmap( self.file_buff.fileno(), 0, access=mmap.ACCESS_READ @@ -199,6 +179,24 @@ def __init__(self, load_path: str, tag_ids: dict): break self.index.append(pos) + def set_transform(self, transform: Callable | list[Callable]): + if isinstance(transform, Callable): + self._transform = transform + elif isinstance(transform, list): + # Check validity + for fn in transform: + assert isinstance(fn, Callable), "Invalid function" + + # Define new transformation function (apply fn in order) + def _new_transform(x): + for fn in transform: + x = fn(x) + return x + + self._transform = _new_transform + else: + raise ValueError("Must provide function or list of functions.") + def __getitem__(self, idx: int): def _format(tok): # Required because json formats tuples into lists @@ -206,24 +204,53 @@ def _format(tok): return tuple(tok) return tok - file_pos = self.index[idx] - self.mmap_obj.seek(file_pos) - + pos = self.index[idx] + self.mmap_obj.seek(pos) raw_data = self.mmap_obj.readline().decode("utf-8") json_data = json.loads(raw_data) - seq, tag, pos = json_data["seq"], json_data["tag"], json_data["pos"] - assert tag in self.tag_ids.keys() - assert pos < len(seq) + metadata = json_data["metadata"] + tag = metadata[self.metadata_category] - seq = [_format(tok) for tok in seq] - seq_enc = torch.tensor(self.tokenizer.encode(seq)) - tag_enc = torch.tensor(self.tag_ids[tag]) - pos_enc = torch.tensor(pos) + assert tag in self.tag_to_id, metadata + tag_tensor = torch.tensor(self.tag_to_id[tag]) + + if self.per_file: + seq_list = json_data["seqs"] + else: + seq_list = [json_data["seq"]] - assert seq_enc[pos_enc.item()].item() == 1 # EOS ID + seq_tensors = [] + pos_tensors = [] + for seq in seq_list: + seq = [_format(tok) for tok in seq] - return seq_enc, tag_enc, pos_enc + if self._transform: + seq = self._transform(seq) + + seq = seq[: self.max_seq_len] + if self.tokenizer.eos_tok not in seq: + assert self._transform is not None + seq[-1] = self.tokenizer.eos_tok + + eos_index = seq.index(self.tokenizer.eos_tok) + pos_tensor = torch.tensor(eos_index) + + assert len(seq) <= self.max_seq_len + + seq = seq + [self.tokenizer.pad_tok] * (self.max_seq_len - len(seq)) + encoded_seq = self.tokenizer.encode(seq) + seq_tensor = torch.tensor(encoded_seq) + + assert seq_tensor[pos_tensor.item()].item() == 1 # EOS ID check + + seq_tensors.append(seq_tensor) + pos_tensors.append(pos_tensor) + + seq_tensor = torch.stack(seq_tensors) + pos_tensor = torch.stack(pos_tensors) + + return seq_tensor, pos_tensor, tag_tensor def __len__(self): return len(self.index) @@ -242,48 +269,6 @@ def worker_init_fn(worker_id: int): return worker_init_fn - @classmethod - def build( - cls, - midi_dataset_load_path: str, - save_path: str, - min_slice_notes: int, - max_slice_notes: int, - max_seq_len: int, - metadata_category: str, - tag_ids: dict, - ): - assert os.path.isfile(midi_dataset_load_path) - assert os.path.isfile(save_path) is False - - tokenizer = AbsTokenizer() - - with jsonlines.open( - midi_dataset_load_path, "r" - ) as midi_dataset, jsonlines.open(save_path, "w") as writer: - - cnt = 0 - with multiprocessing.Pool() as pool: - for result in pool.imap_unordered( - functools.partial( - process_entry, - metadata_category=metadata_category, - tag_ids=tag_ids, - min_slice_notes=min_slice_notes, - max_slice_notes=max_slice_notes, - max_seq_len=max_seq_len, - tokenizer=tokenizer, - ), - midi_dataset, - chunksize=10, - ): - cnt += 1 - if cnt % 500 == 0: - print(f"Completed {cnt}") - - for chunk in result: - writer.write(chunk) - def _get_optim( lr: float, @@ -291,7 +276,7 @@ def _get_optim( num_epochs: int, steps_per_epoch: int, warmup: int = 100, - end_ratio: int = 0.1, + end_ratio: float = 0.1, ): optimizer = torch.optim.AdamW( model.parameters(), @@ -301,24 +286,33 @@ def _get_optim( eps=1e-5, ) - warmup_lrs = torch.optim.lr_scheduler.LinearLR( - optimizer, - start_factor=0.000001, - end_factor=1, - total_iters=warmup, - ) - linear_decay_lrs = torch.optim.lr_scheduler.LinearLR( - optimizer, - start_factor=1, - end_factor=end_ratio, - total_iters=(num_epochs * steps_per_epoch) - warmup, - ) + total_steps = num_epochs * steps_per_epoch - lr_scheduler = torch.optim.lr_scheduler.SequentialLR( - optimizer, - schedulers=[warmup_lrs, linear_decay_lrs], - milestones=[warmup], - ) + if warmup > 0: + warmup_lrs = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=0.000001, + end_factor=1, + total_iters=warmup, + ) + linear_decay_lrs = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=1, + end_factor=end_ratio, + total_iters=total_steps - warmup, + ) + lr_scheduler = torch.optim.lr_scheduler.SequentialLR( + optimizer, + schedulers=[warmup_lrs, linear_decay_lrs], + milestones=[warmup], + ) + else: + lr_scheduler = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=1, + end_factor=end_ratio, + total_iters=total_steps, + ) return optimizer, lr_scheduler @@ -330,7 +324,7 @@ def get_optim( ): LR = 1e-5 END_RATIO = 0.1 - WARMUP_STEPS = 1000 + WARMUP_STEPS = 0 return _get_optim( lr=LR, @@ -345,31 +339,46 @@ def get_optim( def get_dataloaders( train_data_path: str, val_data_path: str, + metadata_category: str, + tag_to_id: dict, batch_size: int, num_workers: int, - apply_aug=True, + apply_aug: bool = False, + max_seq_len: int = 1024, ): train_dataset = FinetuningDataset( load_path=train_data_path, - tag_ids=TAG_IDS, + tag_to_id=tag_to_id, + metadata_category=metadata_category, + max_seq_len=max_seq_len, ) val_dataset = FinetuningDataset( load_path=val_data_path, - tag_ids=TAG_IDS, + tag_to_id=tag_to_id, + metadata_category=metadata_category, + max_seq_len=max_seq_len, + per_file=True, ) + if apply_aug: + print("Applying dataset augmentation") + train_dataset.set_transform(AbsTokenizer().export_data_aug()) + train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, + worker_init_fn=FinetuningDataset.export_worker_init_fn(), ) val_loader = DataLoader( val_dataset, - batch_size=batch_size, + batch_size=1, shuffle=False, num_workers=num_workers, + worker_init_fn=FinetuningDataset.export_worker_init_fn(), ) + return train_loader, val_loader @@ -380,6 +389,7 @@ def _train( train_dataloader: DataLoader, val_dataloader: DataLoader, optimizer: torch.optim.Optimizer, + tag_to_id: dict, scheduler: torch.optim.lr_scheduler.LRScheduler = None, project_dir: str | None = None, ): @@ -398,9 +408,7 @@ def make_checkpoint( ) _accelerator.save_state(checkpoint_dir) - def train_loop( - dataloader: DataLoader, _epoch: int, steps_per_checkpoint: int - ): + def train_loop(dataloader: DataLoader, _epoch: int): loss = torch.tensor([0.0]) avg_train_loss = 0 trailing_loss = 0 @@ -409,8 +417,6 @@ def train_loop( try: lr_for_print = "{:.2e}".format(scheduler.get_last_lr()[0]) except Exception: - pass - else: lr_for_print = "{:.2e}".format(optimizer.param_groups[-1]["lr"]) model.train() @@ -431,7 +437,10 @@ def train_loop( with accelerator.accumulate(model): step = __step + 1 - seqs, labels, eos_pos = batch + seqs, eos_pos, labels = batch + seqs = seqs.squeeze(1) + eos_pos = eos_pos.squeeze(1) + logits = model(seqs) # (b_sz, s_len, class_size) logits = logits[ torch.arange(logits.shape[0], device=logits.device), eos_pos @@ -461,70 +470,113 @@ def train_loop( scheduler.step() lr_for_print = "{:.2e}".format(scheduler.get_last_lr()[0]) - if steps_per_checkpoint: - if step % steps_per_checkpoint == 0: - make_checkpoint( - _accelerator=accelerator, - _epoch=_epoch, - _step=step, - ) - return avg_train_loss - def val_loop(dataloader: DataLoader, _epoch: int): + def val_loop(dataloader: DataLoader, _epoch: int, tag_to_id: dict): model.eval() - val_loss_buffer = [] - total_correct = 0 - total_samples = 0 + pad_id = AbsTokenizer().pad_id + preds = [] + labels = [] - with torch.no_grad(): + with torch.inference_mode(): pbar = tqdm( dataloader, desc=f"Validation Epoch {_epoch}", leave=False ) for batch in pbar: - seqs, labels, eos_pos = batch - logits = model(seqs) # (b_sz, s_len, class_size) + seqs, pos, tag = batch + seqs = seqs.squeeze(0) # (n, max_seq_len) + pos = pos.squeeze(0) # (n,) + + logits = model(seqs) # (n, seq_len, class_size) logits = logits[ - torch.arange(logits.shape[0], device=logits.device), eos_pos + torch.arange(logits.shape[0], device=logits.device), pos ] - loss = loss_fn(logits, labels) - # Gather loss from all devices (if applicable) - val_loss_buffer.append( - accelerator.gather(loss).mean(dim=0).item() - ) + probs = torch.softmax(logits, dim=-1) # (n, class_size) - # Compute predictions and update accuracy stats - preds = torch.argmax(logits, dim=-1) - total_correct += (preds == labels).sum().item() - total_samples += labels.size(0) - current_accuracy = ( - total_correct / total_samples if total_samples > 0 else 0.0 + non_pad_counts = ( + (seqs != pad_id).sum(dim=1, keepdim=True).float() ) - current_avg_loss = sum(val_loss_buffer) / len(val_loss_buffer) + weighted_probs = probs * non_pad_counts + aggregated_probs = weighted_probs.sum(dim=0) + predicted_label = aggregated_probs.argmax().item() - pbar.set_postfix_str( - f"loss={round(current_avg_loss,4)}, acc={round(current_accuracy,4)}" + preds.append(predicted_label) + labels.append(tag.item()) + + tmp_acc = sum(p == t for p, t in zip(preds, labels)) / len( + preds ) + pbar.set_postfix_str(f"acc={round(tmp_acc, 4)}") + + accuracy = sum(p == t for p, t in zip(preds, labels)) / len(labels) + + # Compute per-class F1 scores + id_to_tag = {v: k for k, v in tag_to_id.items()} + # Initialize counts per class + metrics = {tag: {"TP": 0, "FP": 0, "FN": 0} for tag in tag_to_id.keys()} + for true_id, pred_id in zip(labels, preds): + true_tag = id_to_tag[true_id] + pred_tag = id_to_tag[pred_id] + if true_id == pred_id: + metrics[true_tag]["TP"] += 1 + else: + metrics[true_tag]["FN"] += 1 + metrics[pred_tag]["FP"] += 1 + + class_metrics = {} + f1_scores = [] + for tag, counts in metrics.items(): + TP = counts["TP"] + FP = counts["FP"] + FN = counts["FN"] + precision = TP / (TP + FP) if (TP + FP) > 0 else 0 + recall = TP / (TP + FN) if (TP + FN) > 0 else 0 + f1 = ( + (2 * precision * recall / (precision + recall)) + if (precision + recall) > 0 + else 0 + ) + class_metrics[tag] = { + "precision": precision, + "recall": recall, + "F1": f1, + } + f1_scores.append(f1) - avg_val_loss = sum(val_loss_buffer) / len(val_loss_buffer) - accuracy = total_correct / total_samples if total_samples > 0 else 0.0 + macro_f1 = sum(f1_scores) / len(f1_scores) if f1_scores else 0 logger.info( - f"Validation Epoch {_epoch}: average_loss={round(avg_val_loss, 4)}, accuracy={round(accuracy, 4)}" + f"Validation Epoch {_epoch}: accuracy={round(accuracy, 4)}, macro-F1={round(macro_f1, 4)}" ) - return avg_val_loss, accuracy + logger.info(f"Class metrics: {class_metrics}") + + return accuracy, macro_f1, class_metrics logger = get_logger(__name__) loss_fn = nn.CrossEntropyLoss() - TRAILING_LOSS_STEPS = 100 + TRAILING_LOSS_STEPS = 20 - train_loop(dataloader=train_dataloader, _epoch=0, steps_per_checkpoint=2000) - make_checkpoint(_accelerator=accelerator, _epoch=1, _step=0) - val_loop(dataloader=val_dataloader, _epoch=0) + epoch_metrics = [] + for __epoch in range(num_epochs): + train_loop(dataloader=train_dataloader, _epoch=__epoch) + acc, macro_f1, class_metrics = val_loop( + dataloader=val_dataloader, _epoch=__epoch, tag_to_id=tag_to_id + ) + epoch_metrics.append( + { + "accuracy": acc, + "macro_f1": macro_f1, + "class_metrics": class_metrics, + } + ) + + return epoch_metrics def train( model_name: str, + metadata_category: str, + apply_aug: bool, train_data_path: str, val_data_path: str, num_workers: int, @@ -533,12 +585,15 @@ def train( grad_acc_steps: int, project_dir: str | None = None, checkpoint_path: str | None = None, + dataset_size: int | None = None, ): accelerator = accelerate.Accelerator( project_dir=project_dir, gradient_accumulation_steps=grad_acc_steps, ) + tag_to_id = CATEGORY_TAGS[metadata_category] + if accelerator.is_main_process: project_dir = setup_project_dir(project_dir) logger = setup_logger(os.path.join(project_dir)) @@ -548,8 +603,11 @@ def train( logger = get_logger(__name__) logger.info(f"Project directory: {project_dir}") + logger.info(f"Metadata category: {metadata_category}") + logger.info(f"Dataset size: {dataset_size}") + logger.info(f"Applying aug: {apply_aug}") logger.info( - f"Training config: epochs={num_epochs}, batch_size={batch_size}, num_workers={num_workers}" + f"Training config:epochs={num_epochs}, batch_size={batch_size}, num_workers={num_workers}" ) tokenizer = AbsTokenizer() @@ -557,19 +615,19 @@ def train( model_config.set_vocab_size(tokenizer.vocab_size) model = TransformerCL(model_config) + assert model_config.class_size == len(tag_to_id.keys()) + if checkpoint_path is not None: logger.info(f"Loading checkpoint from {checkpoint_path}") model_state = _load_weight(checkpoint_path) model_state = { k.replace("_orig_mod.", ""): v for k, v in model_state.items() } - if "lm_head.weight" in model_state.keys(): - del model_state["lm_head.weight"] + model.load_state_dict(model_state, strict=False) + torch.nn.init.normal_( + model.model.tok_embeddings.weight.data[1:2], mean=0.0, std=0.02 + ) # Re-init EOS tok - model_state = { - k.replace("model.", ""): v for k, v in model_state.items() - } - model.model.load_state_dict(model_state) else: logger.info("No checkpoint path provided") @@ -578,9 +636,11 @@ def train( train_dataloader, val_dataloader = get_dataloaders( train_data_path=train_data_path, val_data_path=val_data_path, + metadata_category=metadata_category, + tag_to_id=tag_to_id, batch_size=batch_size, num_workers=num_workers, - apply_aug=True, + apply_aug=apply_aug, ) optimizer, scheduler = get_optim( @@ -603,50 +663,50 @@ def train( scheduler, ) - _train( + epoch_metrics = _train( num_epochs=num_epochs, accelerator=accelerator, model=model, train_dataloader=train_dataloader, val_dataloader=val_dataloader, optimizer=optimizer, + tag_to_id=tag_to_id, scheduler=scheduler, project_dir=project_dir, ) - -def test_build_dataset(): - FinetuningDataset.build( - midi_dataset_load_path="/mnt/ssd1/aria/data/mididict-ft_val.jsonl", - save_path="/mnt/ssd1/aria/data/train.jsonl", - min_slice_notes=100, - max_slice_notes=165, - max_seq_len=512, - metadata_category=METADATA_CATEGORY, - tag_ids=TAG_IDS, + max_accuracy = ( + max(metric["accuracy"] for metric in epoch_metrics) + if epoch_metrics + else 0.0 ) - - # FinetuningDataset.build( - # midi_dataset_load_path="/mnt/ssd1/aria/data/mididict-ft_val.jsonl", - # save_path="/mnt/ssd1/aria/data/val.jsonl", - # min_slice_notes=100, - # max_slice_notes=165, - # max_seq_len=512, - # metadata_category=METADATA_CATEGORY, - # tag_ids=TAG_IDS, - # ) + logger.info(f"Max accuracy: {max_accuracy}") + results = { + "metadata_category": metadata_category, + "dataset_size": dataset_size, + "epoch_metrics": epoch_metrics, + "max_accuracy": max_accuracy, + } + with open(os.path.join(project_dir, "results.json"), "w") as f: + json.dump(results, f, indent=4) def test_dataset(): + tokenizer = AbsTokenizer() dataset = FinetuningDataset( - load_path="/mnt/ssd1/aria/data/test.jsonl", - tag_ids=TAG_IDS, + load_path="/mnt/ssd1/aria/data/class_eval/genre/classifier_finetune/test.jsonl", + metadata_category="genre", + tag_to_id=CATEGORY_TAGS["genre"], + max_seq_len=1024, + per_file=True, ) - for idx, entry in enumerate(dataset): - print(idx) - # print(entry) - # input("") + for seq, pos, tag in dataset: + print(seq.shape) + print(pos.shape) + print(tag) + + input("") def parse_args(): @@ -654,6 +714,9 @@ def parse_args(): description="Finetune a model for classification." ) parser.add_argument("--model_name", type=str, required=True) + parser.add_argument("--metadata_category", type=str, required=True) + parser.add_argument("--dataset_size", type=int, required=False) + parser.add_argument("--apply_aug", action="store_true") parser.add_argument("--checkpoint_path", type=str, default=None) parser.add_argument("--train_data_path", type=str, required=True) parser.add_argument("--val_data_path", type=str, required=True) @@ -666,18 +729,20 @@ def parse_args(): if __name__ == "__main__": - # args = parse_args() - # train( - # model_name=args.model_name, - # checkpoint_path=args.checkpoint_path, - # train_data_path=args.train_data_path, - # val_data_path=args.val_data_path, - # batch_size=args.batch_size, - # num_epochs=args.num_epochs, - # num_workers=args.num_workers, - # grad_acc_steps=args.grad_acc_steps, - # project_dir=args.project_dir, - # ) - - test_build_dataset() + args = parse_args() + train( + model_name=args.model_name, + metadata_category=args.metadata_category, + dataset_size=args.dataset_size, + apply_aug=args.apply_aug, + checkpoint_path=args.checkpoint_path, + train_data_path=args.train_data_path, + val_data_path=args.val_data_path, + batch_size=args.batch_size, + num_epochs=args.num_epochs, + num_workers=args.num_workers, + grad_acc_steps=args.grad_acc_steps, + project_dir=args.project_dir, + ) + # test_dataset() diff --git a/aria/embeddings/contrastive_finetune.py b/aria/training/contrastive_finetune.py similarity index 98% rename from aria/embeddings/contrastive_finetune.py rename to aria/training/contrastive_finetune.py index 9bf8e2a..2ac32ea 100644 --- a/aria/embeddings/contrastive_finetune.py +++ b/aria/training/contrastive_finetune.py @@ -24,7 +24,6 @@ def setup_logger(project_dir: str): - # Get logger and reset all handlers logger = logging.getLogger(__name__) for h in logger.handlers[:]: logger.removeHandler(h) @@ -299,8 +298,8 @@ def get_dataloaders( batch_size: int, num_workers: int, min_number_slice_notes: int = 100, - max_number_slice_notes: int = 300, - max_seq_len: int = 1024, + max_number_slice_notes: int = 650, + max_seq_len: int = 2048, ): train_dataset = ContrastiveDataset( load_path=train_data_path, @@ -560,13 +559,7 @@ def train( model_state = { k.replace("_orig_mod.", ""): v for k, v in model_state.items() } - if "lm_head.weight" in model_state.keys(): - del model_state["lm_head.weight"] - - model_state = { - k.replace("model.", ""): v for k, v in model_state.items() - } - model.model.load_state_dict(model_state) + model.load_state_dict(model_state, strict=False) else: logger.info("No checkpoint path provided") diff --git a/aria/training/train.py b/aria/training/train.py new file mode 100644 index 0000000..4b2a074 --- /dev/null +++ b/aria/training/train.py @@ -0,0 +1,903 @@ +import os +import sys +import csv +import argparse +import logging +import random +import torch +import accelerate + +from torch import nn as nn +from torch.utils.data import DataLoader + +from accelerate.logging import get_logger +from safetensors.torch import load_file +from logging.handlers import RotatingFileHandler +from tqdm import tqdm +from typing import List + +from aria.config import load_model_config +from aria.model import ModelConfig, TransformerLM, TransformerLM_CND +from ariautils.tokenizer import Tokenizer, AbsTokenizer, RelTokenizer +from aria.datasets import ( + TrainingDataset, + PretrainingDataset, +) +from aria.utils import _load_weight + +torch._dynamo.config.optimize_ddp = False + + +# ----- USAGE ----- +# +# This script is meant to be run using the huggingface accelerate cli, see: +# +# https://huggingface.co/docs/accelerate/basic_tutorials/launch +# https://huggingface.co/docs/accelerate/package_reference/cli +# +# For example usage you could run the pre-training script with: +# +# accelerate launch [arguments] aria/train.py train \ +# small \ +# -train_data data/train \ +# -val_data data/val \ +# -epochs 10 \ +# -bs 32 \ +# -workers 8 +# +# You could resume a run from an accelerate checkpoint with: +# +# accelerate launch [arguments] aria/train.py resume \ +# small \ +# -train_data data/train \ +# -val_data data/val \ +# -cp_dir models/epoch5_step0 \ +# -r_step 0 \ +# -r_epoch 5 \ +# -epochs 5 \ +# -bs 32 \ +# -workers 8 + + +def setup_logger(project_dir: str): + # Get logger and reset all handlers + logger = logging.getLogger(__name__) + for h in logger.handlers[:]: + logger.removeHandler(h) + + logger.propagate = False + logger.setLevel(logging.DEBUG) + formatter = logging.Formatter( + "[%(asctime)s] %(name)s: [%(levelname)s] %(message)s", + ) + + fh = RotatingFileHandler( + os.path.join(project_dir, "logs.txt"), backupCount=5, maxBytes=1024**3 + ) + fh.setLevel(logging.DEBUG) + fh.setFormatter(formatter) + logger.addHandler(fh) + + ch = logging.StreamHandler() + ch.setLevel(logging.INFO) + ch.setFormatter(formatter) + logger.addHandler(ch) + + return get_logger(__name__) # using accelerate.logging.get_logger() + + +def get_tokenizer_name( + train_data_paths: str, + val_data_path: str, +): + """This will throw an error if there is a tokenizer mismatch""" + train_config = TrainingDataset.get_config_from_path(train_data_paths[0]) + val_config = TrainingDataset.get_config_from_path(val_data_path) + + assert ( + train_config["tokenizer_name"] == val_config["tokenizer_name"] + ), "Dataset tokenizers don't match" + + return train_config["tokenizer_name"] + + +def setup_project_dir(project_dir: str | None): + if not project_dir: + # Create project directory + if not os.path.isdir("./experiments"): + os.mkdir("./experiments") + + project_dirs = [ + _dir + for _dir in os.listdir("./experiments") + if os.path.isdir(os.path.join("experiments", _dir)) + ] + + ind = 0 + while True: + if str(ind) not in project_dirs: + break + else: + ind += 1 + + project_dir_abs = os.path.abspath(os.path.join("experiments", str(ind))) + assert not os.path.isdir(project_dir_abs) + os.mkdir(project_dir_abs) + + elif project_dir: + # Run checks on project directory + if os.path.isdir(project_dir): + assert ( + len(os.listdir(project_dir)) == 0 + ), "Provided project directory is not empty" + project_dir_abs = os.path.abspath(project_dir) + elif os.path.isfile(project_dir): + raise FileExistsError( + "The provided path points toward an existing file" + ) + else: + try: + os.mkdir(project_dir) + except Exception as e: + raise e(f"Failed to create project directory at {project_dir}") + project_dir_abs = os.path.abspath(project_dir) + + os.mkdir(os.path.join(project_dir_abs, "checkpoints")) + + return project_dir_abs + + +def _get_optim( + lr: float, + model: nn.Module, + num_epochs: int, + steps_per_epoch: int, + warmup: int = 100, + end_ratio: int = 0.1, +): + optimizer = torch.optim.AdamW( + model.parameters(), + lr=lr, + weight_decay=0.1, + betas=(0.9, 0.95), + eps=1e-5, + ) + + warmup_lrs = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=0.000001, + end_factor=1, + total_iters=warmup, + ) + linear_decay_lrs = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=1, + end_factor=end_ratio, + total_iters=(num_epochs * steps_per_epoch) - warmup, + ) + + lr_scheduler = torch.optim.lr_scheduler.SequentialLR( + optimizer, + schedulers=[warmup_lrs, linear_decay_lrs], + milestones=[warmup], + ) + + return optimizer, lr_scheduler + + +def get_optim( + model: nn.Module, + num_epochs: int, + steps_per_epoch: int, +): + LR = 3e-4 + END_RATIO = 0.1 + WARMUP_STEPS = 200 + + return _get_optim( + lr=LR, + model=model, + num_epochs=num_epochs, + steps_per_epoch=steps_per_epoch, + warmup=WARMUP_STEPS, + end_ratio=END_RATIO, + ) + + +def get_dataloaders( + train_data_dirs: List[str], + val_data_dir: str, + tokenizer: Tokenizer, + batch_size: int, + num_workers: int, + use_embeddings: bool, + init_epoch: int | None = None, + apply_aug: bool = True, +): + train_dataset = PretrainingDataset( + dir_paths=train_data_dirs, + tokenizer=tokenizer, + ) + val_dataset = PretrainingDataset( + dir_paths=val_data_dir, + tokenizer=tokenizer, + ) + + if init_epoch: + train_dataset.init_epoch(idx=init_epoch) + + assert ( + len(val_dataset.epoch_files_by_dir[0]) == 1 + ), "val-data directory should only contain one epoch" + + if apply_aug: + train_dataset.set_transform(tokenizer.export_data_aug()) + + train_dataloader = DataLoader( + train_dataset, + batch_size=batch_size, + num_workers=num_workers, + shuffle=True, + ) + val_dataloader = DataLoader( + val_dataset, + batch_size=batch_size, + num_workers=num_workers, + shuffle=False, + ) + + if use_embeddings is True: + _src, _tgt, _mask, _emb = train_dataset[0] + _src, _tgt, _mask, __emb = val_dataset[0] + assert _emb.numel() != 0, "Embeddings not present in train dataset" + assert __emb.numel() != 0, "Embeddings not present in val dataset" + + return train_dataloader, val_dataloader + + +def _train( + epochs: int, + accelerator: accelerate.Accelerator, + model: TransformerLM, + train_dataloader: DataLoader, + val_dataloader: DataLoader, + use_embeddings: bool, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler = None, + steps_per_checkpoint: int | None = None, + resume_step: int | None = None, + resume_epoch: int | None = None, + project_dir: str | None = None, +): + def make_checkpoint( + _accelerator: accelerate.Accelerator, _epoch: int, _step: int + ): + if accelerator.is_main_process: + checkpoint_dir = os.path.join( + project_dir, + "checkpoints", + f"epoch{_epoch}_step{_step}", + ) + + logger.info( + f"EPOCH {_epoch}/{epochs + start_epoch}: Saving checkpoint - {checkpoint_dir}" + ) + _accelerator.save_state(checkpoint_dir) + + # This is all slightly messy as train_loop and val_loop make use of the + # variables in the wider scope. Perhaps refactor this at some point. + def train_loop(dataloader: DataLoader, _epoch: int, _resume_step: int = 0): + loss = torch.tensor([0.0]) + avg_train_loss = 0 + trailing_loss = 0 + loss_buffer = [] + + try: + lr_for_print = "{:.2e}".format(scheduler.get_last_lr()[0]) + except Exception: + pass + else: + lr_for_print = "{:.2e}".format(optimizer.param_groups[-1]["lr"]) + + model.train() + for __step, batch in ( + pbar := tqdm( + enumerate(dataloader), + total=len(dataloader) + _resume_step, + initial=_resume_step, + leave=False, + ) + ): + pbar.set_postfix_str( + f"lr={lr_for_print}, " + f"loss={round(loss.item(), 4)}, " + f"trailing={round(trailing_loss, 4)}" + ) + + with accelerator.accumulate(model): + step = __step + _resume_step + 1 + src, tgt, mask, emb = ( + batch # (b_sz, s_len), (b_sz, s_len), (b_sz, s_len), (b_sz, d_emb) + ) + + use_embeddings_cond = use_embeddings and (random.random() > 0.5) + + if use_embeddings_cond is True: + logits = model(src=src, emb=emb) # (b_sz, s_len - 1, v_sz) + tgt = tgt[:, :-1] # (b_sz, s_len - 1) + mask = mask[:, :-1] # (b_sz, s_len - 1) + else: + logits = model(src) # (b_sz, s_len, v_sz) + + logits = logits.transpose( + 1, 2 + ) # Transpose for CrossEntropyLoss + loss = loss_fn(logits, tgt) + + if mask.sum() == 0: + loss = (loss * 0).sum() + else: + loss = loss * mask + loss = loss[loss != 0.0].mean() + + # Calculate statistics + loss_buffer.append(accelerator.gather(loss).mean(dim=0).item()) + trailing_loss = sum(loss_buffer[-TRAILING_LOSS_STEPS:]) / len( + loss_buffer[-TRAILING_LOSS_STEPS:] + ) + avg_train_loss = sum(loss_buffer) / len(loss_buffer) + + # Logging + logger.debug( + f"EPOCH {_epoch} STEP {step}: " + f"lr={lr_for_print}, " + f"loss={round(loss.item(), 4)}, " + f"trailing_loss={round(trailing_loss, 4)}, " + f"average_loss={round(avg_train_loss, 4)}" + ) + + if accelerator.is_main_process: + loss_writer.writerow([_epoch, step, loss.item()]) + + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + if scheduler: + scheduler.step() + lr_for_print = "{:.2e}".format(scheduler.get_last_lr()[0]) + + if steps_per_checkpoint: + if step % steps_per_checkpoint == 0: + make_checkpoint( + _accelerator=accelerator, + _epoch=_epoch, + _step=step, + ) + + logger.info( + f"EPOCH {_epoch}/{epochs + start_epoch}: Finished training - " + f"average_loss={round(avg_train_loss, 4)}" + ) + + return avg_train_loss + + @torch.no_grad() + def val_loop(dataloader, _epoch: int): + loss_buffer = [] + model.eval() + for step, batch in ( + pbar := tqdm( + enumerate(dataloader), + total=len(dataloader), + leave=False, + ) + ): + src, tgt, mask, emb = ( + batch # (b_sz, s_len), (b_sz, s_len), (b_sz, s_len), (b_sz, d_emb) + ) + use_embeddings_cond = use_embeddings and (random.random() > 0.5) + + if use_embeddings_cond is True: + logits = model(src=src, emb=emb) # (b_sz, s_len - 1, v_sz) + tgt = tgt[:, :-1] # (b_sz, s_len - 1) + mask = mask[:, :-1] # (b_sz, s_len - 1) + else: + logits = model(src) # (b_sz, s_len, v_sz) + + logits = logits.transpose(1, 2) # Transpose for CrossEntropyLoss + loss = loss_fn(logits, tgt) + + if mask.sum() == 0: + loss = (loss * 0).sum() + else: + loss = loss * mask + loss = loss[loss != 0.0].mean() + + # Logging + loss_buffer.append(accelerator.gather(loss).mean(dim=0).item()) + avg_val_loss = sum(loss_buffer) / len(loss_buffer) + pbar.set_postfix_str(f"average_loss={round(avg_val_loss, 4)}") + + # EPOCH + logger.info( + f"EPOCH {_epoch}/{epochs + start_epoch}: Finished evaluation - " + f"average_loss={round(avg_val_loss, 4)}" + ) + + return avg_val_loss + + if steps_per_checkpoint: + assert ( + steps_per_checkpoint > 1 + ), "Invalid checkpoint mode value (too small)" + + TRAILING_LOSS_STEPS = 200 + PAD_ID = train_dataloader.dataset.tokenizer.pad_id + logger = get_logger(__name__) # Accelerate logger + loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_ID, reduction="none") + + logger.info( + f"Model has " + f"{'{:,}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad))} " + "parameters" + ) + + if accelerator.is_main_process: + loss_csv = open(os.path.join(project_dir, "loss.csv"), "w") + loss_writer = csv.writer(loss_csv) + loss_writer.writerow(["epoch", "step", "loss"]) + epoch_csv = open(os.path.join(project_dir, "epoch.csv"), "w") + epoch_writer = csv.writer(epoch_csv) + epoch_writer.writerow(["epoch", "avg_train_loss", "avg_val_loss"]) + + if resume_epoch is not None: + start_epoch = resume_epoch + 1 + else: + start_epoch = 0 + + if resume_step is not None: + assert resume_epoch is not None, "Must provide resume epoch" + logger.info( + f"Resuming training from step {resume_step} - logging as EPOCH {resume_epoch}" + ) + skipped_dataloader = accelerator.skip_first_batches( + dataloader=train_dataloader, + num_batches=resume_step, + ) + + avg_train_loss = train_loop( + dataloader=skipped_dataloader, + _epoch=resume_epoch, + _resume_step=resume_step, + ) + avg_val_loss = val_loop(dataloader=val_dataloader, _epoch=resume_epoch) + if accelerator.is_main_process: + epoch_writer.writerow([resume_epoch, avg_train_loss, avg_val_loss]) + epoch_csv.flush() + make_checkpoint( + _accelerator=accelerator, _epoch=start_epoch, _step=0 + ) + + for epoch in range(start_epoch, epochs + start_epoch): + train_dataloader.dataset.init_epoch(epoch) + avg_train_loss = train_loop(dataloader=train_dataloader, _epoch=epoch) + avg_val_loss = val_loop(dataloader=val_dataloader, _epoch=epoch) + if accelerator.is_main_process: + epoch_writer.writerow([epoch, avg_train_loss, avg_val_loss]) + epoch_csv.flush() + make_checkpoint(_accelerator=accelerator, _epoch=epoch + 1, _step=0) + + logging.shutdown() + if accelerator.is_main_process: + loss_csv.close() + epoch_csv.close() + + +# TODO: Add use_embeddings logic to this code path +def resume_train( + model_name: str, + train_data_paths: str, + val_data_path: str, + use_embeddings: bool, + num_workers: int, + batch_size: int, + grad_acc_steps: int, + epochs: int, + checkpoint_dir: str, + resume_epoch: int, + resume_step: int, + steps_per_checkpoint: int | None = None, + project_dir: str = None, +): + # Validate inputs + assert 0 < num_workers <= 128, "Too many workers" + assert epochs > 0, "Invalid number of epochs" + assert batch_size > 0, "Invalid batch size" + assert torch.cuda.is_available() is True, "CUDA not available" + assert os.path.isdir(checkpoint_dir), f"No dir at {checkpoint_dir}" + for train_data_path in train_data_paths: + assert os.path.isdir( + train_data_path + ), f"No dir found at {train_data_path}" + assert os.path.isdir(val_data_path), f"No dir found at {val_data_path}" + + tokenizer_name = get_tokenizer_name(train_data_paths, val_data_path) + if tokenizer_name == "abs": + tokenizer = AbsTokenizer() + elif tokenizer_name == "rel": + tokenizer = RelTokenizer() + else: + raise Exception("Invalid tokenizer name") + + accelerator = accelerate.Accelerator( + project_dir=project_dir, gradient_accumulation_steps=grad_acc_steps + ) + if accelerator.is_main_process: + project_dir = setup_project_dir(project_dir) + logger = setup_logger(project_dir) + + logger = get_logger(__name__) + logger.info(f"Using project directory {project_dir} ") + logger.warning( + "Please insure that the training config and resume step are set " + "correctly, the script does not currently check that this is the case. " + "If the previous checkpoint was saved at step n, then resume_step " + "should be n. If there is a mismatch between the batch size then the " + "script will resume at the wrong step. It is also important that the " + "same distributed setup is used for training." + ) + logger.info( + f"Using training config: " + f"model_name={model_name}, " + f"use_embeddings={use_embeddings}, " + f"epochs={epochs}, " + f"batch_size={batch_size}, " + f"grad_acc_steps={grad_acc_steps}, " + f"num_workers={num_workers}, " + f"checkpoint_dir={checkpoint_dir}, " + f"resume_step={resume_step}, " + f"resume_epoch={resume_epoch}" + ) + + if steps_per_checkpoint: + logger.info(f"Creating checkpoints every {steps_per_checkpoint}") + + # Init model + model_config = ModelConfig(**load_model_config(model_name)) + model_config.set_vocab_size(tokenizer.vocab_size) + + if use_embeddings: + model = TransformerLM_CND(model_config) + else: + model = TransformerLM(model_config) + + model.compile() + + train_dataloader, val_dataloader = get_dataloaders( + train_data_dirs=train_data_paths, + val_data_dir=val_data_path, + tokenizer=tokenizer, + init_epoch=resume_epoch, + batch_size=batch_size, + num_workers=num_workers, + apply_aug=True, + use_embeddings=use_embeddings, + ) + optimizer, scheduler = get_optim( + model, + num_epochs=epochs, + steps_per_epoch=len(train_dataloader), + ) + + ( + model, + train_dataloader, + val_dataloader, + optimizer, + scheduler, + ) = accelerator.prepare( + model, + train_dataloader, + val_dataloader, + optimizer, + scheduler, + ) + + try: + accelerator.load_state(checkpoint_dir) + except Exception as e: + raise Exception( + f"Failed to load checkpoint: {e}\n" + "This could be due to a mismatch between the tokenizer used " + "to build the pre-training and fine-tuning datasets" + ) + logger.info(f"Loaded checkpoint at {checkpoint_dir}") + logger.info("Starting train job") + + _train( + epochs=epochs, + accelerator=accelerator, + model=model, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + use_embeddings=use_embeddings, + optimizer=optimizer, + scheduler=scheduler, + steps_per_checkpoint=steps_per_checkpoint, + resume_step=resume_step, + resume_epoch=resume_epoch, + project_dir=project_dir, + ) + + +def train( + model_name: str, + train_data_paths: List[str], + val_data_path: str, + use_embeddings: bool, + num_workers: int, + batch_size: int, + grad_acc_steps: int, + epochs: int, + checkpoint_path: str | None = None, + steps_per_checkpoint: int | None = None, + project_dir: str = None, +): + # Validate inputs + assert 0 < num_workers <= 128, "Too many workers" + assert epochs > 0, "Invalid number of epochs" + assert batch_size > 0, "Invalid batch size" + assert torch.cuda.is_available() is True, "CUDA not available" + for train_data_path in train_data_paths: + assert os.path.isdir( + train_data_path + ), f"No dir found at {train_data_path}" + assert os.path.isdir(val_data_path), f"No dir found at {val_data_path}" + + tokenizer_name = get_tokenizer_name(train_data_paths, val_data_path) + if tokenizer_name == "abs": + tokenizer = AbsTokenizer() + elif tokenizer_name == "rel": + tokenizer = RelTokenizer() + else: + raise Exception("Invalid tokenizer name") + + accelerator = accelerate.Accelerator( + project_dir=project_dir, gradient_accumulation_steps=grad_acc_steps + ) + if accelerator.is_main_process: + project_dir = setup_project_dir(project_dir) + logger = setup_logger(project_dir) + + logger = get_logger(__name__) + logger.info(f"Using project directory {project_dir}") + logger.info( + f"Using training config: " + f"model_name={model_name}, " + f"use_embeddings={use_embeddings}, " + f"checkpoint_path={checkpoint_path}, " + if checkpoint_path + else "" + f"epochs={epochs}, " + f"batch_size={batch_size}, " + f"grad_acc_steps={grad_acc_steps}, " + f"num_workers={num_workers}" + ) + + if steps_per_checkpoint: + logger.info(f"Creating checkpoints every {steps_per_checkpoint}") + + # Init model + model_config = ModelConfig(**load_model_config(model_name)) + model_config.set_vocab_size(tokenizer.vocab_size) + + if use_embeddings is True: + model = TransformerLM_CND(model_config) + else: + model = TransformerLM(model_config) + + model.compile() + logger.info(f"Loaded model with config: {load_model_config(model_name)}") + if checkpoint_path: + try: + model.load_state_dict(_load_weight(checkpoint_path)) + except RuntimeError as e: + print(e) + logger.info( + f"Failed to load {model_name} into {model_name}, attempting with strict=False" + ) + model.load_state_dict(_load_weight(checkpoint_path), strict=False) + + logger.info(f"Loaded finetune checkpoint located at: {checkpoint_path}") + + train_dataloader, val_dataloader = get_dataloaders( + train_data_dirs=train_data_paths, + val_data_dir=val_data_path, + tokenizer=tokenizer, + batch_size=batch_size, + num_workers=num_workers, + apply_aug=True, + use_embeddings=use_embeddings, + ) + + assert ( + train_dataloader.dataset.config["max_seq_len"] + == model_config.max_seq_len + ) + assert ( + val_dataloader.dataset.config["max_seq_len"] == model_config.max_seq_len + ) + + optimizer, scheduler = get_optim( + model, + num_epochs=epochs, + steps_per_epoch=len(train_dataloader), + ) + + ( + model, + train_dataloader, + val_dataloader, + optimizer, + scheduler, + ) = accelerator.prepare( + model, + train_dataloader, + val_dataloader, + optimizer, + scheduler, + ) + + logger.info(f"Starting {'finetune' if checkpoint_path else 'pretrain'} job") + _train( + epochs=epochs, + accelerator=accelerator, + model=model, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + use_embeddings=use_embeddings, + optimizer=optimizer, + scheduler=scheduler, + steps_per_checkpoint=steps_per_checkpoint, + project_dir=project_dir, + ) + + +def convert_cp_from_safetensors(checkpoint_path: str, save_path: str): + d = load_file(checkpoint_path) + key = list(d.keys())[0] + gap = len(key.split(".")[0]) + d = {s[gap + 1 :]: v for s, v in d.items()} + torch.save(d, save_path) + + +def convert_cp_from_accelerate( + model_name: str, tokenizer_name: str, checkpoint_dir: str, save_path: str +): + def _load_state_dict(_tokenizer: Tokenizer): + model_config = ModelConfig(**load_model_config(model_name)) + model_config.set_vocab_size(_tokenizer.vocab_size) + model = TransformerLM(model_config) + model = accelerator.prepare(model) + accelerator.load_state(checkpoint_dir) + + return model.state_dict() + + accelerator = accelerate.Accelerator() + + # Try both + if tokenizer_name == "abs": + state_dict = _load_state_dict(_tokenizer=AbsTokenizer()) + elif tokenizer_name == "rel": + state_dict = _load_state_dict(_tokenizer=RelTokenizer()) + else: + print("Invalid choice of tokenizer") + + torch.save(state_dict, save_path) + + +def parse_resume_args(): + argp = argparse.ArgumentParser(prog="python aria/train.py resume") + argp.add_argument("model", help="name of model config file") + argp.add_argument("-train_data", nargs="+", help="path to train dir") + argp.add_argument("-val_data", help="path to val dir") + argp.add_argument("-cp_dir", help="checkpoint dir", type=str, required=True) + argp.add_argument( + "-use_embeddings", help="prepend embeddings", action="store_true" + ) + argp.add_argument("-r_step", help="resume step", type=int, required=True) + argp.add_argument("-r_epoch", help="resume epoch", type=int, required=True) + argp.add_argument("-epochs", help="train epochs", type=int, required=True) + argp.add_argument("-bs", help="batch size", type=int, default=32) + argp.add_argument( + "-grad_acc_steps", + help="gradient accumulation steps", + type=int, + default=1, + ) + argp.add_argument("-workers", help="number workers", type=int, default=1) + argp.add_argument("-pdir", help="project dir", type=str, required=False) + argp.add_argument( + "-spc", help="steps per checkpoint", type=int, required=False + ) + + return argp.parse_args(sys.argv[2:]) + + +def parse_train_args(): + argp = argparse.ArgumentParser(prog="python aria/train.py train") + argp.add_argument("model", help="name of model config file") + argp.add_argument("-train_data", nargs="+", help="path to train dir") + argp.add_argument("-val_data", help="path to val dir") + argp.add_argument( + "-cp_path", help="path to checkpoint", required=False, default=None + ) + argp.add_argument( + "-use_embeddings", help="prepend embeddings", action="store_true" + ) + argp.add_argument("-epochs", help="train epochs", type=int, required=True) + argp.add_argument("-bs", help="batch size", type=int, default=32) + argp.add_argument( + "-grad_acc_steps", + help="gradient accumulation steps", + type=int, + default=1, + ) + argp.add_argument("-workers", help="number workers", type=int, default=1) + argp.add_argument("-pdir", help="project dir", type=str, required=False) + argp.add_argument( + "-spc", help="steps per checkpoint", type=int, required=False + ) + + return argp.parse_args(sys.argv[2:]) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + usage="python aria/train.py []" + ) + parser.add_argument( + "mode", help="training function", choices=("train", "resume") + ) + + args = parser.parse_args(sys.argv[1:2]) + if not hasattr(args, "mode"): + parser.print_help() + print("Unrecognized command") + exit(1) + elif args.mode == "train": + train_args = parse_train_args() + train( + model_name=train_args.model, + train_data_paths=train_args.train_data, + use_embeddings=train_args.use_embeddings, + val_data_path=train_args.val_data, + num_workers=train_args.workers, + batch_size=train_args.bs, + grad_acc_steps=train_args.grad_acc_steps, + epochs=train_args.epochs, + checkpoint_path=train_args.cp_path, + steps_per_checkpoint=train_args.spc, + project_dir=train_args.pdir, + ) + elif args.mode == "resume": + resume_args = parse_resume_args() + resume_train( + model_name=resume_args.model, + train_data_paths=resume_args.train_data, + val_data_path=resume_args.val_data, + use_embeddings=resume_args.use_embeddings, + num_workers=resume_args.workers, + batch_size=resume_args.bs, + grad_acc_steps=resume_args.grad_acc_steps, + epochs=resume_args.epochs, + checkpoint_dir=resume_args.cp_dir, + resume_step=resume_args.r_step, + resume_epoch=resume_args.r_epoch, + steps_per_checkpoint=resume_args.spc, + project_dir=resume_args.pdir, + ) + else: + print("Unrecognized command") + parser.print_help() + exit(1) diff --git a/config/accelerate.yaml b/config/accelerate.yaml deleted file mode 100644 index 066be39..0000000 --- a/config/accelerate.yaml +++ /dev/null @@ -1,16 +0,0 @@ -compute_environment: LOCAL_MACHINE -debug: false -distributed_type: 'NO' -downcast_bf16: 'no' -gpu_ids: all -machine_rank: 0 -main_training_function: main -mixed_precision: bf16 -num_machines: 1 -num_processes: 1 -rdzv_backend: static -same_network: true -tpu_env: [] -tpu_use_cluster: false -tpu_use_sudo: false -use_cpu: false diff --git a/config/models/large.json b/config/models/large.json deleted file mode 100644 index 44014f3..0000000 --- a/config/models/large.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "d_model": 2048, - "n_heads": 32, - "n_layers": 16, - "ff_mult": 4, - "drop_p": 0.0, - "max_seq_len": 8192, - "grad_checkpoint": true -} \ No newline at end of file diff --git a/config/models/medium-composer.json b/config/models/medium-composer.json index ece2209..e3245c4 100644 --- a/config/models/medium-composer.json +++ b/config/models/medium-composer.json @@ -6,25 +6,6 @@ "drop_p": 0.0, "max_seq_len": 8192, "grad_checkpoint": true, - "class_size": 18, - "tag_to_id": { - "chopin": 0, - "bach": 1, - "beethoven": 2, - "liszt": 3, - "mozart": 4, - "debussy": 5, - "schumann": 6, - "schubert": 7, - "rachmaninoff": 8, - "brahms": 9, - "tchaikovsky": 10, - "haydn": 11, - "scriabin": 12, - "mendelssohn": 13, - "czerny": 14, - "ravel": 15, - "scarlatti": 16, - "other": 17 - } -} \ No newline at end of file + "class_size": 10, + "resid_dropout": 0.2 +} diff --git a/config/models/medium-emotion.json b/config/models/medium-emotion.json new file mode 100644 index 0000000..27a896b --- /dev/null +++ b/config/models/medium-emotion.json @@ -0,0 +1,11 @@ +{ + "d_model": 1536, + "n_heads": 24, + "n_layers": 16, + "ff_mult": 4, + "drop_p": 0.0, + "max_seq_len": 8192, + "grad_checkpoint": true, + "class_size": 4, + "resid_dropout": 0.2 +} diff --git a/config/models/medium-form.json b/config/models/medium-form.json new file mode 100644 index 0000000..2d9a656 --- /dev/null +++ b/config/models/medium-form.json @@ -0,0 +1,11 @@ +{ + "d_model": 1536, + "n_heads": 24, + "n_layers": 16, + "ff_mult": 4, + "drop_p": 0.0, + "max_seq_len": 8192, + "grad_checkpoint": true, + "class_size": 6, + "resid_dropout": 0.2 +} diff --git a/config/models/medium-genre.json b/config/models/medium-genre.json index 97a4b89..31d2bdc 100644 --- a/config/models/medium-genre.json +++ b/config/models/medium-genre.json @@ -6,10 +6,6 @@ "drop_p": 0.0, "max_seq_len": 8192, "grad_checkpoint": true, - "class_size": 3, - "tag_to_id": { - "classical": 0, - "jazz": 1, - "other": 2 - } + "class_size": 2, + "resid_dropout": 0.2 } diff --git a/config/models/medium-music_period.json b/config/models/medium-music_period.json new file mode 100644 index 0000000..27a896b --- /dev/null +++ b/config/models/medium-music_period.json @@ -0,0 +1,11 @@ +{ + "d_model": 1536, + "n_heads": 24, + "n_layers": 16, + "ff_mult": 4, + "drop_p": 0.0, + "max_seq_len": 8192, + "grad_checkpoint": true, + "class_size": 4, + "resid_dropout": 0.2 +} diff --git a/config/models/medium-pianist.json b/config/models/medium-pianist.json new file mode 100644 index 0000000..73b179b --- /dev/null +++ b/config/models/medium-pianist.json @@ -0,0 +1,11 @@ +{ + "d_model": 1536, + "n_heads": 24, + "n_layers": 16, + "ff_mult": 4, + "drop_p": 0.0, + "max_seq_len": 8192, + "grad_checkpoint": true, + "class_size": 8, + "resid_dropout": 0.2 +} diff --git a/paper/scripts/build_aria_ft_emb_dataset.sh b/paper/scripts/build_aria_ft_emb_dataset.sh deleted file mode 100644 index 7444685..0000000 --- a/paper/scripts/build_aria_ft_emb_dataset.sh +++ /dev/null @@ -1,16 +0,0 @@ -python /home/loubb/work/aria/paper/scripts/build_embedding_eval_datasets.py \ - --model aria \ - --model_cp_path /home/loubb/work/aria/models/emb-t0.1-s2048-e25.safetensors \ - --dataset_load_path /mnt/ssd1/aria/data/mididict-ft_val.jsonl \ - --dataset_save_path /mnt/ssd1/aria/data/finetune/ft-val_emb.jsonl \ - --compute_per_file_embeddings \ - --aria_max_batch_size 128 - -aria pretrain-dataset \ - -tokenizer_name abs \ - -load_path /mnt/ssd1/aria/data/mididict-ft_val.jsonl \ - -embedding_dataset_path /mnt/ssd1/aria/data/finetune/ft-val_emb.jsonl \ - -save_dir /mnt/ssd1/aria/data/finetune/val \ - -l 8192 \ - -e 1 \ - -sep_sequences \ No newline at end of file diff --git a/paper/scripts/build_dataset/build_aria_dataset.sh b/paper/scripts/build_dataset/build_aria_dataset.sh deleted file mode 100644 index a46a273..0000000 --- a/paper/scripts/build_dataset/build_aria_dataset.sh +++ /dev/null @@ -1,15 +0,0 @@ -python /home/loubb/work/aria/paper/scripts/build_embedding_eval_datasets.py \ - --model aria \ - --model_cp_path /home/loubb/work/aria/models/emb-t0.1-s2048-e25.safetensors \ - --dataset_load_path /mnt/ssd1/aria/data/paper/clas/pianist/train-mididict.jsonl \ - --dataset_save_path /mnt/ssd1/aria/data/paper/clas/pianist/train-aria.jsonl \ - --compute_per_file_embeddings \ - --aria_max_batch_size 128 - -python /home/loubb/work/aria/paper/scripts/build_embedding_eval_datasets.py \ - --model aria \ - --model_cp_path /home/loubb/work/aria/models/emb-t0.1-s2048-e25.safetensors \ - --dataset_load_path /mnt/ssd1/aria/data/paper/clas/pianist/test-mididict.jsonl \ - --dataset_save_path /mnt/ssd1/aria/data/paper/clas/pianist/test-aria.jsonl \ - --compute_per_file_embeddings \ - --aria_max_batch_size 128 diff --git a/paper/scripts/build_dataset/build_clamp_dataset.sh b/paper/scripts/build_dataset/build_clamp_dataset.sh deleted file mode 100644 index 8c06845..0000000 --- a/paper/scripts/build_dataset/build_clamp_dataset.sh +++ /dev/null @@ -1,13 +0,0 @@ -python /home/loubb/work/aria/paper/scripts/build_embedding_eval_datasets.py \ - --model m3 \ - --model_cp_path /home/loubb/work/clamp3/weights_clamp3_saas_h_size_768_t_model_FacebookAI_xlm-roberta-base_t_length_128_a_size_768_a_layers_12_a_length_128_s_size_768_s_layers_12_p_size_64_p_length_512.pth \ - --dataset_load_path /mnt/ssd1/aria/data/paper/clas/pianist/train-mididict.jsonl \ - --dataset_save_path /mnt/ssd1/aria/data/paper/clas/pianist/train-clamp.jsonl \ - --compute_per_file_embeddings - -python /home/loubb/work/aria/paper/scripts/build_embedding_eval_datasets.py \ - --model m3 \ - --model_cp_path /home/loubb/work/clamp3/weights_clamp3_saas_h_size_768_t_model_FacebookAI_xlm-roberta-base_t_length_128_a_size_768_a_layers_12_a_length_128_s_size_768_s_layers_12_p_size_64_p_length_512.pth \ - --dataset_load_path /mnt/ssd1/aria/data/paper/clas/pianist/test-mididict.jsonl \ - --dataset_save_path /mnt/ssd1/aria/data/paper/clas/pianist/test-clamp.jsonl \ - --compute_per_file_embeddings diff --git a/paper/scripts/build_dataset/build_m3_dataset.sh b/paper/scripts/build_dataset/build_m3_dataset.sh deleted file mode 100644 index ee65960..0000000 --- a/paper/scripts/build_dataset/build_m3_dataset.sh +++ /dev/null @@ -1,15 +0,0 @@ -python /home/loubb/work/aria/paper/scripts/build_embedding_eval_datasets.py \ - --model m3 \ - --model_cp_path /home/loubb/work/clamp3/weights_m3_p_size_64_p_length_512_t_layers_3_p_layers_12_h_size_768_lr_0.0001_batch_16_mask_0.45.pth \ - --dataset_load_path /mnt/ssd1/aria/data/paper/clas/pianist/train-mididict.jsonl \ - --dataset_save_path /mnt/ssd1/aria/data/paper/clas/pianist/train-m3.jsonl \ - --compute_per_file_embeddings \ - --m3_is_encoder_checkpoint - -python /home/loubb/work/aria/paper/scripts/build_embedding_eval_datasets.py \ - --model m3 \ - --model_cp_path /home/loubb/work/clamp3/weights_m3_p_size_64_p_length_512_t_layers_3_p_layers_12_h_size_768_lr_0.0001_batch_16_mask_0.45.pth \ - --dataset_load_path /mnt/ssd1/aria/data/paper/clas/pianist/test-mididict.jsonl \ - --dataset_save_path /mnt/ssd1/aria/data/paper/clas/pianist/test-m3.jsonl \ - --compute_per_file_embeddings \ - --m3_is_encoder_checkpoint diff --git a/paper/scripts/build_dataset/build_mert_dataset.sh b/paper/scripts/build_dataset/build_mert_dataset.sh deleted file mode 100644 index 45f1215..0000000 --- a/paper/scripts/build_dataset/build_mert_dataset.sh +++ /dev/null @@ -1,15 +0,0 @@ -python /home/loubb/work/aria/paper/scripts/build_embedding_eval_datasets.py \ - --model mert \ - --dataset_load_path /mnt/ssd1/aria/data/paper/clas/pianist/train-mididict.jsonl \ - --dataset_save_path /mnt/ssd1/aria/data/paper/clas/pianist/train-mert.jsonl \ - --mert_pianoteq_exec_path "/home/loubb/pianoteq/x86-64bit/Pianoteq 8 STAGE" \ - --mert_pianoteq_num_procs 16 \ - --compute_per_file_embeddings - -python /home/loubb/work/aria/paper/scripts/build_embedding_eval_datasets.py \ - --model mert \ - --dataset_load_path /mnt/ssd1/aria/data/paper/clas/pianist/test-mididict.jsonl \ - --dataset_save_path /mnt/ssd1/aria/data/paper/clas/pianist/test-mert.jsonl \ - --mert_pianoteq_exec_path "/home/loubb/pianoteq/x86-64bit/Pianoteq 8 STAGE" \ - --mert_pianoteq_num_procs 16 \ - --compute_per_file_embeddings \ No newline at end of file diff --git a/paper/scripts/build_embedding_eval_datasets.py b/paper/scripts/build_embedding_eval_datasets.py deleted file mode 100644 index 3f2dad6..0000000 --- a/paper/scripts/build_embedding_eval_datasets.py +++ /dev/null @@ -1,242 +0,0 @@ -import os -import argparse -import torch -import torch.nn as nn - -from ariautils.tokenizer import AbsTokenizer -from aria.embeddings.evaluate import ( - EvaluationDataset, - get_aria_contrastive_embedding, - get_clamp3_embedding, - get_mert_embedding, -) - -MAX_SEQ_LEN = 1024 -NUM_SLICE_NOTES = 300 -SEQS_BATCH_SIZE = 128 - - -def aria_model_forward( - model: nn.Module, - idxs: torch.Tensor, -): - return model(idxs) - - -def build_aria_dataset( - midi_dataset_load_path: str, - embedding_dataset_save_path: str, - checkpoint_path: str, - per_file_embeddings: bool, - max_batch_size: int, - compile: bool, -): - from aria.config import load_model_config - from aria.utils import _load_weight - from aria.model import ModelConfig, TransformerEMB - - assert os.path.isfile(midi_dataset_load_path) - assert os.path.isfile(checkpoint_path) - assert not os.path.isfile(embedding_dataset_save_path) - - tokenizer = AbsTokenizer() - model_state = _load_weight(checkpoint_path, "cuda") - model_state = { - k.replace("_orig_mod.", ""): v for k, v in model_state.items() - } - pretrained_model_config = ModelConfig(**load_model_config("medium-emb")) - pretrained_model_config.set_vocab_size(tokenizer.vocab_size) - pretrained_model_config.grad_checkpoint = False - pretrained_model = TransformerEMB(pretrained_model_config) - pretrained_model.load_state_dict(model_state) - pretrained_model.eval() - - if compile is True: - hook_model_forward = torch.compile( - aria_model_forward, - mode="reduce-overhead", - fullgraph=True, - ) - else: - hook_model_forward = aria_model_forward - - EvaluationDataset.build( - midi_dataset_load_path=midi_dataset_load_path, - save_path=embedding_dataset_save_path, - max_seq_len=MAX_SEQ_LEN, - slice_len_notes=NUM_SLICE_NOTES, - batch_size=SEQS_BATCH_SIZE, - per_file_embeddings=per_file_embeddings, - embedding_hook=get_aria_contrastive_embedding, - hook_model=pretrained_model.cuda(), - hook_max_seq_len=MAX_SEQ_LEN, - hook_tokenizer=tokenizer, - hook_model_forward=hook_model_forward, - hook_max_batch_size=max_batch_size, - ) - - -def build_m3_dataset( - midi_dataset_load_path: str, - embedding_dataset_save_path: str, - checkpoint_path: str, - is_encoder_checkpoint: bool, - per_file_embeddings: bool, -): - from aria.embeddings.m3.emb import load_clamp3_model - - assert os.path.isfile(midi_dataset_load_path) - assert os.path.isfile(checkpoint_path) - assert not os.path.isfile(embedding_dataset_save_path) - - tokenizer = AbsTokenizer() - model, patchilizer = load_clamp3_model( - checkpoint_path=checkpoint_path, m3_only=is_encoder_checkpoint - ) - - # Workaround to outsource global_emb calculation to model - slice_len_notes = NUM_SLICE_NOTES if per_file_embeddings is False else 10000 - max_seq_len = MAX_SEQ_LEN if per_file_embeddings is False else 100000 - - EvaluationDataset.build( - midi_dataset_load_path=midi_dataset_load_path, - save_path=embedding_dataset_save_path, - max_seq_len=max_seq_len, - slice_len_notes=slice_len_notes, - batch_size=SEQS_BATCH_SIZE, - per_file_embeddings=per_file_embeddings, - embedding_hook=get_clamp3_embedding, - hook_model=model, - hook_patchilizer=patchilizer, - hook_tokenizer=tokenizer, - ) - - -def build_mert_dataset( - midi_dataset_load_path: str, - embedding_dataset_save_path: str, - per_file_embeddings: bool, - pianoteq_exec_path: str, - pianoteq_num_procs: int, -): - from aria.embeddings.mert.emb import load_mert_model - - assert pianoteq_num_procs > 0 - assert os.path.isfile(midi_dataset_load_path) - assert not os.path.isfile(embedding_dataset_save_path) - - tokenizer = AbsTokenizer() - model, processor = load_mert_model() - - EvaluationDataset.build( - midi_dataset_load_path=midi_dataset_load_path, - save_path=embedding_dataset_save_path, - max_seq_len=MAX_SEQ_LEN, - slice_len_notes=NUM_SLICE_NOTES, - batch_size=SEQS_BATCH_SIZE, - per_file_embeddings=per_file_embeddings, - embedding_hook=get_mert_embedding, - hook_model=model, - hook_processor=processor, - hook_tokenizer=tokenizer, - hook_pianoteq_exec_path=pianoteq_exec_path, - hook_pianoteq_num_procs=pianoteq_num_procs, - ) - - -def main(): - parser = argparse.ArgumentParser( - description="Process model and dataset paths." - ) - parser.add_argument( - "--model", - type=str, - choices=["aria", "mert", "m3"], - required=True, - ) - parser.add_argument( - "--model_cp_path", - type=str, - required=False, - help="Path from which to load the model.", - ) - parser.add_argument( - "--dataset_load_path", - type=str, - required=True, - help="Path from which to load the dataset.", - ) - parser.add_argument( - "--dataset_save_path", - type=str, - required=True, - help="Path where the dataset will be saved.", - ) - parser.add_argument( - "--compute_per_file_embeddings", - action="store_true", - help="Compute embeddings on a per-file basis", - ) - parser.add_argument( - "--aria_max_batch_size", - type=int, - default=128, - help="Max batch size for aria embedding forward pass", - ) - parser.add_argument( - "--aria_compile", - action="store_true", - help="Compile forward pass", - ) - parser.add_argument( - "--m3_is_encoder_checkpoint", - action="store_true", - help="Checkpoint is for entire clamp model.", - ) - parser.add_argument( - "--mert_pianoteq_exec_path", - type=str, - required=False, - help="Path to pianoteq executable", - ) - parser.add_argument( - "--mert_pianoteq_num_procs", - type=int, - default=16, - help="Num of procs to use for audio synthesis", - ) - - args = parser.parse_args() - - if args.model == "aria": - assert args.aria_max_batch_size > 0 - build_aria_dataset( - midi_dataset_load_path=args.dataset_load_path, - embedding_dataset_save_path=args.dataset_save_path, - checkpoint_path=args.model_cp_path, - per_file_embeddings=args.compute_per_file_embeddings, - max_batch_size=args.aria_max_batch_size, - compile=args.aria_compile, - ) - elif args.model == "m3": - build_m3_dataset( - midi_dataset_load_path=args.dataset_load_path, - embedding_dataset_save_path=args.dataset_save_path, - checkpoint_path=args.model_cp_path, - is_encoder_checkpoint=args.m3_is_encoder_checkpoint, - per_file_embeddings=args.compute_per_file_embeddings, - ) - elif args.model == "mert": - assert args.mert_pianoteq_exec_path - assert args.mert_pianoteq_num_procs > 0 - build_mert_dataset( - midi_dataset_load_path=args.dataset_load_path, - embedding_dataset_save_path=args.dataset_save_path, - per_file_embeddings=args.compute_per_file_embeddings, - pianoteq_exec_path=args.mert_pianoteq_exec_path, - pianoteq_num_procs=args.mert_pianoteq_num_procs, - ) - - -if __name__ == "__main__": - main() diff --git a/paper/scripts/evaluate_embedding_with_probe.py b/paper/scripts/evaluate_embedding_with_probe.py deleted file mode 100644 index 9753b26..0000000 --- a/paper/scripts/evaluate_embedding_with_probe.py +++ /dev/null @@ -1,94 +0,0 @@ -import argparse - - -from aria.embeddings.evaluate import ( - train_classifier, - evaluate_classifier, - CATEGORY_TAGS, -) - -EMBEDDING_SIZE = { - "aria": 512, - "m3": 768, - "mert": 1024, -} - - -def evaluate_embeddings( - model_name: str, - metadata_category: str, - train_dataset_path: str, - test_dataset_path: str, - num_epochs: str, - batch_size: str, -): - embedding_size = EMBEDDING_SIZE[model_name] - tag_to_id = CATEGORY_TAGS[metadata_category] - - model = train_classifier( - embedding_dimension=embedding_size, - train_dataset_path=train_dataset_path, - metadata_category=metadata_category, - tag_to_id=tag_to_id, - batch_size=batch_size, - num_epochs=num_epochs, - ) - evaluate_classifier( - model=model, - evaluation_dataset_path=test_dataset_path, - metadata_category=metadata_category, - tag_to_id=tag_to_id, - ) - - -def main(): - parser = argparse.ArgumentParser( - description="Train and evaluate embeddings with linear prob" - ) - parser.add_argument( - "--model", - type=str, - choices=["aria", "mert", "m3"], - required=True, - ) - parser.add_argument( - "--metadata_category", - type=str, - choices=["genre", "music_period", "composer", "form", "pianist"], - required=True, - ) - parser.add_argument( - "--train_dataset_path", - type=str, - required=True, - ) - parser.add_argument( - "--test_dataset_path", - type=str, - required=True, - ) - parser.add_argument( - "--num_epochs", - type=int, - default=1, - ) - parser.add_argument( - "--batch_size", - type=int, - default=8, - help="batch_size for training classifier", - ) - args = parser.parse_args() - - evaluate_embeddings( - model_name=args.model, - metadata_category=args.metadata_category, - train_dataset_path=args.train_dataset_path, - test_dataset_path=args.test_dataset_path, - num_epochs=args.num_epochs, - batch_size=args.batch_size, - ) - - -if __name__ == "__main__": - main() diff --git a/paper/scripts/evaluate_embeddings.sh b/paper/scripts/evaluate_embeddings.sh deleted file mode 100644 index d992cf1..0000000 --- a/paper/scripts/evaluate_embeddings.sh +++ /dev/null @@ -1,10 +0,0 @@ -MODEL="m3" -CATEGORY="pianist" -echo "Evaluating model ${MODEL} on category: ${CATEGORY}" - -python /home/loubb/work/aria/paper/scripts/evaluate_embedding_with_probe.py \ - --model $MODEL \ - --metadata_category $CATEGORY \ - --train_dataset_path "/mnt/ssd1/aria/data/paper/clas/${CATEGORY}/train-${MODEL}.jsonl" \ - --test_dataset_path "/mnt/ssd1/aria/data/paper/clas/${CATEGORY}/test-${MODEL}.jsonl" \ - --num_epochs 50 diff --git a/paper/scripts/make_eval_split.py b/paper/scripts/make_eval_split.py deleted file mode 100644 index 6d60ee0..0000000 --- a/paper/scripts/make_eval_split.py +++ /dev/null @@ -1,133 +0,0 @@ -import json -import random -import argparse -from collections import Counter -from pathlib import Path - -from aria.datasets import build_mididict_dataset -from aria.embeddings.evaluate import CATEGORY_TAGS - -random.seed(42) - -MIDI_DATASET_TRAIN_SIZE = 10000 -MIDI_DATASET_TEST_SIZE = 1000 - - -def get_midi_paths( - dataset_dir: str, - metadata_path: str, - metadata_category: str, -): - metadata_tags = list(CATEGORY_TAGS[metadata_category].keys()) - with open(metadata_path, "r") as f: - metadata_dict = json.load(f) - metadata_dict = {k: v["metadata"] for k, v in metadata_dict.items()} - - midi_paths = list(Path(dataset_dir).rglob("*.mid")) - buckets = {tag: [] for tag in metadata_tags} - - for midi_file in midi_paths: - # Extract metadata key from file name (e.g., "000001_0" -> "1") - key = str(int(midi_file.stem.split("_")[0])) - metadata = metadata_dict.get(key) - if not metadata: - continue - tag = metadata.get(metadata_category) - if tag in metadata_tags: - buckets[tag].append(midi_file) - - # Calculate the desired count per tag for both splits - num_tags = len(metadata_tags) - desired_train_per_tag = MIDI_DATASET_TRAIN_SIZE // num_tags - desired_test_per_tag = MIDI_DATASET_TEST_SIZE // num_tags - - train_paths = [] - test_paths = [] - for tag, files in buckets.items(): - random.shuffle(files) - total_files = len(files) - total_desired = desired_train_per_tag + desired_test_per_tag - - if total_files >= total_desired: - # Enough files: use fixed numbers. - train_paths.extend(files[:desired_train_per_tag]) - test_paths.extend( - files[ - desired_train_per_tag : desired_train_per_tag - + desired_test_per_tag - ] - ) - else: - # Not enough files: split based on the desired ratio. - train_ratio = desired_train_per_tag / total_desired - train_count = round(total_files * train_ratio) - test_count = total_files - train_count # all remaining go to test - train_paths.extend(files[:train_count]) - test_paths.extend(files[train_count : train_count + test_count]) - - def _extract_tag(midi_file): - key = str(int(midi_file.stem.split("_")[0])) - return metadata_dict.get(key, {}).get(metadata_category, "unknown") - - train_distribution = Counter(_extract_tag(mp) for mp in train_paths) - test_distribution = Counter(_extract_tag(mp) for mp in test_paths) - - print( - f"Finished with splits: train={len(train_paths)}, test={len(test_paths)}" - ) - print("Train distribution:", dict(train_distribution)) - print("Test distribution:", dict(test_distribution)) - - return train_paths, test_paths - - -def main(): - parser = argparse.ArgumentParser( - description="Train and evaluate embeddings with linear prob" - ) - parser.add_argument( - "--dataset_dir", - type=str, - required=True, - ) - parser.add_argument( - "--metadata_path", - type=str, - required=True, - ) - parser.add_argument( - "--metadata_category", - type=str, - choices=["genre", "music_period", "composer", "form"], - required=True, - ) - parser.add_argument( - "--train_save_path", - type=str, - required=True, - ) - parser.add_argument( - "--test_save_path", - type=str, - required=True, - ) - args = parser.parse_args() - - train_paths, test_paths = get_midi_paths( - dataset_dir=args.dataset_dir, - metadata_path=args.metadata_path, - metadata_category=args.metadata_category, - ) - - build_mididict_dataset( - mid_paths=train_paths, - stream_save_path=args.train_save_path, - ) - build_mididict_dataset( - mid_paths=test_paths, - stream_save_path=args.test_save_path, - ) - - -if __name__ == "__main__": - main() diff --git a/paper/scripts/make_pianist8_dataset.py b/paper/scripts/make_pianist8_dataset.py deleted file mode 100644 index 79e4163..0000000 --- a/paper/scripts/make_pianist8_dataset.py +++ /dev/null @@ -1,118 +0,0 @@ -import os -import re -import json -import random -import argparse -from pathlib import Path - -from ariautils.midi import MidiDict -from aria.datasets import MidiDataset - -random.seed(43) - -SPLIT_RATIO = 0.9 - - -def get_midi_paths(dataset_dir: str, test_split_file: str = None): - train_paths = [] - test_paths = [] - - test_pairs = set() - if test_split_file: - with open(test_split_file, "r") as f: - test_files = json.load(f) - for entry in test_files: - parts = re.split(r"[\\/]", entry) - assert len(parts) == 3 - - pianist = parts[1].lower() - file_name = parts[2].replace(".npy", ".mid") - test_pairs.add((pianist, file_name)) - - pianist_categories = os.listdir(dataset_dir) - for pianist in pianist_categories: - pianist_dir = os.path.join(dataset_dir, pianist) - mid_paths = list(Path(pianist_dir).glob("*.mid")) - random.shuffle(mid_paths) - print(f"Found {len(mid_paths)} for {pianist}") - - if test_pairs: - for path in mid_paths: - if (pianist.lower(), path.name) in test_pairs: - test_paths.append( - {"path": path, "pianist": pianist.lower()} - ) - else: - train_paths.append( - {"path": path, "pianist": pianist.lower()} - ) - else: - split_idx = int(len(mid_paths) * SPLIT_RATIO) - train_paths += [ - {"path": path, "pianist": pianist.lower()} - for path in mid_paths[:split_idx] - ] - test_paths += [ - {"path": path, "pianist": pianist.lower()} - for path in mid_paths[split_idx:] - ] - - train_mididicts = [] - for path_entry in train_paths: - _mid_dict = MidiDict.from_midi(mid_path=path_entry["path"]) - _mid_dict.metadata["pianist"] = path_entry["pianist"] - train_mididicts.append(_mid_dict) - - test_mididicts = [] - for path_entry in test_paths: - _mid_dict = MidiDict.from_midi(mid_path=path_entry["path"]) - _mid_dict.metadata["pianist"] = path_entry["pianist"] - test_mididicts.append(_mid_dict) - - return train_mididicts, test_mididicts - - -def main(): - parser = argparse.ArgumentParser( - description="Create pianist8 dataset train-test split" - ) - parser.add_argument( - "--dataset_dir", - type=str, - required=True, - ) - parser.add_argument( - "--train_save_path", - type=str, - required=True, - ) - parser.add_argument( - "--test_save_path", - type=str, - required=True, - ) - parser.add_argument( - "--test_split", - type=str, - default=None, - help="Path to JSON file listing test split files (paths like 'pianist8//.npy')", - ) - args = parser.parse_args() - - assert os.path.isdir(args.dataset_dir) - assert not os.path.isfile(args.train_save_path) - assert not os.path.isfile(args.test_save_path) - - train_mididicts, test_mididicts = get_midi_paths( - dataset_dir=args.dataset_dir, - test_split_file=args.test_split, - ) - - TrainDataset = MidiDataset(entries=train_mididicts).save( - args.train_save_path - ) - TestDataset = MidiDataset(entries=test_mididicts).save(args.test_save_path) - - -if __name__ == "__main__": - main() diff --git a/tests/test_data.py b/tests/test_data.py deleted file mode 100644 index bd477c6..0000000 --- a/tests/test_data.py +++ /dev/null @@ -1,305 +0,0 @@ -import unittest -import os -import shutil -import logging - -from aria import tokenizer -from aria.config import load_config -from aria.data import datasets -from aria.data.datasets import _noise_midi_dict -from ariautils.midi import MidiDict - -logger = logging.getLogger(__name__) -if not os.path.isdir("tests/test_results"): - os.makedirs("tests/test_results") - - -def setup_logger(): - logger = logging.getLogger(__name__) - for h in logger.handlers[:]: - logger.removeHandler(h) - logger.propagate = False - logger.setLevel(logging.INFO) - formatter = logging.Formatter( - "[%(asctime)s] tests.test_data: [%(levelname)s] %(message)s" - ) - ch = logging.StreamHandler() - ch.setLevel(logging.INFO) - ch.setFormatter(formatter) - logger.addHandler(ch) - - -def get_short_seq(): - return [ - ("prefix", "instrument", "piano"), - ("prefix", "instrument", "drum"), - ("prefix", "composer", "bach"), - "", - ("piano", 62, 50), - ("dur", 50), - ("wait", 100), - ("drum", 50), - ("piano", 64, 70), - ("dur", 100), - ("wait", 100), - "", - ] - - -class TestMidiDict(unittest.TestCase): - def test_resolve_pedal(self): - midi_dict = MidiDict.from_midi("tests/test_data/maestro.mid") - midi_dict.resolve_pedal() - self.assertListEqual(midi_dict.pedal_msgs, []) - mid = midi_dict.to_midi() - mid.save("tests/test_results/maestro_npedal.mid") - - -class TestMidiDataset(unittest.TestCase): - def test_build(self): - dataset = datasets.MidiDataset.build( - dir="tests/test_data", - recur=False, - ) - - self.assertEqual(len(dataset), 7) - self.assertEqual(type(dataset[0]), MidiDict) - - def test_save_load(self): - dataset = datasets.MidiDataset.build( - dir="tests/test_data", - recur=False, - ) - dataset.save("tests/test_results/mididict_dataset.jsonl") - - dataset_reloaded = datasets.MidiDataset.load( - "tests/test_results/mididict_dataset.jsonl" - ) - self.assertEqual(len(dataset_reloaded), 7) - self.assertEqual(type(dataset[0]), type(dataset_reloaded[0])) - - def test_build_to_file(self): - datasets.MidiDataset.build_to_file( - dir="tests/test_data", - save_path="tests/test_results/mididict_dataset_direct.jsonl", - recur=False, - overwrite=True, - ) - - dataset_reloaded = datasets.MidiDataset.load( - load_path="tests/test_results/mididict_dataset_direct.jsonl", - ) - self.assertEqual(len(dataset_reloaded), 7) - self.assertEqual(type(dataset_reloaded[0]), MidiDict) - - def test_split_from_file(self): - datasets.MidiDataset.build_to_file( - dir="tests/test_data", - save_path="tests/test_results/mididict_dataset.jsonl", - recur=False, - overwrite=True, - ) - - datasets.MidiDataset.split_from_file( - load_path="tests/test_results/mididict_dataset.jsonl", - train_val_ratio=0.7, - repeatable=True, - overwrite=True, - ) - - self.assertTrue( - os.path.isfile("tests/test_results/mididict_dataset_train.jsonl") - ) - self.assertTrue( - os.path.isfile("tests/test_results/mididict_dataset_val.jsonl") - ) - - def test_data_hash(self): - mid_1 = MidiDict.from_midi("tests/test_data/pop.mid") - mid_2 = MidiDict.from_midi("tests/test_data/pop_copy.mid") - - self.assertEqual(mid_1.calculate_hash(), mid_2.calculate_hash()) - - def test_concat(self): - if ( - os.path.exists("tests/test_results/mididict_dataset_train.jsonl") - and os.path.exists("tests/test_results/mididict_dataset_val.jsonl") - and os.path.exists("tests/test_results/mididict_dataset.jsonl") - ): - datasets.MidiDataset.combine_datasets_from_file( - "tests/test_results/mididict_dataset_train.jsonl", - "tests/test_results/mididict_dataset_val.jsonl", - "tests/test_results/mididict_dataset.jsonl", - output_path="tests/test_results/mididict_dataset_concat.jsonl", - ) - - self.assertAlmostEqual( - len( - datasets.MidiDataset.load( - "tests/test_results/mididict_dataset_concat.jsonl" - ) - ), - len( - datasets.MidiDataset.load( - "tests/test_results/mididict_dataset.jsonl" - ) - ), - ) - - -class TestPretrainingDataset(unittest.TestCase): - def test_build(self): - MAX_SEQ_LEN = 4096 - tknzr = tokenizer.AbsTokenizer(return_tensors=False) - mididict_dataset = datasets.MidiDataset.build( - dir="tests/test_data", - recur=False, - ) - mididict_dataset.save("tests/test_results/mididict_dataset.jsonl") - - if os.path.exists("tests/test_results/pretrain_dataset_buff_1"): - shutil.rmtree("tests/test_results/pretrain_dataset_buff_1") - if os.path.exists("tests/test_results/pretrain_dataset_buff_2"): - shutil.rmtree("tests/test_results/pretrain_dataset_buff_2") - - dataset_from_file = datasets.PretrainingDataset.build( - tokenizer=tknzr, - save_dir="tests/test_results/pretrain_dataset_buff_1", - max_seq_len=MAX_SEQ_LEN, - num_epochs=3, - midi_dataset_path="tests/test_results/mididict_dataset.jsonl", - ) - dataset_from_mdset = datasets.PretrainingDataset.build( - tokenizer=tknzr, - save_dir="tests/test_results/pretrain_dataset_buff_2", - max_seq_len=MAX_SEQ_LEN, - num_epochs=2, - midi_dataset=mididict_dataset, - ) - - def test_multiple_paths(self): - MAX_SEQ_LEN = 4096 - tknzr = tokenizer.AbsTokenizer(return_tensors=False) - mididict_dataset = datasets.MidiDataset.build( - dir="tests/test_data", - recur=False, - ) - mididict_dataset.save("tests/test_results/mididict_dataset_1.jsonl") - - if os.path.exists("tests/test_results/pretrain_dataset_buff_1"): - shutil.rmtree("tests/test_results/pretrain_dataset_buff_1") - if os.path.exists("tests/test_results/pretrain_dataset_buff_2"): - shutil.rmtree("tests/test_results/pretrain_dataset_buff_2") - - datasets.PretrainingDataset.build( - tokenizer=tknzr, - save_dir="tests/test_results/pretrain_dataset_buff_1", - max_seq_len=MAX_SEQ_LEN, - num_epochs=3, - midi_dataset_path="tests/test_results/mididict_dataset.jsonl", - ) - datasets.PretrainingDataset.build( - tokenizer=tknzr, - save_dir="tests/test_results/pretrain_dataset_buff_2", - max_seq_len=MAX_SEQ_LEN, - num_epochs=5, - midi_dataset_path="tests/test_results/mididict_dataset.jsonl", - ) - - dataset = datasets.PretrainingDataset( - dir_paths=[ - "tests/test_results/pretrain_dataset_buff_1", - "tests/test_results/pretrain_dataset_buff_2", - ], - tokenizer=tknzr, - ) - - for epoch in range(11): - for idx, _ in enumerate(dataset): - pass - - print("-------------") - dataset.init_epoch() - - def test_aug(self): - MAX_SEQ_LEN = 512 - tknzr = tokenizer.AbsTokenizer(return_tensors=False) - mididict_dataset = datasets.MidiDataset.build( - dir="tests/test_data", - recur=False, - ) - if os.path.exists("tests/test_results/pretrain_dataset_buff"): - shutil.rmtree("tests/test_results/pretrain_dataset_buff") - pretrain_dataset = datasets.PretrainingDataset.build( - tokenizer=tknzr, - save_dir="tests/test_results/pretrain_dataset_buff", - max_seq_len=MAX_SEQ_LEN, - num_epochs=1, - midi_dataset=mididict_dataset, - ) - pretrain_dataset.set_transform(tknzr.export_data_aug()) - for idx, seq in enumerate(tknzr.decode(pretrain_dataset[0][0])): - for _idx, tok in enumerate(seq): - if tok == tknzr.unk_tok: - logger.warning(f"unk_tok seen at seq={idx}, idx={_idx}") - - logger.info(f"data_aug_1: {tknzr.decode(pretrain_dataset[0][0][:50])}") - logger.info(f"data_aug_2: {tknzr.decode(pretrain_dataset[0][0][:50])}") - - -class TestFinetuningDataset(unittest.TestCase): - def test_noise(self): - config = load_config()["data"]["finetuning"]["noising"] - midi_dict = MidiDict.from_midi("tests/test_data/clean/1.mid") - noisy_midi_dict = _noise_midi_dict(midi_dict, config) - noisy_midi = noisy_midi_dict.to_midi() - noisy_midi.save("tests/test_results/noisy.mid") - - def test_build(self): - MAX_SEQ_LEN = 4096 - tknzr = tokenizer.SeparatedAbsTokenizer(return_tensors=False) - clean_mididict_dataset = datasets.MidiDataset.build( - dir="tests/test_data/clean", - recur=True, - shuffle=False, - ) - noisy_mididict_dataset = datasets.MidiDataset.build( - dir="tests/test_data/noisy", - recur=True, - shuffle=False, - ) - if os.path.exists("tests/test_results/clean.jsonl"): - os.remove("tests/test_results/clean.jsonl") - if os.path.exists("tests/test_results/noisy.jsonl"): - os.remove("tests/test_results/noisy.jsonl") - clean_mididict_dataset.save("tests/test_results/clean.jsonl") - noisy_mididict_dataset.save("tests/test_results/noisy.jsonl") - - if os.path.exists("tests/test_results/comb"): - shutil.rmtree("tests/test_results/comb") - - finetuning_dataset = datasets.FinetuningDataset.build( - tokenizer=tknzr, - save_dir="tests/test_results/comb", - max_seq_len=MAX_SEQ_LEN, - num_epochs=2, - clean_dataset_path="tests/test_results/clean.jsonl", - noisy_dataset_paths=["tests/test_results/noisy.jsonl"], - ) - - finetuning_dataset.init_epoch(0) - for seq, tgt, mask in finetuning_dataset: - tokenized_seq = tknzr.decode(seq) - if ( - tknzr.inst_start_tok in tokenized_seq - and tknzr.bos_tok not in tokenized_seq - ): - detokenized_midi_dict = tknzr.detokenize(tokenized_seq) - res = detokenized_midi_dict.to_midi() - res.save(f"tests/test_results/comb.mid") - break - - -setup_logger() -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_data/arabesque.mid b/tests/test_data/arabesque.mid deleted file mode 100644 index 38a0f85844d27d0a1a064115f248a34dc73f5f64..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 16975 zcmeHP33%JZm7g5gvV7!#Kw8qE5D5PyP8`d^@iDR`S^CS8WgV7e%SUXU%Wn(+# zpR@}o350ISO`(LVpTb6a-`;C=2&}ADS z+uiRAU%t%z-h1=@Z)VOUfnX`qMjLgd4V-b^i^PQP=Dl;*d8L{+@WGurY6C;-# z2ktnU{cH0W-D3m#@et52p8-PGb^={+ z8_+e2fo@X+?tT;K@!x^>+XopH%*ba{u;4nN6_=r@e-Xs^O+cIO2ig_{y2cB%e;?5I zb^$&8OQ1j9!pJ=T_l(S?_n`641##osKo?vEbklu6_v{3E@+F`@zs#ucBiAr0ta=S7 zF$uKgJ7});fq2)IKu>>(QPJ$rGAdf01PT=aZN3@kdMnT)p9Fe)kWsNVj3#y&6N|H0 z^5Xr2Ku=dPuVh{;qmqWlfX;pZ=*w$?p2Nn?S@Z;>Ic@mDoGYR%GUuTf6X(u@opT2d z0Ns8&^X9((Z;a;EXBo}gbt}*l=Q5gKdOD-|=iQIy@f}P&zTz{Cj=%g2pg)W-I-%pA z7@hE4Gou9+PcmBYg$;~OJmrs!$o+Po&h6*ICvtP`*j#R|+C7h(%f7#Wn@_(DbYT^E z+qa-O=N=5$S_IzaZV*Smfo5m{cpJ8W*ZVAp-4PJm4*W#CpdZG; zd%ys?|3>ic`61AqZ=v~S8AQIe0lfW~L*AF32k({@Al^Iykv-oA@rL8k>>dE`s{4Sh zut4OJEYPkO!MiX5wBv5@vTH#6hnqp%x){yJuLW;>5vDn^2WV(HcmrPt>i2+`dK{>G z1iVf&yH8&0&FoBh$%oNwIt9ed7XzKA&_6x`bWIq%y=_4I;I(;oAZGFoDBh!plDud3 zL*(WC81N?$^Yh9W<Z>#fW)9k-s5>EtS)m554{{3(dRLJ)h9*Gy-{fj))drt4+{ z-L53R{U*?(BN+F}%Zv(UBbgU0e4L2|tDXe0!3ETg#8mL{JAp3x1bBOI1^SjB=%+|w z1+T+0bIB$qnl&uB*_8lV_fZ~THjv58*W+X`-_Z@;V|&4SeHWv`6F$PI(2n^Pc4K0N z=OFtRe(rLJ-1nalc>z`zooFyBs=b|gMX3jYF2o#)_CE*o;?J2^Jii%@9m%J7Ba{_? zb_mME+v5al*O-+p4FmY0qQ3}mN=R)LXp9lKIgJ_kRrw>Bb;?W;zC^|f-Ke_0K5w-xAqJJ9z+Ko6_R__5o7o_q|vXK}a8 zf6)!%ZD*-wcw{w%G9xp0-L9E6X zOm%kvxyLZvUk()d6Hxs7K%IMlHsAta8tw!d$1U5m1($Twxk(VuAIE@QmjGS%RiGPA z0@|x2Z&h~ge-*rMG=X>5J~ZF?7>GZ(0_d?DfS$sdnV$Pspx>Mg-fPQ&{`fr5B;HyI z@>D&4LJboOPKHRqsY`*DYdoM}4a)k0wWxjz>J0EaUj#2!3e&!ioiAp9UcnL+y!l^@ z%mRySHXQ((=iqVXkKhT!ymA&$#XqBwDEZAE++NJVl^D>5ubY#LfCg7W@)KKucH9AU zu^H&yi1^9uU6@bIo^M4fN0gh&+L7i}|@HfPM!zHUII~j0$H#Sz$5kEIbMK zzrs>H#}ron3?fcEt`!F0427*fW7UPJTfrN;9O!JgP2olGy25Mn!Q0z`X5V)pawp=U z@W3yC9{v<~PlbTqK>biOvy@R$NrZVtAH4)<*$Y4wxP2DY{|dywHW1e@1R5^{I!~dC zTtK_=I#%?Bmw>*$0h0I60D21M7X5ZRh=0aIQL);w;*(Bfk>b-YU|z8eD_HFME>IH7 zQ#^7V&^Zr7^5PX}uE)z#@s~Y7cbOpaBe+fR%g7$;wU*defA3JgxL1h%`R59=!?@HZHJ=axsRe`xN-ZcvSZZM*nxxhwgjZ@_ zksb9&%_D>oRSY!%Y~hxgn}mG2=8~FA2$Ck1LERu_JnRrXYhipC=`hiVu1))t5fLKN zv~3|Z;dZtu*g^Ig0V`~^vaLQVtOvzwC0nN`ZY!B26vbs_qZGwq)f~Bgy)`>pFNED{ z*oCOI%32|hl4`A1h-#}=O`oaYzTK9!(Kj7yB^gs+Ey)T;4fpM}Y>V5T9VP1>E`!-_ zjg9nZ#HiP~@9<`AJPE%fAy<>+$u)D|5##~BFl~hpdR)^9wDS7st=xCF86*^F&yLbo z`?|4kJ-n2@B=^Iqtc}+=E!hY96&M8AsK78lMzT~NZ*I7W`w`uEw-fQ6>Dn|WGI*43 zOxxZm7b86ZU7O}f2AR~PXGh8X0y=}R&Yy|uGRll?rNr>*l1vWh@*vww4BCE=F5z;A z*ltqg7Mh1#i^pNwI!~MIgH7y5O}ar^-8!w1bm_dL(-k>Lg+pjAa%5(_Q;}cS{8NL4 zi$qDNPB-erPKQf3swMVaV&5h2VlDdbO}9&fw79(dH*b8M@H84^zst`QZ<9e&bvN^V zvO~CqHwa3IDeMCg;fZ8N?SB5A!_PkNYJ%}iY`ibKI85b z-hLiZ-^sr1>S8h8$1dSp&k7yuSlGReWynfzw`dyRICpj9qdhFc4$&h5gS@xyBujde zEW?hJ62VQJ$vi2R^!2g~yIijb53v_``&hE6k7d{$`b1K}`amCBcBc zJxS3)pHu!sK|y_=p0)9`b;*lD)00zu6#+C=mn4^jZ_pyjoCCs@**U7DsADSov_t14 zQ%6M7D1~2AL}o{MH4tZhqEe!g!hZMw@*zrHKKc$BPxc%P#)vLyRzaQf0xvH8ElJmr z&%tqU4mxt4WNVepG=+oaxOGl*2qEIa#dP+LTAibdlPuXzGg1cAFv1RKB>R~tNt7uB zt7LYRmVrl$9SM5^j@b@h>x2ibW54ip!>Qp@t`t134}PS)V?Cz2fy3Vi@AGctQT`#O z1mQ=KjNwfQe~MQsypa$|Z#+X(O<0$&TRWhAUMnqRt@}Y{xP|+0B zd5Y7k9xv5eIf`3HrUOOgQ$=wgtEzrIxe2?Vjw)F`D$q7`+2_E3}6wZXM`XI9>bhfBN*E+~99X5gz)nPR#>SCfODu5M5 z#hvO;U66cq_o8oGdFhYF2<5V2a_e zvdbNO^q4U>_vleL_2}VBjgKFC^7v))QoA0uNzFzlm(=Lsntl~i{rBdsmF%-BBTk*J z$c}5p0~V{US%s+dYL%qd$H`1x!fFFZZ_P}p3Taj3+hL~EwP+L+t|*GxID21RD^YfR zSc(p<%k1LZx~Rt>>$@I<#HDszP#qG1#&13p2>iH11qtWW={Hv;iU%LM9AkuHS4o=P zg~|E`+~+)~nZ2l)ee57jO&lSvsFe-^hLmX0M|~I$hqbvE!(*imqgiw*D2GTtPEj4trLx z{go>v9|5J3HnegT_ibfNFJHqlrRB1e`c}wSV7@lYw-WPR%Y9oF8WenI~D&@8T+PX(2=w|L>ueMl-6<&?qyOHCHJN;^p913iS=q(RmbtmJ1-Hkj?;m%?oW z^nEf%lZ4qWTKB2PQ&K>q9g^!z8p^RLX^9OENv4w~Qb5FQ1B_`YXg)LrwJS9sBABBY{Q z*t&V?NcWOXv8t0NPkL+#Cc0v{vqdE@P(;!YdzhB~NWRa`xs-nRk`{GoQJWT(RN+iB zz*v?R*0hAd($yvEx&nCZ3s=^og};Ny4>h4&=HtyNxp~ zo#HCTQK&hlWptcT+Z^wmSgMd@RKiiIQ^}W0oY~V7k;*dD<(16Q*1tlc=1}(7IWbn& zp)#mPHBgUgz`^u#C)=!2e0c-xSBc(@MCw7!fhxh~g}j$#RyCsL@If;whH5`*jwa5k z)d75^8D9yq%<3R&ju7{&!l*eSELq;d+EgV`6GhE|O2QUr$(mNw9IY(rZbQwHK+Vz4 z6h}L!6upCL4)ThaR^7x!V_g83X8ioJt{KO3km(dywuGjx*Y*$>C62Jh={UlDDyunj zlM8lfl~F#8YNNP<$5=ycjQf>wrjyH-CpB9upIg;!tXcg?r8>cVTRYQfN@ac8nH?JW zaaAWS^qr{9ySRVmI;L|7ZP3cQQQ~yt{Of_8N$%4pc$H~mHfd$OTo~JWS#wn%?Cj%y zS(@pb`_nj3(>PGUs_ccr6f3Mv(x=Z4FAZya(1&u3W!7+&#>>DfYz^Z@i`}5g%7=r^%Gs^f z%5{gjQ@O450=&~f?>LG_Dh}S4^3KZLmyQefr&B}X-rWC5BOK>uQw}lbeKgx67&GzYqW9lpP;j?J*6<8ApCfMm z_AmUMBzKk{VOIXI_CB|iNvd299z0@o?%-K+_`j>SY)qX^e;flN6T_*&q1AmO!_!8c zo*T7r>L1-(-u=(v`(}vv~9T0;6Xc^VguwfGuhfp%V+MxV|M%jVk%|8C)dLgo$G-p}sWohL4qZqqJ&_L42=%$MpJEU8f}6&xbdKA9MQM9ckSTE zU(>b!2tDWV>+$!%5+)-H=iPTd&OP@)WA|Xs%;s&z=4P$2+xm*9*EZOl!#6&I_p`~w ztEV>jkeI>KyX5?zH$SctgN(~f?AhSi%bmlIt#;$&_{*n-oo}DpJEuqOgN3XOk}uvK z9k*L2hxsOn7>r~D9wdR?1Y5;Xc>VBlkSIpTU*Z`;yQ~*|+m)2PY zsQ$mtCX+XjQzsa9#t6=;vYeeKl#AE{(9rJUyv4n?z)+F#E$ZBYdntK~o3d(~<=fFK z7xDT?XOs_>clszteKy%AreXCS!5s~$Dyq{ol+F^N#=X2Vd0zWDwbycX4X<6pYv-$= z9Kb99XaB%xgw6`!th<~QMph6hf=*BbsIU)n=*IwlJf|1|d<1~V5V(kPXjM0Tbm%y- z2cR*|SyZ(JZxY}FcMEzsSSPlrs?^y_9+w5>G3h4ElHig>V6FMxf;BK)A>f(T$eED} z1&^IlRqM=Ik8%J;PU)z6;_i7fmJU`FjTi>c2NM+ zZUJDKk!64)Be3BlP{8U@mHN=~U3mY3UW-7Pk#ZDquP9F!+aSk9O6)@#>hnnWXbW*J zaxcmv@&d|Hgjb40C~;FRIIe~w%8~M3Lno#hikN0Cb42WLYA99JS;^AbOCG(2UX{6f z?3AJ*sv6-{ff%dYl|9e?p4zi{yMH;gmAn(k3YCb-U0y+r#=!hM|tnRvM#v95lbi+d3i%aGweg#UbR#u+fEpgTew{fT0uL%Wy)Mxi4UE)Fcr#NkkDzy-`I{Zb9+ta;n|sq z`HA4bY%GxqiZj89;#}AM+3D%d;NaOy(Z=O-@|PX#Yb z%pf#1e|q+XsTuIiBU9(*i>FRa&CO3uPtOMvXM$vJED?`o`yOO_>!FF69dp6N`IEun z`H7iHToIN;CYFe0Gs3a?$kds$v%i?Z{qr-YreR3U|j`F_K6Geguv?$%$_?_RP8&1WKS%Sj3wjowor=(#VbX7Zzn7KP+C@rgYs0Bgk)*A=$e z{d`f^1ZU~dkCfKloK{-D;{m1cTl(;K?Y`5?zWVv$>q_3dqU6Q5l@zuqdHB0RYTsW_ za&ub!bj8UR7bQ8~Ohe;)x^DQpAFCM(Qx^CAK1tR6I`78KY3-T14_~)(->~7cR=J2T zUsS$zM=mJe#(x+rNzw^RY6_bc`uznZH>Z`%om3Jvfwkl;`-EaF`6%=52&?PS5tbFI^d!%n)RjBJenOA}Rm9}63SCe7 zPWMNZMA{?y(?3?e=m7ZxvtsgWxcfWGhvQ{w%Bbf@(u*3X9|?0ObxmQEpLS8YZtI^Y z#ZPNDVMh#pNf;bFCCA@Y()gYbI#tQnf2rj7J4$vuAcQ!Sym>{*i*G9_Y;$|+Lo^)R z7{lkX_)e_4?+@ZS=gWp$|ENB0`TK7x^=popuKjfIv6E}nv$EUQcU@C_Q)gbg9bcs_ z5!0f6{CIY?Rgy23HP^l-jT_aEX!!XtdaP0PcRZ<5|GlJ#;0J%Mt3UlSU47n;|B)62 zH-67L?oZMC3)YaE<9^aaQ`A4MTyycbQhDFizSwfSgJlWcc>1{OR>g%S6+C znn5lLjHboY(5UC!R9;yx_#_&aok;QV;xjs&nU#(WSWW zLC+uYx!&=m!M(CXPA{utn|||}%KMW0IpIs|Cwk49?`SXAjDAyZi_NEmti63h$xpwn zSke$yX8gVe`pi) z$cOeevi8bvtlCHZ#i}J45`@QV{1x2cz6a`mXnSPsPwo8j`XhIr9bR1@S9{fe^D}n; z6&KzQ&TECpz9g-Nv>y}7WmO*HtQ9^xrYXHSU6$Axu3wVO*B*RXl|?~&BPbPPHD61s z7|*ZMEPa9ejTxKDw(jyVo5Fv3NK)u01|cqnqQzCo)y!599a1dXLy2uSu1l)Wj{y9X z&f?r?FkJlYw!4H0E^abwmHK`Ep!fdyFqt=t;*ghH`TX_glk0@l=gy_CCASon_oL81 zqd4cvPmOioU!6J65?o~x=_;B;vLdBQRz7xAU#)vn!LHTs4vLS@+l}6nc$P^%xE#sJ z;i{wtqqimjD{oC|cO73BmexpjHxGwhyKvRv+^BW!J{8Z%KUMTQsdrrFl}9MDQ)Q8T9QyWX^PkV*I(({ z;~N|Lb$v%l+&GdGdKP(FR8PdM!#W$$woklp7eO!WAIcu*L@r7p<)JOq?~~}VB+AQq z{te&1U%m0?(I+q}xO?@LDI-z&g878|3E&T!mx~@??j{`-&fm~0cNeSIUs!z>+=Z0} z*M0A=x92Y(;ENKcx6k*Fvcwr>-{r$MROE+iAb$t)-%Xr>yu^mFH3!~t9raVeW^2Vm6e0DdVva`Zx5B=(Jvf^IaF|bx-f;^i&=rsz<&~ ziXFbn*|qyB>AL_ui(+@{+a=xi)5lktuD72NUwpef4xitTy5nc|aS%Vu!eQCcRlc`8o;PMs zvOK=Wn6P^>9h$E`KNMR$J!H&K*G2vJS|B88v7};I#qt@I&JQkyC>de}W7u6D4~@xV z4lv&>Oin9v9Ls z|@4^WiHCU0|2sEj^Q=f=sS zAg;oYu7>w3r)#eQrGc}v*Og@{?3C5wX(bo+-)q6|!1|Ss#u2{VR@P#@RkE4oe57@m z4Yiw~BNM1dF709~I$!JfjnJCfLt~ZmH8*vFUsP2Nn3409X)?EYk(DaZpN0H+7>1uz4M zOZ|L<>*74 zP%3`*Augcv({WaVwj6!1EE8|g8hM>od=N5dj0e$zp+W2k*l5M8L1M4y9K?`tk|n@? zIBBpDLWlIVSx)xEd5D0m3J!2zphjW0f@ux{{W(CJ3TVaCa02ck1&Hqk+)XE=ML<>+ zun#$L$$|qAw4?xOZx50sKxb$ddm@f%(di5|T5yX6Sz%}=>Dh^APJA`6+CkTK-a*8% zkcZ^)g)lLj!(I*$awmvq6);4AoDDHicr5QCS%`?KksFAo-4LUy&5KdBFaVh6ONCgaWQ2fNK80u4PB+9IXD?tc zAZIg!GP+!p0b+!dWTi1@JclM{!ZVXF#^G?fA_ z?4yWt#q5U=nDe?Y$jTwsW$X>_uYl~Ko-p%waOTZQwv6zgM)X^4^hS2yN{|*wb4I;g zGK~q^CxtP_sgIMD1dmRSkbXh9Jfky|4blMt+~+tkE1*p{H~zpW&7ciu21x7*#F%g| zAnC+*1_B!@2GE;An}BE&Sz~$tdk6_-5(EUdUBnJ_iFLubT03wUfnib!$wviAi6OKD zV!`tS2m&~^TeXZp0QVKF zmLUZX?kNQaBk)P~p2TZ%E?$pn8x4f-0p!6y!vf+ql< zQ1HkzrQnD$4t)ApDR|nLBNGZ9__VS-fj9ztS)ua);wT$sh3dWR0S>bwgQKhf^Q>@b zm=$0S3vv`#0ro#lhXG<8F#D8NtY12t!m&yN;wNS7K_E>FfKb~6X5A`F8rV-56NfWR z#IR%_G(#4`UWSCRN<-L-qk)8^Nf@gvMAV}}KuRJY#OcHr2g{m|L!tMn0aA*-vcd6xmCKClnF{iNLUO5Krq|X52arH~h$AVVQ-TEA7_2nUh zJ#IPA-)T*t%U-FIr;m5=&W%1hjOEyF(;mfg`5pTh9jM; zgH4tbQsE5AE;(~gaqm(U=o%#-Q*k%X2l{Ls=ra+A;AajZ$DRR*Oa7_aYr*7laW_;8 zHLI8KOw;P@0Ou~fV{Hmb$({j2buK)=Q)xL?aAkJ@H4J64G&OtZ_ z;e;#$VG*!O++RrW8b1ffnY0@RPn?GcNJ0c8g95S=&>rSI85EEV3dl)?bC8I{5KX(x zI-v+y718Q`Hd)Rs%f-E)I4FkjOk9dv>v$$$)`CnFamj(sPv<E!^GO-T9kOxsMAfy83 zXfQ0e@#HwG-`)PtpbOfa)B*>A?r~FNjC?W~-0j=nx=1A8q1~n65cj1)dx_%O065B( zLM|?lz4Q`Kea)&3(#Bq^UEwRMme!@P1vgqus*!JqOQGmuU4d?x z2#UTs-}@JQb$-zPR%d$ZI>9sFcW(2m#or8~gKW)n0oh4K?RL>tfzSy-LhTM5I&TW`pP6Lh=d(of{r>$&^hsJBcRK+hp3C zZirI8k5I65o9Sp%)oES^VrB>V+rdw$iJ1foSJWiG-7xXwwIC|$_PFVcn>g`DP)ncb z>@$fz``ZvyNOsLuv{)A!yUb%7JKd(UJ6d#@PMX59Rd@gLW>7m(!?$yeISs1ACrt;fZFgEE9k012>4=*&?SKHA*{tcxmd6o5tZGM{)=tHvtsSO=8X!Q6?C3H* zUDxD=jHuh%c@gu6tEw4yw3!~Fwrr$+d#pUp=_ll>d^E*O7pcaBCb~cxxIAua-H|n0 z@y#}+HjD3qszpo74Qb)+5E){26+q~!m8(ArTMRw1$s`(0%O)Qk<~*dT2pA|D)%WBW z)%PBnnw*&a{H<`|`o_?lhddOTr^AJ-pAXGJ$l1^w3Kz=TL(>PT4b8KVdo9=yno|hW zg=QY|$#5L$>@|jF5P~7G z0OxMI+K;&NWj-{$;ldJzaOvvjLQ`M`XFFs=XxdPHIWz;vZ41pAEY#@UtCy@YEOfxR zJ~Sr~KN^~2kQ1T#3o3e8S95g=KmE1LNdSt5~(Br+8)5W1A_{?H7E z3nXbfqzhYdP1INhB)ibuKSV^`KptGg#XtC=9p*ac5gza9!qBWQ9wtl{#;fWJ;)2ndO zsMmY#-b=RAXevp1b7+cI%V#0S)ukm;O1n#Mkucx3MfL3ZuqNbshg!cLsAo$g`cnv; zR`FLkJH$}tc!}4vpA?w}kggy-(ki5yW4Qz<{X;joPTK}ZR|v>)Cf6cIm}LT^?_S-n z#lAkI&dWih53YYo#jlfO+WqKnY>aRtyuGG3`ZLs>{PCsy z4Qz1WTrA+vXXS5a&-~WEo%R12h%3SEHBCdaXD`|R3iOvjBAd+w$1hd?QEX*+nE-(e izx{7yBgLEHs~esy&KGBA_>)uqsDYo?Q&P7EK&(&Q`B-7oSmJ)3@|{W8;x$D8~sM3(Tzqo9)lU|a<#Mc zQcF@Rawt+_OR~ZOF#8}WE``QeRdR`TlnVcGnQ>8KF?7XIIikvWRLYg4%8vX`;)n7_ zUX?`Be@>!&zUSWi>&MK_?%I-EDNohT-gC}9_ug~QJ?Gr#uhBjDH>PdoRKUvz)+gO6YT%ojfOg;zfDlKYdJk38}hKJ>WzlcT=?=b!BU;Cud$GxsO& z`Jvh;KJkT5e)i`-@tLoS_<=v(g!rTVPrdk=&(7?B_A@Vj;l-Jq&o1QOe4NZ5ZO#pp z{L!z$^CPvN`^=|5HPiak=U@Dp&;IPEUayez@BE!Br?c!^COM-FexgU;WXs#{b=4{Ovn8!9Vy9zhgpu z@aA<(ANzw}P_f5;cg>RDocZdEM~`olKK|JMD)Qs=|ACSp`lGKY`HO$+f)cu8JRdk} z4Asp9|K{_rwLE_ENu_U}|F=Hm`K`a`)ykt-HS?GL#VsWt{@-6rlz9=H{uD{jv2_bn zT=ec&B{$^PRVU7Od2w_p9(82Y`7qWU;^i{CxXPc-(LIz4K8_yM7rZv{c8m`yVy<};H!Vg75@s|5YV;}cc=x>cF|#ihuNfczKHL{^&cKmhWG7O1yVgRa|)dhOHiJ|8g1+O}#6-m>Q4=qqe}u z5PbUadm_#fV}9wGfB*F!>0`e+6QhM$$e(qzu*QDuzkbeEYgfMUNt5xBU;n<8^2o`X zU+*h@^L?`(wRO%2eXZnkPkvVEuYdofp!9oBp8lB9-}51e(Btjpmp=E}9`r~5um8mw zBQ%{Jg2x^sgx{PVHQ0DQQcRa-%#sO=YMt%YG%9N(SEAl0_omN%aM`IXrWthDbG?0) ze;Vy?A|(x;?vgT7m~ArVmcG3NQ+=EgsNprBnEFGP|Gw1pkvD$HwDrfEZ+roihECc_ z|LE7AB?%peYI;DmUOoOmU}}Pj^A_A+k&eT*7Ot;t`>KzZNvBG|ZCg=FTQ=L@ci7db zzD@P~!DuLJz4_2z{=TgNak!ugGkG@uW?nz?W51`~2vsRialL)6bPcm~4fE@dyc5)O zy-|!UysHM)<1t|;+kbNsof5~+41VN|;xAK$20|wz_vztO3lElVig^mvQUzh(xT5rp zqO}PAsuTWu>wgXs1nIx70YU;xe(d*#7bnoKyflHn`8kjJ^!L`k{(f-wUZVL_&u+#Kec(~E z@)7!*QpiXI>){BE0(N}a*M4q-*$!M_e*LeCeDdZsCExs>Z}<=YU?{E8t*I6hITZGUQvQ|%{+^H^`Cyg{CS|#y8(*Ct2JQz+tjar-)sMi zZBHL7H(ih6Z@$kr_v_#H{I8U4i8bo*gM0YqzE|b>UkpMjvdDs!KKhq_?BqjAANvP3 zVNj`^Y0?>6Wc>Q~E%}{WO5Qtv_17vY{MOzP3PnNR!aL#!1@7Q8(LjQyP4()XnYjhC zG5fZt@TmKHUrtHxFSc!wL+8T4{f8})Pu}$13Ja_+w*OSfevsbSxU=P({_3Jn{}V;M zxR{G=B@%gCPEfCeH#STsc(Dbew?e%S{TnYEoh0b^!KiJw85pz1#40o2^459ryFMQm zF{}Mz+pm594Or*ja2mGgtFM&*fm!Vr+qOu=$A9Ir^1t~#TcyGGA6tI({WmuL%MU0& zgp7M!ctT09l)v_znCY#B>tsYd*V}|0nzXs;jSU}`-h7`o{(s%D^RFJhd^8w^$rwE& z-FwEH@AImRd{K2@Y{n9WY-7~F>HLVtF#4)a;bV{>IM`JSTS2RjI|JX& zTTJ~#;RjFsevrP{o(yhNA^Sn-oA2|}4lU^dhG;v*?IQ@LXxU$Ef91zLTZ}Yc$p!qT zQu56dUn$>N3wYD#M}2zzzEY03Pfy!#NG`v1$% zx<6TT{s#8bzt0WqV>_Rn*?IA2KHN)eD7dzyeE9m*a9Zk1+s$=OACrTa~C_5lu1+=0?A%?Q|U(9FJjwSTtXk30;DvaU0T$JYT4 zx8XSy&!L(;gmi?2c zCeqgN6Nq$hw!ayHMg-~sI4Hi_hxyB?vx8@8#)_>LeHZh-n#268HFrR&{Jq?)w>eT% zFsU@3&{wFtPqY1O1Txb=?;5>Sp&QNDe%N;^b@#+lvD_i93+W zlX@xn*Z}3o2?l5TO9PZQI6FvTcn9wI_E~S)yaA}Q3e*uTL@1zE?W;3hD{4V%X=Qwq z#bpR4>(wADsmY4WCM&7Qip&$T=Ax_%Q?erSgsjMHvZfPRvx%&^MAkK+wxn*Yf8BXm z)wxaUTn~lzJcKdJEV^hw=FlcpWZk3{^{Y;*VXY#ChOgXlR<8ev4v-gY9?_MMyED<+RbYXh1{vq2TbZ#fKY)|=t=l>t6YO*voFbyCCg z!l|*Q8PL4T<`t$9=E@Ew+DyTkJ9n9Xc1ib0_egIryWAi>B>W!=C+Zu%A^hZ%XS@5D zo#Mn^OC_Lm4xv~i=%`L88_zMQ_A%x?O4 z*-W>YEAA9uogp`+0M~)uT=79|r!Yoe38{M0kUOLnepg_S!*v*YjbLYa)^{i4 z89NymY#1;eTNdG&GQT#SC%l81QBdM1|D7!$x9JeKH62pe`h&#kNAjOnrUHj4)5i^z zxhipsXU#%!r}Aop%pW#gX3^zZ&}C_io~Hr{*}{?$&U86XkLD3DgEgBBSh2|fb25$% z#U};&9j$ImBU*n*Bi(XP6{%&EkAH%UM_DrkRjmhAk(vtEgQ~WQ-oLgAvo?2iNG09Y z8Z~TDLu<|Ju*s-25wV46(&-lBrha$7vkqz(e+7Qh}rpR-Mqw%_$q?~z(Y7qjbQcD-#M=X-QZGg?1%uh)U02g7v# zdi2j`<*hJ)+jLM1nD(X@e47Ck-eVxGMp7xPrvbH+^wi=G11e4QNR3_(HFkX0&pOqh zGE!4n1qO>dVM_7aF`8zirL%7Ex?VBVSd2Ady*3DIp+-`xQQp^~mu-TJYG|`Swk;s3 zwUnm@5u!V`2$Hs0BsDt0ax?oVw9ng;iBM_Tq?Wx*c@dh-uQJXiQO{CBrWtE06d<(% zsaSwmYK(+UO9}r?cjx9=N9%6M*5R$N2bLqCg_>opo<0Zg=-GAwU|Yjj1W=^vl>jub z{R)6gV7CUa3$RxQ*aJA&05|}+z6o$0;IIX72ymlK6E|pme;Z(5tE0}&@h55QaW+!N zw>#P&A%d+`8<<~v>oG#Y@u%3QlGApNHQebDvcL8Vux&EOyjBDF_s%-`NGYCblq-a- z(3@?JMCb(Jw+i4*2EdG~0U+zzZ36hk0)R~XK$2>+bG-v#dpB+LY{%9EP$zEo0dA@j z2m44T?xTa_+XvS@a-^L84G=YP2++Uj5!KQ=I@Sg#b#`MjRCkCSTK%M|{)k|CM0$f% z_#vtC4@lM4K9j<}a8luiq?V7l3g06&ZzVMsA~jE8u8mjN2kkopB zxmyoOt%m_WB(;3pz$#SHw`4+Cba@^B044>lfdK_6~GwR_uZ>Cn;f50}bRTXw*&o_@4e7sS8_v3(Xw-S4QD;l5 zt;R;h?%Fxz;_OoA2+al?Q~?{yxXp$;y>-^o6AdXUd2OL58q%OY?ITqo+sms^d$JIl zY87rLg_xto+gT5Vw$}i3Qn0ZKpe=tr2O#lkD`69@+o^H}YMBqeZ0}s*07SR88yviB zkk&a^)b77y)ep|LIn&fx<}UY5y1VN2{3id#zK%NvXE$$fKyuTa%$#j+oO~G8)0MMr z4(C;7hq`xE#tv;rb)2OT!p%&DoI4*l+u=~bSU613CU%!4n3~yLJzfB)oz*#QkUTZ& z(tTgNMQDcu&>nKC;0|)CqU*?co%FCr1ze(~&vrIW9w+Oxb+*G%i8aO%j5S81)>sa? zbaK{P4k}>$22Muy&0W@u!)rTWjJFqkzxmib9iGZ zhd1MCAB)$ed>{Q!q+CH5&dK#eibELAI9bZE{Y1)Hx#5i8IA4!)8h{@4DeY3f9huX< z(rx6Djcp^R(r6c3^viB>>90QTcHyHOpB7|EOHT~e zo*3BW>oV!GbGn^b`iQPI+*$e~gFK+_X|kkTYo$efa&W9P_2gg`a1}5^nsF{A>QjSb zrHfAua)3F&G${oZMV--g4=v0L@_?#&`JlD&MO02Zw4O$R6%5p#=P&>^5E$^pKw#iD5DbVM z>^cr$1A(L1KoC%DAaqu1iY*4k7lPtRZ6LUcwtQ}w4CmI}a8luiyv(wJ;ANKKx~3N4 zCa!EP=n&XS)IyEbP$Q|;SP6I~w7^}o@snDO<Q!s7$GeWKgG@<~}OM z((*~Id_CloD*v$N>HAfW_eguteKh8}q@q~@ysT)-gjVwTT~b{Z^ii?sT~fO|Allo4 zgIPc;kuIs|ZK9~H?MSui*d&_j!BLypMD$Hkdw_&dXn_&0OMH|C$7}8SglMb$Dn_K& zKtAe%Jkn7bjngm7O~BUSh!%M6fSy)Tod&^w0a8h}>KjRFoT18VBpO_V*|fXcLY zDggGpAYO%bv@=ybo7$PmESnn?SE*JDio`wG0x&UjCn_;|+W_*NuI@x7&(;or#Ax*Z z%xJ~tpmTwwgNU2GN*qUUB^WbaoHI*EwC9&1?o?cT~NYQI5H5Ok{?Geu{L$9e! zqXY$8vhgs$f)c6GS70zr6Af$)$I}4o4X>$;o`J!{#CT*>{+%O>1FX)B4C|zz7awsY zeij2X%~8g>5gt*Q-?#L3$~32u@oJG>qParxpvp@h*+O^qJun?V)phUZEq3u}^VJHx zA5Z{3(4=IMnaG75ThnvMKS$GPqMN4F4AIYC0>!se3NdxPol?vTFT>fP>Ft#4v%~zP zG`^wg=+Y4TAD7<#U8>;p?im^CMoHqC4&6K*;CN;WVZ(syxn*?h4slU(0 zZxPME#;}4D;-?T_8cZ`@){XdfN|CWA@iiz1dOIZz_$&M3S+82Wv4L&zazaxe`HST9 zoMOFrqc%p@+bPo=A|;w*UNOxv?^u&`*fht?oIKm=>Ft!7KGI_!jlvDRK~iHkcygok zh_^YG(LpgH9CkDzvfXR@=iN6*joyaAW>98gT|OUBgY||P-OtQ0n3ydesCJo}wb9vU zN=9s^=ujh%cdSM|-m#L*)mCF8)JSUdMyPR*$$F0(_ITAJ2RY!mmRj1Y!63pev$&e- zRlh)gk1w9>GHvg={@DQw8gX97P_CDOBkvcET@~BcIzV&i0G% zq0I4z&w9o8Li>#Eq6%NY!fDap)%VZ#SfrVZII~rVGxt{?h4y(x@jj1$tU6M$_4CIw zz1*YFK5w2L9?kbxW_v5|fl6yTE05Bt&ugE(ll=mbn$~u3q8<9z$MLN?9N((TV7~SS zi`l{|On;vAm(P04D#H6qp#X1#RG^o72|Bij^BX?_!cu|)DY5$W7Gaatx0qjEA*G#hy!1EwKog8igX!LLD zdC*PV;fM)lfBP}a^z_Ic2laM>S?h67PmhDB+#UyY^)^jS?(4r={5c40SmVUI2x9DfX3`Jg>f@A4wZcj1UG+$3k;eNbI$ z>1j}lhcgGvp=Mh9_`KtQ}HC z*SL1|Cjle;o~J-|LNHp}HVB6e%GXI1zx|1mPSGaUxDjvZX^{0$iwyN>pZURhNNPQ_ ztrONlYR*OK%|7(Ntezf|dbbGqwuXwcJv{`MPCIyo^^jE2e~VPc96t-_;1!lnYO@2W zY2)3SgpN12EAS0VH}$J=XvW+259i-X@;{fOwMk zoqf$Y76QruiV>Y^D7CGuZ7AI$70Pm&YxG<-$H$CYL@d?YQP0)VhI+0WfGq*mWF`cw zW*UHxFwPnR8UXbr06lSRreIdRjd_~n#p>-dI_h^Xt@$sfOys_{BIX#{TgYX)ag%Y7 zf=8y5e1Xx^=D~Qhd9Z0=Io6|vxz!`Uc(mj;&yZoL8|~Fs`<<1?@QtjYIsfrnivuKS zugyV|{JX2qHmx8|AjXa~C7>#8%2?_=@~?})iIDILCq!mVoFoZvaE@Z5N2pNMo4mJH z#B7mTQ)Cnz?Ssg5V!iZ zDQ+t!ZY!1!S&iiY)Rk4YVwydynDVTcBCHjY%9OQJzw^s1z<-b(=&{2>XU;ZD0Gj~& zTV(B%)!%^9hqB!O*ao=SBkLwv*SDZthqBiK*i$M01u@qi8yJc1J1Px;iu;uh@l7eM zY7iRgnwqaI0BE66Q9qPb)~u=#+fZLO)Q^fFD;7Ic$wY2VmEta~GSITOl0t+52}NO2 zl=TWc8pibvC^nvo#Aaq1VYBGvU@+^d$9Pr&%*v^a-Wi)eN2D!`ftX1sM>fs%)U;W-Qsp4rfc)YC5iG&RVdGZ#Js zSK8*NXQ1Vw@d3z;?~&TE9xVW9JX#W}1<>SP*t*t{8Wp*=txv+rS+R4VY3wHs1O4E{oGh=PHw5U@|01NTiN3o-}W+{ZP8gKD~Egl zOJsQmg3^FeUjV2BgvkmE>O6x9O44mkZB3fXceK`;IqOKVGW#t7TL3NT+?tpNP!kOS zR&!V#Z?V{wSWV8!)U^MeS!V6Sb(0Gw6Nyw>t&`SC8=N7VmXSgV3~}M~7GgvJsnP^o zhyrG4+yqqDO~5RG377*g0rLPR;3~>cOIrdYVm%)y4Y@v_ms~vS@CCXJXeRKq(g2cY zOMpssrBs`?Qq*j%F9KNWi>J3>rk)npYi264B6W8|YagvGnpXmalNDbfM=L?*%K#>` zVxq}Yyx_kiU%LRad~o%n0CI>m0ivwW0LW>IPf^?)LVz41|0I(LQTFy({)q(oFm{vI zI|dhjRAcG~HJ;;(N`4cdew7dR>wL=3pg%!p4$j*10BwLy3Nw|Ox|%Gdi}n%}&E1{& zC2DKo^+(#NAEIE!=5$S&(#WPJHpSbPRym@ymyegb*$-3vf~i%9kgB}B0$_D&tz*r_ zwN5y0qhmc?Qu)um>$p{qa>)r|uXsGuUH>3N+gtZI{ZW3O--a`fY^&b3)YFM=cI8Ya z#qw-7_d#0y;8}0=gOFaISJ`mDOOaOWP+L7VF*S+XNA(@4#RS$&Z~75nO|`WgFm?fA zYF#;%)!>TV$lg=sCaYqOqDN>&vifG|r#_Zf^ zCk2~yD(9tBoyIN+OgeRs<6hoyXv)_jdyfUC-;wLwUeYDUpIx1Us>+UU4#9Pds~TH6 z;WP`-keLPO+e)kYK>J3mf%jUmYMtv?)%qk+olA~=I&CNtt=$2zs<#2G>Mnp)y<-%l zZnyisL!G`a)b)64$H$rWu3vO^_ZdC}YB`YZt)m(sDYGX){oB@PD7+=7@J$Vig2$%z0TVt`m-8r? zH8nX~JYB?!h#9xeN-bp7oL*KjO<_%Oi>a>RRWUv!_c~0z|4vM<6GwheBQFt0n-|%& zjYahPjx5&3Lce)nVjw!pBoA0BEUq9SseiqF#+0Ahsm0RplJ4qPj%FmbxeRUOAPCnfTndr z8uHM>ulmkDs2}Q(`mv*8JL-pOl94l~ge(CZ{^D(NYV$Tt^a5^Zf2Px$T`e+p0rr;P zxMf+p3xqAB2oUqZB7hm108w@Y7*8se(j;dT0ai+Ym0AvzT>-o9cm{zYlK?Rtt^gdW zF#%Rg0H^0OXNS3vcSDn`q&^aG1dvb;H4&@)Q6Xln1?EFBTh*#(tZI~|<}ETvSQ~0h z%@%YtHQS)YqdXQRHe@VH%oOc@7E7FmQGQUCzT_fRX-9rt7oeh8%=O^PCWb!ODL{>n~6&kR$PD;XA-mGd+Pvulw!*B zJbeMGaIYM)`jtT8QpodM7jPY((WQt=sh+-SHhGdpC66^K2_-Xgu)j)UhV+zeY28*~ zaUHZJ{cO3@Tl`Q8P?2JZS{u@)RK|4a*HJi@uc)t4Q}YM4wiB=ep-ry+mvM9(E!X z8j7-d#?gjf!g9dq?WUy0G<6jvnJj5d0%_pNxzkuo)-F;})-I-!jNFRV1BKr^2*ox| zL$SgRq=Jg|15;@iz>4h!O5q^zRIdY=w1)t4{?bhVF*lwAu&zB1V3Fb$fSB7K2QUr4 z1Yj!qDFDsf-A@CUYkda5^!YgeF&_;u2n&qLx&h_zX8;ZX4qgE`0NDR&0NF}k_4EOH zp8~Lv`XqqW{4#*mtVT_YPXL%0l1E}R1eoxD#a~ubKklCnm0#@pZjzcxZuRZ^VX3#k zulb=l`)nAyf6&I@8qOgRMjHb;j$|2a3{tOs0iix{*5=o}OiFpXNvWCImN}ZlMIrnw zcR9bLdsc_mO4dWIr;MVdmNlqVn6%In;5P)71S+Co(^bMX+8Pi~-%H6Sxww2kOs_9V zyKk{I=xf98h-JuDc=N+tAViKL9wH)ffh<)YF#EC=t^i$bH^T( zy#o3akgP5Lwb$Ny^&=bG+lp2HsjaQK9Sh0OssdDX z^9q0>Sotb|s!k3V>vI${e-|L;vH%MX^8j+2ssPIppseIty{r3mS!i=1G|^cE&@eZ2 zqARc5N>NIO3@ycR3Dy)~5mkVBZW=&hlm(RC>8-N^?aVZUT>)xiPy0rhNk_{_wb5UJ z+X9CGa|Quo?kKd1r!7DO))in`dCU6h3~QCZlw6=E7covSZk8clj zs;b8(dpc*;3y{7}UUh-j+XT>!-c}1hFL@92ONm;3?6m=GvC#pr#fAV|V{8N18bg5A z7+YO{Er6B)osKkj0Ga@e9zX-2Eg9A0Pve)}}ZOkkU3e1+b_c z^r9%*Q;U+8?;DSP6DWm#=D25>6ECS))7+(sAtxKr6S6t((ZrDAa#dAiNYz|U7Gn%t zUyC74s*+h!6JDlZ4BTs54Ww2B7eQ7-DKO-?AyO|3+!$#JFOkZI>J=Dtm7e1!N(@Er zZ{_r5Qh8s!78$tE6+?;C9F>b*`DuM4GH}CdHIP~jq*lXbV5oETE8E^A)g;v>m9O-) zjFUyxc~m8a*0zss9ny{%wj%>a4q|ANnwOEv$$GuW&<{0`%2zs^O;`=P#&A4yR@X{J zb=39&Y5;WsCf5Og2`0cK`xt;cy{*-ZD4PPrQ$KvwZQck-^Cp12byI*s2q7hjT%Bkr*N(kewlCMQsDx&BH{74KbtDXjbTmw%W@fbS*iuXqmD z%X2i#xhhAqtk~X<`hJ}K=#${DUyyRc*)L^BN3+Ykdb6L*%ijNJ)^~#<$t&D;JiFnC z2eqJ5bt0L5v=Gy%jZRiM-wvxv9cGECVt*p@6+m6JU+ zM6#JToxGU_qs2L9OTKqDmUKH9))|WisI%0s@Xry(*j18-_t#S zfL(QN7hqqVlbaj}P_aXGNfq5x`I`Vo0?bX-BXeAJNR*ob#KX~&&*E^v*OLJIs!&br zs)>n44<&{Pt${G11sJo&NX&e!80}LlaBr#CIK8ue$O5cNQRM>&i2(0lUi*_o`cw--Kl=y=#o&w$|0a#YNwQ(g->H@^m z(o))r&jp^k04pwF1D?(KkhLK|l;mhGSvYSl1RgFgL-EDHyeU9}ZKMLFEOIYF#`n z38pLpVJNO%SaJ1B6()+}X_BAx68&%l?=?0$`CJJUZZv#7d#eC@)U!+2_gzZAq}rsBi+ir-UQNF`2_D+3BsbRY|`}v7)$5WhyU~ZKy^K7;!eJBR0RA zV*=e}Y0W3p0Z<5%W6-g%3()zbb^&z$2^0h%M-7_t^Djqoen!SLpTJ7kAS|5Tq8scs z>)L?IOxoJk+m{P^EM&ly21Rft}e1wJ*7zY9+s|Cc@Sh zD^tz(rLAfll^|GEJa#%N9s}+*Fg&8P7oe!Z3+Ci6ku8MCsUFDDJkiCa%q>Wnvx%? zzR%^KeAkK#u+fs%rH15Yyra>wECI?&ZfV2M<1++WJ1L7_TMj>otP?~lCBQ0N!B|_4 z^K8ED@tcLiBETWVZmt8^dZrX82LcYrI@&qC#p*%DHUjCe8JKPguwt8ma?lDqM~A1e z*iInvn+>70K7bY54U~huz;kqnwDzl{2hRc7cN7QDMri*C$e+61G}>z{qSc}e@)m$P zx-GzhU8JHis_v>YR*Yj1AFE3OYgUCw&G>E3~XftSQ!Cg+o({6f4+!GPEOrDy7r1U~5NaW$g%1;Wj`?rX^}w0DKyds$f%~?9B=b z5Zfj|L8hQaivpr-f~|U4gUKeqWZQB@z1o0?U-Zv<&z;%~b=Exy4}SN;gI~@$%pf=j zF;nYOs;wg3=PA;q$ODtS^ODxE--pUQAl~x!B6>(8^~*(Vq!LU-3n^K24D=_2uZ7hjLjqS6@8# z#|zhg9^g8Z0|AN#djj?V`acI?p-_N@!Y=@rDg-F>Y<(VJ3!wG+zUi$=s)@Kk+90j} zY~SYfDya=CsYZ33be*(Fsu5fJoc|Ue|GA<+J0#6P&yr?IGo%^PrO%xdID`0=4Sr(b z%c*ZyG$-X1>GP^Y;d~9Cm|>Dtwdvhr6?bq3E$rEv+8gyn<5@O0t| z+%f}R#`$2|ybMJhYl&G|jb|v|Q2EvcAdz-H$`WJ8y^%j_JsC*7k3i~)bmNzg!->vc z1vrHL;Fka_0RJL@4a+Y?9s$NP2t4~5v3*MQH9l`KEE)})5d@ey|15w++x#4W+0X4$ z6E^w#Rv-D+_}{nsEoV>he%HC*;*S~d7p?faRu$)dz`1`;`dg3v;@90U*LRuzS7uKb z4~|TS1wN-a!P$nzipy07Ckz|<9Tux@mB*wfs~#<+PjA6>dTUtPa|M1$NkwqlVQJkJ z)&~hzaT3MytORYjT5FWK=k!NJ zVrIRJc!cR;Ys=NQJj`eeTi0FVI(1-lE%!GD!!m!a>IAnKa#(C>mA@l(g1-#w8w0CA zhr$URniL-JYkssn9F|56+ULFo1hi>V^-2wAZDL+ z`Db16*@0-3v>K0G>BwdIBUfqwCGc_iu*#phGKs47ftO4os&9?PKph!2B%*-2M3myz zpLe@hf!{<1pfLuKb)^eQqFYfo!yAUHMtJ#$Tv9!EYn+ z$gLf@@{wr?vKSfm)tj!wAIs8!Q31a)#V8@J4HpVjy2-W62@Wu9Fr_8;uYn;oBFMas$H8K(9OXxPbzg+=A|^$PO8x^T^rCl zEQOTU`a?afmsW?(#lZ!qAt)JbZh zxz6^6o6D}d?CLDE~-8R+$ z(6r3BwBo8Opja?NGAw2szZ0VIF7lvcxOO%4mDKu5pRKQ?)>ryG(N{US^>tyYuk?AM zuk=}c?WA2R?RqTSPpqR(dyQc?>o&8l&R$7^V+wpnO#n>``xF`0RspQ9+=0iw)>Ez|tkL3@GY56SW(X7hzJ6L7_ z{IL~Gmin^3zTS$;4~G5R;Dl?(VUN`Cym0C*@42F476HJ>Yz5m6G`e{)BJfhkPu5!oB8jmoQ*B4yV6#H0AkJDc|8g3jykux?ssuN22 zSt!=(j#`DXDtE$(hikjClGFd^@nNTQx=TfW^Gm~Z3}BtMiVVE<9pkpXqYD{7zfNOm zoxay)YJ5WMl~D-FU_!9Y(@80VpV{$B(X@iC4C|W!8tWdPt$NCSpd9dv_9sl&!xq12 z|67$WIg1{I+zE;0EIpsl<51D(8-a&>7H%P9SsI{A9rt;xy5 z-1evD-QEKkZie`xSm^>wSM%+I4HYF1P#T zU(y1Y;VZdo!xG^|6JV+0Y8BQE%pF4Obyp&40!oN_!&Ms8Mg$RBTgInq?)27hLyJk) zvy`%x%(|-)z9Lp~H7&d&T5~l5a>N?0rYZwU0BEXK_AmOgvG-J5a5ds%EIU=Ps!B3# zO-;qf1N*SD=IZ=tsoX;Ib8Xd?`MO(1rrbi#RrsJ;9b;uWTwQVH6<1{HwBYHd8%6q+ zv|Hud0jYC^uYvr6XoE&I{nS~dE5vu0GBo{c_xzl{yYHu)9zS17oIaTuR(dBNrmfSJ zVTCDA3krS!f?iPvy;B_{ALr*w$+`1^VU3x?SZu~%-N@{yW;gixQh+-bhV`N^Tw@*- zCqJ!jQwo18%I3iyf@w-=^LwS5=F-Dj@8oe_yw>jfWM#SDg1WP6cDrODrLT_dIVI} zgr>o&S{G1}V3_`}BAEs73HxoW-iGVypa6~zZ)?UJ7Nr&eYf_Y!cf2ke0Qip`=-8rm-x3!Y8nBq~TfkMNQ9g zF#xHj0a6rN{3yVp`^{U!3hUMJ)%q5Os`Vy+IMMuJn>Cp&%CUKEK=`YhC%l9k9`en| z$stu8T*LSkM)%d+zSYT^ccM<#yjCZ5Ypu!8i1|8M-&vh_snuDJbrxftQY6PD!|Q}f zW7&!n5+k#>V zwk=zBy=B)@vA8;A8=ZBHEEJqkdxfnP*U+L8RQqmrXEnu2rW1+`tE)~wF``_#*!r5Q zt@&K#A(G?3y~AA-J(nWS8UgowSyc3U`gFYga+p z=+oA)S$DNMC}LP{!<8GL)U4dvhAVG)RMTE{)2(lUB7+*Lw_K$Kil!7OdRwj~GxHi4 zcH6Gi_Nc_z-geDxP+HNELaXZ;4X~>a6x6up591zimhpr~hqn*wS{1W2N^G z*BCLgUPjDDlmW64WneU-{HUj;6{e+%wo6U6OE!{Bm^P9)k&PtFMU7;a^^DS1)pfLE z@(b@_dD9iOZz}p7BvN;@dy2XHuBbg!$US)3bzXM6FNb9SODgyj4pv`unHOF0MbAp? zTIF!{w<7G4Hd(36tS;NOxh4-FGtiGPV1%c)hK&VR zUGP+WcsWd8by?20H5H`ocnKFYOJJnU5(uDX39G@k8F=$;7CYwKva||KG|`1w1gjpi z2x6vL1gjtOZMtN>-Jt*jIc)V@r{`gHqqFB)d!VKxPLy^bw?$R)C+7il9=e1#e-eu=uw%jaJ zsmonMW-8CPug-Yx*Wc&4sq31{USp>pxU);nre*3HF7v}YC(AB1T-J=e3?;*__#jnk zn9=G9@rLOv6n^bxnAU`^`8mVSl1=iA>e2LW_3#5Oo+nP?>5H)*&D5%A>B-{)LO;6X z{^!&%cNsIwx&N**tgXEID2{Sc7;diJ`FSk+q&?g&z52rp>&bFFCLIh{Hr(olEAVPv zXBf6LKl$Si4a+>sP~00%?@vnZ$;vuFjVTRQ$b}VN1?UtmG8hU|8Z2jCpE_mo6UQ zqd^bwz{P+v9}QAB+R>kc^sVp~*ORJquFQ9XqS{$*Y1UQwdJv%=E4f#+a-l!^G!3UfJzKx5+aA zA6AjG^_=TS**0SwfA!|D$`>kB!B~}f<6~XWZ`euu!>Rz&;Zo?`x&RqbQGj^XbY#4y z@-4nYp-~|Ho)=P|-r{J~Xhb&&XF*ngjg+j9lbBdESZC4Whb-#?uUl0c~|FEmyCqV5nPf;v7=z&2Wuzjk*Tfs ztZO~%sm%AL-QJOFAGuAg{Z&1qt*2ST-c8rJ=~_H^@xuW1}@2;yBgc@q;BofBC1}Fd&h=x2|qPbj7Vp5<}rGPYvE=FU#5m z!<~%Ilm>Nv8?#Yp(ecD=j9Uk~5qIY`U1)+v$D8IJ*FTj_jpdXibgLoB%rEIjW28~7 zh0YrzoxyTmU3U$&9o^VpIZ5H3dmO))d%P~yJsJSj9WHKtWuQ}%a5ml4SToz#Bw2G* zT3@o9H%pF3cDGLn>#&?R%dXA=UV>I!y%JGzS6!VGzx!rS6_%vyPCq$x%A9JF7M*^w z=f0@8mX{R6I@wsqUQ;y#k zF~hLR7F}5(fs*&2s{5MLsy3-xv*H*!=lH^6tnRB$tJzb3_;Oyo;z|nd0kt(KDDAKb zRg_$tADptXmee_FzBW)hY{1Uf<{eKUBox`_`~`$lqh536g)yq*2a%H!3-n`@x8Ukv zu{Qdc8s#l=$#>tp!TB1>SsHP29he?*F1a$8$pmLwj~x;TQ8HL;kG$nk9vLP*&m*#9 zRNg3zN-R(0aVWSl$|LxJMP2Z=Avci+7G>hY=}GgZJD5>Dln zLLPM%@w!B^6mA0+>*S6$<_*rXJV~jTWNft0QyjQd6HeuUA$1mO%JA^3HaO)NL4CMV zaV;?FyaJY~xQ%+q8=R-BuGttULZro6bzLyl=;9pD8=R+VuCmcP%iJ zdT_qba9f*$Q}##;s69$uH1v4^CNX8t0bdLty0@C%;ACP4ecq2Cbf?PR+a4 zIlo8-I5=;?U@67u(D}m7pi1pkms)_UiVz$lo>vjnQaYn_4UZf1Dm-q?tH>MkBV5ru z06&Z#g!N2%7oIofPvnjH6P`Dgx5SZoVm`iU=;_=8;6ag- zJg~r-Uvp*r3C#SLcMf3De+#aR2U(u^S_Kb^{>z*M7JaSgl05u(-Bt0Y=)d@y_aJ$C zo+%s5@~D$p6MqVlrOK`a7V9Jrf3iIMx8fRjkohS7TOIeOY}GaJpyZ# z*YT$6e6HiRcL%2&56FM_b_a9YZhJ2haQXB3?LmKkaN1YY7@YSH1dgTGJ>ET(D1-Cv zjga09>Cp?&kH2yB+xOib)~-4oJ;`ENdY(^g*;oS9n02)Th}lypp04eo<%)Y8r?M$- z=5?(fBKOqBX1$D5uKOgfp|1lPuufQrU7CT^mg1daM8z)8@E z>XBtc#jWQS5vQ{8of0TfPQ8lyp00be7POHctp#oITpP!!Y)Y`Vk%Z8u1#Og@h@};8 zBFQ*T-Fj0Jv!_scIk7b&d$Q9L+bX!L8MiJS#a&p|t!up;(Wy-tw~ZW6kFwX?x&)2f z1*d!bWD8d_qcQyL`f zOVFrqU5iFk+$Cm=h_m6-DxWh6Z&Qk zJTfEcvoSgivC*iH{M7n1g^t3D-cX*GQ;y%Ew86oW&dZ}SwtG-%U=-f1?onhjbQJfj z?$T9P)kMRl^ zcw3Cv3#N#(R+-`+zwo2(>HEMbNyg=byE!_IX7tbN^RB7wVO*VEb4_gz4b{?R(edNK zqwFQ8U#3sE(~duCGRCd9^%1A;@odoX!uP3ryy+>y;GWvnPTzV?YjFIg5)F=DbApa; z#{5!TPw|nPD|(h$=hdv^xy}e(c188da*>2RGQVLn%1*hW)`GF{qAN;c3HPEa36H1! zORl8C5zV-gHerUUZhc1-aU{vRl6FRsd(H8?3S->kINYnQq#Z@%9!FJo&Xu&UirG_5 zWnD>~j*{4Sir_40>B@O=v}HDQwCgHrL0lYdg)Dot>oRn->+;;AjnqBaG&NpWb&vK{ z#y#3s89F6NX|%AXE9dsimu|$>^k@^NZ$1_Pv3$CCj%8MIQ@HJOZD2wR4D=5$zIyf)RyWAx^vX&8X zftLZ}QpAMjk>lkgF9WFClH!8P%_dZyD`dc8iN%S$MVGlU%1gN%7>DE5Qp#nnPUL}E z9y0MbKMY2kmoNSEN^8ZZ^h-fv54hzZW%0=<1b8de_?q8o#N!Uw}_m1w+t3JS4VmD zbIs+sso+pv8sRV^#oYu4=rUd9fmz<*yi#y$Tr6-Upq_CUkOG5KX91O>TjQp}^0;?d zcX=?%sTNKVQruJoPR>s9V3s$~aYc@M3*?kzwaShk$c>!J1GBurd86X8+(~e$rYW6w z-YKw9Cy(O!yag@(foO<=2F~H1kR1R%Ya$l;C!>;(%ew6>f$2myxVYd z+(~et&EdG^fmt5+O&c!74MnJ>+jMhafwRTmyyS+$@&@PaO}7ZPXpgAQyPIx~8;Zc$ zZn;G;%Nv~Y?s|b63a;NcMnC8I{Y9{g_5|DVxS_B-PRHBs3O5v7EJ(6`$6Wxs;9Or) zcia{3D=ZJr?GXm&w!6p;1$7!{*IfY%waje}{8|4%C-@yM0r-gRyu(F-rCcU-pyzf5 zZSEX+NX;;ux4Db}Jonduq+DtUFXB~W{)<&P|E^E?@A@e3&(43NkKWIJ@z_5F|HVgxKi%uU(MRv+zj*AQ zg8$;9Z|T3$NAa$E{1+b`^WW&B-haW~iU0CP-0;}@`|pI0_ID=ycgJ1fPI1hC?Me=( z-TM%Ixn+LKm7Aq(pv->{25qj`%xQ0a2phlU!9Y4vdi-XETejnWQhDsBxi;kkyA1y{ zz=lZI=lq>wPss@|3ZLf5WzKT~5(LeXKW?wHpE7>`a{ z0Yv0)qu$nEq4i~SjbWER29tA)|{><}jeckAVOtc_P(KV??T1y4|JR{xW zBj>3AsQ+2uCE5uu^&Pipz%%+A{D8(R!4-c7<^5>$BXe#A{G0_Sz$0kLAu0 zqZ2gnaNU}R$cEGMXw@p3cQ)tEMNh+X@FP!<9vER&(u`ghq`;IjrUv-F_=947cB)a)=@3R4E1d!7 zXI)w<1lCA0bZHG;Ky%MvFgf>+Xa3F4)*3p@8$)>8**b*wcQsqf&s8^nE8xL6j+v8m~rKqvD}Zk!bit)KkCXtW7!L?aA7R>X;*%F zEcaqgD8@(|#ynD;&?!w8M$A%sVN|oBqjv9umM^-3 zR2MWkrK`fI`&M?O97AVadDeHqPsB;&E6%Z}lpC|;SiIs^r1ml5RkwK6t!QXQ#6_)$-HJvN zB4w%Je8q^1BKiu+I$zZ4*saV@a4uhSi`QKK8bm@F%sQ7`TyV>_@--q9)w#IH@rKJU zdd>$DQygfxd}@ModC4sZL(6H$Z!76QmsRQ9*IyS{Bp*vWE{)dan9vs zx3cWkmM4i>x02R#rHgoFL{oR9$fe5g9caQZQf6Ier!@pMwbM>F3LNSI9d) z4j(5ju7a9o5@0P*_-LR|0H(O=%%@#JeGF(}qz5SHg^^Z;^0_`o32(3s)W6 z;RI!YP>|)shDPC3?G~PKc|}Oyk@Ng0P@7bvm~YH8F0UCOFc+S5`6tKFr(FK2F?5P~ z63tK5e8J@vPGikayS(O$fKEwOxac^080C)U3lmbyZe=5uI_X(qu%Ra;s8Z zM5h!xg}xg-GuPzRZdB|P`arc4D#u4*O6^mYW2j3@KYs0L+SXWTLHhQ@%PxP}rKS9_ z;yOpr^7ieCS`@pq270Wx9KrFMQEyMw(Sb|PO>)YuTw1|!tj_xq6&_t$E;+_|U!r2B zOD{}v%DY`!VRNkbeTj;hF0HUR#;I80(h8ezM@&27Oj>bfjPt(43|Boatynt7x$M#k zq);@td}NYec1xnLBgVjp&*y7dm(RK-h1N0R3YR=?Y2_Ud6<6JoLhD#@#Z|YY&^kfP zjfj3BalXch`l?$}u!Zx1M9$aO+|ruMC?GQ@j`4H6UT{kVu7=)m*71y%_>PJIo-yFK z5+O*J64K?!mkL%({*s zF~dY;fJh#K#IK@|2Z<@uF%8c|KqP)C#gjZBlK73w10ty|BY8k1wP+*{h$L~0`3vTeOS${t6RZO@ZO*>;LBR=FKp%=oqu#n7>AJA)XNwiAj`c9VNMZ#o+3mcW=NEc}x5=JBRrxA;CX|?XCJ)I&Gv`g!<4J13xI7>(iDV=Xh$M?Ql6qqW z`Itpon&AnA0dVoTOCWhbBnW$ zNizR<{KrYY8#?p)Vh?tW$V&YXMfvit9U^lqtq zdg~6yiFYdBIH`hDx|JASsh>vzE;be0?L;}Lq*?m-gf zIc~1Ua{X;alIsuJBgz^>FWe*tll9%A?dfjc0fusyr4H6L%C${I$o+EaZu*kT+QQmM z97nMTxSPK0vRb!Mj9Ma5ze{E}J>x_juWMT}!Kw3jyR#YR)Dga2*^CjhS~J-_i-i{ zcM0m>^(Cl(sh6PsSzm${9Dg`zRF(gO95J_4ba_4-Qh#J5`~{IbkGCSqC#n2}Br6NU z-F(@Vc=@HWLbDZDsDQExlOnSElOkhi&GF}tM!9vzpE*iUJ|fLGK)oLCrZ-%EV+?J& zd~*!lbotE@y4Z5r7ATFfslB_n<+592XxnAmV`#@^J7eg!%Wl6l+I89PSniI??u^h> z&t-ca<=FIYs_(M>F?82ucgN5@r@z%LA;rA7@3Q-MUhp!%R=wLU96xb~Kka*$KSt(( z*6OSMJ1uv&f0svxyeEA^9#VpXCy6|IOHli%QSZlP_t$5>!*37XOCE5`zl(b}o%LyEWpF10p2_<(yE;Ga@_f}ZO;4W{nEL~Y1+*tu%p?;;j8}t0gr*cXEqaI>9t&fOL>m%Y( zA2E&km`#f6 ze3nk=Se}VD-YqRU{b^%)rg`-^3O8>NbrE!j6`x}6H0rOCVy=D;ido_- zvv=pw9-l^sRIU1-B8ja|qyCpke6CvcUmC}Hxr5^vcqGIms_VJX>c!CNGd`UookNPb zOCh~HmCgn{rdI-vO<1{@(i_8LiVYaf*iWR4t>Gb^XN*k?jPXRuSQ{Qw#@ukmd?ID+ z3}@^oQpQ&JHPUNJu?0&R`|)%En^+(}N%0}W@u7*d81R^`2fQ?umIEHsO2DgAX)WL} ztp~g@m2L#QIhAe-cXt;Q$rs$|KXrHO{@t-GqE5~<>K`flTPvP~Ybb|Va#U2&+XudV>Dbzo1mnG{#^j&_*LSKBVG*kZFod4SMz9`%o zUJm+~$cpF`U`KH`qsPeh^QJhwywHextze%6B%GAb+7XPNbr3r=@!Dj2U!tBWB4mcAwE1C(&s|@vB7sLUvk6_ zclYY^uUD9;weTMQ@6J{7=008&-{gOXADFbPEOe39v75(Al9t;fq{JBi-Q?vQhTt9*0?z~qjYH}y05TEWruQ@jKN$vwC zcRnO9XrmO@5L4(CSKyU`&sD4b%}mm&e^K)m=)=kV--5}V?||ysN;0#X9ZsSAJel^L zmMi-y@RLmu$7=Q0aNYtHxBoUuhZk12|28T?{kKsbg*$x2CEQlS{-+!g)PM2iQPs_( zd`<35%T~94v3Oq;?xIcZoGaSo?5A?JCz{+jdo-e-^zZ7E13RxAaZ)~G^SJ)#2pa)S z?*CtVXV)9ok>>ZB!A!;w0*sK&Vl`t+JfbKzo2S`4WV6}LX0u<|q(thC6h&FIWvwXN zvZakIdn^rxa}fj*)uWy9?szZy?BxR4AQ!nvZZZZoSGn4&e1Lp{e24%6lK=0iQ>XbN zX+4WYvKyEHQU7}CJ~M|Nx5%$bYEpGF)_D5aX%UT9Sd_zFV(Y{R7Njip zq7Iknd}i?0s%{alC06VOe{>FcjlD^|dVwx2EXrXov4={kS*L~;J@XDHpas+gsTrn^ zNqsO^d@xshFjxAYx#EMl;)A);^Jb~{k+Lleb!q$BK`*fxdn$TXlQ$Jrv1?`5U_@!) zz(ncE?8Amjq8#=rw#HSBK^20B#ypFC&?x%srBCTuXEqY!e8f}&ca{k zj%4AZnHyHlVZ|I)!cr7%g{zh_S1lF8`FLJ-CiK_15x{f++DYgmhsqtT#L+^BLY~P! z65LS8;6_>a*&}YKL~ze#;lY)ezzQGC0%7?dDk&}b;4(gzR_iNS4^`?0&j zdvJp~AKXCy$D-DUdImSt^Qar@8{AOe;D&kyH`Mp2o7FoBuIx0n?8)32$*CF1ui)nE z*laypm1{OUt!r`T<(h3yat$}PVFp`Xa5I};{fIjoYbWFL-O>Ht97*bgDW|Ue@?E=H_APo|DOeib|9u3X^u@aM`1Lu z=FkqrA@D(SXa};`4aeJWdF-KC$j2W3I6`dQ2EdgFm~zZJ{9(mnYt{_EyD-Nd^$zVo zioL1#%XM1MZ+h<^40=3Lng@2S zyPh7+VJkA&xO+OFGH1g&QX)lJXMro6&5o~eoQ3e$z9 z={2h%^y>r77^a&bUXM0oXg-NSV{>S}&}-(DeSi$~{RaPwG(AL2FikrD&}pV42SakW+|bL8LN}1b`A|6uttePYF=k zdMN`HieyrVLIU(k*RBM*02Kk$3j4d-KNM-2d)^_N8NZ?3`~w}u3?VN&gY487AE5?K zes5Pcnb%HD%BxI_4JgJ#ul=SkuM%1rSsx6&H8%CgD^X7wbW2AqtFbM=*%qi!B{VOK zb9ogL@5;QZqNA!d8ZCn=PbkPMO#ONb!WvhDUOfOpjlz&HKyC&q0q{8}EAX3aS12B! zO|h*Q@Y?{@1ax7u2IQa%n*lwsSvy%GO?J;aQ zDG}1)x7&dpF(xmKN!vLwSOgr1q;xv=kcAz|Twygd0;^tT0v>Sb zG&PY{MJ=qx0J=19487ov;sMX7y^Dwq&;Y|4kT+HA{IsIc%<0!Y_i9Nur9k>EYBW0S*#Iz=(xNS zK3HUxXq~OOA!OyMZ3$V0lkMjZG)w$PY0F3-vg3_550E}anUXcB>Hd9Z&_LLDL3h>d*uvK%e6DbruGb$dgBato^m6HiY?D>ZbrpM+QDmkCcwg z={MO|_#R!CJb6OOxYYpobiPzd2e5RstPi8-c|^}UiWrHWMnCb0K6jKhj5o}0$cte~ zJXa|1kQ>9g(`S6ntQ!n1u-;gJqYHGF1sb^Juk53iCKc>SDG9D+W-5j2NtjkPl6}YH zrJ)cm&ia?Ld~yjr&1}4{gnvbM84UCWPYa{B+S|S$>fc0zB^FL1iG54j3pB7-YUG#n zv5&6!w@iI5Dnt)y++zxA=ovdzM$1FJH|K#~kn=#V!+B8sbxq*ktq*u+PVK&@juU`h zbMusi478WNcGiwj!>{dbM7WPiO$IS0I5* zuK?J@o6n>Gq#kcOFQnbJ$V%Dbp zm2#*UAlyuQK=`~)vp^PF*id^wP?}}|u+r%0Wu{pGEDT^_!7P4Eji7@at1}NJW|{?v z$!(}-!98D2&qv8dguB(jYZm`a{12027a)l^q%1>OM&$OGCT;fZu`Kck72YEzSdW6n ze@URn;zK;Xq&(LYrAK!1uUU@6)WaNyUtuxu>aQp-@>x@+J)egtEFz$+lGm!hIvK1( zAx|im5ZDxDl||dC`(4LRByN$?s}01KE;B)s%T09Ii>T^nnsb}X2OTU@B8kHtn}o`) z#pf;6{_dc|cRI`my>W?ZJ@+~2>FHAs^z^9*divCZSPDqb{o7sX*CLrQkk|%@ZGhMY zh+R88*XXrVKeLn+#+$O*Hz1PUBvl6pfC&jo3w?)X56eTT9~otZN~|cTr$<;-B(2e? zrctVa0()(lae={oL6I&J>7oD)Ew=9O4lck)3xjnQM&3S*cCj$p#lj|kwWhGW!G%Hy z8rK}oE5`$F`|*RF@vBA%V?klz_t`=Cb!os=lbMiU@R^VRSUQHP4GMq>2|swnilAh1 z3#=b{JUnMC;K>rbQP)ulJ)Gs0wceckE$`-`D2q1n)KC@;aV@FpMeX^bN96Aj znRrABkGw&g6s$o!jWviztUWwp?cot?50BOj?Q7+K%O$tuQBkc<@R|Frv`UtruF6#nGv-iw0!)6ag_u}|pjWNXhEy}@jDi3&o2$Ll4!M50eFO$eA4 zF|M$t`w!7z`l*|J`rAMDS&t41pSk!mg{f=F(3n3xy!zJ%PxpM|tr>so-2afY{^omT z5We#c=X`$juic5-*Df8~bjjv#Ci-fUG3QeKMAR2^(%I(o2YuBd)s3HorSt8Vo&LN) zP2#Pu*WOR2r8}GT_motsOSrvu2QIq zX}j?g5$^(9db+$Xfsi)4dwq?~w7H>=0T$Kg`+C(n)w5}qs?`orSC2+*t6Y*DY2&hh zrmM7mbAL~T+h0pZ@4525zwGo!MJ=>lrTwFR**$wmMY+GmYRn?S1E~ZaNR#kDa@Ye2 z2oEGrj(|K!BEtt|$<;jR>}@xGTl8_;JClyT?`GeZR}2;cf@SnX}se`s7U5L^RAmzMdl@jtm`1%Jdil@ zfcB+V%}bwfuvpFY_#G^5>M*lE-gtqd7zQ+fus~5 zQs{^-%ao|ek9)RRnp64|e-*~a39QH;eLQ-vO zd#B%X+o?9Rd9b3BKK&J)^hv8a<|9sQr&S&BF~BbZ8joK$zmYUb<9RoWiMt8A)c;}~ z!}vxPsdS80bpc+6u0GUdCs+=+(D2)KR0kkquC&(0@W&p%7 z%2EJ%98y`(g#bDM@O2ByDXW1JQZ2P4xTOT9i`qu!Rht#Xv6@~ApqGJmPjT9cytJC7 zYF<>3X;YU2Sj#|{@^s-CsmO3~Ir~JD@{p7X``! zrS<*WD&r3+<877khiUPOKvAG@O`ss~dzSIp$13bUPo*Y#5WTmbsIV_irlZ@gye*Gu zTJjfexyr4V356}VlS3uAepO9EUulF^uh3WEdvZ+xow2kofV?i#MbPF7h%uADgB*Mt z-5{Uta}ehJu^iBn_KD+wwzZ&R;5*eSuqNl@F&6*7RjuNx8$q>_t`SSM(gD%QRS3MnOvzf-NgQ>{jG9V^$Ua`2sMB{|;sPPLNO^-oi^DsQ_{R4c6*%BWhhQWsFI zq-0dSQ?34~sa7%r$~2?R(lR?_DpkDcMr5{b$Y8AHi%o4VUFbf4(ATcrw6?Nm7U(6E zVOCJBd8zy@_xwRx+1hJw1rv1h$I_x7wu!K$qz6sVjb*JL8kS;bue}peEdEF-e)h@~ z&EEbl+MqXFvwFfBy~sKdFwLsaOt-8z0q0n6Vz)h~9{>ctc4mm1z!Wec@&pRE0rYGG zEYA&~RU1H=Hh}7E04WZbTXPd@FYEm|_m+8C6VE|sO$?~a2(XG6N7peW8^zE7+P?uT z3L(8aGbI6A9Zou#8RyVOb?`5pLF+Pl z%&!f5X_fC|>eZugD(hal1c8=pma+TqrCpXBEMYsi#?Pa19G{b&209^c`kJ`;Pq-v6psOQ*;!gyQFNpWXn%j zw&-9k&wvI3!|)eAlwcI5efKnTs>+7_ULb+7L!8}-E2wLRY@7q9FoiYT>px~;&{ z1vf)8s;Ii;?u)ds;>y@Yw60^fSX%4qMK$?lRpZCMbgu0tFvh5z=kQmXuE+vRw!JC_ z7RI&UayUgnYT z_2?MLtCNG#5GSu-PQl`7wKyL@Zq__r_uN3q(6Ur>q3&3d5IbXQq3%$u?#&SAa)`4X z;#>`JNM$^{72@oMIPZlx*DcQ9MBn|gfDAWdIRXsut$$fKZV<>R-{!}Hd?9-rd>hNa z?6U_N7jLCwIj96m0OtYc73ch11`2?JDEeW>gCd|PisX+663P$6h_}+oS_X8>=>yrW z-b#m{9w<3j2$TTE8={OWRSpV^qS)61C<5&3%?wNe?CVPzC;{y2tqha_W#wFMzaDf# zPyiEMQ6`is2gPMk?CSx&IhZ>7nF3}|wApvmE_ zGyofC>|9VhS;U)I#Le`JMce?hzX4`{1I*3_P(uu0=rwRW@O_%1piY z@()z7sWdNYje1!=D(W|y%gvK$pO?+z34#$DYI$A;;<7BfWhMn|sHG*@u1grP+rfz4 zm8IB>*gXL=Vq-)$BQ{24Gh(j`m=PO8H%9DpR|L)pj9wEkYxbr9*6h<;0;dH=v|(Rz z^~FvaUlWQ!Zz?0A`X8c<*x;><*x>(;GQP9D@A~h!7o9W)sf?X8x#h-KTA3<$GOg?6 zq4kR@R%@E5hSeJOC$u??E^P~7tS++8!+wAy4!;q;;>YJ)>K(Jljpc$YH)tYdna;`{*C}@9xQO4bCey#GB6qiO=t?jov ztkwq53=N=R0!FY}o3$CMwE?Ww2C!NK&S({6iwve~?f{c8vrheTwBycnP#)XTT-yq5 zu^MU#G_@MCVa=ccELPNa7Ts^(N>{W+XGN4A%b%V=SBcT8y9`1=yR!eevitL2rcG@P zvTRz43()aw8+!HY8cB$P#;uP}K5uarNn^_9O%0v5FH07S{Dx)?tvQCGZIRKy4Y2dWK>pJMZD*9L0if_3 z<$3@&0zmU8=*;`iJ5say+U;^jXdYDC6{>+oJ`mFO4pr%{(3l%&piHV27$`HbfGu3K zY0TIs8?GNQC0IwCn2lMCp(3g*qpAw8tr?KfG(!RDFC$2LepcYTzzkD`<;)_DhAHw1 zl8^FH`gMU~kSdzk-z}ZpZ}cNK&L^Yu`hb3T+f7TX>>yRhhEFBTFXSgXbym_j#kP{7 zhFeMN0(`32$Us3yK70YSIn4@~%)LCRd=gbl`4G8Gk?kigctE*I+SjX`wX)&SJ>N^; z0MM4LCm#cIP;8ORFEy_QyryVe<;lg@0Pe2#Qx09&=AL7K$2)D&@5<`4gAYVj<4Y!=ZAsUER zBn?#5anvimiqxspuD(uNmx_(}%Ej$sJhq6{z!Y=lw((9)yYA!*6W^fT)OrN8n$ zCMvnm3z4QXbdEzEJEljHTZsW;9g!Ys9g!YsLxvt{Lk0tm4H#3<{ zQGZGm8-jA61!H3L>nF%!V-F;EGVxeV|ki!q*8SUi|N25Ny)&%iXm3~vho zG%|1wU>3N=qhPM-CYtK1W>S$IYRXMa4ZZViF?3!Gou{ql#gQuPgZ=Zg6_k0}3NQyB zjkEBPmeISc8E0uUI6F6?%w@ZY($hNmEQK(VuoZ{|&cXF_0^=kb7bucZQ6Llknlw%m z$&wMtlF>yh8Q-^Lw5%ob13<}SW2BDCN$)|ZMbyU3%i0|qr3E@mL8KVKua0G=a|T^W z)nOE9pwWsm`G{aGk%$sw>BPnwqt7Hsm5V(-TvMm6Sx5AG3*-l(Jh%tBP$U&8Mdz6^ga2 z$YM!k1&rxbq$<+`VTlG(WR=is#F2@7UU4&7+2-r0Wao)PwR9>VOHLe;&oB@un2t4M z$r(O~+9f9BC4na5qA5_-5fS5D`J4dB%V%BTY?f6;NB6CpaVj4GM++m*qKf{lh=Y>L(@lqqOjbLX3uDoeN;>7|@LezUaE~ zzyrp;RfchY2$XbWOFeUcH|2nU^`!c!NZQ-Jn^tt=feNnK`t3j}=*9!(ptk)(O`9bw z!qF+A$j~1tFmn2_!P$=Gzjf{#p8mM)bTFXf-1qFTfKCP^6apw3c=<)zjQ3Be zC$6Q;zz+MdyLToH(*%xnd+??|(ng=2b66Vw7e}4j}p5^vel}_EVZ7#J;LD z+1k7)HBq#ikmD1l(poE{R67DyC@cGG7N>S015K78oW9YD+R7o6YzT2uCXKg*xHB#t z!;moM{!E(Hh8x}R>J~CWPmPu8){~@gsDbC01x&|H0+`M~6@}UsXFA~$T~Z#Iq}SCd zF_Bb5Z7J8fZc-kYJR2Vc`AVRU_(Pwpu?Z7pcovyUstQurIxkY);vk?c&h&nwI7vD? zl>wcTVRM}J_DyC$_vR!@m8A8F49o-OMVTiEpn@IEM-`Z-%JcjhEZHeO%+C4Z!>4IY z1eKcn3Pf*6G$U46pvX^ymPt!Qt|gF!7S9{ZR}F@v9uf<*j?okigF}N(+Q87@5wl^# zbTig(I(%f=d4aFnCL@#cx>OkU9l#J_$QXckHR+=6FL~7LpT2KI%te(Yb z*Y`=bFHyq8Y^bs1YRt*vx|OIjsmM%(dR{E@BLJYY>MozV< zR9#0e6>m=o0?>JXPtjcf25Qw`xAo0u_fl@>u{Li`_cv8uF?6vc#EIv$QT%nH?8Uu$ zo`tRq@OkOV01r5G#P@2fAL%}IT@)aS^z1UI?D{%HH1RA_`DEQF3Siyn0nF|}ouTM1 z79VGY@w-AC=3(OKG+~w;V9Ap#s0*;EKK+@&t{o0xz(BD5rs3XAx$NOLZ?muM6OQmd)2v&X{JmN6oO`A0aE1> znTxw@^cDcteS)#)iov2_1nm?f-MTvDx!lveSFayCKWTHOFvw;2JOKmc3@mFh0U zCSy)tg?Zgsw}e-EGviYGCBx|e7}?}~4TWl?Qpzgx5udD8m7R$X0QoEc76cme0u;J3 zCwZc0BGNV##(JwJfV^!3pezHH1s0IEmb{`=4fm&Yt-(byVLB&Z!nDln0v{`6$w}1o za#P2*F6$iERlrq&4Zw!LCSda*C69A?DND5uSXZ1iz?#4Xzy*O7z=}YxCBQMVE}$#W z2DAm1080Y8r|2EAbUSST8UppUKwY2)s0ma796751DgrZr8G!_l2$VYlWq~Qc6qIg8 zNdUx|Sr)KVy+8p#nOzZ>6{rCCM0HhwkIn<;1!@-rC{Ia0Cd^!Y`HuSP7@Z~Fc-cv| zKkdC{HWCyF!dijYCJ)88Dg?P=w#{Q|Z3CEw0q zDj2EtHwz1UHjc&?j64Rp(cdh}LQcNpO(TQz3<}tZuY*W7=2m9>cJ_-OUMM_3G*_UX`*R=P1Qc7JwzE;1T3P6PYal| zpUI@X7A;;!ucE%AkypzZs7wh|NEHiMs!0J9`dZPzbNz14>Vu9^VD&-s$4vFq_V~i0 ziXv(9hu^PE{$TDv{>-B5=<;&x3w{-%6P{V*@VjP$8pO7#Ouv+TKiym&>p10bVi6h0z#08n(voR7mzR#~z*BXt0ltR!H`0G15E z5aS{kCEMSXq8LziMqY|=G9j9-8W02d4o$l-DZrGVtLP%t_W-2^g*Z*BpmnX1@f}MB zuw(#B7ARWIX33`ZNw&XB7the#w9p(%d6C*6pQZp)0tq=VNU9`c6ltOq(((RXt=|L6 zN~0{gx}`CS!+lEwYq#mcy9 zu-|}F4U{BANv$S8Z>?VtKyFl<0*JOoE5zx9I9EcPl@Nzimd{#0DI;Y=(b!>lCn|0-%U^0`xvgWkhvWdgg!RjSfyduAbuZ7pHx|jt#dzmTPUfx_?HLMo|MvgxP|-^?|E>;1)mdJzY(RY^&XJO%5gM zLm?OC-Ms5+cU|SK&s$L$BhA~ce%n=VyVh;RR$`yMU(A2#S|7Ta^aA>Z!ss)#UDw=o zmEFU6*B-g%BUgEJsND5W-O{J7`l+HSSIuVD0hYdW^)FrZOHBy=TTFBHU%I7Vy6P{F zME!+p{KD0K;o86WD^{zHk*cDotC48rPh6}^T43gg?rHg^LkpCqrCzA{G-%Zqxjja| z*oj1^9{#ena_;#9Dbq5x_09_QK=D&o(@ibZHFl%v*k>;J%(CEymHeib@8A92|Nj1q zckllGew%e{@O(UtDj#nq^HrF-MXdarvWy1Kry9`!b^Tv_dH zuWwvCv2kq}w)|eyy|Q}aM)%sv)<$$={p!ZG)$Q$#Qnb0XvAMOry1lV=;`cAq-4iFy zM%|}R?>~O}@bkxw=;?!pPohsAKDr;t{rTgM?|&J6^!XRR{PN+a51vM6Pn?+esQCHg zj~?B7^5o%1_a2oZ`IX|{XPfBgB=`%lUbo_NV diff --git a/tests/test_data/clean/1.mid b/tests/test_data/clean/1.mid deleted file mode 100644 index 5a221b285102db9560f7c7c8b84a38196d4f55c6..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 49388 zcmZs^U2h~un(r5f7UV^uAq^G@LKF^ic~m``Y__sRWfz%zi7r->tYopWij}Np^~@w1 zqOd3kI%|Uvgn>Z_!VI#iXLsRa*a`UJ+}gf5cgB2#^C|pEd~g5$5k4dE$x5&36l9%!mB{@BZEYJGuGce=z31{y*mbyZ-O~`~P#>{8#h8pZ~4@m;dX3 zZ1c3(+gk1)M$OR1&G68e!%ow)#@NIRie_66l4iJP%w7}@%Ywgo`M16a+JOmj(`?&L zNL9^tdy+I=eHb@80~?(f!-H*CQ`2+b2St9*-vizlGjz4ww@0z%^F`ZBn%4S;ZFcl^Cz#|BfAgkTcVKFo&51Et8jNi^HjB(eM@cZw3iW5B%@~ibaWkV- z*32e0R8?CckHJT3s}I{GrsCz-X&fbH`NLTnK98c@1gW{><2;RX6ZX@%Z%l4&Zlc}T^lkjy z-1kk?OPlSUiTiFE=yj4bvw;orG)#=)i`{i6pLK%0XnC_|qM{6*QFjsz%yK_9@ryX5 z&xKjgiaPmGn{=X)F(Vtka353a$OMTAhDFmWifCx=3lj`%vpq1;nGMp^rlvkKVV*T_ z&$8e}66Pj-VU}Oe`yz`Av;6KTGyN>sjah<##osln5A!5E>;%ukU~jAGso^hLsDla4 zjd|u8UXI$E+i+mbu#8mTllCk(MPX+1*w5b0dT?hZU{MF9xxNi%RFE?^b1nK>es?_& z26>R%kiU7#qpz=-Gvs_R&6Jlv_R;bOQy)NsAsr55&SQ|gJQTx#CfO%v#S+a=<{p{Wxre;z?!V#0` zM#AjoAtbHY{+L!%w%2Hy1!tYWcAD>Lt?9P2eD&$rMn%Gun6pH)Vxy4$(JG73d`p|q zVq{G*fkHzw=mcX^FiE*#Y5yAE%;O>r`pm4`Ov~@u0n4ndp^XHs<*{E)Ms78D`t#3E zhG}>d2i+*x-CF$y>h79&NR<&}9UO*G`eIL8i?O`>{M9`CstSLb-_c{XeQ$2lEciUD zerwDtW4`kJ+-N^{i)PzB+kAUiG~bKqSf;2NRAQiqMKGE--|8u|4&7H)+jXw#@+1tC zW5j7YnFo6&;AP5xg<92KF2g;WLYb3!by%3|dGnS6TbXG5KUd@A}GRShK#B_67DQZ?@TGf-Pf~3du z3O>pcD)iixg`!SMGa0clmu?h6JDL}lHGQb#(#&pyZV_I>zcPmS7uJZO@dd>hanv3a z=2aPV^B}W9#!>a_3Ks{%+DL9nvgjP_LAVzG&3rR+C!K>jCPi1 zFaUFy7K5Tld|sJ~xhl*>Nwp~rt)8|ARd8G( zO`sQ(md#r&fm*H0xfb0ofnU?ljkz%9z?jOIV>hR^Eg_+~hBcfQu#1cN-RG*uRJZYk z(JZfCv>KRnBX^9M0>D3PSvKLyxVL>xCYgdN>bWrqP9IhSVFjS=xMS z)B3DfAS;Je-~=(E#++5P_IDPPbN6aq>6Hzss;ny&p5MJR<|UkJu97^V03SZIRn~CY ziod4wFw2YbHk}%KQ`16VZB)W9=|I|Mwwbwx8dln5Mf2lu9`B~j6CUyara&t!Zr##t zzG#N+K~|sVwYpG->OxkX8*^SQ818{J<^0a|uBK_dH8@Nu_r^cL?q77GtpUlb^a6NNls~1;w@P#p7vOP z4tH%;-o{s^U~YE}eWmwIuew{eyS5o#t=nmOY61=8xed-rdr=glDq_KS>Svk8u3YES zb~=Yd%$K$p8BKZ(j1L$>|0yyv5;P1$g&`${{?>{joqrxv!JZ=`5xc-O=>GiR29lKkT$ z9-Gx~pzg8s4m6CM2h9#zTrZ9(lXs$vBh1zHjpwSuS43^Prt(?D24tBkaw% zEYO`M$f_zpao;#}mBcOdFk?;85RVfy1+gr?joRGYPs{k4F7f0n4X&^;FjTIXT{!jB zMA!z$ChVDe2w;R}0tZ8dD_J4wXU!9s=s0hNG|FR?ibHsaqscKOFTwE8$voJLsTQD%StP>qa&E}Zd%$gmlA8PU7)}rL_B5S%#-XLo>2dpT4l%^F2 zWqxO6CO|F`Yl=8R=ALB>bRy41t%>>>Ts;_e$ z!In*URYu3@-LdQtQUA1{pAZ9|>)Wd`I8N`6P0&kW0l~Cj0iYe)EzXRcSpU}K8B2mJ zl$`_NoF!1{f~BWL1SG4_Hw$(;7N(<|+?)710Euhn$4X!Yld3TInT58Yju;&odqk<3J8?9Ex$Q;d$ zO9`?rUD!*_lrnZf;}sjlS>!APogK~9c-S(etT4G_dbYi&D*6vu`Hg(vvg?NPuwO_z zLIBp>4RJx*MtwBl6F&WSbWHa#(co>AkA z!t~1Evt071ZZYMl@Uu0SftjGsoagLTRlA*LF{9r;w_z`i&e^Z@eJ_FS@C;6uuAgn5 z)xqy_Rr}Jq?ryZ&%}%TMovS)-dziXz8mOa}NoFt&DI85h=Gpd98XwVJm9w-jR=J6I z+Ro7>jP=v7(Ly`a`-L&vu3n4jFVdq9jrAZ;4_JwLV2c#Cc4qE!2){IG6?~pYsR=$; zIf$oug3it+r86N0>!uEG>|4mVK4A>->IQ!0Y%U|R7yqIJm?|te-=lyV8VR{ z$_A2(uTtH8?OM}grc+^W@Dfcp)xFBYBh*GrL|$TVAEz)9@UbFzm5IK3ulR$P7|gIG zYSsoXgFfHD30O&R3M3w6u@}=N)(dTlLRiwv5#R>$CPxN4>CaQ_XI}h}TLB-|>LX|S z`xmQE2WiB7TF|vEC(z&hW6KWqrXO@ux6FRO6YXwk8!ks{x_$)aQqgsmiBCRf(z1xz z%=s;O&dy*!bZyz^5ha+dEH=*tugUmh)RcZs@1DoVP!t2^J~(zPjn}@8%&cCp=Fc7G zcpL>;C^P+^H=$?M;VM{LdU1<-kc6NnQhdR%V5#hz#K9j2g7fPPj`2h;pfe;{u!(!&oXDnphDd0)Z8zl-fuPY5BDitd(X+${(cNS$!&5 zTd1E%6+Mbp|D;D&+bWM}*H{NZ(hybLw6C)}|Ngb(Jt1|*(=_ULkcqC94_Zn2bm2z0 z7N1#o0X;=QZwtXJISZcY)C3VN(~QB=i_Jx1vSjr_FG{`8)5Kh`coa%jZ`>ClPcIV9 zQjZ1WMZ~OtF8BZP?pzF)LcT&bMdLUaGSnylgx~P(PJy)1UOK?uU~1AYP-a|VXn4$E zobokF!Q2q=&4$|m(NK6J3m0PKV+1_CK>bPss&L9U9TCzFSwP>^6}EtB-9lytdWM5w z_-d`-u+Q$>F}P~Zf}s*M+XtBwT$4N)S(yPW;P#;qK2{Y|zzT?zp|O`nNSz(_+bCzr zQ1ua1o2v%2lAQLDEhv?tl4;u zQBxh<247Sd(%7%k#!`&68a-#?Y}Z<*+r&U`_OK2Tp=!duT{)19@fuyw1TzZ}6k0q3 z-4J}n1`v6~ZC+KQ63}n0Js~_w#$sg#R$_zb60Tgo1n{U~2K@@v)~`_C*yp9m=5PYa zX*(cO%f&^B9DwDCHr(v@%wpHsQ2jhOvuHq3H?>_@LRMt#r5w4akE3?CxbNmgR#etn zHjpOODf&DE7&`vYCWl$Ss<1eWtvxfjX0L>U9*AW-8=wD z89KrbTi0G;%X_rv-Ac+Z#rc|!$8EQyMs*E%r>t#Jye{KigqzbweL42sP3cPmOTV5+ zdpX|(dRz-xd+p$Ao5$BB4>eM#@ETT;hd2zTMi_kBb_<(X8gSr7`HNmy^@2i{E>t!z z7_PeNdf}($=Apv@>N4ymfONVAcD3xLNgabVJxmIEJ?4n)P{fIw5e5?phq#4W#cm7w zl4d!5K@VL+2-8~)`Egz-;(7$+R8X#dUBdY2k+>6t`#3L5VbI=`%t+*C0gHzgBGCQV z@=TMHURPqiOcb9{3uEz|EERWTs^C2qNm~fF1(#tUQ~@;n*w*9d@wrgu4~w8*NhqxT z)K<3vc;Xjjvjgd{rXHL4!k9m-*B2j5Qklbf4ChG<%c9DwzM-m{1@MfqrPw1?buH#3 z&ONBCl$l3|VMxrc?=zLV(N1$L6L^+rZc>RBC^fB9HXP>E1VhjoOh5K4}%f)q>h zR%l2zhl6O#P33HZ1ur4=AK>i|A<@ z>?Sy8O!QQb-~Hh%S)C5ksJpfNX`I{LEjO~KEnOI=YzH%xu4T@mFzY}Fu!}C+7WSnH z_y*pg5vJ`0C`l$4BsMc|misN}3ZsRX{C3A#fR%q1GYdhtLtBZMG{fEPU`dO_kbJID zWI=c)mnvkEArNS+>u1|H7PX@!xadHw(xqXmGl!>u7NDbdkReiJc9>c;518|}qdXme zy-ac}@a1R>aFM7SPEc`h3Q<%{d?eAlI=#-}q^qYr)QwJ%MXL{|NqTL-w!m#VtKZ0q zl=>oZ)sK*Hj1v47`wp|wGLrt5Y%s(}!=j2m<@j%5ow#>tEVaqo=~gECON4I%rvHtLf61pY;#*nn#ryMGbR}q3bj| z$ENO?yTc@W76s^B>DD4O;WLBRcKHPxrx&k2#Yx;_E}umVCP<sSpP#t@gVLf5O zXQ;>W{XzD*WeiW&wmnuY>M)v{(|TDEmP2^q+KBEQrvN(5Hoi&>YbZ)qR+gl)Ht0V+ zaW*_=3}y%hI%d2eNX+5x88(YE3Qx@a%cRDH$DTOLu%Xn-<}O2fN&iywXe~csBdQ^} zoJGB~wZnTH>8~7&(PCF z|5K$rf;)Yor_5n)^ah8s>!aRuGxXqQn+_7Ba9L`CdC8k?!8~Qae2#A65F!d>9kT+1 z_n{7HJHWeF$%TluQ}SBnw5F zni&`2Zy%h?F6zc>7^K0!Ia80u+Z*1A2y~N}0%ms3=20L2u1wJ~?B#H7 zX+VgLhhChS8WrCG0YDthHj-W7h7Des;D8Da5Y(vw`~U)BM~Nu{jvN5&u=4>voE+c8 z+Ub|Us~perx>UDU(qwRut0F3e`D~-isT83s4|P4kGXma4e^jYAvjDBW=I~#<_=)v7H%P(;|cl! zo0w995>ou_&5o1VbTO1rOsG!dUQY@w9!41Y%g;Q5KrhDdN(OS2CXUn zan(ez4Vx#|SvZ5qEB>U}hC%ZN<#&cd21YP5;c3TQGZfwsB!l#ufw~cCgVu8zHE(B8 zbQYzjd^&dC(bZUj1MWwiP}jSE-{wiW7d5lPj*}}hxg&E|`MCR3f zOF)E+xWna)w#tH`jBFCoakOqe4knT|5*Tnn+QtLN;*Qa%5_Z{|c|hw&xL=$L%k@vT zj26w+1Xv*(h8NZuYbh??JUv9sVSJ$?Is996w?Te4K@%DkMDze55P{eHh@S~<^uwXF zkaxR-PBTPnMmyN+pwuqEPGU5q<##UZZuKGl26V;TED#sc5~M){{nXNN$Bxt8j>O3} z!|b8B_7nNAD+S!W9Vj{}0x4meph&o}qR8E#8xw7pN*!B5$l(jznT3moqw0}KM_HIe zurU}C;uEHYI4cUEfaZxJ?DA$CJuW2@%(4@3w~1kV3!uo?FHn#8@a>sw9aZ_Yq5`wv z2%O3Y7bV2*X&%21v2(ty(XT`7Ff48=Jcw{6EeMv!tO@`SaLlsotT_aDvmOy$nn4XY zjw?E%_v;Wl6haqe7d?j9{d8R_#13zdDkW~O!1>2^7HfwPJ;MwL$dZ_Xy+CCy4oo9M zC=lND6LJC$;X~}e-g)%HBqP3#C_CcN~=3`1|tv%3@20q&4C4}W!|CPV^alFLJxovLku17#UxZLX48fkGlK}6L7j8`7bx5( z_5vUFz&-@j&BHfRXBcGU1L~YYg1)OEvIjJK#3FJN%%~I*Yw8ZP-uj4+>Po z2_W!~v1np%SQ~6h8_9ZQD-lWe0wl12lNV&%=uU;In z13OYF_qNPoB=iT%+1Z-V$Y1MqBwZn`VD;y=S7Q+dpWFJkKZ~SWtOM!p6-YNEqUw8M z8O1o#YETxaFXELv^{N;s#GrJ6baf5r#XJ9;iGzb7=+pH!NUe+C+H6Nf1PTcq66wSu zA~L6-8CtCxO-%|Pl`v~g0no&y{lDC`Xc32Hv#kOI2h1CVNb`p;L>qopqc*OBZe=d# zQW$(xU0uv>mmk0pf2pE&k#nUX*e>(x%nVm^MYBHY4HoIU0Hd6_+W8t$K6Y;30?;6O zjIX0T;_KovBEC+`_BCk5I$> zDsxe>}MayuH$x4ZojAY=%l!;Ynu>^4VDC%Zu=z_tZ<$pD)3B0C;PnAsy7N@BwQsH$9qC@jErNcih0jJdJ9%F3E&$k|gYJ8}_YJXRmm*Tdv zNr^GLx^MWFX7iN zsIjOkN!#Z|i(H$NW5b#7}Gjv{-BDarzNlNYs=pDetj0bW&ty!SG8 zCG6nNJt@AzqRYh0C{JM?@|N&4D5{o0Mocb^@DpOmD)*(D;SD9ib~e`@KT5S+ckN#4 zPdcyz)DkQa&%&T`e^L4D!Fp7mH7O{Khqj{Q+M9)AshPq>=s5A4+N%E5Tun0huvREk zh3AV&?Qm_2e)gWhQig8f?GqNw4#2!5M!9f*imb2e2qm3!2Q*%^dzs?39>kwxB0`a= zkJb5-D2Q8>WLLY_#R#O%*L>l?)_Jiw$2=pzjw0E5z6qXYQxa96SGBq=ZjxiJ*Z#?u z#QBxs(3 zX0HNHD&yhY5`4wqo?tzjpk+Sf2tl3R;DX78+-eOraa}Y$OE~WyVbOB|{%kRJ4LsZR zs;Z+K6cMLrqk0oLTG>h6Y~0+U8=BqH1$Xlao%1d66;6{%61#4*3R{54W)zt{S2;D` zJXD^x6~f2~iI9#32Z&4}9fyn1&VN=d%{F@*N{GZcQ{XF+EjR+Dju7nueV`6Nb$Bv{w+tWpSiRtn-P}CQMPyH< z7yZi@&td|+Qhd{Ugj#?S%U{HPMDbYugpLYE>V9QBXm)lZODGH;uOmR{tipgoa)FX> z9}z@Z2z75ZP<0$?y87@LwjL%hATBXa&L4iLm=Ym&pr(rV!V36dH>1KD?TBH47IJil z5kLz&gnQus-iulCd zEwL6Jg?I-nf&<&Z*$os9U>p1%miO2locf_EpnswpAw*69;45Ux+1z}6-ww+Z(}Q5j z-~#O!C_8SpFNm=OmjIkE1b(7^!Mx!<#B4S_{1k+g5kG_v_Q|<3ReE%g;D{71mal`S zu}5daoby1qwzY96L4Q-Y;c~&dmRFm=u=do*Cz3`%o?Fr*xOj|HX;5ms>*;cr@f?N z;m?H{g0#ofI>8_PPpjPrE8iI8A(0P!bUfbx|_19h_md zn}oA`!Dc>6i>cL4U@Nm-ZGr%$VFM+{jg>Xmh( z)-mTjD(D|@uM{I;V&t4h{~x2sgGXSKd}2Flq002Chl#`#?L`K3V3Wmi|864;anuE&@h}Lh#=@hdzF~s zA!e8J?SWr!_Z2$WY-61l>rK+^usg*93cQOkMwc4c9Dsp7TCY)SliS zB&-`-nPM5;a1nod@f|U0pp=9uyGQqI8>n_D41L|3ZGeEG`$9N71nHtM=;8X=eOIxw zO^?8l!Xi{$0CHjCXF0MnIG6W8MlAN16KT)hOU&~;Wchseb`C7Te0<9%niq9%6D}+& zMUOmp9s~s`z9MKm&lRuK zZ2$b(CVCw~jb}ImJfgjUd8OGxOtu5d{N$K833E@>Yi^sJ zbM~wP0v%=pPMd0j$Lxr6kW@!MlE4K=4AfCI8-Yfh`>Avl!ZVPv`izYegGimHxM3j+ z1Z{`Y#@d1mN)Ah8?|E#7fYLO}>=nFBYP`3;Vm{0hm}qe>As-Pc+)vi^Myiw`JP=Si zb?iDV=^oKs8&dV$)_Hq8kcVymE5D4x@CyHi2`L=HDy=PNV{6RMgW)(J8^C7iJoNGxxp@J+02x`*WsBQL;=q*_#BnH;0m zBNtnmC#yf{g=KAws^SYB0|xd*DFLx1s_pL}yR{iZQTUW&s{=HgD9A$})%Jm+CEyVt zV7vG_2;X!H1s}S?5E?7y@K?q_TgyI9zg-RyIBzv$_V0M{14t=SK|~`4`}l&R+VndO z`f<1moX=qYid!z$P##LDEyNU__KwqWIAel(K_61rjj(-@1hDWy2N0T_hMuT>#GGm$ zy@2AGcQ|D%@RW6W3(xFQ1}%^fNQyBZUlU>k7>mTsF%CUUzR^zX9iO<7P6($U{>Mq~ z&FhjAx6PB;90}Vzxv6W3#5Z-MiH==_!q@E^#ObYDkPNyd;&Jt<9@XxVz+_@IR&VsE zls;Lw0MdL*hiITKmi)8!4{5FUdHsiVC;01zBfv46-lDRI9e>;2VAsru0>C2zbhx!f zD!u%a&d2ot&YH7$^`U#=Fb5rdF@UkmIWNA^3vgZ@Meb22dKIdiI`(K_!QoOu zUZu{_IzutbIM6}wh>K1Ep__5H;7UgLE$SEHx>ZrejTY6A0bEKx8}capIh z(XPdI5|j~P#NRzD76}h>VFEy}?1+JO!rER=ob%<`nnF>;o-vaZGZ{K-&m>lnGsrNA zoO!}d93ZRZxyD+=p42|gbs{!;R?4LGhXA5B$zL~s{+-RWk$y*U_|3(Bh2Q^pdmg}&)p&Ki3>X56|}+V zv_BLSxW_OZnLCstBB=Fdfkhz_w{23=gT-EXCl&sdc{B=oeHbnkrF7C4-7W9-%K8uU z*!AdrA@e3JCqjuFAk2!us!srhaj5_-DYAO?Pc3J#QRm+l>;;_4%>9E9Vp8d2KoE(XZ|JSI z&~A_xQAS540*Uieyb3wtg3^R6J5p?Ch}Qse9F!P%*r;NeMDq~+Yu!=Mp+Hi}v!Mio zi!5^bB&4I-C9E}aI40h)I>5WcAy6bx@Vj#L2d&$Q9? zzkLIPzUiC)1fyP!`{rMM>i^|GKkfhJzwF9P-#H4K9;NOnwBI5^{;3zO-gs&9=TUYq zc^)a^8ScS&DA{Z{hM!|_r*~2S zP^I}=Z6i5JPZKS!lY98=G(+a-xxC=vB*S;D=tRDl0&!@)zHhGqymWd%#TKl>feTE@ z>U4|wcU5DOsA>(XN8PJ~y>RvEEKPZP6mqOU@oy$YyJcm0HhYjr>Wu8AX!Q+Wj8(kz z?thx5#Y-X&d|7(Jx2j$E(oUOoRxfCd!EG=mo>#7qmD8TPZbZuuIEkS$QNiDp4t1!Y zXhdsC9EiMQ{?lc@9i^!&Ea~m%NF_gMi3mN3#QG?yVZUg{ahz`5c_+hq$iRnyUK1=? zC*&!bwm9lVH~VcGt$xEOd}C2NO^82nI?0oEk}W?fQ-uPj(LGB}M}09j%r2cm9SD7Q z7~yh6cUOwwn$<& z<2Q7bfhcm9m#pslT_{r7;j7g*EdHAkJ&CvIcKRYAJpc$2Ctx_#121y{6~EYUGrrzz zQBnN3)#lgY#ASj;Aj)yL8m8T_gA=Ris*gX>De()Tyz72`-R4{PSNwDiV-S86RN%>U zkl+>|M-Xh_!MXr)%|nE?3Rh~cy-e9G=4cE5trPXUAg@QhRb1i`i*brYml?mY3TZNY z6`D&T4^WEe44bONj*V91ab^$Vhw1T9^ZR{D+$3dwL&Y+oafFM3vi)2uZ?(fRtZl7)S$XxiAn#Es+?a5Zm5$Da?NO(ylDjM1TX~H?o_)w~wb))M zR6g2|kUJ0N3h>tvb)tA;!YIgw;%qU#cwZ-Ms{^ChTAj{PXQ>fWJV}DDqC3wj(g(I8 zP)t!UFiurU?D2tCDt`C(OgE9fzBnS5-jI#PZC93LUM#;p+FHGV#k^7ZXmyJ8-rbTE zfpL3w6Fd68$x;@6MiAAv6){3_K@nv9%$T{WwGfOv<*jExw4;fYhEKPM+#`d|R~_Oy znr%L0?tJyMMoG929Y2UyPbWafQL{}Qey!ZZZnp=pK;mq+83to`zl=Mg|18qhju@48 zCKY|<0?moK_gv`)x{Zl|e=`LBkBAvFFCMfFPXS_nAGHUr8U&(=FB3<_8OS^d57;ju zoMJD8_LhAfT_({%r$R}jId%odl(AC9=r14ow(hpLJkZ8DcwHp_79*#p4e}V_NtDwO zxr!bjv!G>Vs1QFV*aoa%FN9dM1ej6!N`m7@V#Zddo<-3fRq_TeO(Gu1euvU6+l2I2 zdyz7D&kC($9wfCy25d!1dMx|5$N@0>B{_hQpBk zSVeX`y5^Wv`tBU(1G{dAz2^0uie`qb@*d?_+H1v(hlp(gVlSR>9h8dkTpT#>)O=BPZIA%j8p%iG7_#JJIL8J z=t#!+h9Mid+3+zrV8)NC=HaL%6X9v1N*{>|GeUM}M&v%Ro}8Bt5f@Bf8b$kC%g^Y} za=-tV|2aeo5=^F0lO88bV8KkJCjvfC`SyY1ifFrZ$@%($c{1trk1bkS3k$m-Y*r=& z$rlhUtO0~O%2qEZ%3#Q&M=}FXxC{xTi*8UU7Od)W2MgZvIaxt|MBai-_Bw!ccj%z_ zRRou%%MXA~N6C_`g;B16+Hf~mj%bJ@rzGS!+6wwx^4YWbAGzmWEA2H<2pR;hA~3bz zkt#~#tKUF=SjEsC0Re95A8{@-{PLOB+q0L2Fm zB1?&>>#RsQwdNfWDDbwKocCUH5j?QShzLZ|sdVd_KuaAR|ALX? zoC9yxM8NK%)i-E&hiH=;_WOyz)a01#gegEB8KR|X^gsNPiO}QaO)EewTm4xl7cx@m zXgva~Aa#(LCh5sYsXG1R$79ywBdJ3BU1=0kEesq9on^G8RC7p3UO8FzEu17kuD&n5 zP-i}C#a3dupIjxJ#nabyu$1g4hKs z=e@64T;&s;^*AjALp^d8_#S0|0mM}}R!>l`avPe%81-%Wg?i7D|1Y0-uEIfz2d)w> zJYDd0`ZJ~?eulV8u(G4VI#rvjb>bw--HB7LzlBQ`#B2Exr)O_gFyK1^8rC{BZj7ERZftF&(abHk4*2Vg-P8&T#zD@X zB>Ca@Rha_EUcri#9G1`nVB?Y(Zo`n|lWeZnB`6UQ%$>!}gUG`TrtJW#e!$`!#;b42 zY&9Nbi<1!DVXeZmq(DwtoW>#K+?mXdOQ?v7?!-0Yb#@+k4*?wcu1p2=fhPGyd@|Ny)MjUt5f9t!46|rA@J<+pjJ|LK8p=6Mdg+PUYqEz}S@6f* zaQVS??A{)d5@V7sKZB~IACY4L0R-&;vTKM0a_$ecgrkY(`(9eKOXQvEr| z_=+4>?l(qGhae#T%7e{=I6c#bK~jNMBjYOf+d-*5AbR~oPT>bes>XDp5$%B5$>!o} zPx#H&h>tq83mc*zuISM-=nuI{ZHz?nuH4nC$+7|8Z0EJ)qP zR#~)X@$!qk$YBZ`?LhuMc1a8KfuzI>W)$b-kPY>AO5z}y?`C45AmAr4ectoL(Ik*Q z$FX75V?0Vwli%(qLN&BZ8-V&qhCqUB4Kf>ya<3b5)Vwnv;U};ksXz^B&=i zWy^pRK~KxkECrrk_ODsb=ziwyvL0vHwun>_!Ub+WrvCQz0rCG8;thPjnZL4};(2K%P2icB0GJCz(#Cb_LiV%8B)=1-5^VKA zrXUI(iB>niAGD;c!JK`hR2g2Bv=ylDx1LTtV#-9_3Gf{Q_�ZIGPoYZj!GZK1=!8 zLo=dEb;ct<4wvEQRdN6F0rsIssT(&{qnb`_opqWYJ$S@yV8^UT50S;k1cA9RG(UnI zzH}!?s`Z2&yPN~4fz9_eG9?V)B~!>>qd0xZyt&ME%a4X2?w+DnpSII@k%rAS(+5$m z;b55t-?tM?j8GHj7K0;FBH$%mCsx8Mq5d3WtFZr;)}sEi{1qY%V(!OQ!$h(XDf zQfE#6WC2#?0*u3PB7siuWq`NcZ9f!#0mQLlVL%lfKuo$c49uU=qVEDK(q#%@)bvJ# zkc8eS{I>1k7%4pEN2@=p&!llfAr4kLGP&RMf7?&Rj4WPj- zcZM*QNymHMty{)gg^9sVHW)~kG@Kg@V^${QG89h%M8SfW_O(qC9oCp|{|E--HbuWh zbKYT#bUlZGd4fXVCG!0u2Hfkcf`?;bU^j?#!*0Uo{{ZQ3<2)+Ze&%M(alT z1=yrsagE4sIVk@-9>QPA{?Q?^k-%ZuZd3r_cmqYd#3_ODJ$=KQyz^I*q1F2+w!SDc z;6X*PmbdVM)hQtCzQQoc4hp@WMhezg2tu_@i$LgXspmJWkhX#*y;zTsg z`J*|fP3j{XUe$Ng`mU<)uIpxp%>QCOgo?8eY(qbG_WCanWu~xbv0yi5d8Oh-#n~18 zAlAL3qYivIa|@>_BeFSJ%L9ao7d{T_zS2+(xbM-b9<5Sz=4G)50F zG}+s>h#g&9D?v*rr_6iYM$Gs>5f|5YX;}Q$q74BzK&PDDMCm^F(?<3UArs(rxi~ta z`aKB!0A{%{i1Z(f1nTmdmwUM4Y&t3~P*F-`_&tBD+tMCPS8qr?V)2zZ>R`L*<@fY* z8zpr0=dzG$fW}AaU{0Su=xOM!QdNLpRCmUR4T1vTLVp5A61xds-jkIs=!K_#-ZhOL z9HMP#jnq0S=Solqh)ay`7}eevg8g+?e@0aHWUf8%5wUM1(f`7J#vVLPsLQII$fQf5DuB#(LLyKJ$CO8rgPtLp88?znzePa-=W%o zDF$#*x#p12L$r0~Mp6wmNW&+G1H(G(aPWCOe03Nb_A zT=r2N^EezN-0{+M(I5%R*W%++&k9Wef}f61*J?~J7Nh7MF@ohXq}1vu98QDqV^jid zS|nZs4QIrGbv!migJawe(_oMYVu7*E2wig(YV|(R31?#GiB9Aswda5D%}fgHgJcI_ zFc^?mE@+Wm3Tvbix&+11g_(^oc?tlsST-W83E0Lv-P9xYq3**Vm<-=z%#|ImI9z!d zK?RCgVsP|}0r*-yV*|MNXJtQX$(nrXQty#>mF4iwd&*)>u^fOTN`ci75wMF!_`Z1$ zJxttPAmnL64*#kOam3GB(jT)XOcDkQ{qYk@MGR4G0Edx0A|q=uVXbS+1v&B92_iNr zOc|sFWzI4sr{ne!vd03o{CMmtT7J;~%YW*@sX!J;ffWpur7=7SoEMhzFhwe%1{#&_ z&N0mNo+e6L;}3N|lnjV=nJY?C3neRj;aFP1V5|Ao)8Z6i0SPm+OqHjL3KI%*j|kpo z8$AM~i~QU=&cN%EC}ihyMRPJFGiMPYPMxklG2|Uj$pbya2tga%q1Y7sq5zlid9kJ( zJtnSx5P=9{hj9SK$B7ztJh{Fch(QWSw~rN9;dhhvON@OE7>5_H*}N1{GhIOuK@a%8 zS0Az=MBSe%=?!_Jpz)(v+tCxSCdj+*wA)Vr&lMW_emg@G}bpT z9>8$1150w1Zn2-q?z0@_9S&?i7;P}1PYKT;+*Ui<0m*zzv*3s_%Tk+@O+x)-JJUSs z`4-_%ICus16$7#+M;L=nnB2$#+v{i_ul?ogKJ-zUXdP%FhL!D*{{UkoS5i$Es7iDa z?lO%Z^P0gW8oysoiAUxTM(o%)`q9q@@1!bhnwf#@Js+k;z*G7k09XMqkS;w1AVf(; z*B0~S7gL&d@Q-?Nr6m$e6Gswc#`#VBa9=Rrl2FlgF1{ucxE&y$b(?) zECtsi6G8JmwY>ZLUvNOTRWZWX;URD4N;vYpOG`k1*4$6qMhZU{_uR}$^o5I>%VgZv zg8O*h2xlYPz+Cz4-wskhXcX`@?}0s11AKu<(lWZrmiw@D1UNO!e#QNu(z|xX(V?XI zb^vRM)rm*$Nhd0zoJb5pCrE=0AR*pp$xh$&PVq`T+`+`Gu=0H8nWOH$kjJEG;mFux z!!tNzEYyV0;SBHDAMK-@Q6UDKQpLpLpq_E0R3e#)Rzz+d^gFB6V(dO~Kk9MR;xA{R zy@Hb&T>6qL$exw*9im7~Es8CmmV(nj@0X^#@q8GqC?Njyn*2)hXO85rEA_ZL7Cr|h(4ZJTd0|Mz!~XvMMRdnx@Yrq8 zZLIQU2tOW@$ncS?#qdqJ=AkDa?!aa#I1-NgV6Y`wd6L&4?=O_cZUN&SlEHe{Ul)d6 z_Yz*}MH2O(Zlk=SldvE-1+6Br z(>82kV^+oZ@qr;6S0E6EnN6`j0MZzy2J^P9Tzn;Jz%<8d1ZZ4q|8ATc1l_#*J2ak` zoLhjZ_}FXl=r_6o0VA7LnUVEiny{-e0a6Hh^QGJiI^2qipvGBXqFQEUMggJ_!wVsW zVHO5}Bt7^F>2AB&t}W-wakX;y$$VuHnH7m*-et55BKOvAT zU)m{cP0jr^8*s}36ID9{gdRnTlg2mb>tf>k_Kuy(7;AHqeK%3y7%6*D7`3-j|ECi4 zj<#UA;4Cgn$5zE}T+YM9MVBUAO~h|4of!pOPQV}9`zH=)9HstsC)||q`?%~EimErm*w$`}- zT5&WUx$4|$KA0J{4=7YvlyKyjLn!B_A~Jbk?j|JD>j3aBe}D@Uw}ICV?z+ViFGptt z5O6)q7$>E=CMTs19Rm1Z6$M4kIb-~c!C7>7DuP;$%9LwdmiIp6*Lp!4n@sSn!+ z^vFFnFV;E9wZS-GGJ4gm)tcvy+TA04o#P2xqc1O;m+v( zu_l8LG5JDK69;cR(@7SR!_+hOY^_M<>b9rx}Vj@B?1Oi1X?1l}KU^NrcsH1q|kX z|B57Vz6E}OT>2-!6vRMDrbr~gZ!+mP_m3xuBIO*t6X5ZLEFkh#!F=&LdDaC<#H1b| zl4u}%IvVN0G>QZI_?NI5kjd)ssO8=jT;zDa(n zvtC@C$T2N4KPGw-r>n8VT^m;ucq zB84&Dldqh^ZJRozS#aPd0O~{e+veI4ve?1ujwoW;h>VZlkS*O`|A8ud$6Ot8$Uq>D zRCJZj%n{ehRJhME1lgDij(6?hJ=BFXaI=bRmE4x2ybb6{GlG^GZC$q_VO@)UXYR<^EsRiPTR1r-6BW^m=eO0lBK$2sF=^K;vLOn+j+aOhS zrW5TDA-+Yzzg$WUW(tj*PLtp&lj*}SD52EDD0{3|hSU%*tGJBv0$+gA8fpZZNILce z3+7M<1X)6^veLCCWCCU5_}6On(6vgx*&+HXRV_ctlC|1=Z4R7VK{_Z>W>K z@P78W3{~L^_rIL{d(Awwx29Ds&xUpg4&CpZW<5x^+5 z!Ccpw9(JvcY%`8K{U*hu21uMotm_3}^DcqUJ$D(rzb}g^XNJHNbkK@vL+2yW6icjk?7bKv0kG8d$NFinZuBSFPS&e?Uam-2__1dGl#Vnngx4RaS~12 zo?N>dq&`k7G#!qf&^b0v!M396W`nG*nvL>rEVIsvPNmRRB$JlCyqKyU!G4^P&qbnYRAFS*%=nw1(vK|vc_BA`>L zNMu4%NxY8F2-{4-shIAsNQj8irkY6~+w2gM0&I1bpx!848IYyfK{T>Dq`T;mD4&CK ziOUXF{CWoG3=?Ih**>d7I169>at0v}CFa-O&=d|x>;*Whh>vFDRT+Mj(SgX#A&G}& zfMY9?Ka;am8FL5|K+_{qTysdY^zx?}in8YLK=F=#4i&fxV6@h|uQ9r^YGT*nlwKkG z`SiJn6r4$8Q!AMw6vG5{G>I?+LKi(t;lejGp=-vDWr4Z}8jHsREq^TctROCyTl!Io z=y-QTJU?N)7`O8C0|XLyL~Z~eYPJF@BQ+m6L<^4uM{FtMQ26G1UMQof^O*Q$W=vAd z@iYTNk!G9Z0$tAZ%eaYkgNoZ*Yrdrf&G)3Y&Rn{PJ%_|`XAurVV&q7IyXTJTv(CU1 zOo5v&qC(}$#f@R1ou((51v&k)6(&L9G4zgaNcyNHktCK4;+whC)knAmyM(^tTwy-` zn`RWpsbN5?j0BEDogfH|gnHd)>t>iBjT%oQy;4RdZeLVcBC*I|M^;oHq0e>W9nT!s zQ>u^TLik#k8h}NB*{j2iN=>R5w zS3y$f#Gl`4pk`W35^PA>Gg)oMw~BIOQ=&)FB-o}ez~RjiCu=!vTHkq>;d4JKh1WlA z$$O&O;Rd=-FNkf$G(b(y-WPK1Km%R#K7~29wT;pa25!WWRu`B?7J_g;f(FBVsZCO+ zxIpkQW@d{5ik{p9O>TdpACaZv;5k$06DB+`%qJQRT_)?}NoVI2q@<=oLhUINnxIkPn-ct|0I z2DqJz?Yjm8tS`yd=Q1ayCxCNO*zY2dDu~c9t>r->r4zX3Lm(dIu;XaPyGDeLq{AEy zm{ckeS4APEU2-m|UuP-zA0cpY#KfR$06g3cNy_$PL+l>OJDBj8L^b5^;IJHAhsb57 z7wq;@d1rWN38-dz;WO-Jax;^Pg7szdFeVUaw`A7q;_lnn-NA*z$_l0~0vFSZ-{Dw+ z*|EEti@RZ21wx@oo0W2$8%csFx!BS@M&^*Vb8ehO1Wegl8)Is*rl#N!JA?M zWd6ELDRxYAT!Bh}q6k}Up{kE%)Uk26p{FLhGB*@WtC$1~^arJjD{OFgYsR>f9Cd+h z&kezj*`&l-I9O`VT#Bre*q-hRbm0Q13wDvxDUh_l#eyminNO!cu6K-;v|kh~Th9W# z24o&3uje*=%C=;I;eyVlC!WYPnU1&|(c%QzW?B%fQ07T?&mlzQJ;Wf_26UhcT;c0z zGGx6TDm`*6CFP3DCc4xGcsm#YykcdN4V+s`Bjsy@8cCgYoIq)y&Z+JKQ9z3>HHGI; zA9-yevNm&f992}h2!)_-lm{k7t#T%kBy<8cQ&RHlVnsJyxGV{-1DJt&xh;o@akzL$ z8O%0Z>aY}oF2wjx5(&tA?sIA03S~)gdnf>uJ*5pzf+xG??Zbk@g(Qf50O5QIZdmfW znOrYf(<2QQZj;%J*-Ol6HPReqj+YzbS%%R)`OM3 zZ;#ow_>p2V01_jK=9$8)33MiD4*ESu;Q&3PRiy1J4ngCn1p1r`TXd1Wob+)k!RDj6 zbugJ%Drk(Q`1ben^dMOuMDj~Pv*e~vv74IVDOhL5HERy@_Y;p>Eo$FjCgD!T?{-6S zIoJ&L9EF%!nt_O5C%^a01bS-Q1LZ^pM5%5)hPag zXQ31CR9FH8i$-%sc;L@(9Bxn2x zAAoEn@^nT-A_>@BAPWUbK94!11&9M1e|VB&*K=7m9NyI*-V_`V;eK^JajH@J+9E;Z zoTC4T^6P?=9PvdC_{E_Uwd=Gj>oC5^UCw0PZS3f8ki4(N_VW~dz^MnZyWN%=gs8d; z08@vPrpM9;;)F2mpOOUv^j{E{4T(kRo5LN|zTaR)o>oU1W%q9Juyk(p4MTjV~PNjKMJ-CciS z=jOCx-FO8Ajrbu~jm%7MZ4OD52Zn$Wk|l~XdUqw&irXQu8jyzU%bcsyaJ=FhLyw^Z z(ee=M|0Al99Z+;rAR$g+^IX?st*+cROKCH^?E+63N1TvC8xnY!;Bc}g>QqyA;4hKc3x2iw-Mg6*0F9iv3I{|G> zJHb&uN+gtuVSp63uX9JK0()4A^np@EXI_DGboM~0%5Q63hdrg@DU@L7D-U4mRLRs) zs`toRhrp!X#(1FA1!M|q#DL|z5@*LrXM~l|Y-1z9R#6~drUouKgyp4}hAA1G9HDlZ z-SeEg#ELI-5hw&enh+AQ!zB+2u7(F2Seph4-oHBstt)huw4+$k`HNz4=0{yCisgz~ zk?N+~j=XYrj_x?Ulj~ke;UhitrA(KaD;c27;&;0(7pG$E!vh~IW{y5N*ycrhMGE~= zrZ_D9=S0h$nNq7Fmr_eQ;T2a*8XN$gVG?%habe7#^%Jr_SGs2aNf$XcOhoGXO1Ca5 zSuMHw(O_TCxha~%f+1HMzBW)w2%N@UilC6(CdO(u+W>*cO|}=t9Law(wC@U&j?OjP zZj}%QRyV+kxxySl5E+mJlV_roPv7_tzbj-dY6;?uDdYObYZ57fc2 zZkZ1bA&N3sDlUoC?*683o4kff`>j)KQ%EUpO+%ZUU zJ}LD5ypnd!fngoE1B{|yXhDG^+7yWL5bZjT09m+9{yB#oNsrN zbYx^5g^>v+h08w8NhYS0S5Kk2m?mu)`*g=l({u7CnCLmW5=P>#F5f1Hb+h4mX}S?< zgzo^dG9}D|YmodCv89A96kPtqC3W3I8WrK-x=35*2^2dnp`zx;M@j7#T*?T206jo0 zAP##sDY{&;NH$_6HB==mEg9?z*7aXM{MEO7hzjDr1x|TdM!zeA=ecyXpxs^EO{=@B zTdyxHNFesQbO!q^`QS0!^eMLgYu7RdF&0P@l8T1EbMu8~fa|{?cdiUjxPerS1!drO zx;TL=6S$bmIAjQ%CpB2I7g_!9u(RHLPkNV42aZeMu}> zbW)+Z#@_vLQsL$f&k8ih@TA1K9KC!%?q9;l9rgBZ@`$s*&dos$d+sSrWcd_}JL-rSdoh81H8LALvI@c`tWtA|Q_8Oaa# zl>p~H8yb>T&H9R~S#umQY*e|@7?d=HVTl2W+zPnt@IX-}3XE`8Wq45g@-(5(x$jBk z@%2DU1-BR?{Za238gHXs8Sd+%Cf5NCEj*b=U`Pmk?lANf_Nr(fI8kNRPxetZxy@K* z=!xn4;qBr=b{Mh* zf`J=Wc&G&&==$y$XB+s%MNvEqnHw6&vm0L&G-`+3#8AT0*#3CNQN`dNLT>&zuYX?$ ze>4^8!hdf_4IbS%pv7Y|)f{^zo^sp-(ONL)urAa!;$9Y94W(B^BPjbOm{vr#iZQ5$ zO9Cl-0)M)2dD6{tG*|i=*cjeIL=rj3VHZRoQ|y)E^9cI`4&k^g-gzVjY7;uF4)jVk z=uSgsl_!|K4qD?%MN&orm8fPh3VaC1nP%oJe<%a^s)9h_}P5Xx3%5 zCIgWI{an~=I?HQrd+=(Xvu5azm(iLuhY@0eu!+zYg_1gRgxllr*INfCgA2(6o8x}u zmTGyOVcj7|m=KLeSs_1kRbhL>+W~s>ib$OR2F{^LU_39o6~V zszT`8n`KUd*s`FvTvdR+W`AE+zXuH{q2w6dY0XLIT;hRhl;YWglHJw@o0Kedh$Ee6 zGBF214q=Rh%b~(C3WaR}zco9I6=p1K8;FH|YJtG}NvJ~syj+LEax7U{h8CR}9*5&v z5xJXiR{S&D%qNk$+puN^zmL`pmpyd$BlizS!vA2M>`Zu8nKQzBc#sm7X#S8nyoVdp zL*$o(V96pDaqm`_xPu-^q$?CnV^QeAZPn#0t6nNG1O|H3WecIcqu`UCwBfKgQ4@0d zYaR?o$0}S0ld9Q47b71obWilpSREi-1Ci$Xj#A%6597WsMuzQ6A{E-9SE6JI8)FBDM{+zlwrOd9D_(O!_ib()TXe zpZA@UOOcrSOhinMBPGvQJlqCJ@6q$QQTozlpW;?6au@>4?&)?L83Z14QW{ZQJ!`m- zONTh|=-gf?{lmKQt~In*^fdfSLHn|7db{&xgJDt)MiDP!+=SeQZjKhg3zvq!;__VK zTRv_~LG5f$=)cC@Oh|@{@euf2a3_n2lG+^Jn&GXSLWwHGkRq=y8RKb}Mrfgsw1pkhx7Vw+;gX~s!M?FW(8*b7@1=sdtBxh3VA6jv|!`J+*G<#gUD_e=Y zwF+3l7^RH|>7|yfQqCB$QqBwlkKzQZr|>{Kk8PUI#u;u7{%RJ*;RAkWD2LpwY~T^N5qmdAo(q% z_@H#$S@w5;I^d-!q^MrBFlC;c&t!J|qUu}RnQJq>->YIjdmwEdrECb=yBp1f70f;r zV_?C^9Q7I4^bV}%3!4X`bOaK5a|n!7MV5Ct#yvnvK6kaXq|;H*gmF)cDD${0-=@|2 zS%KJ});M5&GdmMGx7s*w;{Fp9u#wN@?e-LD@$1d@F7F>|8I!`|vKJT@22e7@b(}hG z;6Y!xLz%auKPdpheQh`n#<`cgY5iO6jF}ugb#EJg=HY1N2kxFdUvF;-F3$a$d-y(x z2bR6R-e9VmeH+=NU!1QMN>&9*>gDx&l8)=K zg)(k`u~-U!d9O(GVK2;iAQ)aK$FyEs9=hEyOb?S{fda8vueUc;vp!o&f3ZPUBAsp$ zoQCnt_mw6=C4Em_`&@y77sel+CmLq|ES5HSBTAE6v;ngeN*qFy3Y%qHe#Q*Ir>P!mcRx~7yqMOs_bzqC6X?nhq>ugVZ{h zUHbm8MN%*TG;_d0ioHT{LtB5AAOX1K_J%tf-A>z^PEBN~cYOFQ*)bg&)BZ zf3Q2KG^-@a4*yDB4pq11AJZ>+U}iY+6IHOQhwelYl$^B9p|`o?BNPA-5PMD*7Y9gm zQmL+llq65NV=d3PnLCmJ+Z`G8_rXAAoPIF^7{$p``n2TQoJuoqH$+d5so_zhB<$_w z0)4$p;`bj$(s_HA$Zi`U>9ppj8*x5O)}$+0JpaW!56COY7Y>EEu>hFpkp&NVxOqQ5 z%k8cGI9llitJshW4^{2Hf<5sP#@eQ&% zjkKP7yrLOW?9UiulnavelDtfZdA}x zI~z50r>i%bXhNP6dO~_M-Me4L0_~3XIs;soI=M-0hA+lyKJHy_Fc6w$BxY4onICoV zFB(-KA>=YLw^(&lYkop?>YSK}paa&i-3g5=>8$uz3Y+O_<3mofj!FYZ0VWG>BH!tZ=E>hqreoCJ3cIbTrs)(N}eRod{b<8 zJm6A(+P_ID-~#o|0`d<1_Xpl#q3FvuRj?^ev2eI&-;j7A@?v$5)n=A9VRe_|ENvF5 zc1ggdVAr=3oFN_l_K$bwi@o{c{_S3t_wo+(uloKK%XrzLO}ysWTC$i^wBvLs04XF# zE2I&+U#13{gzb3TvzJ3r`&Xk-=HMA*PB1(9plU(HgxwkXFNd(dF|`e4LH#?DK6C$? zr&7Pp8u&jfGmedT^rWu^5js(GV)g35`M2tM3>21GDEVb;9z&I`dEBAB&(YH!B?E>w zje~;6ZN5V%jDgfQ8%j1i`*A6FufC_q;{ zZg<%RZ9&l3AOY!iNIpOIH(vJDg4pr0DT1bIUJF7Se`Wx|K50Mn#p6bwQXtNU6Zu_X z68)2v6;O^Lr)PxUZ=YT9xGh!c0n_`pLzKZk6{RgEbjQ&B?K3qHDNTnfl4(2|fK?#O zU>K#pF1}lhs(c|Umw^Zien(-Y`_Y2DTSQ~5S@r=uIgXgV93;hgiGOWrE4aPzpJh zTS0Jku6mL2TSYd&`1uqDB$?Y7#Rlv>R;?E)tH?a5jTe0tBc`sf_xv5xyiul$6HSVs zFR;zSgI4q>uRH=tyfP*vsgu*45b%e0D$+K*r+4i6#iT=d(AwJ1sDxUC#?86zEm?6C zZ_T(yQ(2Wx8`8GSio0$@+8*k4-a)%g#Fy@=R_K6%hFrj*9*GPMz4e`ZosW`Qf#K?> z;eLneyFk{gclAlv!%N-k;e5OC&5K|0;3~%J6#u7%2>XyafTkGsn~CzVJ3wV?d|XPtkiej3UKH&^JJ+;EVEp( zZr0V6qoY-0KS!TpDN~ggL^)viOS8aZWMk%MaDqJ1Aetj_| zu9chqN^*Y3qihA|7t^m*cHVH`Ol6;FmNvbWeXwpyYtoYZdm_8&^h^)R%h&b8aQWTZ zVVY#o2%kq)B0EjSn3T?0wnZqMKZh6UI^FF%S{M&=bwlgfHF%LisTv%*aXtNFNc^`w z#Ujuk977yWUOp{Oyy!|Wj{UaFj$=P-3FySfpi59Aj78~3 zV6=}bkFnIwKr!+rYIRGT=`NwPEKw&+A+T!gDwD}xbQZ~m0TW05oPECn}g&;QH0~5C<6Tk8~zGQr?a6g3ABM84rM@p2w%VdL4)3i zr!?LPI%a&hHqABv63s4ZP`UE^I?d-Y1fIemSIwK@-}Hezu8`n~1fXY=sTxrB&bgSO zrLs_VjKjqm8V?5-tp+IOR1cYj!8*_b1AN%W*Bn66S*bq^d?pb3sCtIO8y3dgoengU z86P2@=-yH#ReViuT|DFGvUrz%};GPF~!v31RFR6TwPv{dhD1k}%!&p5iB9Rijd zQ*tYQp{mPXYJ&qa`qDi!J}&MB`v>EZ)N{t^@d~u-aa`SDn zVad*;$}8^3CXGvBaID6BIT%2w;2pKF3(1kyNkKq(>|H|3dB)RY`A$>F`tFmckkP{* zhKG6au=FygoeHbY0pw)+_m&NN0F;1>DBXy&nDBL*0W;r|_ZKN7da&o2gD8R{f?KUo zuP4ypCPD&}bDv~`8-bL&eaJF@mIVZd#dA|elcu89pVLy6J>@=g+91ZI5+l> zH&cA^VU_$9|6@GuU1LpMzOS7oIwWh;xf-G@A`Ymre}{~ z^gxTR_TuBY239nDr`CPmYIb>`JN($YUrMeKPjlX?b%x9SV1cwBBpBpU(*O4MmVsa; zcW-o6ykLn@f_!r*U!{*6!I?99ZLGM>}$~M+IXp zEkYtu-(1p0pLLYRdz1Azf1V|m2f3w*XCK#Pu=f)SlXeo`*3YIzcVNUdB?hOHknM{$ zd&?1H9m)u^-8X|TdleF`@Jx`!)!yawYGmGBf=azhclou+JhSoUvy&;DbRk2BT;w#H z-01bmU3yAWN47GYEH33_I7ue!SN&9ZlV^(*nfGoU(nXNv^#rBcpjf1~?ulB|)Lag1 z?MyY0%s{qL8MCBpf__nH}q9_lKx@_ofe7@J8#q}O)mvp+#Y>Yf2lC=n+IM0dgPkW(;h!@&OdoN;-=6ou<&)=;|)rb+T z?F#wTP3FE${MC0I@QoY>Csnp?m^`WwGE?W2^-Av7ZoDK|y$L?-c=0T5u@FOKJ*zAHK?DgAwLi-a^=2w*VMd?L_y|wu!(ekx{hOl zXGSN_)_mChP1_B4am!&Uf#FKLrGr4lX3Vhh+Wr(;l_vx4M~S2LiWO_M9@5eSk7 zA?rG7^)cdPInyh8yE|LM2dHZUJxp-Fmaml$Mj#vKuW=lRlO>#xvJuZ@^*?hKi%xOA zIJhf(JXgegg3an3qA(z06fey!`SZ2&cNHuG=1{)P`A@6GPcb>Il_)~z1K6B%-cjgi z_S~VARCUpPg^(C(OecvXT3N&VngX!nB3AHs$$2oY9-1?@uOOD@=!QUrY4U6U>DhLD z{E*@=!&|*eemYW|9gi1vScRRUs)@A$vojI+=cj9#MYnhNSJ^Cf-j(Z)Ywuu{LyOkN z9t=YKcxe#WMal^B6gqPbt#s;gowkONdT7w1$lf|z*Z-=7M$ zsign9vTI}mmVPQ2a#8|hqYxHWe*yp*6c(&EYdgc$9%9a4N`J`C7%%A*&%zCYB5xAl zUH>4j^hD_TQ-9Z6M7GG$`ei#=tzEHwa2c~>_S=rBFck>|T;brV6`c)H>z7)*qZmC~ zh2v41i8J>CYXus@ps)SGeHYvT5>ZScp1|pT_;#-BSWFt+PxdvBa9Vq1XsYj_E-QGN zndvVRbz@0yF3uG}&Dsv%@(=qX>vzRGiCR?aH)QYQ!HTwiW3UJo;Z@MA_Q!6=Wc`9Y zE7O6#{rPG`KLxnfT7Jt~wr>`%E!Daa>d3lI|G{Nvf=rHsGJMz{6gE$+{edQ|4J9sV zGVF%4v2FI({#sqw8)n0onA#>Won(_)ZI+gI{{1%7fb0S8$}k?P8d72 zG%rX9t+GRS9P+=rZKK>|aM&TY&R5weCpl`O5K6RCrXfyl>@$g-tH6()5caQ=xwesL z>PI7Km*dGH>JVw?zua*ZXJ$E=Kr7-^O;==MU3n#XdZ6cuMQal{gIcZQ(Ved+v-9=H z{#lDlSjM7VPv-3k(Wy2nmEFM1FJbjRUCpX8fvAdtX%7@nK^C#1!hu%-=-~a!1M!a8 z;cZh|2``q}=fYjlw~2@%*hsh|^;&jI1r{f3|Hab~%!M_^p!BW|qnAr|SZ|H_syx5) zdi*q5oNoqiY|EK_+u>}Jxf?Fe-$uJ7Mu318VLcQ&!q)be^1%dtmCO)U5gQX+v=lOz zDEHWy^_9d(-IhR+N&=%@(V`epp;Re;%iG4q1w%Rxn3pFjO0r3$iZ5;t*F`F35LR*J zm2-blmJr<6;>Fuj#P=GWG%+nG)n?DqZj1VOxrkFIQ^M>K6%IC2k2|l~T^{w(@ht(e z#loytb|jSZhZ~mm*A>TSfy|hbV3^&JE+sX6x3@feFR|J}Dk`7B{X!1of_!iQtP1%c z5u#7w+Lf>((d()?^)m)KU`{G+QJPio`}@g%|9le2iM*L0e6%Xl@|KcfO)4ArCgp$q zF3Q`Z)kHO1^{Bt!O#XcFd!myJ_+HJ{pX)Pz`-B-n+N2Kn;cBW9vG4tx%7N9A3U^w7 zmk<7ZcKN!Wo(mrmn>L?_%hGasd2jwdKkL)9k7X#_yIQlN=*5;&-4nY$|MefV{`=>O zre82Pj^%)w2k30aiQc!rsAM2i_V%y;st*s%>li;O`W%=Vl0^j-j_q+y8Pc6058O~hvws9CsC6@lep zn8j`q7b3c)^ne7c2b?UudV90-FwE>YIRyUmbXB?20D>5Uhl`m={a7{$Ll$Z*FLWKc zzU%aAs#5+C$G+2n!`^q3PT&|@v(kxVW=W(ZP;bEoAH7^Ayx@e9D}Aue748a6UlN?h zb0LuQPhv_(o5uiY!c7s0`S0(S|2%n`Ok1ch$xR7`BsBi^=Fwud$7I-(s)8XXg-I|~ zx;raOP<~*eLSw%fruRb16yCGFbzj(h`KaO>mqtKRJ%_kAy!-fI_T^$XVsB)L?e>BX zcJlwplzKRi`g}O~^)JVB6<7unw?FyqlT&G7&0MO9ExO|*NL}qZu}hT|qL*!ixFg~}NJY|tO;o5yQ0=xc?E8JAlf=JCggK{dFnx_a+$z0Ih%3+!?g=J+8 zwV8Cy0Nf>5VBuB28DAK|b?<-pOx_HXs8+Z_0z)I>9(Q0>9@@b j+&_Q1%3|N{{<=}QnsH?2bIDiOH<(A*x7cUc`@Nt2{i|A(Hv=S~Rp->H^ZnJ|ygYqvj9KRY zKm6f;zj^u7zZ=u|AM=lI|K(r*vtj+1FP8B zYpr;3CCj_{z`gGUVyE*RJG;AOW5CGy9tM}^eb?U!6tDC55h%aD>)3b=-bA--@gB6n z+45-R`MhVna6QyC`#kX0gX91c46~}I6N@*yOzqfeRT{-l!n^)3se0k?jVC($ zT$rwWibzdUtPAFfcYnV~qu;Z|gv)pV$H5Es!u@x{U)=l= z&O3w!RC@Htm`6!>2S zI<`8KQj!n((8V1a^^({n#dQb;P5^?&15XyDiWl=o7(mab)k;ver$x-`Fo8=S z;|>1M8-~Ba;fW68!;Z1eR4b2dgtIxqRZud5ptf4+7}JU4t~Gy2%X1G3o|IJW$HqJc zWK5T51+w8*it|?&g7%%DM24WQT6UORbY1z0z`NzS$I9=fQJ%a9!fT*%=`UBZHl+FX zF4pd*k?_ego{?&G*O* z&4;G^lvH#qKb+#nqfs~Ba^=h5udu;g5ckOw$I5v>Dd%0tL-Ux3-*Wbzi_+|Uz`tWm zJ23Dwz7Q+N_2%2epipt}An_=qx}(Qk*GY|YeOTfMqM(6tDQ8Fiv;t#C!`yw9McuSm z;vKRmNC*ITLM0BYFsH=c#neS3XV+b{Wow=m*89q%m!J8t!4~k<6JJD-_~I||mCUY5 zy3Bc&W1B|g-F1=;K+vXv$fxs%){ECv0vH7{XAg~em{ce5I7mJ~F-p_uULroMPv4+{ z0?*}e_)9Lu_&k(UD`Eim_6%SH7+A#NlpI!{^q#m}m_&;Y&Z;R)o2MqaXR7w3>uNgK zsEJ253C<^}>12^}(Y7H`8KYc<;vS<_Av@E}qOY>*jX{mFy*Xam`igmf_>SIDwf+B}Q6Cpd}=MvNrLC$>R zqOV->V!uO@~Yh@)Qnkp_) z-~UsBSZh#6(E=8d+es?8lvMn9fvs3}aU+Xb?tSVt4^{|TUZU=7(>aiS>ilvf&!LI% zX=ZlqV;X;S@t3yP27&U{AlC$2koyUGlsU$Ky@JYfIWBQTf#C4XUf#$1A6j{T0}#jH zLHbK${k=}$IQG%UN8GX|N6@lng4DDxA2w1!q&#rmlV?jwGUnjG>zTEw~ zE|fI3)~-`QIRlT&kOcMXph&nUweoITCeds3b)turphiBez z@+}ivW7kd9qdnN8eK|1ZV7Gkbe{xSM`qG)NT|Y=)HxcXJna0q5Qa3RC1PAch0~P45 z{KR{OKTN8#E3!4n;q`b&=J1ZMX_AG|FNYp4a2jnnvz{G72gJ)9zpf6pw}&Iw&(u-i zg&?KvM*Tdx@5~q3F$e(3#^@6;lrM>N=5$aids(%5up1A$(Pn1W@n+n}FK_v2w6ujSFE6XVh$ z=u@x*ElaJvH+$WxrxT1Ga7OY0olqKdVvF^OaV--yO!0y=-vCo!Tr+?>V^V*QdZW`i zgA&(!oYF_11ycUGMD2X}DN((u(;HRD2E~$3hUFbZVLIx-r^?sZ(&Cp-53sEijGCn2 zte8At9GvzqTYH}&Ns3#h{7kB+FX@AzD{-;C!>O*Jp@!y}r~?Bn8#mOrGL^!2kAPZM zbtSx%B$DFisUd}VqxJ$I^u;h^)Icz@Hgz7Cuf}FycW#uYEfY7)aU+ePejn=OkCiPt zVl=y=90LD4S^08k4kbGVkm0=s%kw?w2Btj4>Lc5S!}Ce%=#W=Jk!U&7$ll&JWozKd z^IrO1qN<#V6?cp}T6nEjQuZF%KRf7#2`wvbA^#`S?)yDc`snZC&*jsxc{|}}U}7F< zWmwiUavFU8r@KVyTl;xvyu0PObP7G(YQI~pPVryR4q)!j{i;1RRqsW=IvaKEN4!=m z6J|s_YvmaPi10kHzFAOw| zdC7mffpT>u+=#I}`KPgZ!)QXK+UHSwX<_QC9wheSAD+{mpDx9UVNhAYOr z+OqW=y>j1J5BuA)|2f{osQz|!w3&s0KiszEoehIvsJ`e3BbjY%IH%o!tua_T$QeTc zVgd|*U5Ie)9kiK;m#i)}K?#`ZSh3Ov{z@>dg&ohw2VU9|K* zSR&m@s*`Z;w@Y?*^L`So+sLvp!kgj!Hn)75M4IvrO+>DQ*Pe?Eh2uVn!An?O$onnE z(xKm~*y8&@Bcx6d7_#rGGiE=TJ@mM<=P2;4f-wHuS;yMVIA|(bOSJHK0fc}+7_gzC z!MvKzrOX?%&tc=p>$M7#MRhc!sqw(-%ySgFckDTKxj)aDpEtEm2>}bsN-as)rDS!a zx(yp?i^2`4fS{LY(OF8+qlZPxV$$%{t3zn4)L)q&K zAuYc52<4W3a zR*i1RTs`{GDNUfiOLy;|ug>aCN?J66u{aL)qRw%@{6w!ak7$$9#Q*&~h?IIdW!ml9 z2i7lBQol(-AeaD%JhY+WV8S}U&x0c^(XP?|2al;I%L|`$e_a|fK%}fRSf}_w9HgLw z(9K@;MP*C-1u7I^oUkb*Iun1nA2fg9fiXSPrIJerUIqq zDc<~rc`F20d@qEoe5xN8dx`Q6+f}EIDA*wkHood&JE2hbB?{evB%rIB^rJ7-55|2Q z7ZAf05TX}sDA&&LS{=?2jnTO>ZxWsvUi z$BDCj7Rt)LMF$L^ivVB>ds;iO!T>wd231E&R@x$hoCMMA?B5PD7VUOZi}#`f2&9mQ z;R&VI>Oa;Tyb2K?hR~9H{D3s(8-U^(7(242PVg&9ocS4r25{GJ%Dll{aedQ7+W0^1 z)2-;*2su~iU40{x##^KWcSs1o5~FNABrt0rA*WH6neSL_1Q}KjOzhQBjdo2mN&QfQ zcD3Cpx4?lJ;OHVV=DSNP*}3;|is6CNVtf!sGHLrEht1suNKRU*V+N)*EWy#m zIJ3<(+OWl~MR|b9D^@7c$J!rFNAxm1g`-_d>Y=f@=!aCu*XIK^5yNEmsA)V0wRdNa;$`r^em8YQ*xRVwryOq>E`v+-UYQz5ORuky8# zi?>nlWp)4yig+-Qbt$jDqRoEXvA^9Jr!Guky|^556RCq$xrX5N1vbzm43%Tpr$+nj zNd2`VnLUU2SSxvIhwMP zQgHXHqr+XscJVHO(e;$sjSnglQ!M%L4=P}rJn7>}#d@#~$lSt&po4Z8jLzznr=R6S z5RZ^*COk`Ik8ld*J$9yb5Kozv?)FpLa~`$97$GX&}=yfJbNg5Y1Fq- zFr+>cH+;NsGxQh*3jD1#XML%mzLi>G#RNQm z*viN%(d!TMi8i8suu3%l`N^S|l%KQ=r2VHENQ4pcP@@NJ8PnKG3%dN1G}DPSjkUvh z{oPNXPJB!WKL=wIcNVjWpO0=-7di`n#M)0E3rIubNklH0-MqpkC_SxT8*9m&#^FPx zMg>2#KAvurFP+f`X`#&<0wnPui93*wJw_Z2{X`pS*i^8gzz-yA)@~n%AusX{Bcl1o GJ@bDb^^So6 diff --git a/tests/test_data/expressive.mid b/tests/test_data/expressive.mid deleted file mode 100644 index 40d9e8467f06d21bcbc2fae9d7aa0dbe3fa53def..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4786 zcmXAtTU*=Kv&P>gA28=r2nk7`Ar5DZ0?0-}Hj-sY%d%u4TNpbW116-qP0}W9lO`Ot zXrE`_`YY_qefR%{-D7%Zq@}f_nR$Qn4qo`K4Il{yf&I4^e)@_28pl6Dtml7!z(37+ zTzHgX|E%HPTVv5r|Iu#$??Wdj#r~t=Yo}p-5BKH*um+;nMe8nQY@k*eDMAhrm3;m^ zMvUPt|!o-ze;FG)bK$^G#&;&1UT`WtKhig0U<`{oW%|gztS=B zPNu{v7X^;t3WNX;GnjJDuGSs!){S-E&yjTCtz$Yei`cAuv0RE+ah)W-qC{bC?lEZ)k6bP&P@?84W_7&aiHL#i29(#K!lL-d4p zyN6$4LsJuW_TZ6`*-6KM@nrdKra!0==PkM(k0 z+haN`4Z~)eEtaC+P4pVabgx(6({;qP2IY2*ZWnf~0t9gj%QjPY@8+~UdMH^a_o>K; zm{eV%fE|_AneH+r-A)wni;BKYG1aaLJF~_^b5aaq6j)YpelD2PU zDnunF;Y7t0H#SUg4yKl7Wjs(}XXM)q*k=BU8^fFUnLbxF3>C_%dMy=ZsqkNScc(!3eO*8q5rFq@;2QtI@Qw$98QR7et6S7r5K4-IBxhy z6TRwQta-|^!3Py77d>u$B8!tIhPRu{J7+on8C|5k4m*wGL=OGiNO=ewc-p8z-(|T9 za~jhQerRAu$9(;g)?%`ggIC8)JrEpa#=+lDLXD|1w2ey&Tpd{rzX$xL;XDs<5Y`?A zNDw00Mbf?@vzis72vMP$zIANiS;!Y_*sC4%Q~mge&a}#1GT34IE8!jyB$Ox_29p7P z>>}c%7GjClmObR@ttUDen^i7H>81KR97_RY=XE@TSOveYd*J6o4nA0iSVxNEcXOX* z^j-^R8H|_^jOJI8u4qC({6Hn8V^G__O;b2ehBUO(7*|@%z&8a${Fte4*rF?l#aCh% z5-H`NU&V%kl)M+k2@aX8pt=yVu?%KaJmDyaR3|~yXH>o!h!u+wz7l>q~13WlwtYu!#*_@;5_xX#Hv8ZL4e(zI&=qG3VKy*4mS-yiGf z+b`(j4?Bwnf(yofWD*w$9OA~9j7eHCqaoE10SUu|&_xtT!Nd&z4Z`DlW?0eZ5WV5s zgpWtWjNJu1SvRo>%oXo?>iY{J#IQX^sHnqiobH5pBHCyk5Ho~z#2auLtwwnMuz_R~ zg?c+i@d?4w@Pea%nK2Zg8n#tpMmt0*v}<)I4D5z)TOkI$VxJGgL(<(T)-fC)9#r8H zYIx}B#2HHaV^e3oVk2pqM3g6nPj{#FN}waA?I{|jfd6!*6;f+CAa>B!>s@x#i5)#^7!!-Ec|!Q0y#D+zL9)d7+UF*Qwp+|= zvs^IbVxIPOi6CP+B)8Gy;hnwL=ANHr02%-cu&P$puSg-uZ-;odi5RsMK)t> zo8Kk6`ILx6v2ARrRLi`cc(#;$hG7lohJRzdr3!Bu=r_Yj3vXqE`RyC;c*RG-#F(Tp z_hiY#eRn!|EO)zXxK6P??04|VzPu+$gRK=3>wt}mHH@3o|Mx%A;mak#c^Ad?a@xWC zGO{|7#sL${_x+igux+?RopIlbdY#HcAD3lw(82pM1Wk%i-ubLSC=;eF z=$qf@dtAUVuJwZ8;L7uQb$sHx4=KL3O~ij~Q}J!QFJQzrzu8D=GIN~o+)g?8&22NK zDrzuk2jWzC83}%k$<=9&bMJ`BAi-af3D|THoGX++r{f<_5UNS~l&4IGG*^2oNSvVj zTW~xeL8#$s%_a-5LaT`@OPw{OS_e_Gg+%N0pxQX{+p!j2i?3gZhzeC~q<1x1`d7Q- zBk7-xr!hf0uXjvsmrR;v>po%@9*Ihp)}N$}-R|O6$K97{+4D}ZIOd{QM4w^4GO#S> zqm+nW(|Bqfgcg!Js=9;cnojDR5%YJGMs0(f88sAH-ItM6tRn?Xb4+TS*?cpLM6T%N zFd`Q3AmX8RDL5BU-UylIl!hbf$rL09(L??=rk}FaP>`G3*k zz}wWBeDC}*$KN9@D`=JyRCx`n9$I0|A)%0*0=wEks!`2QG^RTt5@Bh#hGfls+=3L) zQN_Tw36oBMzQF6EAVq;r`Bp4|QX_lTZRI_lHqxXb$O|4;8u-CQ+%br`Wo|UQ%9sPq zwS1RMbmb{oOAlMT#-6RkpTV<&x>=)Lk?dIAIKV`LlssEvRO&H0ZBD~}ho3*CpD#OB zlUmz~18Q46W-5`E z4I(UrG~)Ws7agh9?W+yr>jrux^@~J3Iybh8kCO6-ghZyWq@%uW+~`^A@8=mjk>qdl zGLtZ+<8&UeRqNv_M)lkkU8m}u+aUFPoI^5W|B)dgWv?aWtAzN=Tvt^u5nnT|z2ng} zWIaf_dTk#)((%>?bo(YakL&ni*7E_fCPpQfG$X8)Kd+(JtS=EK0`yy#Y_r5flbewJ z6Our;iCVZ+n#`z^%Z_^8b94j-1v>+|S&~;;c)`Zy(C-6PTSvEm^n>C8PgOeoF)b1@}T z8qivI89$hGC^VOYWM9|u!ECX$A|a7hz9J+Tt=|j`EnCd_F^cJY!NJHfcIh3BBPMn; zoYZT*BE=ZJS~E%!@C>Jm@T9*-P;^Ipd@Q>Y4vx$8;J4c1$=Z*nc^-vn6vRn!6>O9d^f7$ z7aBbL-B%o@j6QFF)W>d!`AwVB8}qP1mA%{XPX_`_ZeotAS?|;?nB-3S_!OKbXbP+0 z&o%3I4FjzsTJ>8CeXTdVh0*H%iGQ(Dg(x1-R1saKcVE+He@$zcCyn;`4JVCJv)Cj} zS@sJbTjKV#)g85>_i59irSlp>?e?zSaZkE9-tMF^Q`DziY~KU1WMERX?jBl)OCnxa z=-W|`cL;a-&OHsgq>KiQ8U=e=)(MAk#cC+%(>%KWB�giYr;9a~Kr0J7OW6wxFEF zGmSwSX*BQ*1{L%W4wL+ypr2wjOvp`Q)i{~)pS?4iaCk^leJ<=-1$o%$le=H@+8+5h zzfI|JmLU*{7ApwOnB=8$u>3>vffet>_9^(uT(VhuNBm41MPnU)rdbWrM9vJ#Mx~p!M5mkJZX{E)ctOKx zIuG5v`+){hvsvxVAQh6_gkG|AYIBu5na(s#tjZNt`91n>HXz)iFRo|bt>TfC`!l1G z&-uQd|79IXiL|u&HiO5~*42EL&`0yapYsq_jVnDHF-fLp#w4&gX;z0nhyILmtLJVc z`2%H`uD{$I)$x9rYh5BRGR-OLEfY&@(~QRa^I4N-7*Gk4YIco)DGerC(TaOM(iJ&|6eL6>hef+=g6V7cYm!pa>8Ed!yN#*)kEs6#xirzc diff --git a/tests/test_data/noisy/1.mid b/tests/test_data/noisy/1.mid deleted file mode 100644 index 522cb3b449a48297c4c4d6d532c7b3de018c3a1f..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 75600 zcmZs^?QR@bmaZACV|W??!7LO4g45`VBPiWfVM?KrLZpI}LL^fZLn$U{hG@#DP8ja` zK?p|8fG~{0fG~_UBI)XxIXFGYo;#R3=u4Q3n*TG;`>sfeReb=Pirle3)?Rzs|MEZm-*=P$CrSTf@8`uwdr7|c^IvlR z^S}Sg@*t-;`AfI=^S^((JV+A#X!m~R&7@xZus!bIH)X%Eog~{uf451Jrpgc63;iz7 z^|S2nPTT&L9>Tmf+dK|yU7p_|zC&_V>x`tcF+fwbR zz}pj#+23B)MS4`zSNg0>56kR`ha|0&`LiNDE>(WWJPwl#u+n8cP7m|`j)$D{PQ&CH z;4lZG{=I(J0j>w94@##C^s{)doKkH}NwV*1EcDO)k6KZ%tR9yKMLHd8E}(c) zrJF^5{gOTxA1%xA;)iOSu_TOCx>dT|8DFaZPnY{;Ru}!9NmYPI=E5-5(DPnJqjWDx zCr!aaU8a+|(gTI{oUwN_T(y&=n@PX3SEP^fjM|Tqi?S3z=gZCF@?b(MDkPhwKw$ix zLV)+TmNYy_`fpljpvpFrjNgLkhvi|DJ}k1E0VUuTtZeLb~wS(V9tUF;`Wm0V8hi%D|6UomD*mHl$QpDgy; zDtnj!!{Xb$N`&-ithqr*kCLljO-j~3-OL-{?cXn1#K11|`sLTlvw1e{(u3BE^8NC( zOPOLU3UfMJf!HQzv+Pu3qzAL)?YsO;_(S;UaH;$|981!#9e$((H-dz3axJ@VQ zQ0h_PoBbv|nWRTe+EnU^ZPX;`NpkfWo3^Qv{U$k@q&Hpw<4iATv0tatI)fD)iJ;P{ z+NJvyT!yy^?WgXN{``t+iytUGs?+^4ZN}+T-0SD7&zHeX*xBw#bdxUMR@#4%{lsK{ z_1W^QOQ#@%z1CD;cLf+e>&{PSpj{g<*by6$^0!Iy^$Pzg4i;tq& zaSgI#?cfe-qXraK1-1D(NwU*P@pi)C1wRjJjp5a=z}5aFZE7vG7bQKOrQ3BfnP|su zd;NCcb;rs1vG?D$HXp0XUQCYLY&(gcS0C$j+4|k$-btM{MY<0|DMY)A@0O>vW^z)3 znB=s&`uuu5*eDdS9@+A!N~fh-v2g9UpdrjKUz|l<@E&Uy+c{evm+8Sco#bhgwbcE~ zxtJMs*+JCx3O6ML(r*ZaPnR(L12|8irHn3J$JNiN0ymt}U#9c!0;2hN38tq-`iv3s zw9V6l3{)nQ(c+`)9^l*oPP&{V{q0GfZRc!f!9#a9xAUcYd@)ZScj-UQ(l@j8c(T|# ztU1T#X*J8S{CE~Pqn#_}KEt$q`qSm%BxO%; zHWbn^uIi)9w%c<(wCQ2f zinsHi5qvhzhG);|MhUF?%`7|We3H+)lMW(J-lo~1PJDC#T)w>5F&ivo5V6 z>lIkR1sIJQR0Yq#S1jL7_!p8po%DB4ChkPc4=$7YZ*e@S3~O`5=uji@P0}|_&S6(K zH|o#4jBTnVox~g}J zzJJ>WOx;P+9nV;a>S%_z0nQqY(RThDgwz@WKR1#*Vx*hHY$QqH>34fSUo5vdUyJ^9 zr)np~xk$a?C`~w14|H57lUfqNK{Jq=Ue&_tHqr`9SM)pj(O3|xgNqiXjvzr!=47)T zdQ^8DAAu@yJc)c1(kazg51rFK_cjbx1--{af)MhztrugK?$o)~JBVe{uX~BS8c>&a z!;XqNx1O11+WxlSGD!fhpH0)lYV~($cyoI&03hYd?p<1SI_(Z;$?Gn8)n!{<5Y6)$ z5BxoaFT^ z{p}=qH4~cCIK$+AM2#_f{l~1}Edxyf>KifiuEmKcGq|4o7Yw+9FZ#lh_{YeW~d6Fbg z>X@1L9Me_9R@`6~&0V>P=I&Nv)g@wt+j6k#v}Pdsq%DQ448(A`tU2o{_fc{1q9&>0 z6!}^Il(~wkn2ukR{Wk))>8FNK5)tnofr=-^#TF+fzj=T(%Iww9SEIEnFVvud01^Z# zK~-cbAMj>6tSXgEau!q7Cg(ARDp#J`gu=js>(@7t7ASSk)x7JnM~Si2cEZ&xW2w8D zM`SHjGGOT}(Ht0vX~yK#hA+w3-?xxhK$ho^#E)j42r`1U{84 zJ^|PV&<^0|MpWU~d2jV}>=wCo(|FF#11-j3^mqQUGBE_+n&u^wn50k3SZLQ_{e&wM z-c7oiWlxjjY2Z!!&Xw!#%g*HkIi(Mpfe61$QfB>S2VJEPlH@_1KC1FBRT2n~-g3RF zav`wm{G6W6Jj2SZJoQJQ$Ycfp^Z=!$OJ?>Uku<_%dGV;6YpO`*U$)G~9Tb^Ho1u+} zJRcQ$y{=l27r(3i{rx(vn~tN-rIr9D2#y2sdNZW=ljOe1nKnqr)Vu~Fv)od#IVTn@ zVwX1ky7`D{eA0h|c=D+3o-|O1`Ek=|G%4SbH~idAXFh7HH~`IQ8f0Cc&NG96qM^NM zd7$-NVRd3jjs7BsP!du_7Z7VN1t$-fz@9o>Eg{U}>d#B|)l?$8 zb_=kkl1w{6UUXf$H4)t1QoT0<&N8AfcRgL|LK<$#p4=48r4Zbc2F&W2qwq8P1JRIA2B{*ex5EGa+BWI zYnAMyF&-89QSz^Mf%}?58_t&QDYjDa>kT{U4#e&JAIV&>Hygk zHIY)+Icl0COAXjND6VtKX}p2og1a$~Kj&C1(!+5#8RxGH3r=NrBKyQ};8T&4%e|c> z`ANep`&%cWAZ|(LI4l-VMU`b*XWk6s#Ycxp!EY=q8m36GV5o*IJS-ZWaT;s!q~>>4 zE9z`7MqvsY%MCOorfhw;mdeC@Hxe0cFaW7f*`Hixcl}HH$HZ5!xp5tpEXiF~@BAXoS@&eLa%4wK!WI$p$3g*&VAYpzx>^>JZ5a zDMuzh>P*A5^zD+(3jl#?EK})ndkoJ>*!hd6OB5~;!|8NVK`3b4-oACG=SgXOI

BpiAaJR~!fzOUxL&q@v_K$z2+uUtflFhJrE6o>xeInrON_X_B|CD=k?_ zq4L&}Eg8)bJWl-ba{1hATP?qwq&w1C`#Z>1)B;J=S}{^I+8wK;{T<9rgNYCwv`$~M z*w%884LJugwYK@IwYyC^}>IR-BTA7(X$pin`M zIDTsJ1NiQdSi%HO0Qse?t}px^7sN zkkR}q>=#8S_JvsjkV(vA0Lb*3D-0yX0`+&%h&etED(wQ8KyvxE1w&FPY3hnF2)43{ zE7Rs}m_vrIV86W77{&ZcD^#)RqcSLVQwvA!bEt&F0Q<5UxsUZF7O&@p_FJ@bb>_rK zAmn*5KTa~*9hySG@D6(qArcl4oBkH`_oVhxXy&V+>a?Y{4*Z%jk$;wP{r?@zjB zpljur)veITSx{CCqPlo3>!H>gtqUrO4gZ>j=itWj3=T~|P=9}Szk=}d{j%L>ovL^M|0eo?}%!QQxE zXwMpkjxE3(cg7bx3{Yefn5^nxA7>+I8v=ZtBWiG7_jmbcUwy?3$URSYa~uNNPrEXj zpYP_)Zjz}m%UyEG^AmoKjJ%B{gUFFo0FKTvWGS!BZgS3Pw>zQ{T1o~f?~Xcm&^^kj zu9a0g@hrw+wElL5ak!KKG)aodsF^SWL6T#2tn?F!cD|d@Ec<<)?$dilV}7Q;Pc9&O zcp-Io_Gf>2g&k!2}KW4N){*8bSS?NB^9D-X;58vilC8lie9`r=J@S z`^n|5#f3+u3K>x)ns~!%JQ@wt@bvmy6zz_|muBXoOT1OPBT){Rr-bd!au{WFQSge{ z-`cH{FJ@Rdu-oxF*==RQ`9fVZyR{bK6?3i9SGCsR6$HW{TSdW-EiIGkp^Wa1HNq

G8T-K(V78R0|9H;V!ZcWYd1CG(ve=+b-Rd{w8vvoDiT`nv`xw}3Q$r^X#4 zv2HMQ5qFjTu1cR21+Nf;8CWtau0~$hMIlr%a-RNfl)JHB?>u9%^K@q<4CC8jQ%IOV z#TqBSD>B!E5UF|rP3&_3!3o;*cQxZwF%DdTC8P^tfP0;Z{6ahb}O5n?WW)TU3{iH$z=nNsMFt;S=7?I23rG6XoAT< zy+29+YXj4goc>?cn>%@4r1!_#!jw{UtF9#JPvMXgA(Y{ zs<|bw7TNs~4)Xb`UaS5xj|yY`DhBn=ticH0@to&5#=)`N<3Srq(5 zQNTM_HTt`pmz z+~-O1d8XA}>@DGO_dQ*E+o5qu_Z*1G)~qq>LuEGt-id`#bGU|+SLb(*fFK>2w?6FC zu`r#Fg=VHmA6K7qc+o~R@VatGKh2Odw~AV^mc&5ud3lgI#D`>4u*!RWZx< zv^9B9zmw6`$NWPB0HxF*GS}@<+KkdN^wd3R2A?ZNAYm0k_AEcl=<`v|A;}W4EC6W6 zPTx$86T7&dBi)EIYRtT4eOVHg{B&oO-ydbqMPs}WTjMbC&W8xtAq7&usUhS)Srt%W z;6;zdmmI@=(;DIve?j2BO(Ltq<0K!(<#O4}~bf>n zXmcGJ6mFv&J0uM;yfY;AN5Un+PVz-1ha)NW$GJxvp&Ch*KV0)rQ5er_iwAzO_(avm zs8!Z<(A_m{x6a+uC_RaXOh$>5wf>eGti_K@KhO1Ih+eWYE6+BgP2p-VZ`E7X`0zX$ zECr!4#FisgXQfCG!*rBO?>-!(%;BX+B89x99fp!%V?h7oVG6wS*I*TGHxJ`dLYvn( zxktE@*deke(oE}kSA;hm*-kZV|AP$fHbfZBmnZXdW0o9sX)(z*9NGW~MN$Q=69mEo zC8*O!O@UH$RPeLfD$@rAZYD&C^k__Zi-`&_ts{YchVSnZMT1FW8>Ke6w9<@wh428y zSfCHz-qdSr-yFUQGfJl!+cG;`qjf>p6D%6Cp;xm*Ap-hzlSan3aZ3s0G|6#<|j$Fxf)objHuAOE`&PH6WQqG0fi`pwAV$4dJXCdZJLc0UW%l^ zw@h({sYQ;cjfrTDHSyMFO*dc`U5dSIz*YxVnyq)~gLc5vVVi8>P-u}Eg*L`IoXw9r zVE}W)VTa1!Y77&l1@bi}hh+$J4}iQuhm=U@D#m=FY>0#z2=Ld`|rggRZfR%ex?Xz(DQk zC5W>R&tBkQ7P3Ra%jKLhAS?vkWqXbKf)-WSP+@o`1mON8`ptJH(e-|Pj^2!|nJKe7 zI3HJ*J8Ha^O^2n*cghxOqC3bP5ddA;Cx6Toi5+_x7F-O!O^!yLAmIyx7(HRL)FhS^t*Jv4RxRg!$wRA=%Rz@0e{lCLJ&84jZX7?-p|i8voY z7!zB8TH1!2$8S6NN0|uy04>09?f}eccOYYX^MF=RgKg+5{6Y-ONVR*Z#Ps~rT(nkq zO&lx3JW*Ocl*BUwFn`ZHX4~mG+E$g`#xm4KG>g%AC1O=tPIdxDJo7G|&)G#nggnOb zrnth~qx@%h!r2)s(YVgg%5CA~t7R_zb~P74KFkPp$;(p_9hjU=0{g~^kjqN=Oqy_F zCtnTgVHCw%txEoO=9+E9TNr@V$_We{W15ao+RS9k7o3B)&s#xEKiQ}?ir-dsP1 zjwOT%N$wF!j^#AaM(Y%!^Hl`|)gRswJd^IV+0hWO(>0jT>8p~J9C;v~ae8%_m?4Fz-##nTw*~mqEHH~dBKUHs|5hPaac+wFiicQfo{p2F@#0f$)xtO% zC<{wZg1tvrj7h?}s1tDn$f3N22Z8=^e`mAifD8_V7i&O?srSatkF6N-O|vGJPY6?# zCrfDp4x*jOFR12C>P1w~j14cR=Jh%v0s_R1-FS#kaw5OyQLD2DrtZ3Nc0(^Zt}%S` z3LhL2U)>UxL;d8dcRGWkK6F7RUvVs`E2vnS#8>KS@g0$`XWBl=*}Ie@@eyMF4A_`2 zuVuyK?b)n3OXkr08LNRJ)+86;49pX>9ga8QJ2{b5fm6%4%Kbbt$&C8jdNJzP^-tz9 zNe?_U45*70-QA*13Y(Xlsgtn7JRuIUk8aJ#zt*$eUf>ft_9EpjU`x`qqafw+i3rfV-|lGA7}_5_d!tjKe%QxgeG}QO_atAD5~L z%I)Prx>N6nWy8+&&k(Bvizpcab! z+VPaj|EPx*x7na!I4>w=uh4TKWR3>>(eg0QchI{|axY;nq6NXo*T~W8CR;oek}?Fy zP9XGnp7MTO{{+R3+%|Jm&>Eg9Dq^YUZ_sn} z92$%ao|uFN^Y={2aXo?&KyhM{94uA#CBkbX!1ps)lkqUPFSPh;9M&SpXW z(im$9ct~vhkNT1JcLb)6PY)VUj#G>Ju^vMsm9TDZ#L<*+#eoiaK#)V?RQaGO`dMm* z2wCOvoO*y%P4?2oM^g++Jjq&-^Dw0sguLJh=#2r+iX=Sj+TvUH>3-HshIx+h#WVLn z5uOl}&laCXZSS6jOD6&)Ivh6JipS_hW_U-gG(1O_H<1uq37yAFr;gXObPodpO-BX` zw;Le8(Z{f<`#T0JUU(9jcQop~_uRGLkM8L)G}-acIAYc5X-`HUB z&ItS3@u=-gXn#bFAk5749y~!FE2z-w$tWm!agXYuDD$pb8KH3dV>^YxqPaa!7WXh< z$U9b~8~9Z?I0~amSY@#|A~|~ST}k*44{y3RVqGurZ%cK^FNwN{BUop#KwKbU~A$zz`F4~X)xZ_IhHDp5rTLbTXq%!F-|5$ z#{&l>LdJ7Q#m8%%BgQr>HjOV--m5h{v| zw2uJJW@yLVG3q>8s!r@AndpV`C&|SV@RlpgNTxG6_)r`2M>H9;8`aycwkY-I`&y-p z30FY2nh#z%(1G`@%5i+)Z)bWex~3`(-iFI6>0+Mp&DnuC$%Mq;CoO9!FJmZYGLoqw z82SCE`h7A7L6Z@8d4$BeP|v7zh#e@e^1HycG-0kzG|8h``gz=9vlEPR=JPq;FRO`6 z17FW>fq$$)6B5N#Ebaw8gqsnEgJx@&qx_R&M7$PE3lSPsjy3w+muX+lGuGtgIDHAq z>h$F#-JPVRggo$-ykwwK^~)rGS>!JnOL3&>%YyL5oT!qQ-%K55U)X7epJZy2{1wE#C@#Ml2 zdEyR+TIB&1{n{eS6*`P^_LM=!O@fi(4Mt`iHN0UJeZlv`=MI)|$O&w3%Rz)nsf{&8GEV{Aok zre6Bn>jkcd#=e72uZbp*uo8pVD~z*^yI45P0ApX#11a}@7+t{2YnirqHFWiiHvKJ} zSMoE-*>K)!=ZcX1js{pKlh%tJy+C3#z|`E|adEnnY2sAZn^s<5-77%dY)!Ur@XwcN ztJpbJg7<*RR-rKbgyV3R2Bt&h+#ILjC)M6~Q@5r({iLQG8u`W;?e@Gak~iF;GLZE1 zCaK=&P7b$J=(7bG2Uy35TnGD;9J3Fo#|n3c8r)+zP_Vu%Z`ik8!sgW$o0J`YZIsKf z*B&?==CIVACe#3~25|fDmY6yh4(Oh$#rXjWEWuARj8AgIKbDpt;r?-QL3mVEsgR!p zyR4q)2$+wH{ISbl6vYc{ontha;smw7f{Dn3>`u-f7un;aeViAMv9u3whYqp2-@#ky zXiD9_`fZ~ZsG2H&~s;tX!waB3O%%NJpUZK^L^2vgPYDDA0@mUa(6d`)y zN|izG*TIGFCNbrB(ckURw#Rl543wawvIp zj$G4K8UWe2!_b#>+<$z?-#mYqG#xM%^v6KwNNa~sNhv~jTjl7gtSzc5Y7vokp+-}E z+~3u~&Js*|+nR@47*-92b^1!?D;2)d1T4i?nq*ju{7%1$)P z)+=t-VO+>$1K%m2L$-d!-gr?|9LO)sU?v2L7hK}iQdtZvJ<>mHx2=v3fgfEd=yka;H>joFsT?mRHrjknBm4uI}?kY-GkkO;Vf#$kL zc0C}bRb(R|4d0wXEp!o}3*2YEVT#A?W>@|91Q)*qDe}dEg(ZG~ zJ#Mhc)n;?Y-K78~rh^U;i;=Pn2opki>Hc;g9V5d)7X;Y`NSDztU}6%q<*Et15TRB< zVe)>#+_*=`3k51@c6qurS|R;!CxZ)RJCpy0*v05OdSwzRT`;51{Le4HUHZh05m}Ng z;Nw7%omED1ttt1Mn+zl~Gj2zrunt|(vEJj~!+fPfNMxjC)ZUmVL{qn^uIm@7@32k9 z_s*9Xpb7A~p19D3&k5)%Xc9+7qv%+AzgZ(=S>@>h`S4-bWPre(;Ok+CjvNtl+}{%D z@0X{N%?5~>HWAOKD#3rw-_VcfcIDheriGDEmBpqK0KiBtWU5EH z;#Z-N3E^UWX+lpoGf_@6zBp(s2Sd+-(&OMD(>N-jX}aW$RMWfa4%rO|Zd)wQ+c44% zpmNz%K*X2<1csAz`7Fr^X4o7_rhxRwa3BuC&2bN+lAQ25P+%V)<8t zvH;>ToAi(oDFAwyxt0Rk3(Ka<_u1k*qG@C40h?jAP00Dc2IT$aG+z9K>AMu zE2;&Yv+|q!^Rna=`hL z1P<%$B~VI*6@9jorX_B@k%);!;Goux_zL|C(a}l=3go0lC+y!NLO+DA97`XR36c=H z7#H2YEh{oO+{4NfbpA*pwan8H6L}m6S~QjAY3UAr`TONrF-QPR9Ym;`b!h&!A~SG! zO`FBxDH>$0Obh_cV7xdT5FP0lup}aj05KP>Idu+n(Up^klKw|$5FnTjbK_8S3nv5a zhWCUyOI8wRk2lEUh)(^EWN)JB@r5bv0*8R4U{q3EHagQvIa=tQ%nm5f!u!*THVtclyvc&CpjdpXoD5nP*@vU_bW%lCfMTrIw%2=~nd=`!@9p3ulG(Rnws@w@xq;!a75FJD6+}0XN zI#9(QmY6YFBRKPFTe5(NnqEU1OK)XjtOC^W38`;jm&bs_V&*GEvv)I*7g%9p>h@c* zqv$RIJCBwujE1HBQ*W?$IZN&^FJct*${ade@++PMA<&VeUR(gM2>WE%goYp91db0% z_5mw14EmRDx#_dfRU%mqN4t_k#L2*3Z+&fx06eh-2LiaVC`Zg2ULyo^3^6tWxL}4` zy!hwz@jU&xlbpxBv!du%JH`F`s=9weMl zo#7%UnQ_d)sSVHN1wHF*@kkw5Tgj zGSjrl)}VX?L_4W>)Yzj=>zyx8-ZiK1v~Ac8xW+c*z$T6mn+>nK0^m+giP&=bA&8HAAYK8+7wB=fP8jA_w<^#Ue5w+6 zF3=eY&oNhL3i97jEV++Jf|J<=AVNA!3Va|g&@XFqJ+Fw*!+)%h7tMm|bxi<}U18v)s;Ul1-dhE9w6bLC3?z=how708%U6ZJ++Yzt8!>`Xv3wIR&eN*$FQJpQId=qQSR^s z?gJnPkUC7AMW{nujgJQ0oTeNZK$9PuseawEw2Xve%q!|psKIQxOSy00<86IOBhqIC zgwaRz2fDZ=IzXO4R1$Ly0+KEPm4} z$-5hMoZy2bZLZ^1m4iW6*%_jhpp3i!l~G;}W9e;x_*dAtz;9+?iZAD`kHBt(FiKr3 z+W;4NRAJ_~POPROCjI-E+`2lRBt$Ab4M0EzCGp&*y42XZgGquH$xXTCnaofFsOA+2 zo=kU72O^AHph$L2`n0*^+@#qIA50r5Dm*}>sF*=FInAEd4x_75vNc7gG=b0=33Ct? zfa8Vs9EYg07GbN$rxV28)D_9pXhs|4M&45oGXgHT8Tue+VUCuBXhDuOPM>Y7c1W^c zzDu7B!JSWal~rBDP#VVu9nlW=u6i=6LDgJ=adh3pxU5qF?psz~(*Cx7qgq6Cec3@F zF2M&@aoQN)#z5IjTgPHd8dZp0M8 zi<@8@(Ul=CnaNx-*g>p{s7V_K*LbeyI9;S1D21EqOM*pFY2LJjf1=MMBsx^T1%4o! zgjN#?$SeYuZipRTIjf?+>#&e2@j#7DFfmY+PYevvx5_Y*FHYg6q?e`YdJ_i6Ob=6YRL3YJ!!4=kTHb(zpsx>;5(e zG_Lv>K!{+-FV*^3~J!qybA7*XIz_8gtfe{*Q#f5}Z$7C#gH&R}v6ZbO!8zrtohCi*c#@&=s zstrj(8ude5U+Zu_cwQAvVt)=i7g~%BJZ*hBAx^D2VDLi3&=KBg?$eZAKp0sVU|X73 zq4|O{M>NOo4r5gdwlXAY(Av-J0P2AM+Rq(Ra* zF7%VbelKkG|W=n37vN#oDzSe*oK~tT(^*~L={gBFxMz4bKskf@L$tCJ0 zebvA%Ljr_5ej+*X`L}3zFsfOZXOiL2s zS@zLmg*fsYH66Rw>EbK*U=8*l^dQzDcYDQf`YB{J#&yl86C@Mxk5bmQw^(Cx*NfLg|vLKQ*V z&1H-nf{C6d{oPZc#15N4?@AhZCJj%JYdW(_PsdO-v8`J&ps9=4tdX)R#GpaE_h5z> zga&_&gP(4}W@vUR*qd76|3Hlu_D!30?_FRnGNiWiE15IK*3*V1vY>Pc@P%OprIt*SZms`gSm@>jmZ6( z(OK-!wLe>zI9w)V;YEmhB&h&`YadxcAYv!x^~z7KPbu2%Vv-j%H4o0N1t4#Z$eH zfYaPf6~frxRVl%gihl9>I5jgR>V~*{y>@8T9(hW6jybL^e)#YsTO9g_QG^d=n99|k zR2tn4f8FWegj{@!0QD?!81p#(TJ$Lln;j7c<;`bLCQ2d%v+ln~qZ7NH^y?_-g1AaI zVyxT}Kk^@iA;925(GMcB<&7EEKp4-V@)#r7Q7jI{J;fHO%;KQ2$3dNf_)%u zb}qKMHq4&z2}~Gnd}9d3+v=pexn-M(Z0Wg6`TlC**P8s zbvBTw&+jN1Ssut7oz9)#0TY&0g2dJcJffxjlVo7fI3xF!d95>79g(u8pLRr_$P62k z(+hS!4d60N3l>w^XGM3gdTck>Pg zrZ5t{r#pIymmNV5#WqeE)Cc)#RGcoK%J|NsMj%+P?9|-WA~Pgv#KGm0Laiqi^)cPh zHGN$7j`&s1yOs79#oWRr$t-7Xq0^5WgBuPwxzM=K1NSG+3@cY5c#k#YjBV^QlF?A} z$~kv}pl~2@^<<;qmXP4CXfL5lIq%yuT)#f#z8dC4fn&%p&O@iFcP^8>(4Z{P#LJ%P zN|J`-U9o&tRNM)M$Da-zDD4_Txo=Tp%cK&SX}Rw@5l%Us@QIul4zN?m`w_hR7dWypNmFf_VhBB3*QwKtbxH;t`uy>V~o7|sVXdkQpth`##*fbPNci)Oa3YP z!k9?(1HNPMwQ_CYb_xF}x2&Vt<&uKqgcEb#mFM(LND?6wkX|g0mX@2lvgTv|XG486 zVb|nbIMP;>WS_U*50e5} z-nvsTl(Uyl#tNAFx{G}$xrcRHktSXehCfXNdbPwkhJ$p^j%51$+Dip9F_A$*)D~Ok zh%Q*zW zPKUMMbDT2uxHDcWfK)Jyg$ziI^l#AwFQWzvawRy^@7O6gcL+m6qU7*`kJ1hF)9Qvi zkMiX!4(u8k&@>a9_!buqNWw3u)$foh+H=O?X3BKu0X#q<-24q)OMt<*xi12?Sa3_k zLLm)Fe-7Qr#VK(`-QU)28Qv6_3FK-d-7aDV@E()-*(xJ6&c3A)M0tla6nN{X!BGY^ z<4)xKiLh1jzaC&hOy>MlUVeAe@fw~>2UEqv@z9`VE2KfEnmQ(4cYn+`4Ac?vLM|bl zm!g;VlpjKMaBJ`II$}z7(;Bbd)QMd9af`FPk?g>6$QdTa8?u;CPO#^lS5&4nL*r6s z=z31ubH0^xCtD_53e{2p9LS7O=W=t5nIhX7&$j?g>qxi*$!=**1?-`UepN(1t^!X) z$&L&o_Q)vd?^3j3NPW`-_J}a!X)q%2-f|>1`x~i5jm1LT7wRS4w?dP|q6bQ)P%3PwZ zi7QXi8qWlZvFH8#TZ9X_hwAOZ=vIjGWt^OGMvJ9!Tg(wmhLg1 zlZq)qh%w``inMv+oJMcFjou`!VuQ!U8$}lvb&1R6;*COop$4KwDdf3(lhd$^*i+)v z2}~z!6Vjm?Dbd|+3!!ic0E7{!{iH-2jfo3qPjrz1YNIAoCo7Y>$eOYEB@1D}M<*;k-T(a_v6=cYMP!EB4i>S-={RM@1XVM!)q4k0$p<%=>W1glT7T5+{B^R~tjn7#G_{x+lg4%rj1}(t_ z^t?YlXKa?tqv4WfZ)ph>{pm&|<03hY6Ao!OU5&U5cReA9Y;RXWq~}I{jvJi)*^D47 zN*I9)8)-Wd)|mWV-c81sN9g5TjKdkSnD(t#qyW(f+SjnpSNVltSKbxH`>pCpZ@2|U zfRqa^0n9668OF?-W)+0OwGBc^SD>rGo#(m^Lh|bvSNZAirVx$7ovIb+7wVLpHjxfH z>~OF~dZj;R%yH)iT*UE0Hc9mohe)Z566F^JX*3}c+;m7)A|0dw-R23562y)6MTnTljK*LvLMfcQRc|?) zedeH$@a$tp#X6!moD`F{_(WcDUUCZ#>{6vsw4Hp_2w9IKpvS?7epgYEo!^p0g(Q#4sMlm-(#h{02qgq24D#sA4vsz zM*=eJP$@WMwo-X~n+VF;L9O*kBQ$soiq?07_M|FB$>M=p=r|fGaZ#rAO&{;J1@S$Z zD9PDKGDV)EB}h()06#us^~dEQp_s5}G(^7+*HtJPR#2E7z^%j-BlSXG0#P)dmIo;c zH574`XXqc5qRn*VeRGg$Cjc7tfFIKKJQw#4xH2cLCnsERbA)}K=YBonU3_x)w-FvV z;t*x;zE#v$KRun~r&JFJms<7D+T|a{g?b3#a+K>CxgTAKlhIT{ zBWhCqhhagXj)T!$Ld-Wvec7pQ#U^j_O9m2Y5H`#$YX$cVx>pFD9b|h(XXhUh><8zh*iYZ62$Zt z6Us4HkeLU;wOn2#;!ChX5CkMg_mZ-$X+p9Hzs%NyAciqWyCh->9;imJjL{r=p&2!F z_%a2GM|1(0NW4)mUA{~5h*o<$X~=QP z-2{l1yEJfg=u1Z%f!!xn0sN57IBKC-J(Xc!h@yyK;vd;Pv7S~8N z(7BMguupOv--e+c=!T=&n8WH69e9vj9FKU&^gXm3e-dPNPE;=Vi?D_&e4+z`PKPL4_Vt{+~J@m^fPTp2_P>b#=hb}z7edx z^g|4ip5yPMNyTy9HtGj+B>b67M9hbj1xK)WecdZU`G7}BZ4~J0usC`^E{>|v^Wd=JJ6z7_U}FJP5q#P9R!6WTdJqi9N0SU7f}wRjo@d3+Y5%; ztOKm(NI7`KmU%@TlQ|ifu=Ilu(#e21uLHQziW9&hTi8TI_(h$6;o-ZbfT>z$?|Iv$ zSO+KpgrCEaR-YRbgGeZB!^KC?)|^f<0awEFlG7F|K7c~BsOBtH@~AF=!i8)a_kVM$ zwd=&R>*yyYY0Y#-t~OS{-nYsqPy*G}TGfFRqK4!(Fbvs0wNYX9GfDnw=Bbs@y!a=T z81gdvI)YrPm`L_di7~tC>e$F-xKL%85lquqE*&1HUnv@a0 zQ!mc+gS=;W)L+;hvK_G~h3?tJy3y$yTgd4beokl*L4OT#{SKwE4G1FzMk+pI<7ql{ zB_d2$B{~HhLulCWVL?`lLo62kd%jD!&G#B`8$Xi?Wbkvi9{hkyhW;i7+mDbH9m-9s z1fUSevxy2_hmXn0da}=Unz7$F6H;Omc#fJ=eNUs`z=bTEtrLbla&joPsY65+s@wuT zv`-AcD(*>>Df{ac*-Ngr*fJa%dp~o?D}3#2 zFZsV-avg_-`8nfjm+PK7K`=W!cox+MPLZG2IqnIAXjQJ3=;swb`jc(Wy8%C#FIOb~UA;f(i;ed_lh z=4y+J6r9~i=nVZ6eO`UMnO_}=!QkRIfmE*C4etp>Nq+Ttoi9EDz$f_yknmj!KPw#zU}%(2mRzlqw!YpDP1=^36zlnT|$Vm&=fhDzv0&p%9F&z?C><3XX6GnFReY)#uLR&~E z(|oZ9^i>sGt;99 zU*)6~t2yjl5z7PLAI-16rrUp+mHzl$Nmdr@=DrXC@HC*ku&D-ocCcD}$5=Ce7lKrc ze%3LI7#zEYX@N>waNjIl`hAGWDQPzEsU}yp{Ew%RIdY5Xs|zb>`bOD6J6 znFz_!iGr?<$zgf*t5ARD#>B<6u#2Wd`PXG6s$EYMm7t(lCW5v?=m@>lM*yFT&|08; zKuH3ZWy7+hK1K-40x6KiwLT3j8|j49IvhxE#+S3wpQsz^pR15eDYd^?XY~XSVjcwL z3#|0TNFuhlTLinv=0u~2Rf}O4A5jD_O#O06&K}PveM2T{-`afPE+^LaBCy^S7T+1> zoHgJ9<`YZ>m;*YrT)nBl=2b`ubFUSP_fDS2XynR4z{JE7d`N_0_L#ti=r5^Xff^ z2OY&K75ZihL8#9{+xp}3S>g$+~wwuw$FAD6|lia|j7EwM2m2}zi% z#I)fT8C z9901rkStYCKivP|sDC!9p7C7+CBqtG)*$A0ZTfCl0WPf}L1{yMx_mY=BpDZqul^L1 z)4`OjN_i%clU`)=4=>!vKf%)CfnPf>)6I6nvx9&C2hYF!!&#sZy9*s$_jd2YFkH{Y zzjpKY6bG9|?|4umm_lFi=QMBym=cZNru7 zty~vw9_J1OLLgwRx!j|o*xk;h#?RP|xO@>Fw&F(du>I4=+10;5Vh=OOTq|^48~Q-( zu7kodrE_@b5RiO{EnjQ$-9TfHSrg6i>N5|Fo!^*<^%@18Tzw56#!AW;9lQqmOCy*V zZI});%Pspg>n-WxC*a*0`A23hmCfH&IvJ+5^_aNEoDI)yd`!`Svya?oOjnkooE_4P zv}IxJTN&SBjbI&+M9QuS*8x1~fFqj4$#!5Zt_+8R(ixv800&EG*slyY&i0m& zKDqs9vNghNs=4u!-Vn(#{W86R4`g6&lM9Wr0W%}i4Cx}CM=z2mkl3q6H-A?&%NR7Z zM;(_i_v|+_NtRT^HVF)DRE}dnqLw=N#8nQaz4oowFR~npLMn1C(w=&E;-`o%;&P24Ws=9hFplnD37RbqmJKK(zF*e1W z8e)f{6*shN!jrfn`$#DVCCOgTL?H}PGQiJ?L@WD*H_o(g2OJyKvGt>{}7zSE`0gSA1Z zqjIbw8($`!9RDm8D<1LCUXt^liJGvyo^<-*W}y)~FaCwl?=$HyggCpNsVS_m1tRa{=xPg_lPGJPZCF-`>FN4D zcf_SWD3!$rI{A0P9sD^6Jsl1vY z*l9!GlX06$Q@$M^$zT{7`#%ml$7B7K3JhZ-Oo|4vQ9^XIhNEcg=u}7+S&*)X#D>>( zdo<=kdj5VfSh!bU?JcYRKx1@GaU00yTq$uVR&zA7 ztFLjv45=CamDEsj$0}FCNTYyL5C{yF>vvTD49Fsotm@)&Yhe9)bde>Ok_%EtL7y%D zQsN@U-T^-eEv9!V2mk=2G7PMA8-F7XMS!dwS|i<%uN34Ri%5N)7W(!DgnogS?Ach^ zDmay~_CQM5rTG%fjvZqC1`CA&`aM|}@Ppn_ci;1cix1pbGy*dKZcYRD?MEpb-rvf# z@XN1xzj%OWSQ}iaAP6?#GxK9b7hlgi$ntXeswDP+z;dj%a93p=K??KGj6(yq5;$e% z2iJP{nPQNLrJ@`V!0+N)RPDu2T)pA~LTofd)A1p>(ceZOFGRP-qUD zK9@Qch>8sf1I1rfEPkMucm>fmjG-7^x(GG}yChHpDj_jA(}wBn*O_w=4xFAYwN`h8 zt{6#M-Q0WVK^x>FK3$RkXjT$p@{yt%>h#2J7DEUsYJh8_w3 zG*R7E0vAxX;yPrX`O$s~ACMbLG9(N;>`&uq6UAAZU;UX|=Y%v)Y#ZRE(ELCa4YjZd z;F=n-3n&$XLMQ}(j3)i}Q|fbKF1~9cZ*2)>qr&>6K5Ob6-a zmtSilwh{nj%vf-OHmoZG1oa;}1DiyUJX-Zfl|gC%*Vr5gMbhRBMX}JJEC`reoi+^8 z5+(}jpmw-j!-r2au;Nn-qQoHJ$!MS$aIW?7f{GLT@`?|U_8TnF57&Dx#LWSr1Tyng zyyNETYdkHH3_f^fWY-|-jc8VkGXu*z1#S0Wq-5+8%NrQ!20Y|fQiiX&sA=OPwCo(7 z7q!&kcbN0wp|WX2hg!D1@S!#r4WONp=x zs?sm>f~w)`*n3|pOl5_Q;PzC)D%zQ2RtkYl~EGABT%omI2GtZ6ITWp--g1zE*gW)`a${Y1U19QE)+M! zm~uh*I1lUwmL2}Vd7>rvIdh0fFMhz1NMu6~&>my4_zdSfV-9SgK6a@N!3{B%e&+gU zD1FeFB{{ZYEbGIRU~|&pDxOO@yreqix*T-9zyLMAAbN#Vfr75Y;M1=kmSl6oUESYe zFeT*gMm;@ zMgD9kQEB{Sea|`Rg}MWCB9={)aF<$qzBO5pyMjrrJ|yXHDo^4%if-yd(Lv^McU)J^ z_>u_upu_7|gpaRn`rFI|9}_bnCUwB?X(PP`z(2$g9^1OWLHdllDkn|90eFmoZVINE z$7(PZO5Vptx#)OcwTL0bn61y<=<6owm)v*9-htz@B(X^kovS*mGO;6LCxcc<9d86=5F3s} z5fC&AHsf=aK%qfp)MVLGJ^)Pgh={vgepxD|cXUDKf=V9>=a)fwTN9FC!+PL=LXMnw z6+sAubV(@TJ@yGACd+5LnwBF=BL^T=<9>xSqnuUJ@g7vHL{@!Y`{;gEo$2>jhGe9=*#G2~>H-{d4N zsM&xT-VK^k9z*cW=D~^0x2MCct$={z77tZHMEs=SGXd$@1Y@r`-48?|D_Z%K|IAx0 zu<$J{3{v{w+?dOoHccnuj#{H-K1h=PaatWiRvGkxr2emcT^?m6s{tLN%7 zY@u>*jqtM}w>tR+0qCR};{(S7Pa{l(jS#Z|UDp$0D{)*QFA6BqcuJtx{wz(xdb=k^ z#X-cXu^)q{#A%?nxJ@Ehy}0^X^|Wld$-qcVL>!C55Vuf%^|4-u@D(ZXZQG7|DGFRV zQ#vY%zRhw6(+7aG1df8=lXc{LTEaUmG<912qMGt?$*eyG;yuPfL zEamy)lW4nma6;UrfT=`NG>!o-gCR5eal!X|eOoZzWH>1V*JXHcn!~dH1_j|ed??TK zK{J67HQ{h%rzD7TiW(xpq0;N%=anP{#0a)%i!3uphzR|7HVQmbwQa+eDrpds6&RMn zY{qyvE)=^yt2%}|bjQ<@k5}=GzZluU2`Mah&^}f=pb~a*d)alnx11@UC)wexg;P7YjD{1u-ICOO>}x{+obSad6r_n;I-6DI zA5XD3Tzw95yz#?dO*{+GIoHzlrgk5Yt9+6q&+$I#@F-(VLwL#;dx8~Su(}Zy4tzO* zX(_~_dU`e}%nZd;0#>mn%vJcP{LPpmJXuOBt*F>fXN7_$7{d6%0Hq(2O<9FlU<8q{ zPH)KUV}pVke^@0}l$b2~IwC^ihEy0epbvzbS+zK1kYHPzc0ogEtP)wBP@to~7*7j- z2;xP6M#|<$h$Lp4@4LXybxPoHn~Yl#om4f{P2Uf5lbCyqjpwk3!-SU%F;E9d_z<;pDlNWi!hJuSI4;d2kSXJS1ziRT z^a=M4`V8ll1i3#h*~+_J@pOf0I%m1iQB+fW_o(!T9Ff_TY(^<_(iw8C~rmiw-rcgZ?p?s-`D3jk$&ikKs#x$WgQ*b-r%MqTit%L*l4mgiqbsDd(s4;wN#s8h0%p zPZr9c#mZvnu%J2z(+Z~V! z9$P1T|DA$auhp@K4*W~WgDd9{wRAY#;b#F~MdEnp!j1txw##VcTIOsb$m=JDU|9bU zYo!gn&{sUW+$G@%4uAG&@vSP^;RH8~#%?BAfS(UNrcRy1C3g`yx*u|52*t#e`fs>4 z><%;{4^$701?U0vl8`tK5%K{^cGj`jfjeKLguv6~L*fFDCjIx13AlHR1(iU^he3b9 zGcLv9kI4RTS_Ywyz`4w#9B0`}5Paoz|Ucfw6VjqdUeXBg6gI}Jlemx^(O@&dY z*O+N^jQ<*E_XrkkJNHMN&SaHyOoVMnJi|b*M_`3uIie@b^5l+S@mwvAihhF^pD>H; zL4N+G@>TL`-o_V2?X?DL@KkC8@B_mpFCxnie!f60;adhWSl3+nw zf{kLnA}k~hg2Z0F3ga~wK`a_GOBO{AQ44lI*?X;KAfuHY+)^=&45sRHG1b$Bzn*=X9DEm-=GHa8U_HRVbnnj_XjbY5(xJh+{)OY!6d{|s_B4KBpLCB zxD=u5UHe#Qe=Uz72;oCu$b*8ImFnDxg8j=@72rx-aF1L@{P~wZEDy)$402rjxA|$c zc&fa}<>K2@IPEMa$+>D{FQ5Y|9n{Jd_i|<9U{Gk18rKX)V%WDB7z%`O;y=UbNncpn zIX}Q42*yN7;sbr;#AdzdfmaHO`Dn@Adg9OY#Uz{i8T73is^?FZ8l9Z4eBnpeKgD@y z9tc&12Qasc2sfOJd?Jap1;Xp0qXH+2K1Wd?xaXk~Pv2b7dih_Ug1c(X*wHGNSrssLh||{Z;_6RcWf>+Zrh1@3 zbchg&Z5`7TprQK2-ifI}!EkKC-)H4wPyOWGO_~3i77&`T!f>zCb;#zUNwdivMg~KV z>G4uCmOl}@j`W;Jqyy(b&h`$M-2})25d!YpCEvwSujutA;s9(ofLM3rVlX6>dKjQ* zaE6H;V0hvafhHB$5Jm}?S|=J1Yg7%vbt;e56J|8B8E1$sqs0%xDpqTd7~$N*vw;K< z-iBZ_?(fo?o@UVBCQD1V8jO=Dq8OkYCb5YPQ@4)tSV8 z2`zfiBT(GH?5aMrdz1Vd7HRR(VNq!@ctwnsT!{VpNo~t+l^<2Q2Z7ZIN{A(5teb@B zycI_Y+t8bB>!l^@<9ofvohbUE)c`Y*+tr7JO2x}7fcn6xn9N1AI z`7`UQhhKLPD2ckcc<79AP}JA1DCglFC+pyb4YO?s)F79|%BdL+6ji#lBuPFFW?GWH z5_T4&TrA<&B*lkU@FoV=;`8z3Hp52LR(LeTqgT~YQSct)Qituwks&+8hO)(ATMtTS zZ7;+P86|o|kAiTN;iCWkmNm zR0|4{?+K$6jF!V_6?>BsN)d)#pE06su=oK`b7mA=Kw}Q--KwCbyzcKHp>q>IjgRhc zMhm3uk~)Umvz8&E>!5TnEQ1b5BRA~4EjXVbn*x!GJ;UQ}((N?qb9n=0dx2G=lyVYr zPSFtwuEQ!uM;4KENGa>EzG^QCh;G0tVj(}(CC$H=K^*cHmGHI8M2W2eGn``Jdq#*q zw}UB~Nit7nLjop0$Jb5cc0#t;bEW=A{Zl#$-TE_OD=JZP44HpjY4P20b*ZO{?}p1n zq~W7>-s!m@L|x_DJ5{vr{5)*44}Ia-RD(5{D5h&}mS3L4)eBuG_TI>CkjmUxUq_9f z-aOtf$7y_bcw`^U@Bv(dif#!A+?q9hIEn?=uk% z6Zv?BAY6Sn7sB8teM3XgJ2kz=abf2|mP=w4_>9s9)+^m*RFql^4&1^; z8eK(VduSz*7ukW}H>Quc6N^ZbsE+`bFS(kbjApo!G)6337Aew{f!JB%!lqX{iCodi zm9IqKtFVwrtXNw|n0VJT{#$ePHaS!(6K2QC_zo0b3yE)eM_`H7eL@S|n`dUrEk6M< zZ1FN~tou}8d5Opm`ydSOFe;lx&-Mh5gJvy7R~_76S=Z*c8)zb zsmTPSGJlQKS-Y`n&p3Vgob9Shk@eR*u-_ zu>!hrql_Jh-H6=S%q^wkhw_SI=F#fq=6dD8A6=r;1a=;saJbL0kc6?*I{^;QDoHgcTB_arlBH>7-TCqpiaA1-OVC@Jzs_=M@??CiLQoOhRiRq;i)rIbu>ou zMrn~*)-U{_`XIt4P-OmD$F|J+_v{EU)RkfW2mL6_+J{fln@*(qdKnG+#h~tt+&eBX z+QbxbcxYDfRm$);mp%f9l8~W%qUZ*(_KND056#f>Q6(iQS?qZjze2Fmk2One`M}6+ zavpVF1Y=?xhnY*p9&nKVTvB-Jj-$Sxa~ttw3+fvyV76h0P3}-m{4v@QsDzBP*57@b zc>Fl(hA=)Zs6#Q_9#GnJJ;o^>tVyKd!*1lULedz#DP|b){q{0~`ht7<*eKY4LB}^< zL9c_vV#4*+3L-mC1s&1RoERE^&>xI?T-MA0AvN?*D3KbmfqL}d9P3Oc00r^{IaTMcW(k=@i)<6h9hHdXX7CUC31mhNkC{mYjoH3F~P&L$PcH4GMQ*WxHxGc#Qy+6mVzFq1YFt%1C1 zQPX4*7(~yw`dGi$2rW~9Y}C;6KtE&()-Ow+(_A@8jGG~Hd%9Y|4K++-T<)xC(Esfx z4-|{!DGqRXkhQ;f0}iUhU%cz|p<*_Aybq~j5M{Ao5P>yx;@&;fN;q}NAOY~TU;P{a zbCYcG)HQzj^^N`qSU-+8zl}FWMT{SV{lz$o<<9=q6cDqBp$Ax_F9-BbS1)dK6>A^e z1~vTKpFM1h@yj2KD8FBc?l($e3eoPDzxUjt-M_umkCBtXu^Qm4`tX!u`2t@ztMtR*GWi~)wlHpMJzZmKc7=U>*j-2Iwtk&}N@vK#Zi`bt!dzwPYlSGz^`fa~6E zz$Ik-%dP!#(RWfRJ!NiYB|={{-tLzg!DSuCKvPqTv7R2=q1+u@pW{k-PJFCYO@yP3 zs78{^@>by=?O;!fB4;zeJY?4MEK~-VGardTd(LmW9NQWEsupKnd_i0A_O{hW>E?so zepC~?IqpVSZyVjY!!4Iba0kD?-l#43V&7W>t@h=DRWq=G!8gb;BO~AL*V0wL8AgMW z)sXmn#pZeP zuzcH|j~8!6Yu^x#C#Lq>%}l^1D**4C>;*PDrhmfsWqr%)`x4<{luwpYTC`J^gr zW`sCZzdG8;Ei7c<=fW~ht@ht+D&X2lg z;7AvmKZ$YW(@ZqLYDYQaL(h^z`b1@adX`@&O#X%xx7s|8jbc?``IO8(%E5S^kL_x0 z%T&9nymrafwJmMiSes^7?QP-9NXOa^OIK*0bF)41f8OQ+j3}%+5#d)OdBwMfJ|MkU zU=`(h$xu_6njW66Gdz`n9?OY>pft#}zs_a86^mHo08o7JmOzSu7kv(D+;5!fSe4XJ zD`PPWles@*D3-SbvL1{l4^Jz;8h`euV{l@Xw{{}8T_;5=P9sifhwGV$%UsX!ua-iV zF@r>hx49Tkn+SUu4@}KU^ASYIy-*1~zB)F-w-0Jfykh_v!#i&{= za@Ia+5*%kP!^oy&^>PHQKz8|XGAdeAZ@ctZ>uE?CRlZf{e`P3^@8#??4;iQci{!#d zer$SLv=DOfJYL6)&gmX!q(mF>q3?l)sJRG`JOtr;fC^qtXF@o77R-_db%m=Of0Qy( zR?OpoojjhF`)>LY>?8`&PY)>hx$M4(05f4>1FLI}R{w2!*lmAd*B3GPpDu13Pj@!w zN7M93zP0J}`LRHF^P%NlJKj7#jt}{x<7wT3_aE~Yh(O(*5QsOi!e=q#4GWy4db7=i zv$&T1Q~sifJCt3tUIJYwYSN;z`y|A5J$P|$H!1?94QZ}(=Xdxr%s$H2BSOI)v9$y- zV9A(2O^sjE)JiSJ>78*VzF$ke!)cIL5U^2lu#VG`&h@;{a>{ld_DpEyVb}BGa5w{7 zn&&HYH~_EsW@nIN%(Nzntq)g%XlJXd(nl_L^2IT+aQjXx{)~+73EVkd10|m^!cQt8 zMs?e-}500gZKgO;YsuQ;k-2d?LJOMbj)@+4c(4NkJCJ6?L)&JC!gn zsXA)$=I!k*kQL9l0i{l_w9?1dFFaqUne^FxQw!m#oo|)xEV+~7xTS+Y#H;h z#q~kfm&-sUi^>aQPQ}VzWK9!dT0TUnH^*IdvpCJ^?Aa$+wT-Z(pV#WJH@>907-wg4 zCLSgWRTesWalKp)F+CpcUZjWX5?4C&gD0-e6NIk!2)FK1`IbxlAiHsA{HQ)j|JpFR#y2EprLoT0dA z$@}@eh_W&mVS}#=+yIps%A-9dI~AVY%cI$^4_R|Pj)%?sD15{A&oI+peu|AwLo;== zA097OLA-*aZoj^{w7Eaba4eMU~Y-8 z+pc_lB&Tj-T4q-OQf@nI*Y@M{X6*N~2X&10XP<7dhqW;VkFt2+w6RVBnE7dw)8)`*c z`OE(+TT^V?Tpqc6%WTB*8{Wn&*wa)rDp@JIhG`%Q0WvcP;2Mo1-x^o!jZXYEZWdZzL}RX4Y{76 ztL~w82_0VgbDv&70fMw1HM5<^l|#I28FUCii6E8DxpL74)xvGGx@|<6z;%%U3dz@Y z>5vt}QW%UJvo@4ALZcUvW^ZE{ZGHMIiuS!IS0dFiLsqN7-oaGq{1<-n@(%vX<&Hzd znfuB{iS?d}*BKf+R7oLgh*nyA^oU9Fh9;uFO|kdwUM^$e+JH9~jvy_mG$yuGg&M~} zW!j;H{5;(3Zh}gRNzsD5a)AG0{}BnVPi4Q&pv#T^5^7(~Q^YzQNM|K3Uj1?YcvC|@ z%IVz|(|G>##UbMQGz8QNWqq3J+gRc~NKYxP77XOY>3ley|85h_AP7R3S?S-k_SOGQ zh>LF}xjdO2oHUy~B$zGx_Oc2y$^A)Q`FY2~zf8~R?9VeC6kaU=-{rQd%BtYUm5dP6 zw_CHd@!A$WzcD+0r znBqgxn-@Jq$CN6u4W<{9KNgrESdUSzULh*?g~mSS(7agn-3-yn6WEXtv#NK`$2?;@ z1cG|Bnn0fF%}uVIKZ>nBtM#;}e(9d9c%?XKW#D}E;(dSm$Km5Y4dXu!;~!_UKMu=( zoX!4mT>jH+_D{v7)C0q;uF4nn=hU7b42vhFgJZy^#47-_yh!u8waI!s6n8dC_-=^$ zZ!FKk2H#DD6GZy5pl1g|>)6QiXyPrSJbP2WJ?|%JZ^YgeI+NW6boL;pJ!FNBo{#o; ztiF1}183lA+-~(^vBnCZKE9f|SK&F_?<0q&-0X4i9=WaCvCSyS>xNX|Ao;!|e9JFt z`R%%Us~r0?q{P@N*!@`F_B7e{wY7%B`YvJ&WYGNm*pmUA98aRmWX;{70vQQ~9=%fK4qI@|^Wz+5c}-32$A59f{VPq{O8b4F+<)aDn zASOSs=^mi^X8lpm9DQ*#u(n>rwrJB)<&6VsSw17r%lke5FeKtR%^iRE`^Mi3j&^DX z>6!0G-M+aGdMlBb(}%Qs${POes=Hff>zs9=JoaVn*UVU$D(=mnBqT%5-nsK8YmUX< z8FzE4g+43o3nTT}nN)CJGg-tkg4A|K7XnnUK#AG-`Gu^qs$BKnOY9KAkAt1YX!_B-0CM3 zMQbF+8t3UC#f=B$wob>+|5H;DrSwz%vKq?>`xd--U$h)m9Iw@_fCcaTc2Un!bsxz2ezR~+Rg~u zSe~*Q{^H35^mG|Zkp^m;8p1cHEJqDdKfY=N=~&LV=*G>(eVr`vWBS2z?*eS-jVV)S z1Z2hRLx6~P?T|EJLvWTsL2V+Vts${nv=NIEUsp6qVE0pky0^1C60kUhb7?ay5KQ`3LH#4;jfS7A^qO^>@f^l#)yeT;GrP1pb_v{30^RmJj zis=HFLQ>SmaD=ONQDQH9++Docd?mOi)A_w$Z~T0`kdYk7uZ0WwD-Eom#Y$iY3jBFH zTXlKD_aB1nAu^7zo(+h8t}Yf_Fh9N@LYP?cUgVQ7OFp?48F!MoTxO(q9|4<9+5AS}cPk$&D zM~=&)XM{f$w`3WRcb2rTlEz0t{xx1m>}mqsU?7BuXamujV>?v9K~|X`y=SY=R<6c9 z9#3;LPHPa1DXh^-&u8uvl#Fosm_ceX6Fn!H^(U--i5&rlBbG%E`y>v$ii@Y@pU3kD zQ#N(Hcy&DfbbJc-VcqB$>NSBNWicM+qP5XF5d!uPAq z=B}LNmD|Uqlak{N!g-hq=M(fF+Yrw#=FJNhkVSB`D}0p1_1Iy}Tld|#c4f{H!eku8 zMfHiz%LA;e1de=n%xN2G9I}RUHMuL0RspI%BGcdGPhn#kS)-l1sgo9JY@U@wJdvvb znLHJNf<`gmF*E_N6uPU4*?fiUBf!<;3lqUc_n|EWJx?zAE;ReAhy}I3;H|3!6_P*s z$axiSQGi{;>s@@xPwzS?u^nyWS-cw4pc2s~4JiXiQj2gBS+vkk8}f1Tyc$D18}y<( ze^!s7c;3~W>EXqIRs(CHfLRRNpq5;>0>IDT4xzCS8Dk8(kWkA;t2K$|>jC{Ye=qcUYp!aUiIlideOLh#^a?RLS`KwsSAx-YU^nx7S=Q4+S|F zRwqy2Oh0gVD+!2Gzbe=0s*$6Eits)9i$D7hyN1Mo1SDzkCOCi^jsXx%=2q zlxL{;yB*%?cbQ`2fq`Ji$-{+=^V{vC%BR!5(AlCKr2E<3is}~HE9caUVD}uovkN&W zQ;PU1tm2(Ry*ua0Omlw0^cn_G`8276Ra2+C*zfUM%YDRpq~J*Q+!UL%%r7$FwgyT( z(+Nq(%CMEQ^=kGJf02r5N;{kQ=7a%|txLkN5-MQ(lb(loC@e>6#VzN$ro-{aNb$wa zJZ=N)V?H35&rxGx34gZR0J`w8cqgL?a8oWm{|}M8sM6*U`|H2UBqR7KdA7lI7q2#_ zKb@Eb^8qTKG**iTiHm z*Pp?6uz&{uLv01<(__+>HU2MtJeb15L2!yqBQ(sYkgbqBGyw~|8dC0VaSKBzc*lCS z{!7vN8s1R>_R-P1)fgskm5_I3g}!u~`$|*DJhGGwAWo}UjaR*qB=*AHKyIO-@RTr- zri{=SQXWjy*IK1~=?W8>ui+qBy1+nJB?pIaHN8 zPgWQ#cq$z8>!A}L;_VP_^2sV(JCfZ^v%?N0`FS~*0$FNg8kMG%e_gmhvR626-3JFX z>5xEJcsr`10PisREo|8H05kog-t7)FOEwT{byXum1q~PW(b6Lx8H_@8r1+5-Kvv85_WcRGrFWo zI_=c=RUV&3;T$A8Tl_dK2>f*eZNW@?!}c0P^xYtQ-uCO;T`wV>Z?MQAR$!8mMmqEr zFjkHiQ0{|050{~@wbXD9a-rc9b?R)TxF0RaFxw0(h=bS1%<>RJNVsk*4rZqWW~2l4 zI||frd$w^0S$LoX=f4`=P^a~KCK?Bc(m`_JZNILo<@|Mme~@j65wg7&_7)B+T9qhx zkM(#n{ntGNBL&GZ*@iHj zj%`N&YIgP$RaNC8K{^yUvQUDgfLn5|J=dLG&!U?wHZr8GQOFU=u!p!rI7{gpH^(jrvcyDrfE~xJr9j`(Y*FnVuCKy=dS}s{% z!@tAi{`e(7HWV*E#o3ZU>JA&K(Awt`gG(NN?7OjOcX zza@~NRX4mOMw~b5V?8FbD(M`~(Gi4jOCZygp0zEyH8HkU?#vL^An~&C<3KLu{trpm z#4U$mV-(XXFApV_fR1xAphX56@YMG?`vP0Huq+FqmWpigMk$0Mc5Il4Q^CPWuoCea zIMQDHc0K=3>*a5EwAXKs=`h{)E+e=7RNRLv#`A_rQ$4sN!Z81KHv3OIOqIJXFOv+g zEQItk9iK#ecD&(1wHzrVq7+azvbmWIl@aq@7}s^!);WhqQ)HC#i#vD^)?g2~gx_h#Om$#_4z*dWN$zM5N9-beJCqXb z!33e`MfvV4nFS?yzdf?ZaRT#PBZNo%Nt#(TGneIksn-(;F=RLiskzr^m$!+z^Q~*e z*vfFfM%xP8*s8p>kQ+)1HH&ye->T)km3|A?)|l@nx1e3Hs>W%))@ObzzR;p~`L>Ne zD=c7Fc43{J+U?`a?F49W;g+<+<4Sre+Dt7riIpy=?DW|e@M`o9O zsnzJ8jFGm#Tyv@d74zG5H3pXQY#7LMpcMk}2`CEfW_`NSCAec%Sls5gQ~=OgQD@!Z*9HTMo{bXXMHMUrUmgNQB&fzvC&)uBI6k0eN-l3rI+VCesIH&> zZigUsfn}GEh5*lc$f0{Oq~^=f`t(2cjL=Y!MC;v?!9;rc#$ddHx~emOJ0i5_AExQU zZ29wg`SYxA3YhofL}Te>xG+=#3G1)^OB_OBY(Qgs# zi_I#}a!y+}siIvaXpS+rTZ9HJgf({LlI>uqpoxo2gw_EX(4-DmbKd2q6h*Y00}pdcL87;#mJSWX+*wXx{I~!lh(2^WNb7rT5j?pA(~}7 ztTtuVsWw{+l?Cy~Rm?RD*Xcd?{=-<_uyd3*V)3wxKXf|MW>?7|gMJw#!l+o`06^^y zwG4IAqRpsg`uVvD!9?mLVZmUlOc;{>>vSNz1~&*S3f+4O#-3hnC*I*EAt zq_hpDEf)Fp8X-<3*aU{*(O-^HjuzwI6?42i9{~H z0;lY!=pSslSF|~@0U0U-dW`#FIAc8QSt-%LIr7n^7=4cx$(2t_pHq&IeMsYD%$?vs zvGu#Z=Q7=1z16X8SYVQ`_=g?m;Y?gNqX35Z1$r9#puhst!Rt~kT!9diyE0DPAr2BG zOcYZ0oG(g{Zj`vnEBJ7CTBQ*p1;-P`PU0vEjYK34Uix8nhPW-XDFLrM5H@Rjy%$_) zc78B~HN8f{dM9q-mSlU~A2$<%{R|K0@^z*hQp37JS$C(C;&(b&A^F8@7yKBoZ`?}<(Ad7_k+L>pP^dm28LU4mcSnT^)v>sz_F6T^wIkzA6` zR`YvH1_#jHxzxcw%ua7*y>s=r%5N=CZ_P$^)UWGI@~zpk?%eZ!fOdNoLx86omb{%! zzGvijuY)GeV6KBzOjfY5*sKGFeo0GU4#@|f#cZTygrUTaGI_&hgaikxb9~MZBMk^v zV=dQ#y9&uUXOt`{dt+Y{i)&yQ7M)G#`z*S{#o# zPE&b)6uY%Jo`91NV`y2iQj`1?0ijTvvLB|r%rI{q&#$iovDd?Ea)n(C$Ft*Oy$8H? z3c=L?cC9h%M#NWN4gKwio9bmeh`Y%A+sU@izqJ%dt98r=ra_OpIktYEPu9>1#0;IG z(GQsKjkN{DOGc2mmyB?9Xt5knzvrxQvy!+>abk9@q^~wfWRPE7GbXLDz|kn!xc^9m zG@Jj&41Hp)lLJzBe2r3FgAlIZv?vm&WwHzf16#<{g!K=A$w z%jCjyg}3Xf`(ZWzHle!vzCZo8tVj;}jTJsx$#_o&C{@l7$J z4#B53V|js|QL~J=`WwUiH^cawVoaF>-Y8~s`{yvezhR^Qo7v1iSM!e$|Lw19^-tH! z+U^2k^&sxR*Mv_fGkVcoh%`#fXy3|a6495}DyF2>?Go`b_g5ql-X%4>pq@k_4{0;l zy!YGdh)B2H(^&-=T31}9X?vBHEDZBhO1t4B`8|RC>*MVo^3PN#IoWIcAs3B`!+Cxq z3iA;fJ)zsF31O*v-d@QcNfJDhhd96MUD1*>sy?kD!5JA@*saWzq)pO;>n!RF8B`7M zf(bRnMj0Mg4ROAOgo5yC3mh2cciZdFny1asMn{l8d)-&AXQr8{*oJKmHRA1+8pG+U zX}d4KZEBB+%gcaB=Y*yfPR~9hH;B+kg)%hMvjO{s{lwHL5)Ba&Z8xi{_YCFU zWI-bWv_jU0rOZ?UC0@ z$;5<@P8sd<7y>%Vh_O`UPpzRC1(W>Y)t$y|psKL~{+g=DIrwZo>M!kyUL(s^RXUfD z5*}rT8;#f)n_n+3<;2LEb&t3ipA3l38tGFXM7aFpZiv)TmsUDYy2O&*sjB2An#)0s zVSCbrExflgsFJnLnzT*)Ncg;HGHs?3>k1h7x$ACc03guMkM$-J?Xn-B2{=te;LoFA z2(V#7lw(rMR>RI{8D3Qp!yNfho6s`hps`lWb5@B7Mf3>_tQJ~=l^5NiyTW)Mb%gmL zq-P2zwJfe>*+ZYcEx)xE-C!UaIb6aRo>Yuqc~my^iNF`_2!6<@RbCfJQp zHvApgF2Q!%#DB#zZJmofY;Wi;R6|CEaA*Q6Q;}&(0NVJC9R=F}#0MvhB*pUOqM)UecaG9m`YucLmDKFvNm zYX1`aTDoIxv=K(R%wWCmIV3``zIwb)bV1@xTQ&2=reAnYQ71U4CRPja1A_7jGkCjhI`g&jUZ z4;Mt2!NuavkUZ;$LnzlqQz93^{^5Y5;S=E+UmT8kuD6ZeCAJCGxIaWnU)nJA6H)83 z=$lz-5Ei8(@C!1SiL(*!Z#prkz?JMDyWgArq~yl1yg{NrpqLk{VL~_mZs7fo8CDhZwMFC|ql0D4>4RwCI4;x%1kK+3Cc;u%GrjYvzK@G#G4?Fbp z=9p~&kKm~9HvYTR2GysGa43xekznZ6reNimC6OB^TJa`QdockB;1Wjl^}27;C#{to zbJ)^?jKJ1&(b)^0JeOFq>4(vM0Y`!$Eig4q&6puHwg{dAF@4Kr_PC2?=*Lw@{MDYT zon?tKmd(O@J%oiXO*>}UPzh(iC=Gd@U-l80RGp^xrb1WV#P#fR@?2Ahh8*LP%1X*7 zw91=y+BBo1R)=G)2cv4#{mUdTbc*u(D2|cME__$=IGB@gGKcNuI)X>|>|T(cuG3dF zyPSqF1O#F)lZ{AWUd6emTo&M=%W+0fSuNyHwNAlz-SNTR_kyJ&Si2lm=KgpH4z6=3 z(36&zKbDOEISFxAl^`CfJYq&#bb`lL-n%Es9yMb9nF+?PAZ9vQ-bX^*=DM7}{D!ci zuHU0_Quubc_;z;oZJDczybqwe>05pzva*H1pLr8d?KkxtjnKORh|fL@1-wK50X>w$ z#kYH`+da(R4DG{I2F?&{HC!Uh*`*MLPP2cvRt0V(P=Ub+RIF(2TXH=?<#?X9=j`nn zwA-}U%(qo9kAS|!^9^@_RyIvL;h7|2DmQ5^#TuAxqP4KZD_QUGm6+aqYm@^LynES+LwzK0|uOmoWl+m%a0sy{PqUFs>>0a<1F7nRAxI z0rlUj9s}t+sb>ZdFbB@Yv)fsP8e` zwJ!egnwM3No^mh+Mo{cOMOXVTKN5UdFnw|Qjr%6w(5PFqToa<;CAbKUgg3;+Qx62bBFfO| z2FCH$s2>d_y3Mbh57sLUIWVF2quh$M%po_$cuqx{hxS#c-Fqd<+@RhpbCsNyqPk&Q zzvK_0X{QAL9J@0Vl2IEy3+gaBp{MwbOL?OSd~tG8Piyw~~B?W^j!l*YN`|E5QUJ+6-yqnOl(u*K{JKyi!tM{C%7NWiYGJd$Cbf z%5#O6RrBYNYCcXib*e1}zioAtFm;~sJtPD5Rc+d&PO@BoQHQO}pykq5sPhO$=kh_M zSW)Ni?y+SnnW)(1OpY`}XAd!W2)VB?<}!Ffg}UMR82`d74!3X*v0z=y55LL~A_nw9 z3cGO;uu$iL7>YZ8d%m<1MOYlrMxc|Pd=VH=dQuClc# zW#vUx!+=iIsQ|CBbl3q(t*yOJ-zLjIDlV99GnebKVy@j|NuEAxlh3%!Y9gRfCBYEM zhsEkaVY)w@O)S#eA`;v3=4db=_CScG&BaYTK4Sc)N$qY76cO&~X-7Xwjqa3X=%7Pu zLx40!5)?LRGA-<-~$pUkIn5IvVWYAPL(nZGPtgeHM+1MqNZ#gs3r-H{%)WJHxg4lCHc- zFyqU;ljC;zddSOS4mWImtlQLC=K!J}Uw7ok%t)2jZgVbfD&H}rMZ9=IGq`2~TL1py z6_2h;KI+nJo4t@)Iw{o?)O0UwXlq(l!S2&LwVJ6gsOXX_@lj&0yGBur;qp@wa?v@7 z?W$(n<)JD<>x^~PqI*yVJ8MwX57xkkXr@w!EdaxVz0gB;B4G-RBok1G z^K41We(EW;S>{J=D%CGfatSy9wF@THVXOjCxkaPox7qevLp-w-XfcFnHm&c*Da*gM z9ok{=*uCj{>-?^dBJbT056qJ>V#xRF4(pBxcs_JDz#(uQ_%h8g_QI}s1>oO}c6U>X zT^6$C>Km)r1tW0A`IttdEpv-l;JTvy2_<0r?dIz-io*JKtlwo*DdHjDt(Pl}Tdp!| zDHUO)3GZfOnRYDSjDkezQ9C=eegkYoClP)KvN$1TQ_B8-bIG3MRkLN@jLZhCj}}C} z@T+Ak>yktytC57oy_gM$ky8YV7hP_&`i<`>kvxQ87B{k?Iepb>XhANWFV+c3LY-Yy zX2%THc>0V7N9n{_$sAs*#7f6e3iL(8!`C3Dd!UwPshRFHeMQl=5kp!)wEk+4w9Nn$<`45$0Wuhx&`F-pudev$^J5?ma56u_moyU7Bh{M zL-~YQMPo|LCh$CO$BgN!91F;iKIYSVsQVc{%vD7OqNE9(mlxPuP?mQfSO7@j zz9sG|)16Uytq4CztisBcw zsWk(!7-35oDJ|1Wbgf3!nbE0WUiY-QGyaqbVBcir*hMTfQlz@eU?+|)h+nG8Y?7V0 zR~s5nP!a9J7#IT)dGm+sFTTegg(TlT-{pI#L_s?H8IW(f;@|J&|BKSKU;c zWv}Gb685t&Y^c(jk+yr$EAa1?4{NLW2ilNLcKD-Z)Qe2CdLhPm+MQWDUy#Su^kQ%; z;Xw@JVqR>$RX>%beQE4}0)46vfP#h$2d!Hdu_A;+4qW*eZnA&_Wo8SV0;WD6HrliR zS%yG^!O*|8R+hfqjGv{X77R4f;^#Bkp(7vy_XgSl7G(Kc_q+>~Cziw+P(9pBsax4n z^(_(f;qv?zPl%-}F~ou4d7#6ST6jM%BxLB5L_(26<24GTfzLdNR;RrLkN^gI+iMbu zUT>$i4{^PtWS19jlTtL+b81>od*ZOvzU)oeGHP7DqNI&h_xsZGD7DNtODThV9XVfZ ze^Kw~WWA7Q@`l;eUVsZDy4=0*_nOc5zO4$D#PMyf{6j}(kA=@NY1{&z>^7S}wE>a# zgZhI#Xn1zejURX>8V~r;wgEsIx_UaC*PmypL=Aju;`bbhf~TO4~34AsX=OURSM z)#;O+buN`V?0*h2SVKoE3Oh9&4*b@=i zVxO|7UggG)Ew-8wo!!h=(dge~EW?~a3uIi!AwCW9@o||QB!EULKRG6kR4wMlEpo-v zG9^~Ctbn5;VuQRJz|Jx{R=~}Sqr4=VWii=NOdt5^)I&3c&KM37GO+c2$|3cjbh6ri zJGB@T1eQEN5N`ALlW-iKk$#;tK1$xK5jfqOHIfh&rhGw@Bz%NF$shg!Wt0z>C7GPR znL(;fJvk&X#hnRMfjP5ADq9Sz~CR>BhX)`Np z^kC6?iJwLKQN!@IXjr_2n&qigEli0<(uX88fTG}*A4a{P-Qt5+QpCRbj25*x|NYqg z9#K~E$!3K3en8kVghZ+*rwkhHMIQskcai2U$~J5vJFI)@4ILXZv>3`YDso35Hp%p> zK9pKBL2X`JRkD-lx4xQ*eR4=f#?`P) zcg8U7Iu4jON%JQ+2H`C&oQ@0Pv_rW!93Mmxg&mG*O z)lxHOGw2qo7{=iG7WGO=o-gQH9}Tm6V~(PWHP-n%?&;EifaJfut1r5XJEc!Qk0?U14$)v@tm{BN&P3@dho|E@>&F~u92#@rBJ-(lfb^aF4=;sssrk&0h zYWKpqVcU?*&TE}!=N?wz^8|(1k}414fi1MTd_A$#-C>cO>DZMz8OyiRruq${`mkmi zz9<86XI*R2DwAbt1T{>2kk&5CSswTkq!n!8;sPp4bd7l`mNLU=X65h3@W8wqU&l`- z0*Ln22SCsTbp^eJuij~6O32dT=b;mYzmI+kH%KlJZt#5{6kE#Y<-0z(@Lg{W$$-KK z%Ht$zO8rKOogMruM)B4xy3!yv4xbq=^w#8(m%6Toy>?o_kfNncq#bHn!|LuUutOv!sL-7A2N%w05QRu`dlZV0LWC~{!_){Q+*!srNlj7fAgN~Ug6upT z3B@9K7Lk-RjmWy;sdm?qoxy%UG_aKL+1`RBQKO>8P@a1!<$l5N=H$33=b4Hv-cKAG z#G2yv7aV`$WKB-i^7lcm%iqtIAEplMMI#H>l`0AUY}k>Rm&UDc#?j4=`}L}b|&T7a!IJ(*J%MLHc_n z(aJi9D|)kzCLgU}es$K(smP0*>F#Yjr=bB5-8)$EG{I!G8$_Kvry8mEJfYhg z{)71f&4vHSUZS0UoRrT^*5N)*YzMn--_~S7vFgrqMf>B{@mYVhgreNs4+a9gvAM$S z>^3kO1yDaYSh3q6ac8T+xsueM0*Rc2_vQqSc9t2jfM$*1_goyl)ha!%h)Oz8$&Ph? zGY^QokEl9#LRm4N5>lxKRkQof1J}j^jgp&vEGq(D&sJ~tdI^7gdhibUIkCq?lcMK! z#@qwBKO*;8o*$J>>e<0EcB}2SM7F^@oQi)I#ykl|P$gM6WT$v#zw$i8noTJw+^w{(5`^48d zwgM$Nj34$s{e!>#^LP2DpZ~l{*dfoYJGN?x;jsNh>;E=Ui+rh}T{?T>jYM8eT%}+C zu}B(fi^rHYNC`qXEPa9dT$TC713D9GpbCLXhj`hPXq^JM)?XXVypWhx$R!{mU+<^C z&p)vO5*Ug5{iO;r*;}PnkKwO%Fe^P-G~I0fknoon8{lBi6IVFw4-sQDHz1{g3qzuz zkOhW#v?2cTe0o;|#<&)Vu*U%_ZarT_v=!QRHl9R%25Jt)>`ee;Fhb#kE9&oU=HE{- zvLMP3R{Z;^1tC6VYkO2tYI=IoFgOOl@H)Cq_WJvYiV-H?;!vznvHem$AXBjflDLO| zFg=J?4#g&LPoOfk6kSJI(Vinw?5Sx{?URD&d+{iQX_`+Je>pWD6z|tNCgHd+n@#)h+;9wT#k# zj|QnhtjMfZyPj_hLKKTtn`L|xfAS#`Fp9Wq-;FaNL26^S*JDjMM5oBa@hS_0CSOgo z@O$I@`+gTZ`p>*J+vVNI272CibT|oF{HJp5$h`_H3O#)18;HpBr~mw5Nj{m~avNHD zRK6iQC5z z2JyIFAn6w?20;x~Ck#G?XzhGeXbChEm62IsZEKwv_0{ohgNs^BsHlnTjm@EDoKGh1 zNRhUozWEEaPJo8}ZS;73*K8~@(LXu3pKj%VipKLzY-ycTQaPVg&BI0U0{|38$k|S|@PRIPr|q@~T?EaivSaLLVo5xTO* zVZ^B(vx)3&#Rsto1v^^IciFA^!NT#()U1?2vQtYQq={qq}-jEPPVa7;}=q1nOd7BW3pG%lvDrs+b&+ToMBmQP- zt?E=crcXOU%=!vhM(Bcnm_3Vr-7^+q_H5OjTb0X@)?P;SW&7nXvKwt%lN!}fa~xWm$lQ#uX3=KH zFV2jr{^gjz=(o_z$~YvnD#57bsHB-%E?c(Sf|^>lq;8sOk)l?^F*LIa!Z3_G zLeTjJgy1y>gkX$ORh*gKxi9?+_XhI__x)byzVH41M2W2hY_T&lG9ofEA~G^Fvi|Ai z>(inruHgT_`m6u76if(K1 z(;u%cihA)=cYS@QDgLr3&K5tV^__x0j~74Hi=w}_e%Q9?*Bzk8_IfetK0v9zwSHI> z3jwC}Zk-l&cb)b3*AJV#tNQK(ly?GdQ4{A1HQ%f^yH(_=|9<_jj$X$oJ|oDY&E}$i zx!!FPYnmE0a{w%e_2ER70wZGgnq)Kq3?u2@5RaKTT{p|(ljXcA(kDd%O$Y#Dzz{Nq ztbh9tA&d3H8Am(ZT`y*V>2&eaIS~OVMFolF_3kt++9_GTaY>^^=X5w%|C{T@JRhnP zM+C8Khzp%3$L#zPJ-`TgJe-pNXaN?a-PObTPhHz{b{$8HPUpfz0Zl+#0{U_NV4i=r z4BV(p>jwoHgZO_2(alH;+Ocwq&Q$nNw~m8o;{)<$uR)}8fYIo}sWSQbp?VMDVr z045gJ-4^?kgnwa?;12(R3m6W17hwD5`r#CgB6@IKuP<$f@8S^ID{ zHx*db@B`SW%?6F^?Rx*HH6_~nEulEhJE8>g7#9@q@hp3oay6Kt>97Y`Jf37@qBUJx zvu2vCN>uqP>%B>0WyAR)%q=R~UXxWhn$SWW8+Jq!AZ=zUs(aOa#X_xaRd=0*$;tll z34Q<*dUV%;{Koq6JnfAv(V}09+|=Z{57@6&A!YtHADh4H)&5FN0d{6EiZ$tXd(7WQ zQ)t9zb3&Sag4nvNPzjjI!`X0iu}tvo5bI~5VRxue)DXL&#WZn(?QAFgH!LhOx?p-S zoD5Z_f&m*L+x`aiz)I0{Up@JC-d@(|V`vH%etHgH*F7Sp)V?AG98zG_UHevaRW^MkD-EW5IOy)Y9MJAD}_!8V;y#%IexbBUB@%8msd;YbuRB_4jQcLZvGA%l&<#Ek&>F zzhPklpGbuR4JZjM^ESA&4AXwo|45)#a%&_*^VwKZX*Y|Bs@tLfa`!cy0)I8dt~}AA zb|aA zm)$n|l>#j_q*6jCThd9LEiKL6fmFrx)0ai@ zWyOdQGW9HERjIi;tj|oIM1UXE?Izltuw6hMCO*i)3LCMPO@WzxP(?i>K8&{J!v$8V>89D z5&ZPG2+*p|4%7pv891X|UCMrkg^4gXQ~u7~J$5`6-|cXhg`-wXvI-VO|FpAaO)Df# z|BOKO*>Wl&RKe%6D3%is;FQ19G5h42zXPWGfCXE9Z6=axB-z@X=d7_*>Xkan&Z>~I zYUEI=6Y^0CKCMVBnZP*1qMy~lqb$WtsorVI-nnJ7)i<5+ch(?nKS`aFIE@+awh5Fr z##v5DCq}5@xbw)x2D}I)4aa1%BxSwJBIa51?@hMSi3TY6(vxWwl<0n~zLtJ%fyg>G z`Aia*1yzgBoI-13oGVHsx(Fb3Z!f7XW5QIU=BfCI&X$NrsQQWc^Eg+A7lx(zwrKks zq`90EsRy0YJog_;Xlah{H_)H+ccV^Vn|z{m_{eZ(3AHM+4K@D#MM9BqCyz({;?gpX zQkc3&MxX|3tw`k4MVi}s1RxtpTCE@%Hnf5ZWszHCv*9cddls8tlDi1f-6J3%`3#*) zGuu{~?#|M!sk)BhfbEKU7`g{d?KV^s^eu?$yQ|ovNXJREfu|PPtzYx!0-kn7?CQjIF=HPGLW4t{h~jxvquM z_)bJBVSDxnj(1iVKg)MfRi|5p`sCrek3BEOBu?QSPu7qhfPv*>#Yl+wG1f)KSci|| zi~#cl`8AFXycl}zt&2x?!?i=k{(H`b*J5|dj$g7S*j+>SlB$x`0_?O#A?FqgxN5S- z8ScwE+W$z2MJNz)2syO)({R^g!XQHGrPVzNRjIYguo2e;A?m4Pv(S=Z@7**>hY{lp zS4{g6&OZRo9vJcP18Zxom^t6CYxPU_i{gG246-JvCVSK3K_<_>W5Log@rb+Y9Fw)t z8dsq^vE5H(#Jtt0BHUO;K`L8b{EigM;oHaKAkC9B>yHO4jFFsvTTP^9<#1blT~I?g z+;k6hcYBr|j=}8rY~fO1y>#C{TM)-jy@{edU1vyCXU=~(WlH3yUgqmyX&l~RE3QM3lYv3 zIOZm^OC?tqVOcS%a9zfJ<4oC30r7`zlnU6ngAh)VSIU7a=H5S?<5_ z^lC`MO%*mL8fw|UmWB*B1*7@5?dD;3W4;3`Yr@uAb_oAxBsGBk_~>bF31#PAN!buh zlL*_3r}m{#j=wGDESKSs_li?tbg{oiBy>7ciys@lSmF>>xz2M4$H9JW`q*ASg4+HW z#p1c~tYKlI)Wb7y|!CPeD4D$eSz3ctd8U(Ww&M4T=;;0JOD?4rF~|?M%)DI{A6~(Pc$1BO8UV(XyaA=>7tr z4NEnY>g?#PY6)7abF6D#mCc2frN3qr970(i7UF_rp>XEk7UZ+GSi9U&ORd6K&l0t+fmia+!E<)NU8?N%MKzlo4tL1iyvd}1vp_Rl_+a){}Mkz3kN z%>QtgBwNAOnf+p5)HT6*LsRFK{MW*uX6o=q2 zJi)MQd=yqyFRSjA%5Y2lH#mdU*#B@^JFFY6kG0Xizr{irp=DL22cD>R!2~&ItTt9% z>qMP!W5%|KNE1!6qzK7P>KBJN&39iC3GO|MW~d3NYpSh)4#;)%3;%K1J4{_bfoUby z`yKv?V|B7eWHnBlEvz%~qHd{w&4Twmv0oj2fB3HgjEfU@8om`yb2qIEvjfeBD;Sa3 z|A>X|6${74(|^Ol5y!D?r65i5_Y;NX)ZzP!x$Am~{s`2KMsY#nGz$;BbZDF{#5{7q|=?-H2NB!3k|&a~Y;Me5>Kb z`b4PONLqQkde{@sC+s*)e_vp%-92@o3TRl9j7tu->l7WN|H7Nbr`iDXf{QcQU*jJf z6F->@yZlLpz45W~c~6F2GG`GO2{PQAFVn$s`o%JRKBth;2Hbg#2uvh;=%|w zR)lZt_G&CHGpTZu4{#OQbYHVI9kWm#AO1RhhO4J#dZ^1OoJklfZrcH+tJBZCcYlpL zJRzQdD+9N~^T+f2sqUd5gm)|@+|Ti1mDfgb?W3EIN!o9_dn8}^+}#MV`S&g09tOqKjwg1+6Ga=9bu3ObMiZefRyaFqw`9&S<(=8Dqe z0!4=fi+1>q0Il?FlAagE^S1wyz|)KWiVIQxnxySEZu3wCIG`}MxTD4W*IsLVsk^IB zrs;4xJZG2L^zSD5A3Wfe*%!K#M}Uh7+%9yvjsysJfVzLpxV%1FwHF^B%-YA`PR@v!?QHhm(WIYWU&Fu>$R;f6enp%m#tl zDO+o+8vQc8(~Ai>w`ZwOUlej?b%a?<91Ef5 z0&gbms?O*>)U7^J&)x+*a2`9Gbl707W1N4fQhb#Bmw9}~FeAm*l3Q7YT6_+-WdboYK z^i5AY++@S_PuDo3xE5t8lnKVDY3|-0;G~lcb$9QcQYhvHY~btw*O_zNc|69=j+)Ok zJ};xyl+5Fj4tK{phlOtRyv3}$uLP$HQ{VXnr`oSox(NeH60eTMm)&RgM*ZBbGio0q z03peb8_=sN)au^fe^_*i+)dk4hYbi}a`%cr;leL|OYGRrNWlGvxNR09YY&K?A8{ji z3UA1XAMHeH!dub3T53`{IiNZvn%2{52hd%Y!bmmq8Hu$F0@T&I90RWG+s!d}j>fnDonIvKCgOL)t7;sps&!7COOl zcU3a78Li+uGbrJ67S? zpvL;$7Kaw2=hsnQB_~&PTlFZ}5!8<%@q>^LMOOIRIZbEU-S>>z*$98d= zDm#Jb#et~#@;T0+ZJO_H3dDb0LTeaMb|Nk%qNX=v({FNak$IE64-|;UoFuR)gQ5QUg2`*@udA_QT%c$+>LPC z&f#bXq{O7X$iOg=90$%P!~oW`WkeM?^kWr~*m7}Eezc5WTP?D!xs3<*#=0d;NA_n9 zX5BrSXW6nX2|gXaSkbz5-)j zxslbOaddA9;*P-|xrWFOu*__-NCFci@QLHxvY5mX+@%Vq{CD|oq>yczY7mx^gn>g8 zl0~v18CAnnzuWE&yO)MtvDuUwH7`{k0+wG+RV#%lZ1_??Q#e)V!FXeFL zMZP>Iq@8R!)o;)19HBFd6ETfwn1xVpQ4blbc`6M;GM5G`O^A~lHCJ;)uoPbt4F6uMa?Nuh*y z@lA`nXmVwD5F1#_Kt@P1rP(rEbVz zS$UovNtJfCKI;>vaKE0IxmLTkR~Ytqtm}+xpk% zHYtoy-K{R3PN`8)^+bqo$B1#`YE`RY)zY>A~n)OVy)~C9NOA`%#nZFucow0XHsU%%L7v%3%7!<*^9Wd z9k!1_YwTrNdw_E~+y)~t=L$z>eyt$LRMbF(!b6ClP4sFH(>_c3AiE5`5yi?rRD-#M z=ods=CoIe)>Fq9x&{$K+8fCT~>O6`aViIGgAz5i=GH~@D8iM#DQt7rY5yp-h2qmL{ zun!DP#)NsxW#Sf6?!Omai;p1kiMg1r($AMZxHWuf#kC!K-{NVtZbb+x>e%WMDQ+pq zZMc#12i^(V1he<3am?+|t@gw^M%&cb*Z|q7$V10kZ>Z&R%+YXMD~zWo(xg8s0`~e7yUYjy#8Mt^UlKD*-ozqEV}mU^+=O%3!-nAaoq7=ijQ6o)Ob1P0N zA08t^|B?t(AZ&5rH7PJ;2!t=omU7(nOPIJS6Dz*h*9w9C9FBX1RNFyNH3d((c8->X zHhvaegc^yhv8H-WOu~G|h_G-GRk)zLGH^`r~0<%+bOddCo zJjaZe0Un6#sFcFvl-C_a+&a`<7fKJPzs15#9ozAOkr+emx;HG?@B5)Z7`J+)%@J}# zBh0>82+u5Q_wiG` z!GiVN+}ITDoZ~4)=&#wgSmj8Ah$xH$rHZQ__g+ZMB@th$&Fo7FOf9GtE=9gYXf;n8 zle)x4yZf3|K4J8oa(D$urPJ;&ZaU)()ec{?TYnTrJ9T;Xn?O^yn^%IXpt zU{AcM!v(cFtIr0bI7)*oEntUEwFVH~jy()>AVw{N*x_yrvNdMhqVCUhiR9A{U&m#s zz7N`C*!K4pqWLifhx&0!A&Q$YHjfh=J5uXHt^6Bc-Ii+KQzto28EI0zFF$a69VgUY zrfIf|24GXfAnGC~9=}&cyBGH=I$4g36FKI{qR={-<^aZ0wOhs94tFzygB$Y0ZTlw%KGjoQOwALa{g4z>)@>P#9~V$Aeo7Q^7ve0`;kXLLvj%cJh$Vu6@)mLX?Z0 z8*zF#z(bjcq9!@At&BE?Y~Ij?c?C(1+v{OwzLBrQq%s>`dfe1p6>7`qbmKP$rfV1A zMi953i3XR*Vn2itV&E)Y%3i!kV4>W9k^_etCYmkTNY*|NC+2Gp`5U3R)-dOz@i#Gw zJV%#h`=M=m%JGgMn)eg4c^3($(crF4--xO?wI+kGevnfoGSb@FAG}8zY|L%<6~1H8@HIHZfpv5w#6e=6%iR`{XA8*& zB|_Nis39cA-xZTwtehIkd_Q3hA!_F}dsYoExm_9**yWfs9sDgMkU4FhDu;h9%hIDV z5>DBrm^qtvU2cWck&Oi;x+!9JFaYh9S;-7T9Angq&oAt^< z8a?4tuLd01`v?_In#6vi6+@Dyj5!40kJX7V|LTeMrec#~x6S5p)+ADUz~bo{=l6>l z!@VFAslV@(%uHKlarm)>bQ7lNEPaB~;ZoW$PGGVo$9BX0WDH(kaiPq|6i&Sk zm3T)RDnxnz;y8VNlz&kapX2vxg&+7<2qoitG4H$Lh$~FV+F&s`sIb*Q7ejVZ7qh-u zkl42wxk4z??x>I`eZ1Pr>v?LN(sIoa(m*pW7Ag)hfxkc0$DbiKl=UT)0Zp@oDP6 zi5z;}v8`h7>`b5TEh)6Pe2ce@*eCu$l&XHgCS;cKPtC+Ol;W*bBt}`modg^E!;9aM z!6UV_`7GXyRFcjj$;qJ*h43Oxim|aYu2$k|#UXEGGn2 z=S+@{Dko2m58yfA?6bs5Z-2(q?YkRYq_;i^D+N>z-*HB>3keSuZV=oVVBS09fKpX5 z2HX{}fIl4XvCM~$wZW9yeY`#!_E=?q5(gB+o^;Htb@%RI=82+7&yF=scx$ChE>r># z-I-9k)~N84pWBj3g`VRhIWu(w_V^-=0cj8;)fvt8`s_Hi6d4f%xtfs}(#i zzFl!S@!39Iyjd^8H(R+vlQh}ZtvCQxO(>E&k4r9Re>6%MI$VZd6T}$_o8KK`vfx8h z10fHVJ@{x1NL(0Ij1B&6B~D1L{&m_$UfCQUSP}~}XqfnlcN5*VrrOhiF?Tu!Y~=|2 zZ&<1_C!D33{aXs=Y{uH|+h>88pYbS#)6qK0?6!F{7o(oz?KSE0t{E$W*9)(?{WCV$ zH`2|!Rw`6(p0mNrs4_5ao-pd^i2akbzOAbQFMT|=IB|=_QF54sw_ZKP7b})l1*^NQ z2`39;X~SJY?h;$Kb*XSx%XpDcS*|9JDHF5;gscI1D2`|75JfMqR`~MNb|K+cCr=u6 zFpjwREj@-XX0Ic`Osqlj!AO-6MIBg%DygB=3Tsxq)YY&wkJ(S%pu5NcN2CY`~l zHVb?khM@c;zFXys7psxqY{gS358Yk?S9397ndWiE;|zQ;kc#Sz#POkkKa_<%2b$p8 z3V$#S-*80ycW1}x|nML?|1#Q&Zaj9$^4trsPVNUAXlk{nucKxUW&l-5C zlKPjrTd`9Z?;pSBq#!ck@SBYHZdj}ET|j@10q`IPLXPi0EsEP>jGu$kS1{pj*|8da zc$cC|LDjtG!R1mr{aKqPO}ZaTje^0r`gG!Q);w2weu{G&GUM}Bd~J&bq^;Hb7~aFu zMi?iwp@pP{%y{+E11mT|^%9|Yhv$yC=ZKf%)x|zM(iT4J*AzRUaPI6_wcq)yfJa)9 zs+*mD+;VP!9iQ`r^x0}f(*YE&{%(*mY=TU+_z<7;w1Ea6we>V1ErMNiaI;$!KNG5* zqS(UB;G+;5vwp^xK@Te~p;Un40S?P1IP4Ndeb>#eubBLRZ@h^Uyt5OaUKMA5{7v+5 z__!ICVQ%gD^i-sr;r6&nNf_ zhCJ)0(5Yg$#b*bvQ;~}=?!_9a+${ zw^3)Gks@*(JXPXN_7_^n06Gz`Eu_~!5oo^LR-jryWZm;zyZg!FJ;O%^TKX80K^%1H1QyQ2; zJig=H8q(J60I$;wKizvc)7$=|{>Dz*FAi38 z`#^_>NN9w>Ag2F>?)U4Xxh~|J;qH^!uzP^Z?x~_Wy#I=~PK__N9ilkmQw-yr;L@?4 zR_01+s!sBVLrTLtenaSl$AVhXMKiL;oHS>3%X4zDxvQxZCOWG+E^HCVAbCCPjqrXJ z=P0Y(Oem3Mj-PA!#uh}=sq@anYO550F(2w{B8!|H$6e+v%+6Kn5$i<&ZL6O|5Lpaj z4iA+{a?C_wx|2tTsn}5-HF&`Z@7(E`cv?$vm^Is>j10UiDA0*)K0Ay=^<>H_QWI=F zt`T&c&UMNTf8Rf=gWuM9do`w>2w&t*bsw-)>o6n|c4GMwk=sIqSnLkx<^fdhUH$b@4H6$w@MG_m&8!Sn_}~!R#PI5#m#g-Nw0j zczVe%OBeXiG*)Y~OIeeW*QS`P{sx|uupKWAu}33JWYF67#sh&PCe;d2WWV@*Rb-u> z@q@#ob%s7lb(|5ZZIVSx?XU?tgIs$6W3iE00i4?H4MXDu&^QRZ_AWa~*3-Qr%TNK9 z#%wM!8}Bk?1EXTHJ`LZhBfpHGuF4KZ;L1XB!D(cE85O3OU>}I*=__LC8_kkP%hgu# zr(0q{NS*qQhsi|hr2*Zsyx?2a4b~`HTOFrj@N}a4f0c+2A-Fn$_7ogMG|Jb zl+ke-zVgKuEXO`CDow3CqptXgA9UC2zt2ng^cIHXX*5P`lpYi|CV!|VLZ1G6B#Zcl zMa1wLj!IJ1p|@j-9XY^0%<4R=ejm-h)Bd&oh=iyt)GE`H7sh=aK{!)Mu zA((R9x%fjBp4b>|NDAAiDT(;qi9ZlLO4MV%X?_Kyiy$MY`o%pY`%v2U^f1Zs1Wd5q z_D-RD)R^8;)6^K$ErpAJ6q@1=6mWx>6GLMzBgC^1{-A3a(AD^dE4%mz1v}a=n~?Rl z(fmYEAJ~nfzlC@0Z$&ODo*fi^^FLBpBl_cpSun1E9Ns%a6N&YOnWsh%aFmp{@{@X|D+22xKK$!K5|F$UpEGnQr ahQ0#(cZt@t-?OGa@G=x2!1*7a6#pNA$=(tG diff --git a/tests/test_data/pop.mid b/tests/test_data/pop.mid deleted file mode 100644 index be83c69da85868db81a45465d75c08b1c2384477..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 12527 zcmb`N+kX|;b;sAn`7#BQ+@ zyJ=5>&^9?u9!O{#tL;C~H`&5q1NNu)d;8vJJ@DyU-}*?uYtL*k0tp;9pO5xuAI(gd$0NqYu-FtKl}dgR9}DldvY(>m|EJqYnk4l+&dX$Ux4n* zVQ4FPx%Y~N%fCJ;e-p{y?w7xNM*jYX@(&l~A4lY$PRc)TmVc?pzdj}ZR+E2!NB-k| z`OmXv;mk$(F!wF_D0f~K8P^#fGj8Ng$<5q%%Kcl+xB34Lzdz&m=ecRKz?u(> zlkyQ`k#W7)Egu&fa-+CcZWiB>Pq_9e<5uxmxy`jZ#f*H$_0Nk%v%vKaTQ15+j77$E z#>b2sjGK&47@sn3F>cR&UG6Y{#`wJDqFG@6a8Bf-xrQvx)#Umd&oXZC|IInx%e_zM zy5$zvZnO5z+*9(|+_Uoe9M7`;!!ta~SY%vhe9XAPxXJkB3_LU6V%%ojIkQ(jWB&Zi zvxoot;>9b`pH$wu^2#rwb8_uiUH-IGmp_|*N#^_Ck|>nHmt;wbA|;VFk#!2d0bbWwU51f!Q3G?V;HjG^2Jd3}(#U zj2S`vR#`MpWXxE`9LYDMb*1^JjVoh;dBU0zYxY_*Xv|Jywi~nAm++S z@s0{1X) zOU;}3zt>zls?WOL?eo!Qu5D-DTD)?!p?ZV)9ViNr;Hih51LMr2&cpXXY@4dc%Ak?R zD$r1oQ;uu|GDbSBtb;^9tJ;mUdi}JZqMTxQKI-Kb{~~;hwl=594yRHGY`_dVSpvCk zBbyxctyV|?ya@>M)?Mo5EE$3wrRKn>CXC^Sbq^ZV(ozgPP}hti{+mS3?dWO%7w==< zRl;eZu>-$MJU295fpc0O8MKRLyEQv;#K7zg<`HVqj5z>`EXY=ItG8^xczZ>I?|#rk6i(*k87lDMqkCyKpcE4vNcRAqIX#bXWlTu{q0t^ z8rf`cd0f56NX04=)qRcNi6EOC(1smOkN{e9QV#Bm^K{6iK%tL=vyxPS)30|}-UtI< z;{|wbkL3lj1&kPd-wu>C$;)+kx9OM5N(CTsSpXB~m(b9{^w>;<2PAC(+=9FHBt0SR|2XZ)&Ljie- z>#8d5sgPtVg*g}^kp+arW^|g%WVU^Q>Ts`BnC%Amioba$CD||vq#tA&>2na>=HQCl zn+!IVX7mjtL>Y0+M8B$8fKRfunsSMQSrju3Bbl5Fj8@~QCSw;l;@Hp$GqWrq>_P~0 z#abmY_ghl(eDo|g;R`!g<7DiNj|&hAi1@1FB=I&QEgB}#*K_FaW%T!a4$lwpd=T{t z`hFfSdlfI6A^6S_d|zNb!+etYwU!u3FQW5rLgSQrdl@t#>9@-F1M_L_pXI*J=Sl=- zQ1cD!^E$lFGS^b`OZ@*v5|rIfD+XHM36>~;;-IWgn~c~{SLH(-z8Rgz z{+;5qZk^W2gqjay#HmC722W{Kq%I*tu>o8zg){tU)ENLU9^!4^Wo_)}lLgh(d~^ow z)X~ywX#9J`$dJ%yP05;eOC+3+zNz&?=<0{S>aUd~}u-uVcU8!j5O5`dVOKWf4oQ;elyOoF+s!qu=JHkJB2T_Zk0-XYS}%5g${7ir!=eyvEHJ5my2MGaF< zOZ(N9&FK5dvVG2v@n%#?;M!>wV|aTnB7Gk+n%lq4yxEzh+iI=^8$_XXhNwl>ZAL%f zeo5EiFch|_;>4MsoeLKJK@pU)0M4xgBjK-)ht$h}&Ln~Sy(D?OiLu_tSmzP#hgn=c z>)dkO#yR9Sc^wjAandjz2{5S&a6>o(?S+iiKI(d5Gx}ZpE(jzP+;QId1z{j>T~Q!f zG7upMs_U($Bv%OZAkgorz|Ox3HK+uSLqpWuu9^nFcvQWm88x<$vAW_+RWI}ZKCWA` z6v8yXR2@HP1mHG*Kcv@*wT*q~VvNl5UF_Nk?w(_-M)kB^2yYuYARUed#n1)1x zlE@z$6yDRK48pOm5Hjk9M$9}jJq1F<2?vo zIueHEsFx1#~X{u(gPQ>qX;>YjI zB<~BF1C3mNqf~9QKzaS8iJG}uWpy9iF_sloNXWKAR)!ioFI8*e*kG|E%Zw=b%YD4P zuU0_JLbcHf&8{>p&#Af0^FOK9)PD`v6(EEINt?isicsLDB)*A%2q9 zc3Bulb)XB~ulLl^JdqQh1Q`-s9Y(JHgk9!k+0>w@Y8?i3QI+B!o3a;RN1&-oH<7Z@ zY6jeFu8uh{)V%$~mEZ4Kdg;8hsyVNt6hV?d)ADj_CQb|oF z4F*;-6fTX|%2$uZ?0%52dS7)=_{p$S9bE6YIvF$MtPnG06F$tPqqRJ$%l9{m)rN&~ z`=!ZR`EpEP85Q1*93m3$?)N^L@|hs z*5IBtp<$i#+{Zx6GMnL#r)wSfgZH>)HXSGbU+`99hbw$J@Oq{Eioh+RuuQd)MOSNL zH=6(~L`kLCA-AuA-7tmSOn;-j+9*PKZR~neY=A7--SHUg25?)hHrk>2SnT3}Ti`OA zU{_qES&7{YmW9iWss#Mg@mmpUFJj9A-^Ic0FW6C7a|D@2)8f zmk}leyCJ-HL3CB@wj{n-f?hCK>+$}UU^sjLyUJ!uez6`rl6gIb?k1hxROXG%=+f_M@Lfi3jN>m3Uu8QpSlhM+Y=eZQkoTJbjnCw? zN5D5o;M*0G?ez|tf~^nvIY?%KZvmFo#H+&BJ7W^>H?IL(Pg5%d>q0&*9;e=4D|}0- zXkrIz4!%A@-Tl=IXs?||z*l#w4#9C~o@_;1IkdHA^7R=q2YlONU{n$oGi$(Ciy#@$ zw5CyiZTQ+mdi5D%X>S5kvj&VykR>E&ye}v4&8-ICZpfj`O>id($?PidReyoYjxPw` z6mTD(e8=jWw6+^-lpYV?`ygFWLHp`CwwCrFT-SbO<@`wKJ|gLo8prA$TUz%8V~wp3 zKsH?|Yu(e0-XB*ZkJYeHp=VhZ{!kyO5$?q z{JK32$abCmb znp7P%=%Khhw9YgI8YRWNQs)W&%?{S+xWv0`*7K7Sq(&iX&{wF_*K%_n()8QhZ{30E z-rV2UJ-V)Q^B+$%=on}Rx~^=ssR+vdtnS$JYyl}sxTpL6PS1y5BFLv4(U5U-iJ;0) zeNrP~W}-$<-~CJ5Z!M9Wgxp4mS}CUGh46D00Rrej*UR8%FH3+N%7*qHRBtO~k=?W9 z=I)w}G&f{piJS%=Cy|m==WB~xX%@tG%E2mjqW`ok{wj9zTG81{;^!RdftYIEOX6oQ zk>l9e08gG4+Ec#xyn(8chn9rdkI8eie=vIxp((r08B}(lW zc)C%&5X-Fst<%qCeOU!*SHS}j_@S2!~98xLwIhtKLURL5(xGW((O9v9mU1h~` zS1TwwTJ;bDWsFIO%xdG7a*&3@sNSN9MVE#H}-Wd*~MQ)Vj3B(Ow^I?Dzr1iTlq`qyNlFAnmyA+Tzd_ z7L6>LckLF2_3R|?6ZVojG0|REnt8wJ=M8X9&)RoOv=4@EpbUM&I;ZT~PNVywr|EiJ zI@D;58x+Ln4JBWt`))#8(qyHbu-rY<;z{31Yc~&m1Lx|v|HgD7!j6;{{d zr|$gPAROj89t!Z5c8WxQo-a*Kq*VgYCba0m zqi9?qvD2dFCXGOwlx(6qQ46Ky)w%{W=2@!IY*kUc`c&E?3-w4y^E6x3w6H#~=_`mo zs_1T3dLAFPvWnc7@RcT{5f&L0w;h;4gNaWgs|GUH-bvn8^m3+T^eI7I2+Hx^i}mYN z;FK?_79ZwPHK-rnbUHTTq$tW^>J1v}2#|U%wn2#|hPG*kNo!Gc)5F5$KnA@qMNhpQ zK1m7jr%a&jAV;c!J6Kv#IS_$3wBrnk;G~Cq9#^DlJ?}NRFW)Nw``ijUe@r9q0JeagI+NPxGnc9G^PQ@u}l!zjI!ZSG?LZ4Yc$AZIs9< zO?h~vlo^7c(6uGX7`HH8YgsJ|It}QTFgb|!dZ8er`Dukv2ox^lCXnU?(##;u4ARUX z%?#2^Ak7TY%plDPq?tjQ8KgO(tv^Le7c`e+@oEmC4yO3i1COkCUwtSgra>kk0jx5t z?Dqb}s)rp`YlO-|?f?vr!|)UgPr>jM3{S!E0T`Zw;VBp%hv6w0o`T_V{g5H;)IQ{B z5TlulGio65NL=w=ZzR_ngp39!-zGCSfhwFY*2@02t9ywHH24jiW%M7>N$h5p6FIo z*0s!~3l->kS{V#z%1e8U7rEUXksomJ;>LCd_Bn5g?v58*Moc&a~2@vP(3?%PTyOgdpjTT(;X^H zbirG?@bgH&572KA&Os6dUMO>SAy-7a65^E+uZ(zQ#496S5%J21S4O-N;*}AvjCiF0 z)o8avI|ZbX3molVSfO-GI#QRv+&|m;GwtS%G+H0py6IhCc^6rw$F18QWk7wP-H}5; zerpXv*P+_4;ddZ78%zWA2V(+ZZbY$e9E}RxXrke6V2bNlrZ`hZ-+J|4O@Bt zW-RJUw-q@)_Fz`&_AJj+EjU_*)|WSE%5RP-{A~ZEEW9uByJ#SuK8mNsc)A`>AIH;;c)A%+pTyIr@pLPmZpYIdH)Zba Glm7#1XtnVG diff --git a/tests/test_data/pop_copy.mid b/tests/test_data/pop_copy.mid deleted file mode 100644 index be83c69da85868db81a45465d75c08b1c2384477..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 12527 zcmb`N+kX|;b;sAn`7#BQ+@ zyJ=5>&^9?u9!O{#tL;C~H`&5q1NNu)d;8vJJ@DyU-}*?uYtL*k0tp;9pO5xuAI(gd$0NqYu-FtKl}dgR9}DldvY(>m|EJqYnk4l+&dX$Ux4n* zVQ4FPx%Y~N%fCJ;e-p{y?w7xNM*jYX@(&l~A4lY$PRc)TmVc?pzdj}ZR+E2!NB-k| z`OmXv;mk$(F!wF_D0f~K8P^#fGj8Ng$<5q%%Kcl+xB34Lzdz&m=ecRKz?u(> zlkyQ`k#W7)Egu&fa-+CcZWiB>Pq_9e<5uxmxy`jZ#f*H$_0Nk%v%vKaTQ15+j77$E z#>b2sjGK&47@sn3F>cR&UG6Y{#`wJDqFG@6a8Bf-xrQvx)#Umd&oXZC|IInx%e_zM zy5$zvZnO5z+*9(|+_Uoe9M7`;!!ta~SY%vhe9XAPxXJkB3_LU6V%%ojIkQ(jWB&Zi zvxoot;>9b`pH$wu^2#rwb8_uiUH-IGmp_|*N#^_Ck|>nHmt;wbA|;VFk#!2d0bbWwU51f!Q3G?V;HjG^2Jd3}(#U zj2S`vR#`MpWXxE`9LYDMb*1^JjVoh;dBU0zYxY_*Xv|Jywi~nAm++S z@s0{1X) zOU;}3zt>zls?WOL?eo!Qu5D-DTD)?!p?ZV)9ViNr;Hih51LMr2&cpXXY@4dc%Ak?R zD$r1oQ;uu|GDbSBtb;^9tJ;mUdi}JZqMTxQKI-Kb{~~;hwl=594yRHGY`_dVSpvCk zBbyxctyV|?ya@>M)?Mo5EE$3wrRKn>CXC^Sbq^ZV(ozgPP}hti{+mS3?dWO%7w==< zRl;eZu>-$MJU295fpc0O8MKRLyEQv;#K7zg<`HVqj5z>`EXY=ItG8^xczZ>I?|#rk6i(*k87lDMqkCyKpcE4vNcRAqIX#bXWlTu{q0t^ z8rf`cd0f56NX04=)qRcNi6EOC(1smOkN{e9QV#Bm^K{6iK%tL=vyxPS)30|}-UtI< z;{|wbkL3lj1&kPd-wu>C$;)+kx9OM5N(CTsSpXB~m(b9{^w>;<2PAC(+=9FHBt0SR|2XZ)&Ljie- z>#8d5sgPtVg*g}^kp+arW^|g%WVU^Q>Ts`BnC%Amioba$CD||vq#tA&>2na>=HQCl zn+!IVX7mjtL>Y0+M8B$8fKRfunsSMQSrju3Bbl5Fj8@~QCSw;l;@Hp$GqWrq>_P~0 z#abmY_ghl(eDo|g;R`!g<7DiNj|&hAi1@1FB=I&QEgB}#*K_FaW%T!a4$lwpd=T{t z`hFfSdlfI6A^6S_d|zNb!+etYwU!u3FQW5rLgSQrdl@t#>9@-F1M_L_pXI*J=Sl=- zQ1cD!^E$lFGS^b`OZ@*v5|rIfD+XHM36>~;;-IWgn~c~{SLH(-z8Rgz z{+;5qZk^W2gqjay#HmC722W{Kq%I*tu>o8zg){tU)ENLU9^!4^Wo_)}lLgh(d~^ow z)X~ywX#9J`$dJ%yP05;eOC+3+zNz&?=<0{S>aUd~}u-uVcU8!j5O5`dVOKWf4oQ;elyOoF+s!qu=JHkJB2T_Zk0-XYS}%5g${7ir!=eyvEHJ5my2MGaF< zOZ(N9&FK5dvVG2v@n%#?;M!>wV|aTnB7Gk+n%lq4yxEzh+iI=^8$_XXhNwl>ZAL%f zeo5EiFch|_;>4MsoeLKJK@pU)0M4xgBjK-)ht$h}&Ln~Sy(D?OiLu_tSmzP#hgn=c z>)dkO#yR9Sc^wjAandjz2{5S&a6>o(?S+iiKI(d5Gx}ZpE(jzP+;QId1z{j>T~Q!f zG7upMs_U($Bv%OZAkgorz|Ox3HK+uSLqpWuu9^nFcvQWm88x<$vAW_+RWI}ZKCWA` z6v8yXR2@HP1mHG*Kcv@*wT*q~VvNl5UF_Nk?w(_-M)kB^2yYuYARUed#n1)1x zlE@z$6yDRK48pOm5Hjk9M$9}jJq1F<2?vo zIueHEsFx1#~X{u(gPQ>qX;>YjI zB<~BF1C3mNqf~9QKzaS8iJG}uWpy9iF_sloNXWKAR)!ioFI8*e*kG|E%Zw=b%YD4P zuU0_JLbcHf&8{>p&#Af0^FOK9)PD`v6(EEINt?isicsLDB)*A%2q9 zc3Bulb)XB~ulLl^JdqQh1Q`-s9Y(JHgk9!k+0>w@Y8?i3QI+B!o3a;RN1&-oH<7Z@ zY6jeFu8uh{)V%$~mEZ4Kdg;8hsyVNt6hV?d)ADj_CQb|oF z4F*;-6fTX|%2$uZ?0%52dS7)=_{p$S9bE6YIvF$MtPnG06F$tPqqRJ$%l9{m)rN&~ z`=!ZR`EpEP85Q1*93m3$?)N^L@|hs z*5IBtp<$i#+{Zx6GMnL#r)wSfgZH>)HXSGbU+`99hbw$J@Oq{Eioh+RuuQd)MOSNL zH=6(~L`kLCA-AuA-7tmSOn;-j+9*PKZR~neY=A7--SHUg25?)hHrk>2SnT3}Ti`OA zU{_qES&7{YmW9iWss#Mg@mmpUFJj9A-^Ic0FW6C7a|D@2)8f zmk}leyCJ-HL3CB@wj{n-f?hCK>+$}UU^sjLyUJ!uez6`rl6gIb?k1hxROXG%=+f_M@Lfi3jN>m3Uu8QpSlhM+Y=eZQkoTJbjnCw? zN5D5o;M*0G?ez|tf~^nvIY?%KZvmFo#H+&BJ7W^>H?IL(Pg5%d>q0&*9;e=4D|}0- zXkrIz4!%A@-Tl=IXs?||z*l#w4#9C~o@_;1IkdHA^7R=q2YlONU{n$oGi$(Ciy#@$ zw5CyiZTQ+mdi5D%X>S5kvj&VykR>E&ye}v4&8-ICZpfj`O>id($?PidReyoYjxPw` z6mTD(e8=jWw6+^-lpYV?`ygFWLHp`CwwCrFT-SbO<@`wKJ|gLo8prA$TUz%8V~wp3 zKsH?|Yu(e0-XB*ZkJYeHp=VhZ{!kyO5$?q z{JK32$abCmb znp7P%=%Khhw9YgI8YRWNQs)W&%?{S+xWv0`*7K7Sq(&iX&{wF_*K%_n()8QhZ{30E z-rV2UJ-V)Q^B+$%=on}Rx~^=ssR+vdtnS$JYyl}sxTpL6PS1y5BFLv4(U5U-iJ;0) zeNrP~W}-$<-~CJ5Z!M9Wgxp4mS}CUGh46D00Rrej*UR8%FH3+N%7*qHRBtO~k=?W9 z=I)w}G&f{piJS%=Cy|m==WB~xX%@tG%E2mjqW`ok{wj9zTG81{;^!RdftYIEOX6oQ zk>l9e08gG4+Ec#xyn(8chn9rdkI8eie=vIxp((r08B}(lW zc)C%&5X-Fst<%qCeOU!*SHS}j_@S2!~98xLwIhtKLURL5(xGW((O9v9mU1h~` zS1TwwTJ;bDWsFIO%xdG7a*&3@sNSN9MVE#H}-Wd*~MQ)Vj3B(Ow^I?Dzr1iTlq`qyNlFAnmyA+Tzd_ z7L6>LckLF2_3R|?6ZVojG0|REnt8wJ=M8X9&)RoOv=4@EpbUM&I;ZT~PNVywr|EiJ zI@D;58x+Ln4JBWt`))#8(qyHbu-rY<;z{31Yc~&m1Lx|v|HgD7!j6;{{d zr|$gPAROj89t!Z5c8WxQo-a*Kq*VgYCba0m zqi9?qvD2dFCXGOwlx(6qQ46Ky)w%{W=2@!IY*kUc`c&E?3-w4y^E6x3w6H#~=_`mo zs_1T3dLAFPvWnc7@RcT{5f&L0w;h;4gNaWgs|GUH-bvn8^m3+T^eI7I2+Hx^i}mYN z;FK?_79ZwPHK-rnbUHTTq$tW^>J1v}2#|U%wn2#|hPG*kNo!Gc)5F5$KnA@qMNhpQ zK1m7jr%a&jAV;c!J6Kv#IS_$3wBrnk;G~Cq9#^DlJ?}NRFW)Nw``ijUe@r9q0JeagI+NPxGnc9G^PQ@u}l!zjI!ZSG?LZ4Yc$AZIs9< zO?h~vlo^7c(6uGX7`HH8YgsJ|It}QTFgb|!dZ8er`Dukv2ox^lCXnU?(##;u4ARUX z%?#2^Ak7TY%plDPq?tjQ8KgO(tv^Le7c`e+@oEmC4yO3i1COkCUwtSgra>kk0jx5t z?Dqb}s)rp`YlO-|?f?vr!|)UgPr>jM3{S!E0T`Zw;VBp%hv6w0o`T_V{g5H;)IQ{B z5TlulGio65NL=w=ZzR_ngp39!-zGCSfhwFY*2@02t9ywHH24jiW%M7>N$h5p6FIo z*0s!~3l->kS{V#z%1e8U7rEUXksomJ;>LCd_Bn5g?v58*Moc&a~2@vP(3?%PTyOgdpjTT(;X^H zbirG?@bgH&572KA&Os6dUMO>SAy-7a65^E+uZ(zQ#496S5%J21S4O-N;*}AvjCiF0 z)o8avI|ZbX3molVSfO-GI#QRv+&|m;GwtS%G+H0py6IhCc^6rw$F18QWk7wP-H}5; zerpXv*P+_4;ddZ78%zWA2V(+ZZbY$e9E}RxXrke6V2bNlrZ`hZ-+J|4O@Bt zW-RJUw-q@)_Fz`&_AJj+EjU_*)|WSE%5RP-{A~ZEEW9uByJ#SuK8mNsc)A`>AIH;;c)A%+pTyIr@pLPmZpYIdH)Zba Glm7#1XtnVG diff --git a/tests/test_tokenizers.py b/tests/test_tokenizers.py deleted file mode 100644 index f110a7f..0000000 --- a/tests/test_tokenizers.py +++ /dev/null @@ -1,535 +0,0 @@ -import unittest -import logging -import os -import time - -from typing import Callable - -from aria import tokenizer -from aria.config import load_config -from ariautils.midi import MidiDict -from aria.data.datasets import _get_combined_mididict, _noise_midi_dict -from aria.utils import midi_to_audio - - -if not os.path.isdir("tests/test_results"): - os.makedirs("tests/test_results") - - -# TODO: Implement with tokenizer functions -def get_short_seq_abs(tknzr: tokenizer.AbsTokenizer): - return [ - ("prefix", "instrument", "piano"), - ("prefix", "instrument", "drum"), - "", - ("piano", 62, tknzr._quantize_velocity(45)), - ("onset", tknzr._quantize_onset(0)), - ("dur", tknzr._quantize_dur(50)), - ("drum", 50), - ("onset", tknzr._quantize_onset(100)), - ("piano", 64, tknzr._quantize_velocity(75)), - ("onset", tknzr._quantize_onset(100)), - ("dur", tknzr._quantize_dur(5000)), - "", - "", - "", - ("piano", 65, tknzr._quantize_velocity(75)), - ("onset", tknzr._quantize_onset(170)), - ("dur", tknzr._quantize_dur(100)), - "", - ("piano", 60, tknzr._quantize_velocity(45)), - ("onset", tknzr._quantize_onset(270)), - ("dur", tknzr._quantize_dur(60)), - "", - ("onset", tknzr._quantize_onset(270)), - ("dur", tknzr._quantize_dur(70)), - ("drum", 50), - ("onset", tknzr._quantize_onset(270)), - "", - ("piano", 80, tknzr._quantize_velocity(45)), - ("onset", tknzr._quantize_onset(270)), - ("dur", tknzr._quantize_dur(80)), - "", - ] - - -def get_concat_seq_abs(tknzr: tokenizer.AbsTokenizer): - return [ - ("onset", tknzr._quantize_onset(270)), - ("dur", tknzr._quantize_dur(60)), - "", - ("onset", tknzr._quantize_onset(270)), - ("dur", tknzr._quantize_dur(70)), - ("drum", 50), - ("onset", tknzr._quantize_onset(270)), - "", - ("piano", 80, tknzr._quantize_velocity(45)), - ("onset", tknzr._quantize_onset(270)), - ("dur", tknzr._quantize_dur(80)), - "", - ("prefix", "instrument", "piano"), - ("prefix", "instrument", "drum"), - "", - ("piano", 62, tknzr._quantize_velocity(45)), - ("onset", tknzr._quantize_onset(0)), - ("dur", tknzr._quantize_dur(50)), - ("drum", 50), - ("onset", tknzr._quantize_onset(100)), - ("piano", 64, tknzr._quantize_velocity(75)), - ("onset", tknzr._quantize_onset(100)), - ("dur", tknzr._quantize_dur(5000)), - "", - "", - "", - ("piano", 65, tknzr._quantize_velocity(75)), - ("onset", tknzr._quantize_onset(170)), - ("dur", tknzr._quantize_dur(100)), - "", - ("piano", 60, tknzr._quantize_velocity(45)), - ("onset", tknzr._quantize_onset(270)), - ("dur", tknzr._quantize_dur(60)), - "", - ("onset", tknzr._quantize_onset(270)), - ("dur", tknzr._quantize_dur(70)), - ("drum", 50), - ("onset", tknzr._quantize_onset(270)), - "", - ("piano", 80, tknzr._quantize_velocity(45)), - ("onset", tknzr._quantize_onset(270)), - ("dur", tknzr._quantize_dur(80)), - "", - ("prefix", "instrument", "piano"), - ("prefix", "instrument", "drum"), - "", - ("piano", 62, tknzr._quantize_velocity(45)), - ("onset", tknzr._quantize_onset(0)), - ("dur", tknzr._quantize_dur(50)), - ("drum", 50), - ("onset", tknzr._quantize_onset(100)), - ("piano", 64, tknzr._quantize_velocity(75)), - ("onset", tknzr._quantize_onset(100)), - ("dur", tknzr._quantize_dur(5000)), - "", - "", - ] - - -def get_short_seq_rel(tknzr: tokenizer.RelTokenizer): - return [ - ("prefix", "instrument", "piano"), - ("prefix", "instrument", "drum"), - ("prefix", "composer", "bach"), - "", - ("piano", 62, tknzr._quantize_velocity(50)), - ("dur", tknzr._quantize_time(50)), - ("wait", tknzr._quantize_time(100)), - ("drum", 50), - ("piano", 64, tknzr._quantize_velocity(70)), - ("dur", tknzr._quantize_time(1000000)), - ("wait", tknzr._quantize_time(1000000)), - ("wait", tknzr._quantize_time(1000000)), - ("wait", tknzr._quantize_time(1000000)), - ("wait", tknzr._quantize_time(100)), - ("piano", 65, tknzr._quantize_velocity(70)), - ("dur", tknzr._quantize_time(100)), - ("wait", tknzr._quantize_time(100)), - ("piano", 60, tknzr._quantize_velocity(50)), - ("dur", tknzr._quantize_time(60)), - ("piano", 70, tknzr._quantize_velocity(50)), - ("dur", tknzr._quantize_time(70)), - ("drum", 50), - ("piano", 80, tknzr._quantize_velocity(50)), - ("dur", tknzr._quantize_time(80)), - ("wait", tknzr._quantize_time(100)), - "", - ] - - -def get_concat_seq_rel(tknzr: tokenizer.RelTokenizer): - return [ - ("dur", tknzr._quantize_time(1000000)), - ("wait", tknzr._quantize_time(1000000)), - ("wait", tknzr._quantize_time(1000000)), - ("wait", tknzr._quantize_time(1000000)), - ("wait", tknzr._quantize_time(100)), - ("piano", 65, tknzr._quantize_velocity(70)), - ("dur", tknzr._quantize_time(100)), - ("wait", tknzr._quantize_time(100)), - ("piano", 60, tknzr._quantize_velocity(50)), - ("dur", tknzr._quantize_time(60)), - ("piano", 70, tknzr._quantize_velocity(50)), - ("dur", tknzr._quantize_time(70)), - ("drum", 50), - ("piano", 80, tknzr._quantize_velocity(50)), - ("dur", tknzr._quantize_time(80)), - ("wait", tknzr._quantize_time(100)), - "", - ("prefix", "instrument", "piano"), - ("prefix", "instrument", "drum"), - ("prefix", "composer", "bach"), - "", - ("piano", 62, tknzr._quantize_velocity(50)), - ("dur", tknzr._quantize_time(50)), - ("wait", tknzr._quantize_time(100)), - ("drum", tknzr._quantize_time(50)), - ("piano", 64, tknzr._quantize_velocity(70)), - ("dur", tknzr._quantize_time(1000000)), - ("wait", tknzr._quantize_time(1000000)), - ("wait", tknzr._quantize_time(1000000)), - ("wait", tknzr._quantize_time(1000000)), - ("wait", tknzr._quantize_time(100)), - ("piano", 65, tknzr._quantize_velocity(70)), - ("dur", tknzr._quantize_time(100)), - ("wait", tknzr._quantize_time(100)), - ("piano", 60, tknzr._quantize_velocity(50)), - ("dur", tknzr._quantize_time(60)), - ("piano", 70, tknzr._quantize_velocity(50)), - ("dur", tknzr._quantize_time(70)), - ("drum", 50), - ("piano", 80, tknzr._quantize_velocity(50)), - ("dur", tknzr._quantize_time(80)), - ("wait", tknzr._quantize_time(100)), - "", - ("prefix", "instrument", "piano"), - ("prefix", "instrument", "drum"), - ("prefix", "composer", "bach"), - "", - ("piano", 62, tknzr._quantize_velocity(50)), - ("dur", tknzr._quantize_time(50)), - ("wait", tknzr._quantize_time(100)), - ("drum", tknzr._quantize_time(50)), - ("piano", 64, tknzr._quantize_velocity(70)), - ] - - -class TestAbsTokenizer(unittest.TestCase): - def test_tokenize_detokenize_mididict(self): - def tokenize_detokenize(file_name: str): - mid_path = f"tests/test_data/{file_name}" - midi_dict = MidiDict.from_midi(mid_path=mid_path) - tokenized_seq = tknzr.tokenize(midi_dict) - detokenized_midi_dict = tknzr.detokenize(tokenized_seq) - res = detokenized_midi_dict.to_midi() - res.save(f"tests/test_results/{file_name}") - - tknzr = tokenizer.AbsTokenizer(return_tensors=False) - tokenize_detokenize("basic.mid") - tokenize_detokenize("arabesque.mid") - tokenize_detokenize("beethoven_sonata.mid") - tokenize_detokenize("bach.mid") - tokenize_detokenize("expressive.mid") - tokenize_detokenize("pop.mid") - tokenize_detokenize("beethoven_moonlight.mid") - tokenize_detokenize("maestro.mid") - - def test_aug(self): - def tokenize_aug_detokenize( - file_name: str, - aug_fn: Callable, - aug_name: str, - audio=False, - ): - mid_path = f"tests/test_data/{file_name}" - midi_dict = MidiDict.from_midi(mid_path=mid_path) - tokenized_seq = tknzr.tokenize(midi_dict) - tokenized_seq_aug = aug_fn(tokenized_seq) - detokenized_midi_dict = tknzr.detokenize(tokenized_seq_aug) - res = detokenized_midi_dict.to_midi() - save_path = f"tests/test_results/abs_{aug_name}_{file_name}" - res.save(save_path) - if audio is True: - midi_to_audio(save_path) - - tknzr = tokenizer.AbsTokenizer(return_tensors=False) - seq = get_short_seq_abs(tknzr) - seq_concat = get_concat_seq_abs(tknzr) - pitch_aug_fn = tknzr.export_pitch_aug(aug_range=5) - velocity_aug_fn = tknzr.export_velocity_aug(aug_steps_range=2) - tempo_aug_fn = tknzr.export_tempo_aug(tempo_aug_range=0.5, mixup=True) - - # Pitch augmentation - seq_pitch_augmented = pitch_aug_fn(get_short_seq_abs(tknzr)) - logging.info(f"pitch_aug_fn:\n{seq} ->\n\n{seq_pitch_augmented}\n") - tokenize_aug_detokenize("basic.mid", pitch_aug_fn, "pitch") - tokenize_aug_detokenize("arabesque.mid", pitch_aug_fn, "pitch") - tokenize_aug_detokenize("beethoven_sonata.mid", pitch_aug_fn, "pitch") - tokenize_aug_detokenize("bach.mid", pitch_aug_fn, "pitch") - tokenize_aug_detokenize("expressive.mid", pitch_aug_fn, "pitch") - tokenize_aug_detokenize("pop.mid", pitch_aug_fn, "pitch") - tokenize_aug_detokenize( - "beethoven_moonlight.mid", pitch_aug_fn, "pitch" - ) - - # Velocity augmentation - seq_velocity_augmented = velocity_aug_fn(get_short_seq_abs(tknzr)) - logging.info( - f"velocity_aug_fn:\n{seq} ->\n\n{seq_velocity_augmented}\n" - ) - tokenize_aug_detokenize("basic.mid", velocity_aug_fn, "velocity") - tokenize_aug_detokenize("arabesque.mid", velocity_aug_fn, "velocity") - tokenize_aug_detokenize( - "beethoven_sonata.mid", velocity_aug_fn, "velocity" - ) - tokenize_aug_detokenize("bach.mid", velocity_aug_fn, "velocity") - tokenize_aug_detokenize("expressive.mid", velocity_aug_fn, "velocity") - tokenize_aug_detokenize("pop.mid", velocity_aug_fn, "velocity") - tokenize_aug_detokenize( - "beethoven_moonlight.mid", velocity_aug_fn, "velocity" - ) - - # Tempo augmentation - seq_tempo_augmented = tempo_aug_fn(get_short_seq_abs(tknzr)) - logging.info(f"tempo_aug_fn:\n{seq} ->\n\n{seq_tempo_augmented}\n") - - seq_concat_tempo_augmented = tempo_aug_fn(get_concat_seq_abs(tknzr)) - logging.info( - f"tempo_aug_fn:\n{seq_concat} ->\n\n{seq_concat_tempo_augmented}\n" - ) - - tokenize_aug_detokenize("basic.mid", tempo_aug_fn, "tempo") - tokenize_aug_detokenize("arabesque.mid", tempo_aug_fn, "tempo") - tokenize_aug_detokenize("beethoven_sonata.mid", tempo_aug_fn, "tempo") - tokenize_aug_detokenize("bach.mid", tempo_aug_fn, "tempo") - tokenize_aug_detokenize("expressive.mid", tempo_aug_fn, "tempo") - tokenize_aug_detokenize("pop.mid", tempo_aug_fn, "tempo") - tokenize_aug_detokenize( - "beethoven_moonlight.mid", tempo_aug_fn, "tempo" - ) - - def test_aug_time(self): - tknzr = tokenizer.AbsTokenizer() - mid_dict = MidiDict.from_midi("tests/test_data/beethoven_sonata.mid") - tokenized_seq = tknzr.tokenize(mid_dict)[:4096] - pitch_aug_fn = tknzr.export_pitch_aug(aug_range=5) - velocity_aug_fn = tknzr.export_velocity_aug(aug_steps_range=2) - tempo_aug_fn = tknzr.export_tempo_aug(tempo_aug_range=0.5, mixup=True) - - # Pitch augmentation - t_start = time.perf_counter() - pitch_aug_fn(tokenized_seq) - t_pitch_aug = (time.perf_counter() - t_start) * 1e3 - logging.info(f"pitch_aug_fn took {int(t_pitch_aug)}ms") - self.assertLessEqual(t_pitch_aug, 50) - - # Velocity augmentation - t_start = time.perf_counter() - velocity_aug_fn(tokenized_seq) - t_vel_aug = (time.perf_counter() - t_start) * 1e3 - logging.info(f"velocity_aug_fn took {int(t_vel_aug)}ms") - self.assertLessEqual(t_vel_aug, 50) - - # Tempo augmentation - t_start = time.perf_counter() - tempo_aug_fn(tokenized_seq) - t_tempo_aug = (time.perf_counter() - t_start) * 1e3 - logging.info(f"tempo_aug_fn took {int(t_tempo_aug)}ms") - self.assertLessEqual(t_tempo_aug, 50) - - def test_no_unk_token(self): - def _test_no_unk_token(file_name: str): - mid_path = f"tests/test_data/{file_name}" - midi_dict = MidiDict.from_midi(mid_path=mid_path) - seq = tknzr.tokenize(midi_dict) - enc_dec_seq = tknzr.decode(tknzr.encode(seq)) - for tok in enc_dec_seq: - self.assertTrue(tok != tknzr.unk_tok) - - tknzr = tokenizer.AbsTokenizer() - _test_no_unk_token("basic.mid") - _test_no_unk_token("arabesque.mid") - _test_no_unk_token("bach.mid") - _test_no_unk_token("expressive.mid") - _test_no_unk_token("pop.mid") - _test_no_unk_token("beethoven_moonlight.mid") - - -# TODO: This example is not working, I'm pretty sure the issue is in _get_combined_mididict somewhere -# Fix this!! -class TestSeparatedTokenizer(unittest.TestCase): - def test_tokenize_detokenize_mididict(self): - def _find_inst_onsets(_seq: list): - curr_time_ms = 0 - time_toks = 0 - for tok in _seq: - if tok == "": - time_toks += 1 - elif isinstance(tok, tuple) and tok[0] == "onset": - curr_time_ms = 5000 * time_toks + tok[1] - elif tok == "": - print("Seen at", curr_time_ms) - - tknzr = tokenizer.SeparatedAbsTokenizer() - - clean_midi_dict = MidiDict.from_midi( - mid_path="/mnt/ssd1/data/mp3/raw/maestro-mp3/2004/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.midi" - ) - noisy_midi_dict = MidiDict.from_midi( - mid_path="/mnt/ssd1/data/mp3/raw/maestro-mp3/2004/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.midi" - # mid_path="/mnt/ssd1/amt/transcribed_data/noisy_maestro/small-long-e7/2004/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.mid" - ) - - noisy_midi_dict = _noise_midi_dict( - noisy_midi_dict, load_config()["data"]["finetuning"]["noising"] - ) - - clean_mid = clean_midi_dict.to_midi() - clean_mid.save(f"tests/test_results/combined_clean.mid") - noisy_mid = noisy_midi_dict.to_midi() - noisy_mid.save(f"tests/test_results/combined_noisy.mid") - - comb_midi_dict = _get_combined_mididict( - clean_midi_dict, - noisy_midi_dict, - min_noisy_ms=10000, - max_noisy_ms=25000, - min_clean_ms=30000, - max_clean_ms=60000, - ) - - comb_midi = comb_midi_dict.to_midi() - comb_midi.save(f"tests/test_results/combined_raw.mid") - tokenized_seq = tknzr.tokenize(comb_midi_dict) - detokenized_midi_dict = tknzr.detokenize(tokenized_seq) - res = detokenized_midi_dict.to_midi() - res.save(f"tests/test_results/combined.mid") - - for idx, sub_seq in enumerate(tknzr.split(tokenized_seq, 4096)): - if idx == 3: - _find_inst_onsets(sub_seq) - print(idx) - print(sub_seq) - detokenized_midi_dict = tknzr.detokenize(sub_seq) - res = detokenized_midi_dict.to_midi() - res.save(f"tests/test_results/combined{idx}.mid") - - -class TestRelTokenizer(unittest.TestCase): - def test_tokenize_detokenize_mididict(self): - def tokenize_detokenize(file_name: str): - mid_path = f"tests/test_data/{file_name}" - midi_dict = MidiDict.from_midi(mid_path=mid_path) - tokenized_seq = tknzr.tokenize(midi_dict) - detokenized_midi_dict = tknzr.detokenize(tokenized_seq) - res = detokenized_midi_dict.to_midi() - res.save(f"tests/test_results/{file_name}") - - tknzr = tokenizer.RelTokenizer(return_tensors=False) - - tokenize_detokenize("basic.mid") - tokenize_detokenize("arabesque.mid") - tokenize_detokenize("beethoven_sonata.mid") - tokenize_detokenize("bach.mid") - tokenize_detokenize("expressive.mid") - tokenize_detokenize("pop.mid") - tokenize_detokenize("beethoven_moonlight.mid") - - def test_aug(self): - tknzr = tokenizer.RelTokenizer(return_tensors=False) - seq = get_short_seq_rel(tknzr) - seq_concat = get_concat_seq_rel(tknzr) - pitch_aug_fn = tknzr.export_pitch_aug(aug_range=5) - velocity_aug_fn = tknzr.export_velocity_aug(aug_steps_range=2) - tempo_aug_fn = tknzr.export_tempo_aug(tempo_aug_range=0.8) - chord_mixup_fn = tknzr.export_chord_mixup() - - # Pitch augmentation - seq_pitch_augmented = pitch_aug_fn(get_short_seq_rel(tknzr)) - logging.info(f"pitch_aug_fn:\n{seq} ->\n\n{seq_pitch_augmented}\n") - self.assertEqual( - seq_pitch_augmented[4][1] - seq[4][1], - seq_pitch_augmented[8][1] - seq[8][1], - ) - - # Velocity augmentation - seq_velocity_augmented = velocity_aug_fn(get_short_seq_rel(tknzr)) - logging.info( - f"velocity_aug_fn:\n{seq} ->\n\n{seq_velocity_augmented}\n" - ) - self.assertEqual( - seq_velocity_augmented[4][2] - seq[4][2], - seq_velocity_augmented[8][2] - seq[8][2], - ) - - # Tempo augmentation - seq_tempo_augmented = tempo_aug_fn(get_short_seq_rel(tknzr)) - logging.info(f"tempo_aug_fn:\n{seq} ->\n\n{seq_tempo_augmented}\n") - - seq_concat_tempo_augmented = tempo_aug_fn(get_concat_seq_rel(tknzr)) - logging.info( - f"tempo_aug_fn:\n{seq_concat} ->\n\n{seq_concat_tempo_augmented}\n" - ) - - # Chord mix-up augmentation - seq_mixup_augmented = chord_mixup_fn(get_short_seq_rel(tknzr)) - logging.info(f"chord_mixup_fn:\n{seq} ->\n\n{seq_mixup_augmented}\n") - - seq_concat_tempo_augmented = chord_mixup_fn(get_concat_seq_rel(tknzr)) - logging.info( - f"chord_mixup_fn:\n{seq_concat} ->\n\n{seq_concat_tempo_augmented}\n" - ) - - def test_aug_time(self): - tknzr = tokenizer.RelTokenizer() - mid_dict = MidiDict.from_midi("tests/test_data/beethoven_sonata.mid") - tokenized_seq = tknzr.tokenize(mid_dict)[:4096] - - pitch_aug_fn = tknzr.export_pitch_aug(aug_range=5) - velocity_aug_fn = tknzr.export_velocity_aug(aug_steps_range=2) - tempo_aug_fn = tknzr.export_tempo_aug(tempo_aug_range=0.5) - chord_mixup_fn = tknzr.export_chord_mixup() - - # Pitch augmentation - t_start = time.perf_counter() - pitch_aug_fn(tokenized_seq) - t_pitch_aug = (time.perf_counter() - t_start) * 1e3 - logging.info(f"pitch_aug_fn took {int(t_pitch_aug)}ms") - self.assertLessEqual(t_pitch_aug, 50) - - # Velocity augmentation - t_start = time.perf_counter() - velocity_aug_fn(tokenized_seq) - t_vel_aug = (time.perf_counter() - t_start) * 1e3 - logging.info(f"velocity_aug_fn took {int(t_vel_aug)}ms") - self.assertLessEqual(t_vel_aug, 50) - - # Tempo augmentation - t_start = time.perf_counter() - tempo_aug_fn(tokenized_seq) - t_tempo_aug = (time.perf_counter() - t_start) * 1e3 - logging.info(f"tempo_aug_fn took {int(t_tempo_aug)}ms") - self.assertLessEqual(t_tempo_aug, 50) - - # Chord mixup augmentation - t_start = time.perf_counter() - chord_mixup_fn(tokenized_seq) - t_mixup_aug = (time.perf_counter() - t_start) * 1e3 - logging.info(f"mixup_aug_fn took {int(t_mixup_aug)}ms") - self.assertLessEqual(t_mixup_aug, 50) - - def test_encode_decode(self): - tknzr = tokenizer.RelTokenizer(return_tensors=True) - seq = get_short_seq_rel(tknzr) - enc_dec_seq = tknzr.decode(tknzr.encode(seq)) - for x, y in zip(seq, enc_dec_seq): - self.assertEqual(x, y) - - tknzr = tokenizer.RelTokenizer(return_tensors=False) - seq = get_short_seq_rel(tknzr) - enc_dec_seq = tknzr.decode(tknzr.encode(seq)) - for x, y in zip(seq, enc_dec_seq): - self.assertEqual(x, y) - - def test_no_unk_token(self): - tknzr = tokenizer.RelTokenizer() - seq = get_short_seq_rel(tknzr) - enc_dec_seq = tknzr.decode(tknzr.encode(seq)) - for tok in enc_dec_seq: - self.assertTrue(tok != tknzr.unk_tok) - - -if __name__ == "__main__": - if os.path.isdir("tests/test_results") is False: - os.mkdir("tests/test_results") - - logging.basicConfig(level=logging.INFO) - unittest.main() From a37ba9c1960dd035476b0127ebb6d0dd1e52f1b7 Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 3 Jun 2025 11:52:44 +0000 Subject: [PATCH 51/72] add resid dropout to model --- aria/model.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/aria/model.py b/aria/model.py index 5ed277f..68b3598 100644 --- a/aria/model.py +++ b/aria/model.py @@ -19,9 +19,9 @@ class ModelConfig: drop_p: float max_seq_len: int grad_checkpoint: bool + resid_dropout: float = 0.0 vocab_size: Optional[int] = None class_size: Optional[int] = None - tag_to_id: Optional[dict] = None emb_size: Optional[dict] = None def set_vocab_size(self, vocab_size: int): @@ -29,13 +29,13 @@ def set_vocab_size(self, vocab_size: int): class FusedEncoderBlock(nn.Module): - def __init__(self, model_config: ModelConfig): + def __init__(self, model_config: ModelConfig, resid_dropout: float = 0.0): super().__init__() - self.drop_p = model_config.drop_p self.n_heads = model_config.n_heads self.d_head = model_config.d_model // model_config.n_heads self.max_seq_len = model_config.max_seq_len + self.resid_dropout = resid_dropout # Attention self.mixed_qkv = nn.Linear( @@ -71,8 +71,11 @@ def __init__(self, model_config: ModelConfig): self.norm2 = nn.LayerNorm(model_config.d_model) def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor): - x = x + self._att_block(self.norm1(x), freqs_cis) - x = x + self._ff_block(self.norm2(x)) + att_out = self._att_block(self.norm1(x), freqs_cis) + x = x + F.dropout(att_out, p=self.resid_dropout, training=self.training) + + ff_out = self._ff_block(self.norm2(x)) + x = x + F.dropout(ff_out, p=self.resid_dropout, training=self.training) return x @@ -136,8 +139,18 @@ def __init__(self, model_config: ModelConfig): self.out_layer_norm = nn.LayerNorm(model_config.d_model) self.encode_layers = nn.ModuleList() - for _ in range(model_config.n_layers): - self.encode_layers.append(FusedEncoderBlock(model_config)) + + for layer_index in range(model_config.n_layers): + if model_config.resid_dropout > 0: + layer_dropout = model_config.resid_dropout * ( + layer_index / (model_config.n_layers - 1) + ) + else: + layer_dropout = 0.0 + + self.encode_layers.append( + FusedEncoderBlock(model_config, resid_dropout=layer_dropout) + ) def forward( self, From 91802dfc35d3fc7b96685317855820133199563a Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 3 Jun 2025 12:04:13 +0000 Subject: [PATCH 52/72] import fix --- aria/eval/linear_probe.py | 4 ++-- aria/eval/m3/utils.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/aria/eval/linear_probe.py b/aria/eval/linear_probe.py index 2db8c25..0766a74 100644 --- a/aria/eval/linear_probe.py +++ b/aria/eval/linear_probe.py @@ -176,7 +176,7 @@ def get_mert_embedding( hook_pianoteq_exec_path: str, hook_pianoteq_num_procs: int, ): - from aria.embeddings.mert.emb import ( + from aria.eval.mert.emb import ( seq_to_audio_path, compute_audio_embedding, ) @@ -210,7 +210,7 @@ def get_clamp3_embedding( hook_patchilizer, hook_tokenizer: AbsTokenizer, ): - from aria.embeddings.m3.emb import get_midi_embedding + from aria.eval.m3.emb import get_midi_embedding emb = [ get_midi_embedding( diff --git a/aria/eval/m3/utils.py b/aria/eval/m3/utils.py index 2713aaf..0b84964 100644 --- a/aria/eval/m3/utils.py +++ b/aria/eval/m3/utils.py @@ -3,7 +3,7 @@ import math import torch import random -from aria.embeddings.m3.config import * +from aria.eval.m3.config import * from unidecode import unidecode from torch.nn import functional as F from transformers import ( From f6890296f42d3782fce8c7ac38726fb943bf8844 Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 3 Jun 2025 12:04:46 +0000 Subject: [PATCH 53/72] inference tree skeleton --- aria/inference/__init__.py | 1 - aria/inference/{model.py => model_cuda.py} | 0 aria/inference/model_mlx.py | 2 - aria/inference/sample_cuda.py | 429 +++++++++++++++++++++ aria/inference/sample_mlx.py | 2 - aria/run.py | 4 +- 6 files changed, 431 insertions(+), 7 deletions(-) rename aria/inference/{model.py => model_cuda.py} (100%) create mode 100644 aria/inference/sample_cuda.py diff --git a/aria/inference/__init__.py b/aria/inference/__init__.py index f87bb30..e69de29 100644 --- a/aria/inference/__init__.py +++ b/aria/inference/__init__.py @@ -1 +0,0 @@ -from .model import TransformerLM diff --git a/aria/inference/model.py b/aria/inference/model_cuda.py similarity index 100% rename from aria/inference/model.py rename to aria/inference/model_cuda.py diff --git a/aria/inference/model_mlx.py b/aria/inference/model_mlx.py index 22a211e..cb80d85 100644 --- a/aria/inference/model_mlx.py +++ b/aria/inference/model_mlx.py @@ -6,7 +6,6 @@ import mlx.nn as nn -# TODO: Implement this with dynamic kv-size class KVCache(nn.Module): def __init__( self, @@ -229,7 +228,6 @@ def setup_cache( dtype=dtype, ) - # mx.bool isn't a thing? How do I do this in mlx, is it mx.bool_ ? self.model.causal_mask = mx.tril( mx.ones((max_seq_len, max_seq_len), dtype=mx.bool_) ) diff --git a/aria/inference/sample_cuda.py b/aria/inference/sample_cuda.py new file mode 100644 index 0000000..3fcedbf --- /dev/null +++ b/aria/inference/sample_cuda.py @@ -0,0 +1,429 @@ +"""Contains generation/sampling code""" + +import torch +import torch._dynamo.config +import torch._inductor.config + +from typing import List +from tqdm import tqdm + +from aria.inference.model_cuda import TransformerLM +from ariautils.tokenizer import Tokenizer, AbsTokenizer +from ariautils.midi import MidiDict + +torch._inductor.config.coordinate_descent_tuning = True +torch._inductor.config.triton.unique_kernel_names = True +torch._inductor.config.fx_graph_cache = True + + +def get_cfg_prompt(prompts: list): + cfg_prompts = [] + for prompt in prompts: + cfg_prompts.append(prompt) + cfg_prompts.append(prompt) + + return cfg_prompts + + +@torch.inference_mode() +def decode_one( + model: TransformerLM, + idxs: torch.Tensor, + input_pos: torch.Tensor, + pad_idxs: torch.Tensor | None = None, +) -> torch.Tensor: + assert input_pos.shape[-1] == 1 + + logits = model.forward( + idxs=idxs, + input_pos=input_pos, + pad_idxs=pad_idxs, + )[:, -1] + + return logits + + +@torch.inference_mode() +def prefill( + model: TransformerLM, + idxs: torch.Tensor, + input_pos: torch.Tensor, + pad_idxs: torch.Tensor | None = None, +) -> torch.Tensor: + logits = model.forward( + idxs=idxs, + input_pos=input_pos, + pad_idxs=pad_idxs, + )[:, -1] + + return logits + + +def update_seq_ids_( + seq: torch.Tensor, + idx: int, + next_token_ids: torch.Tensor, + dim_tok_inserted: list, + eos_tok_seen: list, + max_len: int, + force_end: bool, + tokenizer: Tokenizer, +): + # Insert dim and pad toks + for _idx in range(seq.shape[0]): + if eos_tok_seen[_idx] == True: + next_token_ids[_idx] = tokenizer.tok_to_id[tokenizer.pad_tok] + elif ( + force_end + and idx >= max_len - 130 + and dim_tok_inserted[_idx] is False + and tokenizer.id_to_tok[next_token_ids[_idx].item()][0] + not in ("dur", "onset") + ): + next_token_ids[_idx] = tokenizer.tok_to_id[tokenizer.dim_tok] + + # Update dim_tok_inserted and eos_tok_seen + if next_token_ids[_idx] == tokenizer.tok_to_id[tokenizer.dim_tok]: + dim_tok_inserted[_idx] = True + elif next_token_ids[_idx] == tokenizer.tok_to_id[tokenizer.eos_tok]: + eos_tok_seen[_idx] = True + + seq[:, idx] = next_token_ids + + +@torch.autocast( + "cuda", + dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16, +) +@torch.inference_mode() +def sample_batch( + model: TransformerLM, + tokenizer: Tokenizer, + prompts: List[list], + max_new_tokens: int, + force_end=False, + temp: float = 0.95, + embedding: list[float] | None = None, + top_p: float | None = None, + min_p: float | None = None, + compile: bool = False, +): + assert top_p is not None or min_p is not None + if top_p is not None: + assert 0.5 <= top_p <= 1.0 + if min_p is not None: + assert 0.0 <= min_p <= 1.0 + if temp is not None: + assert 0.0 <= temp <= 2.0 + if force_end: + assert max_new_tokens > 130, "prompt too long to use force_end=True" + + prompt_len = len(prompts[0]) + num_prompts = len(prompts) + assert all([len(p) == prompt_len for p in prompts]) + + model.eval() + dim_tok_inserted = [False for _ in range(num_prompts)] + eos_tok_seen = [False for _ in range(num_prompts)] + total_len = prompt_len + max_new_tokens + seq = torch.stack( + [ + torch.tensor( + tokenizer.encode(p + [tokenizer.pad_tok] * (total_len - len(p))) + ) + for p in prompts + ] + ).cuda() + + if compile is True: + global decode_one + decode_one = torch.compile( + decode_one, + mode="reduce-overhead", + fullgraph=True, + ) + + model.setup_cache( + batch_size=num_prompts, + max_seq_len=total_len, + dtype=( + torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + ), + ) + + if embedding: + condition_embedding = torch.tensor( + [embedding for _ in range(num_prompts)], device=seq.device + ) + model.fill_condition_kv(cond_emb=condition_embedding) + emb_offset = 1 + else: + emb_offset = 0 + + print( + f"Using hyperparams: temp={temp}, top_p={top_p}, min_p={min_p}, gen_len={max_new_tokens}" + ) + + for idx in ( + pbar := tqdm( + range(prompt_len, total_len), + total=total_len - prompt_len, + leave=False, + ) + ): + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): + if idx == prompt_len: + logits = prefill( + model, + idxs=seq[:, :idx], + input_pos=torch.arange( + emb_offset, idx + emb_offset, device=seq.device + ), + ) + else: + logits = decode_one( + model, + idxs=seq[:, idx - 1 : idx], + input_pos=torch.tensor( + [(idx + emb_offset) - 1], + device=seq.device, + dtype=torch.int, + ), + ) + + if temp > 0.0: + probs = torch.softmax(logits / temp, dim=-1) + if min_p is not None: + next_token_ids = sample_min_p(probs, min_p).flatten() + else: + next_token_ids = sample_top_p(probs, top_p).flatten() + else: + next_token_ids = torch.argmax(logits, dim=-1).flatten() + + update_seq_ids_( + seq=seq, + idx=idx, + next_token_ids=next_token_ids, + dim_tok_inserted=dim_tok_inserted, + eos_tok_seen=eos_tok_seen, + max_len=total_len, + force_end=force_end, + tokenizer=tokenizer, + ) + + if all(seen_eos is True for seen_eos in eos_tok_seen): + break + + decoded_results = [tokenizer.decode(s) for s in seq.tolist()] + decoded_results = [ + ( + res[: res.index(tokenizer.eos_tok) + 1] + if tokenizer.eos_tok in res + else res + ) + for res in decoded_results + ] + + return decoded_results + + +# Not tested but I think this works +@torch.autocast( + "cuda", + dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16, +) +@torch.inference_mode() +def sample_batch_cfg( + model: TransformerLM, + tokenizer: AbsTokenizer, + prompts: List[list], + max_new_tokens: int, + cfg_gamma: float, + embedding: list[float], + force_end=False, + temp: float = 0.95, + top_p: float | None = None, + min_p: float | None = None, + compile: bool = False, +): + assert 0.0 <= cfg_gamma <= 15.0 + assert top_p is not None or min_p is not None + if top_p is not None: + assert 0.5 <= top_p <= 1.0 + if temp is not None: + assert 0.0 <= temp <= 2.0 + if force_end: + assert max_new_tokens > 130, "prompt too long to use force_end=True" + + prompts = get_cfg_prompt(prompts) + + prompt_len = len(prompts[0]) + num_prompts = len(prompts) + assert all([len(p) == prompt_len for p in prompts]) + + model.eval() + total_context_len = prompt_len + max_new_tokens + seq = torch.stack( + [ + torch.tensor( + tokenizer.encode( + p + [tokenizer.pad_tok] * (total_context_len - len(p)) + ) + ) + for p in prompts + ] + ).cuda() + dim_tok_inserted = [False for _ in range(num_prompts)] + eos_tok_seen = [False for _ in range(num_prompts)] + + if compile is True: + global decode_one + decode_one = torch.compile( + decode_one, + mode="reduce-overhead", + fullgraph=True, + ) + + model.setup_cache( + batch_size=num_prompts, + max_seq_len=total_context_len, + dtype=( + torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + ), + ) + + condition_embedding = torch.tensor( + [embedding for _ in range(num_prompts)], device=seq.device + ) + model.fill_condition_kv(cond_emb=condition_embedding) + embedding_offset = 1 + pad_idxs = torch.zeros_like(seq, dtype=torch.bool) + pad_idxs[1::2, 0] = True + + print( + f"Using hyperparams: temp={temp}, top_p={top_p}, min_p={min_p}, gamma={cfg_gamma}, gen_len={max_new_tokens}" + ) + + CFG_WARM_UP_STEPS = 250 + curr_step = 0 + for idx in ( + pbar := tqdm( + range(prompt_len, total_context_len), + total=total_context_len - prompt_len, + leave=False, + ) + ): + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): + if idx == prompt_len: + logits = prefill( + model, + idxs=seq[:, :idx], + input_pos=torch.arange( + embedding_offset, + idx + embedding_offset, + device=seq.device, + ), + pad_idxs=pad_idxs, + ) + else: + logits = decode_one( + model, + idxs=seq[:, idx - 1 : idx], + input_pos=torch.tensor( + [(idx + embedding_offset) - 1], + device=seq.device, + dtype=torch.int, + ), + pad_idxs=pad_idxs, + ) + + curr_step += 1 + _cfg_gamma = min(cfg_gamma, (curr_step / CFG_WARM_UP_STEPS) * cfg_gamma) + + logits_cfg = _cfg_gamma * logits[::2] + (1 - _cfg_gamma) * logits[1::2] + logits_cfg[:, tokenizer.tok_to_id[tokenizer.dim_tok]] = float("-inf") + + if temp > 0.0: + probs = torch.softmax(logits_cfg / temp, dim=-1) + if min_p is not None: + next_token_ids = sample_min_p(probs, min_p).flatten() + else: + next_token_ids = sample_top_p(probs, top_p).flatten() + else: + next_token_ids = torch.argmax(logits_cfg, dim=-1).flatten() + + next_token_ids = next_token_ids.repeat_interleave(2) + update_seq_ids_( + seq=seq, + idx=idx, + next_token_ids=next_token_ids, + dim_tok_inserted=dim_tok_inserted, + eos_tok_seen=eos_tok_seen, + max_len=total_context_len, + force_end=force_end, + tokenizer=tokenizer, + ) + + if all(seen_eos is True for seen_eos in eos_tok_seen): + break + + decoded_results = [tokenizer.decode(s) for s in seq.tolist()][::2] + decoded_results = [ + ( + res[: res.index(tokenizer.eos_tok) + 1] + if tokenizer.eos_tok in res + else res + ) + for res in decoded_results + ] + + return decoded_results + + +# Working +def sample_min_p(probs, p_base): + """See - https://arxiv.org/pdf/2407.01082""" + p_max, _ = torch.max(probs, dim=-1, keepdim=True) + p_scaled = p_base * p_max + mask = probs >= p_scaled + + masked_probs = probs.clone() + masked_probs[~mask] = 0.0 + masked_probs.div_(masked_probs.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(masked_probs, num_samples=1) + + return next_token + + +def sample_top_p(probs, p): + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > p + probs_sort[mask] = 0.0 + + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(probs_sort, num_samples=1) + next_token = torch.gather(probs_idx, -1, next_token) + + return next_token + + +def get_inference_prompt( + midi_dict: MidiDict, tokenizer: AbsTokenizer, prompt_len_ms: int +): + midi_dict.note_msgs = [ + msg + for msg in midi_dict.note_msgs + if midi_dict.tick_to_ms(msg["data"]["start"]) <= prompt_len_ms + ] + + if len(midi_dict.note_msgs) == 0: + return [("prefix", "instrument", "piano"), tokenizer.bos_tok] + + seq = tokenizer.tokenize(midi_dict=midi_dict) + if tokenizer.dim_tok in seq: + seq.remove(tokenizer.dim_tok) + if tokenizer.eos_tok in seq: + seq.remove(tokenizer.eos_tok) + + return seq diff --git a/aria/inference/sample_mlx.py b/aria/inference/sample_mlx.py index 65a33ae..4dea923 100644 --- a/aria/inference/sample_mlx.py +++ b/aria/inference/sample_mlx.py @@ -182,8 +182,6 @@ def sample_min_p(probs: mx.array, p_base: float): # Added type hint p_scaled = p_base * p_max mask = probs >= p_scaled - print(mx.sum(mask).item()) - masked_probs = mx.where(~mask, mx.zeros_like(probs), probs) sum_masked_probs = mx.sum(masked_probs, axis=-1, keepdims=True) masked_probs_normalized = masked_probs / sum_masked_probs diff --git a/aria/run.py b/aria/run.py index 079b1e8..5f68afa 100644 --- a/aria/run.py +++ b/aria/run.py @@ -79,7 +79,7 @@ def _get_embedding( from aria.model import ModelConfig from aria.config import load_model_config from aria.utils import _load_weight - from aria.embeddings.evaluate import ( + from aria.eval.linear_probe import ( get_aria_contrastive_embedding, process_entry, ) @@ -149,7 +149,7 @@ def sample(args): """Entrypoint for sampling""" from torch.cuda import is_available as cuda_is_available - from aria.inference import TransformerLM + from aria.inference.model_cuda import TransformerLM from aria.model import ModelConfig from aria.config import load_model_config from aria.sample import sample_batch, sample_batch_cfg, get_inference_prompt From 3491ae4f685918a0036e3e14470d1a721dd6572e Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 3 Jun 2025 12:07:44 +0000 Subject: [PATCH 54/72] fix tree --- aria/inference/sample_mlx.py | 2 +- aria/run.py | 6 +- aria/sample.py | 426 ----------------- aria/train.py | 903 ----------------------------------- 4 files changed, 6 insertions(+), 1331 deletions(-) delete mode 100644 aria/sample.py delete mode 100644 aria/train.py diff --git a/aria/inference/sample_mlx.py b/aria/inference/sample_mlx.py index 4dea923..2abeddc 100644 --- a/aria/inference/sample_mlx.py +++ b/aria/inference/sample_mlx.py @@ -205,7 +205,7 @@ def sample(): from ariautils.midi import MidiDict from ariautils.tokenizer import AbsTokenizer - from aria.sample import get_inference_prompt + from aria.inference.sample_cuda import get_inference_prompt CHECKPOINT_PATH = ( "/Users/louis/work/aria/models/medium-75-annealed.safetensors" diff --git a/aria/run.py b/aria/run.py index 5f68afa..06f11af 100644 --- a/aria/run.py +++ b/aria/run.py @@ -152,7 +152,11 @@ def sample(args): from aria.inference.model_cuda import TransformerLM from aria.model import ModelConfig from aria.config import load_model_config - from aria.sample import sample_batch, sample_batch_cfg, get_inference_prompt + from aria.inference.sample_cuda import ( + sample_batch, + sample_batch_cfg, + get_inference_prompt, + ) from aria.utils import _load_weight from ariautils.midi import MidiDict diff --git a/aria/sample.py b/aria/sample.py deleted file mode 100644 index 653505a..0000000 --- a/aria/sample.py +++ /dev/null @@ -1,426 +0,0 @@ -"""Contains generation/sampling code""" - -import torch -import torch._dynamo.config -import torch._inductor.config - -from typing import List -from tqdm import tqdm - -from aria.inference import TransformerLM -from ariautils.tokenizer import Tokenizer, AbsTokenizer -from ariautils.midi import MidiDict - -torch._inductor.config.coordinate_descent_tuning = True -torch._inductor.config.triton.unique_kernel_names = True -torch._inductor.config.fx_graph_cache = True - - -def get_cfg_prompt(prompts: list): - cfg_prompts = [] - for prompt in prompts: - cfg_prompts.append(prompt) - cfg_prompts.append(prompt) - - return cfg_prompts - - -@torch.inference_mode() -def decode_one( - model: TransformerLM, - idxs: torch.Tensor, - input_pos: torch.Tensor, - pad_idxs: torch.Tensor | None = None, -): - assert input_pos.shape[-1] == 1 - - logits = model.forward( - idxs=idxs, - input_pos=input_pos, - pad_idxs=pad_idxs, - )[:, -1] - - return logits - - -@torch.inference_mode() -def prefill( - model: TransformerLM, - idxs: torch.Tensor, - input_pos: torch.Tensor, - pad_idxs: torch.Tensor | None = None, -): - logits = model.forward(idxs=idxs, input_pos=input_pos, pad_idxs=pad_idxs)[ - :, -1 - ] - - return logits - - -def update_seq_ids_( - seq: torch.Tensor, - idx: int, - next_token_ids: torch.Tensor, - dim_tok_inserted: list, - eos_tok_seen: list, - max_len: int, - force_end: bool, - tokenizer: Tokenizer, -): - # Insert dim and pad toks - for _idx in range(seq.shape[0]): - if eos_tok_seen[_idx] == True: - next_token_ids[_idx] = tokenizer.tok_to_id[tokenizer.pad_tok] - elif ( - force_end - and idx >= max_len - 130 - and dim_tok_inserted[_idx] is False - and tokenizer.id_to_tok[next_token_ids[_idx].item()][0] - not in ("dur", "onset") - ): - next_token_ids[_idx] = tokenizer.tok_to_id[tokenizer.dim_tok] - - # Update dim_tok_inserted and eos_tok_seen - if next_token_ids[_idx] == tokenizer.tok_to_id[tokenizer.dim_tok]: - dim_tok_inserted[_idx] = True - elif next_token_ids[_idx] == tokenizer.tok_to_id[tokenizer.eos_tok]: - eos_tok_seen[_idx] = True - - seq[:, idx] = next_token_ids - - -@torch.autocast( - "cuda", - dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16, -) -@torch.inference_mode() -def sample_batch( - model: TransformerLM, - tokenizer: Tokenizer, - prompts: List[list], - max_new_tokens: int, - force_end=False, - temp: float = 0.95, - embedding: list[float] | None = None, - top_p: float | None = None, - min_p: float | None = None, - compile: bool = False, -): - assert top_p is not None or min_p is not None - if top_p is not None: - assert 0.5 <= top_p <= 1.0 - if min_p is not None: - assert 0.0 <= min_p <= 1.0 - if temp is not None: - assert 0.0 <= temp <= 2.0 - if force_end: - assert max_new_tokens > 130, "prompt too long to use force_end=True" - - prompt_len = len(prompts[0]) - num_prompts = len(prompts) - assert all([len(p) == prompt_len for p in prompts]) - - model.eval() - dim_tok_inserted = [False for _ in range(num_prompts)] - eos_tok_seen = [False for _ in range(num_prompts)] - total_len = prompt_len + max_new_tokens - seq = torch.stack( - [ - torch.tensor( - tokenizer.encode(p + [tokenizer.pad_tok] * (total_len - len(p))) - ) - for p in prompts - ] - ).cuda() - - if compile is True: - global decode_one - decode_one = torch.compile( - decode_one, - mode="reduce-overhead", - fullgraph=True, - ) - - model.setup_cache( - batch_size=num_prompts, - max_seq_len=total_len, - dtype=( - torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 - ), - ) - - if embedding: - condition_embedding = torch.tensor( - [embedding for _ in range(num_prompts)], device=seq.device - ) - model.fill_condition_kv(cond_emb=condition_embedding) - emb_offset = 1 - else: - emb_offset = 0 - - print( - f"Using hyperparams: temp={temp}, top_p={top_p}, min_p={min_p}, gen_len={max_new_tokens}" - ) - - for idx in ( - pbar := tqdm( - range(prompt_len, total_len), - total=total_len - prompt_len, - leave=False, - ) - ): - with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): - if idx == prompt_len: - logits = prefill( - model, - idxs=seq[:, :idx], - input_pos=torch.arange( - emb_offset, idx + emb_offset, device=seq.device - ), - ) - else: - logits = decode_one( - model, - idxs=seq[:, idx - 1 : idx], - input_pos=torch.tensor( - [(idx + emb_offset) - 1], - device=seq.device, - dtype=torch.int, - ), - ) - - if temp > 0.0: - probs = torch.softmax(logits / temp, dim=-1) - if min_p is not None: - next_token_ids = sample_min_p(probs, min_p).flatten() - else: - next_token_ids = sample_top_p(probs, top_p).flatten() - else: - next_token_ids = torch.argmax(logits, dim=-1).flatten() - - update_seq_ids_( - seq=seq, - idx=idx, - next_token_ids=next_token_ids, - dim_tok_inserted=dim_tok_inserted, - eos_tok_seen=eos_tok_seen, - max_len=total_len, - force_end=force_end, - tokenizer=tokenizer, - ) - - if all(seen_eos is True for seen_eos in eos_tok_seen): - break - - decoded_results = [tokenizer.decode(s) for s in seq.tolist()] - decoded_results = [ - ( - res[: res.index(tokenizer.eos_tok) + 1] - if tokenizer.eos_tok in res - else res - ) - for res in decoded_results - ] - - return decoded_results - - -# Not tested but I think this works -@torch.autocast( - "cuda", - dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16, -) -@torch.inference_mode() -def sample_batch_cfg( - model: TransformerLM, - tokenizer: AbsTokenizer, - prompts: List[list], - max_new_tokens: int, - cfg_gamma: float, - embedding: list[float], - force_end=False, - temp: float = 0.95, - top_p: float | None = None, - min_p: float | None = None, - compile: bool = False, -): - assert 0.0 <= cfg_gamma <= 15.0 - assert top_p is not None or min_p is not None - if top_p is not None: - assert 0.5 <= top_p <= 1.0 - if temp is not None: - assert 0.0 <= temp <= 2.0 - if force_end: - assert max_new_tokens > 130, "prompt too long to use force_end=True" - - prompts = get_cfg_prompt(prompts) - - prompt_len = len(prompts[0]) - num_prompts = len(prompts) - assert all([len(p) == prompt_len for p in prompts]) - - model.eval() - total_context_len = prompt_len + max_new_tokens - seq = torch.stack( - [ - torch.tensor( - tokenizer.encode( - p + [tokenizer.pad_tok] * (total_context_len - len(p)) - ) - ) - for p in prompts - ] - ).cuda() - dim_tok_inserted = [False for _ in range(num_prompts)] - eos_tok_seen = [False for _ in range(num_prompts)] - - if compile is True: - global decode_one - decode_one = torch.compile( - decode_one, - mode="reduce-overhead", - fullgraph=True, - ) - - model.setup_cache( - batch_size=num_prompts, - max_seq_len=total_context_len, - dtype=( - torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 - ), - ) - - condition_embedding = torch.tensor( - [embedding for _ in range(num_prompts)], device=seq.device - ) - model.fill_condition_kv(cond_emb=condition_embedding) - embedding_offset = 1 - pad_idxs = torch.zeros_like(seq, dtype=torch.bool) - pad_idxs[1::2, 0] = True - - print( - f"Using hyperparams: temp={temp}, top_p={top_p}, min_p={min_p}, gamma={cfg_gamma}, gen_len={max_new_tokens}" - ) - - CFG_WARM_UP_STEPS = 250 - curr_step = 0 - for idx in ( - pbar := tqdm( - range(prompt_len, total_context_len), - total=total_context_len - prompt_len, - leave=False, - ) - ): - with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): - if idx == prompt_len: - logits = prefill( - model, - idxs=seq[:, :idx], - input_pos=torch.arange( - embedding_offset, - idx + embedding_offset, - device=seq.device, - ), - pad_idxs=pad_idxs, - ) - else: - logits = decode_one( - model, - idxs=seq[:, idx - 1 : idx], - input_pos=torch.tensor( - [(idx + embedding_offset) - 1], - device=seq.device, - dtype=torch.int, - ), - pad_idxs=pad_idxs, - ) - - curr_step += 1 - _cfg_gamma = min(cfg_gamma, (curr_step / CFG_WARM_UP_STEPS) * cfg_gamma) - - logits_cfg = _cfg_gamma * logits[::2] + (1 - _cfg_gamma) * logits[1::2] - logits_cfg[:, tokenizer.tok_to_id[tokenizer.dim_tok]] = float("-inf") - - if temp > 0.0: - probs = torch.softmax(logits_cfg / temp, dim=-1) - if min_p is not None: - next_token_ids = sample_min_p(probs, min_p).flatten() - else: - next_token_ids = sample_top_p(probs, top_p).flatten() - else: - next_token_ids = torch.argmax(logits_cfg, dim=-1).flatten() - - next_token_ids = next_token_ids.repeat_interleave(2) - update_seq_ids_( - seq=seq, - idx=idx, - next_token_ids=next_token_ids, - dim_tok_inserted=dim_tok_inserted, - eos_tok_seen=eos_tok_seen, - max_len=total_context_len, - force_end=force_end, - tokenizer=tokenizer, - ) - - if all(seen_eos is True for seen_eos in eos_tok_seen): - break - - decoded_results = [tokenizer.decode(s) for s in seq.tolist()][::2] - decoded_results = [ - ( - res[: res.index(tokenizer.eos_tok) + 1] - if tokenizer.eos_tok in res - else res - ) - for res in decoded_results - ] - - return decoded_results - - -def sample_min_p(probs, p_base): - """See - https://arxiv.org/pdf/2407.01082""" - p_max, _ = torch.max(probs, dim=-1, keepdim=True) - p_scaled = p_base * p_max - mask = probs >= p_scaled - - masked_probs = probs.clone() - masked_probs[~mask] = 0.0 - masked_probs.div_(masked_probs.sum(dim=-1, keepdim=True)) - next_token = torch.multinomial(masked_probs, num_samples=1) - - return next_token - - -def sample_top_p(probs, p): - probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) - probs_sum = torch.cumsum(probs_sort, dim=-1) - mask = probs_sum - probs_sort > p - probs_sort[mask] = 0.0 - - probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) - next_token = torch.multinomial(probs_sort, num_samples=1) - next_token = torch.gather(probs_idx, -1, next_token) - - return next_token - - -def get_inference_prompt( - midi_dict: MidiDict, tokenizer: AbsTokenizer, prompt_len_ms: int -): - midi_dict.note_msgs = [ - msg - for msg in midi_dict.note_msgs - if midi_dict.tick_to_ms(msg["data"]["start"]) <= prompt_len_ms - ] - - if len(midi_dict.note_msgs) == 0: - return [("prefix", "instrument", "piano"), tokenizer.bos_tok] - - seq = tokenizer.tokenize(midi_dict=midi_dict) - if tokenizer.dim_tok in seq: - seq.remove(tokenizer.dim_tok) - if tokenizer.eos_tok in seq: - seq.remove(tokenizer.eos_tok) - - return seq diff --git a/aria/train.py b/aria/train.py deleted file mode 100644 index 222eab1..0000000 --- a/aria/train.py +++ /dev/null @@ -1,903 +0,0 @@ -import os -import sys -import csv -import argparse -import logging -import random -import torch -import accelerate - -from torch import nn as nn -from torch.utils.data import DataLoader - -from accelerate.logging import get_logger -from safetensors.torch import load_file -from logging.handlers import RotatingFileHandler -from tqdm import tqdm -from typing import List - -from aria.config import load_model_config -from aria.model import ModelConfig, TransformerLM, TransformerLM_CND -from ariautils.tokenizer import Tokenizer, AbsTokenizer, RelTokenizer -from aria.datasets import ( - TrainingDataset, - PretrainingDataset, -) -from aria.utils import _load_weight - -torch._dynamo.config.optimize_ddp = False - - -# ----- USAGE ----- -# -# This script is meant to be run using the huggingface accelerate cli, see: -# -# https://huggingface.co/docs/accelerate/basic_tutorials/launch -# https://huggingface.co/docs/accelerate/package_reference/cli -# -# For example usage you could run the pre-training script with: -# -# accelerate launch [arguments] aria/train.py train \ -# small \ -# -train_data data/train \ -# -val_data data/val \ -# -epochs 10 \ -# -bs 32 \ -# -workers 8 -# -# You could resume a run from an accelerate checkpoint with: -# -# accelerate launch [arguments] aria/train.py resume \ -# small \ -# -train_data data/train \ -# -val_data data/val \ -# -cp_dir models/epoch5_step0 \ -# -r_step 0 \ -# -r_epoch 5 \ -# -epochs 5 \ -# -bs 32 \ -# -workers 8 - - -def setup_logger(project_dir: str): - # Get logger and reset all handlers - logger = logging.getLogger(__name__) - for h in logger.handlers[:]: - logger.removeHandler(h) - - logger.propagate = False - logger.setLevel(logging.DEBUG) - formatter = logging.Formatter( - "[%(asctime)s] %(name)s: [%(levelname)s] %(message)s", - ) - - fh = RotatingFileHandler( - os.path.join(project_dir, "logs.txt"), backupCount=5, maxBytes=1024**3 - ) - fh.setLevel(logging.DEBUG) - fh.setFormatter(formatter) - logger.addHandler(fh) - - ch = logging.StreamHandler() - ch.setLevel(logging.INFO) - ch.setFormatter(formatter) - logger.addHandler(ch) - - return get_logger(__name__) # using accelerate.logging.get_logger() - - -def get_tokenizer_name( - train_data_paths: str, - val_data_path: str, -): - """This will throw an error if there is a tokenizer mismatch""" - train_config = TrainingDataset.get_config_from_path(train_data_paths[0]) - val_config = TrainingDataset.get_config_from_path(val_data_path) - - assert ( - train_config["tokenizer_name"] == val_config["tokenizer_name"] - ), "Dataset tokenizers don't match" - - return train_config["tokenizer_name"] - - -def setup_project_dir(project_dir: str | None): - if not project_dir: - # Create project directory - if not os.path.isdir("./experiments"): - os.mkdir("./experiments") - - project_dirs = [ - _dir - for _dir in os.listdir("./experiments") - if os.path.isdir(os.path.join("experiments", _dir)) - ] - - ind = 0 - while True: - if str(ind) not in project_dirs: - break - else: - ind += 1 - - project_dir_abs = os.path.abspath(os.path.join("experiments", str(ind))) - assert not os.path.isdir(project_dir_abs) - os.mkdir(project_dir_abs) - - elif project_dir: - # Run checks on project directory - if os.path.isdir(project_dir): - assert ( - len(os.listdir(project_dir)) == 0 - ), "Provided project directory is not empty" - project_dir_abs = os.path.abspath(project_dir) - elif os.path.isfile(project_dir): - raise FileExistsError( - "The provided path points toward an existing file" - ) - else: - try: - os.mkdir(project_dir) - except Exception as e: - raise e(f"Failed to create project directory at {project_dir}") - project_dir_abs = os.path.abspath(project_dir) - - os.mkdir(os.path.join(project_dir_abs, "checkpoints")) - - return project_dir_abs - - -def _get_optim( - lr: float, - model: nn.Module, - num_epochs: int, - steps_per_epoch: int, - warmup: int = 100, - end_ratio: int = 0.1, -): - optimizer = torch.optim.AdamW( - model.parameters(), - lr=lr, - weight_decay=0.1, - betas=(0.9, 0.95), - eps=1e-5, - ) - - warmup_lrs = torch.optim.lr_scheduler.LinearLR( - optimizer, - start_factor=0.000001, - end_factor=1, - total_iters=warmup, - ) - linear_decay_lrs = torch.optim.lr_scheduler.LinearLR( - optimizer, - start_factor=1, - end_factor=end_ratio, - total_iters=(num_epochs * steps_per_epoch) - warmup, - ) - - lr_scheduler = torch.optim.lr_scheduler.SequentialLR( - optimizer, - schedulers=[warmup_lrs, linear_decay_lrs], - milestones=[warmup], - ) - - return optimizer, lr_scheduler - - -def get_optim( - model: nn.Module, - num_epochs: int, - steps_per_epoch: int, -): - LR = 3e-4 - END_RATIO = 0.1 - WARMUP_STEPS = 200 - - return _get_optim( - lr=LR, - model=model, - num_epochs=num_epochs, - steps_per_epoch=steps_per_epoch, - warmup=WARMUP_STEPS, - end_ratio=END_RATIO, - ) - - -def get_dataloaders( - train_data_dirs: List[str], - val_data_dir: str, - tokenizer: Tokenizer, - batch_size: int, - num_workers: int, - use_embeddings: bool, - init_epoch: int | None = None, - apply_aug: bool = True, -): - train_dataset = PretrainingDataset( - dir_paths=train_data_dirs, - tokenizer=tokenizer, - ) - val_dataset = PretrainingDataset( - dir_paths=val_data_dir, - tokenizer=tokenizer, - ) - - if init_epoch: - train_dataset.init_epoch(idx=init_epoch) - - assert ( - len(val_dataset.epoch_files_by_dir[0]) == 1 - ), "val-data directory should only contain one epoch" - - if apply_aug: - train_dataset.set_transform(tokenizer.export_data_aug()) - - train_dataloader = DataLoader( - train_dataset, - batch_size=batch_size, - num_workers=num_workers, - shuffle=True, - ) - val_dataloader = DataLoader( - val_dataset, - batch_size=batch_size, - num_workers=num_workers, - shuffle=False, - ) - - if use_embeddings is True: - _src, _tgt, _mask, _emb = train_dataset[0] - _src, _tgt, _mask, __emb = val_dataset[0] - assert _emb.numel() != 0, "Embeddings not present in train dataset" - assert __emb.numel() != 0, "Embeddings not present in val dataset" - - return train_dataloader, val_dataloader - - -def _train( - epochs: int, - accelerator: accelerate.Accelerator, - model: TransformerLM, - train_dataloader: DataLoader, - val_dataloader: DataLoader, - use_embeddings: bool, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler.LRScheduler = None, - steps_per_checkpoint: int | None = None, - resume_step: int | None = None, - resume_epoch: int | None = None, - project_dir: str | None = None, -): - def make_checkpoint( - _accelerator: accelerate.Accelerator, _epoch: int, _step: int - ): - if accelerator.is_main_process: - checkpoint_dir = os.path.join( - project_dir, - "checkpoints", - f"epoch{_epoch}_step{_step}", - ) - - logger.info( - f"EPOCH {_epoch}/{epochs + start_epoch}: Saving checkpoint - {checkpoint_dir}" - ) - _accelerator.save_state(checkpoint_dir) - - # This is all slightly messy as train_loop and val_loop make use of the - # variables in the wider scope. Perhaps refactor this at some point. - def train_loop(dataloader: DataLoader, _epoch: int, _resume_step: int = 0): - loss = torch.tensor([0.0]) - avg_train_loss = 0 - trailing_loss = 0 - loss_buffer = [] - - try: - lr_for_print = "{:.2e}".format(scheduler.get_last_lr()[0]) - except Exception: - pass - else: - lr_for_print = "{:.2e}".format(optimizer.param_groups[-1]["lr"]) - - model.train() - for __step, batch in ( - pbar := tqdm( - enumerate(dataloader), - total=len(dataloader) + _resume_step, - initial=_resume_step, - leave=False, - ) - ): - pbar.set_postfix_str( - f"lr={lr_for_print}, " - f"loss={round(loss.item(), 4)}, " - f"trailing={round(trailing_loss, 4)}" - ) - - with accelerator.accumulate(model): - step = __step + _resume_step + 1 - src, tgt, mask, emb = ( - batch # (b_sz, s_len), (b_sz, s_len), (b_sz, s_len), (b_sz, d_emb) - ) - - use_embeddings_cond = use_embeddings and (random.random() > 0.5) - - if use_embeddings_cond is True: - logits = model(src=src, emb=emb) # (b_sz, s_len - 1, v_sz) - tgt = tgt[:, :-1] # (b_sz, s_len - 1) - mask = mask[:, :-1] # (b_sz, s_len - 1) - else: - logits = model(src) # (b_sz, s_len, v_sz) - - logits = logits.transpose( - 1, 2 - ) # Transpose for CrossEntropyLoss - loss = loss_fn(logits, tgt) - - if mask.sum() == 0: - loss = (loss * 0).sum() - else: - loss = loss * mask - loss = loss[loss != 0.0].mean() - - # Calculate statistics - loss_buffer.append(accelerator.gather(loss).mean(dim=0).item()) - trailing_loss = sum(loss_buffer[-TRAILING_LOSS_STEPS:]) / len( - loss_buffer[-TRAILING_LOSS_STEPS:] - ) - avg_train_loss = sum(loss_buffer) / len(loss_buffer) - - # Logging - logger.debug( - f"EPOCH {_epoch} STEP {step}: " - f"lr={lr_for_print}, " - f"loss={round(loss.item(), 4)}, " - f"trailing_loss={round(trailing_loss, 4)}, " - f"average_loss={round(avg_train_loss, 4)}" - ) - - if accelerator.is_main_process: - loss_writer.writerow([_epoch, step, loss.item()]) - - accelerator.backward(loss) - optimizer.step() - optimizer.zero_grad() - if scheduler: - scheduler.step() - lr_for_print = "{:.2e}".format(scheduler.get_last_lr()[0]) - - if steps_per_checkpoint: - if step % steps_per_checkpoint == 0: - make_checkpoint( - _accelerator=accelerator, - _epoch=_epoch, - _step=step, - ) - - logger.info( - f"EPOCH {_epoch}/{epochs + start_epoch}: Finished training - " - f"average_loss={round(avg_train_loss, 4)}" - ) - - return avg_train_loss - - @torch.no_grad() - def val_loop(dataloader, _epoch: int): - loss_buffer = [] - model.eval() - for step, batch in ( - pbar := tqdm( - enumerate(dataloader), - total=len(dataloader), - leave=False, - ) - ): - src, tgt, mask, emb = ( - batch # (b_sz, s_len), (b_sz, s_len), (b_sz, s_len), (b_sz, d_emb) - ) - use_embeddings_cond = use_embeddings and (random.random() > 0.5) - - if use_embeddings_cond is True: - logits = model(src=src, emb=emb) # (b_sz, s_len - 1, v_sz) - tgt = tgt[:, :-1] # (b_sz, s_len - 1) - mask = mask[:, :-1] # (b_sz, s_len - 1) - else: - logits = model(src) # (b_sz, s_len, v_sz) - - logits = logits.transpose(1, 2) # Transpose for CrossEntropyLoss - loss = loss_fn(logits, tgt) - - if mask.sum() == 0: - loss = (loss * 0).sum() - else: - loss = loss * mask - loss = loss[loss != 0.0].mean() - - # Logging - loss_buffer.append(accelerator.gather(loss).mean(dim=0).item()) - avg_val_loss = sum(loss_buffer) / len(loss_buffer) - pbar.set_postfix_str(f"average_loss={round(avg_val_loss, 4)}") - - # EPOCH - logger.info( - f"EPOCH {_epoch}/{epochs + start_epoch}: Finished evaluation - " - f"average_loss={round(avg_val_loss, 4)}" - ) - - return avg_val_loss - - if steps_per_checkpoint: - assert ( - steps_per_checkpoint > 1 - ), "Invalid checkpoint mode value (too small)" - - TRAILING_LOSS_STEPS = 200 - PAD_ID = train_dataloader.dataset.tokenizer.pad_id - logger = get_logger(__name__) # Accelerate logger - loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_ID, reduction="none") - - logger.info( - f"Model has " - f"{'{:,}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad))} " - "parameters" - ) - - if accelerator.is_main_process: - loss_csv = open(os.path.join(project_dir, "loss.csv"), "w") - loss_writer = csv.writer(loss_csv) - loss_writer.writerow(["epoch", "step", "loss"]) - epoch_csv = open(os.path.join(project_dir, "epoch.csv"), "w") - epoch_writer = csv.writer(epoch_csv) - epoch_writer.writerow(["epoch", "avg_train_loss", "avg_val_loss"]) - - if resume_epoch is not None: - start_epoch = resume_epoch + 1 - else: - start_epoch = 0 - - if resume_step is not None: - assert resume_epoch is not None, "Must provide resume epoch" - logger.info( - f"Resuming training from step {resume_step} - logging as EPOCH {resume_epoch}" - ) - skipped_dataloader = accelerator.skip_first_batches( - dataloader=train_dataloader, - num_batches=resume_step, - ) - - avg_train_loss = train_loop( - dataloader=skipped_dataloader, - _epoch=resume_epoch, - _resume_step=resume_step, - ) - avg_val_loss = val_loop(dataloader=val_dataloader, _epoch=resume_epoch) - if accelerator.is_main_process: - epoch_writer.writerow([resume_epoch, avg_train_loss, avg_val_loss]) - epoch_csv.flush() - make_checkpoint( - _accelerator=accelerator, _epoch=start_epoch, _step=0 - ) - - for epoch in range(start_epoch, epochs + start_epoch): - train_dataloader.dataset.init_epoch(epoch) - avg_train_loss = train_loop(dataloader=train_dataloader, _epoch=epoch) - avg_val_loss = val_loop(dataloader=val_dataloader, _epoch=epoch) - if accelerator.is_main_process: - epoch_writer.writerow([epoch, avg_train_loss, avg_val_loss]) - epoch_csv.flush() - make_checkpoint(_accelerator=accelerator, _epoch=epoch + 1, _step=0) - - logging.shutdown() - if accelerator.is_main_process: - loss_csv.close() - epoch_csv.close() - - -# TODO: Add use_embeddings logic to this code path -def resume_train( - model_name: str, - train_data_paths: str, - val_data_path: str, - use_embeddings: bool, - num_workers: int, - batch_size: int, - grad_acc_steps: int, - epochs: int, - checkpoint_dir: str, - resume_epoch: int, - resume_step: int, - steps_per_checkpoint: int | None = None, - project_dir: str = None, -): - # Validate inputs - assert 0 < num_workers <= 128, "Too many workers" - assert epochs > 0, "Invalid number of epochs" - assert batch_size > 0, "Invalid batch size" - assert torch.cuda.is_available() is True, "CUDA not available" - assert os.path.isdir(checkpoint_dir), f"No dir at {checkpoint_dir}" - for train_data_path in train_data_paths: - assert os.path.isdir( - train_data_path - ), f"No dir found at {train_data_path}" - assert os.path.isdir(val_data_path), f"No dir found at {val_data_path}" - - tokenizer_name = get_tokenizer_name(train_data_paths, val_data_path) - if tokenizer_name == "abs": - tokenizer = AbsTokenizer() - elif tokenizer_name == "rel": - tokenizer = RelTokenizer() - else: - raise Exception("Invalid tokenizer name") - - accelerator = accelerate.Accelerator( - project_dir=project_dir, gradient_accumulation_steps=grad_acc_steps - ) - if accelerator.is_main_process: - project_dir = setup_project_dir(project_dir) - logger = setup_logger(project_dir) - - logger = get_logger(__name__) - logger.info(f"Using project directory {project_dir} ") - logger.warning( - "Please insure that the training config and resume step are set " - "correctly, the script does not currently check that this is the case. " - "If the previous checkpoint was saved at step n, then resume_step " - "should be n. If there is a mismatch between the batch size then the " - "script will resume at the wrong step. It is also important that the " - "same distributed setup is used for training." - ) - logger.info( - f"Using training config: " - f"model_name={model_name}, " - f"use_embeddings={use_embeddings}, " - f"epochs={epochs}, " - f"batch_size={batch_size}, " - f"grad_acc_steps={grad_acc_steps}, " - f"num_workers={num_workers}, " - f"checkpoint_dir={checkpoint_dir}, " - f"resume_step={resume_step}, " - f"resume_epoch={resume_epoch}" - ) - - if steps_per_checkpoint: - logger.info(f"Creating checkpoints every {steps_per_checkpoint}") - - # Init model - model_config = ModelConfig(**load_model_config(model_name)) - model_config.set_vocab_size(tokenizer.vocab_size) - - if use_embeddings: - model = TransformerLM_CND(model_config) - else: - model = TransformerLM(model_config) - - model.compile() - - train_dataloader, val_dataloader = get_dataloaders( - train_data_dirs=train_data_paths, - val_data_dir=val_data_path, - tokenizer=tokenizer, - init_epoch=resume_epoch, - batch_size=batch_size, - num_workers=num_workers, - apply_aug=True, - use_embeddings=use_embeddings, - ) - optimizer, scheduler = get_optim( - model, - num_epochs=epochs, - steps_per_epoch=len(train_dataloader), - ) - - ( - model, - train_dataloader, - val_dataloader, - optimizer, - scheduler, - ) = accelerator.prepare( - model, - train_dataloader, - val_dataloader, - optimizer, - scheduler, - ) - - try: - accelerator.load_state(checkpoint_dir) - except Exception as e: - raise Exception( - f"Failed to load checkpoint: {e}\n" - "This could be due to a mismatch between the tokenizer used " - "to build the pre-training and fine-tuning datasets" - ) - logger.info(f"Loaded checkpoint at {checkpoint_dir}") - logger.info("Starting train job") - - _train( - epochs=epochs, - accelerator=accelerator, - model=model, - train_dataloader=train_dataloader, - val_dataloader=val_dataloader, - use_embeddings=use_embeddings, - optimizer=optimizer, - scheduler=scheduler, - steps_per_checkpoint=steps_per_checkpoint, - resume_step=resume_step, - resume_epoch=resume_epoch, - project_dir=project_dir, - ) - - -def train( - model_name: str, - train_data_paths: List[str], - val_data_path: str, - use_embeddings: bool, - num_workers: int, - batch_size: int, - grad_acc_steps: int, - epochs: int, - checkpoint_path: str | None = None, - steps_per_checkpoint: int | None = None, - project_dir: str = None, -): - # Validate inputs - assert 0 < num_workers <= 128, "Too many workers" - assert epochs > 0, "Invalid number of epochs" - assert batch_size > 0, "Invalid batch size" - assert torch.cuda.is_available() is True, "CUDA not available" - for train_data_path in train_data_paths: - assert os.path.isdir( - train_data_path - ), f"No dir found at {train_data_path}" - assert os.path.isdir(val_data_path), f"No dir found at {val_data_path}" - - tokenizer_name = get_tokenizer_name(train_data_paths, val_data_path) - if tokenizer_name == "abs": - tokenizer = AbsTokenizer() - elif tokenizer_name == "rel": - tokenizer = RelTokenizer() - else: - raise Exception("Invalid tokenizer name") - - accelerator = accelerate.Accelerator( - project_dir=project_dir, gradient_accumulation_steps=grad_acc_steps - ) - if accelerator.is_main_process: - project_dir = setup_project_dir(project_dir) - logger = setup_logger(project_dir) - - logger = get_logger(__name__) - logger.info(f"Using project directory {project_dir}") - logger.info( - f"Using training config: " - f"model_name={model_name}, " - f"use_embeddings={use_embeddings}, " - f"checkpoint_path={checkpoint_path}, " - if checkpoint_path - else "" - f"epochs={epochs}, " - f"batch_size={batch_size}, " - f"grad_acc_steps={grad_acc_steps}, " - f"num_workers={num_workers}" - ) - - if steps_per_checkpoint: - logger.info(f"Creating checkpoints every {steps_per_checkpoint}") - - # Init model - model_config = ModelConfig(**load_model_config(model_name)) - model_config.set_vocab_size(tokenizer.vocab_size) - - if use_embeddings is True: - model = TransformerLM_CND(model_config) - else: - model = TransformerLM(model_config) - - model.compile() - logger.info(f"Loaded model with config: {load_model_config(model_name)}") - if checkpoint_path: - try: - model.load_state_dict(_load_weight(checkpoint_path)) - except RuntimeError as e: - print(e) - logger.info( - f"Failed to load {model_name} into {model_name}, attempting with strict=False" - ) - model.load_state_dict(_load_weight(checkpoint_path), strict=False) - - logger.info(f"Loaded finetune checkpoint located at: {checkpoint_path}") - - train_dataloader, val_dataloader = get_dataloaders( - train_data_dirs=train_data_paths, - val_data_dir=val_data_path, - tokenizer=tokenizer, - batch_size=batch_size, - num_workers=num_workers, - apply_aug=False, - use_embeddings=use_embeddings, - ) - - assert ( - train_dataloader.dataset.config["max_seq_len"] - == model_config.max_seq_len - ) - assert ( - val_dataloader.dataset.config["max_seq_len"] == model_config.max_seq_len - ) - - optimizer, scheduler = get_optim( - model, - num_epochs=epochs, - steps_per_epoch=len(train_dataloader), - ) - - ( - model, - train_dataloader, - val_dataloader, - optimizer, - scheduler, - ) = accelerator.prepare( - model, - train_dataloader, - val_dataloader, - optimizer, - scheduler, - ) - - logger.info(f"Starting {'finetune' if checkpoint_path else 'pretrain'} job") - _train( - epochs=epochs, - accelerator=accelerator, - model=model, - train_dataloader=train_dataloader, - val_dataloader=val_dataloader, - use_embeddings=use_embeddings, - optimizer=optimizer, - scheduler=scheduler, - steps_per_checkpoint=steps_per_checkpoint, - project_dir=project_dir, - ) - - -def convert_cp_from_safetensors(checkpoint_path: str, save_path: str): - d = load_file(checkpoint_path) - key = list(d.keys())[0] - gap = len(key.split(".")[0]) - d = {s[gap + 1 :]: v for s, v in d.items()} - torch.save(d, save_path) - - -def convert_cp_from_accelerate( - model_name: str, tokenizer_name: str, checkpoint_dir: str, save_path: str -): - def _load_state_dict(_tokenizer: Tokenizer): - model_config = ModelConfig(**load_model_config(model_name)) - model_config.set_vocab_size(_tokenizer.vocab_size) - model = TransformerLM(model_config) - model = accelerator.prepare(model) - accelerator.load_state(checkpoint_dir) - - return model.state_dict() - - accelerator = accelerate.Accelerator() - - # Try both - if tokenizer_name == "abs": - state_dict = _load_state_dict(_tokenizer=AbsTokenizer()) - elif tokenizer_name == "rel": - state_dict = _load_state_dict(_tokenizer=RelTokenizer()) - else: - print("Invalid choice of tokenizer") - - torch.save(state_dict, save_path) - - -def parse_resume_args(): - argp = argparse.ArgumentParser(prog="python aria/train.py resume") - argp.add_argument("model", help="name of model config file") - argp.add_argument("-train_data", nargs="+", help="path to train dir") - argp.add_argument("-val_data", help="path to val dir") - argp.add_argument("-cp_dir", help="checkpoint dir", type=str, required=True) - argp.add_argument( - "-use_embeddings", help="prepend embeddings", action="store_true" - ) - argp.add_argument("-r_step", help="resume step", type=int, required=True) - argp.add_argument("-r_epoch", help="resume epoch", type=int, required=True) - argp.add_argument("-epochs", help="train epochs", type=int, required=True) - argp.add_argument("-bs", help="batch size", type=int, default=32) - argp.add_argument( - "-grad_acc_steps", - help="gradient accumulation steps", - type=int, - default=1, - ) - argp.add_argument("-workers", help="number workers", type=int, default=1) - argp.add_argument("-pdir", help="project dir", type=str, required=False) - argp.add_argument( - "-spc", help="steps per checkpoint", type=int, required=False - ) - - return argp.parse_args(sys.argv[2:]) - - -def parse_train_args(): - argp = argparse.ArgumentParser(prog="python aria/train.py train") - argp.add_argument("model", help="name of model config file") - argp.add_argument("-train_data", nargs="+", help="path to train dir") - argp.add_argument("-val_data", help="path to val dir") - argp.add_argument( - "-cp_path", help="path to checkpoint", required=False, default=None - ) - argp.add_argument( - "-use_embeddings", help="prepend embeddings", action="store_true" - ) - argp.add_argument("-epochs", help="train epochs", type=int, required=True) - argp.add_argument("-bs", help="batch size", type=int, default=32) - argp.add_argument( - "-grad_acc_steps", - help="gradient accumulation steps", - type=int, - default=1, - ) - argp.add_argument("-workers", help="number workers", type=int, default=1) - argp.add_argument("-pdir", help="project dir", type=str, required=False) - argp.add_argument( - "-spc", help="steps per checkpoint", type=int, required=False - ) - - return argp.parse_args(sys.argv[2:]) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - usage="python aria/train.py []" - ) - parser.add_argument( - "mode", help="training function", choices=("train", "resume") - ) - - args = parser.parse_args(sys.argv[1:2]) - if not hasattr(args, "mode"): - parser.print_help() - print("Unrecognized command") - exit(1) - elif args.mode == "train": - train_args = parse_train_args() - train( - model_name=train_args.model, - train_data_paths=train_args.train_data, - use_embeddings=train_args.use_embeddings, - val_data_path=train_args.val_data, - num_workers=train_args.workers, - batch_size=train_args.bs, - grad_acc_steps=train_args.grad_acc_steps, - epochs=train_args.epochs, - checkpoint_path=train_args.cp_path, - steps_per_checkpoint=train_args.spc, - project_dir=train_args.pdir, - ) - elif args.mode == "resume": - resume_args = parse_resume_args() - resume_train( - model_name=resume_args.model, - train_data_paths=resume_args.train_data, - val_data_path=resume_args.val_data, - use_embeddings=resume_args.use_embeddings, - num_workers=resume_args.workers, - batch_size=resume_args.bs, - grad_acc_steps=resume_args.grad_acc_steps, - epochs=resume_args.epochs, - checkpoint_dir=resume_args.cp_dir, - resume_step=resume_args.r_step, - resume_epoch=resume_args.r_epoch, - steps_per_checkpoint=resume_args.spc, - project_dir=resume_args.pdir, - ) - else: - print("Unrecognized command") - parser.print_help() - exit(1) From 3614a8ba38e46d6d40c2fe66cb8e19168f622869 Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 3 Jun 2025 12:09:06 +0000 Subject: [PATCH 55/72] rm scripts --- scripts/download_data.sh | 3 --- scripts/midi_to_audio.py | 20 -------------------- scripts/upload_data.sh | 2 -- 3 files changed, 25 deletions(-) delete mode 100644 scripts/download_data.sh delete mode 100644 scripts/midi_to_audio.py delete mode 100644 scripts/upload_data.sh diff --git a/scripts/download_data.sh b/scripts/download_data.sh deleted file mode 100644 index cba0165..0000000 --- a/scripts/download_data.sh +++ /dev/null @@ -1,3 +0,0 @@ -mkdir data -gsutil cp gs://gpt-aria/train_data/train.jsonl data/train.jsonl -gsutil cp gs://gpt-aria/train_data/val.jsonl data/val.jsonl \ No newline at end of file diff --git a/scripts/midi_to_audio.py b/scripts/midi_to_audio.py deleted file mode 100644 index 8221383..0000000 --- a/scripts/midi_to_audio.py +++ /dev/null @@ -1,20 +0,0 @@ -import os - -from aria.utils import midi_to_audio - - -def main(): - root_dir = "/Users/louis/work/data/mid/prompts/survey" - for dirpath, dirnames, filenames in os.walk(root_dir): - for filename in filenames: - if filename.endswith(".mid"): - midi_path = os.path.join(dirpath, filename) - midi_to_audio(midi_path) - - -if __name__ == "__main__": - main() - - -if __name__ == "__main__": - main() diff --git a/scripts/upload_data.sh b/scripts/upload_data.sh deleted file mode 100644 index c037d37..0000000 --- a/scripts/upload_data.sh +++ /dev/null @@ -1,2 +0,0 @@ -gsutil cp data/train.jsonl gs://gpt-aria/train_data/train.jsonl -gsutil cp data/val.jsonl gs://gpt-aria/train_data/val.jsonl \ No newline at end of file From 97e2a5c4758dddc763705a8ccb1f494e2b0f4584 Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 3 Jun 2025 20:00:04 +0000 Subject: [PATCH 56/72] refactor entrypoint for generate --- aria/inference/__init__.py | 58 ++++++ aria/inference/model_cuda.py | 231 +++++++++++---------- aria/inference/model_mlx.py | 35 +++- aria/inference/sample_cuda.py | 151 ++++---------- aria/inference/sample_mlx.py | 134 ++++-------- aria/run.py | 381 ++++++++++++++++++---------------- 6 files changed, 492 insertions(+), 498 deletions(-) diff --git a/aria/inference/__init__.py b/aria/inference/__init__.py index e69de29..ceac4b4 100644 --- a/aria/inference/__init__.py +++ b/aria/inference/__init__.py @@ -0,0 +1,58 @@ +import torch + +from ariautils.tokenizer import AbsTokenizer +from ariautils.midi import MidiDict + + +def sample_min_p(probs: torch.Tensor, p_base: float) -> torch.Tensor: + """See - https://arxiv.org/pdf/2407.01082""" + p_max, _ = torch.max(probs, dim=-1, keepdim=True) + p_scaled = p_base * p_max + mask = probs >= p_scaled + + masked_probs = probs.clone() + masked_probs[~mask] = 0.0 + masked_probs.div_(masked_probs.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(masked_probs, num_samples=1) + + return next_token + + +def sample_top_p(probs: torch.Tensor, top_p: float) -> torch.Tensor: + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > top_p + probs_sort[mask] = 0.0 + + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(probs_sort, num_samples=1) + next_token = torch.gather(probs_idx, -1, next_token) + + return next_token + + +def get_cfg_prompt(prompts: list): + cfg_prompts = [] + for prompt in prompts: + cfg_prompts.append(prompt) + cfg_prompts.append(prompt) + + return cfg_prompts + + +def get_inference_prompt( + midi_dict: MidiDict, tokenizer: AbsTokenizer, prompt_len_ms: int +): + midi_dict.note_msgs = [ + msg + for msg in midi_dict.note_msgs + if midi_dict.tick_to_ms(msg["data"]["start"]) <= prompt_len_ms + ] + + if len(midi_dict.note_msgs) == 0: + return [("prefix", "instrument", "piano"), tokenizer.bos_tok] + + seq = tokenizer.tokenize(midi_dict=midi_dict, add_dim_tok=False) + seq.remove(tokenizer.eos_tok) + + return seq diff --git a/aria/inference/model_cuda.py b/aria/inference/model_cuda.py index 707e200..8dbfbd4 100644 --- a/aria/inference/model_cuda.py +++ b/aria/inference/model_cuda.py @@ -1,4 +1,4 @@ -"""Inference implementation with torch-compiler friendly kv-cache.""" +"""Inference implementation for torch (cuda) backend""" import torch import torch.nn as nn @@ -34,118 +34,6 @@ def update(self, input_pos, k_val, v_val): return k_out, v_out -class TransformerLM(nn.Module): - def __init__(self, model_config: ModelConfig): - super().__init__() - self.model_config = model_config - self.max_seq_len = model_config.max_seq_len - self.model = Transformer(model_config) - self.lm_head = nn.Linear( - model_config.d_model, model_config.vocab_size, bias=False - ) - self.embedding_adapter = nn.Linear( - model_config.emb_size, model_config.d_model, bias=False - ) - - def forward( - self, - idxs: torch.Tensor, - input_pos: torch.Tensor, - pad_idxs: torch.Tensor | None = None, - ): - hidden_states = self.model( - idxs=idxs, - input_pos=input_pos, - pad_idxs=pad_idxs, - ) - logits = self.lm_head(hidden_states) - - return logits - - def fill_condition_kv(self, cond_emb: torch.Tensor): - adapted_emb = self.embedding_adapter(cond_emb) - self.model.fill_condition_kv(emb=adapted_emb) - - def setup_cache( - self, - batch_size: int, - max_seq_len=4096, - dtype=torch.bfloat16, - ): - assert batch_size >= 1 - for b in self.model.encode_layers: - b.kv_cache = KVCache( - max_batch_size=batch_size, - max_seq_length=max_seq_len, - n_heads=self.model_config.n_heads, - head_dim=self.model_config.d_model // self.model_config.n_heads, - dtype=dtype, - ).cuda() - - self.model.freqs_cis = precompute_freqs_cis( - seq_len=max_seq_len, - n_elem=self.model_config.d_model // self.model_config.n_heads, - base=500000, - dtype=dtype, - ).cuda() - self.model.causal_mask = torch.tril( - torch.ones(max_seq_len, max_seq_len, dtype=torch.bool) - ).cuda() - - -class Transformer(nn.Module): - def __init__(self, model_config: ModelConfig) -> None: - super().__init__() - self.model_config = model_config - - self.tok_embeddings = nn.Embedding( - num_embeddings=model_config.vocab_size, - embedding_dim=model_config.d_model, - ) - self.encode_layers = nn.ModuleList( - TransformerBlock(model_config) for _ in range(model_config.n_layers) - ) - self.out_layer_norm = nn.LayerNorm(model_config.d_model) - - self.freqs_cis = None - self.causal_mask = None - - def fill_condition_kv(self, emb: torch.Tensor): - assert self.freqs_cis is not None, "Caches must be initialized first" - - input_pos = torch.tensor([0], device=emb.device) - mask = self.causal_mask[None, None, input_pos] - freqs_cis = self.freqs_cis[input_pos] - - x = emb.unsqueeze(dim=1) - - for layer in self.encode_layers: - x = layer(x, input_pos, freqs_cis, mask) - - def forward( - self, - idxs: torch.Tensor, - input_pos: torch.Tensor, - pad_idxs: torch.Tensor | None = None, - ): - assert self.freqs_cis is not None, "Caches must be initialized first" - - mask = self.causal_mask[None, None, input_pos] - - if pad_idxs is not None: - mask = mask & ~(pad_idxs.unsqueeze(1).unsqueeze(1)) - - freqs_cis = self.freqs_cis[input_pos] - - x = self.tok_embeddings(idxs) - for layer in self.encode_layers: - x = layer(x, input_pos, freqs_cis, mask) - - x = self.out_layer_norm(x) - - return x - - class TransformerBlock(nn.Module): def __init__(self, model_config: ModelConfig) -> None: super().__init__() @@ -256,6 +144,123 @@ def _ff_block(self, x: torch.Tensor): ) +class Transformer(nn.Module): + def __init__(self, model_config: ModelConfig) -> None: + super().__init__() + self.model_config = model_config + + self.tok_embeddings = nn.Embedding( + num_embeddings=model_config.vocab_size, + embedding_dim=model_config.d_model, + ) + self.encode_layers = nn.ModuleList( + TransformerBlock(model_config) for _ in range(model_config.n_layers) + ) + self.out_layer_norm = nn.LayerNorm(model_config.d_model) + + self.freqs_cis = None + self.causal_mask = None + + def fill_condition_kv(self, emb: torch.Tensor): + assert self.freqs_cis is not None, "Caches must be initialized first" + assert self.model_config.emb_size is not None + + input_pos = torch.tensor([0], device=emb.device) + mask = self.causal_mask[None, None, input_pos] + freqs_cis = self.freqs_cis[input_pos] + + x = emb.unsqueeze(dim=1) + + for layer in self.encode_layers: + x = layer(x, input_pos, freqs_cis, mask) + + def forward( + self, + idxs: torch.Tensor, + input_pos: torch.Tensor, + pad_idxs: torch.Tensor | None = None, + ): + assert self.freqs_cis is not None, "Caches must be initialized first" + + mask = self.causal_mask[None, None, input_pos] + + if pad_idxs is not None: + mask = mask & ~(pad_idxs.unsqueeze(1).unsqueeze(1)) + + freqs_cis = self.freqs_cis[input_pos] + + x = self.tok_embeddings(idxs) + for layer in self.encode_layers: + x = layer(x, input_pos, freqs_cis, mask) + + x = self.out_layer_norm(x) + + return x + + +class TransformerLM(nn.Module): + def __init__(self, model_config: ModelConfig): + super().__init__() + self.model_config = model_config + self.max_seq_len = model_config.max_seq_len + self.model = Transformer(model_config) + self.lm_head = nn.Linear( + model_config.d_model, model_config.vocab_size, bias=False + ) + + if model_config.emb_size is not None: + self.embedding_adapter = nn.Linear( + model_config.emb_size, model_config.d_model, bias=False + ) + + def forward( + self, + idxs: torch.Tensor, + input_pos: torch.Tensor, + pad_idxs: torch.Tensor | None = None, + ): + hidden_states = self.model( + idxs=idxs, + input_pos=input_pos, + pad_idxs=pad_idxs, + ) + logits = self.lm_head(hidden_states) + + return logits + + def fill_condition_kv(self, cond_emb: torch.Tensor): + assert self.model_config.emb_size is not None + + adapted_emb = self.embedding_adapter(cond_emb) + self.model.fill_condition_kv(emb=adapted_emb) + + def setup_cache( + self, + batch_size: int, + max_seq_len=8096, + dtype=torch.bfloat16, + ): + assert batch_size >= 1 + for b in self.model.encode_layers: + b.kv_cache = KVCache( + max_batch_size=batch_size, + max_seq_length=max_seq_len, + n_heads=self.model_config.n_heads, + head_dim=self.model_config.d_model // self.model_config.n_heads, + dtype=dtype, + ).cuda() + + self.model.freqs_cis = precompute_freqs_cis( + seq_len=max_seq_len, + n_elem=self.model_config.d_model // self.model_config.n_heads, + base=500000, + dtype=dtype, + ).cuda() + self.model.causal_mask = torch.tril( + torch.ones(max_seq_len, max_seq_len, dtype=torch.bool) + ).cuda() + + def precompute_freqs_cis( seq_len: int, n_elem: int, diff --git a/aria/inference/model_mlx.py b/aria/inference/model_mlx.py index cb80d85..7e306c7 100644 --- a/aria/inference/model_mlx.py +++ b/aria/inference/model_mlx.py @@ -1,10 +1,10 @@ """Inference implementation for mlx backend""" -from aria.model import ModelConfig - import mlx.core as mx import mlx.nn as nn +from aria.model import ModelConfig + class KVCache(nn.Module): def __init__( @@ -13,7 +13,7 @@ def __init__( max_seq_length: int, n_heads: int, head_dim: int, - dtype: mx.Dtype = mx.bfloat16, + dtype: mx.Dtype = mx.float32, ): super().__init__() self.dtype = dtype @@ -159,6 +159,20 @@ def __init__(self, model_config: ModelConfig): TransformerBlock(model_config) for _ in range(model_config.n_layers) ] self.out_layer_norm = nn.LayerNorm(model_config.d_model) + + def fill_condition_kv(self, emb: mx.array): + assert self.causal_mask is not None, "Caches must be initialized first" + assert self.model_config.emb_size is not None + + input_pos = mx.array([0], dtype=mx.int32) + mask = self.causal_mask[None, None, input_pos] + offset = 0 + + x = mx.expand_dims(emb, axis=1) + + for layer in self.encode_layers: + x = layer(x, input_pos, offset, mask) + self.causal_mask = None def __call__( @@ -195,6 +209,11 @@ def __init__(self, model_config: ModelConfig): model_config.d_model, model_config.vocab_size, bias=False ) + if model_config.emb_size is not None: + self.embedding_adapter = nn.Linear( + model_config.emb_size, model_config.d_model, bias=False + ) + def __call__( self, idxs: mx.array, @@ -212,11 +231,17 @@ def __call__( return logits + def fill_condition_kv(self, cond_emb: mx.array): + assert self.model_config.emb_size is not None + + adapted_emb = self.embedding_adapter(cond_emb) + self.model.fill_condition_kv(emb=adapted_emb) + def setup_cache( self, batch_size, - max_seq_len=4096, - dtype=mx.bfloat16, + max_seq_len=8096, + dtype=mx.float32, ): # Init cache for b in self.model.encode_layers: diff --git a/aria/inference/sample_cuda.py b/aria/inference/sample_cuda.py index 3fcedbf..0c49d25 100644 --- a/aria/inference/sample_cuda.py +++ b/aria/inference/sample_cuda.py @@ -1,20 +1,20 @@ """Contains generation/sampling code""" import torch -import torch._dynamo.config import torch._inductor.config -from typing import List from tqdm import tqdm +from aria.inference import sample_min_p, sample_top_p from aria.inference.model_cuda import TransformerLM from ariautils.tokenizer import Tokenizer, AbsTokenizer -from ariautils.midi import MidiDict torch._inductor.config.coordinate_descent_tuning = True torch._inductor.config.triton.unique_kernel_names = True torch._inductor.config.fx_graph_cache = True +DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + def get_cfg_prompt(prompts: list): cfg_prompts = [] @@ -54,7 +54,7 @@ def prefill( idxs=idxs, input_pos=input_pos, pad_idxs=pad_idxs, - )[:, -1] + ) return logits @@ -91,19 +91,16 @@ def update_seq_ids_( seq[:, idx] = next_token_ids -@torch.autocast( - "cuda", - dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16, -) +@torch.autocast("cuda", dtype=DTYPE) @torch.inference_mode() def sample_batch( model: TransformerLM, tokenizer: Tokenizer, - prompts: List[list], + prompt: list, + num_variations: list, max_new_tokens: int, - force_end=False, - temp: float = 0.95, - embedding: list[float] | None = None, + temp: float, + force_end: bool = False, top_p: float | None = None, min_p: float | None = None, compile: bool = False, @@ -118,20 +115,21 @@ def sample_batch( if force_end: assert max_new_tokens > 130, "prompt too long to use force_end=True" - prompt_len = len(prompts[0]) - num_prompts = len(prompts) - assert all([len(p) == prompt_len for p in prompts]) + prompt_len = len(prompt) + model = model.cuda() model.eval() - dim_tok_inserted = [False for _ in range(num_prompts)] - eos_tok_seen = [False for _ in range(num_prompts)] + dim_tok_inserted = [False for _ in range(num_variations)] + eos_tok_seen = [False for _ in range(num_variations)] total_len = prompt_len + max_new_tokens seq = torch.stack( [ torch.tensor( - tokenizer.encode(p + [tokenizer.pad_tok] * (total_len - len(p))) + tokenizer.encode( + prompt + [tokenizer.pad_tok] * (total_len - prompt_len) + ) ) - for p in prompts + for _ in range(num_variations) ] ).cuda() @@ -144,22 +142,11 @@ def sample_batch( ) model.setup_cache( - batch_size=num_prompts, + batch_size=num_variations, max_seq_len=total_len, - dtype=( - torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 - ), + dtype=DTYPE, ) - if embedding: - condition_embedding = torch.tensor( - [embedding for _ in range(num_prompts)], device=seq.device - ) - model.fill_condition_kv(cond_emb=condition_embedding) - emb_offset = 1 - else: - emb_offset = 0 - print( f"Using hyperparams: temp={temp}, top_p={top_p}, min_p={min_p}, gen_len={max_new_tokens}" ) @@ -176,16 +163,14 @@ def sample_batch( logits = prefill( model, idxs=seq[:, :idx], - input_pos=torch.arange( - emb_offset, idx + emb_offset, device=seq.device - ), - ) + input_pos=torch.arange(0, idx, device=seq.device), + )[:, -1] else: logits = decode_one( model, idxs=seq[:, idx - 1 : idx], input_pos=torch.tensor( - [(idx + emb_offset) - 1], + [(idx) - 1], device=seq.device, dtype=torch.int, ), @@ -227,16 +212,13 @@ def sample_batch( return decoded_results -# Not tested but I think this works -@torch.autocast( - "cuda", - dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16, -) +# TODO: Verify and implement for mlx +@torch.autocast("cuda", dtype=DTYPE) @torch.inference_mode() def sample_batch_cfg( model: TransformerLM, tokenizer: AbsTokenizer, - prompts: List[list], + prompts: list[list], max_new_tokens: int, cfg_gamma: float, embedding: list[float], @@ -287,9 +269,7 @@ def sample_batch_cfg( model.setup_cache( batch_size=num_prompts, max_seq_len=total_context_len, - dtype=( - torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 - ), + dtype=DTYPE, ) condition_embedding = torch.tensor( @@ -304,7 +284,7 @@ def sample_batch_cfg( f"Using hyperparams: temp={temp}, top_p={top_p}, min_p={min_p}, gamma={cfg_gamma}, gen_len={max_new_tokens}" ) - CFG_WARM_UP_STEPS = 250 + CFG_WARM_UP_STEPS = min(250, max_new_tokens) curr_step = 0 for idx in ( pbar := tqdm( @@ -313,19 +293,21 @@ def sample_batch_cfg( leave=False, ) ): - with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): - if idx == prompt_len: - logits = prefill( - model, - idxs=seq[:, :idx], - input_pos=torch.arange( - embedding_offset, - idx + embedding_offset, - device=seq.device, - ), - pad_idxs=pad_idxs, - ) - else: + if idx == prompt_len: + logits = prefill( + model, + idxs=seq[:, :idx], + input_pos=torch.arange( + embedding_offset, + idx + embedding_offset, + device=seq.device, + ), + pad_idxs=pad_idxs, + )[:, -1] + else: + with torch.nn.attention.sdpa_kernel( + torch.nn.attention.SDPBackend.MATH + ): logits = decode_one( model, idxs=seq[:, idx - 1 : idx], @@ -378,52 +360,3 @@ def sample_batch_cfg( ] return decoded_results - - -# Working -def sample_min_p(probs, p_base): - """See - https://arxiv.org/pdf/2407.01082""" - p_max, _ = torch.max(probs, dim=-1, keepdim=True) - p_scaled = p_base * p_max - mask = probs >= p_scaled - - masked_probs = probs.clone() - masked_probs[~mask] = 0.0 - masked_probs.div_(masked_probs.sum(dim=-1, keepdim=True)) - next_token = torch.multinomial(masked_probs, num_samples=1) - - return next_token - - -def sample_top_p(probs, p): - probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) - probs_sum = torch.cumsum(probs_sort, dim=-1) - mask = probs_sum - probs_sort > p - probs_sort[mask] = 0.0 - - probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) - next_token = torch.multinomial(probs_sort, num_samples=1) - next_token = torch.gather(probs_idx, -1, next_token) - - return next_token - - -def get_inference_prompt( - midi_dict: MidiDict, tokenizer: AbsTokenizer, prompt_len_ms: int -): - midi_dict.note_msgs = [ - msg - for msg in midi_dict.note_msgs - if midi_dict.tick_to_ms(msg["data"]["start"]) <= prompt_len_ms - ] - - if len(midi_dict.note_msgs) == 0: - return [("prefix", "instrument", "piano"), tokenizer.bos_tok] - - seq = tokenizer.tokenize(midi_dict=midi_dict) - if tokenizer.dim_tok in seq: - seq.remove(tokenizer.dim_tok) - if tokenizer.eos_tok in seq: - seq.remove(tokenizer.eos_tok) - - return seq diff --git a/aria/inference/sample_mlx.py b/aria/inference/sample_mlx.py index 2abeddc..9742b66 100644 --- a/aria/inference/sample_mlx.py +++ b/aria/inference/sample_mlx.py @@ -3,14 +3,15 @@ import torch import numpy as np import mlx.core as mx -import mlx.nn as nn -from typing import List from tqdm import tqdm +from aria.inference import sample_min_p, sample_top_p from aria.inference.model_mlx import TransformerLM from ariautils.tokenizer import Tokenizer +DTYPE = mx.float32 + def decode_one( model: TransformerLM, @@ -41,7 +42,7 @@ def prefill( input_pos=input_pos, offset=input_pos[0], pad_idxs=pad_idxs, - )[:, -1] + ) return logits @@ -81,13 +82,17 @@ def update_seq_ids_( def sample_batch( model: TransformerLM, tokenizer: Tokenizer, - prompts: List[list], + prompt: list, + num_variations: list, max_new_tokens: int, - force_end=False, temp: float = 0.95, + force_end: bool = False, + top_p: float | None = None, min_p: float | None = None, - # compile: bool = False, ): + assert top_p is not None or min_p is not None + if top_p is not None: + assert 0.5 <= top_p <= 1.0 if min_p is not None: assert 0.0 <= min_p <= 1.0 if temp is not None: @@ -95,27 +100,31 @@ def sample_batch( if force_end: assert max_new_tokens > 130, "prompt too long to use force_end=True" - prompt_len = len(prompts[0]) - num_prompts = len(prompts) - assert all([len(p) == prompt_len for p in prompts]) + prompt_len = len(prompt) model.eval() - dim_tok_inserted = [False for _ in range(num_prompts)] - eos_tok_seen = [False for _ in range(num_prompts)] + dim_tok_inserted = [False for _ in range(num_variations)] + eos_tok_seen = [False for _ in range(num_variations)] total_len = prompt_len + max_new_tokens + seq = mx.stack( [ mx.array( - tokenizer.encode(p + [tokenizer.pad_tok] * (total_len - len(p))) + tokenizer.encode( + prompt + [tokenizer.pad_tok] * (total_len - prompt_len) + ), + dtype=mx.int32, ) - for p in prompts - ] + for _ in range(num_variations) + ], ) model.setup_cache( - batch_size=num_prompts, max_seq_len=total_len, dtype=mx.float32 + batch_size=num_variations, + max_seq_len=total_len, + dtype=DTYPE, ) print( - f"Using hyperparams: temp={temp}, min_p={min_p}, gen_len={max_new_tokens}" + f"Using hyperparams: temp={temp}, top_p={top_p}, min_p={min_p}, gen_len={max_new_tokens}" ) for idx in ( @@ -130,7 +139,7 @@ def sample_batch( model, idxs=seq[:, :idx], input_pos=mx.arange(0, idx), - ) + )[:, -1] else: logits = decode_one( model, @@ -143,12 +152,13 @@ def sample_batch( if temp > 0.0: probs = mx.softmax(logits / temp, axis=-1) - next_token_ids = sample_min_p(probs, min_p).flatten() + if min_p is not None: + next_token_ids = sample_min_p_mlx(probs, min_p).flatten() + else: + next_token_ids = sample_top_p_mlx(probs, top_p).flatten() else: next_token_ids = mx.argmax(logits, axis=-1).flatten() - print(tokenizer.id_to_tok[next_token_ids[0].item()]) - update_seq_ids_( seq=seq, idx=idx, @@ -176,86 +186,18 @@ def sample_batch( return decoded_results -def sample_min_p(probs: mx.array, p_base: float): # Added type hint +def sample_min_p_mlx(probs: mx.array, p_base: float) -> mx.array: """See - https://arxiv.org/pdf/2407.01082""" - p_max = mx.max(probs, axis=-1, keepdims=True) - p_scaled = p_base * p_max - mask = probs >= p_scaled - - masked_probs = mx.where(~mask, mx.zeros_like(probs), probs) - sum_masked_probs = mx.sum(masked_probs, axis=-1, keepdims=True) - masked_probs_normalized = masked_probs / sum_masked_probs - - # Dumb workaround for mlx not having categorical probs sampler - next_token = mx.array( - torch.multinomial( - torch.from_numpy(np.array(masked_probs_normalized)), num_samples=1 - ), - dtype=mx.int32, - ) - return next_token + probs_t = torch.from_numpy(np.array(probs)) + next_token_t = sample_min_p(probs=probs_t, p_base=p_base) + return mx.array(next_token_t, dtype=mx.int32) -def sample(): - import os - - from aria.model import ModelConfig - from aria.config import load_model_config - - from ariautils.midi import MidiDict - from ariautils.tokenizer import AbsTokenizer - from aria.inference.sample_cuda import get_inference_prompt - - CHECKPOINT_PATH = ( - "/Users/louis/work/aria/models/medium-75-annealed.safetensors" - ) - PROMPT_MIDI_PATH = "/Users/louis/Dropbox/shared/audio.mid" - - NUM_VARIATIONS = 1 # Number of samples (e.g., 2 variations) - TRUNCATE_LEN_MS = 15000 # Prompt length in milliseconds (e.g., 10 seconds) - GEN_LENGTH = 1024 # Number of new tokens to generate (args.l) - FORCE_END = False # Whether to force sequence end (args.e) - TEMPERATURE = 0.95 # Sampling temperature (args.temp) - MIN_P = 0.05 # Min-p sampling (args.min_p) - - SAMPLES_DIR = os.path.join(os.getcwd(), "/Users/louis/Dropbox/shared") - - tokenizer = AbsTokenizer() - model_config = ModelConfig(**load_model_config("medium-emb")) - model_config.set_vocab_size(tokenizer.vocab_size) - model = TransformerLM(model_config) - model.load_weights(CHECKPOINT_PATH) - nn.quantize(model.model, group_size=128, bits=8) - - midi_dict = MidiDict.from_midi(mid_path=PROMPT_MIDI_PATH) - prompt_seq = get_inference_prompt( - tokenizer=tokenizer, - midi_dict=midi_dict, - prompt_len_ms=TRUNCATE_LEN_MS, - ) - - print(prompt_seq) - print(f"Prompt sequence length: {len(prompt_seq)} tokens") - prompts = [prompt_seq for _ in range(NUM_VARIATIONS)] - - results = sample_batch( - model=model, - tokenizer=tokenizer, - prompts=prompts, - max_new_tokens=GEN_LENGTH, - force_end=FORCE_END, - temp=TEMPERATURE, - min_p=MIN_P, - ) - for idx, tokenized_seq in enumerate(results): - res_midi_dict = tokenizer.detokenize(tokenized_seq) - res_midi = res_midi_dict.to_midi() - output_file_path = os.path.join(SAMPLES_DIR, f"res_{idx + 1}.mid") - res_midi.save(output_file_path) - print(f"Saved result {idx + 1} to {output_file_path}") +def sample_top_p_mlx(probs: mx.array, top_p: float) -> mx.array: + probs_t = torch.from_numpy(np.array(probs)) + next_token_t = sample_top_p(probs=probs_t, top_p=top_p) -if __name__ == "__main__": - sample() + return mx.array(next_token_t, dtype=mx.int32) diff --git a/aria/run.py b/aria/run.py index 06f11af..2844d01 100644 --- a/aria/run.py +++ b/aria/run.py @@ -6,258 +6,287 @@ import sys -def _parse_sample_args(): - argp = argparse.ArgumentParser(prog="aria sample") +def _parse_generate_args(): + argp = argparse.ArgumentParser(prog="aria generate") argp.add_argument( - "-checkpoint_path", help="path to model used for decoding" + "--backend", + choices=["torch_cuda", "mlx"], + default="torch_cuda", + help="backend for inference", ) - argp.add_argument("-prompt_midi_path", help="path to midi file") argp.add_argument( - "-embedding_checkpoint_path", - required=False, - help="path to model checkpoint used for embeddings", + "--checkpoint_path", help="path to model used for decoding" ) + argp.add_argument("--prompt_midi_path", help="path to midi file") argp.add_argument( - "-embedding_midi_paths", - nargs="+", - required=False, - help="path(s) to midi file(s) used for embeddings", + "--prompt_duration", + help="length of the input MIDI prompt, in seconds", + type=int, + default=20, ) argp.add_argument( - "-temp", + "--variations", + help="number of variations to generate", + type=int, + default=1, + ) + argp.add_argument( + "--temp", help="sampling temperature value", type=float, required=False, - default=0.95, + default=0.98, + ) + argp.add_argument( + "--min_p", + help="sampling min_p value", + type=float, + default=0.035, + required=False, ) argp.add_argument( - "-top_p", + "--top_p", help="sampling top_p value", type=float, required=False, ) argp.add_argument( - "-min_p", - help="sampling min_p value", + "--end", action="store_true", help="generate ending for piece" + ) + argp.add_argument( + "--length", + type=int, + help="number of tokens to generate per variation", + default=2048, + ) + argp.add_argument( + "--compile", + action="store_true", + help="use torch compiler to generate cudagraph for inference", + ) + argp.add_argument( + "--save_dir", + type=str, + default=".", + help="directory to save generated MIDI files", + ) + + return argp.parse_args(sys.argv[2:]) + + +def _parse_conditioned_generate_args(): + argp = argparse.ArgumentParser(prog="aria conditioned-generate") + argp.add_argument( + "--checkpoint_path", help="path to model used for decoding" + ) + argp.add_argument("--prompt_midi_path", help="path to midi file") + argp.add_argument( + "--prompt_duration", + help="length of the input MIDI prompt (seconds)", + type=int, + default=20, + ) + argp.add_argument( + "--embedding_model_checkpoint_path", + required=False, + help="path to model checkpoint used for embeddings", + ) + argp.add_argument( + "--embedding_midi_path", + help="path to MIDI file used for conditioning", + ) + argp.add_argument( + "--variations", + help="number of variations to generate", + type=int, + default=1, + ) + argp.add_argument( + "--temp", + help="sampling temperature value", type=float, required=False, + default=0.98, ) argp.add_argument( - "-cfg", + "--cfg", help="sampling cfg gamma value", type=float, + default=1.0, + ) + argp.add_argument( + "--min_p", + help="sampling min_p value", + type=float, + default=0.035, required=False, ) argp.add_argument( - "-var", - help="number of variations", - type=int, - default=1, + "--top_p", + help="sampling top_p value", + type=float, + required=False, ) argp.add_argument( - "-trunc", - help="length (in seconds) of the prompt", + "--end", action="store_true", help="generate ending for piece" + ) + argp.add_argument( + "--length", type=int, - default=20, + help="number of tokens to generate per variation", + default=2048, + ) + argp.add_argument( + "--backend", + choices=["torch_cuda", "mlx"], + default="torch_cuda", + help="backend for inference", + ) + argp.add_argument( + "--compile", + action="store_true", + help="use torch compiler to generate cudagraph for inference - only applies to backend='torch_cuda'", + ) + argp.add_argument( + "--save_dir", + type=str, + default=".", + help="directory to save generated MIDI files", ) - argp.add_argument("-e", action="store_true", help="enable force end") - argp.add_argument("-l", type=int, help="generation length", default=1024) - argp.add_argument("-compile", action="store_true", help="compile cudagraph") return argp.parse_args(sys.argv[2:]) -def _get_embedding( - embedding_checkpoint_path: str, - midi_paths: list[str], - start_ms: int | None = None, - end_ms: int | None = None, +def _get_prompt( + midi_path: str, + prompt_duration_s: int, ): - import torch - - from aria.model import TransformerEMB - from aria.model import ModelConfig - from aria.config import load_model_config - from aria.utils import _load_weight - from aria.eval.linear_probe import ( - get_aria_contrastive_embedding, - process_entry, - ) - from ariautils.midi import MidiDict from ariautils.tokenizer import AbsTokenizer + from aria.inference import get_inference_prompt - SLICE_NUM_NOTES = 300 - SLICE_MAX_SEQ_LEN = 1024 - - tokenizer = AbsTokenizer() - - model_state = _load_weight(embedding_checkpoint_path, "cuda") - model_state = { - k.replace("_orig_mod.", ""): v for k, v in model_state.items() - } - - model_config = ModelConfig(**load_model_config("medium-emb")) - model_config.set_vocab_size(tokenizer.vocab_size) - model_config.grad_checkpoint = False - model = TransformerEMB(model_config).cuda().eval() - model.load_state_dict(model_state) - - seqs = [] - for midi_path in midi_paths: - midi_dict = MidiDict.from_midi(midi_path) - midi_dict.note_msgs = [ - msg - for msg in midi_dict.note_msgs - if ( - midi_dict.tick_to_ms(msg["tick"]) >= start_ms - if start_ms is not None - else True - ) - and ( - midi_dict.tick_to_ms(msg["tick"]) <= end_ms - if end_ms is not None - else True - ) - ] - - seqs.extend( - process_entry( - entry=midi_dict, - slice_len_notes=SLICE_NUM_NOTES, - max_seq_len=SLICE_MAX_SEQ_LEN, - tokenizer=tokenizer, - ) - ) - - def model_forward(model, idxs): - return model(idxs) - - embeddings = get_aria_contrastive_embedding( - seqs=[s["seq"] for s in seqs], - hook_model=model, - hook_max_seq_len=SLICE_MAX_SEQ_LEN, - hook_tokenizer=tokenizer, - hook_model_forward=model_forward, + return get_inference_prompt( + midi_dict=MidiDict.from_midi(midi_path), + tokenizer=AbsTokenizer(), + prompt_len_ms=1e3 * prompt_duration_s, ) - embedding = torch.tensor(embeddings, device="cuda").mean(0).tolist() - return embedding +def _load_model_torch( + checkpoint_path: str, + config_name: str, + strict: bool = True, +): + from safetensors.torch import load_file -def sample(args): - """Entrypoint for sampling""" - - from torch.cuda import is_available as cuda_is_available + from ariautils.tokenizer import AbsTokenizer from aria.inference.model_cuda import TransformerLM from aria.model import ModelConfig from aria.config import load_model_config - from aria.inference.sample_cuda import ( - sample_batch, - sample_batch_cfg, - get_inference_prompt, - ) - from aria.utils import _load_weight - from ariautils.midi import MidiDict - from ariautils.tokenizer import AbsTokenizer + model_config = ModelConfig(**load_model_config(name=config_name)) + model_config.set_vocab_size(AbsTokenizer().vocab_size) + model = TransformerLM(model_config) - if not cuda_is_available(): - raise Exception("CUDA device is not available.") + state_dict = load_file(filename=checkpoint_path) + model.load_state_dict(state_dict=state_dict, strict=strict) - num_variations = args.var - truncate_len = args.trunc - force_end = args.e + return model - tokenizer = AbsTokenizer() - if args.embedding_checkpoint_path and args.embedding_midi_paths: - print(f"Using embedding from {args.embedding_midi_paths}") - embedding = _get_embedding( - embedding_checkpoint_path=args.embedding_checkpoint_path, - midi_paths=args.embedding_midi_paths, - start_ms=args.trunc * 1e3, - end_ms=None, - ) - else: - embedding = None +def _load_model_mlx( + checkpoint_path: str, + config_name: str, + strict: bool = True, +): + import mlx.core as mx - model_state = _load_weight(args.checkpoint_path, "cuda") - model_state = { - k.replace("_orig_mod.", ""): v for k, v in model_state.items() - } + from ariautils.tokenizer import AbsTokenizer + from aria.inference.model_mlx import TransformerLM + from aria.model import ModelConfig + from aria.config import load_model_config - model_config = ModelConfig(**load_model_config("medium-emb")) - model_config.set_vocab_size(tokenizer.vocab_size) - model_config.grad_checkpoint = False - model = TransformerLM(model_config).cuda() + model_config = ModelConfig(**load_model_config(name=config_name)) + model_config.set_vocab_size(AbsTokenizer().vocab_size) + model = TransformerLM(model_config) + model.load_weights(checkpoint_path, strict=strict) + mx.eval(model.parameters()) - try: - model.load_state_dict(model_state) - except Exception as e: - print("Failed to load model_state - loading with strict=False") - model.load_state_dict(model_state, strict=False) + return model - assert args.l > 0, "Generation length must be positive." - max_new_tokens = args.l - # Load and format prompts and metadata - midi_dict = MidiDict.from_midi(mid_path=args.prompt_midi_path) +def generate(args): + from ariautils.tokenizer import AbsTokenizer - prompt_seq = get_inference_prompt( - tokenizer=tokenizer, - midi_dict=midi_dict, - prompt_len_ms=truncate_len * 1e3, - ) + num_variations = args.variations + prompt_duration_s = args.prompt_duration + backend = args.backend + max_new_tokens = args.length - print(prompt_seq) + assert num_variations > 0 + assert prompt_duration_s > 0 + assert max_new_tokens > 0 + assert os.path.isdir(args.save_dir) - if len(prompt_seq) + args.l > model_config.max_seq_len: - print( - "WARNING: Required context exceeds max_seq_len supported by model" - ) - prompts = [prompt_seq for _ in range(num_variations)] + tokenizer = AbsTokenizer() + prompt = _get_prompt( + args.prompt_midi_path, + prompt_duration_s=prompt_duration_s, + ) + max_new_tokens = min(8096 - len(prompt), max_new_tokens) - samples_dir = "/home/loubb/Dropbox/shared" - if os.path.isdir(samples_dir) is False: - os.mkdir(samples_dir) + if backend == "torch_cuda": + from torch.cuda import is_available + from aria.inference.sample_cuda import sample_batch as sample_batch_t - if args.cfg and embedding is not None: - results = sample_batch_cfg( + assert is_available(), "CUDA not available" + + model = _load_model_torch( + checkpoint_path=args.checkpoint_path, + config_name="medium", + strict=True, + ) + results = sample_batch_t( model=model, tokenizer=tokenizer, - prompts=prompts, + prompt=prompt, + num_variations=num_variations, max_new_tokens=max_new_tokens, - force_end=force_end, + force_end=args.end, temp=args.temp, top_p=args.top_p, min_p=args.min_p, - cfg_gamma=args.cfg, compile=args.compile, - embedding=embedding, ) - else: - results = sample_batch( + elif backend == "mlx": + from aria.inference.sample_mlx import sample_batch as sample_batch_mlx + + model = _load_model_mlx( + checkpoint_path=args.checkpoint_path, + config_name="medium", + strict=True, + ) + results = sample_batch_mlx( model=model, tokenizer=tokenizer, - prompts=prompts, + prompt=prompt, + num_variations=num_variations, max_new_tokens=max_new_tokens, - force_end=force_end, + force_end=args.end, temp=args.temp, top_p=args.top_p, min_p=args.min_p, - compile=args.compile, - embedding=embedding, ) for idx, tokenized_seq in enumerate(results): res_midi_dict = tokenizer.detokenize(tokenized_seq) res_midi = res_midi_dict.to_midi() - res_midi.save(os.path.join(samples_dir, f"res_{idx + 1}.mid")) + res_midi.save(os.path.join(args.save_dir, f"res_{idx + 1}.mid")) - print("Results saved to samples/") + print(f"Results saved to {os.path.realpath(args.save_dir)}") +# TODO: Add turn - to -- flags def _parse_midi_dataset_args(): argp = argparse.ArgumentParser(prog="aria midi-dataset") argp.add_argument("dir", help="directory containing midi files") @@ -304,6 +333,7 @@ def build_midi_dataset(args): ) +# TODO: Add turn - to -- flags def _parse_pretrain_dataset_args(): argp = argparse.ArgumentParser(prog="aria pretrain-dataset") argp.add_argument("-load_path", help="path midi_dict dataset") @@ -360,22 +390,23 @@ def main(): "command", help="command to run", choices=( - "sample", + "generate", + "conditional-generate", "midi-dataset", "pretrain-dataset", ), ) - # parse_args defaults to [1:] for args, but you need to - # exclude the rest of the args too, or validation will fail args = parser.parse_args(sys.argv[1:2]) if not hasattr(args, "command"): parser.print_help() print("Unrecognized command") exit(1) - elif args.command == "sample": - sample(args=_parse_sample_args()) + elif args.command == "generate": + generate(args=_parse_generate_args()) + # elif args.command == "conditioned-generate": + # condi(args=_parse_conditioned_generate_args()) elif args.command == "midi-dataset": build_midi_dataset(args=_parse_midi_dataset_args()) elif args.command == "pretrain-dataset": From 1daac4433843472145abdaa6dca763250cc70fc4 Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 3 Jun 2025 21:32:51 +0000 Subject: [PATCH 57/72] cfg conditioned generation refactored for torch_cuda --- aria/embedding.py | 17 +++- aria/inference/sample_cuda.py | 52 ++++++------ aria/inference/sample_mlx.py | 3 +- aria/model.py | 2 +- aria/run.py | 146 +++++++++++++++++++++++++++++----- 5 files changed, 169 insertions(+), 51 deletions(-) diff --git a/aria/embedding.py b/aria/embedding.py index 81a1cdd..030bcb1 100644 --- a/aria/embedding.py +++ b/aria/embedding.py @@ -7,6 +7,8 @@ from aria.model import TransformerEMB +MAX_EMBEDDING_SEQ_LEN = 2048 + def _validate_midi_for_emb(midi_dict: MidiDict): present_instruments = { @@ -41,9 +43,10 @@ def get_embedding_from_seq( ): tokenizer = AbsTokenizer() - assert len(seq) <= 2048, "Sequence lengths above 2048 not supported" + assert len(seq) <= MAX_EMBEDDING_SEQ_LEN, f"Sequence lengths above {MAX_EMBEDDING_SEQ_LEN} not supported" # fmt: skip _validate_midi_for_emb(tokenizer.detokenize(seq)) + model.eval() eos_pos = seq.index(tokenizer.eos_tok) seq_enc = torch.tensor(tokenizer.encode(seq), device=device) emb = model.forward(seq_enc.view(1, -1))[0, eos_pos] @@ -51,6 +54,7 @@ def get_embedding_from_seq( return emb +# TODO: Make sure this is bug free def get_global_embedding_from_midi( model: TransformerEMB, midi_dict: MidiDict | None = None, @@ -70,7 +74,16 @@ def get_global_embedding_from_midi( _validate_midi_for_emb(midi_dict) chunks = _get_chunks(midi_dict=midi_dict, notes_per_chunk=notes_per_chunk) - seqs = [tokenizer.tokenize(c, add_dim_tok=False)[:2048] for c in chunks] + seqs = [ + tokenizer.tokenize(c, add_dim_tok=False)[:MAX_EMBEDDING_SEQ_LEN] + for c in chunks + ] + + # Add back eos_tok if truncated by MAX_EMBEDDING_SEQ_LEN + for seq in seqs: + if seq[-1] != tokenizer.eos_tok: + seq[-1] = tokenizer.eos_tok + embs = [ get_embedding_from_seq(model=model, seq=s, device=device) for s in seqs ] diff --git a/aria/inference/sample_cuda.py b/aria/inference/sample_cuda.py index 0c49d25..7f5f692 100644 --- a/aria/inference/sample_cuda.py +++ b/aria/inference/sample_cuda.py @@ -106,12 +106,11 @@ def sample_batch( compile: bool = False, ): assert top_p is not None or min_p is not None + assert 0.0 <= temp <= 2.0 if top_p is not None: assert 0.5 <= top_p <= 1.0 if min_p is not None: assert 0.0 <= min_p <= 1.0 - if temp is not None: - assert 0.0 <= temp <= 2.0 if force_end: assert max_new_tokens > 130, "prompt too long to use force_end=True" @@ -212,51 +211,52 @@ def sample_batch( return decoded_results -# TODO: Verify and implement for mlx @torch.autocast("cuda", dtype=DTYPE) @torch.inference_mode() def sample_batch_cfg( model: TransformerLM, tokenizer: AbsTokenizer, - prompts: list[list], + prompt: list, + num_variations: list, max_new_tokens: int, cfg_gamma: float, embedding: list[float], + temp: float, force_end=False, - temp: float = 0.95, top_p: float | None = None, min_p: float | None = None, compile: bool = False, ): - assert 0.0 <= cfg_gamma <= 15.0 assert top_p is not None or min_p is not None + assert 0.0 <= temp <= 2.0 + assert 0.0 <= cfg_gamma <= 10.0 if top_p is not None: assert 0.5 <= top_p <= 1.0 - if temp is not None: - assert 0.0 <= temp <= 2.0 + if min_p is not None: + assert 0.0 <= min_p <= 1.0 if force_end: assert max_new_tokens > 130, "prompt too long to use force_end=True" - prompts = get_cfg_prompt(prompts) - - prompt_len = len(prompts[0]) - num_prompts = len(prompts) - assert all([len(p) == prompt_len for p in prompts]) + prompt_len = len(prompt) + num_variations = 2 * num_variations # For CFG + model = model.cuda() model.eval() - total_context_len = prompt_len + max_new_tokens + dim_tok_inserted = [False for _ in range(num_variations)] + eos_tok_seen = [False for _ in range(num_variations)] + total_len = prompt_len + max_new_tokens seq = torch.stack( [ torch.tensor( tokenizer.encode( - p + [tokenizer.pad_tok] * (total_context_len - len(p)) + prompt + [tokenizer.pad_tok] * (total_len - prompt_len) ) ) - for p in prompts + for _ in range(num_variations) ] ).cuda() - dim_tok_inserted = [False for _ in range(num_prompts)] - eos_tok_seen = [False for _ in range(num_prompts)] + dim_tok_inserted = [False for _ in range(num_variations)] + eos_tok_seen = [False for _ in range(num_variations)] if compile is True: global decode_one @@ -267,13 +267,13 @@ def sample_batch_cfg( ) model.setup_cache( - batch_size=num_prompts, - max_seq_len=total_context_len, + batch_size=num_variations, + max_seq_len=total_len, dtype=DTYPE, ) condition_embedding = torch.tensor( - [embedding for _ in range(num_prompts)], device=seq.device + [embedding for _ in range(num_variations)], device=seq.device ) model.fill_condition_kv(cond_emb=condition_embedding) embedding_offset = 1 @@ -281,15 +281,15 @@ def sample_batch_cfg( pad_idxs[1::2, 0] = True print( - f"Using hyperparams: temp={temp}, top_p={top_p}, min_p={min_p}, gamma={cfg_gamma}, gen_len={max_new_tokens}" + f"Using hyperparams: temp={temp}, top_p={top_p}, min_p={min_p}, cfg={cfg_gamma}, gen_len={max_new_tokens}" ) - CFG_WARM_UP_STEPS = min(250, max_new_tokens) + CFG_WARM_UP_STEPS = min(10, max_new_tokens) curr_step = 0 for idx in ( pbar := tqdm( - range(prompt_len, total_context_len), - total=total_context_len - prompt_len, + range(prompt_len, total_len), + total=total_len - prompt_len, leave=False, ) ): @@ -341,7 +341,7 @@ def sample_batch_cfg( next_token_ids=next_token_ids, dim_tok_inserted=dim_tok_inserted, eos_tok_seen=eos_tok_seen, - max_len=total_context_len, + max_len=total_len, force_end=force_end, tokenizer=tokenizer, ) diff --git a/aria/inference/sample_mlx.py b/aria/inference/sample_mlx.py index 9742b66..3b49f30 100644 --- a/aria/inference/sample_mlx.py +++ b/aria/inference/sample_mlx.py @@ -91,12 +91,11 @@ def sample_batch( min_p: float | None = None, ): assert top_p is not None or min_p is not None + assert 0.0 <= temp <= 2.0 if top_p is not None: assert 0.5 <= top_p <= 1.0 if min_p is not None: assert 0.0 <= min_p <= 1.0 - if temp is not None: - assert 0.0 <= temp <= 2.0 if force_end: assert max_new_tokens > 130, "prompt too long to use force_end=True" diff --git a/aria/model.py b/aria/model.py index 68b3598..573f548 100644 --- a/aria/model.py +++ b/aria/model.py @@ -182,7 +182,7 @@ def forward( ).to(src.device) freqs_cis = self.freqs_cis[: src.shape[1]] - if self.model_config.grad_checkpoint is True: + if self.model_config.grad_checkpoint is True and self.training: for layer in self.encode_layers: def create_custom_forward(module): diff --git a/aria/run.py b/aria/run.py index 2844d01..2b226db 100644 --- a/aria/run.py +++ b/aria/run.py @@ -75,20 +75,25 @@ def _parse_generate_args(): def _parse_conditioned_generate_args(): - argp = argparse.ArgumentParser(prog="aria conditioned-generate") + argp = argparse.ArgumentParser(prog="aria generate") + argp.add_argument( + "--backend", + choices=["torch_cuda", "mlx"], + default="torch_cuda", + help="backend for inference", + ) argp.add_argument( "--checkpoint_path", help="path to model used for decoding" ) argp.add_argument("--prompt_midi_path", help="path to midi file") argp.add_argument( "--prompt_duration", - help="length of the input MIDI prompt (seconds)", + help="length of the input MIDI prompt, in seconds", type=int, default=20, ) argp.add_argument( "--embedding_model_checkpoint_path", - required=False, help="path to model checkpoint used for embeddings", ) argp.add_argument( @@ -136,16 +141,10 @@ def _parse_conditioned_generate_args(): help="number of tokens to generate per variation", default=2048, ) - argp.add_argument( - "--backend", - choices=["torch_cuda", "mlx"], - default="torch_cuda", - help="backend for inference", - ) argp.add_argument( "--compile", action="store_true", - help="use torch compiler to generate cudagraph for inference - only applies to backend='torch_cuda'", + help="use torch compiler to generate cudagraph for inference", ) argp.add_argument( "--save_dir", @@ -172,7 +171,24 @@ def _get_prompt( ) -def _load_model_torch( +def _load_embedding_model(checkpoint_path: str): + from safetensors.torch import load_file + + from ariautils.tokenizer import AbsTokenizer + from aria.model import TransformerEMB, ModelConfig + from aria.config import load_model_config + + model_config = ModelConfig(**load_model_config(name="medium-emb")) + model_config.set_vocab_size(AbsTokenizer().vocab_size) + model = TransformerEMB(model_config) + + state_dict = load_file(filename=checkpoint_path) + model.load_state_dict(state_dict=state_dict, strict=True) + + return model + + +def _load_inference_model_torch( checkpoint_path: str, config_name: str, strict: bool = True, @@ -194,7 +210,7 @@ def _load_model_torch( return model -def _load_model_mlx( +def _load_inference_model_mlx( checkpoint_path: str, config_name: str, strict: bool = True, @@ -224,7 +240,7 @@ def generate(args): max_new_tokens = args.length assert num_variations > 0 - assert prompt_duration_s > 0 + assert prompt_duration_s >= 0 assert max_new_tokens > 0 assert os.path.isdir(args.save_dir) @@ -241,19 +257,19 @@ def generate(args): assert is_available(), "CUDA not available" - model = _load_model_torch( + model = _load_inference_model_torch( checkpoint_path=args.checkpoint_path, config_name="medium", strict=True, - ) + ) # Might want strict = False results = sample_batch_t( model=model, tokenizer=tokenizer, prompt=prompt, num_variations=num_variations, max_new_tokens=max_new_tokens, - force_end=args.end, temp=args.temp, + force_end=args.end, top_p=args.top_p, min_p=args.min_p, compile=args.compile, @@ -261,7 +277,7 @@ def generate(args): elif backend == "mlx": from aria.inference.sample_mlx import sample_batch as sample_batch_mlx - model = _load_model_mlx( + model = _load_inference_model_mlx( checkpoint_path=args.checkpoint_path, config_name="medium", strict=True, @@ -272,12 +288,102 @@ def generate(args): prompt=prompt, num_variations=num_variations, max_new_tokens=max_new_tokens, + temp=args.temp, force_end=args.end, + top_p=args.top_p, + min_p=args.min_p, + ) + + for idx, tokenized_seq in enumerate(results): + res_midi_dict = tokenizer.detokenize(tokenized_seq) + res_midi = res_midi_dict.to_midi() + res_midi.save(os.path.join(args.save_dir, f"res_{idx + 1}.mid")) + + print(f"Results saved to {os.path.realpath(args.save_dir)}") + + +# TODO: Double checking during training we didn't do a weighted global sum +def _get_embedding( + embedding_model_checkpoints_path: str, + embedding_midi_path: str, +): + from aria.embedding import get_global_embedding_from_midi + + model = _load_embedding_model( + checkpoint_path=embedding_model_checkpoints_path + ).cpu() + global_embedding = get_global_embedding_from_midi( + model=model, + midi_path=embedding_midi_path, + device="cpu", + ) + + return global_embedding.tolist() + + +def conditioned_generate(args): + from ariautils.tokenizer import AbsTokenizer + + num_variations = args.variations + prompt_duration_s = args.prompt_duration + backend = args.backend + max_new_tokens = args.length + + assert num_variations > 0 + assert prompt_duration_s >= 0 + assert max_new_tokens > 0 + assert os.path.isdir(args.save_dir) + + tokenizer = AbsTokenizer() + prompt = _get_prompt( + args.prompt_midi_path, + prompt_duration_s=prompt_duration_s, + ) + embedding = _get_embedding( + embedding_model_checkpoints_path=args.embedding_model_checkpoint_path, + embedding_midi_path=args.embedding_midi_path, + ) + max_new_tokens = min(8096 - len(prompt), max_new_tokens) + + if backend == "torch_cuda": + from torch.cuda import is_available + from aria.inference.sample_cuda import ( + sample_batch_cfg as sample_batch_cfg_t, + ) + + assert is_available(), "CUDA not available" + + model = _load_inference_model_torch( + checkpoint_path=args.checkpoint_path, + config_name="medium-emb", + strict=True, + ) + results = sample_batch_cfg_t( + model=model, + tokenizer=tokenizer, + prompt=prompt, + num_variations=num_variations, + max_new_tokens=max_new_tokens, + cfg_gamma=args.cfg, + embedding=embedding, temp=args.temp, + force_end=args.end, top_p=args.top_p, min_p=args.min_p, + compile=args.compile, ) + elif backend == "mlx": + from aria.inference.sample_mlx import sample_batch as sample_batch_mlx + + model = _load_inference_model_mlx( + checkpoint_path=args.checkpoint_path, + config_name="medium", + strict=True, + ) + + raise NotImplementedError + for idx, tokenized_seq in enumerate(results): res_midi_dict = tokenizer.detokenize(tokenized_seq) res_midi = res_midi_dict.to_midi() @@ -391,7 +497,7 @@ def main(): help="command to run", choices=( "generate", - "conditional-generate", + "conditioned-generate", "midi-dataset", "pretrain-dataset", ), @@ -405,8 +511,8 @@ def main(): exit(1) elif args.command == "generate": generate(args=_parse_generate_args()) - # elif args.command == "conditioned-generate": - # condi(args=_parse_conditioned_generate_args()) + elif args.command == "conditioned-generate": + conditioned_generate(args=_parse_conditioned_generate_args()) elif args.command == "midi-dataset": build_midi_dataset(args=_parse_midi_dataset_args()) elif args.command == "pretrain-dataset": From 479edc1fc705e1565ec3fbc9e20c8f0bd750f67f Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 4 Jun 2025 12:50:27 +0000 Subject: [PATCH 58/72] add mlx backend for conditioned generation --- aria/inference/model_mlx.py | 2 - aria/inference/sample_cuda.py | 5 +- aria/inference/sample_mlx.py | 139 +++++++++++++++++++++++++++++++++- aria/run.py | 21 ++++- models/placeholder.txt | 0 5 files changed, 155 insertions(+), 12 deletions(-) delete mode 100644 models/placeholder.txt diff --git a/aria/inference/model_mlx.py b/aria/inference/model_mlx.py index 7e306c7..169b30b 100644 --- a/aria/inference/model_mlx.py +++ b/aria/inference/model_mlx.py @@ -173,8 +173,6 @@ def fill_condition_kv(self, emb: mx.array): for layer in self.encode_layers: x = layer(x, input_pos, offset, mask) - self.causal_mask = None - def __call__( self, idxs: mx.array, diff --git a/aria/inference/sample_cuda.py b/aria/inference/sample_cuda.py index 7f5f692..909bd8d 100644 --- a/aria/inference/sample_cuda.py +++ b/aria/inference/sample_cuda.py @@ -255,8 +255,6 @@ def sample_batch_cfg( for _ in range(num_variations) ] ).cuda() - dim_tok_inserted = [False for _ in range(num_variations)] - eos_tok_seen = [False for _ in range(num_variations)] if compile is True: global decode_one @@ -284,7 +282,7 @@ def sample_batch_cfg( f"Using hyperparams: temp={temp}, top_p={top_p}, min_p={min_p}, cfg={cfg_gamma}, gen_len={max_new_tokens}" ) - CFG_WARM_UP_STEPS = min(10, max_new_tokens) + CFG_WARM_UP_STEPS = min(250, max_new_tokens) curr_step = 0 for idx in ( pbar := tqdm( @@ -335,6 +333,7 @@ def sample_batch_cfg( next_token_ids = torch.argmax(logits_cfg, dim=-1).flatten() next_token_ids = next_token_ids.repeat_interleave(2) + update_seq_ids_( seq=seq, idx=idx, diff --git a/aria/inference/sample_mlx.py b/aria/inference/sample_mlx.py index 3b49f30..9dccb7b 100644 --- a/aria/inference/sample_mlx.py +++ b/aria/inference/sample_mlx.py @@ -8,7 +8,7 @@ from aria.inference import sample_min_p, sample_top_p from aria.inference.model_mlx import TransformerLM -from ariautils.tokenizer import Tokenizer +from ariautils.tokenizer import AbsTokenizer DTYPE = mx.float32 @@ -55,7 +55,7 @@ def update_seq_ids_( eos_tok_seen: list, max_len: int, force_end: bool, - tokenizer: Tokenizer, + tokenizer: AbsTokenizer, ): # Insert dim and pad toks for _idx in range(seq.shape[0]): @@ -81,7 +81,7 @@ def update_seq_ids_( def sample_batch( model: TransformerLM, - tokenizer: Tokenizer, + tokenizer: AbsTokenizer, prompt: list, num_variations: list, max_new_tokens: int, @@ -185,6 +185,139 @@ def sample_batch( return decoded_results +def sample_batch_cfg( + model: TransformerLM, + tokenizer: AbsTokenizer, + prompt: list, + num_variations: list, + max_new_tokens: int, + cfg_gamma: float, + embedding: list[float], + temp: float, + force_end=False, + top_p: float | None = None, + min_p: float | None = None, +): + assert top_p is not None or min_p is not None + assert 0.0 <= temp <= 2.0 + assert 0.0 <= cfg_gamma <= 10.0 + if top_p is not None: + assert 0.5 <= top_p <= 1.0 + if min_p is not None: + assert 0.0 <= min_p <= 1.0 + if force_end: + assert max_new_tokens > 130, "prompt too long to use force_end=True" + + prompt_len = len(prompt) + num_variations = 2 * num_variations # For CFG + + model.eval() + dim_tok_inserted = [False for _ in range(num_variations)] + eos_tok_seen = [False for _ in range(num_variations)] + total_len = prompt_len + max_new_tokens + seq = mx.stack( + [ + mx.array( + tokenizer.encode( + prompt + [tokenizer.pad_tok] * (total_len - prompt_len) + ), + dtype=mx.int32, + ) + for _ in range(num_variations) + ] + ) + + model.setup_cache( + batch_size=num_variations, + max_seq_len=total_len, + dtype=DTYPE, + ) + + condition_embedding = mx.array( + [embedding for _ in range(num_variations)], + dtype=DTYPE, + ) + model.fill_condition_kv(cond_emb=condition_embedding) + embedding_offset = 1 + pad_idxs = mx.zeros_like(seq) + pad_idxs[1::2, 0] = True + + print( + f"Using hyperparams: temp={temp}, top_p={top_p}, min_p={min_p}, cfg={cfg_gamma}, gen_len={max_new_tokens}" + ) + + CFG_WARM_UP_STEPS = min(250, max_new_tokens) + curr_step = 0 + for idx in ( + pbar := tqdm( + range(prompt_len, total_len), + total=total_len - prompt_len, + leave=False, + ) + ): + if idx == prompt_len: + logits = prefill( + model, + idxs=seq[:, :idx], + input_pos=mx.arange(embedding_offset, idx + embedding_offset), + pad_idxs=pad_idxs, + )[:, -1] + else: + logits = decode_one( + model, + idxs=seq[:, idx - 1 : idx], + input_pos=mx.array( + [(idx + embedding_offset) - 1], + dtype=mx.int32, + ), + pad_idxs=pad_idxs, + ) + + curr_step += 1 + _cfg_gamma = min(cfg_gamma, (curr_step / CFG_WARM_UP_STEPS) * cfg_gamma) + + logits_cfg = _cfg_gamma * logits[::2] + (1 - _cfg_gamma) * logits[1::2] + logits_cfg[:, tokenizer.tok_to_id[tokenizer.dim_tok]] = float("-inf") + + if temp > 0.0: + probs = mx.softmax(logits_cfg / temp, axis=-1) + + if min_p is not None: + next_token_ids = sample_min_p_mlx(probs, min_p).flatten() + else: + next_token_ids = sample_top_p_mlx(probs, top_p).flatten() + else: + next_token_ids = mx.argmax(logits_cfg, axis=-1).flatten() + + next_token_ids = mx.repeat(next_token_ids, repeats=2) + + update_seq_ids_( + seq=seq, + idx=idx, + next_token_ids=next_token_ids, + dim_tok_inserted=dim_tok_inserted, + eos_tok_seen=eos_tok_seen, + max_len=total_len, + force_end=force_end, + tokenizer=tokenizer, + ) + + if all(seen_eos is True for seen_eos in eos_tok_seen): + break + + decoded_results = [tokenizer.decode(s) for s in seq.tolist()][::2] + decoded_results = [ + ( + res[: res.index(tokenizer.eos_tok) + 1] + if tokenizer.eos_tok in res + else res + ) + for res in decoded_results + ] + + return decoded_results + + def sample_min_p_mlx(probs: mx.array, p_base: float) -> mx.array: """See - https://arxiv.org/pdf/2407.01082""" diff --git a/aria/run.py b/aria/run.py index 2b226db..a0e4074 100644 --- a/aria/run.py +++ b/aria/run.py @@ -374,15 +374,28 @@ def conditioned_generate(args): ) elif backend == "mlx": - from aria.inference.sample_mlx import sample_batch as sample_batch_mlx + from aria.inference.sample_mlx import ( + sample_batch_cfg as sample_batch_cfg_mlx, + ) model = _load_inference_model_mlx( checkpoint_path=args.checkpoint_path, - config_name="medium", + config_name="medium-emb", strict=True, ) - - raise NotImplementedError + results = sample_batch_cfg_mlx( + model=model, + tokenizer=tokenizer, + prompt=prompt, + num_variations=num_variations, + max_new_tokens=max_new_tokens, + cfg_gamma=args.cfg, + embedding=embedding, + temp=args.temp, + force_end=args.end, + top_p=args.top_p, + min_p=args.min_p, + ) for idx, tokenized_seq in enumerate(results): res_midi_dict = tokenizer.detokenize(tokenized_seq) diff --git a/models/placeholder.txt b/models/placeholder.txt deleted file mode 100644 index e69de29..0000000 From 9ba3a0006231200d784cd237df962ceceb67956d Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 4 Jun 2025 14:38:05 +0100 Subject: [PATCH 59/72] fix mlx backend for conditioned gen --- aria/inference/sample_mlx.py | 8 +++++--- aria/run.py | 28 +++++++++++++++++++++------- 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/aria/inference/sample_mlx.py b/aria/inference/sample_mlx.py index 9dccb7b..c6a6e25 100644 --- a/aria/inference/sample_mlx.py +++ b/aria/inference/sample_mlx.py @@ -137,7 +137,7 @@ def sample_batch( logits = prefill( model, idxs=seq[:, :idx], - input_pos=mx.arange(0, idx), + input_pos=mx.arange(0, idx, dtype=mx.int32), )[:, -1] else: logits = decode_one( @@ -239,7 +239,7 @@ def sample_batch_cfg( ) model.fill_condition_kv(cond_emb=condition_embedding) embedding_offset = 1 - pad_idxs = mx.zeros_like(seq) + pad_idxs = mx.zeros(seq.shape, dtype=mx.bool_) pad_idxs[1::2, 0] = True print( @@ -259,7 +259,9 @@ def sample_batch_cfg( logits = prefill( model, idxs=seq[:, :idx], - input_pos=mx.arange(embedding_offset, idx + embedding_offset), + input_pos=mx.arange( + embedding_offset, idx + embedding_offset, dtype=mx.int32 + ), pad_idxs=pad_idxs, )[:, -1] else: diff --git a/aria/run.py b/aria/run.py index a0e4074..7c7a687 100644 --- a/aria/run.py +++ b/aria/run.py @@ -15,14 +15,20 @@ def _parse_generate_args(): help="backend for inference", ) argp.add_argument( - "--checkpoint_path", help="path to model used for decoding" + "--checkpoint_path", + help="path to model used for decoding", + required=True, + ) + argp.add_argument( + "--prompt_midi_path", + help="path to midi file", + required=True, ) - argp.add_argument("--prompt_midi_path", help="path to midi file") argp.add_argument( "--prompt_duration", help="length of the input MIDI prompt, in seconds", type=int, - default=20, + default=15, ) argp.add_argument( "--variations", @@ -83,22 +89,30 @@ def _parse_conditioned_generate_args(): help="backend for inference", ) argp.add_argument( - "--checkpoint_path", help="path to model used for decoding" + "--checkpoint_path", + help="path to model used for decoding", + required=True, + ) + argp.add_argument( + "--prompt_midi_path", + help="path to midi file", + required=True, ) - argp.add_argument("--prompt_midi_path", help="path to midi file") argp.add_argument( "--prompt_duration", help="length of the input MIDI prompt, in seconds", type=int, - default=20, + default=15, ) argp.add_argument( "--embedding_model_checkpoint_path", help="path to model checkpoint used for embeddings", + required=True, ) argp.add_argument( "--embedding_midi_path", help="path to MIDI file used for conditioning", + required=True, ) argp.add_argument( "--variations", @@ -280,7 +294,7 @@ def generate(args): model = _load_inference_model_mlx( checkpoint_path=args.checkpoint_path, config_name="medium", - strict=True, + strict=False, ) results = sample_batch_mlx( model=model, From 225203991dc444a60845400ab7d581f63cf341bd Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 4 Jun 2025 15:29:36 +0000 Subject: [PATCH 60/72] update cli flags to standard unix format --- aria/run.py | 85 ++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 61 insertions(+), 24 deletions(-) diff --git a/aria/run.py b/aria/run.py index 7c7a687..71c319b 100644 --- a/aria/run.py +++ b/aria/run.py @@ -274,8 +274,8 @@ def generate(args): model = _load_inference_model_torch( checkpoint_path=args.checkpoint_path, config_name="medium", - strict=True, - ) # Might want strict = False + strict=False, + ) results = sample_batch_t( model=model, tokenizer=tokenizer, @@ -316,7 +316,6 @@ def generate(args): print(f"Results saved to {os.path.realpath(args.save_dir)}") -# TODO: Double checking during training we didn't do a weighted global sum def _get_embedding( embedding_model_checkpoints_path: str, embedding_midi_path: str, @@ -419,26 +418,39 @@ def conditioned_generate(args): print(f"Results saved to {os.path.realpath(args.save_dir)}") -# TODO: Add turn - to -- flags def _parse_midi_dataset_args(): argp = argparse.ArgumentParser(prog="aria midi-dataset") - argp.add_argument("dir", help="directory containing midi files") - argp.add_argument("save_path", help="path to save dataset") - argp.add_argument("-r", action="store_true", help="recursively search dirs") argp.add_argument( - "-s", action="store_true", help="shuffle dataset", default=False + "dir", + help="directory containing midi files", + ) + argp.add_argument( + "save_path", + help="path to save dataset", ) argp.add_argument( - "-metadata", + "--recursive", + action="store_true", + help="recursively search dirs", + ) + argp.add_argument( + "--shuffle", + action="store_true", + help="shuffle dataset", + ) + argp.add_argument( + "--split", + type=float, + help="create train/val split", + required=False, + ) + argp.add_argument( + "--metadata", nargs=2, metavar=("KEY", "VALUE"), action="append", help="manually add metadata key-value pair when building dataset", ) - argp.add_argument( - "-split", type=float, help="create train/val split", required=False - ) - return argp.parse_args(sys.argv[2:]) @@ -451,10 +463,10 @@ def build_midi_dataset(args): MidiDataset.build_to_file( dir=args.dir, save_path=args.save_path, - recur=args.r, + recur=args.recursive, overwrite=True, manual_metadata=manual_metadata, - shuffle=args.s, + shuffle=args.shuffle, ) if args.split: @@ -469,19 +481,44 @@ def build_midi_dataset(args): # TODO: Add turn - to -- flags def _parse_pretrain_dataset_args(): argp = argparse.ArgumentParser(prog="aria pretrain-dataset") - argp.add_argument("-load_path", help="path midi_dict dataset") - argp.add_argument("-save_dir", help="path to save dataset") argp.add_argument( - "-tokenizer_name", help="tokenizer name", choices=["abs", "rel"] + "--load_path", + help="path midi_dict dataset", + required=True, + ) + argp.add_argument( + "--save_dir", + help="path to save dataset", + required=True, + ) + argp.add_argument( + "--tokenizer_name", + help="tokenizer name", + choices=["abs", "rel"], + required=True, + ) + argp.add_argument( + "--seq_len", + help="sequence length (tokens)", + type=int, + default=4096, ) - argp.add_argument("-l", help="max sequence length", type=int, default=4096) - argp.add_argument("-e", help="num epochs", type=int, default=1) argp.add_argument( - "-sep_sequences", + "--num_epochs", + help="number of epochs to build", + type=int, + default=1, + ) + argp.add_argument( + "--sep_sequences", help="start each with a new entry", action="store_true", ) - argp.add_argument("-embedding_dataset_path", required=False) + argp.add_argument( + "--embedding_dataset_path", + help="path to embedding dataset - same format as EvaluationDataset", + required=False, + ) return argp.parse_args(sys.argv[2:]) @@ -508,8 +545,8 @@ def build_pretraining_dataset(args): PretrainingDataset.build( tokenizer=tokenizer, save_dir=args.save_dir, - max_seq_len=args.l, - num_epochs=args.e, + max_seq_len=args.seq_len, + num_epochs=args.num_epochs, midi_dataset_path=args.load_path, separate_sequences=args.sep_sequences, file_embeddings=file_embeddings, From 5c0a435fe5e5edb34a34a1b1f6507708e1be1a5a Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 4 Jun 2025 15:55:54 +0000 Subject: [PATCH 61/72] migrate to pyproject.toml --- Makefile | 9 ----- aria/training/train.py | 82 ++++++++++++++++++++++-------------------- requirements-dev.txt | 2 -- requirements.txt | 6 ---- setup.py | 25 ------------- 5 files changed, 44 insertions(+), 80 deletions(-) delete mode 100644 Makefile delete mode 100644 requirements-dev.txt delete mode 100644 requirements.txt delete mode 100644 setup.py diff --git a/Makefile b/Makefile deleted file mode 100644 index 774c3fc..0000000 --- a/Makefile +++ /dev/null @@ -1,9 +0,0 @@ -.PHONY: test -test: - python -m unittest tests/test_*.py - - -.PHONY: format -format: - black --line-length 80 ./aria - black --line-length 80 ./tests diff --git a/aria/training/train.py b/aria/training/train.py index 4b2a074..67001a2 100644 --- a/aria/training/train.py +++ b/aria/training/train.py @@ -38,25 +38,25 @@ # For example usage you could run the pre-training script with: # # accelerate launch [arguments] aria/train.py train \ -# small \ -# -train_data data/train \ -# -val_data data/val \ -# -epochs 10 \ -# -bs 32 \ -# -workers 8 +# medium \ +# --train_data data/train \ +# --val_data data/val \ +# --epochs 10 \ +# --bs 32 \ +# --workers 8 # # You could resume a run from an accelerate checkpoint with: # # accelerate launch [arguments] aria/train.py resume \ -# small \ -# -train_data data/train \ -# -val_data data/val \ -# -cp_dir models/epoch5_step0 \ -# -r_step 0 \ -# -r_epoch 5 \ -# -epochs 5 \ -# -bs 32 \ -# -workers 8 +# medium \ +# --train_data data/train \ +# --val_data data/val \ +# --cp_dir models/epoch5_step0 \ +# --r_step 0 \ +# --r_epoch 5 \ +# --epochs 5 \ +# --bs 32 \ +# --workers 8 def setup_logger(project_dir: str): @@ -153,7 +153,7 @@ def _get_optim( num_epochs: int, steps_per_epoch: int, warmup: int = 100, - end_ratio: int = 0.1, + end_ratio: float = 0.1, ): optimizer = torch.optim.AdamW( model.parameters(), @@ -799,26 +799,30 @@ def _load_state_dict(_tokenizer: Tokenizer): def parse_resume_args(): argp = argparse.ArgumentParser(prog="python aria/train.py resume") argp.add_argument("model", help="name of model config file") - argp.add_argument("-train_data", nargs="+", help="path to train dir") - argp.add_argument("-val_data", help="path to val dir") - argp.add_argument("-cp_dir", help="checkpoint dir", type=str, required=True) argp.add_argument( - "-use_embeddings", help="prepend embeddings", action="store_true" + "--train_data", nargs="+", help="path to train dir", required=True ) - argp.add_argument("-r_step", help="resume step", type=int, required=True) - argp.add_argument("-r_epoch", help="resume epoch", type=int, required=True) - argp.add_argument("-epochs", help="train epochs", type=int, required=True) - argp.add_argument("-bs", help="batch size", type=int, default=32) + argp.add_argument("--val_data", help="path to val dir", required=True) argp.add_argument( - "-grad_acc_steps", + "--cp_dir", help="checkpoint dir", type=str, required=True + ) + argp.add_argument( + "--use_embeddings", help="prepend embeddings", action="store_true" + ) + argp.add_argument("--r_step", help="resume step", type=int, required=True) + argp.add_argument("--r_epoch", help="resume epoch", type=int, required=True) + argp.add_argument("--epochs", help="train epochs", type=int, required=True) + argp.add_argument("--bs", help="batch size", type=int, default=32) + argp.add_argument( + "--grad_acc_steps", help="gradient accumulation steps", type=int, default=1, ) - argp.add_argument("-workers", help="number workers", type=int, default=1) - argp.add_argument("-pdir", help="project dir", type=str, required=False) + argp.add_argument("--workers", help="number workers", type=int, default=1) + argp.add_argument("--pdir", help="project dir", type=str, required=False) argp.add_argument( - "-spc", help="steps per checkpoint", type=int, required=False + "--spc", help="steps per checkpoint", type=int, required=False ) return argp.parse_args(sys.argv[2:]) @@ -827,26 +831,28 @@ def parse_resume_args(): def parse_train_args(): argp = argparse.ArgumentParser(prog="python aria/train.py train") argp.add_argument("model", help="name of model config file") - argp.add_argument("-train_data", nargs="+", help="path to train dir") - argp.add_argument("-val_data", help="path to val dir") argp.add_argument( - "-cp_path", help="path to checkpoint", required=False, default=None + "--train_data", nargs="+", help="path to train dir", required=True + ) + argp.add_argument("--val_data", help="path to val dir", required=True) + argp.add_argument( + "--cp_path", help="path to checkpoint", required=False, default=None ) argp.add_argument( - "-use_embeddings", help="prepend embeddings", action="store_true" + "--use_embeddings", help="prepend embeddings", action="store_true" ) - argp.add_argument("-epochs", help="train epochs", type=int, required=True) - argp.add_argument("-bs", help="batch size", type=int, default=32) + argp.add_argument("--epochs", help="train epochs", type=int, required=True) + argp.add_argument("--bs", help="batch size", type=int, default=32) argp.add_argument( - "-grad_acc_steps", + "--grad_acc_steps", help="gradient accumulation steps", type=int, default=1, ) - argp.add_argument("-workers", help="number workers", type=int, default=1) - argp.add_argument("-pdir", help="project dir", type=str, required=False) + argp.add_argument("--workers", help="number workers", type=int, default=1) + argp.add_argument("--pdir", help="project dir", type=str, required=False) argp.add_argument( - "-spc", help="steps per checkpoint", type=int, required=False + "--spc", help="steps per checkpoint", type=int, required=False ) return argp.parse_args(sys.argv[2:]) diff --git a/requirements-dev.txt b/requirements-dev.txt deleted file mode 100644 index 28a28f7..0000000 --- a/requirements-dev.txt +++ /dev/null @@ -1,2 +0,0 @@ -flake8 -black diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 6c397c7..0000000 --- a/requirements.txt +++ /dev/null @@ -1,6 +0,0 @@ -ariautils @ git+https://github.com/EleutherAI/aria-utils.git -torch >= 2.3 -accelerate -jsonlines -tqdm -safetensors diff --git a/setup.py b/setup.py deleted file mode 100644 index bc558f2..0000000 --- a/setup.py +++ /dev/null @@ -1,25 +0,0 @@ -import os - -import pkg_resources -from setuptools import find_packages, setup - -setup( - name="aria", - py_modules=["aria"], - version="0.0.1", - description="", - author="", - packages=find_packages() + ["config"], - include_package_data=True, - entry_points={ - "console_scripts": [ - "aria=aria.run:main", - ], - }, - install_requires=[ - str(r) - for r in pkg_resources.parse_requirements( - open(os.path.join(os.path.dirname(__file__), "requirements.txt")) - ) - ], -) From 0878402c8c23e3891a09ca95c7b7865e0b1a7082 Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 4 Jun 2025 15:59:18 +0000 Subject: [PATCH 62/72] add toml --- pyproject.toml | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 pyproject.toml diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..199cfb8 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,41 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "aria" +version = "0.0.1" +description = "" +authors = [{name = "Louis Bradshaw", email = "loua19@outlook.com"}] +requires-python = ">=3.11" + +dependencies = [ + "ariautils @ git+https://github.com/EleutherAI/aria-utils.git", + "torch>=2.3", + "mlx", + "safetensors", + "jsonlines", + "tqdm", +] + +[project.optional-dependencies] +dev = ["black"] +train = ["accelerate"] +eval = ["transformers", "torchaudio", "mido"] +demo = ["python-rtmidi"] +all = ["black", "accelerate", "transformers", "torchaudio", "mido", "python-rtmidi"] + +[tool.black] +line-length = 80 +target-version = ["py311"] +include = '\.pyi?$' + +[project.scripts] +aria = "aria.run:main" + +[tool.setuptools.packages.find] +where = ["."] +include = ["aria", "aria.*"] + +[tool.setuptools.package-data] +aria = ["../config/*.json", "../config/models/*.json"] From f86435f7c3b4ce5301aa2c41df432e120135794f Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 4 Jun 2025 16:00:29 +0000 Subject: [PATCH 63/72] remove old plan --- aria/eval/__init__.py | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/aria/eval/__init__.py b/aria/eval/__init__.py index 6b4e906..e69de29 100644 --- a/aria/eval/__init__.py +++ b/aria/eval/__init__.py @@ -1,25 +0,0 @@ -""" -Plan: - -Embeddings experiments: - ARIA-MIDI - Classification of genre classical vs jazz (vs other?) - ARIA-MIDI - Classification musical time period (classical, baroque, ect...) - ARIA-MIDI - Classification of top 5 composers - ARIA-MIDI - Classification of top 5 pianists - - Pianist8 - Classification -- should work - VGMIDI - Classification -- Probably won't work as it's multi-track - WikiMT - Classification -- No idea - -Ablation comparisons: - Frozen classical embeddings (both mean and last-token from pretrained model) - Aria finetuned on these specific classification tasks (define new token) - TODO - Aria trained from scratch on these specific classification tasks (define new token) - TODO - Aria finetuned with contrastive learning - TODO - Aria trained with contrastive learning without next-token pretraining - TODO (Maybe skip) - -Other model comparisons: - Clamp2 or Clamp3 - MusicBERT - -""" From d5f46b9c40adde539b86e408a32aa203c891098c Mon Sep 17 00:00:00 2001 From: Louis Date: Mon, 9 Jun 2025 17:22:06 +0000 Subject: [PATCH 64/72] add README draft --- .gitignore | 3 + README.md | 137 +++++++++++++++++++++++--------- example-prompts/classical.mid | Bin 0 -> 10045 bytes example-prompts/nocturne.mid | Bin 0 -> 7444 bytes example-prompts/pokey_jazz.mid | Bin 0 -> 14240 bytes example-prompts/smooth_jazz.mid | Bin 0 -> 6961 bytes example-prompts/yesterday.mid | Bin 0 -> 6536 bytes 7 files changed, 102 insertions(+), 38 deletions(-) create mode 100644 example-prompts/classical.mid create mode 100644 example-prompts/nocturne.mid create mode 100644 example-prompts/pokey_jazz.mid create mode 100644 example-prompts/smooth_jazz.mid create mode 100644 example-prompts/yesterday.mid diff --git a/.gitignore b/.gitignore index dbc6464..946afa1 100644 --- a/.gitignore +++ b/.gitignore @@ -167,3 +167,6 @@ fluidsynth/ tests/test_results lightning_logs/ .vscode/ +paper +hf +_scripts diff --git a/README.md b/README.md index 7b11af7..4475208 100644 --- a/README.md +++ b/README.md @@ -1,64 +1,125 @@ -# gpt-aria +# Aria -[Discord](https://discord.com/invite/zBGx3azzUn) +This repository contains training, inference, and evaluation code for the paper *Scaling Self-Supervised Representation Learning for Symbolic Piano Performance*, as well as implementations of our real-time piano continuation demo. *Aria* is a pretrained autoregressive generative model for symbolic music, based on the LLaMA 3.2 (1B) architecture, which was trained on ~60k hours of MIDI transcriptions of expressive solo-piano recordings. Alongside the base model, we are releasing a checkpoint finetuned to improve generative quality, as well as a checkpoint finetuned to produce general-purpose piano MIDI embeddings using a SimCSE-style contrastive training objective. -A repository containing resources for pre-training, fine-tuning, and evaluating musical (MIDI) transformer models. +📖 Read our [release blog post](https://example.com/) and [paper](https://example.com/) +🤗 Access our models via the [HuggingFace page](https://huggingface.co/loubb/aria-medium-base) +📊 Get access to our training dataset [Aria-MIDI](https://huggingface.co/datasets/loubb/aria-midi) and train your own models -***Note that this project is under active development*** +## Installation -## Description +Installation requires Python 3.11+. To install the package and all dependencies with pip: -The main goal of the gpt-aria project is to create a suite of powerful pre-trained generative (symbolic) music models. We want to investigate how modern training (pre-training & fine-tuning) techniques can be used to improve the quality/usefulness of such models. Alongside this we are building various data (MIDI) preprocessing tools, allowing **you** to easily fine-tune our models on your own data. +```bash +git clone https://github.com/EleutherAI/aria +cd aria +pip install -e ".[all]" +``` -If you are new to symbolic music models, a good place to start are the following projects/blogposts by Google Magenta and OpenAI: +## Quickstart -- [Music Transformer](https://magenta.tensorflow.org/music-transformer) -- [MuseNet](https://openai.com/research/musenet) +Download model weights from the official HuggingFace page for our pretrained model, as well as checkpoints finetuned for piano-continuation and generating MIDI-embeddings: - Long story short: Transformer + MIDI + GPUs = 🎵 x ∞ +- `aria-medium-base` ([huggingface](https://example.com/), [direct-download](https://example.com/)) +- `aria-medium-gen` ([huggingface](https://example.com/), [direct-download](https://example.com/)) +- `aria-medium-embedding` ([huggingface](https://example.com/), [direct-download](https://example.com/)) -## Installation +### Inference (Prompt Continuation) -Make sure you are using Python 3.10+. Note that I haven't explicitly developed this project for anything other than Linux. If you are using Windows, things might not work properly. In this case I suggest installing using WSL. +We provide optimized model implementations for PyTorch (CUDA) and MLX (Apple Silicon). You can generate continuations of a MIDI file using the CLI, e.g., using CUDA (Linux): -``` -git clone https://github.com/eleutherai/aria -cd aria -pip install -e . +```bash +aria generate \ + --backend torch_cuda \ + --checkpoint_path \ + --prompt_midi_path \ + --prompt_duration \ + --variations \ + --temp 0.98 \ + --min_p 0.035 \ + --length 2048 \ + --save_dir ``` -## Inference +Since the model has not been post-trained with instruction tuning or RLHF (similar to pre-instruct GPT models), it is very sensitive to input quality and performs best when prompted with well-played music. To get sample MIDI files, see the `example-prompts/` directory or explore the Aria-MIDI dataset. For a full list of sampling options: `aria generate -h`. If you wish to do inference on the CPU, please see the platform-agnostic implementation on our HuggingFace page [link]. -You can find preliminary checkpoints at the following locations +### Inference (MIDI embeddings) -Finetuned piano-only checkpoints (improved robustness): +You can generate embeddings from MIDI files using the `aria.embeddings` module. This is primarily exposed with the `get_global_embedding_from_midi` function, for example: -``` -large - https://storage.googleapis.com/aria-checkpoints/large-abs-inst.safetensors -``` +```python +from aria.embeddings import get_global_embedding_from_midi +from aria.model import TransformerEMB, ModelConfig +from aria.config import load_model_config +from ariautils.tokenizer import AbsTokenizer -Pretrained checkpoints: +# Load model +model_config = ModelConfig(**load_model_config(name="medium-emb")) +model_config.set_vocab_size(AbsTokenizer().vocab_size) +model = TransformerEMB(model_config) +state_dict = load_file(filename=CHECKPOINT_PATH) +model.load_state_dict(state_dict=state_dict, strict=True) -``` -large - https://storage.googleapis.com/aria-checkpoints/large-abs-pt.bin -medium - https://storage.googleapis.com/aria-checkpoints/medium-abs-pt.bin -small - https://storage.googleapis.com/aria-checkpoints/small-abs-pt.bin +# Generate embedding +embedding = get_global_embedding_from_midi( + model=model, + midi_path=MIDI_PATH, + device="cpu", +) ``` -You can then sample using the cli: +Our embedding model was trained to capture composition-level and performance-level attributes, and therefore might not be appropriate for every use case. +## Real-time Demo + +In `demo/` we provide CUDA (Linux/PyTorch) and MLX (Apple Silicon) implementations of the real-time interactive piano-continuation demo showcased in our release blog post. For the demo we used an acoustic Yamaha Disklavier piano with simultaneous MIDI input and output ports connected via a standard MIDI interface. + +❗**NOTE**: Responsiveness of the real-time demo is dependent on your system configuration, e.g., GPU FLOPS and memory bandwidth. + +A MIDI input device is not strictly required to play around with the demo: By using the `--midi_path` and `--midi_through` arguments you can mock real-time input by playing from a MIDI file. All that is required are MIDI drivers (e.g., CoreMIDI, ALSA) and a virtual software instrument (e.g., Fluidsynth, Pianoteq) to render the output. + +Example usage (MLX): + +```bash +MIDI_PATH="example-prompts/pokey_jazz.mid" + +python demo/demo_mlx.py \ + --checkpoint \ + --midi_path ${MIDI_PATH} \ + --midi_through \ + --midi_out \ + --save_path \ + --temp 0.98 \ + --min_p 0.035 ``` -aria sample \ - -m large \ - -c \ - -p \ - -var \ - -trunc \ - -l \ - -temp 0.95 \ - -e + +## Evaluation + +We provide the specific files/splits we used for Aria-MIDI derived linear-probe and classification evaluations. These can be downloaded from HuggingFace ([direct-download](https://example.com/)). Class labels are provided in `metadata.json` with the schema: + +```json +{ + "": { + "": { + "": "", + … + }, + … + }, + … +} ``` -You can use `aria sample -h` to see a full list of options. If you wish to sample from a pretrained checkpoint, please use the `-pt` flag. +## License and Attribution +The Aria project has been kindly supported by EleutherAI, Stability AI, as well as by a compute grant from the Ministry of Science and ICT of Korea. Our models and MIDI tooling are released under the Apache-2.0 license. If you use the models or tooling for follow-up work, please cite the paper in which they were introduced: +```bibtex +@inproceedings{bradshawscaling, + title={Scaling Self-Supervised Representation Learning for Symbolic Piano Performance}, + author={Bradshaw, Louis and Fan, Honglu and Spangher, Alex and Biderman, Stella and Colton, Simon}, + booktitle={arXiv preprint}, + year={2025}, + url={https://arxiv.org/abs/2504.15071} +} +``` \ No newline at end of file diff --git a/example-prompts/classical.mid b/example-prompts/classical.mid new file mode 100644 index 0000000000000000000000000000000000000000..4d4db8991eacd0ab6e2093bb713c4c86f3296fc7 GIT binary patch literal 10045 zcmai)ZEGA?mdEd8z24z62s0ExvBzbj5_Tz4X~mV~YTD&)RV2#o?r7Aj2q$rbS&R`T z3r3h>3Bs(YYB}D0we=PD8_Y-854G=Sf9IU)wwwe8CX;mCd(J%%|MPUO4o_aabk1Gp zpWprN{|-;y{F`(4@4LVK%Rl_n|J-%|>;8NC*S&YXU5%SJ>t8g>TS>E&2fv-x7hBG4 zCH2L+bL(ldkvo^?{(cs(@|zXsww?Ryo~!Tfy?fSfW%c=X)~x6*R}}nWLuO1uGcFt- zBZ%0*%6;d0nY))MgFo!Od)WpU1m{B5EX&^9?N`3XuA;fNo7HEs6q=2FH~*fyLE%?I z#O8afIk&6Q{3OC`hWcXDxy`&;8BhGcxvF%7QrYG1psY`Kkxw?1&ED9#aix3MR>(Fv zqwx0&*$>kd_BnSj=l-(YF6;BWY*ye4an*5$i#wVqaaAAp`A1AqHVfDI`FWpFZL`7Af_0U1=J8AY#oI9yJKQIIW zcc&@IUFSCVkJ25MZd!_|eS5E{KgZwC+K^+|qFIJ@QMiN3jV8V<>kC+)q4jPtyU7yL z!nfpqIX?00nZJ`p9J4Svc0qze=cXqA&fdHCuiM9khvs2`EDA}q8~Z$o01dogdgtTT zpW3Z7T1~_r=0eCGLuotCVV1T<=JPb(PA>m6$Y;GF7)(Up?KlF7hA8-=UC|C)eTsI? zayuwyH}ZhYsccOyxBB^PwUpD+`1AJa;pm|1K=r9E z&qS9mMVF0skT(zDG7M3k6CI&>&@W7x)UTxIG_RhjhiJLon=ncRzZ zX@nYUiTG2o0mYdpv(@$z-4OqLsyqDcIJRJ5UkyaxPg`M=d1A1adhB#C%x5=bbMDH* zjuXBZ_jbp9=!n%qB)ZuS^X3F0dIbBRXjX(58a*uidgAY>K*$FzXh3l@Hb5|oypP)5 zBnn_Ez1%8r>mZLMM?ng<)VVM;ICVf8sqsUMM_r!&eBMhVTMr4J{5=IHPJ@AHNJe{K zwPhKhAxfYjyF+r>-!GfD@aEeh08C1#d>c~Y2xhPEIWKyC)k-QM$H0aV=q1fcNkSNw zh+O*3#IGb0U2uavThW(tSB8)kU+2% zqzbiZI{?BlRGj-^{3weok8Q?oB^J+^fxE0D;C$AS^&UC*2w@Y~m!P}+?x39AsF0Nd zQ#(-LX0k$ytWa<9NKXwaw4(tg!ED683~n46rPBvMwB)WlfOc@8&wrR`quMT+<+AMIfDg9UnDFvD0>=$OK$i=eYr>$aF9fb(c2j>TJ~B55I1{Hx zIYPAsXzOF*=>@SYJFq%_+)bNhK(SO}S~W_<>UPQoN8PkIF!ta{sLl_+tU{^9LazbL& z15F4UXe$Z79icsFCFE0I#6O=YX@|jG;V5wxh$tc?CeC0Mq)_t}vq-ds4@<*b%!B}W zCOJd`S0Q;drcCkr6ktZ8bDn9=*FatW;)C@Fp0Ha%VB85drF7@AiTdIZos zZ-!vQ!BjVFHEqYP{tR11?RIFcjZ07B-lKAjXdQ~YKCcQ1BR-A~NNwy4H(B+UDn8x_JLZK0P!61Urw z{2X=msUpKfBx$fgb@hDEf(o7aJ{JYHa(WRxlxCn=&XJy@NZ#YJ*_f2k{U8mY#S&_; zW;Te$27uBDpZ4URi75vS0(Wy z7(Fhd^&(zqVk`g$y_OY3F$|4y7Q?O{!H;5!Z69iTO0(EJM?e6+u3!raez+{!E2S5! z#BC5{aXR#8(4W$`zPeWgi;e|lS4;;23)*#B#pW97(lwrHQ;HEQe^fzA^yeQ-8>%BT zwPIR{4T|>21(Er)HqhuERdKJX&r3T+n@=Phd-iJ)CIS$}0;F^rw~~5ZNA4Jmd;e=w;B5uAaWTj|oxk z4LNv4!GU{?(#UEeqmmP^0mWr{@<&B|+6P$-NK@>7Jc)%gtw&NgX>y>SV*h{znkA{R z_#S1BDi1>lPbC;Ck8+2T=B=zu{A!T21*KI2Ho^khW0gF#hn8YAEnMU{Uzqk0WLjyy zqaQnl3&k>%lm##N!o})rXs1RmA<-=D6wwr;<+xIW5(p4Qg$yunITn^{MY@P{d;%j% zrO6WEPz_IH$?K7X#sSk7VPT|`(ho7t0UTxgYEI;H=~q;)1;D$!*ZE~!c44A^W&9iz z?g(`DK@6C~JFoB}8jE>^c&jVrztJ#1FF~60|z>C%K@pADJh=) zxNsqTV?h%j#SxdM(R%9kDSFh^)Q#qJi(@UL8E?*i+1}B%8HK4oL%IWL<7gAqGg&4(NbXPj^qMGz-UTh57D{4I4PqEUQ`w2u7uVE z`7e5Rh^%7@^+F+hKN{ zBmtNaVwC&CUv&tTdJg%HRm5)@H8Z!SMtx1(EVU&yCRmWzo{ObpsoWZoz!7TOL8$Ix-jgaq{0Z12 zePLuSFVu^+2@N=RQPL>ux92(aCDx9lNMdcgFVmPh(p-a4w?({WWuGBcO=foGt8{2X zE&1Y#3

;SEF~f&6+ntmHFk|t>z!ELjX5uz@*jzhIZL8DakNf4>To!sXWbN!&KTK z&qP_m<%GR>OHm*P}}usN^}Yd5sWKtYnJKv_4h*Nmd4FJxS6kI5yG+0adr4 z@MM5NzB56khK#`pp0pUffY5?UASVcxO$$`Yj@fmwfwmJ<%_7q91`qUZ{}wi(M|=eL z?O`@$m!RPsR4Q8{gvhcpI1fQskb5@CeROqtxGiubnz0N%WSI+JI&dg@h5k-xrS*r& z&A24HKm$3ITx?AEw1^i#I`4$JcDh}>yln+UtSdLV8lsSPmiX8nh?V62xITZX;lo~% zA{_?#%S@>|Y8h|m8VN%NSrdFDX>-zH%P3fe=dRo!&#r$z@dS~?bK%wsH!a+W+`|4w zH;++BuivVHi}to-`p|GYYbAqF!xKM9e&RUknQ|r!VSHk>UGo z>=vEQiC|-my>J>(aO6(-sIx(_@G~9*%#<_3Q_g8B)l%|Z$ONg%_ zupXwS1BqGq$jyL%NRibeJCKkGL`ZTmEjfl62f5=hq2@xWqBMsh0>iW}k8||-6%X)h z)VPynV=ihGLPl?*S~PolP&{#+uYaz9M{(WEPm z24YdKi~QyR2Y)}+U_wI;4LeU26MKkZ5_D*FT*>ImQ`BbZRz6j4Q6kWgEPzyKc&dJT*{+y+=PjYwM)g5j!VfclIeua~rf&@&*&vC-j zLUR#L10$1&33sLBCGBH&g%Xw&KF+elHq}0c0Pd5rsHZldbZI4sYj|AfthFZ0Dt4`gGncndB6O4s%>RWG%7^0Cv3t>*n~Ckhx~|u$LNz~80qu`3DP<@uW*j&z*7ZhsB;!> zfewa>q_0D?mAKV3nz~B*p;Kw-`clo@4r9@jXIklm1`lFK}{pShQr+HMo9qhlGCe5H$Z&Q=1E z-8xD^pk1Mr9Xd2)F4Xok5}M~m!6$?ufxwbYYn{#UBIyYAR?=uQeumlUBj}RoBl-+S zVw_ZqXQoZ%7xgijh>jm3UY;bouu(TjdPumy(Cp4>tm%S1&$t0?_n{^ua0m*pquQY%AHeke4m)p_qw(k$6x8#tZBNo%%D-GLkomKBsgCTOYyg{3mQnVaxpS z$tD55_5F;wpY|No#ICkSeIx3!SH}#}E1`(Vm*76k2RW zhsCG`H##)IU11kkJfa5v+(S0mc_t``;O`kggxK*9trh>~Igj-K*Y_g*D-*cN%~Tc( zu;STM`OVcBKT^F4<4EF-yA#8C8T4&e{uZUHV9Fe9y^>evUfPTWbP^L%V`ZM{MzhY> z`sy|Ni^)k6VEmH9lIAUsBg`;dR&gXZVR{E9O%f8_(d3QtCPK0$^QY`}aJ82eQ#gmhxyO65)(N<{s-?s|Es^6`sMjEIy4L2 z5CD44wLV&uV+JPqP8np%?*!$KXI_x^1*P2z=O+8jp8J91oVp_s=|6v5P2%O` z6Dc^hJClcZOaXtf2HR7=mAnt6nlYgum#D$t!XRo5D@f)w(}Q#YD;iPE@`DM&ljrQm z7RV05?5muhOUbG-&G1*tOFqcZ@1 z6hc~)BV@9;{5^g}Uj7=U$Y1`#WE}*EU>b=xgyoHX*6fjuzcdXI2MWx~#@GBAgUrMH z6-<#e<9%)eN}k>;d48W?R)99lfm5&rB>X8tv(k=Hf0WO9N7?M=VRrea<8b+6n*BU_ jk^86F?Ca9~ymR0HxMQ13F{CBwV{N-$C?of`e|G;57}w98 literal 0 HcmV?d00001 diff --git a/example-prompts/nocturne.mid b/example-prompts/nocturne.mid new file mode 100644 index 0000000000000000000000000000000000000000..00d41ff9ddd95d57eb8851a01908410b7d97df9a GIT binary patch literal 7444 zcma)>?P?@Rc7|hRu^y79*&m6}3`-(LtyW5^l%>!|sx_gKLR}qG>g=jc7pfUk2rn!` zG;33c!I+|H+(cII8vn~%@Ey!0_@eyV>*qwMdbA&2FU-)TsEG6Vo^#%a8eG4-RZ9Jw z|9|$g{~lcb@J~wp@_*Ft|Nd|O_CMFuf2rpB$BU1@yJ4HC}?<4YkW|v{f(%_%Ag^tEuz|@ z=|o$)J$47xvp-Uec(j{#;A(tterSWO>xj|#gU@GWvwG!nTSV7!ucuVc=AyLlhhmp* zmV>YJ-vn*xl~Pxnl(lj;s{EmZz&6Wu-CS(t(^{#z%CY88O+oZPsZp@WZ(n?zw0k%f zQb1^o@gq{uA5>|t^j(RJX}(6=X`!yHiQ!}wDd^ZzYELN>WYz6J-Qu5a&q%2yRbqG1 zET3{XPsHHTKdH((n4TP&x>BlgCVU~E_pPcbEq_+wtMRRTc6!%l{c0Lc-MX>g-pnU4 zvB<{~zCZ)#ca`b~OOke2`)y2_w!5V60Jn}Y=@f(=D|L*wgE|qb3jhAHY@VLzw5klp zol=`h4V5}zn^KoSTHc8bW+3 zb`W9*dQZxVkru<*P_oIuC zwP3yS>(3!3U?_ zRQ~;rQac1#mI-S;5jc*g01hA7^^C7+np%{n?e)aq>HNF)q)fgc7=`Z}cY8+b66m{x z$~09bXg8XSZz=W7!h-o%Euuw@qtDYo)SzrW^vonIWm`oiOO*r|SqQogCk^U$TQIn1 z=HJmaj4|v((x6fdj$=7APR3)uW~2I8n#Omp$NrGE{MNahG8v7ND46t8pHT$n;Uuas z0d3=;@RnxqkL|umkIMOP&Wt(@IX_}_x+)g+gs8ra-PkC!O5K>|-Akh-l~vuZ%$`Xf zCy}fWP#{4liDS~fsr6e$D$T!Yr`1o9@L#6U?a6)LmYESS~lT6<= zZgy7r?U2LJ-<-tk$V_4;5g3z6O8i>ebJKoRhw!HdX8!ZnRe(;F>QVqe4vBM&EkF!j zkH)0k0iaJ`v@gK$DY!Gvccsj$=7OZwi`a#}M z`4yx!*X?Qf=?e+IQA3-a7oij7pMJ@b)v`LU&AXm9eKV!-D_ett2!U9p!_W=l6mspQ zPKQof9{haMeCCyxiZpA@DO=ui6fYHvKCCMaj_RtQUr2P3djjT z+QKEo|9@_6F{NX#kWrx+8`*6P7~8BJj;SdJ9t%-ldyWtmhn_ar9o$ z$EfTv{D^w7XyXqfDKJCh*RZswg@bT#-!Ro5I#TRAeq~hE_|t(|)U{^Y{+Em=W47{I zvSQP}W<1fG{Gl+jQ7{(DL2~ab{|>4IrYWsenAs_Y+y8GoO$RC?0(E;rHh|x#3#f1g zdD|h+{)tSU4yyJjha+Q#b{68TMRJ|)n&h$+pixk0c-*x>2`HPj`w(clre)r1f5`M% zJ2wU(Ueb-2$aS)rKTKc(_Sk1MAj32JZrfc3*0J&Z-7rwzcWlllFpZ(?p(D=P0}9}* z6x<(!$di)4geq%WLb{WUJdAzguoW1_qM+;%fGx4GB(~s#N?L{zF>9M2#N5ZoLfNc6 zQ`#!iQ*^r&?n=hF@st7v+T?MrPs`hLozf53ov_<49V?-5^Jgv8_5Knj(_@xHGvC8u z`&ZVlFo5qQCUgMlibCiJCy5~+(YOTeDJuMZdsR1`YW(Tf9oKXQb@OyMZdPml-z+Di zDx|#5aFbXffNYUj{6#PtFv%gL&!B9U*UW6RXbr(ufJ#!)^66hmx;n0=pW6k-SJu_F z1B%!6?ECSjp0#z)&aT{aqwr4#wMJ)Dh&GZ(Xd`UD4{A+oLu*BMhdWt}F#HGH%Dq&s zYIKl?W+9zICjqmO!V2vLWpcOUtK2Wicczv`5UC(l5N1K3f6hOF$u40~$NKm6*x&Rk zy`^LS&mJcfX{e0hjsYVg#6n;PB?-NP+uI4^=@pDV=@qG@k2AL9B37hV=3$;hFG4qs4?Cu_V?u5#Ypt9X#!VQ(}lPCE|3lwXJx@}%x+G)om zRT!VFBH{B4iA95BOxbC>$Nd}{+ZP(^YG`b%FS4^2P{(di2e;)w+74>D%yNo^bKQ{Z z0G7QFk^<_RWr?XxXe!P@Hl+iT69mbo1)G&B)$lEhf0z-*Z)2Tvnr0#{gw~BH0q)v3WX(??Mahv)zj6N1RI>+qOUF5J&|6`sN7jG>1Z^!yQ5NWesIj#EKyU5* ztF3at{NZ$hh09$}fEzl~8Id z)#iDT=sdZe$gIXt6vUjoRO;J>2)c6E?3WngxrXAxP9)URK{eeflX{$?8th;+zmNMWQ5ns3y<*HoR`V?h$k-@BY5r8peDh3yB7kc0a|D&bIk1ta5iu}iBAeAYGBTY1vgq3+#ZdSFBhL6Zy;cm^rs z={<^scaLBqusf)QTqbxc#7ME&CZ~(*4oeuj^=HSEI1RHw4l?3H5-)l#4Pqd%W> z2F&w@8~SyXz(enNc#yDIGjuuN+b;ji>j3SNJOTe$bMSCp5rxAN5QysyD0~=0-JW!( zW?7<20Xpbvcz+rvTUw@Di~yKHWP}alD{|HF^1!^mhpIyG>me)v6kIwr9AjxHaGelz zQh@l!&4-8zD_Rn=`jj5botcTjQ_1WnuIjRCoNg9*>?p(q0nrXzLk}{rVA5`zW*KsS zLqT2$rDL99t;w*~m7z3BHJmi7!2JMl$Z{3ueSXX{z=YQ^$D^W~NUcD)>bm1nT0e>Z z>~_RA5Tzj zQhq#9!?*iT%lP8&anm1i*JBFOo&3rGH%nf;^q1JjTQE#Wut(t!VRyU$T?@MTL1IHq zM&aiP`N6`}+OHj{`FH2)(?0-&L9S1fwD?hBzDMe%YCb^xM~;83@^_V*zv(Gtvi(Z^ z`*p4UV~FOFGT*CDU;KEXehZtLwB7aQ-A+1tRnDL3_2%o6hkng#!g}hMeJ^dY!{Ex2 zSK)!743L}hfGzamw&g(QG;h8R2bxu|!fyWPhJ8u9lSz%kX*E9(mBb}V?xMpMS(JQU zrtw4+ogYQ49%rOla^HVX2w(Afnp1yOB8Sl?rQy)OaeHr+uuz1C96|C5e$9m+b|RXiC+hcqr~VgX;}p>V literal 0 HcmV?d00001 diff --git a/example-prompts/pokey_jazz.mid b/example-prompts/pokey_jazz.mid new file mode 100644 index 0000000000000000000000000000000000000000..73c4f216da2e7e75159b1e8a8e81b1ae84d219f3 GIT binary patch literal 14240 zcmZvj-EJdCvghMy8?%dq&@c`Pjlo>7Wm}^zQKF4pZ@Rl^ymL7lK(IHkK6zBpa1uo{2%!r>#cR2BiGq4ZB}-+ta+I&rvENXBWJcmwp?OGGLe@e|2~wY?MkvK$%#0oHiLHHn!$P|{X(a( z0#k+}0})&3__|#e(-&uDG0bz*&&^)e$;B|)w|$Yluxj&7JE+`1?zeJt>e5rOTapa( zbSTaEylB?)p6JY8!nDOfd*RGg(Y%P$&qXdouEIK$|LZ{@<96Dva}S0&pS#^@f;+OCz-o7b6~md&6m!e*zBI~|IY zBCo@irrpfVpfKmwObWA?x=k?yF%wDlBsrJGcnb}7uc5%S#*fJT5j9s(J zAMqJDE+2&l+f8Q!+u4=QsqGB-ub82mA=&OxZuWCi77|3xl~lZ0%L2?Z?Y=bc2F^?} z{kVCdqD|9>f89oGnVgi|Qg%gBkx7`qh!c+D^nA^#>P5GgtL`vvd%0v^j4ft8Nl}Ln zxys)mhFt1F$pg7lSD%VcKhK8weTo90GdqQOmC3mZMaI5a)y3bnXV$RrWiB>1!@PqE zaM>M^S7D!dyOEh*$Mz)I;Jq&;Ig=#QDak?3sx}jIz!wKPUyr+(w61g`m$B$USR!e6 zEE}7P$~Z|!;^V#}+q^*8o_0#wA~&3&m&v%`tz_eKH`ZNEFX&O)vSJZdTuj@2YdFQI zob5}8BaEc!7j9O_Y%4QaYW7nz$atibM;aCwL1&A9$o+|zfzR{+E9lzd15!q$SSpXB z#(~jgBAa})lrrC@pQq>IQd#_YD^ItJY)jI@$+$A-PQEOq=aX$|)^yFa_S3BCMhSA# zp2(LXpDtJ0;C9Xwg^UWGti!zuNjUe%Q>@0D%aT(jL+3Etm%Z##nind~AKUdT0S6*t z7}Ua`PwTRHo|nd!()Y9Q>L^FHm${=t2Bq0AvONKUFwq1c*mmZ?%c-9qh&dIvjWDcX z`q0;)$Cj?xygT9?T#CFAdAmH6mkw6=ZP1+=NUIpLX=TsG zGf$kkDdet9!5DyZ62{HjJCU2P0D8>P`-SXdER~s-a#P+PDS`mO2H8%<+(@&!n>S-c z8!O{dZc8-*+p-`=bQ>0GKDJ|Tt{=4y#XagF*`~{gaoN0>6fCP~&uy@NK1HT!yPKzz z98o6Ml9ojV241;@3tdUGa`QgBk=aRZip&6BgPu%^taKeNfpDJ$eN~uJW>+$scy%zd z7L+auGq8(E`nlg#1XHoGX*4U-reY6~OJycL-IbJ04|R9Flxrus2XFyvCPW#Tb}aLr z?ToAB#wA~s8N$I>C;VH|8*vkvk4wa?1~+d5M*aQ2a6x9Lsa`W_FM)G0Iof_HCIv_l z36fEu+Og=x(n}3-E`UzniEzP7GaHoL!2f_tpMcz%nw4SA3rQm=ePVMth@doVopYno z1tdy9qKHl|tdq1OsV{;vz=TIU=OmXf_>h1sPL>7=1Q|=5turAcn!+rWbt#?8j#DCSSoi8IHzVMz-$ixTtX> z>1<|9L{;idj(=o1z8bpqsz66l2U~3V;q=YQdM^X*1{ROXu~HsQy*}Px4+>`G7~Cu% zBTGdsiUcZvcG+Id69EK(O>rbalx2{9mA zX4;}UsMfJhkK^?|YXc}XFS3Bj+;zIV-!@XH4fmN=9iB>`^RxDrm3uB>%oNw<6OP8C zMu19MG#WHRm+vZ>ROy{06P}HG0UA@l2Zr8gz723R7-qi%?@HH z&2bgB94fB>K!=4z0MVYJ81CDRyxtHLZt*Mr=^)1-c8Ej-t zA2NVdzHmj;zjPi6@xeyW$@r+Iz2v+DV5@q#J19+FI5^4KRt_TaX3Ywkh8*tOQ`f8w zRn{duPYMr%ROcE=Hk1@lF)eO6m-(TcA4%sBmCo?3=*#`c8~px(Q;L|Y4a^RpovfA> zD1Fo4!n84dQTMrcP8^VS9U7%hGpwsjzm#ERw!G|DNmuNNSjJ)xFLP)KY7hqF>f?c! zm9*=H^!2n{3bKpY^APy|hqFCr_8r=zUk`jID4QBDd^Wffo`(gRZo3UBE#;t6z2*CP z`+IZ75S`@ZJeM>Qfj)v6c!3TmT4=zck8vN-8+~0&FVrbF-FQZ}1%{%UvSgG{HH7il zU%t71lt$*D2!ufA4*v<6BAAI%<>>lI?^^r<-0iG{->3{GP5~Kz9(@58JgETD8GG=q z+?H7ue=dxLx+2m@VBm{sFWW6}EIJVH&Or6+ol73HGKW5kQicEGY1cUS8VBvcUie-O zT5r$kmmZc;>e8{##UxjhOe{v@R1?QSXlVOr7%mGVTTOfQpy&V(0VwUTOrYINd=`NL zj}E~cS4u^Q3C|)*s*jhUZB`;{esa87+v~|jPTXPe?`N3ULGbUe1TLg8h;*>=b!)Lm zeG6{x%5+qEB8?rpbn(0djw^x$4)JfzSlxRcZ_YfZIG%kYg)r!oJ#9VLJBP3XhXs^3 zt0O-jiBlI}Bmn%r?ygL0RJg)L7q8FRX|0}d`J%ukkFy0fsMqfx4&D-%8Js<1{x;)Z z1cvpK)8|V*u8D1utXTXAZh-oa#^0>rE7mRcs9Z7o0|CHtiEs%D4)Ut$Vvk*KK%=7y z2ur)fz8Z~zg6fWEr zwk*&!hfBP2sVcLk6)L%@Gdwb^RQ#CaLNDyoHj%OMb%m&TB=yS&L)QQMi*4 zmNT~VNjXo8^!I4;WJR#WurZo7ah`*fg}qWDs(W`YYVmajd-IZF4;Vwgm=EBw%F`DU zOqyiC;)5nAG&Txm`vTh8yyy~7SmUc^43usba$A}1BelYCz>{bMrX`fPl^OWWZKhYX zT56xSPa%(k6~SMpx?agD1;G@`IcDflJ-I8pIsWl~$n z^28^*ko=Ob(<_T5!f5o8Ya_2>o3{#}#HtwXB{>@;g!BerLbP`NemrqXxD6~_NCKzs zYg89j!0Iz+4&CC%ZI8414I<#ZyzMr&j=Xw2>B)Q5Xc%_RUCBq4bN1j$2JeFRvpA$a zyjEs|QjLElElg0Xj1Xd2puUt>`3wo?pYJoYe+JYS+9@{XUZzm zd-^0K8T7)xoO$zgB^MPan0_71!nX$nX7Trb#-I-%Yi{xDrOQUXepMI@c*lv?32Gl) z3^_24^D4J*`g)a>7GZUtY~3%fvjJ;&(hC8-bsP}Kj?wcfYK&P6&j{B>DE&b7XZy}% zrGY$LJA9dXk87Ynoxdwe0t6OFZ5auFDGbG(?uQtESmO4Gy4;XFg}6bWN=?7rEOEh$ z4=}z>r4q8w1E(mm2WK=H0g^v*eS!q+ppTsu%?c)VRG>)mX_=%&)`N;x=C#+@{!@*n z$UvCakpe!)!$78pWV~cRJww?CIYKO0T^V;27u$n#JTzQ*L-a{95RlmIywX(Qx%jP$ z`|FL4SfteF1xc3#ct|j!)M{KcYwMbwTasDs0i1#H=qNxES*>rs^zxNY*TsA(?kny| zhOF$ITjd9!3G(+?ke-Xz!{hfkEZ)CTiJp5)kf0Asd|fL0Qc}xZA{SvA`!H@&#rGpw zGxmf+j;Mi&Jnw{6&q>f$a#fjeu#Z{07q;jVcI`L}=MiS(qIhc-#Zhpgq}^3#07@D> z81WSz9&HF`=929HsN|&jA10Y zh2~BxbE`4Rd@LG}+>*F**_+y2;zcSTR%Hky;*RjS2>!frB)W)fS}BFQ^-8mwpHK@p zSr84+&ol8rEE*rEM7vDoIfL95UdR8^BFQ^a47))}cOeEs^D1<*6v6@~l4J{16vBE{ zD;k6#J*e6*?RT}guA3L^h%&ql8>KDw&+J2Q2BnTiy&}n)w^Qf6^JKW-eBF-Q<-d;! zD<2N}U04DtS0z`q65wfx%fPW{MCK)V8^!Uloj`cW!`WtyjX6EY(o5&zcuMRLOg>xP~7D(w%)@H!E+e#f-Gks0|6LB&sVh?kL}Rm7u!-%OhaJ5dll^3!6WyvNM?#%HTWK3efjkArd1|I+VFBp+{;Q zw#_m6mvBv%e8k+eeHp&=EwJ4E!Z-arO}me9PW$o4CGA32hKyVHp59Cd31jU zV@9#UYs3VDD)_w?Y$B+`y2$^c*#jMz5o$ve@hStEgLWD+D$e3L6p~R-0`X@ohyx;@ zJb+Y$FCdQvfQB&6i6B^vd1cHm=_Iy9`3YV@9Pk-NP~c>cHbRTAjOa*MM@e^DINPcN z*_CaWW0Ee?Io2b2^xMJZI2oQTB3KwC1pXoWgZtIp7TYZ@g^(XJIxfi!G+o{@0>5eF zD{`WpqFG}cN8sIJg21#pV*XkLRIEJLCc4q%e{6yC<4Rn8|JNXIPjDRg)=FDAY91tM zJWCw=u;hB!%{{J~YKXN*_FdQDeU!c-Gtwz|Y+e3c`$dslXpY6jw|_j%7bq=(jx%U^ z2#d8E6`w-y`Ml-u$P=laN5{v@2K?D7RaIH|HRzf*UB%e~RZ~|83`EI8Gl49|6Pv&= znc&)sT4`w)Gz;rwo)9GgZ22-%7PB(;9H(#LXm3$Y_>$?!+ z0#SO{ZT65&mqmr0w@?9wU)Km#zNs}dXCSG)Z{qN0VR$}c8&2ZYJ6-_4?>QcY%b-+U zBC2`eYwHQA)cJ584#PtOa7@Zt#9h?*u=z1`{jieohd;zGDB3H3om+cvF&!)9Dp1@L zuQ^aifQs*@7F-x0tnCn|vLqQ#hqK3+YH;r7B||N2NKyfIeo&K~;zY>$p3Lfr+5tou zLubzYz*7}!Rw%oar6}T&%_2G%yCW(5bm)DEj9{;G$SI*Qe(s`5;}Zl*l#T5@;coyp zAWIn)`4xmhXpE6UE_f|{%DSBvP4_LJ@=n8)<^>*uzq*L#9RjN5UEQo^!Rq*iPjOv* z!5!0fr&KI{>oqXD$B{H^-pyi=`ZSh-uI*9r%FWc`fB2Yu}Pb^m6mwq)w(@zOOnr!7r&MJ#RV1_hIqxoxG(y zLb49cPlH>;ZL=C*Lv2}#0w{D@HPYx|93ULx&0kuIL~kjAU~wpBbg56rU}D-LD@3cL z0@_f(H-a8NYrm@S?9e{q7_Zy~;JW*rHxDd9w{BMS8E%XE+IjSH?C-N+fc!gd6q7Wj zoz&|1CX&6uuW_UsI2vq*1;+28Dt44+pA7K3%9cgu(W&Vq;hPq_l{PWb`SqXE?zu+xC5c~>|p72!Lo7dIlLN($P!OPCE`Kk!<}~1Hy<|<>k$x#{j*T)5d^3pD#*-I!M46KS2kK~>b!(}$@WBb;-wC8p;C$O{ure-ZO zP^`O}VE!?m`K+CiL&F*)Y@@Z)uurEYHsiJ4RGJ|q(RC%AH=2EWSl7TuYxX2wA&45h zM1qLJ4d}pS*^O=v)Zn&CZ;`IX{Q;onw}c;c@>=1f#(d~lA*o4nhhid=E!9a)s_;fE zaSzW-UQ>R9-W6mbu!qi-(?Bq0A_ixwWX5~+sOiE04hq174giX|3VKRO4m9h%e2fsH z&zQmEW1&HsuL{{AQ>R{a3GleIHQdHi1}Z-1n*cqM)Qs{@^BdGzN$Z5DyHEQvNHs>YqSqgH#iAJw-`^(T8(bY z#WdvK6aXP`xw{9VvVa^>*rr9}U;$f&zy$YPidmwpT7TPw%~j0`Dh;M}3kb0#*?+)c zUUCv@kDC7{dR|PCKlETT_g(j!EwL{p0Hz|Xq#cIATV(pxFLt_uYYgo|{q{tY>Z<#i zlO_a$Vvw-=2ZIbE-$Jp;xDkA6A+hG1!B>JNTw$?If@hC(0xlJQ{9+r zkYFsJ3AMcK_DT&i*#BOD0Jf%y$a)}+mVi%qN=o{4 zLJdOeO^B*6{lzp^l^*^?PK{QWLpld45)@z$R`Q2B0J4NioqCewO&`pUW)8|muqi{* z4n-!J7p@I(^5jH(=LLiy3lS`FD4EK|w$inb)DM|FqHvNG$Sy-9$P!E;E!QV|IOc|E zBQ;$t7>FWz4^{*8u?MfmYZ6#c8Al7zP`;dLcS>j;Xa)>pHtj@fv_L$;WdzD-a$4$l z180fpPL+~UVT_6M<`wsvOaZMNL~T`(>z5P zS#eG+AzXZ|J@-(gX7!aN`DR9y-W(+uzYmE??)GHP(0xk<-38;NDJydh(oKvgm&M}G zBqR@tSxDzEJ{-WG9N5R10)l+jx{79^8Uz78m=WGJgsChM==DGu@j%gnN`1&hu>uP5 zM?k?dk0lEy(Ehp4zz#MZRhTslk+qs`QGzp=ZBTlO`2*gem&B0eF>@D^2xH_PzNmz6cR$( zD3~0bc|jW^JVcCn^4dYFuoP^8JPWa%E$8c+H((zQ1gBjL@j>AXF-)Bk=W;m=A~^4+ ztL+3@Xw+vTNt_znc`RmQplIGB{8Ls$qXA?Z{Z4usp|S^_7qSO<4nTkyOkug_s-AqI zoH~N@yxmhAGkLbyCiDOQL7c4R?RR-v3zfoJar%p23H_Gnqi?K>!UlkzpHpCu`r{|$ z#PAZzcT4l81c1~_EPeYg#Bb7d=G%Zf)E|+9xzT)`+y&|h=Hc`s`ux5GiFyRc!0Fxz z)OoRua8W$=fK6*DvZA1;{VzE5P=rkAC|cV^o!J3zPm*g_EU(S?`1BX!notb}9k}xg zOq)(Gq|tP-HqE$B+&{JwO&^PpSqn9pI9l_-F}ykygOCRiX*W@`R?vq4SW>dG6j`X_ z*fK}tP$(ri01pfSQc}vCbk3w^LLqbtVItavq~Z+b3IXYvZkmD3=oV1Kw4hrlZ&EVR z+|-mGJWz_7blS3PRv0qLYvn^ddQP|qBZ4&ff-oXXzwzd)(oPD}(JHI2I5k(G+H_MB zjW&G^0Yl{Tu)fTzun;pd0CGk};7Jt4Xuh$hkQK;lK9X$pPx%xK@E=Dee-+ZMc{{sDHqa`2syu%rM< zH24qj)S)FpsD(ggwKY?9lolOY1i3HkY2l$+p>KtTAgE~OQj~=uq2}O21+rrcX2Kx| zCnMW&AqhmLzpO*e0q}A(8HHja_@ubcsY5{som@(je6^Ph?1rk4)~v*)@vo$Pwxvhp711l@@QNDE^HjK&2&2vzly6GgxMhX)dC9 zEo)w&Xy6@mJ%ZihSFI1yGH^^iueg8wmR!xlp8yRij0OEh)QqEoawIYf{YkF~hemen9?9mZz zI*5oqBDRJ?9C^IZGUJ1`kj&INXHc#=9h8rueK1CR4_d@#d$R0f;!Q|+YX=1$37h5m zjO@?^DpV-7Rilc1%6&rkW(|qpN#Y>WlIxs0Dzk{f(FaZa0I9wyHIYs9gaN%FsTQuO zSkNM6#6^l%6TfWP5PNn5NAGV)yW)8ZcT zGQ_9gN68t7q8+_`M6u0B?#f%DFy9bJg(ufSCS5Y1rqF zgrpCTQTj`}6G29^3J!7F5Oys-3*|ul409xppsb8GYB|#EEk28p_^))!Wk?5Ba*}E6 zf-lQ00ImrQI-x_M)5G7CacF0F0*yyZob_G_SDz>_bUV?1%yVfS7R`ylO?;<&*oN69 zJSY>tPp=-9$-{M&mpFO zH$dsJs`;*+-AK(IkK99Dr(i@(7Nu6Ogm{c%~s?2x2WQ%v@N1T|j{LD6bB|PFFjhv~b$25E`oL z(C?ILv=v4zp)<|1T?=FT9pkVRO~1s*44Tv>1+>O70Sg!hF(ix>hs@lVp&}J1 zq<6tdaxV+zz1nK8xn1?`WZXiC)Vw&U4QZv|Tu)Z!q(+Ht`BRIoV^ z={dy1hZu>Ba^yB(oRW};7s`l3>GvCj8Lh%Z=t5O6T7GuH;pl)+7&aTB01Yy6f2_dM z8oDf~(O*T#42*#kgcT!&eMpFQEywt2mCOt2LEZ&CS)uQVEGg+d{c%gA@nqqG^I2v5 zzRaHjsmL8EuaUoTiy?(FY_YGA*)rk6DXr#E`S}#u%C&3~$OFn)qza&NxW zc8uiHaPWs6e%M3{lBTr@qF8~7qw8GI2@QM8GW{VFfgb%|{HTc^h;RndSRs@_Be?y; zU-*3rNh!iY?dj%1;YUKuq;dHqZ`NY(=DZz{b$L1|$^=HR+hW@CBc=gEE#Sw+$xY?lfJ-16o0^x(^>le0!dsH;0kU2~2}wwX_KLK997s)^u@8rP@WP53DGpzM!U2K5 z^;*?t5&apPf+%7#E$hMx*@I(FyS^E3X^x8~jm2})YTwuLFOYs=I_V52Wt^P~>W6*5HK|53{q0U1){cGoAwbl$$oA(lX7MXIO z)S##rMg!k zY6SyI^^}T~dZ5&zQr{_6Dz&E6u2N5w`n99%G^gKb`lTK93fsRQp7fRK59=NOvHmm* z=8bwB>rSj!jU0kK4osBFSm|-QX)*qWc$Y$5 zmWG}8`9`UERw=cr)Q(ac&QQ`kRT*$xG1@iQ9*I9(R@-JrDZhWzXvT+9L7 za&)p+1_D!RdSP{Rcu)=YhgD}!sl80&>atRKsFp*u5e2y2YXu${wqOwqmXz95YF4SC zQgccjIxlpyuk1tIu(z8*Q5}KCdYJ`9s25FuqXgIvH7IjJZsdGyg9N8>vS8Y0Z@*#R zu{fUT%_LY2^|#>2RyMg@03rzbz;)AC_Dqg^*=!~Chaxpgh7<>RC@y0Kyg&+Pzx7AF zWbLTgEXLbdn`@~&(aT0}gpefQ2#ovK#2waFon|x9OW~)ah>V;MQ$o$*ZKE>q2?I#s zcYp(tu<1d%ZZefI9fkcbHxvZCs8`y%LXRvC~9DwESLH%Yms^7DkuM5M9$sP{K z!|GfTRaR@VbZGPW@FFqx;r>tywrrSK74JEU#aQM*qQ)29>?*qgK%d?=pyAf5G7a*` z1tA%(I=ke4U#b0JJv+!8XA;dycs6gu{%UBvaa>5gKu@-v<& zs++>qsU`{QH-dbxsO}2Y3p%!N4zhGauHk+{XIr z2Q52IfLp0BX?-!OyTX(MvR2)p=<{}BZw#{f>nEY=mztvPhKUdPf#Zs&1$(@-kUS#X z{(@4@iPg3DA|#uH>s%6P*VaKRMeB8|SDuOcqJAXGm!Wm>Nj zWc+-Piq>)#b+iidvEDSg2tWVSO++dHZ6lOiQ&e-U`h%b=J5#44p|EqYMXl(p<<&$8 z+76`VamJ?`GP1&= zv*t;n=cI)6T5Dnqes~eM=zRRV(4fH zfCQS7qH8SdZRbvr)URg3YC*D)RcnqH)$4YmJK3a#@F~Lv`G;MR5orPN`c{MQCFzcJ zf(*v>qU0{FcQ|QK3TEx3?tVAaPf|@Zu|UQh8_24w`<-K19l31MVqn?vYt*d7b@yAy zH8`jic7CW{7m+AT7D>E`34Fu!or4;Va@B$khGS<#H#=Ty$A1%`dZm;vM0Y^cv%vIW zUP@BYDs5--B?avHGW76nPQ*Dv21@!e*`1>s_# zSP?Lk@xyXj4f>>-T=MaHzD0|LP|~CYkE7_CH-O9v_~c0=IFM&M-6RwOIa`gZo7_-2 znms)z>)BzZxjS%VQ<^o=L)^y#H_mUJ0ueDqTLERTo+S#hOJ)#l4&N$*)Ljg!n>ej@ zB?C;p!5BJBv7ouK+}&f}IRejxpF1yM%IC=P# z{fojldD#(p5tHB%PDEZrrX|+|P7w;|apQ;`G>=w+5bi}2J)%Eccm3OM^~K&`-CNO< zY$7Q-b)7gKk(PT`l4EJZVb)??zgf=gyRFim^@etYB+Sf3LATWX1J?r()769kjfw1c z!ZHa_l0q3TXH};hRtu=g*Ms4es^E*X!xQQJe^mvvKP6YHQb@^2F5T6*k=EsMLI_WJ!ABTRo-b)Y+(9F_rM)8pa~B|!Y*9z-9A5lfz;3!Yxy-YUIl zQSV4_Ez>XRS7o8+TGi`WW2aW43lKuHxLs)P@@U*%3dc*bHvBx0MxR>a*2ihXIB3u{ z@A`GVLD$y@Re(6)1jqKEro;cLq#GnD2 zq45|I5T|1k8YS*#l+~^@5z~!Gb+f>&XRs=(g#q-M@Q8MZTZI|7(7A~K2@hH6ee(zj zHHA|r4FOO%Qp2tb=TwZixe5ECv`64^)Nc!$xcW& zlh5UTQQdJP5#zl?;6|m0WZ@)Bc3jBL5R=Ii27Xy-@SK8^xLi<-QzG1bVJc-#dweOs zayw5-(w^ZJD_+u)kL1^(J)1X)XMqO;V;&CxcBl1~?v-xvt!A3laQqihVXk^p&tm1$ zF|L|{7j%;tqBI$52I#TKf7f;Zv& zggLUoJ)0jTcKR^2xA*znl@TSeEB(ygmV9LPprVXXm*Zh z7zpX83}U0x5R7nF{!ko@lp2-QP1(uKfvypv^!$-#+S)IxJ3xJtBcZNal-SGrtqb`> zTAjOH_1BSf1imkqQCnznAk45U!u!llfy|9Q81&MMw$fgc1doK0MxYBr_vF}Or?)eE z<0*o?Q`p-FneE3)eV^Ku=cV>-(Q3p4lUk3tE>3<>5(OdHhsU&a^ew55|RcW!rLaC{oVXTSrAT4>Hm>S>;4>{AY|3;yz4*NvvBB>I$uiWMv-ng7`dxL&* zBncqe@5$hECE=|r10nI7MLYuS$Lg2>ILW;owN3St+fpq+m;tK630wo4?vC0rTE28O zfkv$Tm0%=<$gaFlMzTDt&XKyqLhr@Q(7^zC_vQwk12fN5aG;Fj6d9Hh*VF?ibi>c_fAqa;&{s0Xa^9Z*>P zEMzIGBl)0|7ft4Vt7lyf#!I;5v{0^ix3#+aeRe7pKPv6*LlrD@L(Cg}n8{PlAd-fK z8tfU)6#4TSiwW$pACKt?@ECW$_Z~}kciPD2geKx@KaUQK`mOpnY1OazQc&fpibcak z86zk!jd$XX13upzFws;C@)Gz4wPCK}Zo6mHlQ4MHn)Eq1=YbToIGZVWn+cIVe3mzX zn2A|(E4$s?{K)k&_pV*_1Ht>{)b=@@`}wTLLoz}+AkBW5q?zZz#A_`Gctyb=f=6oQ z*sB?`-eIs<+dbgHWmw;$Uy(5~wa+*PxYZ}F)68OAN;;#}bN#a$ zzBgjhM<0LH>Y43AcbIb=qx_ERgor6FXKD4Ab0)1`GHLaglP9A2<^d0*H0p6cipVM9 zB>!ihC$`=-$n2!g%Ms=Na%{K|N0MV+l^2Y(E+E6duX{TJ^BHv4Ice+irX`nor=j)9 z4^lfdDr}x5`UNJQAsB11Ip&r2-0M0=s3dJYp|yM5^jzq*UZyd}F*t=2}c zuXsCW>^oP9c4aoakOiUbuSBI&SE<$>Pt9R{+>Y(af=WuIyQ+GfhcsijqpG7!=@%;b zF*VO)&8Ct=-E-yu#aL70zw9d9p=NLdr9MV_q)~L!27LKR0LAucm2q@{k?0g*my^&&$ zh>8AC6o-ycq%*8Kbcr+E_&@O`;GWMFNOC z$H=jHSg(V|`5>!r?UkS29+pny5)mM7L^Q}j_GSJD7Y_DF`~GzA%866j%DfpB(r1&R z0S_UuUJO)ej97Dkl68qr-0jm4R_7R_$ND*s>YajT)L!_@^pBZ_Z@^TB&KY6d3BGy` X`M$+V?Z{cnL|IvmZO^5DK2rY!Xx4^!U-CLa1R4QGN3ihC5R!#q1Js%PDr)(wz0?7YG)Umt9vv+2YNG zQtAu-`^|6uXSR6vPf9&_r~W+uyTAX>HTB=>PnYdpxBfi*81A(;4~wHbyt7K#{3jp( zgf~Y@@$p=Kb;D#&sXex(wGV0SO=p+Gk21b+5SVMLVgpfzA3kO1zrZ&C)^Htj*4}tD2RkWpgrAYFJ*0 zEoHM}%v+93x(cP9D>Y$Rsi{&2NNKJ|K-=QC9wsb-Z!^7f!m6Wa{fJgJ+$E5d~B z2l-ED@~0E_dwKfO4d=Vv&YeP|RPXMFZ(H2G6IIygws6myW@+s6VbKhFg->2#i9NeX z>{1T*Wp>zYR*!tRH*(iMY<17L4QHTwtdv*kFm7g}>kRU2bZc}kZ`O`%Zu1Kc-WLxT zEukqgB*o$_rM5a@BEP^x-VN&`AJ+<1R$;hVVxu^XduPWsK}KnleMeE zk7sQ`B#tUAj&V)j6wUill?_T=mikEKNJyAPXX8k`xUSS(saH|sg}5GUZ}&PeTsA_!h!^XIWn4N;Zj^kI+v2oBSpyiN zclYJYB1@NLeXmdlM%nUKY?v2nVDL#MSIu?W$j#V`2Bal3R)!y&rPcsAd^#ioHtcY3 zQy57Tub$OmQWW*`N=tmLdTO(DEta1}lqC1w1?g<~@rs<$1E2Lx)1RK|p>Ng>Y_sGl zH9Kuq$+CXgEP=0~y%43YS=%&u(#qmp+N7MpeRpY7#RxQg>{XMh*(NVV0owM8uyas^ zg_zz6Lva^~>vf37?XSgd-f?+Y#}Dj|BvoV!*KEJ=k{NsX#mkx>O1<#4c(`bm7D_E_ zxOZ5n6Z=+vxU&M7u6DR5h=ix$vctAEn)Ukk4#qxA;jP6?$&P z#9l|uj6Ue7ajv%>2o+!@Zc5-nj05l#JU|v6r(?E3?CVw>Y$GNMEsAb+Va_tTa zMqLd*UjNv})o_m27V)3g|Co2IAjiVPdXRE1rczRclH?>b)ExJHi3__!SR4> zQigl@l*Mq64Rq$0o_5{~m*KZ(Nawoe|Cj73x@*P6Jn*eLASW!7jyKFXu(I+80WO&U|-N%Tt`ui9RIHEa#>uK08XxO(G$KW&na|J5f1Q4vF@Qus-*CYO{%pGUSiM6Glc4juKpWbyVvi zbu?yUx+wZnBn29MfrjWoW@@>adX6?5CP`K7qbyy5Hl-Q>4VV8dX!K)Rn{W?k^kPhx z)VB?wc~D%_V?mS-Ye>49tBKPn?M%6b(2 z^f%Cv(MP(#z;F3Z;cxjvM1;rXbjH$Wur)c^Mc#vPRhZf1-Eac7iC-j`@>m}9xy=>3J z^4x@B{Q7BgL&1eu7M6SJYOAX#{ZeL)e%Jw69ouaNQ>6*bP@V0WW*c__GHSFnnJXha zit8U~lJ;M!JoX}~b&4x|6`lg7i&#Fu+E#(^|Jkx5Tk?)qVTX1W5{qSC<+m*X!AxXI zz%e={L!z}4r5+|ST@&947lF-#8f=F3$4*P>x0=#_UPRP>q;JWfj#h_A?vqbT98v+7yS14Xt-<)hzLW9XMT;wj7IR|G$hBqxI z%%HTmWxGAJ%2(RZbVvv5Oiwu4>C$JT`g)4smiR48%y;P3z$;%%VdoSM#_LHlrV#DK zY@L>de2jiPaR9xL=&%I>X?Ksn^7T|(FC~dHlJ?I-WwGgXE8xO)f+ic*bJtwTKV5cQ z>k$apw0;(=5fL7HMZ=Cc04r%n7#yU^T%&H4U^Hv<(^#2lnBk1|19x^@XA`fVIW=-& z{gHc{{6=KaLQFA=8{w{-f)?q9p7i{Q4fk0zMhnr@;lrNMhfeR>4mQar2LUSYEe0iE zviq)8fR?X+Xlc?oqF1PyQ)8&KP+JrXD?J|upd^(ox-ncrooP;@OA3KJf)Mb`Y0Sem z?mw?TeJlQ9f)GBGf0m3kWz9_vlAmg+Vsd1>RmwP2$9{hRuh15+qWcI&n)_yDYA*!$ zbV>Y3-@h-sv?oy|q05OITN1!<_g%5b+ou+8Ho8xXEuyYjoimnB`6@aswIk^)!!muu zpjMMpssAh%f2EtCe>W|*tvH)eCpdc0ZN@!FwUX&~bwsxPNk`Au4p;(t)8j@Xp^+HPg&wZxanaRm)MuWWq6kSPe1QruQ& zTo#PXZPWjKtzT8@cMMJxz*TD9sw~`yQj#jwtKzJH&qbO|ho1aOGe8*1?wHJkeCiWp8l}DYP-cU9DGgzqAi&I&gRcf_leAq+5KB)7Ee2KeT3Seccu&kCvO7x16v1Thw3)S`nTC zDbA_n91LS}1Q=WZVF1Vr*CX0$n}Gd@gG1+vWpt*X2_abNTUSJobuIe zoWf3f=)(uZk{5>VdUxcT_cM1kcX0|XoA=bIhd$%t++$YpB%~5{6ABV=lLs1F(4Uvu z1*q`14cGTScseo$UnuZ|VknFBUgxeE%U9Yod?@gh1t!r~u31Y9u~kD#jy26m`gqgE zjlq6jK+FS?2=e+M+jt=IrWwZt4?Z8NW{oSKn6R9d2S-J78JFqsIG6VO2W8^g^^cB+ zv16;x5dyh@yL#Bz`OHtkj&pCen=Sb0$O4-2FLb7qhV(z$6quh6#>3 zF7%|(!)~3Di3PTxPb$$;`>Y7VxeLQTQi)G_gW-*baWR3OiJVA2q4TK(>X^s-8*eht zV=xY)A`bXtR0j~5e4biClk8-QrH}X0>Ln=sk_DG>xUSW!vvB{oR^Oe54}+pvLDhGd z>(zI4%sD5ogrJ0m3|FLJhjkhE=$KbtdOdStd0JSBQ7w3vt5Zx92k=h(hVgkJwopqg z(`)?a^%sY{ak~5}deHM06XpGwaN!}BUof+Mo9Kd~1_>d$4~kzn!S#f3(kSWP1p)Gs z0^3PLC=SQ0c!_2p@pt;T%zfT$1Kv@QIHS0bmD2MV6jg?uctCpliMK8@(S&}nwtFR~ z_?zxvUJ8GnB$Hx)$%AwHbq+N$9xh9~YsnVgLF0-b8E$4O$ZPV58=~G>6h#KU9XXNc z$3Wrlggy>Vsf)^+MB`!PZEFi@0TO+G`iEoclgmc*^5yl0A?>?o9E<6{0-HRLTOP=- zzYyEo1cC4@uFrnJY3+ExqlvIQlcdkLMmTOwSn;aAb_LfQ%3x8M>#8;GHw7 za434vnF^}N%Jbt1Zx@X7=K8nHY*Su=i(GH!^1les0gO?TS}96m$_;Ay$4~##9@u|f z=hgSCj)tH&JK3WL+H}<8j_y7s`Iu<>^^^ zPwM;QNQ2$GeExUJZ7B`GLj-cY(b2ppLLpn3Xi5An`k^=)O-!XLdMLZ4E0KL28QL(u h*VUASjm}$k2Xf@Uc&|}xP5aov@}K%Fsy{zg{|kUh)baoT literal 0 HcmV?d00001 From 23343cc6730c00a84cf9d50d53c2906f720f0e06 Mon Sep 17 00:00:00 2001 From: Louis Date: Mon, 9 Jun 2025 17:48:32 +0000 Subject: [PATCH 65/72] update README --- README.md | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 4475208..0a6fb9c 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ pip install -e ".[all]" Download model weights from the official HuggingFace page for our pretrained model, as well as checkpoints finetuned for piano-continuation and generating MIDI-embeddings: -- `aria-medium-base` ([huggingface](https://example.com/), [direct-download](https://example.com/)) +- `aria-medium-base` ([huggingface](https://huggingface.co/loubb/aria-medium-base), [direct-download](https://huggingface.co/loubb/aria-medium-base/resolve/main/model.safetensors?download=true)) - `aria-medium-gen` ([huggingface](https://example.com/), [direct-download](https://example.com/)) - `aria-medium-embedding` ([huggingface](https://example.com/), [direct-download](https://example.com/)) @@ -41,7 +41,13 @@ aria generate \ --save_dir ``` -Since the model has not been post-trained with instruction tuning or RLHF (similar to pre-instruct GPT models), it is very sensitive to input quality and performs best when prompted with well-played music. To get sample MIDI files, see the `example-prompts/` directory or explore the Aria-MIDI dataset. For a full list of sampling options: `aria generate -h`. If you wish to do inference on the CPU, please see the platform-agnostic implementation on our HuggingFace page [link]. +Since the model has not been post-trained with instruction tuning or RLHF (similar to pre-instruct GPT models), it is very sensitive to input quality and performs best when prompted with well-played music. To get prompt MIDI files, see the `example-prompts/` directory, explore the [Aria-MIDI](https://huggingface.co/datasets/loubb/aria-midi) dataset, or transcribe your own files using [piano-transcription model](https://github.com/EleutherAI/aria-amt). For a full list of sampling options: `aria generate -h`. If you wish to do inference on the CPU, please see the platform-agnostic implementation on our HuggingFace page [link]. + +### Intended Use and Limitations + +Aria performs best when **continuing existing piano MIDI files** rather than generating music from scratch. While multi-track tokenization and generation are supported, the model was trained primarily on **single-track expressive piano performances**, and we recommend using single-track inputs for optimal results. + +Due to the high representation of popular classical works (e.g., Chopin) in the training data and the difficulty of complete deduplication, the model may **memorize or closely reproduce** such pieces. For more original outputs, we suggest prompting Aria with **lesser-known works or your own compositions**. ### Inference (MIDI embeddings) @@ -70,7 +76,7 @@ embedding = get_global_embedding_from_midi( Our embedding model was trained to capture composition-level and performance-level attributes, and therefore might not be appropriate for every use case. -## Real-time Demo +## Real-time demo In `demo/` we provide CUDA (Linux/PyTorch) and MLX (Apple Silicon) implementations of the real-time interactive piano-continuation demo showcased in our release blog post. For the demo we used an acoustic Yamaha Disklavier piano with simultaneous MIDI input and output ports connected via a standard MIDI interface. @@ -86,7 +92,7 @@ MIDI_PATH="example-prompts/pokey_jazz.mid" python demo/demo_mlx.py \ --checkpoint \ --midi_path ${MIDI_PATH} \ - --midi_through \ + --midi_through \ --midi_out \ --save_path \ --temp 0.98 \ @@ -95,18 +101,18 @@ python demo/demo_mlx.py \ ## Evaluation -We provide the specific files/splits we used for Aria-MIDI derived linear-probe and classification evaluations. These can be downloaded from HuggingFace ([direct-download](https://example.com/)). Class labels are provided in `metadata.json` with the schema: +We provide the specific files/splits we used for Aria-MIDI derived linear-probe and classification evaluations. These can be downloaded from HuggingFace ([direct-download](https://huggingface.co/loubb/aria-medium-base/resolve/main/eval-splits.tar.gz?download=true)). Class labels are provided in `metadata.json` with the schema: ```json { "": { "": { "": "", - … + ... }, - … + ... }, - … + ... } ``` From e5b877727cc3bc2e9212f6ba7be0c15deb337e60 Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 10 Jun 2025 12:48:34 +0000 Subject: [PATCH 66/72] rmv test_dataset --- aria/training/classifier_finetune.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/aria/training/classifier_finetune.py b/aria/training/classifier_finetune.py index ccc7712..b8d4b9b 100644 --- a/aria/training/classifier_finetune.py +++ b/aria/training/classifier_finetune.py @@ -691,24 +691,6 @@ def train( json.dump(results, f, indent=4) -def test_dataset(): - tokenizer = AbsTokenizer() - dataset = FinetuningDataset( - load_path="/mnt/ssd1/aria/data/class_eval/genre/classifier_finetune/test.jsonl", - metadata_category="genre", - tag_to_id=CATEGORY_TAGS["genre"], - max_seq_len=1024, - per_file=True, - ) - - for seq, pos, tag in dataset: - print(seq.shape) - print(pos.shape) - print(tag) - - input("") - - def parse_args(): parser = argparse.ArgumentParser( description="Finetune a model for classification." @@ -744,5 +726,3 @@ def parse_args(): grad_acc_steps=args.grad_acc_steps, project_dir=args.project_dir, ) - - # test_dataset() From 38d0ff2c42fabb521e820bf8d44511d39058dc89 Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 10 Jun 2025 12:48:52 +0000 Subject: [PATCH 67/72] update README --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 0a6fb9c..ca65eef 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Aria -This repository contains training, inference, and evaluation code for the paper *Scaling Self-Supervised Representation Learning for Symbolic Piano Performance*, as well as implementations of our real-time piano continuation demo. *Aria* is a pretrained autoregressive generative model for symbolic music, based on the LLaMA 3.2 (1B) architecture, which was trained on ~60k hours of MIDI transcriptions of expressive solo-piano recordings. Alongside the base model, we are releasing a checkpoint finetuned to improve generative quality, as well as a checkpoint finetuned to produce general-purpose piano MIDI embeddings using a SimCSE-style contrastive training objective. +This repository contains training, inference, and evaluation code for the paper *Scaling Self-Supervised Representation Learning for Symbolic Piano Performance (ISMIR 2025)*, as well as implementations of our real-time piano continuation demo. *Aria* is a pretrained autoregressive generative model for symbolic music, based on the LLaMA 3.2 (1B) architecture, which was trained on ~60k hours of MIDI transcriptions of expressive solo-piano recordings. Alongside the base model, we are releasing a checkpoint finetuned to improve generative quality, as well as a checkpoint finetuned to produce general-purpose piano MIDI embeddings using a SimCSE-style contrastive training objective. 📖 Read our [release blog post](https://example.com/) and [paper](https://example.com/) 🤗 Access our models via the [HuggingFace page](https://huggingface.co/loubb/aria-medium-base) @@ -21,8 +21,8 @@ pip install -e ".[all]" Download model weights from the official HuggingFace page for our pretrained model, as well as checkpoints finetuned for piano-continuation and generating MIDI-embeddings: - `aria-medium-base` ([huggingface](https://huggingface.co/loubb/aria-medium-base), [direct-download](https://huggingface.co/loubb/aria-medium-base/resolve/main/model.safetensors?download=true)) -- `aria-medium-gen` ([huggingface](https://example.com/), [direct-download](https://example.com/)) -- `aria-medium-embedding` ([huggingface](https://example.com/), [direct-download](https://example.com/)) +- `aria-medium-gen`([huggingface](https://huggingface.co/loubb/aria-medium-gen), [direct-download](https://huggingface.co/loubb/aria-medium-gen/resolve/main/model.safetensors?download=true)) +- `aria-medium-embedding`([huggingface](https://huggingface.co/loubb/aria-medium-embedding), [direct-download](https://huggingface.co/loubb/aria-medium-embedding/resolve/main/model.safetensors?download=true)) ### Inference (Prompt Continuation) From 442cc7aa7ffae88c2f458cf5bb6af0b011f444ec Mon Sep 17 00:00:00 2001 From: Louis Date: Mon, 16 Jun 2025 14:53:26 +0100 Subject: [PATCH 68/72] demo adjustments --- demo/demo_mlx.py | 144 +++++++++++++++++++++++++++++++++-------------- 1 file changed, 101 insertions(+), 43 deletions(-) diff --git a/demo/demo_mlx.py b/demo/demo_mlx.py index 615dfc2..7ef8954 100644 --- a/demo/demo_mlx.py +++ b/demo/demo_mlx.py @@ -25,24 +25,24 @@ from aria.model import ModelConfig from aria.config import load_model_config -# TODO: Investigate DTYPE=mx.float16 (speedup?) DTYPE = mx.float32 MAX_SEQ_LEN = 2048 -PREFILL_CHUNK_SIZE = 32 +PREFILL_CHUNK_SIZE_L = 128 +PREFILL_CHUNK_SIZE = 16 RECALC_DUR_PREFILL_CHUNK_SIZE = 8 -RECALC_DUR_BUFFER_MS = 50 +RECALC_DUR_BUFFER_MS = 100 BEAM_WIDTH = 3 TIME_TOK_WEIGHTING = -5 -FIRST_ONSET_BUFFER_MS = -150 # Controls onset timing for first generated note +FIRST_ONSET_BUFFER_MS = -200 # Controls onset timing for first generated note # HARDWARE: Decoded logits are masked for durations < MIN_NOTE_LEN_MS # HARDWARE: Sends early off-msg if pitch is on MIN_NOTE_DELTA_MS before on-msg # HARDWARE: All messages are sent HARDWARE_LATENCY_MS early -MIN_NOTE_DELTA_MS = 50 -MIN_NOTE_LEN_MS = 100 -HARDWARE_LATENCY_MS = 100 -MAX_STREAM_DELAY_MS = 50 +MIN_NOTE_DELTA_MS = 100 +MIN_NOTE_LEN_MS = 50 +HARDWARE_LATENCY_MS = 150 # There is a bug with how this works +MAX_STREAM_DELAY_MS = 25 file_handler = logging.FileHandler("./demo.log", mode="w") file_handler.setLevel(logging.DEBUG) @@ -231,7 +231,13 @@ def compile_model(model: TransformerLM): ) model = _compile_decode_one(model=model, logger=logger) - for chunk_size in list({PREFILL_CHUNK_SIZE, RECALC_DUR_PREFILL_CHUNK_SIZE}): + for chunk_size in list( + { + # PREFILL_CHUNK_SIZE_L, + PREFILL_CHUNK_SIZE, + RECALC_DUR_PREFILL_CHUNK_SIZE, + } + ): model = _compile_prefill( model=model, logger=logger, chunk_size=chunk_size ) @@ -374,7 +380,7 @@ def decode_first_tokens( ): logger = get_logger("GENERATE") - buffer_ms = FIRST_ONSET_BUFFER_MS + buffer_ms = FIRST_ONSET_BUFFER_MS - HARDWARE_LATENCY_MS time_tok_id = tokenizer.tok_to_id[tokenizer.time_tok] eos_tok_id = tokenizer.tok_to_id[tokenizer.eos_tok] dim_tok_id = tokenizer.tok_to_id[tokenizer.dim_tok] @@ -515,6 +521,10 @@ def decode_tokens( f"Using sampling parameters: temperature={temperature}, min_p={min_p}" ) + # TODO: This seems to fix issues? + if control_sentinel.is_set(): + control_sentinel.clear() + while (not control_sentinel.is_set()) and idx < MAX_SEQ_LEN: decode_one_start_time_s = time.time() prev_tok_id = enc_seq[0, idx - 1] @@ -550,14 +560,14 @@ def decode_tokens( ) if next_token == tokenizer.eos_tok: - logger.info("EOS token produced, exiting...") + logger.info("EOS token produced") generated_tokens_queue.put(next_token) return else: generated_tokens_queue.put(next_token) idx += 1 - logger.info("Seen exit signal") + logger.info(f"Finished generating: {idx}") generated_tokens_queue.put(None) @@ -691,7 +701,7 @@ def decode_tokens_to_midi( return elif tok is None: - logger.info(f"Seen exit signal") + logger.info(f"Seen exit signal: Sentinel") return logger.debug(f"Seen token: {tok}") @@ -754,7 +764,6 @@ def decode_tokens_to_midi( note_buffer = [] -# TODO: Refactor for readability def stream_midi( inbound_midi_msg_queue: queue.Queue, msgs: list[mido.Message], @@ -1055,11 +1064,30 @@ def chunked_prefill( num_prefill_toks = len(prefill_toks) logger.debug(f"Tokens to prefill: {len(prefill_toks)}") - if num_prefill_toks > PREFILL_CHUNK_SIZE: + if num_prefill_toks > PREFILL_CHUNK_SIZE_L: logger.debug( - f"Prefilling {PREFILL_CHUNK_SIZE} tokens from idx={prefill_idx}" + f"Prefilling {PREFILL_CHUNK_SIZE_L} tokens from idx={prefill_idx}" + ) + mx.eval( + prefill( + model, + idxs=mx.array( + [prefill_toks[:PREFILL_CHUNK_SIZE_L]], + dtype=mx.int32, + ), + input_pos=mx.arange( + prefill_idx, + prefill_idx + PREFILL_CHUNK_SIZE_L, + dtype=mx.int32, + ), + ) ) + prev_context = curr_context[: prefill_idx + PREFILL_CHUNK_SIZE_L] + elif num_prefill_toks > PREFILL_CHUNK_SIZE: + logger.debug( + f"Prefilling {PREFILL_CHUNK_SIZE} tokens from idx={prefill_idx}" + ) mx.eval( prefill( model, @@ -1131,7 +1159,9 @@ def continuous_prefill( msgs.append(msg) msg_cnt += 1 - if (msg_cnt >= 10 or seen_sentinel) and len(msgs) > 30: + # TODO: This workaround is not good enough. Instead just loop is + # curr context has no notes. + if (msg_cnt >= 10 or seen_sentinel) and len(msgs) > 75: midi = convert_msgs_to_midi(msgs=msgs) midi_dict = MidiDict(**midi_to_dict(midi)) curr_context = tokenizer.encode( @@ -1193,6 +1223,7 @@ def capture_and_update_kv( return msgs, prev_context, first_on_msg_epoch_ms, num_active_pitches +# TODO: Change MIDI-through logic, this is not the best way to mock playing def capture_midi_input( midi_input_port: str, control_sentinel: threading.Event, @@ -1209,8 +1240,6 @@ def capture_midi_input( first_on_msg_epoch_ms = None prev_msg_epoch_time_ms = first_msg_epoch_time_ms # - logger.info(f"Listening on MIDI port: '{midi_input_port}'") - logger.info(f"Using MIDI control signal: {midi_control_signal}") if midi_through_port is not None: logger.info(f"Sending through on MIDI port: '{midi_through_port}'") @@ -1221,6 +1250,15 @@ def capture_midi_input( if midi_through_port else None ) + logger.info(f"Listening on MIDI port: '{midi_input_port}'") + logger.info(f"Ready to capture MIDI events") + + if midi_control_signal is not None: + logger.info( + f"Commencing generation upon keypress or MIDI control: {midi_control_signal}" + ) + else: + logger.info(f"Commencing generation upon keypress") while not control_sentinel.is_set() or ( wait_for_close and active_pitches @@ -1285,18 +1323,18 @@ def capture_midi_input( if midi_through is not None: midi_through.send(msg) - while active_pitches: - pitch = active_pitches.pop() - msg = mido.Message( - type="note_on", - note=pitch, - velocity=0, - channel=midi_capture_channel, - time=0, - ) - received_messages_queue.put(msg) - if midi_through is not None: - midi_through.send(msg) + while active_pitches: + pitch = active_pitches.pop() + msg = mido.Message( + type="note_on", + note=pitch, + velocity=0, + channel=midi_capture_channel, + time=0, + ) + received_messages_queue.put(msg) + if midi_through is not None: + midi_through.send(msg) # Turn off pedal msg = mido.Message( @@ -1320,7 +1358,7 @@ def play_midi_file(midi_port: str, midi_path: str): midi_dict = MidiDict.from_midi(midi_path) - if MIN_NOTE_DELTA_MS: + if MIN_NOTE_DELTA_MS > 0: midi_dict.enforce_gaps(min_gap_ms=MIN_NOTE_DELTA_MS) mid = midi_dict.to_midi() @@ -1352,7 +1390,7 @@ def listen_for_keypress_control_signal( ): logger = get_logger("KEYBOARD") while True: - time.sleep(3) + time.sleep(5) _input = input() logger.info(f'Keypress seen "{_input}"') if _input == "": @@ -1363,12 +1401,12 @@ def listen_for_keypress_control_signal( return -# TODO: Get rid of logic for end sentinel -def listen_for_midi_control_signal( +def _listen( midi_input_port: str, - control_sentinel: threading.Event, + logger: logging.Logger, midi_control_signal: int | None = None, ): + logger.info("Listening...") with mido.open_input(midi_input_port) as midi_input: while True: msg = midi_input.receive(block=False) @@ -1379,7 +1417,25 @@ def listen_for_midi_control_signal( and msg.control == midi_control_signal and msg.value >= 64 ): - control_sentinel.set() + return + + +def listen_for_midi_control_signal( + midi_input_port: str, + control_sentinel: threading.Event, + midi_control_signal: int | None = None, +): + logger = get_logger("MIDI-CONTROL") + + while True: + _listen( + midi_input_port=midi_input_port, + midi_control_signal=midi_control_signal, + logger=logger, + ) + control_sentinel.set() + logger.info("Seen MIDI control signal") + time.sleep(5) def parse_args(): @@ -1443,16 +1499,15 @@ def main(args): assert (args.midi_path and os.path.isfile(args.midi_path)) or args.midi_in if args.midi_path: - # TODO: Don't hardcode this midi_input_port = "IAC Driver Bus 1" play_file_thread = threading.Thread( target=play_midi_file, args=(midi_input_port, args.midi_path), daemon=True, ) - play_file_thread.start() else: midi_input_port = args.midi_in + play_file_thread = None control_sentinel = threading.Event() generate_ending_sentinel = threading.Event() @@ -1464,7 +1519,9 @@ def main(args): midi_control_thread = threading.Thread( target=listen_for_midi_control_signal, kwargs={ - "midi_input_port": midi_input_port, + "midi_input_port": ( + args.midi_in if args.midi_in else midi_input_port + ), "control_sentinel": control_sentinel, "midi_control_signal": args.midi_control_signal, }, @@ -1473,6 +1530,9 @@ def main(args): keypress_thread.start() midi_control_thread.start() + if play_file_thread is not None: + play_file_thread.start() + msgs, prev_context, first_on_msg_epoch_ms, num_active_pitches = ( capture_and_update_kv( model=model, @@ -1560,6 +1620,4 @@ def exit(midi_out_port: str): try: main(args) except KeyboardInterrupt: - if args.midi_out: - exit(args.midi_out) - raise + exit(args.midi_out) From 893495d504ef88a32f254779d7f3ebab7a5db800 Mon Sep 17 00:00:00 2001 From: Louis Date: Mon, 16 Jun 2025 19:46:56 +0100 Subject: [PATCH 69/72] add input delay correction --- demo/calibrate.py | 360 ++++++++++++++++++++++++++++++++++++++++++++++ demo/demo_mlx.py | 202 ++++++++++++++------------ 2 files changed, 470 insertions(+), 92 deletions(-) create mode 100644 demo/calibrate.py diff --git a/demo/calibrate.py b/demo/calibrate.py new file mode 100644 index 0000000..74d0720 --- /dev/null +++ b/demo/calibrate.py @@ -0,0 +1,360 @@ +import argparse +import sys +import threading +import time + +import mido + +MIDDLE_C = 60 +C_MAJOR_CHORD = [MIDDLE_C, 64, 67, 72] # C4, E4, G4, C5 + + +def schedule_note_off(port: mido.ports.BaseOutput, note: int, delay: float): + """Schedules a non-blocking MIDI note-off message.""" + + def _off(): + port.send(mido.Message("note_off", note=note, velocity=0)) + + t = threading.Timer(delay, _off) + t.daemon = True # Allow main program to exit even if timers are pending + t.start() + + +def strike( + port: mido.ports.BaseOutput, velocity: int, offset_ms: int, notes: list[int] +): + """ + Performs a "3-2-1-GO!" countdown, sending MIDI notes with a precise offset. + The note-on message is sent `offset_ms` *before* "GO!" is printed. + """ + offset_sec = offset_ms / 1000.0 + + print("3") + time.sleep(1) + print("2") + time.sleep(1) + print("1") + + # Use monotonic time for a clock that is not affected by system time changes + go_time = time.monotonic() + 1.0 + note_on_time = go_time - offset_sec + + # Wait until the calculated time to send the MIDI message + sleep_duration = note_on_time - time.monotonic() + if sleep_duration > 0: + time.sleep(sleep_duration) + + for note in notes: + port.send(mido.Message("note_on", note=note, velocity=velocity)) + schedule_note_off(port, note, delay=0.5) + + # Wait for the exact moment to print "GO!" + sleep_duration = go_time - time.monotonic() + if sleep_duration > 0: + time.sleep(sleep_duration) + + print("GO!\n") + + +def note_repetition_trial( + port: mido.ports.BaseOutput, + velocity: int, + notes: list[int], + note_length_ms: int, + gap_ms: int, +): + """Plays a note or chord repeatedly for a 3-second trial period.""" + print("Playing 3-second loop...") + + note_length_sec = note_length_ms / 1000.0 + gap_sec = gap_ms / 1000.0 + end_time = time.monotonic() + 3.0 + + while time.monotonic() < end_time: + # Ensure there's enough time for one full note cycle before the end + if time.monotonic() + note_length_sec + gap_sec > end_time: + break + + for note in notes: + port.send(mido.Message("note_on", note=note, velocity=velocity)) + + time.sleep(note_length_sec) + + for note in notes: + port.send(mido.Message("note_off", note=note, velocity=0)) + + if gap_sec > 0: + time.sleep(gap_sec) + + print("...loop finished.\n") + + +def calibrate_output_latency( + port_name: str, + velocity: int, + step_ms: int, + initial_offset_ms: int, + chord_mode: bool, +): + """Interactive loop to find the ideal hardware latency offset.""" + notes = C_MAJOR_CHORD if chord_mode else [MIDDLE_C] + offset_ms = initial_offset_ms + + try: + with mido.open_output(port_name) as port: + print(f"Opened MIDI output: {port_name}\n") + while True: + strike(port, velocity, offset_ms, notes) + print(f"Current offset: {offset_ms} ms") + cmd = ( + input("[u]p / [d]own / [r]epeat / [q]uit: ").strip().lower() + ) + + if cmd == "u": + offset_ms += step_ms + elif cmd == "d": + offset_ms = max(0, offset_ms - step_ms) + elif cmd == "q": + break + # Any other key (incl. 'r' or enter) repeats the trial + print() + except (KeyboardInterrupt, SystemExit): + print("\nInterrupted — exiting.") + except Exception as e: + print(f"\nAn error occurred: {e}") + + +def calibrate_note_timing( + port_name: str, + velocity: int, + step_ms: int, + note_length_ms: int, + initial_gap_ms: int, + chord_mode: bool, +): + """Interactive loop to find a comfortable note repetition speed.""" + notes = C_MAJOR_CHORD if chord_mode else [MIDDLE_C] + gap_ms = initial_gap_ms + + try: + with mido.open_output(port_name) as port: + print(f"Opened MIDI output: {port_name}\n") + while True: + note_repetition_trial( + port, velocity, notes, note_length_ms, gap_ms + ) + print(f"Current gap: {gap_ms} ms") + cmd = ( + input("[u]p / [d]own / [r]epeat / [q]uit: ").strip().lower() + ) + + if cmd == "u": + gap_ms += step_ms + elif cmd == "d": + gap_ms = max(0, gap_ms - step_ms) + elif cmd == "q": + break + print() + except (KeyboardInterrupt, SystemExit): + print("\nInterrupted — exiting.") + except Exception as e: + print(f"\nAn error occurred: {e}") + + +def measure_input_latency(port_name: str, timeout_sec: float = 2.0): + """ + 3-2-1-GO countdown → you strike a key on GO. + Prints the latency (note_on arrival – GO). + + • Uses the same MIDI port for input. + • Waits `timeout_sec` seconds for a note-on; repeats if none arrives. + """ + try: + with mido.open_ioport(port_name) as port: + print(f"Opened MIDI I/O port: {port_name}\n") + + while True: + # ── simple countdown ──────────────────────────────────── + for n in ("3", "2", "1"): + print(n) + time.sleep(1) + + go_time = time.monotonic() + print("GO!") + + # wait for first note-on (velocity>0) or timeout + deadline = go_time + timeout_sec + latency_ms = None + while time.monotonic() < deadline: + msg = port.poll() + if msg and msg.type == "note_on" and msg.velocity > 0: + latency_ms = (time.monotonic() - go_time) * 1000.0 + break + + if latency_ms is None: + print("No key press detected – try again.\n") + else: + print(f"Input latency: {latency_ms:.1f} ms\n") + + if input("[r]etry / [q]uit: ").strip().lower() == "q": + break + print() + + except (KeyboardInterrupt, SystemExit): + print("\nInterrupted — exiting.") + except Exception as e: + print(f"\nAn error occurred: {e}") + + +def list_midi_ports() -> None: + """Prints a list of available MIDI output ports.""" + print("Available MIDI output ports:") + try: + port_names = mido.get_output_names() + if not port_names: + print(" (No ports found)") + for name in port_names: + print(f" - {name}") + except Exception as e: + print(f"Could not retrieve MIDI ports: {e}") + + +def parse_args(): + """Parses command-line arguments for the calibration tool.""" + parser = argparse.ArgumentParser( + description="A tool to calibrate Disklavier latency and note timing.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # ── global option ────────────────────────────────────────────────────── + parser.add_argument( + "--list-ports", + action="store_true", + help="List available MIDI output ports and exit.", + ) + + # ── options common to *all* modes ───────────────────────────────────── + parent = argparse.ArgumentParser(add_help=False) + parent.add_argument("--port", "-p", required=True, help="MIDI port name.") + parent.add_argument( + "--velocity", + "-v", + type=int, + default=80, + help="Note-on velocity (1-127).", + ) + parent.add_argument( + "--step", + "-s", + type=int, + default=10, + help="Adjustment step in ms (latency/timing modes).", + ) + parent.add_argument( + "--chord", + "-c", + action="store_true", + help="Use a C-major chord instead of single note.", + ) + + sub = parser.add_subparsers(dest="command", help="Available commands.") + + # ── output-latency calibration ──────────────────────────────────────── + p_lat = sub.add_parser( + "output", + parents=[parent], + help="Calibrate output latency.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + p_lat.add_argument( + "--offset", + "-o", + type=int, + default=100, + help="Initial latency offset in ms.", + ) + + # ── repeated-note timing calibration ────────────────────────────────── + p_tim = sub.add_parser( + "timing", + parents=[parent], + help="Calibrate minimum gap between notes.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + p_tim.add_argument( + "--note-length", + "-l", + type=int, + default=500, + help="Note duration in ms.", + ) + p_tim.add_argument( + "--gap", + "-g", + type=int, + default=100, + help="Initial gap between notes in ms.", + ) + + # ── input-latency measurement (new) ─────────────────────────────────── + p_in = sub.add_parser( + "input", + parents=[parent], + help="Measure input latency (countdown → strike).", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + p_in.add_argument( + "--timeout", + "-t", + type=float, + default=2.0, + help="Seconds to wait for a key press before retry.", + ) + + args = parser.parse_args() + + # global flag handler + if args.list_ports: + list_midi_ports() + sys.exit(0) + + if not args.command: + parser.error( + "A command is required: choose 'output', 'timing', or 'input'." + ) + + return args + + +def main(): + """Dispatches to the selected calibration or measurement routine.""" + args = parse_args() + + if args.command == "output": + calibrate_output_latency( + port_name=args.port, + velocity=args.velocity, + step_ms=args.step, + initial_offset_ms=args.offset, + chord_mode=args.chord, + ) + + elif args.command == "timing": + calibrate_note_timing( + port_name=args.port, + velocity=args.velocity, + step_ms=args.step, + note_length_ms=args.note_length, + initial_gap_ms=args.gap, + chord_mode=args.chord, + ) + + elif args.command == "input": + measure_input_latency( + port_name=args.port, + timeout_sec=args.timeout, + ) + + +if __name__ == "__main__": + main() diff --git a/demo/demo_mlx.py b/demo/demo_mlx.py index 7ef8954..5f34820 100644 --- a/demo/demo_mlx.py +++ b/demo/demo_mlx.py @@ -17,8 +17,6 @@ import mlx.nn as nn import numpy as np -from contextlib import ExitStack - from ariautils.midi import MidiDict, midi_to_dict from ariautils.tokenizer import AbsTokenizer from aria.inference.model_mlx import TransformerLM @@ -34,14 +32,24 @@ BEAM_WIDTH = 3 TIME_TOK_WEIGHTING = -5 -FIRST_ONSET_BUFFER_MS = -200 # Controls onset timing for first generated note +FIRST_ONSET_BUFFER_MS = -100 # Controls onset timing for first generated note # HARDWARE: Decoded logits are masked for durations < MIN_NOTE_LEN_MS # HARDWARE: Sends early off-msg if pitch is on MIN_NOTE_DELTA_MS before on-msg -# HARDWARE: All messages are sent HARDWARE_LATENCY_MS early -MIN_NOTE_DELTA_MS = 100 -MIN_NOTE_LEN_MS = 50 -HARDWARE_LATENCY_MS = 150 # There is a bug with how this works +# HARDWARE: All messages are sent HARDWARE_OUTPUT_LATENCY_MS early + +# C4DM Disklavier: +# MIN_NOTE_DELTA_MS = 40 +# MIN_NOTE_LEN_MS = 50 +# HARDWARE_INPUT_LATENCY_MS = 50 +# HARDWARE_OUTPUT_LATENCY_MS = 150 + +# Pianoteq +MIN_NOTE_DELTA_MS = 0 +MIN_NOTE_LEN_MS = 0 +HARDWARE_INPUT_LATENCY_MS = 0 +HARDWARE_OUTPUT_LATENCY_MS = 0 + MAX_STREAM_DELAY_MS = 25 file_handler = logging.FileHandler("./demo.log", mode="w") @@ -150,13 +158,15 @@ def _compile_prefill( compile_start_time_s = time.time() logger.info(f"Compiling prefill (chunk_size={chunk_size})") - for _start_idx in range(0, MAX_SEQ_LEN, chunk_size * 4): + for _ in range(5): mx.eval( prefill( model, idxs=mx.ones([1, chunk_size], dtype=mx.int32), input_pos=mx.arange( - _start_idx, _start_idx + chunk_size, dtype=mx.int32 + MAX_SEQ_LEN - (chunk_size + 1), + MAX_SEQ_LEN - 1, + dtype=mx.int32, ), ) ) @@ -190,12 +200,12 @@ def _compile_decode_one( # Don't need to explicitly compile with mlx, instead we are just precalculating # the computation graphs for different shapes compile_start_time_s = time.time() - for _start_idx in range(0, MAX_SEQ_LEN, 4): + for _ in range(5): mx.eval( decode_one( model, idxs=mx.array([[random.randint(0, 20)]], dtype=mx.int32), - input_pos=mx.array([_start_idx], dtype=mx.int32), + input_pos=mx.array([MAX_SEQ_LEN - 1], dtype=mx.int32), ), ) logger.info( @@ -233,7 +243,7 @@ def compile_model(model: TransformerLM): model = _compile_decode_one(model=model, logger=logger) for chunk_size in list( { - # PREFILL_CHUNK_SIZE_L, + PREFILL_CHUNK_SIZE_L, PREFILL_CHUNK_SIZE, RECALC_DUR_PREFILL_CHUNK_SIZE, } @@ -380,7 +390,8 @@ def decode_first_tokens( ): logger = get_logger("GENERATE") - buffer_ms = FIRST_ONSET_BUFFER_MS - HARDWARE_LATENCY_MS + # buffer_ms determines how far in the past to start generating notes + buffer_ms = FIRST_ONSET_BUFFER_MS - HARDWARE_OUTPUT_LATENCY_MS time_tok_id = tokenizer.tok_to_id[tokenizer.time_tok] eos_tok_id = tokenizer.tok_to_id[tokenizer.eos_tok] dim_tok_id = tokenizer.tok_to_id[tokenizer.dim_tok] @@ -394,6 +405,7 @@ def decode_first_tokens( num_time_toks_to_add = num_time_toks_required - num_time_toks_in_priming_seq logger.info(f"Time since first onset: {time_since_first_onset_ms}ms") + logger.info(f"Using first note-onset buffer: {buffer_ms}ms") while num_time_toks_to_add > 0: generated_tokens_queue.put(tokenizer.time_tok) @@ -521,7 +533,6 @@ def decode_tokens( f"Using sampling parameters: temperature={temperature}, min_p={min_p}" ) - # TODO: This seems to fix issues? if control_sentinel.is_set(): control_sentinel.clear() @@ -672,7 +683,8 @@ def decode_tokens_to_midi( logger = get_logger("DECODE") assert ( - first_on_msg_epoch_ms + priming_seq_last_onset_ms < get_epoch_time_ms() + first_on_msg_epoch_ms + priming_seq_last_onset_ms + < get_epoch_time_ms() + HARDWARE_INPUT_LATENCY_MS ) logger.info(f"Priming sequence last onset: {priming_seq_last_onset_ms}") @@ -778,7 +790,7 @@ def stream_midi( f"Sending generated messages on MIDI port: '{midi_output_port}'" ) logger.info( - f"Applying hardware latency adjustment: {HARDWARE_LATENCY_MS}ms" + f"Applying hardware output latency adjustment: {HARDWARE_OUTPUT_LATENCY_MS}ms" ) active_pitch_uuid = {} @@ -808,8 +820,9 @@ def stream_midi( break while midi_msgs: + # Messages are sent HARDWARE_OUTPUT_LATENCY_MS early latency_adjusted_epoch_time_ms = ( - get_epoch_time_ms() + HARDWARE_LATENCY_MS + get_epoch_time_ms() + HARDWARE_OUTPUT_LATENCY_MS ) msg = midi_msgs[0] @@ -1159,21 +1172,22 @@ def continuous_prefill( msgs.append(msg) msg_cnt += 1 - # TODO: This workaround is not good enough. Instead just loop is - # curr context has no notes. - if (msg_cnt >= 10 or seen_sentinel) and len(msgs) > 75: + if msg_cnt >= 10 or seen_sentinel: midi = convert_msgs_to_midi(msgs=msgs) midi_dict = MidiDict(**midi_to_dict(midi)) - curr_context = tokenizer.encode( - tokenizer.tokenize(midi_dict, add_dim_tok=False) - ) - prev_context = chunked_prefill( - model=model, - tokenizer=tokenizer, - prev_context=prev_context, - curr_context=curr_context, - full=False, - ) + + if len(midi_dict.note_msgs) > 0: + curr_context = tokenizer.encode( + tokenizer.tokenize(midi_dict, add_dim_tok=False) + ) + prev_context = chunked_prefill( + model=model, + tokenizer=tokenizer, + prev_context=prev_context, + curr_context=curr_context, + full=False, + ) + msg_cnt = 0 else: time.sleep(0.01) @@ -1190,7 +1204,6 @@ def capture_and_update_kv( midi_input_port: str, midi_capture_channel: int, midi_control_signal: int | None = None, - midi_through_port: str | None = None, first_msg_epoch_time_ms: int | None = None, ): received_messages_queue = queue.Queue() @@ -1203,7 +1216,6 @@ def capture_and_update_kv( "received_messages_queue": received_messages_queue, "midi_capture_channel": midi_capture_channel, "midi_control_signal": midi_control_signal, - "midi_through_port": midi_through_port, "first_msg_epoch_time_ms": first_msg_epoch_time_ms, "results_queue": results_queue, "wait_for_close": wait_for_close, @@ -1223,7 +1235,6 @@ def capture_and_update_kv( return msgs, prev_context, first_on_msg_epoch_ms, num_active_pitches -# TODO: Change MIDI-through logic, this is not the best way to mock playing def capture_midi_input( midi_input_port: str, control_sentinel: threading.Event, @@ -1231,28 +1242,25 @@ def capture_midi_input( midi_capture_channel: int, results_queue: queue.Queue, midi_control_signal: int | None = None, - midi_through_port: str | None = None, first_msg_epoch_time_ms: int | None = None, wait_for_close: bool = False, ): logger = get_logger("CAPTURE") active_pitches = set() first_on_msg_epoch_ms = None - prev_msg_epoch_time_ms = first_msg_epoch_time_ms # + prev_msg_epoch_time_ms = first_msg_epoch_time_ms - if midi_through_port is not None: - logger.info(f"Sending through on MIDI port: '{midi_through_port}'") + logger.info(f"Listening on MIDI port: '{midi_input_port}'") + logger.info(f"Ready to capture MIDI events") - with ExitStack() as stack: - midi_input = stack.enter_context(mido.open_input(midi_input_port)) - midi_through = ( - stack.enter_context(mido.open_output(midi_through_port)) - if midi_through_port - else None - ) - logger.info(f"Listening on MIDI port: '{midi_input_port}'") - logger.info(f"Ready to capture MIDI events") + # Clear undesired buffered notes + with mido.open_input(midi_input_port) as midi_input: + while True: + msg = midi_input.receive(block=False) + if msg is None: + break + with mido.open_input(midi_input_port) as midi_input: if midi_control_signal is not None: logger.info( f"Commencing generation upon keypress or MIDI control: {midi_control_signal}" @@ -1287,16 +1295,14 @@ def capture_midi_input( ) or msg.type == "note_off": active_pitches.discard(msg.note) received_messages_queue.put(msg) - if midi_through is not None: - midi_through.send(msg) elif msg.type == "note_on" and msg.velocity > 0: if first_on_msg_epoch_ms is None: - first_on_msg_epoch_ms = get_epoch_time_ms() + first_on_msg_epoch_ms = ( + get_epoch_time_ms() - HARDWARE_INPUT_LATENCY_MS + ) active_pitches.add(msg.note) received_messages_queue.put(msg) - if midi_through is not None: - midi_through.send(msg) elif msg.type == "control_change" and msg.control == 64: received_messages_queue.put(msg) elif ( @@ -1320,8 +1326,6 @@ def capture_midi_input( time=get_epoch_time_ms() - prev_msg_epoch_time_ms, ) received_messages_queue.put(msg) - if midi_through is not None: - midi_through.send(msg) while active_pitches: pitch = active_pitches.pop() @@ -1333,8 +1337,6 @@ def capture_midi_input( time=0, ) received_messages_queue.put(msg) - if midi_through is not None: - midi_through.send(msg) # Turn off pedal msg = mido.Message( @@ -1345,43 +1347,48 @@ def capture_midi_input( time=0, ) received_messages_queue.put(msg) - if midi_through is not None: - midi_through.send(msg) - - received_messages_queue.put(None) # Sentinel + received_messages_queue.put(None) results_queue.put((first_on_msg_epoch_ms, num_active_pitches)) -def play_midi_file(midi_port: str, midi_path: str): - logger = get_logger("FILE") - logger.info(f"Playing file at {midi_path} on MIDI port '{midi_port}'") +def play_midi_file( + midi_through_port: str, + midi_in_port: str, + midi_path: str, + currently_streaming_sentinel: threading.Event, +): + def _send_delayed_message(port, msg): + port.send(msg) + logger.debug(f"SENT: {msg}") - midi_dict = MidiDict.from_midi(midi_path) + logger = get_logger("FILE") + logger.info(f"Playing {midi_path} on through-port '{midi_through_port}'") + logger.info( + f"Simulating input to port '{midi_in_port}' with {HARDWARE_INPUT_LATENCY_MS}ms latency" + ) if MIN_NOTE_DELTA_MS > 0: + midi_dict = MidiDict.from_midi(midi_path) midi_dict.enforce_gaps(min_gap_ms=MIN_NOTE_DELTA_MS) - - mid = midi_dict.to_midi() + mid = midi_dict.to_midi() + else: + mid = mido.MidiFile(midi_path) time.sleep(1) - active_pitches = [] - with mido.open_output(midi_port) as output_port: - for msg in mid.play(): - if msg.type == "note_on" and msg.velocity > 0: - if msg.note in active_pitches: - _off_msg = copy.deepcopy(msg) - _off_msg.velocity = 0 - output_port.send(_off_msg) - else: - active_pitches.append(msg.note) - elif msg.type == "note_off" or ( - msg.type == "note_on" and msg.velocity == 0 - ): - if msg.note in active_pitches: - active_pitches.remove(msg.note) + with mido.open_output(midi_through_port) as through_port: + with mido.open_output(midi_in_port) as in_port: + for msg in mid.play(): + if currently_streaming_sentinel.is_set() is False and not ( + msg.type == "control_change" and msg.control == 64 + ): + through_port.send(msg) - logger.debug(f"{msg}") - output_port.send(msg) + timer = threading.Timer( + interval=HARDWARE_INPUT_LATENCY_MS / 1000.0, + function=_send_delayed_message, + args=[in_port, msg], + ) + timer.start() def listen_for_keypress_control_signal( @@ -1487,9 +1494,6 @@ def parse_args(): return argp.parse_args() -# TODO: Need functionality for handing case where we run out of model context - - def main(args): args = parse_args() logger = get_logger() @@ -1498,19 +1502,32 @@ def main(args): model = compile_model(model=model) assert (args.midi_path and os.path.isfile(args.midi_path)) or args.midi_in + + control_sentinel = threading.Event() + generate_ending_sentinel = threading.Event() + currently_generating_sentinel = threading.Event() + + if args.midi_through: + close_notes(args.midi_through) + if args.midi_out: + close_notes(args.midi_out) + if args.midi_path: midi_input_port = "IAC Driver Bus 1" play_file_thread = threading.Thread( target=play_midi_file, - args=(midi_input_port, args.midi_path), + args=( + args.midi_through, + midi_input_port, + args.midi_path, + currently_generating_sentinel, + ), daemon=True, ) else: midi_input_port = args.midi_in play_file_thread = None - control_sentinel = threading.Event() - generate_ending_sentinel = threading.Event() keypress_thread = threading.Thread( target=listen_for_keypress_control_signal, args=[control_sentinel, generate_ending_sentinel], @@ -1542,7 +1559,6 @@ def main(args): wait_for_close=args.wait_for_close, midi_input_port=midi_input_port, midi_control_signal=args.midi_control_signal, - midi_through_port=args.midi_through, midi_capture_channel=0, ) ) @@ -1550,6 +1566,7 @@ def main(args): curr_midi_channel = 0 while True: control_sentinel.clear() + currently_generating_sentinel.set() msgs = stream_msgs( model=model, tokenizer=tokenizer, @@ -1573,6 +1590,7 @@ def main(args): if generate_ending_sentinel.is_set(): break else: + currently_generating_sentinel.clear() msgs, prev_context, _, num_active_pitches = capture_and_update_kv( model=model, msgs=msgs, @@ -1581,7 +1599,6 @@ def main(args): wait_for_close=args.wait_for_close, midi_input_port=midi_input_port, midi_control_signal=args.midi_control_signal, - midi_through_port=args.midi_through, midi_capture_channel=curr_midi_channel, first_msg_epoch_time_ms=first_on_msg_epoch_ms, ) @@ -1608,8 +1625,9 @@ def main(args): midi.save(args.save_path) -def exit(midi_out_port: str): +def close_notes(midi_out_port: str): with mido.open_output(midi_out_port) as out: + out.send(mido.Message(type="control_change", control=64, value=0)) for note in range(128): out.send(mido.Message("note_off", note=note, velocity=0)) @@ -1620,4 +1638,4 @@ def exit(midi_out_port: str): try: main(args) except KeyboardInterrupt: - exit(args.midi_out) + close_notes(args.midi_out) From 6832c0226023ee2b1bda5a6ccf6b10a43d279d74 Mon Sep 17 00:00:00 2001 From: Louis Date: Thu, 19 Jun 2025 17:43:44 +0000 Subject: [PATCH 70/72] update README --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index ca65eef..7ed175a 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Aria -This repository contains training, inference, and evaluation code for the paper *Scaling Self-Supervised Representation Learning for Symbolic Piano Performance (ISMIR 2025)*, as well as implementations of our real-time piano continuation demo. *Aria* is a pretrained autoregressive generative model for symbolic music, based on the LLaMA 3.2 (1B) architecture, which was trained on ~60k hours of MIDI transcriptions of expressive solo-piano recordings. Alongside the base model, we are releasing a checkpoint finetuned to improve generative quality, as well as a checkpoint finetuned to produce general-purpose piano MIDI embeddings using a SimCSE-style contrastive training objective. +This repository contains training, inference, and evaluation code for the paper [*Scaling Self-Supervised Representation Learning for Symbolic Piano Performance (ISMIR 2025)*](https://example.com/), as well as implementations of our real-time piano continuation demo. *Aria* is a pretrained autoregressive generative model for symbolic music, based on the LLaMA 3.2 (1B) architecture, which was trained on ~60k hours of MIDI transcriptions of expressive solo-piano recordings. Alongside the base model, we are releasing a checkpoint finetuned to improve generative quality, as well as a checkpoint finetuned to produce general-purpose piano MIDI embeddings using a SimCSE-style contrastive training objective. 📖 Read our [release blog post](https://example.com/) and [paper](https://example.com/) 🤗 Access our models via the [HuggingFace page](https://huggingface.co/loubb/aria-medium-base) @@ -41,7 +41,7 @@ aria generate \ --save_dir ``` -Since the model has not been post-trained with instruction tuning or RLHF (similar to pre-instruct GPT models), it is very sensitive to input quality and performs best when prompted with well-played music. To get prompt MIDI files, see the `example-prompts/` directory, explore the [Aria-MIDI](https://huggingface.co/datasets/loubb/aria-midi) dataset, or transcribe your own files using [piano-transcription model](https://github.com/EleutherAI/aria-amt). For a full list of sampling options: `aria generate -h`. If you wish to do inference on the CPU, please see the platform-agnostic implementation on our HuggingFace page [link]. +Since the model has not been post-trained with instruction tuning or RLHF (similar to pre-instruct GPT models), it is very sensitive to input quality and performs best when prompted with well-played music. To get prompt MIDI files, see the `example-prompts/` directory, explore the [Aria-MIDI](https://huggingface.co/datasets/loubb/aria-midi) dataset, or transcribe your own files using our [piano-transcription model](https://github.com/EleutherAI/aria-amt). For a full list of sampling options: `aria generate -h`. If you wish to do inference on the CPU, please see the platform-agnostic implementation on our HuggingFace page [link]. ### Intended Use and Limitations From abf8e215fb3b22ac90e36a1beff8f7b8466fcb38 Mon Sep 17 00:00:00 2001 From: Louis Date: Mon, 30 Jun 2025 15:32:15 +0100 Subject: [PATCH 71/72] add quantize option to demo --- demo/demo_mlx.py | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/demo/demo_mlx.py b/demo/demo_mlx.py index 5f34820..1e10727 100644 --- a/demo/demo_mlx.py +++ b/demo/demo_mlx.py @@ -32,7 +32,7 @@ BEAM_WIDTH = 3 TIME_TOK_WEIGHTING = -5 -FIRST_ONSET_BUFFER_MS = -100 # Controls onset timing for first generated note +FIRST_ONSET_BUFFER_MS = -100 # Controls onset timing for first generated not # HARDWARE: Decoded logits are masked for durations < MIN_NOTE_LEN_MS # HARDWARE: Sends early off-msg if pitch is on MIN_NOTE_DELTA_MS before on-msg @@ -40,17 +40,17 @@ # C4DM Disklavier: # MIN_NOTE_DELTA_MS = 40 -# MIN_NOTE_LEN_MS = 50 +# MIN_NOTE_LEN_MS = 100 # HARDWARE_INPUT_LATENCY_MS = 50 -# HARDWARE_OUTPUT_LATENCY_MS = 150 +# HARDWARE_OUTPUT_LATENCY_MS = 120 # Pianoteq MIN_NOTE_DELTA_MS = 0 -MIN_NOTE_LEN_MS = 0 +MIN_NOTE_LEN_MS = 30 HARDWARE_INPUT_LATENCY_MS = 0 HARDWARE_OUTPUT_LATENCY_MS = 0 -MAX_STREAM_DELAY_MS = 25 +MAX_STREAM_DELAY_MS = 50 file_handler = logging.FileHandler("./demo.log", mode="w") file_handler.setLevel(logging.DEBUG) @@ -158,14 +158,15 @@ def _compile_prefill( compile_start_time_s = time.time() logger.info(f"Compiling prefill (chunk_size={chunk_size})") - for _ in range(5): + for idx in range(8): # hard-coded “8” + start = idx * (MAX_SEQ_LEN - chunk_size) // 7 # 0 … MAX_SEQ_LEN-chunk mx.eval( prefill( model, idxs=mx.ones([1, chunk_size], dtype=mx.int32), input_pos=mx.arange( - MAX_SEQ_LEN - (chunk_size + 1), - MAX_SEQ_LEN - 1, + start, + start + chunk_size, dtype=mx.int32, ), ) @@ -243,8 +244,8 @@ def compile_model(model: TransformerLM): model = _compile_decode_one(model=model, logger=logger) for chunk_size in list( { - PREFILL_CHUNK_SIZE_L, - PREFILL_CHUNK_SIZE, + # PREFILL_CHUNK_SIZE_L, + # PREFILL_CHUNK_SIZE, RECALC_DUR_PREFILL_CHUNK_SIZE, } ): @@ -269,9 +270,11 @@ def load_model( init_start_time_s = time.time() model = TransformerLM(model_config) model.load_weights(checkpoint_path, strict=False) - nn.quantize(model.model, group_size=64, bits=8) model.eval() + if args.quantize: + nn.quantize(model.model, group_size=64, bits=8) + logger.info( f"Finished initializing model - took {time.time() - init_start_time_s:.4f} seconds" ) @@ -1172,7 +1175,7 @@ def continuous_prefill( msgs.append(msg) msg_cnt += 1 - if msg_cnt >= 10 or seen_sentinel: + if msg_cnt >= 10: midi = convert_msgs_to_midi(msgs=msgs) midi_dict = MidiDict(**midi_to_dict(midi)) @@ -1484,6 +1487,11 @@ def parse_args(): help="wait for note-offs before generating", action="store_true", ) + argp.add_argument( + "--quantize", + help="apply model quantize", + action="store_true", + ) argp.add_argument( "--save_path", type=str, From a378720128d3624f354e83738ea3dbfc834f3303 Mon Sep 17 00:00:00 2001 From: Louis Date: Mon, 30 Jun 2025 14:33:51 +0000 Subject: [PATCH 72/72] delete --- demo/demo.py | 1598 ------------------------------------ demo/demo.sh | 12 - demo/midi-tunnel-client.py | 143 ---- demo/midi-tunnel-server.py | 61 -- 4 files changed, 1814 deletions(-) delete mode 100644 demo/demo.py delete mode 100644 demo/demo.sh delete mode 100644 demo/midi-tunnel-client.py delete mode 100755 demo/midi-tunnel-server.py diff --git a/demo/demo.py b/demo/demo.py deleted file mode 100644 index 46a3a0b..0000000 --- a/demo/demo.py +++ /dev/null @@ -1,1598 +0,0 @@ -#!/usr/bin/env python3 - -import argparse -import os -import time -import uuid -import copy -import logging -import threading -import queue -import copy -import torch -import mido -import torch._inductor.config - -from torch.cuda import is_available as cuda_is_available -from contextlib import ExitStack - -from ariautils.midi import MidiDict, midi_to_dict -from ariautils.tokenizer import AbsTokenizer -from aria.utils import _load_weight -from aria.inference import TransformerLM -from aria.model import ModelConfig -from aria.config import load_model_config -from aria.sample import sample_min_p - -torch._inductor.config.coordinate_descent_tuning = True -torch._inductor.config.triton.unique_kernel_names = True -torch._inductor.config.fx_graph_cache = True - -DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 -MAX_SEQ_LEN = 4096 -PREFILL_CHUNK_SIZE = 32 -RECALC_DUR_PREFILL_CHUNK_SIZE = 8 -RECALC_DUR_BUFFER_MS = 50 - -# Decode first -BEAM_WIDTH = 3 -TIME_TOK_WEIGHTING = -5 -FIRST_ONSET_BUFFER_MS = 25 - -# HARDWARE: Decoded logits are masked for durations < MIN_NOTE_LEN_MS -# HARDWARE: Sends early off-msg if pitch is on MIN_NOTE_DELTA_MS before on-msg -# HARDWARE: All messages are sent HARDWARE_LATENCY_MS early -MIN_NOTE_DELTA_MS = 100 -MIN_NOTE_LEN_MS = 200 -HARDWARE_LATENCY_MS = 0 - -file_handler = logging.FileHandler("./demo.log", mode="w") -file_handler.setLevel(logging.DEBUG) - - -def get_logger(name: str | None = None) -> logging.Logger: - logger = logging.getLogger(name) - if not logger.handlers: - logger.propagate = False - logger.setLevel(logging.DEBUG) - - class MillisecondFormatter(logging.Formatter): - def formatTime(self, record, datefmt=None): - created_ms = int(record.created * 1000) - return str(created_ms) - - if name is not None: - formatter = MillisecondFormatter( - "%(asctime)s: [%(levelname)s] [%(name)s] %(message)s" - ) - else: - formatter = MillisecondFormatter( - "%(asctime)s: [%(levelname)s] %(message)s" - ) - - ch = logging.StreamHandler() - ch.setLevel(logging.INFO) - ch.setFormatter(formatter) - logger.addHandler(ch) - - file_handler.setFormatter(formatter) - logger.addHandler(file_handler) - - return logger - - -def get_epoch_time_ms() -> int: - return round(time.time() * 1000) - - -@torch.autocast("cuda", dtype=DTYPE) -@torch.inference_mode() -def prefill( - model: TransformerLM, - idxs: torch.Tensor, - input_pos: torch.Tensor, - pad_idxs: torch.Tensor | None = None, -) -> torch.Tensor: - logits = model.forward( - idxs=idxs, - input_pos=input_pos, - pad_idxs=pad_idxs, - ) - - return logits - - -@torch.autocast("cuda", dtype=DTYPE) -@torch.inference_mode() -def decode_one( - model: TransformerLM, - idxs: torch.Tensor, - input_pos: torch.Tensor, - pad_idxs: torch.Tensor | None = None, -) -> torch.Tensor: - assert input_pos.shape[-1] == 1 - - logits = model.forward( - idxs=idxs, - input_pos=input_pos, - pad_idxs=pad_idxs, - )[:, -1] - - return logits - - -def _compile_prefill( - model: TransformerLM, - logger: logging.Logger, - chunk_size: int, -): - assert chunk_size > 1 - - global prefill - prefill = torch.compile( - prefill, - mode="reduce-overhead", - fullgraph=True, - ) - start_compile_time_s = time.time() - logger.info(f"Compiling prefill (chunk_size={chunk_size})") - prefill( - model, - idxs=torch.ones(1, chunk_size, device="cuda", dtype=torch.int), - input_pos=torch.arange(0, chunk_size, device="cuda", dtype=torch.int), - ) - logger.info( - f"Finished compiling - took {time.time() - start_compile_time_s:.4f} seconds" - ) - - for _ in range(5): - prefill( - model, - idxs=torch.ones(1, chunk_size, device="cuda", dtype=torch.int), - input_pos=torch.arange( - 0, chunk_size, device="cuda", dtype=torch.int - ), - ) - - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - start_event.record() - prefill( - model, - idxs=torch.ones(1, chunk_size, device="cuda", dtype=torch.int), - input_pos=torch.arange(0, chunk_size, device="cuda", dtype=torch.int), - ) - end_event.record() - end_event.synchronize() - compiled_prefill_ms = start_event.elapsed_time(end_event) - compiled_prefill_its = 1000 / compiled_prefill_ms - logger.info( - f"Compiled prefill benchmark: {compiled_prefill_ms:.2f} ms/it ({compiled_prefill_its:.2f} it/s)" - ) - - return model - - -def _compile_decode_one(model: TransformerLM, logger: logging.Logger): - global decode_one - decode_one = torch.compile( - decode_one, - mode="reduce-overhead", - fullgraph=True, - ) - - with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): - start_compile_time_s = time.time() - logger.info(f"Compiling decode_one") - decode_one( - model, - idxs=torch.tensor([[0]], device="cuda", dtype=torch.int), - input_pos=torch.tensor([0], device="cuda", dtype=torch.int), - ) - logger.info( - f"Finished compiling - took {time.time() - start_compile_time_s:.4f} seconds" - ) - - for _ in range(5): - decode_one( - model, - idxs=torch.tensor([[0]], device="cuda", dtype=torch.int).cuda(), - input_pos=torch.tensor([0], device="cuda", dtype=torch.int), - ) - - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - start_event.record() - decode_one( - model, - idxs=torch.tensor([[0]], device="cuda", dtype=torch.int).cuda(), - input_pos=torch.tensor([0], device="cuda", dtype=torch.int), - ) - end_event.record() - end_event.synchronize() - - compiled_forward_ms = start_event.elapsed_time(end_event) - compiled_forward_its = 1000 / compiled_forward_ms - logger.info( - f"Compiled decode_one benchmark: {compiled_forward_ms:.2f} ms/it ({compiled_forward_its:.2f} it/s)" - ) - - return model - - -@torch.inference_mode() -def compile_model(model: TransformerLM, max_seq_len: int): - logger = get_logger() - assert 10 < max_seq_len <= MAX_SEQ_LEN - - model.eval() - model.setup_cache( - batch_size=1, - max_seq_len=max_seq_len, - dtype=DTYPE, - ) - - model = _compile_decode_one(model=model, logger=logger) - for chunk_size in list({PREFILL_CHUNK_SIZE, RECALC_DUR_PREFILL_CHUNK_SIZE}): - model = _compile_prefill( - model=model, logger=logger, chunk_size=chunk_size - ) - - return model - - -def load_model( - checkpoint_path: str, -): - logger = get_logger() - if not cuda_is_available(): - raise Exception("CUDA device is not available.") - - init_start_time_s = time.time() - - tokenizer = AbsTokenizer() - model_config = ModelConfig(**load_model_config("medium-emb")) - model_config.set_vocab_size(tokenizer.vocab_size) - model_config.grad_checkpoint = False - model = TransformerLM(model_config).cuda() - - logging.info(f"Loading model weights from {checkpoint_path}") - model_state = _load_weight(checkpoint_path, "cuda") - model_state = { - k.replace("_orig_mod.", ""): v for k, v in model_state.items() - } - try: - model.load_state_dict(model_state) - except Exception: - logger.info("Failed to load model, attempting with strict=False...") - model.load_state_dict(model_state, strict=False) - - logger.info( - f"Finished initializing model - took {time.time() - init_start_time_s:.4f} seconds" - ) - - return model - - -def _first_bad_dur_index( - tokenizer: AbsTokenizer, - priming_seq: list, - pred_ids: list, - chunk_start: int, - last_offset_ms: int, - logger: logging.Logger, -): - num_time_toks = priming_seq[:chunk_start].count(tokenizer.time_tok) - local_onset_ms = tokenizer.calc_length_ms( - priming_seq[:chunk_start], onset=True - ) - logger.debug(f"Starting from local onset {local_onset_ms}") - - for pos, tok_id in enumerate( - pred_ids[: len(priming_seq) - chunk_start], start=chunk_start - ): - prim_tok = priming_seq[pos] # Should never error? - pred_tok = tokenizer.id_to_tok[tok_id] - logger.debug(f"prim={prim_tok}, pred={pred_tok}") - - if isinstance(prim_tok, tuple) and prim_tok[0] == "onset": - local_onset_ms = num_time_toks * 5000 + prim_tok[1] - elif prim_tok == tokenizer.time_tok: - num_time_toks += 1 - elif isinstance(prim_tok, tuple) and prim_tok[0] == "dur": - dur_true = prim_tok[1] - dur_pred = pred_tok[1] - if dur_pred > dur_true and ( - local_onset_ms + dur_true - > last_offset_ms - RECALC_DUR_BUFFER_MS - ): - logger.info( - f"Found token to resample at {pos}: {prim_tok} -> {pred_tok}" - ) - return pos - - return None - - -# TODO: I'm still not 100% sure this is bug free. -# A good debugging strat would be to run it over and over again until we -# cover all of the edge cases -@torch.inference_mode() -def recalc_dur_tokens_chunked( - model: TransformerLM, - priming_seq: list, - enc_seq: torch.Tensor, - tokenizer: AbsTokenizer, - start_idx: int, -): - """Speculative-decoding inspired duration re-calculation""" - assert start_idx > 0 - logger = get_logger("GENERATE") - - priming_len = len(priming_seq) - last_offset = tokenizer.calc_length_ms(priming_seq) - - idx = start_idx - while idx <= priming_len: - end_idx = idx + RECALC_DUR_PREFILL_CHUNK_SIZE - - window_ids = torch.tensor( - enc_seq[:, idx - 1 : end_idx - 1].tolist(), - device="cuda", - dtype=torch.int, - ) - window_pos = torch.arange( - idx - 1, end_idx - 1, device="cuda", dtype=torch.int - ) - - logger.info( - f"Recalculating chunked durations for positions: {idx-1} - {end_idx-2}" - ) - logger.debug(f"Inserted: {tokenizer.decode(window_ids[0].tolist())}") - logger.debug(f"Positions: {window_pos.tolist()}") - - logits = prefill(model, idxs=window_ids, input_pos=window_pos) - pred_ids = logits.argmax(dim=-1).flatten().tolist() - - bad_pos = _first_bad_dur_index( - tokenizer=tokenizer, - priming_seq=priming_seq, - pred_ids=pred_ids, - chunk_start=idx, - last_offset_ms=last_offset, - logger=logger, - ) - - if bad_pos is None: - idx = end_idx - else: - new_id = pred_ids[bad_pos - idx] - enc_seq[0, bad_pos] = new_id - priming_seq[bad_pos] = tokenizer.id_to_tok[new_id] - idx = bad_pos - - next_logits = logits[:, priming_len - idx] - - return enc_seq, priming_seq, next_logits - - -# TODO: This is now the latency bottleneck. -# Ideas for reducing it: -# - Get rid of the manual time_tok insert stuff, instead just mask logits -# for all invalid tokens, this should force the model to sample a time tok -# if there aren't any other valid options -@torch.inference_mode() -def decode_first_tokens( - model: TransformerLM, - first_token_logits: torch.Tensor, - enc_seq: torch.Tensor, - priming_seq: list, - tokenizer: AbsTokenizer, - generated_tokens_queue: queue.Queue, - first_on_msg_epoch_ms: int, -): - logger = get_logger("GENERATE") - - buffer_ms = FIRST_ONSET_BUFFER_MS + HARDWARE_LATENCY_MS - time_tok_id = tokenizer.tok_to_id[tokenizer.time_tok] - - logits = first_token_logits - time_since_first_onset_ms = get_epoch_time_ms() - first_on_msg_epoch_ms - idx = len(priming_seq) + 1 - - num_time_toks_required = (time_since_first_onset_ms + buffer_ms) // 5000 - num_time_toks_in_priming_seq = priming_seq.count(tokenizer.time_tok) - num_time_toks_to_add = num_time_toks_required - num_time_toks_in_priming_seq - - logger.info(f"Time since first onset: {time_since_first_onset_ms}ms") - - while num_time_toks_to_add > 0: - with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): - generated_tokens_queue.put(tokenizer.time_tok) - logits = decode_one( - model, - idxs=torch.tensor( - [[time_tok_id]], device="cuda", dtype=torch.int - ), - input_pos=torch.tensor( - [idx - 1], device="cuda", dtype=torch.int - ), - ) - - logger.info(f"Inserted time_tok at position {idx-1}") - num_time_toks_to_add -= 1 - enc_seq[:, idx - 1] = torch.tensor([[time_tok_id]]).cuda() - idx += 1 - - logits[:, tokenizer.tok_to_id[tokenizer.dim_tok]] = float("-inf") - logits[:, tokenizer.tok_to_id[tokenizer.eos_tok]] = float("-inf") - - log_probs = torch.log_softmax(logits, dim=-1) - top_log_probs, top_ids = torch.topk(log_probs, k=BEAM_WIDTH, dim=-1) - - if time_tok_id not in top_ids[0].tolist(): - top_ids[0, -1] = time_tok_id - top_log_probs[0, -1] = log_probs[0, time_tok_id] + TIME_TOK_WEIGHTING - - top_toks = [tokenizer.id_to_tok[id] for id in top_ids[0].tolist()] - - logger.debug(f"Calculated top {BEAM_WIDTH} tokens={top_toks}") - logger.debug( - f"Calculated top {BEAM_WIDTH} scores={top_log_probs[0].tolist()}" - ) - - masked_onset_ids = [ - tokenizer.tok_to_id[tok] - for tok in tokenizer.onset_tokens - if tok[1] < ((time_since_first_onset_ms + buffer_ms) % 5000) - ] - - logger.debug( - f"Masking onsets for {len(masked_onset_ids)} tokens ({time_since_first_onset_ms + buffer_ms})" - ) - - best_score = float("-inf") - for i in range(BEAM_WIDTH): - tok = top_toks[i] - tok_id = top_ids[0, i].item() - tok_log_prob = top_log_probs[0, i] - - with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): - next_logits = decode_one( - model, - idxs=torch.tensor([[tok_id]], device="cuda", dtype=torch.int), - input_pos=torch.tensor( - [idx - 1], device="cuda", dtype=torch.int - ), - ) - logger.debug( - f"Sampled logits for positions {idx} by inserting {tok} at position {idx-1}" - ) - - next_log_probs = torch.log_softmax(next_logits, dim=-1) - next_log_probs[:, masked_onset_ids] = float("-inf") - if tok_id == time_tok_id: - next_log_probs[:, time_tok_id] = float("-inf") - - next_tok_log_prob, next_tok_id = torch.max(next_log_probs, dim=-1) - next_tok = tokenizer.id_to_tok[next_tok_id.item()] - score = tok_log_prob + next_tok_log_prob - - logger.info( - f"Calculated tuple {(tok, next_tok)} with scores {(tok_log_prob.item(), next_tok_log_prob.item())} (combined={score.item()})" - ) - - if score > best_score: - best_tok_id_1, best_tok_id_2 = tok_id, next_tok_id.item() - best_tok_1, best_tok_2 = ( - tokenizer.id_to_tok[best_tok_id_1], - tokenizer.id_to_tok[best_tok_id_2], - ) - best_score = score - - logger.info( - f"Chose tuple {(best_tok_1, best_tok_2)} with score {best_score.item()}" - ) - - enc_seq[:, idx - 1] = best_tok_id_1 - enc_seq[:, idx] = best_tok_id_2 - generated_tokens_queue.put(tokenizer.id_to_tok[best_tok_id_1]) - generated_tokens_queue.put(tokenizer.id_to_tok[best_tok_id_2]) - - with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): - decode_one( - model, - idxs=torch.tensor( - [[best_tok_id_1]], device="cuda", dtype=torch.int - ), - input_pos=torch.tensor([idx - 1], device="cuda", dtype=torch.int), - ) - - logger.info( - f"Updated KV-Cache by re-inserting {best_tok_1} at position {idx-1}" - ) - logger.info( - f"Inserted {best_tok_2} at position {idx} without updating KV-Cache" - ) - - return enc_seq, idx + 1 - - -def decode_tokens( - model: TransformerLM, - enc_seq: torch.Tensor, - tokenizer: AbsTokenizer, - control_sentinel: threading.Event, - generated_tokens_queue: queue.Queue, - idx: int, - temperature: float, - min_p: float, -): - logger = get_logger("GENERATE") - logger.info( - f"Using sampling parameters: temperature={temperature}, min_p={min_p}" - ) - - while (not control_sentinel.is_set()) and idx < MAX_SEQ_LEN: - decode_one_start_time_s = time.time() - - with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): - prev_tok_id = enc_seq[0, idx - 1] - prev_tok = tokenizer.id_to_tok[prev_tok_id.item()] - - logits = decode_one( - model, - idxs=torch.tensor( - [[prev_tok_id]], device="cuda", dtype=torch.int - ), - input_pos=torch.tensor( - [idx - 1], device="cuda", dtype=torch.int - ), - ) - - logger.debug( - f"Sampled logits for positions {idx} by inserting {prev_tok} at position {idx-1}" - ) - - logits[:, tokenizer.tok_to_id[tokenizer.dim_tok]] = float("-inf") - for dur_ms in range(0, MIN_NOTE_LEN_MS, 10): - logits[:, tokenizer.tok_to_id[("dur", dur_ms)]] = float("-inf") - - if temperature > 0.0: - probs = torch.softmax(logits / temperature, dim=-1) - next_token_ids = sample_min_p(probs, min_p).flatten() - else: - next_token_ids = torch.argmax(logits, dim=-1).flatten() - - enc_seq[:, idx] = next_token_ids - next_token = tokenizer.id_to_tok[next_token_ids[0].item()] - logger.debug( - f"({(time.time() - decode_one_start_time_s)*1000:.2f}ms) {idx}: {next_token}" - ) - - if next_token == tokenizer.eos_tok: - logger.info("EOS token produced, exiting...") - generated_tokens_queue.put(next_token) - return - else: - generated_tokens_queue.put(next_token) - idx += 1 - - while not control_sentinel.is_set(): - time.sleep(0.1) - - logger.info("Seen exit signal") - generated_tokens_queue.put(None) - - -@torch.inference_mode() -def generate_tokens( - priming_seq: list, - tokenizer: AbsTokenizer, - model: TransformerLM, - prev_context: list[int], - control_sentinel: threading.Event, - generated_tokens_queue: queue.Queue, - num_preceding_active_pitches: int, - first_on_msg_epoch_ms: int, - temperature: float = 0.97, - min_p: float = 0.03, -): - logger = get_logger("GENERATE") - - generate_start_s = time.time() - priming_seq_len = len(priming_seq) - start_idx = max(2, priming_seq_len - 4 * num_preceding_active_pitches - 1) - enc_seq = torch.tensor( - [ - tokenizer.encode( - priming_seq - + [tokenizer.pad_tok] * (MAX_SEQ_LEN - len(priming_seq)) - ) - ], - device="cuda", - dtype=torch.int, - ) - - logger.debug(f"Priming sequence {priming_seq}") - logger.info(f"Priming sequence length: {priming_seq_len}") - logger.info(f"Prefilling up to (and including) position: {start_idx-1}") - - # In theory we could reuse the logits from prefill - prefill_start_s = time.time() - chunked_prefill( - model=model, - tokenizer=tokenizer, - prev_context=prev_context, - curr_context=enc_seq[0, :start_idx].tolist(), - full=True, - ) - - torch.cuda.synchronize() - logger.info( - f"Prefill took {(time.time() - prefill_start_s) * 1000:.2f} milliseconds" - ) - logger.info(f"Starting duration recalculation from position: {start_idx-1}") - - recalculate_dur_start_s = time.time() - enc_seq, priming_seq, next_token_logits = recalc_dur_tokens_chunked( - model=model, - priming_seq=priming_seq, - enc_seq=enc_seq, - tokenizer=tokenizer, - start_idx=start_idx, - ) - - logger.info( - f"Recalculating durations took {(time.time() - recalculate_dur_start_s) * 1000:.2f} milliseconds" - ) - - decode_first_s = time.time() - enc_seq, idx = decode_first_tokens( - model=model, - first_token_logits=next_token_logits, - enc_seq=enc_seq, - priming_seq=priming_seq, - tokenizer=tokenizer, - generated_tokens_queue=generated_tokens_queue, - first_on_msg_epoch_ms=first_on_msg_epoch_ms, - ) - - logger.info( - f"Decode first two tokens took {(time.time() - decode_first_s) * 1000:.2f} milliseconds" - ) - logger.info( - f"Time to first token took {(time.time() - generate_start_s) * 1000:.2f} milliseconds" - ) - - decode_tokens( - model=model, - enc_seq=enc_seq, - tokenizer=tokenizer, - control_sentinel=control_sentinel, - generated_tokens_queue=generated_tokens_queue, - idx=idx, - temperature=temperature, - min_p=min_p, - ) - - -def decode_tokens_to_midi( - generated_tokens_queue: queue.Queue, - outbound_midi_msg_queue: queue.Queue, - tokenizer: AbsTokenizer, - first_on_msg_epoch_ms: int, - priming_seq_last_onset_ms: int, -): - logger = get_logger("DECODE") - - assert ( - first_on_msg_epoch_ms + priming_seq_last_onset_ms < get_epoch_time_ms() - ) - - logger.info(f"Priming sequence last onset: {priming_seq_last_onset_ms}") - logger.info( - f"Total time elapsed since first onset: {get_epoch_time_ms() - first_on_msg_epoch_ms}" - ) - - pitch_to_prev_msg = {} - note_buffer = [] - num_time_toks = priming_seq_last_onset_ms // 5000 - - while True: - while True: - tok = generated_tokens_queue.get() - if tok is tokenizer.eos_tok: - _uuid = uuid.uuid4() - end_msg = { - "pitch": -1, - "vel": -1, - "epoch_time_ms": offset_epoch_ms + 250, # Last note offset - "uuid": _uuid, - } # pitch=-1 denotes end_msg - outbound_midi_msg_queue.put(end_msg) - logger.info(f"Seen exit signal: EOS token") - logger.debug(f"Put message: {end_msg}") - return - - elif tok is None: - logger.info(f"Seen exit signal") - return - - logger.debug(f"Seen token: {tok}") - note_buffer.append(tok) - - if isinstance(tok, tuple) and tok[0] == "dur": - break - - while note_buffer and note_buffer[0] == tokenizer.time_tok: - logger.debug("Popping time_tok") - num_time_toks += 1 - note_buffer.pop(0) - - assert len(note_buffer) == 3 - logger.debug(f"Decoded note: {note_buffer}") - note_tok, onset_tok, dur_tok = note_buffer - _, pitch, vel = note_tok - _, onset = onset_tok - _, dur = dur_tok - - _uuid = uuid.uuid4() - onset_epoch_ms = first_on_msg_epoch_ms + (num_time_toks * 5000) + onset - offset_epoch_ms = onset_epoch_ms + dur - on_msg = { - "pitch": pitch, - "vel": vel, - "epoch_time_ms": onset_epoch_ms, - "uuid": _uuid, - } - off_msg = { - "pitch": pitch, - "vel": 0, - "epoch_time_ms": offset_epoch_ms, - "uuid": _uuid, - } - - # Not thread safe but in theory should be ok? - if pitch_to_prev_msg.get(pitch) is not None and MIN_NOTE_DELTA_MS > 0: - prev_on, prev_off = pitch_to_prev_msg.get(pitch) - adj_off_time = max( - min( - prev_off["epoch_time_ms"], - onset_epoch_ms - MIN_NOTE_DELTA_MS, - ), - prev_on["epoch_time_ms"], - ) - if adj_off_time != prev_off["epoch_time_ms"]: - logger.debug(f"Adjusting {prev_off}: t={adj_off_time}") - prev_off["epoch_time_ms"] = adj_off_time - prev_off["adjusted"] = True - - pitch_to_prev_msg[pitch] = [on_msg, off_msg] - - outbound_midi_msg_queue.put(on_msg) - outbound_midi_msg_queue.put(off_msg) - logger.debug(f"Put message: {on_msg}") - logger.debug(f"Put message: {off_msg}") - logger.debug(f"Ahead by {onset_epoch_ms - get_epoch_time_ms()}ms") - - note_buffer = [] - - -# TODO: Test the new changes in decode_tokens_to_midi and clean this fn up. -def stream_midi( - inbound_midi_msg_queue: queue.Queue, - msgs: list[mido.Message], - prev_msg_epoch_time_ms: float, - midi_output_port: str, - control_sentinel: threading.Event, - midi_stream_channel: int, - results_queue: queue.Queue, -): - logger = get_logger("STREAM") - logger.info( - f"Sending generated messages on MIDI port: '{midi_output_port}'" - ) - logger.info( - f"Applying hardware latency adjustment: {HARDWARE_LATENCY_MS}ms" - ) - MAX_DELAY_MS = 50 - - active_pitch_uuid = {} - is_pitch_active = {} - midi_msgs = [] - - with mido.open_output(midi_output_port) as midi_out: - while not control_sentinel.is_set(): - while True: - try: - msg = inbound_midi_msg_queue.get_nowait() - except queue.Empty: - break - else: - logger.debug(f"Received message: {msg}") - midi_msgs.append(msg) - - midi_msgs = sorted( - midi_msgs, - key=lambda msg: ( - msg["epoch_time_ms"], - msg["vel"], - ), - ) - - if control_sentinel.is_set(): - break - - while midi_msgs: - latency_adjusted_epoch_time_ms = ( - get_epoch_time_ms() + HARDWARE_LATENCY_MS - ) - msg = midi_msgs[0] - - if ( - 0 - < latency_adjusted_epoch_time_ms - msg["epoch_time_ms"] - <= MAX_DELAY_MS - ): - if msg["pitch"] == -1: # End msg - control_sentinel.set() - break - - mido_msg = mido.Message( - "note_on", - note=msg["pitch"], - velocity=msg["vel"], - channel=0, - time=0, - ) - - if msg["vel"] > 0: - active_pitch_uuid[msg["pitch"]] = msg["uuid"] - should_send_midi_out = True - should_append_to_msgs = True - elif msg.get("adjusted", False) is True: - should_send_midi_out = True - should_append_to_msgs = False - else: - should_send_midi_out = ( - active_pitch_uuid.get(msg["pitch"]) == msg["uuid"] - ) - should_append_to_msgs = should_send_midi_out - - if should_send_midi_out is True: - midi_out.send(mido_msg) - is_pitch_active[msg["pitch"]] = msg["vel"] != 0 - logger.info(f"Sent message: {mido_msg}") - if should_append_to_msgs is True: - mido_msg_with_time = copy.deepcopy(mido_msg) - mido_msg_with_time.channel = midi_stream_channel - mido_msg_with_time.time = max( - 0, msg["epoch_time_ms"] - prev_msg_epoch_time_ms - ) - prev_msg_epoch_time_ms = msg["epoch_time_ms"] - msgs.append(mido_msg_with_time) - - midi_msgs.pop(0) - - elif ( - latency_adjusted_epoch_time_ms - msg["epoch_time_ms"] - > MAX_DELAY_MS - ): - # Message occurs too far in the past - logger.debug( - f"Skipping message occurring too far ({latency_adjusted_epoch_time_ms - msg["epoch_time_ms"]}ms) in the past: {msg}" - ) - midi_msgs.pop(0) - else: - # Message occurs in the future - break - - time.sleep(0.005) - - remaining_note_off_messages = [ - msg - for msg in midi_msgs - if msg["vel"] == 0 - and active_pitch_uuid.get(msg["pitch"]) == msg["uuid"] - ] - - logger.info("Processing remaining note_off messages") - for __msg in remaining_note_off_messages: - logger.debug(remaining_note_off_messages) - - for msg in remaining_note_off_messages: - mido_msg = mido.Message( - "note_on", - note=msg["pitch"], - velocity=0, - channel=midi_stream_channel, - time=msg["epoch_time_ms"] - prev_msg_epoch_time_ms, - ) - prev_msg_epoch_time_ms = msg["epoch_time_ms"] - msgs.append(mido_msg) - - results_queue.put(msgs) - - while remaining_note_off_messages: - msg = remaining_note_off_messages.pop(0) - while True: - latency_adjusted_epoch_time_ms = ( - get_epoch_time_ms() + HARDWARE_LATENCY_MS - ) - - if 0 < latency_adjusted_epoch_time_ms - msg["epoch_time_ms"]: - mido_msg = mido.Message( - "note_on", - note=msg["pitch"], - velocity=0, - channel=midi_stream_channel, - time=0, # Does not matter as only used for streaming - ) - midi_out.send(mido_msg) - logger.info(f"Sent message: {mido_msg}") - break - else: - time.sleep(0.01) - - -def stream_msgs( - model: TransformerLM, - tokenizer: AbsTokenizer, - msgs: list[mido.Message], - prev_context: list[int], - midi_output_port: str, - first_on_msg_epoch_ms: int, - control_sentinel: threading.Event, - temperature: float, - min_p: float, - num_preceding_active_pitches: int, - midi_stream_channel: int, - is_ending: bool = False, -): - midi = convert_msgs_to_midi(msgs=msgs) - midi_dict = MidiDict(**midi_to_dict(midi)) - priming_seq = tokenizer.tokenize(midi_dict=midi_dict, add_dim_tok=False) - priming_seq = priming_seq[: priming_seq.index(tokenizer.eos_tok)] - - if is_ending is True: - priming_seq.append(tokenizer.dim_tok) - - generated_tokens_queue = queue.Queue() - midi_messages_queue = queue.Queue() - - generate_tokens_thread = threading.Thread( - target=generate_tokens, - kwargs={ - "priming_seq": priming_seq, - "tokenizer": tokenizer, - "model": model, - "prev_context": prev_context, - "control_sentinel": control_sentinel, - "generated_tokens_queue": generated_tokens_queue, - "temperature": temperature, - "min_p": min_p, - "num_preceding_active_pitches": num_preceding_active_pitches, - "first_on_msg_epoch_ms": first_on_msg_epoch_ms, - }, - ) - generate_tokens_thread.start() - - decode_tokens_to_midi_thread = threading.Thread( - target=decode_tokens_to_midi, - kwargs={ - "generated_tokens_queue": generated_tokens_queue, - "outbound_midi_msg_queue": midi_messages_queue, - "tokenizer": tokenizer, - "first_on_msg_epoch_ms": first_on_msg_epoch_ms, - "priming_seq_last_onset_ms": tokenizer.calc_length_ms( - priming_seq, onset=True - ), - }, - ) - decode_tokens_to_midi_thread.start() - - prev_ms_epoch_time_ms = ( - first_on_msg_epoch_ms - + tokenizer.calc_length_ms(priming_seq, onset=False) - if is_ending is False - else first_on_msg_epoch_ms - ) - - stream_midi_results_queue = queue.Queue() - stream_midi_thread = threading.Thread( - target=stream_midi, - kwargs={ - "inbound_midi_msg_queue": midi_messages_queue, - "msgs": msgs, - "prev_msg_epoch_time_ms": prev_ms_epoch_time_ms, - "midi_output_port": midi_output_port, - "control_sentinel": control_sentinel, - "midi_stream_channel": midi_stream_channel, - "results_queue": stream_midi_results_queue, - }, - daemon=True, - ) - stream_midi_thread.start() - - generate_tokens_thread.join() - decode_tokens_to_midi_thread.join() - msgs = stream_midi_results_queue.get() - - if is_ending is True: - stream_midi_thread.join() - - return msgs - - -# TODO: Channel 9 issues here? -def convert_msgs_to_midi(msgs: list[mido.Message]): - channel_to_track = { - chan: mido.MidiTrack() - for chan in list(set([msg.channel for msg in msgs])) - } - - for msg in msgs: - channel_to_track[msg.channel].append(msg) - - # Workaround for possibility that track_0 start time != first_on_msg_epoch_ms - for msg in channel_to_track[0]: - if msg.type == "note_on" and msg.velocity > 0: - msg.time = 0 - break - else: - msg.time = 0 - - mid = mido.MidiFile(type=1) - mid.ticks_per_beat = 500 - - for channel, track in channel_to_track.items(): - track.insert(0, mido.MetaMessage("set_tempo", tempo=500000, time=0)) - track.insert( - 0, - mido.Message("program_change", program=0, channel=channel, time=0), - ) - mid.tracks.append(track) - - return mid - - -def _find_divergence( - prev_context: list, - curr_context: list, - logger: logging.Logger, -): - agreement_index = 0 - for prev_val, curr_val in zip(prev_context, curr_context): - if prev_val == curr_val: - agreement_index += 1 - else: - logger.info( - f"Found divergence at position {agreement_index + 1}: {curr_val}, {prev_val}" - ) - break - - return agreement_index, curr_context[agreement_index:] - - -# There is an error here if curr_context < prev_context -@torch.inference_mode() -def chunked_prefill( - model: TransformerLM, - tokenizer: AbsTokenizer, - prev_context: list, - curr_context: list, - full: bool = False, -): - - assert isinstance(curr_context[0], int) - assert tokenizer.pad_id not in prev_context - assert tokenizer.pad_id not in curr_context - - logger = get_logger("PREFILL") - while True: - prefill_idx, prefill_toks = _find_divergence( - prev_context, curr_context, logger=logger - ) - num_prefill_toks = len(prefill_toks) - logger.debug(f"Tokens to prefill: {len(prefill_toks)}") - - if num_prefill_toks > PREFILL_CHUNK_SIZE: - logger.debug( - f"Prefilling {PREFILL_CHUNK_SIZE} tokens from idx={prefill_idx}" - ) - - prefill( - model, - idxs=torch.tensor( - [prefill_toks[:PREFILL_CHUNK_SIZE]], - device="cuda", - dtype=torch.int, - ), - input_pos=torch.arange( - prefill_idx, - prefill_idx + PREFILL_CHUNK_SIZE, - device="cuda", - dtype=torch.int, - ), - ) - prev_context = curr_context[: prefill_idx + PREFILL_CHUNK_SIZE] - - elif num_prefill_toks > 0 and full is True: - logger.debug( - f"Prefilling (force) {num_prefill_toks} tokens from idx={prefill_idx}" - ) - prefill_toks += (PREFILL_CHUNK_SIZE - len(prefill_toks)) * [ - tokenizer.pad_id - ] - prefill( - model, - idxs=torch.tensor( - [prefill_toks], device="cuda", dtype=torch.int - ), - input_pos=torch.arange( - prefill_idx, - prefill_idx + PREFILL_CHUNK_SIZE, - device="cuda", - dtype=torch.int, - ), - ) - prev_context = curr_context - break - else: - break - - logger.info( - f"KV stored up to idx={max(0, len(prev_context)- 1)} (curr_context_len={len(curr_context)})" - ) - - return prev_context - - -def continuous_prefill( - model: TransformerLM, - msgs: list, - received_messages_queue: queue.Queue, - prev_context: list[int], -): - tokenizer = AbsTokenizer() - logger = get_logger("PREFILL") - msg_cnt = 0 - seen_sentinel = False - - while seen_sentinel is False: - while seen_sentinel is False: - try: - msg = received_messages_queue.get_nowait() - except queue.Empty: - break - else: - if msg is None: - logger.info("Seen sentinel in message received messages") - seen_sentinel = True - else: - msgs.append(msg) - msg_cnt += 1 - - if (msg_cnt >= 5 or seen_sentinel) and len(msgs) > 10: - midi = convert_msgs_to_midi(msgs=msgs) - midi_dict = MidiDict(**midi_to_dict(midi)) - curr_context = tokenizer.encode( - tokenizer.tokenize(midi_dict, add_dim_tok=False) - ) - prev_context = chunked_prefill( - model=model, - tokenizer=tokenizer, - prev_context=prev_context, - curr_context=curr_context, - full=False, - ) - msg_cnt = 0 - else: - time.sleep(0.01) - - return msgs, prev_context - - -def capture_and_update_kv( - model: TransformerLM, - msgs: list, - prev_context: list, - control_sentinel: threading.Event, - midi_input_port: str, - midi_capture_channel: int, - midi_control_signal: int | None = None, - midi_through_port: str | None = None, - first_msg_epoch_time_ms: int | None = None, -): - received_messages_queue = queue.Queue() - results_queue = queue.Queue() - capture_midi_thread = threading.Thread( - target=capture_midi_input, - kwargs={ - "midi_input_port": midi_input_port, - "control_sentinel": control_sentinel, - "received_messages_queue": received_messages_queue, - "midi_capture_channel": midi_capture_channel, - "midi_control_signal": midi_control_signal, - "midi_through_port": midi_through_port, - "first_msg_epoch_time_ms": first_msg_epoch_time_ms, - "results_queue": results_queue, - }, - ) - capture_midi_thread.start() - - msgs, prev_context = continuous_prefill( - model=model, - msgs=msgs, - received_messages_queue=received_messages_queue, - prev_context=prev_context, - ) - capture_midi_thread.join() - first_on_msg_epoch_ms, num_active_pitches = results_queue.get() - - return msgs, prev_context, first_on_msg_epoch_ms, num_active_pitches - - -def capture_midi_input( - midi_input_port: str, - control_sentinel: threading.Event, - received_messages_queue: queue.Queue, - midi_capture_channel: int, - results_queue: queue.Queue, - midi_control_signal: int | None = None, - midi_through_port: str | None = None, - first_msg_epoch_time_ms: int | None = None, -): - logger = get_logger("CAPTURE") - active_pitches = set() - first_on_msg_epoch_ms = None - prev_msg_epoch_time_ms = first_msg_epoch_time_ms # - - logger.info(f"Listening on MIDI port: '{midi_input_port}'") - logger.info(f"Using MIDI control signal: {midi_control_signal}") - if midi_through_port is not None: - logger.info(f"Sending through on MIDI port: '{midi_through_port}'") - - with ExitStack() as stack: - midi_input = stack.enter_context(mido.open_input(midi_input_port)) - midi_through = ( - stack.enter_context(mido.open_output(midi_through_port)) - if midi_through_port - else None - ) - - while not control_sentinel.is_set(): - msg = midi_input.receive(block=False) - - if msg is None: - time.sleep(0.001) - continue - - if prev_msg_epoch_time_ms is None: - msg_time_ms = 0 - else: - msg_time_ms = get_epoch_time_ms() - prev_msg_epoch_time_ms - - prev_msg_epoch_time_ms = get_epoch_time_ms() - msg.time = msg_time_ms - msg.channel = midi_capture_channel - logger.info(f"Received message: [{msg}]") - - if msg.is_meta is True or msg.type == "program_change": - continue - - if ( - msg.type == "note_on" and msg.velocity == 0 - ) or msg.type == "note_off": - active_pitches.discard(msg.note) - received_messages_queue.put(msg) - if midi_through is not None: - midi_through.send(msg) - elif msg.type == "note_on" and msg.velocity > 0: - if first_on_msg_epoch_ms is None: - first_on_msg_epoch_ms = get_epoch_time_ms() - - active_pitches.add(msg.note) - received_messages_queue.put(msg) - if midi_through is not None: - midi_through.send(msg) - elif msg.type == "control_change" and msg.control == 64: - received_messages_queue.put(msg) - elif ( - msg.type == "control_change" - and msg.control == midi_control_signal - and msg.value > 0 - ): - control_sentinel.set() - - logger.info("Control signal seen") - logger.info(f"Active pitches: {active_pitches}") - num_active_pitches = len(active_pitches) - - if active_pitches: - pitch = active_pitches.pop() - msg = mido.Message( - type="note_on", - note=pitch, - velocity=0, - channel=midi_capture_channel, - time=get_epoch_time_ms() - prev_msg_epoch_time_ms, - ) - received_messages_queue.put(msg) - if midi_through is not None: - midi_through.send(msg) - - while active_pitches: - pitch = active_pitches.pop() - msg = mido.Message( - type="note_on", - note=pitch, - velocity=0, - channel=midi_capture_channel, - time=0, - ) - received_messages_queue.put(msg) - if midi_through is not None: - midi_through.send(msg) - - # Turn off pedal - msg = mido.Message( - type="control_change", - control=64, - value=0, - channel=midi_capture_channel, - time=0, - ) - received_messages_queue.put(msg) - if midi_through is not None: - midi_through.send(msg) - - received_messages_queue.put(None) # Sentinel - results_queue.put((first_on_msg_epoch_ms, num_active_pitches)) - - -def play_midi_file(midi_port: str, midi_path: str): - logger = get_logger("FILE") - logger.info(f"Playing file at {midi_path} on MIDI port '{midi_port}'") - time.sleep(1) - active_pitches = [] - with mido.open_output(midi_port) as output_port: - for msg in mido.MidiFile(midi_path).play(): - if msg.type == "note_on" and msg.velocity > 0: - if msg.note in active_pitches: - _off_msg = copy.deepcopy(msg) - _off_msg.velocity = 0 - output_port.send(_off_msg) - else: - active_pitches.append(msg.note) - elif msg.type == "note_off" or ( - msg.type == "note_on" and msg.velocity == 0 - ): - if msg.note in active_pitches: - active_pitches.remove(msg.note) - - logger.debug(f"{msg}") - output_port.send(msg) - - -def listen_for_keypress_control_signal( - control_sentinel: threading.Event, - end_sentinel: threading.Event, -): - logger = get_logger("KEYBOARD") - while True: - time.sleep(1) - _input = input() - logger.info(f'Keypress seen "{_input}"') - control_sentinel.set() - - if _input == "e": - end_sentinel.set() - - -# TODO: Not tested -def listen_for_midi_control_signal( - midi_input_port: str, - control_sentinel: threading.Event, - end_sentinel: threading.Event, - midi_control_signal: int | None = None, - midi_end_signal: int | None = None, -): - with mido.open_input(midi_input_port) as midi_input: - while True: - msg = midi_input.receive(block=False) - if msg is None: - time.sleep(0.01) - elif ( - msg.type == "control_change" - and msg.control == midi_control_signal - and msg.value > 0 - ): - control_sentinel.set() - elif ( - msg.type == "control_change" - and msg.control == midi_end_signal - and msg.value > 0 - ): - control_sentinel.set() - end_sentinel.set() - - -def parse_args(): - argp = argparse.ArgumentParser() - argp.add_argument("-cp", help="path to model checkpoint") - argp.add_argument("-midi_in", required=False, help="MIDI input port") - argp.add_argument("-midi_out", required=True, help="MIDI output port") - argp.add_argument( - "-midi_through", - required=False, - help="MIDI through port for received input", - ) - argp.add_argument( - "-midi_path", - required=False, - help="Use MIDI file instead of MIDI input port", - ) - argp.add_argument( - "-midi_control_signal", - type=int, - help="MIDI control change message for AI takeover", - ) - argp.add_argument( - "-midi_end_signal", - type=int, - help="MIDI control change message to generate ending", - ) - argp.add_argument( - "-temp", - help="sampling temperature value", - type=float, - required=False, - default=0.95, - ) - argp.add_argument( - "-min_p", - help="sampling min_p value", - type=float, - required=False, - default=0.03, - ) - argp.add_argument( - "-cfg", - help="sampling cfg gamma value", - type=float, - required=False, - ) - argp.add_argument( - "-metadata", - nargs=2, - metavar=("KEY", "VALUE"), - action="append", - help="manually add metadata key-value pair when sampling", - ) - argp.add_argument( - "-save_path", - type=str, - required=False, - help="Path to save complete MIDI file", - ) - - return argp.parse_args() - - -# TODO: Need functionality for handing case where we run out of model context -# TODO: Make sure channel=9 (drum) case is covered -def main(): - args = parse_args() - logger = get_logger() - tokenizer = AbsTokenizer() - model = load_model(checkpoint_path=args.cp) - model = compile_model(model=model, max_seq_len=MAX_SEQ_LEN) - - assert (args.midi_path and os.path.isfile(args.midi_path)) or args.midi_in - if args.midi_path: - midi_input_port = "Midi Through:Midi Through Port-0" - play_file_thread = threading.Thread( - target=play_midi_file, - args=(midi_input_port, args.midi_path), - daemon=True, - ) - play_file_thread.start() - else: - midi_input_port = args.midi_in - - control_sentinel = threading.Event() - end_sentinel = threading.Event() - keypress_thread = threading.Thread( - target=listen_for_keypress_control_signal, - args=[control_sentinel, end_sentinel], - daemon=True, - ) - midi_control_thread = threading.Thread( - target=listen_for_midi_control_signal, - kwargs={ - "midi_input_port": midi_input_port, - "control_sentinel": control_sentinel, - "end_sentinel": end_sentinel, - "midi_control_signal": args.midi_control_signal, - "midi_end_signal": args.midi_end_signal, - }, - daemon=True, - ) - keypress_thread.start() - midi_control_thread.start() - - msgs, prev_context, first_on_msg_epoch_ms, num_active_pitches = ( - capture_and_update_kv( - model=model, - msgs=[], - prev_context=[], - control_sentinel=control_sentinel, - midi_input_port=midi_input_port, - midi_control_signal=args.midi_control_signal, - midi_through_port=args.midi_through, - midi_capture_channel=0, - ) - ) - - itt = 0 - while True: - control_sentinel.clear() - msgs = stream_msgs( - model=model, - tokenizer=tokenizer, - msgs=msgs, - prev_context=prev_context, - midi_output_port=args.midi_out, - first_on_msg_epoch_ms=first_on_msg_epoch_ms, - control_sentinel=control_sentinel, - temperature=args.temp, - min_p=args.min_p, - num_preceding_active_pitches=num_active_pitches, - midi_stream_channel=itt, - is_ending=False, - ) - - itt += 1 - control_sentinel.clear() - if end_sentinel.is_set(): - break - - msgs, prev_context, _, num_active_pitches = capture_and_update_kv( - model=model, - msgs=msgs, - prev_context=prev_context, - control_sentinel=control_sentinel, - midi_input_port=midi_input_port, - midi_control_signal=args.midi_control_signal, - midi_through_port=args.midi_through, - midi_capture_channel=itt, - first_msg_epoch_time_ms=first_on_msg_epoch_ms, - ) - - # TODO: There is a bug with the token somewhere? - msgs = stream_msgs( - model=model, - tokenizer=tokenizer, - msgs=msgs, - prev_context=prev_context, - midi_output_port=args.midi_out, - first_on_msg_epoch_ms=first_on_msg_epoch_ms, - control_sentinel=control_sentinel, - temperature=args.temp / 2, - min_p=args.min_p, - num_preceding_active_pitches=num_active_pitches, - midi_stream_channel=itt, - is_ending=True, - ) - - if args.save_path: - logger.info(f"Saving result to {args.save_path}") - midi = convert_msgs_to_midi(msgs=msgs) - midi.save(args.save_path) - - -if __name__ == "__main__": - main() diff --git a/demo/demo.sh b/demo/demo.sh deleted file mode 100644 index 03f2d89..0000000 --- a/demo/demo.sh +++ /dev/null @@ -1,12 +0,0 @@ -MID_PATH="/home/loubb/Dropbox/shared/demo.mid" - -python /home/loubb/work/aria/demo/demo.py \ - -cp /mnt/ssd1/aria/v2/medium-dedupe-pt-cont2/checkpoints/epoch18_step0/model.safetensors \ - -midi_path ${MID_PATH} \ - -midi_out "Midi Through:Midi Through Port-1" \ - -midi_through "Midi Through:Midi Through Port-2" \ - -save_path /home/loubb/Dropbox/shared/output.mid \ - -midi_control_signal 66 \ - -midi_end_signal 67 \ - -temp 0.98 \ - -min_p 0.02 \ No newline at end of file diff --git a/demo/midi-tunnel-client.py b/demo/midi-tunnel-client.py deleted file mode 100644 index d4c6f83..0000000 --- a/demo/midi-tunnel-client.py +++ /dev/null @@ -1,143 +0,0 @@ -import socket -import rtmidi -import time -import subprocess -import signal -import sys -import os -import argparse - -SSH_SERVER = "home-4090.remote" -def parse_arguments(): - parser = argparse.ArgumentParser(description='MIDI UDP bridge with SSH tunnel') - parser.add_argument('-p', '--port', type=int, default=5004, - help='UDP port number (default: 5004)') - return parser.parse_args() - -def kill_existing_process(port): - # Check and kill existing process on remote server - check_command = f"ssh {SSH_SERVER} 'lsof -ti :{port}'" - try: - pid = subprocess.check_output(check_command, shell=True).decode().strip() - if pid: - print(f"Found existing process {pid} on port {port}, killing it...") - kill_command = f"ssh {SSH_SERVER} 'kill -9 {pid}'" - subprocess.run(kill_command, shell=True) - # Wait a moment for the port to be freed - time.sleep(1) - except subprocess.CalledProcessError: - # No existing process found - pass - -def setup_ssh_tunnel(port): - while True: - try: - # Kill any existing process first - kill_existing_process(port) - - # Start SSH tunnel using socat - print(f"Attempting to establish SSH tunnel on port {port}...") - ssh_command = f"ssh {SSH_SERVER} 'socat -u UDP4-RECV:{port} STDOUT'" - local_socat = f"socat -u STDIN UDP4-SEND:localhost:{port}" - - ssh_process = subprocess.Popen(ssh_command, shell=True, stdout=subprocess.PIPE) - socat_process = subprocess.Popen(local_socat, shell=True, stdin=ssh_process.stdout) - - # Check if the processes started successfully - time.sleep(1) - if ssh_process.poll() is not None: # Process terminated - raise subprocess.CalledProcessError(ssh_process.returncode, ssh_command) - - print("SSH tunnel established successfully!") - return ssh_process, socat_process - - except (subprocess.CalledProcessError, OSError) as e: - print(f"Failed to establish SSH tunnel: {str(e)}") - print("Retrying in 1 second...") - time.sleep(1) - -def create_virtual_port(port): - midi_out = rtmidi.MidiOut() - # Create a virtual MIDI port with port number in name - midi_out.open_virtual_port(f"UDP_{port}") - return midi_out - -def start_udp_listener(port): - # Create UDP socket - sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - sock.bind(('localhost', port)) - return sock - -def split_midi_messages(data): - """Split a byte array into individual MIDI messages.""" - messages = [] - data_list = list(data) - i = 0 - while i < len(data_list): - # Check if we have a status byte (most significant bit is 1) - if data_list[i] >= 0x80: - # Most MIDI messages are 3 bytes - if i + 2 < len(data_list): - messages.append(data_list[i:i+3]) - i += 3 - else: - # Handle incomplete message at end of buffer - break - else: - # Skip non-status bytes (shouldn't happen in properly formatted MIDI) - i += 1 - return messages - -def cleanup(ssh_process, socat_process, midi_out, sock): - print("\nCleaning up...") - # Kill the SSH and socat processes - if ssh_process: - os.killpg(os.getpgid(ssh_process.pid), signal.SIGTERM) - if socat_process: - socat_process.terminate() - # Close MIDI and socket - if midi_out: - midi_out.close_port() - if sock: - sock.close() - -def main(): - args = parse_arguments() - port = args.port - - ssh_process = None - socat_process = None - midi_out = None - sock = None - - try: - # Setup SSH tunnel first - print(f"Setting up SSH tunnel on port {port}...") - ssh_process, socat_process = setup_ssh_tunnel(port) - - # Setup MIDI and UDP - print(f"Creating virtual MIDI port UDP_{port}...") - midi_out = create_virtual_port(port) - print(f"Starting UDP listener on port {port}...") - sock = start_udp_listener(port) - - print(f"UDP MIDI Bridge started - listening on port {port}") - - while True: - data, addr = sock.recvfrom(1024) - if data: - # Split the data into individual MIDI messages - midi_messages = split_midi_messages(data) - for midi_message in midi_messages: - print(f"Sending MIDI message: {midi_message}") - midi_out.send_message(midi_message) - - except KeyboardInterrupt: - print("\nShutting down UDP MIDI Bridge...") - except Exception as e: - print(f"Error: {e}") - finally: - cleanup(ssh_process, socat_process, midi_out, sock) - -if __name__ == "__main__": - main() diff --git a/demo/midi-tunnel-server.py b/demo/midi-tunnel-server.py deleted file mode 100755 index 988e200..0000000 --- a/demo/midi-tunnel-server.py +++ /dev/null @@ -1,61 +0,0 @@ -import rtmidi -import socket -import time -import struct -import argparse - -class MIDIRouter: - def __init__(self, midi_port="14:0", udp_port=5004): - self.midi_in = rtmidi.MidiIn() - self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - self.udp_port = udp_port - - # Print available ports - ports = self.midi_in.get_ports() - print(f"Available MIDI ports: {ports}") - - # Find and open MIDI port - for i, port in enumerate(ports): - if midi_port in port: - print(f"Opening MIDI port {i}: {port}") - self.midi_in.open_port(i) - break - else: - print(f"Warning: Could not find port containing '{midi_port}'") - - self.midi_in.set_callback(self._midi_callback) - - def _midi_callback(self, message, timestamp): - try: - print(f"Received MIDI message: {message[0]}") - midi_data = struct.pack(f'B' * len(message[0]), *message[0]) - self.socket.sendto(midi_data, ('localhost', self.udp_port)) - print(f"Sent {len(midi_data)} bytes to localhost:{self.udp_port}") - except Exception as e: - print(f"Error in callback: {e}") - - def start(self): - print(f"Routing MIDI messages through SSH tunnel on port {self.udp_port}...") - try: - while True: - time.sleep(0.1) - except KeyboardInterrupt: - self.stop() - - def stop(self): - print("Shutting down...") - self.midi_in.close_port() - self.socket.close() - -def parse_args(): - parser = argparse.ArgumentParser(description='MIDI to UDP router') - parser.add_argument('-midi_p', type=str, default="14:0", - help='MIDI port identifier (default: 14:0)') - parser.add_argument('-udp_p', type=int, default=5004, - help='UDP port for forwarding (default: 5004)') - return parser.parse_args() - -if __name__ == "__main__": - args = parse_args() - router = MIDIRouter(midi_port=args.midi_p, udp_port=args.udp_p) - router.start()