From 628cb9f287d394bb68beadc4575b179f89e64d4d Mon Sep 17 00:00:00 2001 From: Louis Date: Mon, 7 Jul 2025 14:40:51 +0100 Subject: [PATCH 01/21] add reset to demo --- demo/demo_mlx.py | 386 ++++++++++++++++++++++++++++------------------- 1 file changed, 229 insertions(+), 157 deletions(-) diff --git a/demo/demo_mlx.py b/demo/demo_mlx.py index e28a56a..7aa852c 100644 --- a/demo/demo_mlx.py +++ b/demo/demo_mlx.py @@ -22,7 +22,9 @@ from aria.inference.model_mlx import TransformerLM from aria.model import ModelConfig from aria.config import load_model_config +from aria.run import _get_embedding +EMBEDDING_OFFSET = 0 DTYPE = mx.float32 MAX_SEQ_LEN = 2048 PREFILL_CHUNK_SIZE_L = 128 @@ -32,23 +34,23 @@ BEAM_WIDTH = 3 TIME_TOK_WEIGHTING = -5 -FIRST_ONSET_BUFFER_MS = -100 # Controls onset timing for first generated not +FIRST_ONSET_BUFFER_MS = -200 # Controls onset timing for first generated not # HARDWARE: Decoded logits are masked for durations < MIN_NOTE_LEN_MS # HARDWARE: Sends early off-msg if pitch is on MIN_NOTE_DELTA_MS before on-msg # HARDWARE: All messages are sent HARDWARE_OUTPUT_LATENCY_MS early # C4DM Disklavier: -# MIN_NOTE_DELTA_MS = 40 -# MIN_NOTE_LEN_MS = 100 -# HARDWARE_INPUT_LATENCY_MS = 50 -# HARDWARE_OUTPUT_LATENCY_MS = 120 +MIN_NOTE_DELTA_MS = 40 +MIN_NOTE_LEN_MS = 100 +HARDWARE_INPUT_LATENCY_MS = 50 +HARDWARE_OUTPUT_LATENCY_MS = 120 # Pianoteq -MIN_NOTE_DELTA_MS = 0 -MIN_NOTE_LEN_MS = 30 -HARDWARE_INPUT_LATENCY_MS = 0 -HARDWARE_OUTPUT_LATENCY_MS = 0 +# MIN_NOTE_DELTA_MS = 0 +# MIN_NOTE_LEN_MS = 30 +# HARDWARE_INPUT_LATENCY_MS = 0 +# HARDWARE_OUTPUT_LATENCY_MS = 0 MAX_STREAM_DELAY_MS = 50 @@ -87,6 +89,77 @@ def formatTime(self, record, datefmt=None): return logger +def parse_args(): + argp = argparse.ArgumentParser() + argp.add_argument("--checkpoint", help="path to model checkpoint") + argp.add_argument("--midi_in", required=False, help="MIDI input port") + argp.add_argument("--midi_out", required=True, help="MIDI output port") + argp.add_argument( + "--midi_through", + required=False, + help="MIDI through port for received input", + ) + argp.add_argument( + "--midi_path", + required=False, + help="Use MIDI file instead of MIDI input port", + ) + argp.add_argument( + "--midi_control_signal", + type=int, + help="MIDI control change message for AI takeover", + ) + argp.add_argument( + "--midi_reset_control_signal", + type=int, + help="MIDI control change message context reset", + ) + argp.add_argument( + "--temp", + help="sampling temperature value", + type=float, + required=False, + default=0.95, + ) + argp.add_argument( + "--min_p", + help="sampling min_p value", + type=float, + required=False, + default=0.03, + ) + argp.add_argument( + "--wait_for_close", + help="wait for note-offs before generating", + action="store_true", + ) + argp.add_argument( + "--quantize", + help="apply model quantize", + action="store_true", + ) + argp.add_argument( + "--save_path", + type=str, + required=False, + help="path to save complete MIDI file", + ) + argp.add_argument( + "--embedding_checkpoint", + type=str, + help="path to embedding model checkpoint for conditioned generation", + required=False, + ) + argp.add_argument( + "--embedding_midi_path", + type=str, + help="path to embedding MIDI file for conditioned generation", + required=False, + ) + + return argp.parse_args() + + def get_epoch_time_ms() -> int: return round(time.time() * 1000) @@ -95,14 +168,12 @@ def prefill( model: TransformerLM, idxs: mx.array, input_pos: mx.array, - pad_idxs: mx.array | None = None, ) -> mx.array: # pad_idxs is only needed for prepended pad tokens logits = model( idxs=idxs, - input_pos=input_pos, - offset=input_pos[0], - pad_idxs=pad_idxs, + input_pos=input_pos + EMBEDDING_OFFSET, + offset=input_pos[0] + EMBEDDING_OFFSET, ) return logits @@ -112,16 +183,13 @@ def decode_one( model: TransformerLM, idxs: mx.array, input_pos: mx.array, - pad_idxs: mx.array | None = None, ) -> mx.array: - # pad_idxs is only needed for prepended pad tokens assert input_pos.shape[-1] == 1 logits = model( idxs=idxs, - input_pos=input_pos, - offset=input_pos[0], - pad_idxs=pad_idxs, + input_pos=input_pos + EMBEDDING_OFFSET, + offset=input_pos[0] + EMBEDDING_OFFSET, )[:, -1] return logits @@ -244,8 +312,7 @@ def compile_model(model: TransformerLM): model = _compile_decode_one(model=model, logger=logger) for chunk_size in list( { - # PREFILL_CHUNK_SIZE_L, - # PREFILL_CHUNK_SIZE, + PREFILL_CHUNK_SIZE, RECALC_DUR_PREFILL_CHUNK_SIZE, } ): @@ -1203,10 +1270,10 @@ def capture_and_update_kv( msgs: list, prev_context: list, control_sentinel: threading.Event, + reset_sentinel: threading.Event, wait_for_close: bool, midi_input_port: str, midi_capture_channel: int, - midi_control_signal: int | None = None, first_msg_epoch_time_ms: int | None = None, ): received_messages_queue = queue.Queue() @@ -1216,9 +1283,9 @@ def capture_and_update_kv( kwargs={ "midi_input_port": midi_input_port, "control_sentinel": control_sentinel, + "reset_sentinel": reset_sentinel, "received_messages_queue": received_messages_queue, "midi_capture_channel": midi_capture_channel, - "midi_control_signal": midi_control_signal, "first_msg_epoch_time_ms": first_msg_epoch_time_ms, "results_queue": results_queue, "wait_for_close": wait_for_close, @@ -1241,10 +1308,10 @@ def capture_and_update_kv( def capture_midi_input( midi_input_port: str, control_sentinel: threading.Event, + reset_sentinel: threading.Event, received_messages_queue: queue.Queue, midi_capture_channel: int, results_queue: queue.Queue, - midi_control_signal: int | None = None, first_msg_epoch_time_ms: int | None = None, wait_for_close: bool = False, ): @@ -1253,8 +1320,7 @@ def capture_midi_input( first_on_msg_epoch_ms = None prev_msg_epoch_time_ms = first_msg_epoch_time_ms - logger.info(f"Listening on MIDI port: '{midi_input_port}'") - logger.info(f"Ready to capture MIDI events") + logger.info(f"Listening for input on MIDI port: '{midi_input_port}'") # Clear undesired buffered notes with mido.open_input(midi_input_port) as midi_input: @@ -1264,20 +1330,15 @@ def capture_midi_input( break with mido.open_input(midi_input_port) as midi_input: - if midi_control_signal is not None: - logger.info( - f"Commencing generation upon keypress or MIDI control: {midi_control_signal}" - ) - else: - logger.info(f"Commencing generation upon keypress") + logger.info(f"Commencing generation upon keypress or control signal") - while not control_sentinel.is_set() or ( - wait_for_close and active_pitches + while (not reset_sentinel.is_set()) and ( + not control_sentinel.is_set() or (wait_for_close and active_pitches) ): msg = midi_input.receive(block=False) if msg is None: - time.sleep(0.001) + time.sleep(0.01) continue if prev_msg_epoch_time_ms is None: @@ -1308,13 +1369,6 @@ def capture_midi_input( received_messages_queue.put(msg) elif msg.type == "control_change" and msg.control == 64: received_messages_queue.put(msg) - elif ( - msg.type == "control_change" - and msg.control == midi_control_signal - and msg.value > 0 - ): - control_sentinel.set() - logger.info("Control signal seen") logger.info(f"Active pitches: {active_pitches}") num_active_pitches = len(active_pitches) @@ -1359,6 +1413,7 @@ def play_midi_file( midi_in_port: str, midi_path: str, currently_streaming_sentinel: threading.Event, + reset_sentinel: threading.Event, ): def _send_delayed_message(port, msg): port.send(msg) @@ -1381,6 +1436,10 @@ def _send_delayed_message(port, msg): with mido.open_output(midi_through_port) as through_port: with mido.open_output(midi_in_port) as in_port: for msg in mid.play(): + if reset_sentinel.is_set(): + logger.debug("Exiting") + return + if currently_streaming_sentinel.is_set() is False and not ( msg.type == "control_change" and msg.control == 64 ): @@ -1396,30 +1455,34 @@ def _send_delayed_message(port, msg): def listen_for_keypress_control_signal( control_sentinel: threading.Event, - generate_ending_sentinel: threading.Event, + reset_sentinel: threading.Event, ): logger = get_logger("KEYBOARD") - while True: - time.sleep(5) + while not reset_sentinel.is_set(): + time.sleep(2) _input = input() logger.info(f'Keypress seen "{_input}"') if _input == "": control_sentinel.set() else: + reset_sentinel.set() control_sentinel.set() - generate_ending_sentinel.set() + logger.debug("Exiting") return def _listen( midi_input_port: str, + reset_sentinel: threading.Event, logger: logging.Logger, midi_control_signal: int | None = None, + midi_reset_control_signal: int | None = None, ): logger.info("Listening...") with mido.open_input(midi_input_port) as midi_input: - while True: + while not reset_sentinel.is_set(): msg = midi_input.receive(block=False) + if msg is None: time.sleep(0.01) elif ( @@ -1427,130 +1490,101 @@ def _listen( and msg.control == midi_control_signal and msg.value >= 64 ): - return + return midi_control_signal + elif ( + msg.type == "control_change" + and msg.control == midi_reset_control_signal + and msg.value >= 64 + ): + return midi_reset_control_signal def listen_for_midi_control_signal( midi_input_port: str, control_sentinel: threading.Event, + reset_sentinel: threading.Event, midi_control_signal: int | None = None, + midi_reset_control_signal: int | None = None, ): logger = get_logger("MIDI-CONTROL") - while True: - _listen( + while not reset_sentinel.is_set(): + signal_received = _listen( midi_input_port=midi_input_port, + reset_sentinel=reset_sentinel, midi_control_signal=midi_control_signal, + midi_reset_control_signal=midi_reset_control_signal, logger=logger, ) - control_sentinel.set() - logger.info("Seen MIDI control signal") - time.sleep(5) + if signal_received is not None: + logger.info(f"Seen MIDI control signal ({signal_received})") -def parse_args(): - argp = argparse.ArgumentParser() - argp.add_argument("--checkpoint", help="path to model checkpoint") - argp.add_argument("--midi_in", required=False, help="MIDI input port") - argp.add_argument("--midi_out", required=True, help="MIDI output port") - argp.add_argument( - "--midi_through", - required=False, - help="MIDI through port for received input", - ) - argp.add_argument( - "--midi_path", - required=False, - help="Use MIDI file instead of MIDI input port", - ) - argp.add_argument( - "--midi_control_signal", - type=int, - help="MIDI control change message for AI takeover", - ) - argp.add_argument( - "--temp", - help="sampling temperature value", - type=float, - required=False, - default=0.95, - ) - argp.add_argument( - "--min_p", - help="sampling min_p value", - type=float, - required=False, - default=0.03, - ) - argp.add_argument( - "--wait_for_close", - help="wait for note-offs before generating", - action="store_true", - ) - argp.add_argument( - "--quantize", - help="apply model quantize", - action="store_true", - ) - argp.add_argument( - "--save_path", - type=str, - required=False, - help="Path to save complete MIDI file", - ) + if signal_received == midi_control_signal: + control_sentinel.set() + time.sleep(2) + elif signal_received == midi_reset_control_signal: + reset_sentinel.set() + control_sentinel.set() - return argp.parse_args() + logger.debug("Exiting") -def main(args): - args = parse_args() +def run( + model: TransformerLM, + midi_in_port: str | None, + midi_through_port: str | None, + midi_out_port: str | None, + midi_path: str | None, + midi_save_path: str | None, + midi_control_signal: int, + midi_reset_control_signal: int, + reset_sentinel: str, + wait_for_close: bool, + temperature: float, + min_p: float, +): logger = get_logger() tokenizer = AbsTokenizer() - model = load_model(checkpoint_path=args.checkpoint) - model = compile_model(model=model) - - assert (args.midi_path and os.path.isfile(args.midi_path)) or args.midi_in - control_sentinel = threading.Event() - generate_ending_sentinel = threading.Event() currently_generating_sentinel = threading.Event() - if args.midi_through: - close_notes(args.midi_through) - if args.midi_out: - close_notes(args.midi_out) + if midi_through_port: + close_notes(midi_through_port) + if midi_out_port: + close_notes(midi_out_port) - if args.midi_path: - midi_input_port = "IAC Driver Bus 1" + if midi_path: + midi_playback_port = "IAC Driver Bus 1" play_file_thread = threading.Thread( target=play_midi_file, args=( - args.midi_through, - midi_input_port, - args.midi_path, + midi_through_port, + midi_playback_port, + midi_path, currently_generating_sentinel, + reset_sentinel, ), - daemon=True, ) else: - midi_input_port = args.midi_in + midi_playback_port = midi_in_port play_file_thread = None keypress_thread = threading.Thread( target=listen_for_keypress_control_signal, - args=[control_sentinel, generate_ending_sentinel], - daemon=True, + args=[control_sentinel, reset_sentinel], ) midi_control_thread = threading.Thread( target=listen_for_midi_control_signal, kwargs={ "midi_input_port": ( - args.midi_in if args.midi_in else midi_input_port + midi_in_port if midi_in_port else midi_playback_port ), "control_sentinel": control_sentinel, - "midi_control_signal": args.midi_control_signal, + "reset_sentinel": reset_sentinel, + "midi_control_signal": midi_control_signal, + "midi_reset_control_signal": midi_reset_control_signal, }, - daemon=True, ) keypress_thread.start() midi_control_thread.start() @@ -1564,15 +1598,15 @@ def main(args): msgs=[], prev_context=[], control_sentinel=control_sentinel, - wait_for_close=args.wait_for_close, - midi_input_port=midi_input_port, - midi_control_signal=args.midi_control_signal, + reset_sentinel=reset_sentinel, + wait_for_close=wait_for_close, + midi_input_port=midi_playback_port, midi_capture_channel=0, ) ) curr_midi_channel = 0 - while True: + while True and not reset_sentinel.is_set(): control_sentinel.clear() currently_generating_sentinel.set() msgs = stream_msgs( @@ -1580,23 +1614,28 @@ def main(args): tokenizer=tokenizer, msgs=msgs, prev_context=prev_context, - midi_output_port=args.midi_out, + midi_output_port=midi_out_port, first_on_msg_epoch_ms=first_on_msg_epoch_ms, control_sentinel=control_sentinel, - temperature=args.temp, - min_p=args.min_p, + temperature=temperature, + min_p=min_p, num_preceding_active_pitches=num_active_pitches, midi_stream_channel=curr_midi_channel, is_ending=False, ) + if midi_save_path: + logger.info(f"Saving result to {midi_save_path}") + midi = convert_msgs_to_midi(msgs=msgs) + midi.save(midi_save_path) + curr_midi_channel += 1 if curr_midi_channel == 9: curr_midi_channel += 1 control_sentinel.clear() - if generate_ending_sentinel.is_set(): - break + if reset_sentinel.is_set(): + return else: currently_generating_sentinel.clear() msgs, prev_context, _, num_active_pitches = capture_and_update_kv( @@ -1604,33 +1643,66 @@ def main(args): msgs=msgs, prev_context=prev_context, control_sentinel=control_sentinel, - wait_for_close=args.wait_for_close, - midi_input_port=midi_input_port, - midi_control_signal=args.midi_control_signal, + reset_sentinel=reset_sentinel, + wait_for_close=wait_for_close, + midi_input_port=midi_playback_port, midi_capture_channel=curr_midi_channel, first_msg_epoch_time_ms=first_on_msg_epoch_ms, ) - # Generate ending - msgs = stream_msgs( - model=model, - tokenizer=tokenizer, - msgs=msgs, - prev_context=prev_context, - midi_output_port=args.midi_out, - first_on_msg_epoch_ms=first_on_msg_epoch_ms, - control_sentinel=control_sentinel, - temperature=args.temp / 2, - min_p=args.min_p, - num_preceding_active_pitches=num_active_pitches, - midi_stream_channel=curr_midi_channel, - is_ending=True, + keypress_thread.join() + midi_control_thread.join() + if play_file_thread: + play_file_thread.join() + + +def insert_embedding( + model: TransformerLM, + embedding_model_checkpoint_path: str, + embedding_midi_path: str, +): + logger = get_logger() + logger.info(f"Loading embedding from {embedding_midi_path}") + emb = _get_embedding( + embedding_model_checkpoint_path=embedding_model_checkpoint_path, + embedding_midi_path=embedding_midi_path, ) + logger.info(f"Inserting embedding into context") + model.fill_condition_kv(mx.array([emb], dtype=DTYPE)) + + global EMBEDDING_OFFSET + EMBEDDING_OFFSET = 1 + + +def main(args): + model = load_model(checkpoint_path=args.checkpoint) + model = compile_model(model=model) + if args.embedding_checkpoint and args.embedding_midi_path: + insert_embedding( + model=model, + embedding_model_checkpoint_path=args.embedding_checkpoint, + embedding_midi_path=args.embedding_midi_path, + ) - if args.save_path: - logger.info(f"Saving result to {args.save_path}") - midi = convert_msgs_to_midi(msgs=msgs) - midi.save(args.save_path) + assert (args.midi_path and os.path.isfile(args.midi_path)) or args.midi_in + + reset_sentinel = threading.Event() + while True: + run( + model=model, + midi_in_port=args.midi_in, + midi_through_port=args.midi_through, + midi_out_port=args.midi_out, + midi_path=args.midi_path, + midi_save_path=args.save_path, + midi_control_signal=args.midi_control_signal, + midi_reset_control_signal=args.midi_reset_control_signal, + reset_sentinel=reset_sentinel, + wait_for_close=args.wait_for_close, + temperature=args.temp, + min_p=args.min_p, + ) + reset_sentinel = threading.Event() def close_notes(midi_out_port: str): From 3373f12c31e78d800e6aaf291b8dc411840cdd95 Mon Sep 17 00:00:00 2001 From: Louis Date: Mon, 7 Jul 2025 14:44:50 +0100 Subject: [PATCH 02/21] Fix README and run --- aria/run.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/aria/run.py b/aria/run.py index 71c319b..804fa12 100644 --- a/aria/run.py +++ b/aria/run.py @@ -317,13 +317,13 @@ def generate(args): def _get_embedding( - embedding_model_checkpoints_path: str, + embedding_model_checkpoint_path: str, embedding_midi_path: str, ): from aria.embedding import get_global_embedding_from_midi model = _load_embedding_model( - checkpoint_path=embedding_model_checkpoints_path + checkpoint_path=embedding_model_checkpoint_path ).cpu() global_embedding = get_global_embedding_from_midi( model=model, @@ -353,7 +353,7 @@ def conditioned_generate(args): prompt_duration_s=prompt_duration_s, ) embedding = _get_embedding( - embedding_model_checkpoints_path=args.embedding_model_checkpoint_path, + embedding_model_checkpoint_path=args.embedding_model_checkpoint_path, embedding_midi_path=args.embedding_midi_path, ) max_new_tokens = min(8096 - len(prompt), max_new_tokens) From caa1d7b115df5e0cc4cbf83c9648f64e314893f9 Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 9 Jul 2025 14:11:26 +0100 Subject: [PATCH 03/21] add short delay to control_signal detection --- demo/demo_mlx.sh | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 demo/demo_mlx.sh diff --git a/demo/demo_mlx.sh b/demo/demo_mlx.sh new file mode 100644 index 0000000..5ed3a9a --- /dev/null +++ b/demo/demo_mlx.sh @@ -0,0 +1,31 @@ +# # MID_PATH="/Users/louis/Library/CloudStorage/Dropbox/shared/audio_piano.mid" +# MID_PATH="/Users/louis/Library/CloudStorage/Dropbox/shared/bill_evans.mid" +# # --midi_path ${MID_PATH} \ +# # --midi_through "IAC Driver Bus 3" \ + +# python /Users/louis/work/aria/demo/demo_mlx.py \ +# --checkpoint /Users/louis/work/aria/models/medium-75-ft.safetensors \ +# --embedding_checkpoint /Users/louis/work/aria/models/medium-emb.safetensors \ +# --embedding_midi /Users/louis/Library/CloudStorage/Dropbox/shared/prompt/noodle.mid \ +# --midi_in "Scarlett 18i8 USB" \ +# --midi_out "Scarlett 18i8 USB" \ +# --midi_control_signal 67 \ +# --midi_reset_control_signal 66 \ +# --save_path /Users/louis/Dropbox/shared/output.mid \ +# --quantize \ +# --temp 0.98 \ +# --min_p 0.035 + +#### + +MID_PATH="./example-prompts/nocturne.mid" + +python ./demo/demo_mlx.py \ + --checkpoint ./models/medium-75-annealed.safetensors \ + --midi_path ${MID_PATH} \ + --midi_through "IAC Driver Bus 2" \ + --midi_out "IAC Driver Bus 3" \ + --save_path ./output.mid \ + --quantize \ + --temp 0.98 \ + --min_p 0.035 From 35b1a53860ed372e5e05cfb2adf347fca58e3eb2 Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 9 Jul 2025 14:12:30 +0100 Subject: [PATCH 04/21] Fix incorrect commit --- demo/demo_mlx.py | 1 + 1 file changed, 1 insertion(+) diff --git a/demo/demo_mlx.py b/demo/demo_mlx.py index 7aa852c..2f8f470 100644 --- a/demo/demo_mlx.py +++ b/demo/demo_mlx.py @@ -1478,6 +1478,7 @@ def _listen( midi_control_signal: int | None = None, midi_reset_control_signal: int | None = None, ): + time.sleep(2) logger.info("Listening...") with mido.open_input(midi_input_port) as midi_input: while not reset_sentinel.is_set(): From e8f7e05c6f3fea7fa4be8cb2400a5db0b604076c Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 9 Jul 2025 14:13:14 +0100 Subject: [PATCH 05/21] rm script --- demo/demo_mlx.sh | 31 ------------------------------- 1 file changed, 31 deletions(-) delete mode 100644 demo/demo_mlx.sh diff --git a/demo/demo_mlx.sh b/demo/demo_mlx.sh deleted file mode 100644 index 5ed3a9a..0000000 --- a/demo/demo_mlx.sh +++ /dev/null @@ -1,31 +0,0 @@ -# # MID_PATH="/Users/louis/Library/CloudStorage/Dropbox/shared/audio_piano.mid" -# MID_PATH="/Users/louis/Library/CloudStorage/Dropbox/shared/bill_evans.mid" -# # --midi_path ${MID_PATH} \ -# # --midi_through "IAC Driver Bus 3" \ - -# python /Users/louis/work/aria/demo/demo_mlx.py \ -# --checkpoint /Users/louis/work/aria/models/medium-75-ft.safetensors \ -# --embedding_checkpoint /Users/louis/work/aria/models/medium-emb.safetensors \ -# --embedding_midi /Users/louis/Library/CloudStorage/Dropbox/shared/prompt/noodle.mid \ -# --midi_in "Scarlett 18i8 USB" \ -# --midi_out "Scarlett 18i8 USB" \ -# --midi_control_signal 67 \ -# --midi_reset_control_signal 66 \ -# --save_path /Users/louis/Dropbox/shared/output.mid \ -# --quantize \ -# --temp 0.98 \ -# --min_p 0.035 - -#### - -MID_PATH="./example-prompts/nocturne.mid" - -python ./demo/demo_mlx.py \ - --checkpoint ./models/medium-75-annealed.safetensors \ - --midi_path ${MID_PATH} \ - --midi_through "IAC Driver Bus 2" \ - --midi_out "IAC Driver Bus 3" \ - --save_path ./output.mid \ - --quantize \ - --temp 0.98 \ - --min_p 0.035 From 4e37206563787a56d0fc6785b4f0c4c3f806d876 Mon Sep 17 00:00:00 2001 From: Louis Date: Fri, 11 Jul 2025 13:37:39 +0000 Subject: [PATCH 06/21] add support for pedal sampling --- aria/inference/__init__.py | 14 ++++++++++++-- aria/inference/sample_cuda.py | 9 --------- aria/run.py | 1 + config/models/medium.json | 1 + 4 files changed, 14 insertions(+), 11 deletions(-) diff --git a/aria/inference/__init__.py b/aria/inference/__init__.py index ceac4b4..4cb243e 100644 --- a/aria/inference/__init__.py +++ b/aria/inference/__init__.py @@ -48,11 +48,21 @@ def get_inference_prompt( for msg in midi_dict.note_msgs if midi_dict.tick_to_ms(msg["data"]["start"]) <= prompt_len_ms ] + midi_dict.pedal_msgs = [ + msg + for msg in midi_dict.pedal_msgs + if midi_dict.tick_to_ms(msg["tick"]) <= prompt_len_ms + ] + if midi_dict.pedal_msgs and midi_dict.pedal_msgs[-1]["data"] == 1: + midi_dict.pedal_msgs.pop() if len(midi_dict.note_msgs) == 0: return [("prefix", "instrument", "piano"), tokenizer.bos_tok] - seq = tokenizer.tokenize(midi_dict=midi_dict, add_dim_tok=False) - seq.remove(tokenizer.eos_tok) + seq = tokenizer.tokenize( + midi_dict=midi_dict, + add_dim_tok=False, + add_eos_tok=False, + ) return seq diff --git a/aria/inference/sample_cuda.py b/aria/inference/sample_cuda.py index 909bd8d..5a0f793 100644 --- a/aria/inference/sample_cuda.py +++ b/aria/inference/sample_cuda.py @@ -16,15 +16,6 @@ DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 -def get_cfg_prompt(prompts: list): - cfg_prompts = [] - for prompt in prompts: - cfg_prompts.append(prompt) - cfg_prompts.append(prompt) - - return cfg_prompts - - @torch.inference_mode() def decode_one( model: TransformerLM, diff --git a/aria/run.py b/aria/run.py index 804fa12..6304430 100644 --- a/aria/run.py +++ b/aria/run.py @@ -263,6 +263,7 @@ def generate(args): args.prompt_midi_path, prompt_duration_s=prompt_duration_s, ) + print(prompt) max_new_tokens = min(8096 - len(prompt), max_new_tokens) if backend == "torch_cuda": diff --git a/config/models/medium.json b/config/models/medium.json index a1df8a6..40384d2 100644 --- a/config/models/medium.json +++ b/config/models/medium.json @@ -5,5 +5,6 @@ "ff_mult": 4, "drop_p": 0.0, "max_seq_len": 8192, + "vocab_size": 17727, "grad_checkpoint": true } From 3ffa60acb1e7fa74cdbde98fec5f2b05fa808f22 Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 15 Jul 2025 10:08:30 +0100 Subject: [PATCH 07/21] update model implementation (training) to fp32 rotary embeddings --- aria/model.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/aria/model.py b/aria/model.py index 573f548..7a04f1b 100644 --- a/aria/model.py +++ b/aria/model.py @@ -379,7 +379,6 @@ def precompute_freqs_cis( seq_len: int, n_elem: int, base: int = 500000, - dtype: torch.dtype = torch.bfloat16, ): freqs = 1.0 / ( base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem) @@ -389,7 +388,7 @@ def precompute_freqs_cis( freqs_cis = torch.polar(torch.ones_like(freqs), freqs) cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) - return cache.to(dtype=dtype) + return cache @torch.jit.script @@ -397,14 +396,15 @@ def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: """ In-place RoPE. Credits to Katherine Crowson: x shape (b_sz, s_len, n_head, d_head). - cos, sin shape (s_len, d_head // 2). + freqs_cis shape (s_len, d_head // 2, 2) and is float32. """ - - d = x.shape[-1] // 2 + x_float = x.float() + freqs_cis = freqs_cis.detach() + d = x_float.shape[-1] // 2 cos = freqs_cis[..., 0][None, :, None] sin = freqs_cis[..., 1][None, :, None] - x1, x2 = x[..., :d], x[..., d : d * 2] + x1, x2 = x_float[..., :d], x_float[..., d : d * 2] tmp = x1.clone() x1.mul_(cos).addcmul_(x2, sin, value=-1) x2.mul_(cos).addcmul_(tmp, sin, value=1) - return x + return x.copy_(x_float) From c5e42e3168b455986d999beb291cd65251e509d6 Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 15 Jul 2025 11:47:16 +0100 Subject: [PATCH 08/21] fix arg --- aria/model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/aria/model.py b/aria/model.py index 7a04f1b..25b892b 100644 --- a/aria/model.py +++ b/aria/model.py @@ -178,7 +178,6 @@ def forward( seq_len=self.model_config.max_seq_len, n_elem=self.model_config.d_model // self.model_config.n_heads, base=500000, - dtype=hidden_states.dtype, ).to(src.device) freqs_cis = self.freqs_cis[: src.shape[1]] From d0129a9511278f87706d7d92e740c41b87e1fb7c Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 15 Jul 2025 12:54:09 +0100 Subject: [PATCH 09/21] fix grad_acc_steps in lr decay --- aria/training/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aria/training/train.py b/aria/training/train.py index 67001a2..6521272 100644 --- a/aria/training/train.py +++ b/aria/training/train.py @@ -586,7 +586,7 @@ def resume_train( optimizer, scheduler = get_optim( model, num_epochs=epochs, - steps_per_epoch=len(train_dataloader), + steps_per_epoch=len(train_dataloader) // grad_acc_steps, ) ( @@ -731,7 +731,7 @@ def train( optimizer, scheduler = get_optim( model, num_epochs=epochs, - steps_per_epoch=len(train_dataloader), + steps_per_epoch=len(train_dataloader) // grad_acc_steps, ) ( From 73e8072bf7d7ec9a839d0154e783919184d6aa91 Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 16 Jul 2025 21:50:05 +0100 Subject: [PATCH 10/21] add dynamic playback support to demo --- demo/calibrate.py | 126 ++++- demo/demo_mlx.py | 745 +++++++++++++++++++---------- demo/hardware/c4dm-disklavier.json | 22 + 3 files changed, 634 insertions(+), 259 deletions(-) create mode 100644 demo/hardware/c4dm-disklavier.json diff --git a/demo/calibrate.py b/demo/calibrate.py index 74d0720..73a9120 100644 --- a/demo/calibrate.py +++ b/demo/calibrate.py @@ -6,7 +6,7 @@ import mido MIDDLE_C = 60 -C_MAJOR_CHORD = [MIDDLE_C, 64, 67, 72] # C4, E4, G4, C5 +C_MAJOR_CHORD = [MIDDLE_C - 12, 64 - 12, 67 - 12, 72 - 12] # C4, E4, G4, C5 def schedule_note_off(port: mido.ports.BaseOutput, note: int, delay: float): @@ -89,6 +89,41 @@ def note_repetition_trial( print("...loop finished.\n") +def velocity_strike_pair( + port: mido.ports.BaseOutput, + high_velocity: int, + low_velocity: int, + delay_ms: int, +): + """ + Sends a low-velocity C5, waits, then sends a high-velocity C4. + The goal is to adjust the delay until they sound simultaneous. + """ + print("Playing velocity pair (C4 high-vel, C5 low-vel)...") + delay_sec = delay_ms / 1000.0 + note_duration_sec = 1.0 # Audible duration of the notes + note_high_vel = MIDDLE_C # C4 + note_low_vel = MIDDLE_C + 1 # D4 + + # Send the low velocity note (C5) first + port.send(mido.Message("note_on", note=note_low_vel, velocity=low_velocity)) + schedule_note_off(port, note_low_vel, delay=delay_sec + note_duration_sec) + + # Wait for the specified delay + if delay_sec > 0: + time.sleep(delay_sec) + + # Send the high velocity note (C4) + port.send( + mido.Message("note_on", note=note_high_vel, velocity=high_velocity) + ) + schedule_note_off(port, note_high_vel, delay=note_duration_sec) + + # Give the user time to hear the result + time.sleep(note_duration_sec + 0.5) + print("...done.\n") + + def calibrate_output_latency( port_name: str, velocity: int, @@ -161,6 +196,52 @@ def calibrate_note_timing( print(f"\nAn error occurred: {e}") +def calibrate_velocity_latency( + port_name: str, + velocity: int, + low_velocity: int, + step_ms: int, + initial_delay_ms: int, + chord_mode: bool, +): + """ + Interactive loop to find the latency difference between velocities. + This mode uses fixed notes (C4 and C5) and ignores the --chord flag. + """ + delay_ms = initial_delay_ms + + try: + with mido.open_output(port_name) as port: + print(f"Opened MIDI output: {port_name}\n") + print( + f"High-velocity note (C4): {velocity}\n" + f"Low-velocity note (C5): {low_velocity}\n" + ) + while True: + velocity_strike_pair( + port, + high_velocity=velocity, + low_velocity=low_velocity, + delay_ms=delay_ms, + ) + print(f"Current low-velocity pre-send delay: {delay_ms} ms") + cmd = ( + input("[u]p / [d]own / [r]epeat / [q]uit: ").strip().lower() + ) + + if cmd == "u": + delay_ms += step_ms + elif cmd == "d": + delay_ms = max(0, delay_ms - step_ms) + elif cmd == "q": + break + print() + except (KeyboardInterrupt, SystemExit): + print("\nInterrupted — exiting.") + except Exception as e: + print(f"\nAn error occurred: {e}") + + def measure_input_latency(port_name: str, timeout_sec: float = 2.0): """ 3-2-1-GO countdown → you strike a key on GO. @@ -240,21 +321,21 @@ def parse_args(): "--velocity", "-v", type=int, - default=80, - help="Note-on velocity (1-127).", + default=120, + help="Note-on velocity (1-127). For 'velocity' mode, this is the HIGH velocity.", ) parent.add_argument( "--step", "-s", type=int, default=10, - help="Adjustment step in ms (latency/timing modes).", + help="Adjustment step in ms (for latency/timing/velocity modes).", ) parent.add_argument( "--chord", "-c", action="store_true", - help="Use a C-major chord instead of single note.", + help="Use a C-major chord instead of single note (ignored in 'velocity' mode).", ) sub = parser.add_subparsers(dest="command", help="Available commands.") @@ -296,7 +377,7 @@ def parse_args(): help="Initial gap between notes in ms.", ) - # ── input-latency measurement (new) ─────────────────────────────────── + # ── input-latency measurement ───────────────────────────────────────── p_in = sub.add_parser( "input", parents=[parent], @@ -311,6 +392,27 @@ def parse_args(): help="Seconds to wait for a key press before retry.", ) + # ── velocity-latency calibration (NEW) ──────────────────────────────── + p_vel = sub.add_parser( + "velocity", + parents=[parent], + help="Calibrate additional latency of low-velocity notes.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + p_vel.add_argument( + "--low-velocity", + "-lv", + type=int, + default=20, + help="The low velocity to compare against (1-127).", + ) + p_vel.add_argument( + "--delay", + type=int, + default=50, + help="Initial pre-send delay for the low-velocity note in ms.", + ) + args = parser.parse_args() # global flag handler @@ -320,7 +422,7 @@ def parse_args(): if not args.command: parser.error( - "A command is required: choose 'output', 'timing', or 'input'." + "A command is required: choose 'output', 'timing', 'input', or 'velocity'." ) return args @@ -355,6 +457,16 @@ def main(): timeout_sec=args.timeout, ) + elif args.command == "velocity": + calibrate_velocity_latency( + port_name=args.port, + velocity=args.velocity, + low_velocity=args.low_velocity, + step_ms=args.step, + initial_delay_ms=args.delay, + chord_mode=args.chord, + ) + if __name__ == "__main__": main() diff --git a/demo/demo_mlx.py b/demo/demo_mlx.py index 2f8f470..c834cf3 100644 --- a/demo/demo_mlx.py +++ b/demo/demo_mlx.py @@ -4,12 +4,11 @@ import os import time import uuid -import copy import random import logging import threading import queue -import copy +import json import mido import torch @@ -24,35 +23,26 @@ from aria.config import load_model_config from aria.run import _get_embedding -EMBEDDING_OFFSET = 0 +EMBEDDING_OFFSET: int = 0 DTYPE = mx.float32 -MAX_SEQ_LEN = 2048 -PREFILL_CHUNK_SIZE_L = 128 -PREFILL_CHUNK_SIZE = 16 -RECALC_DUR_PREFILL_CHUNK_SIZE = 8 -RECALC_DUR_BUFFER_MS = 100 - -BEAM_WIDTH = 3 -TIME_TOK_WEIGHTING = -5 -FIRST_ONSET_BUFFER_MS = -200 # Controls onset timing for first generated not - -# HARDWARE: Decoded logits are masked for durations < MIN_NOTE_LEN_MS -# HARDWARE: Sends early off-msg if pitch is on MIN_NOTE_DELTA_MS before on-msg -# HARDWARE: All messages are sent HARDWARE_OUTPUT_LATENCY_MS early - -# C4DM Disklavier: -MIN_NOTE_DELTA_MS = 40 -MIN_NOTE_LEN_MS = 100 -HARDWARE_INPUT_LATENCY_MS = 50 -HARDWARE_OUTPUT_LATENCY_MS = 120 - -# Pianoteq -# MIN_NOTE_DELTA_MS = 0 -# MIN_NOTE_LEN_MS = 30 -# HARDWARE_INPUT_LATENCY_MS = 0 -# HARDWARE_OUTPUT_LATENCY_MS = 0 - -MAX_STREAM_DELAY_MS = 50 +MAX_SEQ_LEN: int = 2048 +PREFILL_CHUNK_SIZE_L: int = 128 +PREFILL_CHUNK_SIZE: int = 16 +RECALC_DUR_PREFILL_CHUNK_SIZE: int = 8 +RECALC_DUR_BUFFER_MS: int = 100 + +BEAM_WIDTH: int = 3 +TIME_TOK_WEIGHTING: int = -5 +FIRST_ONSET_BUFFER_MS: int = -200 +MAX_STREAM_DELAY_MS: int = 50 + +MIN_NOTE_DELTA_MS: int = 0 +MIN_PEDAL_DELTA_MS: int = 0 +MIN_NOTE_LENGTH_MS: int = 0 +HARDWARE_INPUT_LATENCY_MS: int = 0 +BASE_OUTPUT_LATENCY_MS: int = 0 +VELOCITY_OUTPUT_LATENCY_MS: dict[int, int] = {v: 0 for v in range(0, 127, 10)} + file_handler = logging.FileHandler("./demo.log", mode="w") file_handler.setLevel(logging.DEBUG) @@ -144,6 +134,12 @@ def parse_args(): required=False, help="path to save complete MIDI file", ) + argp.add_argument( + "--hardware", + type=str, + required=False, + help="path to json file containing hardware calibration settings", + ) argp.add_argument( "--embedding_checkpoint", type=str, @@ -156,10 +152,41 @@ def parse_args(): help="path to embedding MIDI file for conditioned generation", required=False, ) + argp.add_argument( + "--playback", + action="store_true", + help="playback file at midi_path through output_port", + required=False, + ) return argp.parse_args() +def set_calibration_settings(load_path: str): + with open(load_path, "r") as f: + _settings = json.load(f) + + global MIN_NOTE_DELTA_MS + global MIN_PEDAL_DELTA_MS + global MIN_NOTE_LENGTH_MS + global HARDWARE_INPUT_LATENCY_MS + global BASE_OUTPUT_LATENCY_MS + global VELOCITY_OUTPUT_LATENCY_MS + + MIN_NOTE_DELTA_MS = _settings["MIN_NOTE_DELTA_MS"] + MIN_PEDAL_DELTA_MS = _settings["MIN_PEDAL_DELTA_MS"] + MIN_NOTE_LENGTH_MS = _settings["MIN_NOTE_LENGTH_MS"] + HARDWARE_INPUT_LATENCY_MS = _settings["HARDWARE_INPUT_LATENCY_MS"] + BASE_OUTPUT_LATENCY_MS = _settings["BASE_OUTPUT_LATENCY_MS"] + VELOCITY_OUTPUT_LATENCY_MS = { + int(k): v for k, v in _settings["VELOCITY_OUTPUT_LATENCY_MS"].items() + } + + +def _get_input_latency_ms(velocity: int): + return BASE_OUTPUT_LATENCY_MS + VELOCITY_OUTPUT_LATENCY_MS[velocity] + + def get_epoch_time_ms() -> int: return round(time.time() * 1000) @@ -217,7 +244,7 @@ def sample_min_p(probs: mx.array, p_base: float): return next_token -def _compile_prefill( +def _warmup_prefill( model: TransformerLM, logger: logging.Logger, chunk_size: int, @@ -262,7 +289,7 @@ def _compile_prefill( return model -def _compile_decode_one( +def _warmup_decode_one( model: TransformerLM, logger: logging.Logger, ): @@ -299,7 +326,7 @@ def _compile_decode_one( return model -def compile_model(model: TransformerLM): +def warmup_model(model: TransformerLM): logger = get_logger() model.eval() @@ -309,14 +336,14 @@ def compile_model(model: TransformerLM): dtype=DTYPE, ) - model = _compile_decode_one(model=model, logger=logger) + model = _warmup_decode_one(model=model, logger=logger) for chunk_size in list( { PREFILL_CHUNK_SIZE, RECALC_DUR_PREFILL_CHUNK_SIZE, } ): - model = _compile_prefill( + model = _warmup_prefill( model=model, logger=logger, chunk_size=chunk_size ) @@ -460,8 +487,8 @@ def decode_first_tokens( ): logger = get_logger("GENERATE") - # buffer_ms determines how far in the past to start generating notes - buffer_ms = FIRST_ONSET_BUFFER_MS - HARDWARE_OUTPUT_LATENCY_MS + # buffer_ms determines how far in the past to start generating notes. + buffer_ms = FIRST_ONSET_BUFFER_MS time_tok_id = tokenizer.tok_to_id[tokenizer.time_tok] eos_tok_id = tokenizer.tok_to_id[tokenizer.eos_tok] dim_tok_id = tokenizer.tok_to_id[tokenizer.dim_tok] @@ -625,7 +652,7 @@ def decode_tokens( if is_ending is False: logits[:, tokenizer.tok_to_id[tokenizer.eos_tok]] = float("-inf") - for dur_ms in range(0, MIN_NOTE_LEN_MS, 10): + for dur_ms in range(0, MIN_NOTE_LENGTH_MS, 10): logits[:, tokenizer.tok_to_id[("dur", dur_ms)]] = float("-inf") if temperature > 0.0: @@ -743,6 +770,130 @@ def generate_tokens( ) +def _adjust_previous_off_time( + pitch_to_prev_msg: dict, + key: str | int, + new_on_send_time: int, + min_delta_ms: int, + logger: logging.Logger, +): + prev_on, prev_off = pitch_to_prev_msg.get(key, (None, None)) + + if prev_on is not None and prev_off is not None and min_delta_ms > 0: + adj_send_off_time = max( + min( + prev_off["send_epoch_time_ms"], + new_on_send_time - min_delta_ms, + ), + prev_on[ + "send_epoch_time_ms" + ], # Don't move prev_off before prev_on + ) + if adj_send_off_time != prev_off["send_epoch_time_ms"]: + logger.debug(f"Adjusting {prev_off}: t={adj_send_off_time}") + prev_off["send_epoch_time_ms"] = adj_send_off_time + prev_off["adjusted"] = True + + +# TODO: Verify that only ON -> OFF sequences are possible in tokenizer +def _decode_pedal_double( + note_buffer: list, + first_on_msg_epoch_ms: int, + num_time_toks: int, + pitch_to_prev_msg: dict, + outbound_midi_msg_queue: queue.Queue, + logger: logging.Logger, + tokenizer: AbsTokenizer, +): + pedal_tok, onset_tok = note_buffer + velocity = 127 if pedal_tok == tokenizer.ped_on_tok else 0 + _, onset = onset_tok + + onset_epoch_ms = first_on_msg_epoch_ms + (num_time_toks * 5000) + onset + send_onset_epoch_ms = onset_epoch_ms - BASE_OUTPUT_LATENCY_MS + pedal_msg = { + "pitch": "pedal", + "vel": velocity, + "epoch_time_ms": onset_epoch_ms, + "send_epoch_time_ms": send_onset_epoch_ms, + "uuid": "pedal", # All pedals have the same id + } + + if pedal_tok == tokenizer.ped_on_tok: + _adjust_previous_off_time( + pitch_to_prev_msg=pitch_to_prev_msg, + key="pedal", + new_on_send_time=send_onset_epoch_ms, + min_delta_ms=MIN_PEDAL_DELTA_MS, + logger=logger, + ) + pitch_to_prev_msg["pedal"] = (pedal_msg, None) + + elif pedal_tok == tokenizer.ped_off_tok: + prev_on, _ = pitch_to_prev_msg.get("pedal", (None, None)) + pitch_to_prev_msg["pedal"] = (prev_on, pedal_msg) + + outbound_midi_msg_queue.put(pedal_msg) + logger.debug(f"Put message: {pedal_msg}") + logger.debug(f"Ahead by {onset_epoch_ms - get_epoch_time_ms()}ms") + + return onset_epoch_ms + + +def _decode_note_triple( + note_buffer: list, + first_on_msg_epoch_ms: int, + num_time_toks: int, + pitch_to_prev_msg: dict, + outbound_midi_msg_queue: queue.Queue, + logger: logging.Logger, +): + note_tok, onset_tok, dur_tok = note_buffer + _, pitch, vel = note_tok + _, onset = onset_tok + _, dur = dur_tok + + _uuid = uuid.uuid4() + onset_epoch_ms = first_on_msg_epoch_ms + (num_time_toks * 5000) + onset + offset_epoch_ms = onset_epoch_ms + dur + send_onset_epoch_ms = onset_epoch_ms - _get_input_latency_ms(vel) + send_offset_epoch_ms = offset_epoch_ms - BASE_OUTPUT_LATENCY_MS + + on_msg = { + "pitch": pitch, + "vel": vel, + "epoch_time_ms": onset_epoch_ms, + "send_epoch_time_ms": send_onset_epoch_ms, + "uuid": _uuid, + } + off_msg = { + "pitch": pitch, + "vel": 0, + "epoch_time_ms": offset_epoch_ms, + "send_epoch_time_ms": send_offset_epoch_ms, + "uuid": _uuid, + } + + _adjust_previous_off_time( + pitch_to_prev_msg=pitch_to_prev_msg, + key=pitch, + new_on_send_time=send_onset_epoch_ms, + min_delta_ms=MIN_NOTE_DELTA_MS, + logger=logger, + ) + + pitch_to_prev_msg[pitch] = (on_msg, off_msg) + + outbound_midi_msg_queue.put(on_msg) + outbound_midi_msg_queue.put(off_msg) + logger.debug(f"Put message: {on_msg}") + logger.debug(f"Put message: {off_msg}") + logger.debug(f"Ahead by {onset_epoch_ms - get_epoch_time_ms()}ms") + + return offset_epoch_ms + + +# TODO: Refactor this method to prettify it def decode_tokens_to_midi( generated_tokens_queue: queue.Queue, outbound_midi_msg_queue: queue.Queue, @@ -770,13 +921,15 @@ def decode_tokens_to_midi( while True: tok = generated_tokens_queue.get() if tok is tokenizer.eos_tok: + # pitch=-1 is interpreted as the end message by stream_midi _uuid = uuid.uuid4() end_msg = { "pitch": -1, "vel": -1, - "epoch_time_ms": offset_epoch_ms + 100, # Last note offset + "epoch_time_ms": offset_epoch_ms + 100, + "send_epoch_time_ms": offset_epoch_ms + 100, "uuid": _uuid, - } # pitch=-1 denotes end_msg + } outbound_midi_msg_queue.put(end_msg) logger.info(f"Seen exit signal: EOS token") logger.debug(f"Put message: {end_msg}") @@ -790,6 +943,15 @@ def decode_tokens_to_midi( note_buffer.append(tok) if isinstance(tok, tuple) and tok[0] == "dur": + msg_type = "note" + break + elif ( + isinstance(tok, tuple) + and tok[0] == "onset" + and note_buffer[-2] + in {tokenizer.ped_on_tok, tokenizer.ped_off_tok} + ): + msg_type = "pedal" break while note_buffer and note_buffer[0] == tokenizer.time_tok: @@ -797,53 +959,58 @@ def decode_tokens_to_midi( num_time_toks += 1 note_buffer.pop(0) - assert len(note_buffer) == 3 + assert len(note_buffer) in {2, 3}, f"Generation error: buffer={note_buffer}" # fmt: skip + logger.debug(f"Decoded note: {note_buffer}") - note_tok, onset_tok, dur_tok = note_buffer - _, pitch, vel = note_tok - _, onset = onset_tok - _, dur = dur_tok - - _uuid = uuid.uuid4() - onset_epoch_ms = first_on_msg_epoch_ms + (num_time_toks * 5000) + onset - offset_epoch_ms = onset_epoch_ms + dur - on_msg = { - "pitch": pitch, - "vel": vel, - "epoch_time_ms": onset_epoch_ms, - "uuid": _uuid, - } - off_msg = { - "pitch": pitch, - "vel": 0, - "epoch_time_ms": offset_epoch_ms, - "uuid": _uuid, - } - # Not thread safe but in theory should be ok? - if pitch_to_prev_msg.get(pitch) is not None and MIN_NOTE_DELTA_MS > 0: - prev_on, prev_off = pitch_to_prev_msg.get(pitch) - adj_off_time = max( - min( - prev_off["epoch_time_ms"], - onset_epoch_ms - MIN_NOTE_DELTA_MS, - ), - prev_on["epoch_time_ms"], + if msg_type == "note": + offset_epoch_ms = _decode_note_triple( + note_buffer=note_buffer, + first_on_msg_epoch_ms=first_on_msg_epoch_ms, + num_time_toks=num_time_toks, + pitch_to_prev_msg=pitch_to_prev_msg, + outbound_midi_msg_queue=outbound_midi_msg_queue, + logger=logger, + ) + elif msg_type == "pedal": + offset_epoch_ms = _decode_pedal_double( + note_buffer=note_buffer, + first_on_msg_epoch_ms=first_on_msg_epoch_ms, + num_time_toks=num_time_toks, + pitch_to_prev_msg=pitch_to_prev_msg, + outbound_midi_msg_queue=outbound_midi_msg_queue, + logger=logger, + tokenizer=tokenizer, ) - if adj_off_time != prev_off["epoch_time_ms"]: - logger.debug(f"Adjusting {prev_off}: t={adj_off_time}") - prev_off["epoch_time_ms"] = adj_off_time - prev_off["adjusted"] = True + else: + raise ValueError - pitch_to_prev_msg[pitch] = [on_msg, off_msg] + note_buffer = [] - outbound_midi_msg_queue.put(on_msg) - outbound_midi_msg_queue.put(off_msg) - logger.debug(f"Put message: {on_msg}") - logger.debug(f"Put message: {off_msg}") - logger.debug(f"Ahead by {onset_epoch_ms - get_epoch_time_ms()}ms") - note_buffer = [] +def _create_mido_message( + msg_dict: dict, + channel: int, + time_delta_ms: int, +) -> mido.Message: + # Creates a mido message from an event dictionary + if msg_dict["pitch"] == "pedal": + return mido.Message( + "control_change", + control=64, + value=msg_dict["vel"], + channel=channel, + time=time_delta_ms, + ) + else: + # note-on or note-off + return mido.Message( + "note_on", + note=msg_dict["pitch"], + velocity=msg_dict["vel"], + channel=channel, + time=time_delta_ms, + ) def stream_midi( @@ -856,127 +1023,114 @@ def stream_midi( results_queue: queue.Queue, ): logger = get_logger("STREAM") - logger.info( - f"Sending generated messages on MIDI port: '{midi_output_port}'" - ) - logger.info( - f"Applying hardware output latency adjustment: {HARDWARE_OUTPUT_LATENCY_MS}ms" - ) + logger.info(f"Sending generated messages on port: '{midi_output_port}'") active_pitch_uuid = {} - is_pitch_active = {} - midi_msgs = [] + pending_msgs = [] + msgs_to_archive = [] with mido.open_output(midi_output_port) as midi_out: while not control_sentinel.is_set(): - while True: + while not inbound_midi_msg_queue.empty(): try: msg = inbound_midi_msg_queue.get_nowait() + if msg: + pending_msgs.append(msg) except queue.Empty: break - else: - logger.debug(f"Received message: {msg}") - midi_msgs.append(msg) - - midi_msgs = sorted( - midi_msgs, - key=lambda msg: ( - msg["epoch_time_ms"], - msg["vel"], - ), - ) - if control_sentinel.is_set(): - break + pending_msgs.sort(key=lambda m: (m["send_epoch_time_ms"], m["vel"])) - while midi_msgs: - # Messages are sent HARDWARE_OUTPUT_LATENCY_MS early - latency_adjusted_epoch_time_ms = ( - get_epoch_time_ms() + HARDWARE_OUTPUT_LATENCY_MS - ) - msg = midi_msgs[0] + while pending_msgs: + curr_epoch_time_ms = get_epoch_time_ms() + msg = pending_msgs[0] - if ( - 0 - < latency_adjusted_epoch_time_ms - msg["epoch_time_ms"] - <= MAX_STREAM_DELAY_MS + if msg["send_epoch_time_ms"] > curr_epoch_time_ms: + break + elif ( + curr_epoch_time_ms - msg["send_epoch_time_ms"] + > MAX_STREAM_DELAY_MS ): - if msg["pitch"] == -1: # End msg - control_sentinel.set() - break - - mido_msg = mido.Message( - "note_on", - note=msg["pitch"], - velocity=msg["vel"], - channel=0, - time=0, - ) + logger.debug(f"Skipping stale message: {msg}") + pending_msgs.pop(0) + continue - if msg["vel"] > 0: - active_pitch_uuid[msg["pitch"]] = msg["uuid"] - should_send_midi_out = True - should_append_to_msgs = True - elif msg.get("adjusted", False) is True: - should_send_midi_out = True - should_append_to_msgs = False - else: - should_send_midi_out = ( - active_pitch_uuid.get(msg["pitch"]) == msg["uuid"] - ) - should_append_to_msgs = should_send_midi_out - - if should_send_midi_out is True: - midi_out.send(mido_msg) - is_pitch_active[msg["pitch"]] = msg["vel"] != 0 - logger.info(f"Sent message: {mido_msg}") - if should_append_to_msgs is True: - mido_msg_with_time = copy.deepcopy(mido_msg) - mido_msg_with_time.channel = midi_stream_channel - mido_msg_with_time.time = max( - 0, - msg["epoch_time_ms"] - - last_channel_msg_epoch_time_ms, - ) - last_channel_msg_epoch_time_ms = msg["epoch_time_ms"] - msgs.append(mido_msg_with_time) + logger.debug(f"Processing: {msg}") - midi_msgs.pop(0) + # End signal + if msg["pitch"] == -1: + control_sentinel.set() + break - elif ( - latency_adjusted_epoch_time_ms - msg["epoch_time_ms"] - > MAX_STREAM_DELAY_MS - ): - # Message occurs too far in the past - logger.debug( - f"Skipping message occurring too far ({latency_adjusted_epoch_time_ms - msg['epoch_time_ms']}ms) in the past: {msg}" + should_send = False + should_archive = False + if msg["vel"] > 0: # note-on or pedal-on + active_pitch_uuid[msg["pitch"]] = msg["uuid"] + should_send = True + should_archive = True + else: # note-off or pedal-off (vel == 0) + if msg.get("adjusted", False): + should_send = True + should_archive = msg["pitch"] == "pedal" + elif active_pitch_uuid.get(msg["pitch"]) == msg["uuid"]: + should_send = True + should_archive = True + active_pitch_uuid.pop(msg["pitch"], None) + + if should_send: + mido_msg = _create_mido_message( + msg_dict=msg, channel=0, time_delta_ms=0 ) - midi_msgs.pop(0) - else: - # Message occurs in the future - break + midi_out.send(mido_msg) + logger.info(f"Sent message: {mido_msg}") + + if should_archive: + msgs_to_archive.append(msg) + + pending_msgs.pop(0) + + if control_sentinel.is_set(): + break time.sleep(0.005) - remaining_note_off_messages = [ + last_archive_time_ms = last_channel_msg_epoch_time_ms + msgs_to_archive.sort(key=lambda m: (m["epoch_time_ms"], m["vel"])) + + for msg in msgs_to_archive: + time_delta_ms = round(msg["epoch_time_ms"] - last_archive_time_ms) + mido_msg = _create_mido_message( + msg_dict=msg, + channel=midi_stream_channel, + time_delta_ms=time_delta_ms, + ) + msgs.append(mido_msg) + last_archive_time_ms = msg["epoch_time_ms"] + + logger.info("Sending final note-off messages for cleanup.") + remaining_off_msgs = [ msg - for msg in midi_msgs + for msg in pending_msgs if msg["vel"] == 0 + and msg["pitch"] != "pedal" and active_pitch_uuid.get(msg["pitch"]) == msg["uuid"] ] + remaining_off_msgs.sort(key=lambda m: (m["epoch_time_ms"])) - logger.info("Processing remaining note_off messages") - for msg in remaining_note_off_messages: - mido_msg = mido.Message( - "note_on", - note=msg["pitch"], - velocity=0, - channel=midi_stream_channel, - time=msg["epoch_time_ms"] - last_channel_msg_epoch_time_ms, + for msg in remaining_off_msgs: + mido_msg = _create_mido_message( + msg_dict=msg, channel=0, time_delta_ms=0 ) midi_out.send(mido_msg) - last_channel_msg_epoch_time_ms = msg["epoch_time_ms"] - msgs.append(mido_msg) + + time_delta_ms = round(msg["epoch_time_ms"] - last_archive_time_ms) + archived_msg = _create_mido_message( + msg_dict=msg, + channel=midi_stream_channel, + time_delta_ms=time_delta_ms, + ) + msgs.append(archived_msg) + last_archive_time_ms = msg["epoch_time_ms"] results_queue.put(msgs) @@ -995,11 +1149,19 @@ def stream_msgs( midi_stream_channel: int, is_ending: bool = False, ): + + logger = get_logger("STREAM") midi = convert_msgs_to_midi(msgs=msgs) midi_dict = MidiDict(**midi_to_dict(midi)) + midi_dict.remove_redundant_pedals() priming_seq = tokenizer.tokenize(midi_dict=midi_dict, add_dim_tok=False) priming_seq = priming_seq[: priming_seq.index(tokenizer.eos_tok)] + if priming_seq[-2] == tokenizer.ped_off_tok: + # Final pedal-off is needed for tokenizer, but unneeded in tokenized sequence + logger.info("Removing final pedal_off from tokenized sequence") + priming_seq = priming_seq[:-2] + if is_ending is True: priming_seq.append(tokenizer.dim_tok) @@ -1059,17 +1221,14 @@ def stream_msgs( "midi_stream_channel": midi_stream_channel, "results_queue": stream_midi_results_queue, }, - daemon=True, ) stream_midi_thread.start() generate_tokens_thread.join() decode_tokens_to_midi_thread.join() + stream_midi_thread.join() msgs = stream_midi_results_queue.get() - if is_ending is True: - stream_midi_thread.join() - return msgs @@ -1245,6 +1404,7 @@ def continuous_prefill( if msg_cnt >= 10: midi = convert_msgs_to_midi(msgs=msgs) midi_dict = MidiDict(**midi_to_dict(midi)) + midi_dict.remove_redundant_pedals() if len(midi_dict.note_msgs) > 0: curr_context = tokenizer.encode( @@ -1315,95 +1475,100 @@ def capture_midi_input( first_msg_epoch_time_ms: int | None = None, wait_for_close: bool = False, ): + """Captures MIDI input with improved structure and readability.""" logger = get_logger("CAPTURE") - active_pitches = set() first_on_msg_epoch_ms = None prev_msg_epoch_time_ms = first_msg_epoch_time_ms + pedal_down = False + pitches_held_down = set() + pitches_sustained_by_pedal = set() logger.info(f"Listening for input on MIDI port: '{midi_input_port}'") - # Clear undesired buffered notes + # Clear any buffered MIDI messages before starting with mido.open_input(midi_input_port) as midi_input: - while True: - msg = midi_input.receive(block=False) - if msg is None: - break + for _ in midi_input.iter_pending(): + pass with mido.open_input(midi_input_port) as midi_input: - logger.info(f"Commencing generation upon keypress or control signal") + logger.info("Commencing generation upon keypress or control signal") - while (not reset_sentinel.is_set()) and ( - not control_sentinel.is_set() or (wait_for_close and active_pitches) - ): - msg = midi_input.receive(block=False) + while True: + epoch_time_ms = get_epoch_time_ms() + active_notes = pitches_held_down.union(pitches_sustained_by_pedal) + should_stop = not wait_for_close or not active_notes + if reset_sentinel.is_set() or ( + control_sentinel.is_set() and should_stop + ): + break - if msg is None: + msg = midi_input.receive(block=False) + if not msg: time.sleep(0.01) continue + if msg.is_meta or msg.type == "program_change": + continue + + msg.channel = midi_capture_channel if prev_msg_epoch_time_ms is None: - msg_time_ms = 0 + msg.time = 0 else: - msg_time_ms = get_epoch_time_ms() - prev_msg_epoch_time_ms + msg.time = epoch_time_ms - prev_msg_epoch_time_ms - prev_msg_epoch_time_ms = get_epoch_time_ms() - msg.time = msg_time_ms - msg.channel = midi_capture_channel + prev_msg_epoch_time_ms = epoch_time_ms logger.info(f"Received message: [{msg}]") - if msg.is_meta is True or msg.type == "program_change": - continue - - if ( - msg.type == "note_on" and msg.velocity == 0 - ) or msg.type == "note_off": - active_pitches.discard(msg.note) - received_messages_queue.put(msg) - elif msg.type == "note_on" and msg.velocity > 0: - if first_on_msg_epoch_ms is None: - first_on_msg_epoch_ms = ( - get_epoch_time_ms() - HARDWARE_INPUT_LATENCY_MS - ) - - active_pitches.add(msg.note) - received_messages_queue.put(msg) - elif msg.type == "control_change" and msg.control == 64: - received_messages_queue.put(msg) + match msg.type: + case "note_on" if msg.velocity > 0: + if first_on_msg_epoch_ms is None: + first_on_msg_epoch_ms = ( + get_epoch_time_ms() - HARDWARE_INPUT_LATENCY_MS + ) + pitches_held_down.add(msg.note) + if pedal_down: + pitches_sustained_by_pedal.add(msg.note) + received_messages_queue.put(msg) + + case "note_off" | "note_on": + # Note-off + pitches_held_down.discard(msg.note) + received_messages_queue.put(msg) + + case "control_change" if msg.control == 64: + if msg.value >= 64: + pedal_down = True + pitches_sustained_by_pedal.update(pitches_held_down) + else: + pedal_down = False + pitches_sustained_by_pedal.clear() + received_messages_queue.put(msg) - logger.info(f"Active pitches: {active_pitches}") + active_pitches = pitches_held_down.union(pitches_sustained_by_pedal) num_active_pitches = len(active_pitches) + logger.info(f"Active pitches ({num_active_pitches}): {active_pitches}") - if active_pitches: - pitch = active_pitches.pop() - msg = mido.Message( - type="note_on", + time_offset = get_epoch_time_ms() - prev_msg_epoch_time_ms + for pitch in pitches_held_down: + note_off_msg = mido.Message( + "note_off", note=pitch, - velocity=0, channel=midi_capture_channel, - time=get_epoch_time_ms() - prev_msg_epoch_time_ms, + time=time_offset, + ) + received_messages_queue.put(note_off_msg) + time_offset = 0 + + received_messages_queue.put( + mido.Message( + "control_change", + control=64, + value=0, + channel=midi_capture_channel, + time=0, ) - received_messages_queue.put(msg) - - while active_pitches: - pitch = active_pitches.pop() - msg = mido.Message( - type="note_on", - note=pitch, - velocity=0, - channel=midi_capture_channel, - time=0, - ) - received_messages_queue.put(msg) - - # Turn off pedal - msg = mido.Message( - type="control_change", - control=64, - value=0, - channel=midi_capture_channel, - time=0, ) - received_messages_queue.put(msg) + received_messages_queue.put(None) results_queue.put((first_on_msg_epoch_ms, num_active_pitches)) @@ -1425,7 +1590,7 @@ def _send_delayed_message(port, msg): f"Simulating input to port '{midi_in_port}' with {HARDWARE_INPUT_LATENCY_MS}ms latency" ) - if MIN_NOTE_DELTA_MS > 0: + if BASE_OUTPUT_LATENCY_MS > 0: midi_dict = MidiDict.from_midi(midi_path) midi_dict.enforce_gaps(min_gap_ms=MIN_NOTE_DELTA_MS) mid = midi_dict.to_midi() @@ -1440,9 +1605,7 @@ def _send_delayed_message(port, msg): logger.debug("Exiting") return - if currently_streaming_sentinel.is_set() is False and not ( - msg.type == "control_change" and msg.control == 64 - ): + if currently_streaming_sentinel.is_set() is False: through_port.send(msg) timer = threading.Timer( @@ -1607,7 +1770,7 @@ def run( ) curr_midi_channel = 0 - while True and not reset_sentinel.is_set(): + while not reset_sentinel.is_set(): control_sentinel.clear() currently_generating_sentinel.set() msgs = stream_msgs( @@ -1677,7 +1840,7 @@ def insert_embedding( def main(args): model = load_model(checkpoint_path=args.checkpoint) - model = compile_model(model=model) + model = warmup_model(model=model) if args.embedding_checkpoint and args.embedding_midi_path: insert_embedding( model=model, @@ -1706,6 +1869,68 @@ def main(args): reset_sentinel = threading.Event() +def playback(midi_path: str, midi_out: str, save_path: str | None = None): + # Mocks generated playback by streaming from a real MIDI file + + close_notes(midi_out) + starting_epoch_time_ms = get_epoch_time_ms() + tokenizer = AbsTokenizer() + tokens_queue = queue.Queue() + midi_messages_queue = queue.Queue() + stream_midi_results_queue = queue.Queue() + control_sentinel = threading.Event() + + midi_dict = MidiDict.from_midi(midi_path) + tokenized_sequence = tokenizer.tokenize( + midi_dict, + add_dim_tok=False, + remove_preceding_silence=False, + ) + tokenized_sequence = tokenized_sequence[ + tokenized_sequence.index(tokenizer.bos_tok) + 1 : + ] + + # Populate token queue synthetically + for tok in tokenized_sequence: + tokens_queue.put(tok) + + decode_tokens_to_midi_thread = threading.Thread( + target=decode_tokens_to_midi, + kwargs={ + "generated_tokens_queue": tokens_queue, + "outbound_midi_msg_queue": midi_messages_queue, + "tokenizer": tokenizer, + "first_on_msg_epoch_ms": starting_epoch_time_ms, + "priming_seq_last_onset_ms": 0, + }, + ) + decode_tokens_to_midi_thread.start() + + stream_midi_thread = threading.Thread( + target=stream_midi, + kwargs={ + "inbound_midi_msg_queue": midi_messages_queue, + "msgs": [], + "last_channel_msg_epoch_time_ms": starting_epoch_time_ms, + "midi_output_port": midi_out, + "control_sentinel": control_sentinel, + "midi_stream_channel": 0, + "results_queue": stream_midi_results_queue, + }, + ) + stream_midi_thread.start() + + decode_tokens_to_midi_thread.join() + stream_midi_thread.join() + msgs = stream_midi_results_queue.get() + mid = convert_msgs_to_midi(msgs) + + if save_path is not None: + mid.save(save_path) + + return msgs + + def close_notes(midi_out_port: str): with mido.open_output(midi_out_port) as out: out.send(mido.Message(type="control_change", control=64, value=0)) @@ -1716,7 +1941,23 @@ def close_notes(midi_out_port: str): if __name__ == "__main__": args = parse_args() - try: - main(args) - except KeyboardInterrupt: - close_notes(args.midi_out) + if args.hardware: + set_calibration_settings(args.hardware) + + if args.playback is True: + # Playback only mode for testing + assert args.midi_path is not None, "Must provide midi_path" + try: + playback( + midi_path=args.midi_path, + midi_out=args.midi_out, + save_path=args.save_path, + ) + except KeyboardInterrupt: + close_notes(args.midi_out) + else: + # Main logic + try: + main(args) + except KeyboardInterrupt: + close_notes(args.midi_out) diff --git a/demo/hardware/c4dm-disklavier.json b/demo/hardware/c4dm-disklavier.json new file mode 100644 index 0000000..f7382fe --- /dev/null +++ b/demo/hardware/c4dm-disklavier.json @@ -0,0 +1,22 @@ +{ + "MIN_NOTE_DELTA_MS": 100, + "MIN_PEDAL_DELTA_MS": 100, + "MIN_NOTE_LENGTH_MS": 100, + "HARDWARE_INPUT_LATENCY_MS": 50, + "BASE_OUTPUT_LATENCY_MS": 50, + "VELOCITY_OUTPUT_LATENCY_MS": { + "120": 0, + "110": 0, + "100": 0, + "90": 4, + "80": 10, + "70": 30, + "60": 60, + "50": 85, + "40": 105, + "30": 130, + "20": 140, + "10": 155, + "0": 155 + } +} From 344ff78c05487b47b3b5b9a62b8d52be933f3ad2 Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 23 Jul 2025 20:07:58 +0100 Subject: [PATCH 11/21] fix mido race condition --- demo/demo_mlx.py | 94 ++++++++-- demo/paper/figure/figure.py | 345 ++++++++++++++++++++++++++++++++++++ 2 files changed, 423 insertions(+), 16 deletions(-) create mode 100644 demo/paper/figure/figure.py diff --git a/demo/demo_mlx.py b/demo/demo_mlx.py index c834cf3..61da844 100644 --- a/demo/demo_mlx.py +++ b/demo/demo_mlx.py @@ -6,6 +6,7 @@ import uuid import random import logging +import contextlib import threading import queue import json @@ -23,9 +24,13 @@ from aria.config import load_model_config from aria.run import _get_embedding +VIRTUAL_PLAYBACK_PORT = "Playback MIDI Port" +VIRTUAL_PERFORMANCE_PORT = "Performance MIDI Port" +VIRTUAL_CONTROL_PORT = "Control MIDI Port" + EMBEDDING_OFFSET: int = 0 DTYPE = mx.float32 -MAX_SEQ_LEN: int = 2048 +MAX_SEQ_LEN: int = 4096 PREFILL_CHUNK_SIZE_L: int = 128 PREFILL_CHUNK_SIZE: int = 16 RECALC_DUR_PREFILL_CHUNK_SIZE: int = 8 @@ -648,6 +653,7 @@ def decode_tokens( f"Sampled logits for positions {idx} by inserting {prev_tok} at position {idx-1}" ) + # logits[0, tokenizer.tok_to_id[tokenizer.ped_on_tok]] += 2 logits[:, tokenizer.tok_to_id[tokenizer.dim_tok]] = float("-inf") if is_ending is False: logits[:, tokenizer.tok_to_id[tokenizer.eos_tok]] = float("-inf") @@ -1475,7 +1481,6 @@ def capture_midi_input( first_msg_epoch_time_ms: int | None = None, wait_for_close: bool = False, ): - """Captures MIDI input with improved structure and readability.""" logger = get_logger("CAPTURE") first_on_msg_epoch_ms = None prev_msg_epoch_time_ms = first_msg_epoch_time_ms @@ -1592,6 +1597,7 @@ def _send_delayed_message(port, msg): if BASE_OUTPUT_LATENCY_MS > 0: midi_dict = MidiDict.from_midi(midi_path) + midi_dict.remove_redundant_pedals() midi_dict.enforce_gaps(min_gap_ms=MIN_NOTE_DELTA_MS) mid = midi_dict.to_midi() else: @@ -1641,8 +1647,10 @@ def _listen( midi_control_signal: int | None = None, midi_reset_control_signal: int | None = None, ): - time.sleep(2) - logger.info("Listening...") + time.sleep(1) + logger.info( + f"Listening for takeover signal ({midi_control_signal}) and reset signal ({midi_reset_control_signal}) on MIDI port: '{midi_input_port}'" + ) with mido.open_input(midi_input_port) as midi_input: while not reset_sentinel.is_set(): msg = midi_input.receive(block=False) @@ -1696,14 +1704,15 @@ def listen_for_midi_control_signal( def run( model: TransformerLM, - midi_in_port: str | None, + midi_in_performance_port: str, + midi_in_control_port: str, midi_through_port: str | None, midi_out_port: str | None, midi_path: str | None, midi_save_path: str | None, midi_control_signal: int, midi_reset_control_signal: int, - reset_sentinel: str, + reset_sentinel: threading.Event, # Changed from string to Event wait_for_close: bool, temperature: float, min_p: float, @@ -1719,19 +1728,17 @@ def run( close_notes(midi_out_port) if midi_path: - midi_playback_port = "IAC Driver Bus 1" play_file_thread = threading.Thread( target=play_midi_file, args=( midi_through_port, - midi_playback_port, + VIRTUAL_PLAYBACK_PORT, midi_path, currently_generating_sentinel, reset_sentinel, ), ) else: - midi_playback_port = midi_in_port play_file_thread = None keypress_thread = threading.Thread( @@ -1741,9 +1748,7 @@ def run( midi_control_thread = threading.Thread( target=listen_for_midi_control_signal, kwargs={ - "midi_input_port": ( - midi_in_port if midi_in_port else midi_playback_port - ), + "midi_input_port": midi_in_control_port, "control_sentinel": control_sentinel, "reset_sentinel": reset_sentinel, "midi_control_signal": midi_control_signal, @@ -1764,7 +1769,7 @@ def run( control_sentinel=control_sentinel, reset_sentinel=reset_sentinel, wait_for_close=wait_for_close, - midi_input_port=midi_playback_port, + midi_input_port=midi_in_performance_port, midi_capture_channel=0, ) ) @@ -1809,7 +1814,7 @@ def run( control_sentinel=control_sentinel, reset_sentinel=reset_sentinel, wait_for_close=wait_for_close, - midi_input_port=midi_playback_port, + midi_input_port=midi_in_performance_port, midi_capture_channel=curr_midi_channel, first_msg_epoch_time_ms=first_on_msg_epoch_ms, ) @@ -1838,6 +1843,47 @@ def insert_embedding( EMBEDDING_OFFSET = 1 +def forward_midi_input_port( + midi_input_port: str, + midi_performance_port: str, + midi_control_port: str, + create_virtual_input_port: bool = False, +): + logger = get_logger("MIDI-FORWARD") + + try: + with contextlib.ExitStack() as stack: + if create_virtual_input_port: + in_port = stack.enter_context( + mido.open_ioport(midi_input_port, virtual=True) + ) + else: + in_port = stack.enter_context(mido.open_input(midi_input_port)) + + perf_port = stack.enter_context( + mido.open_output(midi_performance_port, virtual=True) + ) + ctrl_port = stack.enter_context( + mido.open_output(midi_control_port, virtual=True) + ) + + logger.info( + f"Forwarding MIDI port '{midi_input_port}' to " + f"['{midi_performance_port}', '{midi_control_port}']" + ) + + while True: + msg = in_port.receive(block=True) + if msg: + perf_port.send(msg) + ctrl_port.send(msg) + + except (Exception, KeyboardInterrupt) as e: + logger.error(f"Error in MIDI forwarder: {e}") + finally: + logger.info("MIDI forwarder has shut down.") + + def main(args): model = load_model(checkpoint_path=args.checkpoint) model = warmup_model(model=model) @@ -1850,11 +1896,26 @@ def main(args): assert (args.midi_path and os.path.isfile(args.midi_path)) or args.midi_in + forwarder_thread = threading.Thread( + target=forward_midi_input_port, + kwargs={ + "midi_input_port": ( + VIRTUAL_PLAYBACK_PORT if args.midi_path else args.midi_in + ), + "midi_performance_port": VIRTUAL_PERFORMANCE_PORT, + "midi_control_port": VIRTUAL_CONTROL_PORT, + "create_virtual_input_port": True if args.midi_path else False, + }, + daemon=True, + ) + forwarder_thread.start() + reset_sentinel = threading.Event() while True: run( model=model, - midi_in_port=args.midi_in, + midi_in_performance_port=VIRTUAL_PERFORMANCE_PORT, + midi_in_control_port=VIRTUAL_CONTROL_PORT, midi_through_port=args.midi_through, midi_out_port=args.midi_out, midi_path=args.midi_path, @@ -1881,6 +1942,7 @@ def playback(midi_path: str, midi_out: str, save_path: str | None = None): control_sentinel = threading.Event() midi_dict = MidiDict.from_midi(midi_path) + midi_dict.remove_redundant_pedals() tokenized_sequence = tokenizer.tokenize( midi_dict, add_dim_tok=False, @@ -1938,6 +2000,7 @@ def close_notes(midi_out_port: str): out.send(mido.Message("note_off", note=note, velocity=0)) +# TODO: Debug issue with incorrect tokens being generated (model or mlx issue?) if __name__ == "__main__": args = parse_args() @@ -1956,7 +2019,6 @@ def close_notes(midi_out_port: str): except KeyboardInterrupt: close_notes(args.midi_out) else: - # Main logic try: main(args) except KeyboardInterrupt: diff --git a/demo/paper/figure/figure.py b/demo/paper/figure/figure.py new file mode 100644 index 0000000..e9e790b --- /dev/null +++ b/demo/paper/figure/figure.py @@ -0,0 +1,345 @@ +# /// script +# requires-python = ">=3.8" +# dependencies = [ +# # ### --- SIMPLIFIED --- ###: Pillow is no longer needed +# "matplotlib", +# "ariautils @ https://github.com/EleutherAI/aria-utils.git" +# ] +# /// + +import argparse +import math +from pathlib import Path +from typing import List, Literal, Optional, TypedDict + +# ### --- SIMPLIFIED --- ###: Matplotlib is now the primary plotting library +import matplotlib.pyplot as plt +import matplotlib.patches as patches +from ariautils.midi import MidiDict + +# --- Hardcoded Configuration Constants --- +# You can modify these values directly as needed. + +# Note Pitch Range (inclusive) +MIN_PITCH = 52 +MAX_PITCH = 80 + +# Image Dimensions and Resolution +DPI = 500 +IMG_WIDTH_CM = 35 +IMG_HEIGHT_CM = 4 + +# Visual Style +NOTE_SPACING_RATIO = ( + 0.4 # 0.0 for no spacing, 0.15 means 15% of row height is spacing +) +LINE_THICKNESS_PX = 12 +DASHED_LINE_PATTERN_PX = (61, 17) # (dash_length, gap_length) in pixels + +# Figure Mode Constants +CHUNK_LEN_MS = 690 # Length of a prefill chunk for 'prefill' mode + +# Colors +COLOR_YELLOW_BG = "#FFF9C4" +COLOR_GRAY_LINE = "#A9A9A9" +COLOR_RED_BG = "#FFB3B3" +COLOR_BLUE_NOTE = "#4A90E2" +COLOR_GREEN_NOTE = "#4CAF50" + +# --- End of Configuration --- + + +class NoteToDraw(TypedDict): + """A simple structure to hold note information for drawing.""" + + pitch: int + start_ms: int + end_ms: int + + +def create_pianoroll_pdf( + midi_path: Path, + output_path: Path, + cutoff_ms: int, + total_duration_ms: int, + truncate: bool, + mode: Literal["none", "prefill", "recalc", "generate"], + generation_len_ms: Optional[int] = None, +): + """ + Generates a minimalist, vector-based (PDF) piano roll from a MIDI file. + """ + print(f"Loading MIDI file: {midi_path.name}") + try: + midi_data = MidiDict.from_midi(midi_path) + except Exception as e: + print(f"Error loading MIDI file with MidiDict: {e}") + return + + # Determine the total time span to display notes for + display_until_ms = total_duration_ms + if mode == "generate" and generation_len_ms is not None: + display_until_ms = cutoff_ms + generation_len_ms + + # 1. Filter and collect all potentially visible notes + all_notes: List[NoteToDraw] = [] + time_offset_ms = -300 + for note_msg in midi_data.note_msgs: + pitch = note_msg["data"]["pitch"] + if not (MIN_PITCH <= pitch <= MAX_PITCH): + continue + start_ms = midi_data.tick_to_ms(note_msg["data"]["start"]) + if start_ms > display_until_ms: + continue + end_ms = midi_data.tick_to_ms(note_msg["data"]["end"]) + if mode == "none" and truncate and end_ms > cutoff_ms: + end_ms = cutoff_ms + all_notes.append( + { + "pitch": pitch, + "start_ms": start_ms + time_offset_ms, + "end_ms": end_ms + time_offset_ms, + } + ) + + # Filter notes to draw based on the mode and time window + notes_to_draw = [] + filter_end_time = display_until_ms + for n in all_notes: + if n["start_ms"] < filter_end_time: + notes_to_draw.append(n) + + if mode == "prefill": + for note in notes_to_draw: + if note["end_ms"] > cutoff_ms: + note["end_ms"] = cutoff_ms + + print(f"Found {len(notes_to_draw)} notes to draw for mode '{mode}'.") + + # --- SHARED CALCULATIONS --- + img_width_px = int(IMG_WIDTH_CM / 2.54 * DPI) + img_height_px = int(IMG_HEIGHT_CM / 2.54 * DPI) + num_pitches = MAX_PITCH - MIN_PITCH + 1 + pitch_row_height = img_height_px / num_pitches + y_spacing = pitch_row_height * NOTE_SPACING_RATIO / 2.0 + line_half_width = LINE_THICKNESS_PX / 2.0 + + # Boundary calculations + cutoff_px_center = (cutoff_ms / total_duration_ms) * img_width_px + cutoff_boundary_px = cutoff_px_center + line_half_width + + recalc_boundary_start_ms = 0 + if mode in ["recalc", "generate"]: + truncated_notes = [ + n for n in notes_to_draw if n["start_ms"] <= cutoff_ms < n["end_ms"] + ] + if truncated_notes: + first_trunc_start_ms = min(n["start_ms"] for n in truncated_notes) + recalc_boundary_start_ms = first_trunc_start_ms + + yellow_end_ms = 0 + if mode == "prefill": + yellow_end_ms = math.floor(cutoff_ms / CHUNK_LEN_MS) * CHUNK_LEN_MS + elif mode in ["recalc", "generate"]: + yellow_end_ms = recalc_boundary_start_ms + + # ### --- SIMPLIFIED --- ###: This is now the only drawing implementation + # --- MATPLOTLIB (VECTOR) IMPLEMENTATION --- + fig_width_in = IMG_WIDTH_CM / 2.54 + fig_height_in = IMG_HEIGHT_CM / 2.54 + + fig, ax = plt.subplots(figsize=(fig_width_in, fig_height_in), dpi=DPI) + ax.set_rasterization_zorder(0) + fig.subplots_adjust(left=0, right=1, top=1, bottom=0) + ax.axis("off") + + ax.set_xlim(0, img_width_px) + ax.set_ylim(0, img_height_px) + ax.invert_yaxis() + ax.set_facecolor("white") + + # 1. Backgrounds + if yellow_end_ms > 0: + x_boundary_yellow = (yellow_end_ms / total_duration_ms) * img_width_px + ax.add_patch( + patches.Rectangle( + (0, 0), + x_boundary_yellow, + img_height_px, + facecolor=COLOR_YELLOW_BG, + zorder=1, + ) + ) + + if mode in ["recalc", "generate"] and recalc_boundary_start_ms < cutoff_ms: + x_start_red = ( + recalc_boundary_start_ms / total_duration_ms + ) * img_width_px + ax.add_patch( + patches.Rectangle( + (x_start_red, 0), + cutoff_boundary_px - x_start_red, + img_height_px, + facecolor=COLOR_RED_BG, + zorder=1, + ) + ) + + # 2. Notes + for note in notes_to_draw: + note_end_ms = min(note["end_ms"], display_until_ms) + x0 = (note["start_ms"] / total_duration_ms) * img_width_px + x1 = (note_end_ms / total_duration_ms) * img_width_px + if x1 <= x0: + continue + + pitch_index = MAX_PITCH - note["pitch"] + y0 = pitch_index * pitch_row_height + y_spacing + height = pitch_row_height - 2 * y_spacing + + is_split_note = ( + mode in ["recalc", "generate"] + and note["start_ms"] <= cutoff_ms < note["end_ms"] + ) + is_new_note = mode == "generate" and note["start_ms"] >= cutoff_ms + + if is_split_note: + ax.add_patch( + patches.Rectangle( + (x0, y0), + cutoff_boundary_px - x0, + height, + facecolor="black", + zorder=2, + ) + ) + ax.add_patch( + patches.Rectangle( + (cutoff_boundary_px, y0), + x1 - cutoff_boundary_px, + height, + facecolor=COLOR_BLUE_NOTE, + zorder=2, + ) + ) + elif is_new_note: + ax.add_patch( + patches.Rectangle( + (x0, y0), + x1 - x0, + height, + facecolor=COLOR_GREEN_NOTE, + zorder=2, + ) + ) + else: + ax.add_patch( + patches.Rectangle( + (x0, y0), x1 - x0, height, facecolor="black", zorder=2 + ) + ) + + # 3. Lines + if yellow_end_ms > 0: + for t in range(CHUNK_LEN_MS, int(yellow_end_ms) + 1, CHUNK_LEN_MS): + x_chunk_center = (t / total_duration_ms) * img_width_px + ax.add_patch( + patches.Rectangle( + (x_chunk_center - line_half_width, 0), + LINE_THICKNESS_PX, + img_height_px, + facecolor=COLOR_GRAY_LINE, + zorder=3, + ) + ) + + # Convert pixel-based line styles to Matplotlib's point-based system + lw_pt = LINE_THICKNESS_PX * 72 / DPI + dash_pt = [d * 72 / DPI for d in DASHED_LINE_PATTERN_PX] + ax.axvline( + x=cutoff_px_center, + color=COLOR_GRAY_LINE, + linestyle=(0, dash_pt), + linewidth=lw_pt, + zorder=4, + ) + + # Add border + ax.add_patch( + patches.Rectangle( + (0, 0), + img_width_px, + img_height_px, + facecolor="none", + edgecolor="black", + linewidth=lw_pt, + zorder=5, + ) + ) + + plt.savefig(output_path, format="pdf", bbox_inches="tight", pad_inches=0) + plt.close(fig) + + print(f"Successfully generated PDF asset: {output_path}") + + +def main(): + parser = argparse.ArgumentParser( + description="Generate a minimalist vector (PDF) piano roll from a MIDI file.", + formatter_class=argparse.RawTextHelpFormatter, + ) + parser.add_argument( + "midi_path", type=Path, help="Path to the input MIDI file." + ) + parser.add_argument( + "--output-path", + type=Path, + required=True, + help="Path to save the output PDF image.", + ) + parser.add_argument( + "--cutoff-ms", + type=int, + required=True, + help="Time in ms to capture notes up to. This is the main point of interest.", + ) + parser.add_argument( + "--total-duration-ms", + type=int, + required=True, + help="Total time duration the piano roll's width should represent.", + ) + parser.add_argument( + "--truncate", + action="store_true", + help="If set, notes that end after the cutoff time will be visually cut off (only in 'none' mode).", + ) + parser.add_argument( + "--mode", + type=str, + default="none", + choices=["none", "prefill", "recalc", "generate"], + help="Selects the visualization mode for the figure.", + ) + parser.add_argument( + "--generation-len-ms", + type=int, + help="Required for 'generate' mode. How many ms of notes to show after the cutoff.", + ) + + args = parser.parse_args() + + if args.output_path.suffix.lower() != ".pdf": + args.output_path = args.output_path.with_suffix(".pdf") + + if args.mode == "generate" and args.generation_len_ms is None: + parser.error( + "--generation-len-ms is required when using --mode generate" + ) + + args.output_path.parent.mkdir(parents=True, exist_ok=True) + create_pianoroll_pdf(**vars(args)) + + +if __name__ == "__main__": + main() From ea773dd49021878c34f9768fbda169e7e5efc14e Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 23 Jul 2025 20:09:39 +0100 Subject: [PATCH 12/21] rm --- demo/paper/figure/figure.py | 345 ------------------------------------ 1 file changed, 345 deletions(-) delete mode 100644 demo/paper/figure/figure.py diff --git a/demo/paper/figure/figure.py b/demo/paper/figure/figure.py deleted file mode 100644 index e9e790b..0000000 --- a/demo/paper/figure/figure.py +++ /dev/null @@ -1,345 +0,0 @@ -# /// script -# requires-python = ">=3.8" -# dependencies = [ -# # ### --- SIMPLIFIED --- ###: Pillow is no longer needed -# "matplotlib", -# "ariautils @ https://github.com/EleutherAI/aria-utils.git" -# ] -# /// - -import argparse -import math -from pathlib import Path -from typing import List, Literal, Optional, TypedDict - -# ### --- SIMPLIFIED --- ###: Matplotlib is now the primary plotting library -import matplotlib.pyplot as plt -import matplotlib.patches as patches -from ariautils.midi import MidiDict - -# --- Hardcoded Configuration Constants --- -# You can modify these values directly as needed. - -# Note Pitch Range (inclusive) -MIN_PITCH = 52 -MAX_PITCH = 80 - -# Image Dimensions and Resolution -DPI = 500 -IMG_WIDTH_CM = 35 -IMG_HEIGHT_CM = 4 - -# Visual Style -NOTE_SPACING_RATIO = ( - 0.4 # 0.0 for no spacing, 0.15 means 15% of row height is spacing -) -LINE_THICKNESS_PX = 12 -DASHED_LINE_PATTERN_PX = (61, 17) # (dash_length, gap_length) in pixels - -# Figure Mode Constants -CHUNK_LEN_MS = 690 # Length of a prefill chunk for 'prefill' mode - -# Colors -COLOR_YELLOW_BG = "#FFF9C4" -COLOR_GRAY_LINE = "#A9A9A9" -COLOR_RED_BG = "#FFB3B3" -COLOR_BLUE_NOTE = "#4A90E2" -COLOR_GREEN_NOTE = "#4CAF50" - -# --- End of Configuration --- - - -class NoteToDraw(TypedDict): - """A simple structure to hold note information for drawing.""" - - pitch: int - start_ms: int - end_ms: int - - -def create_pianoroll_pdf( - midi_path: Path, - output_path: Path, - cutoff_ms: int, - total_duration_ms: int, - truncate: bool, - mode: Literal["none", "prefill", "recalc", "generate"], - generation_len_ms: Optional[int] = None, -): - """ - Generates a minimalist, vector-based (PDF) piano roll from a MIDI file. - """ - print(f"Loading MIDI file: {midi_path.name}") - try: - midi_data = MidiDict.from_midi(midi_path) - except Exception as e: - print(f"Error loading MIDI file with MidiDict: {e}") - return - - # Determine the total time span to display notes for - display_until_ms = total_duration_ms - if mode == "generate" and generation_len_ms is not None: - display_until_ms = cutoff_ms + generation_len_ms - - # 1. Filter and collect all potentially visible notes - all_notes: List[NoteToDraw] = [] - time_offset_ms = -300 - for note_msg in midi_data.note_msgs: - pitch = note_msg["data"]["pitch"] - if not (MIN_PITCH <= pitch <= MAX_PITCH): - continue - start_ms = midi_data.tick_to_ms(note_msg["data"]["start"]) - if start_ms > display_until_ms: - continue - end_ms = midi_data.tick_to_ms(note_msg["data"]["end"]) - if mode == "none" and truncate and end_ms > cutoff_ms: - end_ms = cutoff_ms - all_notes.append( - { - "pitch": pitch, - "start_ms": start_ms + time_offset_ms, - "end_ms": end_ms + time_offset_ms, - } - ) - - # Filter notes to draw based on the mode and time window - notes_to_draw = [] - filter_end_time = display_until_ms - for n in all_notes: - if n["start_ms"] < filter_end_time: - notes_to_draw.append(n) - - if mode == "prefill": - for note in notes_to_draw: - if note["end_ms"] > cutoff_ms: - note["end_ms"] = cutoff_ms - - print(f"Found {len(notes_to_draw)} notes to draw for mode '{mode}'.") - - # --- SHARED CALCULATIONS --- - img_width_px = int(IMG_WIDTH_CM / 2.54 * DPI) - img_height_px = int(IMG_HEIGHT_CM / 2.54 * DPI) - num_pitches = MAX_PITCH - MIN_PITCH + 1 - pitch_row_height = img_height_px / num_pitches - y_spacing = pitch_row_height * NOTE_SPACING_RATIO / 2.0 - line_half_width = LINE_THICKNESS_PX / 2.0 - - # Boundary calculations - cutoff_px_center = (cutoff_ms / total_duration_ms) * img_width_px - cutoff_boundary_px = cutoff_px_center + line_half_width - - recalc_boundary_start_ms = 0 - if mode in ["recalc", "generate"]: - truncated_notes = [ - n for n in notes_to_draw if n["start_ms"] <= cutoff_ms < n["end_ms"] - ] - if truncated_notes: - first_trunc_start_ms = min(n["start_ms"] for n in truncated_notes) - recalc_boundary_start_ms = first_trunc_start_ms - - yellow_end_ms = 0 - if mode == "prefill": - yellow_end_ms = math.floor(cutoff_ms / CHUNK_LEN_MS) * CHUNK_LEN_MS - elif mode in ["recalc", "generate"]: - yellow_end_ms = recalc_boundary_start_ms - - # ### --- SIMPLIFIED --- ###: This is now the only drawing implementation - # --- MATPLOTLIB (VECTOR) IMPLEMENTATION --- - fig_width_in = IMG_WIDTH_CM / 2.54 - fig_height_in = IMG_HEIGHT_CM / 2.54 - - fig, ax = plt.subplots(figsize=(fig_width_in, fig_height_in), dpi=DPI) - ax.set_rasterization_zorder(0) - fig.subplots_adjust(left=0, right=1, top=1, bottom=0) - ax.axis("off") - - ax.set_xlim(0, img_width_px) - ax.set_ylim(0, img_height_px) - ax.invert_yaxis() - ax.set_facecolor("white") - - # 1. Backgrounds - if yellow_end_ms > 0: - x_boundary_yellow = (yellow_end_ms / total_duration_ms) * img_width_px - ax.add_patch( - patches.Rectangle( - (0, 0), - x_boundary_yellow, - img_height_px, - facecolor=COLOR_YELLOW_BG, - zorder=1, - ) - ) - - if mode in ["recalc", "generate"] and recalc_boundary_start_ms < cutoff_ms: - x_start_red = ( - recalc_boundary_start_ms / total_duration_ms - ) * img_width_px - ax.add_patch( - patches.Rectangle( - (x_start_red, 0), - cutoff_boundary_px - x_start_red, - img_height_px, - facecolor=COLOR_RED_BG, - zorder=1, - ) - ) - - # 2. Notes - for note in notes_to_draw: - note_end_ms = min(note["end_ms"], display_until_ms) - x0 = (note["start_ms"] / total_duration_ms) * img_width_px - x1 = (note_end_ms / total_duration_ms) * img_width_px - if x1 <= x0: - continue - - pitch_index = MAX_PITCH - note["pitch"] - y0 = pitch_index * pitch_row_height + y_spacing - height = pitch_row_height - 2 * y_spacing - - is_split_note = ( - mode in ["recalc", "generate"] - and note["start_ms"] <= cutoff_ms < note["end_ms"] - ) - is_new_note = mode == "generate" and note["start_ms"] >= cutoff_ms - - if is_split_note: - ax.add_patch( - patches.Rectangle( - (x0, y0), - cutoff_boundary_px - x0, - height, - facecolor="black", - zorder=2, - ) - ) - ax.add_patch( - patches.Rectangle( - (cutoff_boundary_px, y0), - x1 - cutoff_boundary_px, - height, - facecolor=COLOR_BLUE_NOTE, - zorder=2, - ) - ) - elif is_new_note: - ax.add_patch( - patches.Rectangle( - (x0, y0), - x1 - x0, - height, - facecolor=COLOR_GREEN_NOTE, - zorder=2, - ) - ) - else: - ax.add_patch( - patches.Rectangle( - (x0, y0), x1 - x0, height, facecolor="black", zorder=2 - ) - ) - - # 3. Lines - if yellow_end_ms > 0: - for t in range(CHUNK_LEN_MS, int(yellow_end_ms) + 1, CHUNK_LEN_MS): - x_chunk_center = (t / total_duration_ms) * img_width_px - ax.add_patch( - patches.Rectangle( - (x_chunk_center - line_half_width, 0), - LINE_THICKNESS_PX, - img_height_px, - facecolor=COLOR_GRAY_LINE, - zorder=3, - ) - ) - - # Convert pixel-based line styles to Matplotlib's point-based system - lw_pt = LINE_THICKNESS_PX * 72 / DPI - dash_pt = [d * 72 / DPI for d in DASHED_LINE_PATTERN_PX] - ax.axvline( - x=cutoff_px_center, - color=COLOR_GRAY_LINE, - linestyle=(0, dash_pt), - linewidth=lw_pt, - zorder=4, - ) - - # Add border - ax.add_patch( - patches.Rectangle( - (0, 0), - img_width_px, - img_height_px, - facecolor="none", - edgecolor="black", - linewidth=lw_pt, - zorder=5, - ) - ) - - plt.savefig(output_path, format="pdf", bbox_inches="tight", pad_inches=0) - plt.close(fig) - - print(f"Successfully generated PDF asset: {output_path}") - - -def main(): - parser = argparse.ArgumentParser( - description="Generate a minimalist vector (PDF) piano roll from a MIDI file.", - formatter_class=argparse.RawTextHelpFormatter, - ) - parser.add_argument( - "midi_path", type=Path, help="Path to the input MIDI file." - ) - parser.add_argument( - "--output-path", - type=Path, - required=True, - help="Path to save the output PDF image.", - ) - parser.add_argument( - "--cutoff-ms", - type=int, - required=True, - help="Time in ms to capture notes up to. This is the main point of interest.", - ) - parser.add_argument( - "--total-duration-ms", - type=int, - required=True, - help="Total time duration the piano roll's width should represent.", - ) - parser.add_argument( - "--truncate", - action="store_true", - help="If set, notes that end after the cutoff time will be visually cut off (only in 'none' mode).", - ) - parser.add_argument( - "--mode", - type=str, - default="none", - choices=["none", "prefill", "recalc", "generate"], - help="Selects the visualization mode for the figure.", - ) - parser.add_argument( - "--generation-len-ms", - type=int, - help="Required for 'generate' mode. How many ms of notes to show after the cutoff.", - ) - - args = parser.parse_args() - - if args.output_path.suffix.lower() != ".pdf": - args.output_path = args.output_path.with_suffix(".pdf") - - if args.mode == "generate" and args.generation_len_ms is None: - parser.error( - "--generation-len-ms is required when using --mode generate" - ) - - args.output_path.parent.mkdir(parents=True, exist_ok=True) - create_pianoroll_pdf(**vars(args)) - - -if __name__ == "__main__": - main() From d87299baf9e1fdcd079d255cba9a2dbce49854d0 Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 29 Jul 2025 15:24:22 +0100 Subject: [PATCH 13/21] fix mido deadlock in MIDI IO --- demo/demo_mlx.py | 372 +++++++++++++++++++++++------------------------ 1 file changed, 179 insertions(+), 193 deletions(-) diff --git a/demo/demo_mlx.py b/demo/demo_mlx.py index 61da844..65d3c86 100644 --- a/demo/demo_mlx.py +++ b/demo/demo_mlx.py @@ -6,7 +6,6 @@ import uuid import random import logging -import contextlib import threading import queue import json @@ -24,13 +23,9 @@ from aria.config import load_model_config from aria.run import _get_embedding -VIRTUAL_PLAYBACK_PORT = "Playback MIDI Port" -VIRTUAL_PERFORMANCE_PORT = "Performance MIDI Port" -VIRTUAL_CONTROL_PORT = "Control MIDI Port" - EMBEDDING_OFFSET: int = 0 DTYPE = mx.float32 -MAX_SEQ_LEN: int = 4096 +MAX_SEQ_LEN: int = 2048 PREFILL_CHUNK_SIZE_L: int = 128 PREFILL_CHUNK_SIZE: int = 16 RECALC_DUR_PREFILL_CHUNK_SIZE: int = 8 @@ -38,7 +33,7 @@ BEAM_WIDTH: int = 3 TIME_TOK_WEIGHTING: int = -5 -FIRST_ONSET_BUFFER_MS: int = -200 +FIRST_ONSET_BUFFER_MS: int = -100 MAX_STREAM_DELAY_MS: int = 50 MIN_NOTE_DELTA_MS: int = 0 @@ -999,7 +994,6 @@ def _create_mido_message( channel: int, time_delta_ms: int, ) -> mido.Message: - # Creates a mido message from an event dictionary if msg_dict["pitch"] == "pedal": return mido.Message( "control_change", @@ -1438,7 +1432,7 @@ def capture_and_update_kv( control_sentinel: threading.Event, reset_sentinel: threading.Event, wait_for_close: bool, - midi_input_port: str, + midi_performance_queue: queue.Queue, midi_capture_channel: int, first_msg_epoch_time_ms: int | None = None, ): @@ -1447,7 +1441,7 @@ def capture_and_update_kv( capture_midi_thread = threading.Thread( target=capture_midi_input, kwargs={ - "midi_input_port": midi_input_port, + "midi_performance_queue": midi_performance_queue, "control_sentinel": control_sentinel, "reset_sentinel": reset_sentinel, "received_messages_queue": received_messages_queue, @@ -1472,7 +1466,7 @@ def capture_and_update_kv( def capture_midi_input( - midi_input_port: str, + midi_performance_queue: queue.Queue, control_sentinel: threading.Event, reset_sentinel: threading.Event, received_messages_queue: queue.Queue, @@ -1488,112 +1482,109 @@ def capture_midi_input( pitches_held_down = set() pitches_sustained_by_pedal = set() - logger.info(f"Listening for input on MIDI port: '{midi_input_port}'") + while not midi_performance_queue.empty(): + try: + midi_performance_queue.get_nowait() + except queue.Empty: + break - # Clear any buffered MIDI messages before starting - with mido.open_input(midi_input_port) as midi_input: - for _ in midi_input.iter_pending(): - pass + logger.info("Listening for input") + logger.info("Commencing generation upon keypress or control signal") - with mido.open_input(midi_input_port) as midi_input: - logger.info("Commencing generation upon keypress or control signal") + while True: + epoch_time_ms = get_epoch_time_ms() + active_notes = pitches_held_down.union(pitches_sustained_by_pedal) + should_stop = not wait_for_close or not active_notes + if reset_sentinel.is_set() or ( + control_sentinel.is_set() and should_stop + ): + break - while True: - epoch_time_ms = get_epoch_time_ms() - active_notes = pitches_held_down.union(pitches_sustained_by_pedal) - should_stop = not wait_for_close or not active_notes - if reset_sentinel.is_set() or ( - control_sentinel.is_set() and should_stop - ): - break + try: + msg = midi_performance_queue.get(block=True, timeout=0.01) + except queue.Empty: + continue - msg = midi_input.receive(block=False) - if not msg: - time.sleep(0.01) - continue + if msg.is_meta or msg.type == "program_change": + continue - if msg.is_meta or msg.type == "program_change": - continue + msg.channel = midi_capture_channel + if prev_msg_epoch_time_ms is None: + msg.time = 0 + else: + msg.time = epoch_time_ms - prev_msg_epoch_time_ms - msg.channel = midi_capture_channel - if prev_msg_epoch_time_ms is None: - msg.time = 0 - else: - msg.time = epoch_time_ms - prev_msg_epoch_time_ms - - prev_msg_epoch_time_ms = epoch_time_ms - logger.info(f"Received message: [{msg}]") - - match msg.type: - case "note_on" if msg.velocity > 0: - if first_on_msg_epoch_ms is None: - first_on_msg_epoch_ms = ( - get_epoch_time_ms() - HARDWARE_INPUT_LATENCY_MS - ) - pitches_held_down.add(msg.note) - if pedal_down: - pitches_sustained_by_pedal.add(msg.note) - received_messages_queue.put(msg) - - case "note_off" | "note_on": - # Note-off - pitches_held_down.discard(msg.note) - received_messages_queue.put(msg) - - case "control_change" if msg.control == 64: - if msg.value >= 64: - pedal_down = True - pitches_sustained_by_pedal.update(pitches_held_down) - else: - pedal_down = False - pitches_sustained_by_pedal.clear() - received_messages_queue.put(msg) - - active_pitches = pitches_held_down.union(pitches_sustained_by_pedal) - num_active_pitches = len(active_pitches) - logger.info(f"Active pitches ({num_active_pitches}): {active_pitches}") - - time_offset = get_epoch_time_ms() - prev_msg_epoch_time_ms - for pitch in pitches_held_down: - note_off_msg = mido.Message( - "note_off", - note=pitch, - channel=midi_capture_channel, - time=time_offset, - ) - received_messages_queue.put(note_off_msg) - time_offset = 0 - - received_messages_queue.put( - mido.Message( - "control_change", - control=64, - value=0, - channel=midi_capture_channel, - time=0, - ) + prev_msg_epoch_time_ms = epoch_time_ms + logger.info(f"Received message: [{msg}]") + + match msg.type: + case "note_on" if msg.velocity > 0: + if first_on_msg_epoch_ms is None: + first_on_msg_epoch_ms = ( + get_epoch_time_ms() - HARDWARE_INPUT_LATENCY_MS + ) + pitches_held_down.add(msg.note) + if pedal_down: + pitches_sustained_by_pedal.add(msg.note) + received_messages_queue.put(msg) + + case "note_off" | "note_on": + # Note-off + pitches_held_down.discard(msg.note) + received_messages_queue.put(msg) + + case "control_change" if msg.control == 64: + if msg.value >= 64: + pedal_down = True + pitches_sustained_by_pedal.update(pitches_held_down) + else: + pedal_down = False + pitches_sustained_by_pedal.clear() + received_messages_queue.put(msg) + + active_pitches = pitches_held_down.union(pitches_sustained_by_pedal) + num_active_pitches = len(active_pitches) + logger.info(f"Active pitches ({num_active_pitches}): {active_pitches}") + + time_offset = get_epoch_time_ms() - prev_msg_epoch_time_ms + for pitch in pitches_held_down: + note_off_msg = mido.Message( + "note_off", + note=pitch, + channel=midi_capture_channel, + time=time_offset, + ) + received_messages_queue.put(note_off_msg) + time_offset = 0 + + received_messages_queue.put( + mido.Message( + "control_change", + control=64, + value=0, + channel=midi_capture_channel, + time=0, ) + ) - received_messages_queue.put(None) - results_queue.put((first_on_msg_epoch_ms, num_active_pitches)) + received_messages_queue.put(None) + results_queue.put((first_on_msg_epoch_ms, num_active_pitches)) def play_midi_file( midi_through_port: str, - midi_in_port: str, + midi_performance_queue: queue.Queue, midi_path: str, - currently_streaming_sentinel: threading.Event, + currently_generating_sentinel: threading.Event, reset_sentinel: threading.Event, ): - def _send_delayed_message(port, msg): - port.send(msg) + def _send_delayed_message(_midi_performance_queue: queue.Queue, msg): + _midi_performance_queue.put(msg) logger.debug(f"SENT: {msg}") logger = get_logger("FILE") logger.info(f"Playing {midi_path} on through-port '{midi_through_port}'") - logger.info( - f"Simulating input to port '{midi_in_port}' with {HARDWARE_INPUT_LATENCY_MS}ms latency" - ) + logger.info(f"Simulating input with {HARDWARE_INPUT_LATENCY_MS}ms latency") if BASE_OUTPUT_LATENCY_MS > 0: midi_dict = MidiDict.from_midi(midi_path) @@ -1605,21 +1596,20 @@ def _send_delayed_message(port, msg): time.sleep(1) with mido.open_output(midi_through_port) as through_port: - with mido.open_output(midi_in_port) as in_port: - for msg in mid.play(): - if reset_sentinel.is_set(): - logger.debug("Exiting") - return - - if currently_streaming_sentinel.is_set() is False: - through_port.send(msg) - - timer = threading.Timer( - interval=HARDWARE_INPUT_LATENCY_MS / 1000.0, - function=_send_delayed_message, - args=[in_port, msg], - ) - timer.start() + for msg in mid.play(): + if reset_sentinel.is_set(): + logger.debug("Exiting") + return + + if currently_generating_sentinel.is_set() is False: + through_port.send(msg) + + timer = threading.Timer( + interval=HARDWARE_INPUT_LATENCY_MS / 1000.0, + function=_send_delayed_message, + args=[midi_performance_queue, msg], + ) + timer.start() def listen_for_keypress_control_signal( @@ -1641,38 +1631,43 @@ def listen_for_keypress_control_signal( def _listen( - midi_input_port: str, + midi_control_queue: queue.Queue, reset_sentinel: threading.Event, logger: logging.Logger, midi_control_signal: int | None = None, midi_reset_control_signal: int | None = None, ): - time.sleep(1) + while not midi_control_queue.empty(): + try: + midi_control_queue.get_nowait() + except queue.Empty: + break + logger.info( - f"Listening for takeover signal ({midi_control_signal}) and reset signal ({midi_reset_control_signal}) on MIDI port: '{midi_input_port}'" + f"Listening for takeover signal ({midi_control_signal}) and reset signal ({midi_reset_control_signal}) on control queue." ) - with mido.open_input(midi_input_port) as midi_input: - while not reset_sentinel.is_set(): - msg = midi_input.receive(block=False) - - if msg is None: - time.sleep(0.01) - elif ( - msg.type == "control_change" - and msg.control == midi_control_signal - and msg.value >= 64 - ): - return midi_control_signal - elif ( - msg.type == "control_change" - and msg.control == midi_reset_control_signal - and msg.value >= 64 - ): - return midi_reset_control_signal + while not reset_sentinel.is_set(): + try: + msg = midi_control_queue.get(block=True, timeout=0.01) + except queue.Empty: + continue + + if ( + msg.type == "control_change" + and msg.control == midi_control_signal + and msg.value >= 64 + ): + return midi_control_signal + elif ( + msg.type == "control_change" + and msg.control == midi_reset_control_signal + and msg.value >= 64 + ): + return midi_reset_control_signal def listen_for_midi_control_signal( - midi_input_port: str, + midi_control_queue: queue.Queue, control_sentinel: threading.Event, reset_sentinel: threading.Event, midi_control_signal: int | None = None, @@ -1681,8 +1676,9 @@ def listen_for_midi_control_signal( logger = get_logger("MIDI-CONTROL") while not reset_sentinel.is_set(): + time.sleep(1) signal_received = _listen( - midi_input_port=midi_input_port, + midi_control_queue=midi_control_queue, reset_sentinel=reset_sentinel, midi_control_signal=midi_control_signal, midi_reset_control_signal=midi_reset_control_signal, @@ -1694,25 +1690,24 @@ def listen_for_midi_control_signal( if signal_received == midi_control_signal: control_sentinel.set() - time.sleep(2) elif signal_received == midi_reset_control_signal: reset_sentinel.set() control_sentinel.set() - logger.debug("Exiting") + logger.debug("Exiting MIDI control listener") def run( model: TransformerLM, - midi_in_performance_port: str, - midi_in_control_port: str, + midi_performance_queue: queue.Queue, + midi_control_queue: queue.Queue, midi_through_port: str | None, midi_out_port: str | None, midi_path: str | None, midi_save_path: str | None, midi_control_signal: int, midi_reset_control_signal: int, - reset_sentinel: threading.Event, # Changed from string to Event + reset_sentinel: threading.Event, wait_for_close: bool, temperature: float, min_p: float, @@ -1730,13 +1725,13 @@ def run( if midi_path: play_file_thread = threading.Thread( target=play_midi_file, - args=( - midi_through_port, - VIRTUAL_PLAYBACK_PORT, - midi_path, - currently_generating_sentinel, - reset_sentinel, - ), + kwargs={ + "midi_through_port": midi_through_port, + "midi_performance_queue": midi_performance_queue, + "midi_path": midi_path, + "currently_generating_sentinel": currently_generating_sentinel, + "reset_sentinel": reset_sentinel, + }, ) else: play_file_thread = None @@ -1748,7 +1743,7 @@ def run( midi_control_thread = threading.Thread( target=listen_for_midi_control_signal, kwargs={ - "midi_input_port": midi_in_control_port, + "midi_control_queue": midi_control_queue, "control_sentinel": control_sentinel, "reset_sentinel": reset_sentinel, "midi_control_signal": midi_control_signal, @@ -1769,7 +1764,7 @@ def run( control_sentinel=control_sentinel, reset_sentinel=reset_sentinel, wait_for_close=wait_for_close, - midi_input_port=midi_in_performance_port, + midi_performance_queue=midi_performance_queue, midi_capture_channel=0, ) ) @@ -1799,7 +1794,7 @@ def run( midi.save(midi_save_path) curr_midi_channel += 1 - if curr_midi_channel == 9: + if curr_midi_channel == 9: # Skip drum channel curr_midi_channel += 1 control_sentinel.clear() @@ -1814,7 +1809,7 @@ def run( control_sentinel=control_sentinel, reset_sentinel=reset_sentinel, wait_for_close=wait_for_close, - midi_input_port=midi_in_performance_port, + midi_performance_queue=midi_performance_queue, midi_capture_channel=curr_midi_channel, first_msg_epoch_time_ms=first_on_msg_epoch_ms, ) @@ -1845,38 +1840,25 @@ def insert_embedding( def forward_midi_input_port( midi_input_port: str, - midi_performance_port: str, - midi_control_port: str, - create_virtual_input_port: bool = False, + midi_control_queue: queue.Queue, + midi_performance_queue: queue.Queue | None, ): logger = get_logger("MIDI-FORWARD") + logger.info(f"Forwarding MIDI from port: '{midi_input_port}'") - try: - with contextlib.ExitStack() as stack: - if create_virtual_input_port: - in_port = stack.enter_context( - mido.open_ioport(midi_input_port, virtual=True) - ) - else: - in_port = stack.enter_context(mido.open_input(midi_input_port)) - - perf_port = stack.enter_context( - mido.open_output(midi_performance_port, virtual=True) - ) - ctrl_port = stack.enter_context( - mido.open_output(midi_control_port, virtual=True) - ) - - logger.info( - f"Forwarding MIDI port '{midi_input_port}' to " - f"['{midi_performance_port}', '{midi_control_port}']" - ) + if midi_performance_queue is None: + logger.info( + f"MIDI file provided - only forwarding {midi_input_port} to control queue" + ) + try: + with mido.open_input(midi_input_port) as midi_in: while True: - msg = in_port.receive(block=True) + msg = midi_in.receive(block=True) if msg: - perf_port.send(msg) - ctrl_port.send(msg) + midi_control_queue.put(msg) + if midi_performance_queue is not None: + midi_performance_queue.put(msg) except (Exception, KeyboardInterrupt) as e: logger.error(f"Error in MIDI forwarder: {e}") @@ -1896,26 +1878,29 @@ def main(args): assert (args.midi_path and os.path.isfile(args.midi_path)) or args.midi_in - forwarder_thread = threading.Thread( - target=forward_midi_input_port, - kwargs={ - "midi_input_port": ( - VIRTUAL_PLAYBACK_PORT if args.midi_path else args.midi_in - ), - "midi_performance_port": VIRTUAL_PERFORMANCE_PORT, - "midi_control_port": VIRTUAL_CONTROL_PORT, - "create_virtual_input_port": True if args.midi_path else False, - }, - daemon=True, - ) - forwarder_thread.start() + midi_performance_queue = queue.Queue() + midi_control_queue = queue.Queue() + + if args.midi_in: + forwarder_thread = threading.Thread( + target=forward_midi_input_port, + kwargs={ + "midi_input_port": args.midi_in, + "midi_control_queue": midi_control_queue, + "midi_performance_queue": ( + midi_performance_queue if args.midi_path is None else None + ), + }, + daemon=True, + ) + forwarder_thread.start() reset_sentinel = threading.Event() while True: run( model=model, - midi_in_performance_port=VIRTUAL_PERFORMANCE_PORT, - midi_in_control_port=VIRTUAL_CONTROL_PORT, + midi_performance_queue=midi_performance_queue, + midi_control_queue=midi_control_queue, midi_through_port=args.midi_through, midi_out_port=args.midi_out, midi_path=args.midi_path, @@ -2001,6 +1986,7 @@ def close_notes(midi_out_port: str): # TODO: Debug issue with incorrect tokens being generated (model or mlx issue?) +# TODO: Fix bug with context length (metal error?) if __name__ == "__main__": args = parse_args() From 928e208975e2bfd832e3fd81a899b53e87fa0e8d Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 20 Aug 2025 13:55:04 +0100 Subject: [PATCH 14/21] add token masks --- demo/demo_mlx.py | 34 +++++++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/demo/demo_mlx.py b/demo/demo_mlx.py index 65d3c86..f77cb29 100644 --- a/demo/demo_mlx.py +++ b/demo/demo_mlx.py @@ -31,14 +31,14 @@ RECALC_DUR_PREFILL_CHUNK_SIZE: int = 8 RECALC_DUR_BUFFER_MS: int = 100 -BEAM_WIDTH: int = 3 +BEAM_WIDTH: int = 5 TIME_TOK_WEIGHTING: int = -5 -FIRST_ONSET_BUFFER_MS: int = -100 +FIRST_ONSET_BUFFER_MS: int = -200 MAX_STREAM_DELAY_MS: int = 50 MIN_NOTE_DELTA_MS: int = 0 MIN_PEDAL_DELTA_MS: int = 0 -MIN_NOTE_LENGTH_MS: int = 0 +MIN_NOTE_LENGTH_MS: int = 10 HARDWARE_INPUT_LATENCY_MS: int = 0 BASE_OUTPUT_LATENCY_MS: int = 0 VELOCITY_OUTPUT_LATENCY_MS: dict[int, int] = {v: 0 for v in range(0, 127, 10)} @@ -367,7 +367,7 @@ def load_model( model.eval() if args.quantize: - nn.quantize(model.model, group_size=64, bits=8) + nn.quantize(model.model, group_size=32, bits=8) logger.info( f"Finished initializing model - took {time.time() - init_start_time_s:.4f} seconds" @@ -633,6 +633,13 @@ def decode_tokens( if control_sentinel.is_set(): control_sentinel.clear() + last_tok_is_pedal = False + dur_ids = [tokenizer.tok_to_id[idx] for idx in tokenizer.dur_tokens] + dur_mask_ids = [ + tokenizer.tok_to_id[("dur", dur_ms)] + for dur_ms in range(0, MIN_NOTE_LENGTH_MS, 10) + ] + while (not control_sentinel.is_set()) and idx < MAX_SEQ_LEN: decode_one_start_time_s = time.time() prev_tok_id = enc_seq[0, idx - 1] @@ -648,14 +655,16 @@ def decode_tokens( f"Sampled logits for positions {idx} by inserting {prev_tok} at position {idx-1}" ) - # logits[0, tokenizer.tok_to_id[tokenizer.ped_on_tok]] += 2 + logits[:, tokenizer.tok_to_id[tokenizer.ped_off_tok]] += 3 # Manual adj logits[:, tokenizer.tok_to_id[tokenizer.dim_tok]] = float("-inf") + + logits[:, dur_mask_ids] = float("-inf") + if last_tok_is_pedal is True: + logits[:, dur_ids] = float("-inf") + if is_ending is False: logits[:, tokenizer.tok_to_id[tokenizer.eos_tok]] = float("-inf") - for dur_ms in range(0, MIN_NOTE_LENGTH_MS, 10): - logits[:, tokenizer.tok_to_id[("dur", dur_ms)]] = float("-inf") - if temperature > 0.0: probs = mx.softmax(logits / temperature, axis=-1) next_token_ids = sample_min_p(probs, min_p).flatten() @@ -668,6 +677,11 @@ def decode_tokens( f"({(time.time() - decode_one_start_time_s)*1000:.2f}ms) {idx}: {next_token}" ) + if next_token in {tokenizer.ped_on_tok, tokenizer.ped_off_tok}: + last_tok_is_pedal = True + elif isinstance(next_token, tuple) and next_token[0] == "piano": + last_tok_is_pedal = False + if next_token == tokenizer.eos_tok: logger.info("EOS token produced") generated_tokens_queue.put(next_token) @@ -698,7 +712,9 @@ def generate_tokens( generate_start_s = time.time() priming_seq_len = len(priming_seq) - start_idx = max(2, priming_seq_len - 4 * num_preceding_active_pitches - 1) + start_idx = max( + 2, priming_seq_len - 3 * (num_preceding_active_pitches + 2) - 1 + ) enc_seq = mx.array( [ tokenizer.encode( From 3a3865f9f28129aa188c79149a3af0d08376682a Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 20 Aug 2025 21:36:34 +0100 Subject: [PATCH 15/21] fixes --- aria/inference/model_mlx.py | 52 +++++++++++++++++++++++--- demo/demo_mlx.py | 74 ++++++++++++++++++++----------------- 2 files changed, 87 insertions(+), 39 deletions(-) diff --git a/aria/inference/model_mlx.py b/aria/inference/model_mlx.py index 169b30b..a3d5a26 100644 --- a/aria/inference/model_mlx.py +++ b/aria/inference/model_mlx.py @@ -84,14 +84,15 @@ def __call__( self, x: mx.array, input_pos: mx.array, + max_pos: int, offset: int, mask: mx.array, ): assert self.kv_cache is not None, "Cache not initialized" - x += self._att_block( x=self.norm1(x), input_pos=input_pos, + max_pos=max_pos, offset=offset, mask=mask, ) @@ -99,15 +100,25 @@ def __call__( return x - def get_kv(self, k: mx.array, v: mx.array, input_pos: mx.array): + def get_kv( + self, + k: mx.array, + v: mx.array, + input_pos: mx.array, + max_pos: int | None = None, + ): k, v = self.kv_cache.update(k_val=k, v_val=v, input_pos=input_pos) - return k, v + if max_pos is not None: + return k[:, :, : max_pos + 1, :], v[:, :, : max_pos + 1, :] + else: + return k, v def _att_block( self, x: mx.array, input_pos: mx.array, + max_pos: int, offset: int, mask: mx.array, ): @@ -124,7 +135,8 @@ def _att_block( k = apply_rotary_emb_mlx(k, offset=offset) q, k, v = map(lambda x: x.transpose(0, 2, 1, 3), (q, k, v)) - k, v = self.get_kv(k, v, input_pos=input_pos) + k, v = self.get_kv(k, v, input_pos=input_pos, max_pos=max_pos) + wv = mx.fast.scaled_dot_product_attention( q=q, k=k, @@ -159,6 +171,7 @@ def __init__(self, model_config: ModelConfig): TransformerBlock(model_config) for _ in range(model_config.n_layers) ] self.out_layer_norm = nn.LayerNorm(model_config.d_model) + self.kv_ctx = None def fill_condition_kv(self, emb: mx.array): assert self.causal_mask is not None, "Caches must be initialized first" @@ -182,15 +195,23 @@ def __call__( ): assert self.causal_mask is not None, "Caches must be initialized first" - mask = self.causal_mask[None, None, input_pos] + if self.kv_ctx is None: + self.kv_ctx = mx.full( + self.model_config.max_seq_len, 3 + ) # unk_tok id + + max_pos = mx.max(input_pos, axis=0).item() + self.kv_ctx[input_pos] = idxs + self.kv_ctx[max_pos + 1 :] = 3 + mask = self.causal_mask[None, None, input_pos, : max_pos + 1] if pad_idxs is not None: pad_mask = mx.expand_dims(mx.expand_dims(pad_idxs, axis=1), axis=1) mask = mask & ~pad_mask x = self.tok_embeddings(idxs) for layer in self.encode_layers: - x = layer(x, input_pos, offset, mask) + x = layer(x, input_pos, max_pos, offset, mask) x = self.out_layer_norm(x) @@ -235,6 +256,25 @@ def fill_condition_kv(self, cond_emb: mx.array): adapted_emb = self.embedding_adapter(cond_emb) self.model.fill_condition_kv(emb=adapted_emb) + def reset_kv_ctx(self): + self.model.kv_ctx = None + + def get_kv_ctx(self): + # Used for debugging kv-cache validation + _kv_ctx = self.model.kv_ctx + + match self.model.kv_ctx: + case None: + return None + case mx.array(): + _kv_ctx = self.model.kv_ctx.tolist() + if 3 in _kv_ctx: + return _kv_ctx[: _kv_ctx.index(3)] + else: + return _kv_ctx + case _: + raise ValueError + def setup_cache( self, batch_size, diff --git a/demo/demo_mlx.py b/demo/demo_mlx.py index f77cb29..c8c6ee9 100644 --- a/demo/demo_mlx.py +++ b/demo/demo_mlx.py @@ -10,11 +10,9 @@ import queue import json import mido -import torch import mlx.core as mx import mlx.nn as nn -import numpy as np from ariautils.midi import MidiDict, midi_to_dict from ariautils.tokenizer import AbsTokenizer @@ -25,16 +23,16 @@ EMBEDDING_OFFSET: int = 0 DTYPE = mx.float32 -MAX_SEQ_LEN: int = 2048 +MAX_SEQ_LEN: int = 4096 PREFILL_CHUNK_SIZE_L: int = 128 PREFILL_CHUNK_SIZE: int = 16 RECALC_DUR_PREFILL_CHUNK_SIZE: int = 8 RECALC_DUR_BUFFER_MS: int = 100 -BEAM_WIDTH: int = 5 +BEAM_WIDTH: int = 3 TIME_TOK_WEIGHTING: int = -5 FIRST_ONSET_BUFFER_MS: int = -200 -MAX_STREAM_DELAY_MS: int = 50 +MAX_STREAM_DELAY_MS: int = 100 MIN_NOTE_DELTA_MS: int = 0 MIN_PEDAL_DELTA_MS: int = 0 @@ -222,24 +220,18 @@ def decode_one( return logits -def sample_min_p(probs: mx.array, p_base: float): - """See - https://arxiv.org/pdf/2407.01082""" +def sample_min_p(logits: mx.array, p_base: float): + """Min_p sampler in logit space, see - https://arxiv.org/pdf/2407.01082""" + if p_base <= 0.0: + return mx.argmax(logits, axis=-1, keepdims=True) + if p_base >= 1.0: + return mx.random.categorical(logits, num_samples=1) - p_max = mx.max(probs, axis=-1, keepdims=True) - p_scaled = p_base * p_max - mask = probs >= p_scaled - - masked_probs = mx.where(~mask, mx.zeros_like(probs), probs) - sum_masked_probs = mx.sum(masked_probs, axis=-1, keepdims=True) - masked_probs_normalized = masked_probs / sum_masked_probs - - # Dumb workaround for mlx not having categorical probs sampler - next_token = mx.array( - torch.multinomial( - torch.from_numpy(np.array(masked_probs_normalized)), num_samples=1 - ), - dtype=mx.int32, - ) + log_p_max = mx.max(logits, axis=-1, keepdims=True) + log_p_scaled = mx.log(p_base) + log_p_max + mask = logits >= log_p_scaled + masked_logits = mx.where(~mask, -mx.inf, logits) + next_token = mx.random.categorical(masked_logits, num_samples=1) return next_token @@ -473,6 +465,8 @@ def recalc_dur_tokens_chunked( next_logits = logits[:, priming_len - idx] + logger.debug(f"Internal KV-state: {tokenizer.decode(model.get_kv_ctx())}") + return enc_seq, priming_seq, next_logits @@ -492,6 +486,7 @@ def decode_first_tokens( time_tok_id = tokenizer.tok_to_id[tokenizer.time_tok] eos_tok_id = tokenizer.tok_to_id[tokenizer.eos_tok] dim_tok_id = tokenizer.tok_to_id[tokenizer.dim_tok] + ped_off_id = tokenizer.tok_to_id[tokenizer.ped_off_tok] logits = first_token_logits time_since_first_onset_ms = get_epoch_time_ms() - first_on_msg_epoch_ms @@ -519,6 +514,7 @@ def decode_first_tokens( logits[:, tokenizer.tok_to_id[tokenizer.dim_tok]] = float("-inf") logits[:, tokenizer.tok_to_id[tokenizer.eos_tok]] = float("-inf") + logits[:, tokenizer.tok_to_id[tokenizer.ped_off_tok]] = float("-inf") # MLX doesn't have a equivalent of torch topk log_probs = nn.log_softmax(logits, axis=-1) @@ -538,11 +534,19 @@ def decode_first_tokens( logger.debug(f"Calculated top {BEAM_WIDTH} tokens={top_toks}") logger.debug(f"Calculated top {BEAM_WIDTH} scores={top_log_probs.tolist()}") - masked_onset_ids = [ - tokenizer.tok_to_id[tok] - for tok in tokenizer.onset_tokens - if tok[1] < ((time_since_first_onset_ms + buffer_ms) % 5000) - ] + priming_seq_last_onset_ms = tokenizer.calc_length_ms( + priming_seq, onset=True + ) + + if priming_seq_last_onset_ms < time_since_first_onset_ms + buffer_ms: + masked_onset_ids = [ + tokenizer.tok_to_id[tok] + for tok in tokenizer.onset_tokens + if tok[1] < ((time_since_first_onset_ms + buffer_ms) % 5000) + ] + + else: + masked_onset_ids = [] logger.debug( f"Masking onsets for {len(masked_onset_ids)} tokens ({time_since_first_onset_ms + buffer_ms})" @@ -564,9 +568,13 @@ def decode_first_tokens( ) next_log_probs = nn.log_softmax(next_logits, axis=-1) - next_log_probs[:, masked_onset_ids] = float("-inf") + next_log_probs[:, eos_tok_id] = float("-inf") next_log_probs[:, dim_tok_id] = float("-inf") + next_log_probs[:, ped_off_id] = float("-inf") + + if masked_onset_ids: + next_log_probs[:, masked_onset_ids] = float("-inf") if tok_id == time_tok_id: next_log_probs[:, time_tok_id] = float("-inf") @@ -607,9 +615,7 @@ def decode_first_tokens( logger.info( f"Updated KV-Cache by re-inserting {best_tok_1} at position {idx-1}" ) - logger.info( - f"Inserted {best_tok_2} at position {idx} without updating KV-Cache" - ) + logger.debug(f"Internal KV-state: {tokenizer.decode(model.get_kv_ctx())}") return enc_seq, idx + 1 @@ -666,8 +672,7 @@ def decode_tokens( logits[:, tokenizer.tok_to_id[tokenizer.eos_tok]] = float("-inf") if temperature > 0.0: - probs = mx.softmax(logits / temperature, axis=-1) - next_token_ids = sample_min_p(probs, min_p).flatten() + next_token_ids = sample_min_p(logits, min_p).flatten() else: next_token_ids = mx.argmax(logits, axis=-1).flatten() @@ -1713,6 +1718,9 @@ def listen_for_midi_control_signal( logger.debug("Exiting MIDI control listener") +# TODO: Fix issue with pedal being stuck down (send pedal off at end of stream_msgs) +# TODO: Debug, fix, and perhaps refactor the functionality for going back and forth +# - One idea is on resume, to wait to start the clock until the user plays. def run( model: TransformerLM, midi_performance_queue: queue.Queue, From cc125f3333ffbf72c5dbcaa732cea5c936a7042c Mon Sep 17 00:00:00 2001 From: Louis Date: Thu, 21 Aug 2025 15:53:12 +0100 Subject: [PATCH 16/21] add bf16 support --- aria/inference/model_mlx.py | 28 ++++++++++++++++------------ demo/demo_mlx.py | 27 +++++++++++++++++++++------ 2 files changed, 37 insertions(+), 18 deletions(-) diff --git a/aria/inference/model_mlx.py b/aria/inference/model_mlx.py index a3d5a26..51ad32e 100644 --- a/aria/inference/model_mlx.py +++ b/aria/inference/model_mlx.py @@ -84,7 +84,7 @@ def __call__( self, x: mx.array, input_pos: mx.array, - max_pos: int, + max_kv_pos: int | None, offset: int, mask: mx.array, ): @@ -92,7 +92,7 @@ def __call__( x += self._att_block( x=self.norm1(x), input_pos=input_pos, - max_pos=max_pos, + max_kv_pos=max_kv_pos, offset=offset, mask=mask, ) @@ -105,12 +105,12 @@ def get_kv( k: mx.array, v: mx.array, input_pos: mx.array, - max_pos: int | None = None, + max_kv_pos: int | None, ): k, v = self.kv_cache.update(k_val=k, v_val=v, input_pos=input_pos) - if max_pos is not None: - return k[:, :, : max_pos + 1, :], v[:, :, : max_pos + 1, :] + if max_kv_pos is not None: + return k[:, :, : max_kv_pos + 1, :], v[:, :, : max_kv_pos + 1, :] else: return k, v @@ -118,7 +118,7 @@ def _att_block( self, x: mx.array, input_pos: mx.array, - max_pos: int, + max_kv_pos: int | None, offset: int, mask: mx.array, ): @@ -135,7 +135,7 @@ def _att_block( k = apply_rotary_emb_mlx(k, offset=offset) q, k, v = map(lambda x: x.transpose(0, 2, 1, 3), (q, k, v)) - k, v = self.get_kv(k, v, input_pos=input_pos, max_pos=max_pos) + k, v = self.get_kv(k, v, input_pos=input_pos, max_kv_pos=max_kv_pos) wv = mx.fast.scaled_dot_product_attention( q=q, @@ -190,8 +190,10 @@ def __call__( self, idxs: mx.array, input_pos: mx.array, + max_kv_pos: int, offset: int, pad_idxs: mx.array | None = None, + _debug_track_kv: bool = False, ): assert self.causal_mask is not None, "Caches must be initialized first" @@ -200,18 +202,18 @@ def __call__( self.model_config.max_seq_len, 3 ) # unk_tok id - max_pos = mx.max(input_pos, axis=0).item() - self.kv_ctx[input_pos] = idxs - self.kv_ctx[max_pos + 1 :] = 3 + if _debug_track_kv is True: + self.kv_ctx[input_pos] = idxs + self.kv_ctx[input_pos[-1].item() + 1 :] = 3 - mask = self.causal_mask[None, None, input_pos, : max_pos + 1] + mask = self.causal_mask[None, None, input_pos, : max_kv_pos + 1] if pad_idxs is not None: pad_mask = mx.expand_dims(mx.expand_dims(pad_idxs, axis=1), axis=1) mask = mask & ~pad_mask x = self.tok_embeddings(idxs) for layer in self.encode_layers: - x = layer(x, input_pos, max_pos, offset, mask) + x = layer(x, input_pos, max_kv_pos, offset, mask) x = self.out_layer_norm(x) @@ -238,11 +240,13 @@ def __call__( idxs: mx.array, input_pos: mx.array, offset: int, + max_kv_pos: int | None = None, pad_idxs: mx.array | None = None, ): hidden_states = self.model( idxs=idxs, input_pos=input_pos, + max_kv_pos=max_kv_pos, offset=offset, pad_idxs=pad_idxs, ) diff --git a/demo/demo_mlx.py b/demo/demo_mlx.py index c8c6ee9..4146c9e 100644 --- a/demo/demo_mlx.py +++ b/demo/demo_mlx.py @@ -8,6 +8,7 @@ import logging import threading import queue +import math import json import mido @@ -22,8 +23,9 @@ from aria.run import _get_embedding EMBEDDING_OFFSET: int = 0 -DTYPE = mx.float32 +DTYPE = mx.bfloat16 MAX_SEQ_LEN: int = 4096 +KV_CHUNK_SIZE: int = 256 PREFILL_CHUNK_SIZE_L: int = 128 PREFILL_CHUNK_SIZE: int = 16 RECALC_DUR_PREFILL_CHUNK_SIZE: int = 8 @@ -198,6 +200,8 @@ def prefill( logits = model( idxs=idxs, input_pos=input_pos + EMBEDDING_OFFSET, + max_kv_pos=math.ceil(input_pos[-1].item() / KV_CHUNK_SIZE) + * KV_CHUNK_SIZE, offset=input_pos[0] + EMBEDDING_OFFSET, ) @@ -214,6 +218,8 @@ def decode_one( logits = model( idxs=idxs, input_pos=input_pos + EMBEDDING_OFFSET, + max_kv_pos=math.ceil(input_pos[-1].item() / KV_CHUNK_SIZE) + * KV_CHUNK_SIZE, offset=input_pos[0] + EMBEDDING_OFFSET, )[:, -1] @@ -342,20 +348,23 @@ def warmup_model(model: TransformerLM): return model -def load_model( - checkpoint_path: str, -): +def load_model(checkpoint_path: str): logger = get_logger() tokenizer = AbsTokenizer() model_config = ModelConfig(**load_model_config("medium-emb")) model_config.set_vocab_size(tokenizer.vocab_size) + weights = mx.load(checkpoint_path) + for key, weight in weights.items(): + if weight.dtype != DTYPE: + weights[key] = weight.astype(DTYPE) + logging.info(f"Loading model weights from {checkpoint_path}") init_start_time_s = time.time() model = TransformerLM(model_config) - model.load_weights(checkpoint_path, strict=False) + model.load_weights(list(weights.items()), strict=False) model.eval() if args.quantize: @@ -1153,7 +1162,13 @@ def stream_midi( msgs.append(archived_msg) last_archive_time_ms = msg["epoch_time_ms"] - results_queue.put(msgs) + midi_out.send( + mido.Message( + "control_change", control=64, value=0, channel=0, time=0 + ) + ) + + results_queue.put(msgs) def stream_msgs( From 918ffcaf2778d6bb31a634bb49fd5e737b02c8c8 Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 27 Aug 2025 19:12:25 +0100 Subject: [PATCH 17/21] add flag for back-and-forth mode --- demo/demo_mlx.py | 72 +++++++++++++++++++++++++++++++++++++----------- 1 file changed, 56 insertions(+), 16 deletions(-) diff --git a/demo/demo_mlx.py b/demo/demo_mlx.py index 4146c9e..ceab1af 100644 --- a/demo/demo_mlx.py +++ b/demo/demo_mlx.py @@ -9,6 +9,8 @@ import threading import queue import math +import sys +import select import json import mido @@ -104,6 +106,12 @@ def parse_args(): type=int, help="MIDI control change message context reset", ) + argp.add_argument( + "--back-and-forth", + action="store_true", + help="Enable toggling between human and AI. If not set, the control signal will reset the session.", + required=False, + ) argp.add_argument( "--temp", help="sampling temperature value", @@ -1653,22 +1661,34 @@ def listen_for_keypress_control_signal( reset_sentinel: threading.Event, ): logger = get_logger("KEYBOARD") + logger.info( + "Listening for keyboard input (Enter to start AI, any other key + Enter to reset)." + ) + while not reset_sentinel.is_set(): - time.sleep(2) - _input = input() - logger.info(f'Keypress seen "{_input}"') - if _input == "": - control_sentinel.set() - else: - reset_sentinel.set() - control_sentinel.set() - logger.debug("Exiting") - return + rlist, _, _ = select.select([sys.stdin], [], [], 0.01) + + if rlist: + _input = sys.stdin.readline().strip() + logger.info(f'Keypress seen "{_input}"') + + if _input == "": + control_sentinel.set() + else: + reset_sentinel.set() + control_sentinel.set() + logger.debug("Exiting keypress listener on reset signal.") + return + + logger.debug( + "Exiting keypress listener because reset_sentinel was set by another thread." + ) def _listen( midi_control_queue: queue.Queue, reset_sentinel: threading.Event, + currently_generating_sentinel: threading.Event, logger: logging.Logger, midi_control_signal: int | None = None, midi_reset_control_signal: int | None = None, @@ -1682,22 +1702,31 @@ def _listen( logger.info( f"Listening for takeover signal ({midi_control_signal}) and reset signal ({midi_reset_control_signal}) on control queue." ) + seen_note_on = False while not reset_sentinel.is_set(): try: msg = midi_control_queue.get(block=True, timeout=0.01) except queue.Empty: continue + if msg.type == "note_on" and msg.velocity > 0: + seen_note_on = True + + should_return_signal = ( + seen_note_on or currently_generating_sentinel.is_set() + ) if ( msg.type == "control_change" and msg.control == midi_control_signal and msg.value >= 64 + and should_return_signal ): return midi_control_signal elif ( msg.type == "control_change" and msg.control == midi_reset_control_signal and msg.value >= 64 + and should_return_signal ): return midi_reset_control_signal @@ -1706,8 +1735,10 @@ def listen_for_midi_control_signal( midi_control_queue: queue.Queue, control_sentinel: threading.Event, reset_sentinel: threading.Event, + currently_generating_sentinel: threading.Event, midi_control_signal: int | None = None, midi_reset_control_signal: int | None = None, + back_and_forth: bool = False, ): logger = get_logger("MIDI-CONTROL") @@ -1716,6 +1747,7 @@ def listen_for_midi_control_signal( signal_received = _listen( midi_control_queue=midi_control_queue, reset_sentinel=reset_sentinel, + currently_generating_sentinel=currently_generating_sentinel, midi_control_signal=midi_control_signal, midi_reset_control_signal=midi_reset_control_signal, logger=logger, @@ -1724,16 +1756,22 @@ def listen_for_midi_control_signal( if signal_received is not None: logger.info(f"Seen MIDI control signal ({signal_received})") - if signal_received == midi_control_signal: - control_sentinel.set() - elif signal_received == midi_reset_control_signal: + if signal_received == midi_reset_control_signal: + logger.info("Resetting (reset)") reset_sentinel.set() control_sentinel.set() + elif signal_received == midi_control_signal: + if ( + currently_generating_sentinel.is_set() + and back_and_forth is False + ): + logger.info("Resetting (control)") + reset_sentinel.set() + control_sentinel.set() logger.debug("Exiting MIDI control listener") -# TODO: Fix issue with pedal being stuck down (send pedal off at end of stream_msgs) # TODO: Debug, fix, and perhaps refactor the functionality for going back and forth # - One idea is on resume, to wait to start the clock until the user plays. def run( @@ -1750,6 +1788,7 @@ def run( wait_for_close: bool, temperature: float, min_p: float, + back_and_forth: bool, ): logger = get_logger() tokenizer = AbsTokenizer() @@ -1785,8 +1824,10 @@ def run( "midi_control_queue": midi_control_queue, "control_sentinel": control_sentinel, "reset_sentinel": reset_sentinel, + "currently_generating_sentinel": currently_generating_sentinel, "midi_control_signal": midi_control_signal, "midi_reset_control_signal": midi_reset_control_signal, + "back_and_forth": back_and_forth, }, ) keypress_thread.start() @@ -1950,6 +1991,7 @@ def main(args): wait_for_close=args.wait_for_close, temperature=args.temp, min_p=args.min_p, + back_and_forth=args.back_and_forth, ) reset_sentinel = threading.Event() @@ -2024,8 +2066,6 @@ def close_notes(midi_out_port: str): out.send(mido.Message("note_off", note=note, velocity=0)) -# TODO: Debug issue with incorrect tokens being generated (model or mlx issue?) -# TODO: Fix bug with context length (metal error?) if __name__ == "__main__": args = parse_args() From 2253b232cbed235059cf8fd33ce748647650f6cc Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 17 Sep 2025 15:25:00 +0100 Subject: [PATCH 18/21] add config override to demo --- demo/config.json | 51 ++++++++++++++++++++++++++++++++++++++++++++++++ demo/demo_mlx.py | 10 ++++++---- 2 files changed, 57 insertions(+), 4 deletions(-) create mode 100644 demo/config.json diff --git a/demo/config.json b/demo/config.json new file mode 100644 index 0000000..c648810 --- /dev/null +++ b/demo/config.json @@ -0,0 +1,51 @@ +{ + "tokenizer": { + "abs": { + "ignore_instruments": { + "piano": false, + "chromatic": true, + "organ": false, + "guitar": false, + "bass": false, + "strings": false, + "ensemble": false, + "brass": false, + "reed": false, + "pipe": false, + "synth_lead": false, + "synth_pad": true, + "synth_effect": true, + "ethnic": true, + "percussive": true, + "sfx": true + }, + "instrument_programs": { + "piano": 0, + "chromatic": 13, + "organ": 16, + "guitar": 24, + "bass": 32, + "strings": 40, + "ensemble": 48, + "brass": 56, + "reed": 64, + "pipe": 73, + "synth_lead": 80, + "synth_pad": 88, + "synth_effect": 96, + "ethnic": 104, + "percussive": 112, + "sfx": 120 + }, + "drum_velocity": 60, + "velocity_quantization_step": 10, + "abs_time_step_ms": 5000, + "max_dur_ms": 5000, + "time_step_ms": 10, + "include_pedal": true, + "composer_names": ["bach", "beethoven", "mozart", "chopin", "rachmaninoff", "liszt", "debussy", "schubert", "brahms", "ravel", "satie", "scarlatti"], + "form_names": ["sonata", "prelude", "nocturne", "étude", "waltz", "mazurka", "impromptu", "fugue"], + "genre_names": ["jazz", "classical"] + } + } +} diff --git a/demo/demo_mlx.py b/demo/demo_mlx.py index ceab1af..e969aff 100644 --- a/demo/demo_mlx.py +++ b/demo/demo_mlx.py @@ -10,6 +10,7 @@ import queue import math import sys +import pathlib import select import json import mido @@ -46,6 +47,7 @@ VELOCITY_OUTPUT_LATENCY_MS: dict[int, int] = {v: 0 for v in range(0, 127, 10)} +config_path = pathlib.Path(__file__).parent.resolve().joinpath("config.json") file_handler = logging.FileHandler("./demo.log", mode="w") file_handler.setLevel(logging.DEBUG) @@ -359,7 +361,7 @@ def warmup_model(model: TransformerLM): def load_model(checkpoint_path: str): logger = get_logger() - tokenizer = AbsTokenizer() + tokenizer = AbsTokenizer(config_path=config_path) model_config = ModelConfig(**load_model_config("medium-emb")) model_config.set_vocab_size(tokenizer.vocab_size) @@ -1426,7 +1428,7 @@ def continuous_prefill( received_messages_queue: queue.Queue, prev_context: list[int], ): - tokenizer = AbsTokenizer() + tokenizer = AbsTokenizer(config_path=config_path) logger = get_logger("PREFILL") msg_cnt = 0 seen_sentinel = False @@ -1791,7 +1793,7 @@ def run( back_and_forth: bool, ): logger = get_logger() - tokenizer = AbsTokenizer() + tokenizer = AbsTokenizer(config_path=config_path) control_sentinel = threading.Event() currently_generating_sentinel = threading.Event() @@ -2001,7 +2003,7 @@ def playback(midi_path: str, midi_out: str, save_path: str | None = None): close_notes(midi_out) starting_epoch_time_ms = get_epoch_time_ms() - tokenizer = AbsTokenizer() + tokenizer = AbsTokenizer(config_path=config_path) tokens_queue = queue.Queue() midi_messages_queue = queue.Queue() stream_midi_results_queue = queue.Queue() From acdeb9781e3fb49ab0405ff3943262242ee55940 Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 17 Sep 2025 16:45:19 +0100 Subject: [PATCH 19/21] minor qol changes --- demo/demo_mlx.py | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/demo/demo_mlx.py b/demo/demo_mlx.py index e969aff..9e08dbf 100644 --- a/demo/demo_mlx.py +++ b/demo/demo_mlx.py @@ -109,7 +109,7 @@ def parse_args(): help="MIDI control change message context reset", ) argp.add_argument( - "--back-and-forth", + "--back_and_forth", action="store_true", help="Enable toggling between human and AI. If not set, the control signal will reset the session.", required=False, @@ -374,6 +374,11 @@ def load_model(checkpoint_path: str): init_start_time_s = time.time() model = TransformerLM(model_config) + + assert ( + tokenizer.vocab_size == weights["model.tok_embeddings.weight"].shape[0] + ), "Embedding shape mismatch. Ensure that you are loading the demo-specific checkpoint." + model.load_weights(list(weights.items()), strict=False) model.eval() @@ -1661,6 +1666,8 @@ def _send_delayed_message(_midi_performance_queue: queue.Queue, msg): def listen_for_keypress_control_signal( control_sentinel: threading.Event, reset_sentinel: threading.Event, + currently_generating_sentinel: threading.Event, + back_and_forth: bool = False, ): logger = get_logger("KEYBOARD") logger.info( @@ -1675,12 +1682,17 @@ def listen_for_keypress_control_signal( logger.info(f'Keypress seen "{_input}"') if _input == "": + if ( + currently_generating_sentinel.is_set() + and back_and_forth is False + ): + logger.info("Resetting (control)") + reset_sentinel.set() control_sentinel.set() else: + logger.info("Resetting (reset)") reset_sentinel.set() control_sentinel.set() - logger.debug("Exiting keypress listener on reset signal.") - return logger.debug( "Exiting keypress listener because reset_sentinel was set by another thread." @@ -1818,7 +1830,12 @@ def run( keypress_thread = threading.Thread( target=listen_for_keypress_control_signal, - args=[control_sentinel, reset_sentinel], + kwargs={ + "control_sentinel": control_sentinel, + "reset_sentinel": reset_sentinel, + "currently_generating_sentinel": currently_generating_sentinel, + "back_and_forth": back_and_forth, + }, ) midi_control_thread = threading.Thread( target=listen_for_midi_control_signal, @@ -1949,6 +1966,7 @@ def forward_midi_input_port( def main(args): + logger = get_logger() model = load_model(checkpoint_path=args.checkpoint) model = warmup_model(model=model) if args.embedding_checkpoint and args.embedding_midi_path: @@ -1960,6 +1978,7 @@ def main(args): assert (args.midi_path and os.path.isfile(args.midi_path)) or args.midi_in + logger.info(f"Available MIDI ports: {mido.get_output_names()}") midi_performance_queue = queue.Queue() midi_control_queue = queue.Queue() From 80e6a43f907c452a4005c82b42e2a1c0aea3a823 Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 17 Sep 2025 16:45:31 +0100 Subject: [PATCH 20/21] update README --- README.md | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index c50a923..1613777 100644 --- a/README.md +++ b/README.md @@ -79,13 +79,25 @@ Our embedding model was trained to capture composition-level and performance-lev ## Real-time demo -In `demo/` we provide CUDA (Linux/PyTorch) and MLX (Apple Silicon) implementations of the real-time interactive piano-continuation demo showcased in our release blog post. For the demo we used an acoustic Yamaha Disklavier piano with simultaneous MIDI input and output ports connected via a standard MIDI interface. +In `demo/` we provide an MLX (Apple Silicon) implementation of the real-time interactive piano-continuation demo showcased in our release blog post. In order to use the demo, you must download the demo-specific model checkpoint which enhances the model to additionally control the sustain pedal ([direct-download](https://huggingface.co/loubb/aria-medium-base/resolve/main/model-demo.safetensors?download=true)). -❗**NOTE**: Responsiveness of the real-time demo is dependent on your system configuration, e.g., GPU FLOPS and memory bandwidth. +For our demonstration, we used an acoustic Yamaha Disklavier piano with simultaneous MIDI input and output ports connected via a standard MIDI interface. We disabled the built-in Disklavier playback mode, instead manually calibrating key-velocity latency to enhance responsiveness. You may recreate this in your own environment with our acoustic calibration settings, using the following script: -A MIDI input device is not strictly required to play around with the demo: By using the `--midi_path` and `--midi_through` arguments you can mock real-time input by playing from a MIDI file. All that is required are MIDI drivers (e.g., CoreMIDI, ALSA) and a virtual software instrument (e.g., Fluidsynth, Pianoteq) to render the output. +❗**NOTE**: It is vital that you use the `latency=off`/`realtime` Disklavier playback setting when using the provided configuration for `--hardware`. -Example usage (MLX): +```bash +python ./demo/demo_mlx.py \ + --checkpoint \ + --midi_in \ + --midi_out \ + --hardware ./demo/hardware/c4dm-disklavier.json \ + --midi_control_signal 67 \ + --midi_reset_control_signal 66 \ + --temp 0.9 \ + --min_p 0.03 +``` + +A MIDI input device is not strictly required to play around with the demo: By using the `--midi_path` and `--midi_through` arguments you can mock real-time input by playing from a MIDI file. All that is required are MIDI drivers (e.g., CoreMIDI) and a virtual software instrument (e.g., Fluidsynth, Pianoteq) to render the output. In this mode, you can initiate the model takeover by pressing the enter key. ```bash MIDI_PATH="example-prompts/pokey_jazz.mid" @@ -93,13 +105,14 @@ MIDI_PATH="example-prompts/pokey_jazz.mid" python demo/demo_mlx.py \ --checkpoint \ --midi_path ${MIDI_PATH} \ - --midi_through \ - --midi_out \ - --save_path \ - --temp 0.98 \ - --min_p 0.035 + --midi_through \ + --midi_out \ + --temp 0.9 \ + --min_p 0.03 ``` +❗**NOTE**: Responsiveness of the real-time demo is dependent on your system configuration, specifically GPU memory bandwidth. + ## Evaluation We provide the specific files/splits we used for Aria-MIDI derived linear-probe and classification evaluations. These can be downloaded from HuggingFace ([direct-download](https://huggingface.co/loubb/aria-medium-base/resolve/main/eval-splits.tar.gz?download=true)). Class labels are provided in `metadata.json` with the schema: From f993878dfd38e096f878dd7c454acc4f9c23909f Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 17 Sep 2025 16:57:05 +0100 Subject: [PATCH 21/21] update README --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 1613777..d9f497c 100644 --- a/README.md +++ b/README.md @@ -100,9 +100,9 @@ python ./demo/demo_mlx.py \ A MIDI input device is not strictly required to play around with the demo: By using the `--midi_path` and `--midi_through` arguments you can mock real-time input by playing from a MIDI file. All that is required are MIDI drivers (e.g., CoreMIDI) and a virtual software instrument (e.g., Fluidsynth, Pianoteq) to render the output. In this mode, you can initiate the model takeover by pressing the enter key. ```bash -MIDI_PATH="example-prompts/pokey_jazz.mid" +MIDI_PATH="./example-prompts/smooth_jazz.mid" -python demo/demo_mlx.py \ +python ./demo/demo_mlx.py \ --checkpoint \ --midi_path ${MIDI_PATH} \ --midi_through \