diff --git a/Cargo.toml b/Cargo.toml index 7b1757f..7054b9c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,8 +14,10 @@ name = "aetherlink" path = "src/main.rs" [dependencies] -# Core dependencies -iroh = { version = "0.28.1", features = ["metrics"] } +# Core dependencies - using same version as pirohxy +iroh = "0.28.1" +iroh-base = "0.28.1" +iroh-net = "0.28.1" tokio = { version = "1.41", features = ["full"] } anyhow = "1.0" tracing = "0.1" @@ -37,11 +39,11 @@ bytes = "1.8" # Utilities directories = "5.0" -ed25519-dalek = { version = "2.1", features = ["pem"] } -rand_core = { version = "0.6", features = ["getrandom"] } +ed25519-dalek = { version = "2.1", features = ["pem", "pkcs8"] } futures-util = "0.3" url = "2.5" base64 = "0.22" +shellexpand = "3.1" [dev-dependencies] tempfile = "3.10" diff --git a/src/client.rs b/src/client.rs index 3750e6e..4eaa431 100644 --- a/src/client.rs +++ b/src/client.rs @@ -4,8 +4,8 @@ use http_body_util::{BodyExt, Full}; use hyper::client::conn::http1; use hyper::server::conn::http1 as server_http1; use hyper::service::service_fn; -use hyper::{Request, Response, StatusCode, Uri}; -use iroh::Endpoint; +use hyper::{Request, Response, StatusCode}; +use iroh_net::endpoint::Endpoint; use std::net::SocketAddr; use std::sync::Arc; use tokio::net::{TcpListener, TcpStream}; @@ -23,21 +23,26 @@ pub async fn create_tunnel( local_port: u16, bind_addr: SocketAddr, ) -> Result<()> { + use iroh_base::key::NodeId; + use std::str::FromStr; + // Parse server node ID - let server_node_id = server_id.parse() + let server_node_id = NodeId::from_str(&server_id) .context("Invalid server node ID")?; // Start Iroh endpoint let endpoint = Endpoint::builder() .secret_key(identity.secret_key.clone()) - .discovery_n0() - .bind() + .alpns(vec![TUNNEL_ALPN.to_vec()]) + .bind(0) .await?; info!("Connecting to server: {}", server_id); - // Connect to server - let conn = endpoint.connect(server_node_id, TUNNEL_ALPN).await + // Connect to server + use iroh_net::NodeAddr; + let node_addr = NodeAddr::new(server_node_id); + let conn = endpoint.connect(node_addr, &TUNNEL_ALPN).await .context("Failed to connect to server")?; // Register tunnel @@ -127,7 +132,7 @@ pub async fn create_tunnel( async fn handle_client_request( stream: TcpStream, - _conn: Arc, + _conn: Arc, domain: String, local_port: u16, ) -> Result<()> { @@ -223,19 +228,24 @@ pub async fn list_tunnels( identity: Identity, server_id: String, ) -> Result<()> { + use iroh_base::key::NodeId; + use std::str::FromStr; + // Parse server node ID - let server_node_id = server_id.parse() + let server_node_id = NodeId::from_str(&server_id) .context("Invalid server node ID")?; // Start Iroh endpoint let endpoint = Endpoint::builder() .secret_key(identity.secret_key.clone()) - .discovery_n0() - .bind() + .alpns(vec![TUNNEL_ALPN.to_vec()]) + .bind(0) .await?; // Connect to server - let conn = endpoint.connect(server_node_id, TUNNEL_ALPN).await + use iroh_net::NodeAddr; + let node_addr = NodeAddr::new(server_node_id); + let conn = endpoint.connect(node_addr, &TUNNEL_ALPN).await .context("Failed to connect to server")?; // Request tunnel list diff --git a/src/config.rs b/src/config.rs index 623c50f..364b472 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,7 +1,7 @@ use anyhow::{Context, Result}; -use ed25519_dalek::{SigningKey, pkcs8::{DecodePrivateKey, EncodePrivateKey}}; -use iroh::SecretKey; -use rand_core::OsRng; +use ed25519_dalek::SigningKey; +use ed25519_dalek::pkcs8::{DecodePrivateKey, EncodePrivateKey, LineEnding}; +use iroh_base::key::SecretKey; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::path::Path; @@ -76,7 +76,7 @@ impl Default for Config { impl Identity { pub fn generate() -> Self { - let secret_key = SecretKey::generate(&mut OsRng); + let secret_key = SecretKey::generate(); Self { secret_key } } @@ -89,14 +89,15 @@ impl Identity { .with_context(|| format!("Failed to read identity file: {:?}", path))?; let signing_key = SigningKey::from_pkcs8_pem(&pem) .context("Failed to parse identity key")?; - Ok(Self { - secret_key: signing_key.into(), - }) + let bytes = signing_key.to_bytes(); + let secret_key = SecretKey::from_bytes(&bytes); + Ok(Self { secret_key }) } pub fn to_file(&self, path: &Path) -> Result<()> { - let signing_key: SigningKey = self.secret_key.secret().clone(); - let pem = signing_key.to_pkcs8_pem(ed25519_dalek::pkcs8::spki::der::pem::LineEnding::default()) + let bytes = self.secret_key.to_bytes(); + let signing_key = SigningKey::from_bytes(&bytes); + let pem = signing_key.to_pkcs8_pem(LineEnding::default()) .context("Failed to encode identity key")?; std::fs::write(path, pem.as_bytes()) .with_context(|| format!("Failed to write identity file: {:?}", path))?; @@ -113,8 +114,7 @@ mod secret_key_serde { where S: Serializer, { - let signing_key: SigningKey = key.secret().clone(); - let bytes = signing_key.to_bytes(); + let bytes = key.to_bytes(); let encoded = STANDARD.encode(&bytes); serializer.serialize_str(&encoded) } @@ -126,10 +126,9 @@ mod secret_key_serde { let encoded = String::deserialize(deserializer)?; let bytes = STANDARD.decode(&encoded) .map_err(serde::de::Error::custom)?; - let signing_key = SigningKey::from_bytes(&bytes.try_into().map_err(|_| { - serde::de::Error::custom("Invalid key length") - })?); - Ok(signing_key.into()) + let bytes_32: [u8; 32] = bytes.try_into() + .map_err(|_| serde::de::Error::custom("Invalid key length"))?; + Ok(SecretKey::from_bytes(&bytes_32)) } } diff --git a/src/main.rs b/src/main.rs index 5db131e..88ca475 100644 --- a/src/main.rs +++ b/src/main.rs @@ -209,17 +209,24 @@ async fn main() -> Result<()> { } Commands::AddServer { name, node_id } => { - let server_id = node_id.parse() - .context("Invalid node ID format")?; + use iroh_base::key::NodeId; + use std::str::FromStr; + + let server_id = NodeId::from_str(&node_id) + .context("Invalid node ID format")? + .to_string(); config.servers.insert(name.clone(), server_id); config.save(&config_path)?; - info!("✓ Added server alias '{}' → {}", name, server_id); + info!("✓ Added server alias '{}' → {}", name, node_id); } Commands::Authorize { client_id } => { - let client_node_id = client_id.parse() + use iroh_base::key::NodeId; + use std::str::FromStr; + + let client_node_id = NodeId::from_str(&client_id) .context("Invalid client node ID")?; let auth_file = config_path.join("auth").join(&client_id); @@ -241,15 +248,4 @@ async fn main() -> Result<()> { } Ok(()) -} - -mod shellexpand { - pub fn tilde(s: &str) -> std::borrow::Cow { - if s.starts_with("~/") { - if let Ok(home) = std::env::var("HOME") { - return std::borrow::Cow::Owned(s.replacen("~", &home, 1)); - } - } - std::borrow::Cow::Borrowed(s) - } } \ No newline at end of file diff --git a/src/server.rs b/src/server.rs index 4075e47..a3c3638 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,14 +1,13 @@ -use anyhow::{Context, Result}; +use anyhow::Result; use bytes::Bytes; use http_body_util::{BodyExt, Full}; use hyper::server::conn::http1; use hyper::service::service_fn; use hyper::{Method, Request, Response, StatusCode}; -use iroh::protocol::{Router, ProtocolHandler}; -use iroh::{Endpoint, NodeId}; +use iroh_net::endpoint::{Endpoint, Connection}; +use iroh_base::key::NodeId; use std::collections::HashMap; use std::net::SocketAddr; -use std::path::Path; use std::sync::Arc; use tokio::net::TcpListener; use tokio::sync::RwLock; @@ -30,22 +29,33 @@ pub async fn run_server( // Start Iroh endpoint let endpoint = Endpoint::builder() .secret_key(identity.secret_key.clone()) - .discovery_n0() - .bind() + .alpns(vec![TUNNEL_ALPN.to_vec()]) + .bind(0) .await?; info!("Server listening on Iroh network"); info!("Node ID: {}", identity.node_id()); - // Start tunnel protocol handler - let handler = TunnelHandler { - state: state.clone(), - auth: auth.clone(), - }; - - let router = Router::builder(endpoint.clone()) - .accept(TUNNEL_ALPN, handler) - .spawn(); + // Handle incoming connections + let accept_state = state.clone(); + let accept_auth = auth.clone(); + let accept_endpoint = endpoint.clone(); + tokio::spawn(async move { + loop { + match accept_endpoint.accept().await { + Some(incoming) => { + let state = accept_state.clone(); + let auth = accept_auth.clone(); + tokio::spawn(async move { + if let Ok(conn) = incoming.accept().await { + handle_connection(conn, state, auth).await; + } + }); + } + None => break, + } + } + }); // Start admin API let admin_listener = TcpListener::bind(admin_bind).await?; @@ -80,10 +90,93 @@ pub async fn run_server( tokio::signal::ctrl_c().await?; info!("Shutting down server..."); - router.shutdown().await?; Ok(()) } +async fn handle_connection( + conn: Connection, + state: Arc, + auth: Arc, +) { + let client_id = conn.remote_node_id(); + + // Check authorization + if !auth.is_authorized(&client_id.to_string()) { + warn!("Unauthorized client attempted connection: {}", client_id); + return; + } + + debug!("Accepted connection from {}", client_id); + + // Handle tunnel requests + loop { + match conn.accept_bi().await { + Ok((mut send, mut recv)) => { + // Read tunnel message + let mut buf = Vec::new(); + if recv.read_to_end(1024 * 1024, &mut buf).await.is_err() { + break; + } + + match serde_json::from_slice::(&buf) { + Ok(msg) => match msg { + TunnelMessage::Register { domain, port } => { + match state.register_tunnel(domain.clone(), client_id, port).await { + Ok(_) => { + let response = TunnelMessage::Registered { domain }; + if let Ok(data) = serde_json::to_vec(&response) { + let _ = send.write_all(&data).await; + let _ = send.finish(); + } + } + Err(e) => { + let response = TunnelMessage::Error { + message: e.to_string() + }; + if let Ok(data) = serde_json::to_vec(&response) { + let _ = send.write_all(&data).await; + let _ = send.finish(); + } + } + } + } + TunnelMessage::Unregister { domain } => { + state.unregister_tunnel(&domain).await; + let response = TunnelMessage::Unregistered { domain }; + if let Ok(data) = serde_json::to_vec(&response) { + let _ = send.write_all(&data).await; + let _ = send.finish(); + } + } + TunnelMessage::List => { + let tunnels = state.list_tunnels().await; + let domains: Vec = tunnels.iter() + .filter(|t| t.client_id == client_id) + .map(|t| t.domain.clone()) + .collect(); + let response = TunnelMessage::TunnelList { tunnels: domains }; + if let Ok(data) = serde_json::to_vec(&response) { + let _ = send.write_all(&data).await; + let _ = send.finish(); + } + } + _ => { + warn!("Unexpected message from client"); + } + }, + Err(e) => { + error!("Failed to parse tunnel message: {}", e); + } + } + } + Err(e) => { + debug!("Connection closed: {}", e); + break; + } + } + } +} + #[derive(Debug)] struct ServerState { tunnels: Arc>>, @@ -140,101 +233,6 @@ impl ServerState { } } -struct TunnelHandler { - state: Arc, - auth: Arc, -} - -impl ProtocolHandler for TunnelHandler { - fn accept( - &self, - conn: iroh::endpoint::Connection, - ) -> impl std::future::Future> + Send { - let state = self.state.clone(); - let auth = self.auth.clone(); - - async move { - let client_id = conn.remote_node_id()?; - - // Check authorization - if !auth.is_authorized(&client_id.to_string()) { - warn!("Unauthorized client attempted connection: {}", client_id); - return Err(iroh::protocol::AcceptError::User { - source: anyhow::anyhow!("Unauthorized client").into(), - }); - } - - debug!("Accepted connection from {}", client_id); - - // Handle tunnel requests - loop { - match conn.accept_bi().await { - Ok((send, mut recv)) => { - // Read tunnel message - let mut buf = Vec::new(); - recv.read_to_end(1024 * 1024, &mut buf).await?; - - match serde_json::from_slice::(&buf) { - Ok(msg) => { - match msg { - TunnelMessage::Register { domain, port } => { - match state.register_tunnel(domain.clone(), client_id, port).await { - Ok(_) => { - let response = TunnelMessage::Registered { domain }; - let data = serde_json::to_vec(&response)?; - send.write_all(&data).await?; - send.finish()?; - } - Err(e) => { - let response = TunnelMessage::Error { - message: e.to_string() - }; - let data = serde_json::to_vec(&response)?; - send.write_all(&data).await?; - send.finish()?; - } - } - } - TunnelMessage::Unregister { domain } => { - state.unregister_tunnel(&domain).await; - let response = TunnelMessage::Unregistered { domain }; - let data = serde_json::to_vec(&response)?; - send.write_all(&data).await?; - send.finish()?; - } - TunnelMessage::List => { - let tunnels = state.list_tunnels().await; - let domains: Vec = tunnels.iter() - .filter(|t| t.client_id == client_id) - .map(|t| t.domain.clone()) - .collect(); - let response = TunnelMessage::TunnelList { tunnels: domains }; - let data = serde_json::to_vec(&response)?; - send.write_all(&data).await?; - send.finish()?; - } - _ => { - warn!("Unexpected message from client"); - } - } - } - Err(e) => { - error!("Failed to parse tunnel message: {}", e); - } - } - } - Err(e) => { - debug!("Connection closed: {}", e); - break; - } - } - } - - Ok(()) - } - } -} - async fn handle_admin_request( state: Arc, req: Request, diff --git a/src/tunnel.rs b/src/tunnel.rs index 1e066e4..8669724 100644 --- a/src/tunnel.rs +++ b/src/tunnel.rs @@ -1,7 +1,7 @@ use serde::{Deserialize, Serialize}; /// ALPN protocol identifier for AetherLink tunnels -pub const TUNNEL_ALPN: &[u8] = b"/aetherlink/tunnel/1.0.0"; +pub const TUNNEL_ALPN: &[u8] = b"aetherlink/tunnel/1.0.0"; /// Messages exchanged between client and server #[derive(Debug, Clone, Serialize, Deserialize)]