diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..5c98b42 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,2 @@ +# Default ignored files +/workspace.xml \ No newline at end of file diff --git a/.idea/codeStyles/codeStyleConfig.xml b/.idea/codeStyles/codeStyleConfig.xml new file mode 100644 index 0000000..a55e7a1 --- /dev/null +++ b/.idea/codeStyles/codeStyleConfig.xml @@ -0,0 +1,5 @@ + + + + \ No newline at end of file diff --git a/.idea/dictionaries/rsouvik.xml b/.idea/dictionaries/rsouvik.xml new file mode 100644 index 0000000..006c293 --- /dev/null +++ b/.idea/dictionaries/rsouvik.xml @@ -0,0 +1,7 @@ + + + + webutils + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..28a804d --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,6 @@ + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..ecb10da --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/simple-chat.iml b/.idea/simple-chat.iml new file mode 100644 index 0000000..d6ebd48 --- /dev/null +++ b/.idea/simple-chat.iml @@ -0,0 +1,9 @@ + + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..35eb1dd --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..693f8ff --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "chat_app" +version = "0.1.0" +edition = "2021" + +#[package.metadata.release] +#release = false + +[profile.release] +debug = true + +[target.x86_64-unknown-linux-gnu] +linker = "/usr/bin/clang" +rustflags = ["-Clink-arg=-fuse-ld=lld", "-Clink-arg=-Wl,--no-rosegment"] + +[dependencies] +futures-util = { version = "0.3.30", features = ["sink"] } +http = "1.1.0" +#futures-util = "0.3" +tokio = { version = "1.40.0", features = ["full"] } +tokio-websockets = { version = "0.8.0", features = ["client", "fastrand", "server", "sha1_smol"] } +structopt = { version = "0.3", default-features = false } +futures = "0.3.30" +diesel = { version = "2.2.8", features = ["postgres", "mysql"] } +dotenvy = "0.15.7" +actix-web = "4.5.1" +serde = {version = "1.0.27", features = ["derive"] } +serde_json = "1.0.9" \ No newline at end of file diff --git a/src/bin/client.rs b/src/bin/client.rs new file mode 100644 index 0000000..3d651c9 --- /dev/null +++ b/src/bin/client.rs @@ -0,0 +1,110 @@ +use futures_util::stream::StreamExt; +use futures_util::SinkExt; +use http::Uri; +use structopt::StructOpt; +use tokio::io::{AsyncBufReadExt, BufReader}; +use tokio_websockets::{ClientBuilder, Message}; + +#[derive(StructOpt)] +struct Cli { + #[structopt(short, long, default_value = "127.0.0.1")] + host: String, + + #[structopt(short, long, default_value = "2000")] + port: String, + + #[structopt(short, long)] + username: String, +} + +#[tokio::main] +async fn main() -> Result<(), tokio_websockets::Error> { + let args = Cli::from_args(); + let ws_url = format!("ws://{}:{}/", args.host, args.port); + + let (mut ws_stream, _) = ClientBuilder::from_uri(Uri::from_maybe_shared(ws_url).unwrap()) + .connect() + .await?; + + let stdin = tokio::io::stdin(); + let mut stdin = BufReader::new(stdin).lines(); + + // Send the join message immediately + ws_stream + .send(Message::text(format!("/join {}", args.username))) + .await?; + + loop { + tokio::select! { + incoming = ws_stream.next() => { + match incoming { + Some(Ok(msg)) => { + if let Some(text) = msg.as_text() { + println!("From server: {}", text); + } + }, + Some(Err(err)) => return Err(err.into()), + None => return Ok(()), + } + } + + res = stdin.next_line() => { + match res { + Ok(None) => return Ok(()), + Ok(Some(line)) => { + if line.starts_with("send ") { + let msg = line[5..].to_string(); + ws_stream.send(Message::text(msg)).await?; + } else if line == "leave" { + ws_stream.send(Message::text("/leave".to_string())).await?; + return Ok(()); + } else { + println!("Unknown command."); + } + } + Err(err) => return Err(err.into()), + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use futures_util::stream; + use tokio_websockets::Message; + use futures::*; + + #[tokio::test] + async fn test_handle_server_messages() { + let messages = vec![ + Message::text("INFO: Welcome to the chat"), + Message::text("ERROR: Username already taken"), + Message::text("Hello from another user"), + ]; + + let ws_stream = stream::iter(messages); + + let mut results = vec![]; + + ws_stream + .for_each(|msg| { + if let Some(text) = msg.as_text() { + if text.starts_with("ERROR:") { + results.push(format!("Error from server: {}", &text[7..])); + } else if text.starts_with("INFO:") { + results.push(format!("Info: {}", &text[6..])); + } else { + results.push(format!("From server: {}", text)); + } + } + futures::future::ready(()) + }) + .await; + + assert_eq!(results[0], "Info: Welcome to the chat"); + assert_eq!(results[1], "Error from server: Username already taken"); + assert_eq!(results[2], "From server: Hello from another user"); + } +} \ No newline at end of file diff --git a/src/bin/server.rs b/src/bin/server.rs new file mode 100644 index 0000000..f8a6300 --- /dev/null +++ b/src/bin/server.rs @@ -0,0 +1,195 @@ + +use futures_util::sink::SinkExt; +use futures_util::stream::StreamExt; +use std::collections::HashMap; +use std::error::Error; +use std::net::SocketAddr; +use tokio::net::{TcpListener, TcpStream}; +use tokio::sync::{broadcast::{channel, Sender}, Mutex}; +use tokio_websockets::{Message, ServerBuilder, WebSocketStream}; +use std::sync::Arc; +use actix_web::{web, App, HttpServer, Responder}; +use serde::{Serialize, Deserialize}; +extern crate actix_web; + +//mod webutils; + +use chat_app::webutils::{statsall, index, indexPost, User, ServerState}; + +async fn handle_connection( + addr: SocketAddr, + mut ws_stream: WebSocketStream, + state: Arc, +) -> Result<(), Box> { + // Send a welcome message + ws_stream.send(Message::text("Welcome to chat! Type '/join ' to join.".to_string())).await?; + + let mut bcast_rx = state.bcast_tx.subscribe(); + let mut username: Option = None; + + loop { + tokio::select! { + // Handle incoming messages from the client + incoming = ws_stream.next() => { + match incoming { + Some(Ok(msg)) => { + if let Some(text) = msg.as_text() { + if text.starts_with("/join ") { + // Handle user joining + let new_username = text[6..].trim().to_string(); + let mut users = state.users.lock().await; + + //if users.values().any(|name| &name.0 == &new_username) { + if users.values().any(|name| &name.username == &new_username) { + ws_stream.send(Message::text("Username already taken.".to_string())).await?; + } else { + if users.contains_key(&addr) == true { + /*if let Some(tuple_ref) = users.get(&addr) { + if let Some(fe) = tuple_ref.as_ref().map(|t| &t.0) + && let Some(se) = tuple_ref.as_ref().map(|t| &t.1) { + users.insert(addr, (fe,se+1)); + } + }*/ + /*if let Some(fv) = users.get(&addr).cloned().map(|t| &t.0){ + let sname = fv; + if let Some(sv) = users.get(&addr).cloned().map(|t| &t.1){ + let count = sv; + users.insert(addr, (sname.clone(),count+1)); + } + }*/ + + if let Some(U) = users.get(&addr).cloned() { + //users.insert(addr, (sname.clone(), count + 1)); + //users.insert(addr, User(sname.clone(), addr, count+1)); + users.insert(addr, User{username: U.username, addr: addr, lifetime_cnt: U.lifetime_cnt+1}); + } + /*users.entry(addr). + .and_modify(|entry| entry.1+=1) + .or_insert("default_name".to_string(),1);*/ + //users.insert(addr, ((users.get(&addr)).as_ref().0,(users.get(&addr)).as_ref().1+1)); + } + else { + users.insert(addr, User{username:new_username.clone(), addr: addr, lifetime_cnt: 0}); + } + ws_stream.send(Message::text(format!("Joined as {}", new_username))).await?; + state.bcast_tx.send(format!("{} has joined the chat.", new_username))?; + username = Some(new_username); + } + } else if text == "/leave" { + // Handle user leaving + if let Some(name) = username.take() { + let mut users = state.users.lock().await; + users.remove(&addr); + state.bcast_tx.send(format!("{} has left the chat.", name))?; + return Ok(()); + } + } else if let Some(_) = username { + // Broadcast regular messages + state.broadcast_message(&addr, text.into()).await; + } else { + ws_stream.send(Message::text("Please join with '/join ' first.".to_string())).await?; + } + } + } + Some(Err(err)) => return Err(err.into()), + None => return Ok(()), + } + } + + // Handle messages from the broadcast channel + msg = bcast_rx.recv() => { + ws_stream.send(Message::text(msg?)).await?; + } + } + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let (bcast_tx, _) = channel(16); + let state = Arc::new(ServerState { + users: Arc::new(Mutex::new(HashMap::new())), + bcast_tx: bcast_tx.clone(), + }); + + //web data processing + //let (web_sender, mut web_receiver) = mpsc::channel::(100); + + let shared_state = web::Data::new(state.clone()); + //Web server + //Start the web server + let server = HttpServer::new(move || { + App::new() + //.app_data(received_data.clone()) + .app_data(shared_state.clone()) + //.app_data(swarm_controller.clone()) + //.route("/", web::post().to(receive_data)) + .route("/statsuser", web::get().to(index)) + .route("/statsall", web::get().to(statsall)) + .route("/updateuser", web::post().to(indexPost)) + }) + //.bind("127.0.0.1:8080")? + .bind("0.0.0.0:8080")? + .run(); + //.await; + + // Start the event loop + tokio::spawn(server); + + let listener = TcpListener::bind("127.0.0.1:2000").await?; + println!("Listening on port 2000"); + + loop { + let (socket, addr) = listener.accept().await?; + println!("socket: {:?}", socket); + println!("address: {:?}", addr); + let conn_state = state.clone(); + + tokio::spawn(async move { + let ws_stream = ServerBuilder::new().accept(socket).await?; + handle_connection(addr, ws_stream, conn_state).await + }); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::sync::broadcast; + use tokio_websockets::{Message, ServerBuilder}; + use tokio::net::{TcpListener, TcpStream}; + use std::net::SocketAddr; + + #[tokio::test] + async fn test_user_join() { + let (bcast_tx, _) = broadcast::channel(16); + let state = Arc::new(ServerState { + users: Mutex::new(HashMap::new()), + bcast_tx: bcast_tx.clone(), + }); + + // Create a TCP listener for the server + let listener = TcpListener::bind("127.0.0.1:8080").await.unwrap(); + + // Simulate a client connecting to the server + tokio::spawn(async move { + let (client_socket, _) = listener.accept().await.unwrap(); + let (mut ws_stream, _) = ServerBuilder::new().accept(client_socket).await.unwrap(); // Correctly destructuring the tuple + handle_connection("127.0.0.1:8080".parse().unwrap(), ws_stream, state.clone()).await.unwrap(); + }); + + // Create a client-side TCP stream + let client_socket = TcpStream::connect("127.0.0.1:8080").await.unwrap(); + let (mut client_ws_stream, _) = ServerBuilder::new().accept(client_socket).await.unwrap(); // Correctly destructuring the tuple + + // Simulate client sending a message + let msg = Message::text("/join username"); + client_ws_stream.send(msg).await.unwrap(); // Client sends a join message + + // Simulate receiving the broadcasted message from the server + if let Some(Ok(received)) = client_ws_stream.next().await { + assert!(received.as_text().unwrap().contains("username has joined the chat")); + } + } +} + diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..c8ac58c --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,15 @@ +use diesel::prelude::*; +use dotenvy::dotenv; +use std::env; + +pub(crate) mod models; +pub(crate) mod schema; +pub mod webutils; + +pub fn establish_connection() -> PgConnection { + dotenv().ok(); + + let database_url = env::var("DATABASE_URL").expect("DATABASE_URL must be set"); + PgConnection::establish(&database_url) + .unwrap_or_else(|_| panic!("Error connecting to {}", database_url)) +} \ No newline at end of file diff --git a/src/models.rs b/src/models.rs new file mode 100644 index 0000000..d823997 --- /dev/null +++ b/src/models.rs @@ -0,0 +1,21 @@ +use diesel::prelude::*; + +#[derive(Queryable, Selectable)] +#[diesel(table_name = crate::schema::posts)] +#[diesel(check_for_backend(diesel::pg::Pg))] +pub struct Post { + pub id: i32, + pub title: String, + pub body: String, + pub published: bool, +} + +#[derive(Queryable, Selectable)] +#[diesel(table_name = crate::schema::users)] +#[diesel(check_for_backend(diesel::pg::Pg))] +pub struct User { + pub id: i32, + pub username: String, + pub addr: String, + //ws_stream: WebSocketStream, // do we need this? +} \ No newline at end of file diff --git a/src/schema.rs b/src/schema.rs new file mode 100644 index 0000000..9b61917 --- /dev/null +++ b/src/schema.rs @@ -0,0 +1,16 @@ +diesel::table! { + posts (id) { + id -> Int4, + title -> Varchar, + body -> Text, + published -> Bool, + } +} + +diesel::table! { + users (id) { + id -> Int4, + username -> Varchar, + addr -> Varchar, + } +} \ No newline at end of file diff --git a/src/webutils.rs b/src/webutils.rs new file mode 100644 index 0000000..1ebeefa --- /dev/null +++ b/src/webutils.rs @@ -0,0 +1,129 @@ + +use std::collections::HashMap; +use std::net::SocketAddr; +use actix_web::{web, Responder, HttpResponse}; +//use crate::SwarmWebMessage; +use tokio::{sync::mpsc}; +use serde::{Serialize, Deserialize}; +use tokio::sync::{broadcast::{channel, Sender}, Mutex}; +use std::sync::Arc; + +//#[path = "../src/bin/server.rs"] +//mod server; + +#[derive(Deserialize)] +pub struct MyQueryParams { + user_name: String, +} + +// Structure to hold user data +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct User { + pub username: String, + pub addr: SocketAddr, + //ws_stream: WebSocketStream, // do we need this? + pub lifetime_cnt: i32, +} + +// Structure to hold server state +#[derive(Debug, Clone)] +pub struct ServerState { + pub users: Arc>>, + //pub users: Mutex> , // map addr to username + //users: Mutex> , // map addr to username + pub bcast_tx: Sender, // broadcast channel for sending messages to all users +} + +//Make sure to broadcast to all others except sender +impl ServerState { + pub async fn broadcast_message(&self, addr: &SocketAddr, message: String) { + let users = self.users.lock().await; + if let Some(sname) = users.get(addr).map(|t| &t.username){ + let sender_name = sname; + let full_msg = format!("{}: {}", sender_name, message); + + for (user_addr, _) in users.iter() { + if user_addr != addr { + self.bcast_tx.send(full_msg.clone()).unwrap(); + } + } + } + /*let sender_name = users.get(addr).as_ref().0.unwrap(); + let full_msg = format!("{}: {}", sender_name, message); + + for (user_addr, _) in users.iter() { + if user_addr != addr { + self.bcast_tx.send(full_msg.clone()).unwrap(); + } + }*/ + } +} + +//get handler +pub async fn statsall(query: web::Query, state: web::Data>) -> impl Responder { + let users_map = state.users.lock().await; // be careful with unwrap + let count = users_map.len(); + let s_users_map = users_map.clone(); + for (sock_add, U) in s_users_map.into_iter() { + println!("{} / {}", U.username, U.lifetime_cnt); + } + HttpResponse::Ok().body(format!("Total users: {}", count)) + //HttpResponse::Ok().body(format!("Data sent to libp2p swarm: {}", users.keys().len())) +} + +//get handler +pub async fn index(query: web::Query) -> impl Responder { + let name = &query.user_name; + + HttpResponse::Ok().body(format!("Data sent to libp2p swarm: {}", name)) + + //let sender_clone = sender.get_ref().clone(); + + // Send the query parameter data to the libp2p swarm through the sender channel + /*if sender_clone.send(SwarmWebMessage::DataGet(name.to_owned())).await.is_ok() { + HttpResponse::Ok().body(format!("Data sent to libp2p swarm: {}", name)) +} else { + HttpResponse::InternalServerError().body("Failed to send data to libp2p swarm") +}*/ +} + +pub async fn indexPost(query: web::Json) -> impl Responder { + /*let model_type = &query.mtype; +let model_loc = &query.mlocation; +let model_data_loc = &query.mdataloc; +let model_algo = &query.malgo; +let model_output_loc = &query.moutputloc;*/ + + HttpResponse::Ok().body("Data sent to libp2p swarm") + //let sender_clone = sender.get_ref().clone(); + + /*let metadata = serde_json::json!( + { + "mtype": model_type, + "mlocation": model_loc, //e.g. s3://picxelate/dgp/create_model.py + "mdataloc": model_data_loc, + "malgo": model_algo, + "moutputloc": model_output_loc, + } + );*/ + //e.g. s3://picxelate/dgp/create_model.py + /*let metadata = r#"{ + "mtype": model_type, + "mlocation": model_loc, + "mdataloc": model_data_loc, + "malgo": model_algo, + "moutputloc": model_output_loc, + }"#;*/ + + //do metadata check about known peers + + // Send the query parameter data to the libp2p swarm through the sender channel + //if sender_clone.send(SwarmWebMessage::Data(model_loc.to_owned())).await.is_ok() { + /*if sender_clone.send(SwarmWebMessage::Data(query)).await.is_ok() { + //HttpResponse::Ok().body(format!("Data sent to libp2p swarm: {}", model_loc)) + HttpResponse::Ok().body("Data sent to libp2p swarm") + +} else { + HttpResponse::InternalServerError().body("Failed to send data to libp2p swarm") +}*/ +} diff --git a/tests/integration.rs b/tests/integration.rs new file mode 100644 index 0000000..1ca3867 --- /dev/null +++ b/tests/integration.rs @@ -0,0 +1,76 @@ +use tokio_websockets::{ClientBuilder, Message}; // Import Message from the crate +use http::Uri; +use futures_util::{stream::StreamExt, SinkExt}; +//use futures_util::sink::SinkExt; + +#[path = "../src/bin/server.rs"] +mod server; + +#[path = "../src/bin/client.rs"] +mod client; + +#[tokio::test] +async fn test_chat_interaction() { + let server_handle = tokio::spawn(async { + // Start the server + server::main().await.unwrap(); + }); + + let client1_handle = tokio::spawn(async { + let mut ws_stream = ClientBuilder::from_uri(Uri::from_static("ws://127.0.0.1:2000")) + .connect() + .await + .unwrap(); + + let (mut ws_stream, _) = ws_stream; + + ws_stream.send(Message::text("/join user1")).await.unwrap(); + + let received = ws_stream.next().await.unwrap(); + if let Ok(msg) = received { + assert!(msg.as_text().unwrap().contains("INFO: user1 has joined")); + } else { + panic!("Failed to receive message"); + } + + ws_stream.send(Message::text("send Hello, this is user1")).await.unwrap(); + + let received = ws_stream.next().await.unwrap(); + if let Ok(msg) = received { + assert!(msg.as_text().unwrap().contains("Hello, this is user1")); + } else { + panic!("Failed to receive message"); + } + + //assert!(received.as_text().unwrap().contains("INFO: user1 has joined")); + //assert!(received.as_text().unwrap().contains("Hello, this is user1")); + }); + + let client2_handle = tokio::spawn(async { + let mut ws_stream = ClientBuilder::from_uri(Uri::from_static("ws://127.0.0.1:2000")) + .connect() + .await + .unwrap(); + + let (mut ws_stream, _) = ws_stream; + ws_stream.send(Message::text("/join user2")).await.unwrap(); + let received = ws_stream.next().await.unwrap(); + if let Ok(msg) = received { + assert!(msg.as_text().unwrap().contains("INFO: user2 has joined")); + } else { + panic!("Failed to receive message"); + } + //assert!(received.as_text().unwrap().contains("INFO: user1 has joined")); + + ws_stream.send(Message::text("send Hello from user2")).await.unwrap(); + let received = ws_stream.next().await.unwrap(); + if let Ok(msg) = received { + assert!(msg.as_text().unwrap().contains("Hello from user2")); + } else { + panic!("Failed to receive message"); + } + //assert!(received[0].as_text().unwrap().contains("Hello from user2")); + }); + + let _ = tokio::join!(server_handle, client1_handle, client2_handle); +}