diff --git a/shared/src/lib.rs b/shared/src/lib.rs index 4585321..4758cbb 100644 --- a/shared/src/lib.rs +++ b/shared/src/lib.rs @@ -66,8 +66,6 @@ pub enum ControlMsg { message: String, timestamp: u64, }, - /// Keep-alive heartbeat - Heartbeat, /// General Error Error { message: String, diff --git a/src/handlers.rs b/src/handlers.rs index 930089a..9785077 100644 --- a/src/handlers.rs +++ b/src/handlers.rs @@ -37,18 +37,16 @@ pub async fn handle_socket(mut socket: WebSocket, _addr: SocketAddr, state: AppS }).collect(); // Notify others - let _ = room.tx.send(crate::state::RoomMessage { - from_user_id: uid, - content: crate::state::RoomMessageContent::Control(ControlMsg::PeerJoined { - user_id: uid, - display_name: display_name.clone(), - }) + let _ = room.tx.send(ControlMsg::PeerJoined { + user_id: uid, + display_name: display_name.clone(), }); // Add self to room room.peers.insert(uid, Peer { id: uid, display_name: display_name.clone(), + addr: None, }); user_id = Some(uid); @@ -60,9 +58,7 @@ pub async fn handle_socket(mut socket: WebSocket, _addr: SocketAddr, state: AppS room_code: room_code.clone(), peers: peers_list, }; - let json = serde_json::to_string(&resp).unwrap(); - info!("Sending Joined response to {}: {}", uid, json); - if let Err(e) = sender.send(Message::Text(json.into())).await { + if let Err(e) = sender.send(Message::Text(serde_json::to_string(&resp).unwrap().into())).await { error!("Failed to send Joined response: {}", e); return; } @@ -95,107 +91,62 @@ pub async fn handle_socket(mut socket: WebSocket, _addr: SocketAddr, state: AppS match msg { Some(Ok(Message::Text(text))) => { if let Ok(control) = serde_json::from_str::(&text) { - // Broadcast control msg to room - if let Some(room) = state.rooms.get(&rid) { - // Some control messages might need adjustment or just raw forwarding? - // For chat/updateStream, we usually just want to forward but ensure the from_user_id is set correctly. - // But wait, the client sends "ChatMessage" with fields. - // We should trust the authenticated UID or override it? - // Overriding/Verified is safer. - - let verified_msg = match control { - ControlMsg::UpdateStream { stream_id, active, media_type, .. } => { - ControlMsg::UpdateStream { - user_id: uid, - stream_id, - active, - media_type, - } - }, - ControlMsg::ChatMessage { message, .. } => { - ControlMsg::ChatMessage { - user_id: uid, - display_name: room.peers.get(&uid).map(|p| p.display_name.clone()).unwrap_or_default(), - message, - timestamp: std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_millis() as u64, - } - }, - ControlMsg::Heartbeat => { - // Echo back to sender for latency/connectivity check - let _ = sender.send(Message::Text(serde_json::to_string(&ControlMsg::Heartbeat).unwrap().into())).await; - ControlMsg::Heartbeat - }, - _ => control, - }; - - if let Err(e) = room.tx.send(crate::state::RoomMessage { - from_user_id: uid, - content: crate::state::RoomMessageContent::Control(verified_msg), - }) { - error!("Failed to broadcast control message from {}: {}", uid, e); + match control { + ControlMsg::UpdateStream { stream_id, active, media_type, .. } => { + // Broadcast to room with sender's user_id + let update = ControlMsg::UpdateStream { + user_id: uid, + stream_id, + active, + media_type, + }; + if let Some(room) = state.rooms.get(&rid) { + let _ = room.tx.send(update); + } } + ControlMsg::ChatMessage { message, display_name, .. } => { + // Broadcast chat with sender info + let chat = ControlMsg::ChatMessage { + user_id: uid, + display_name: state.rooms.get(&rid) + .and_then(|r| r.peers.get(&uid).map(|p| p.display_name.clone())) + .unwrap_or(display_name), // Fallback to provided name + message, + timestamp: std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_millis() as u64, + }; + if let Some(room) = state.rooms.get(&rid) { + let _ = room.tx.send(chat); + } + } + _ => {} } } } - Some(Ok(Message::Binary(data))) => { - // Binary Media Data (Audio/Video/Screen) - if let Some(room) = state.rooms.get(&rid) { - // Diagnostics: Echo back if it starts with 0xAA (our probe marker) - if !data.is_empty() && data[0] == 0xAA { - let _ = sender.send(Message::Binary(data.clone())).await; - } - - if let Err(e) = room.tx.send(crate::state::RoomMessage { - from_user_id: uid, - content: crate::state::RoomMessageContent::Media(data.to_vec()), - }) { - error!("Failed to broadcast binary media from {}: {}", uid, e); - } - } - } Some(Ok(Message::Close(_))) => break, - Some(Err(e)) => { - error!("WS receive error from {}: {}", uid, e); - break; - } + Some(Err(_)) => break, None => break, _ => {} } } Ok(msg) = rx.recv() => { - // Determine if we should send this message to the client - if msg.from_user_id != uid { - match msg.content { - crate::state::RoomMessageContent::Control(c) => { - if let Err(e) = sender.send(Message::Text(serde_json::to_string(&c).unwrap().into())).await { - error!("Failed to relay control msg to {}: {}", uid, e); - } - }, - crate::state::RoomMessageContent::Media(data) => { - // Only log occasionally to avoid flooding - if rand::random::() % 500 == 0 { - info!("Relaying media chunk to user {}: {} bytes", uid, data.len()); - } - if let Err(e) = sender.send(Message::Binary(data.into())).await { - error!("Failed to relay media to {}: {}", uid, e); - } - } - } - } + // Forward broadcast to client + let _ = sender.send(Message::Text(serde_json::to_string(&msg).unwrap().into())).await; } } } // Cleanup if let Some(room) = state.rooms.get(&rid) { - room.peers.remove(&uid); - let _ = room.tx.send(crate::state::RoomMessage { - from_user_id: uid, - content: crate::state::RoomMessageContent::Control(ControlMsg::PeerLeft { user_id: uid }), - }); + if let Some((_, peer)) = room.peers.remove(&uid) { + // Clean up address mapping if present + if let Some(addr) = peer.addr { + state.peers_by_addr.remove(&addr); + } + } + let _ = room.tx.send(ControlMsg::PeerLeft { user_id: uid }); } info!("User {} left room {}", uid, rid); diff --git a/src/main.rs b/src/main.rs index da7dc92..50e5df6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,16 +1,17 @@ use axum::{ - extract::{State, ConnectInfo}, - extract::ws::WebSocketUpgrade, + extract::{State, Request, ConnectInfo}, + extract::ws::{WebSocketUpgrade, WebSocket, Message}, // Explicit import response::IntoResponse, routing::get, Router, }; use std::net::SocketAddr; -use tracing::info; +use tokio::net::UdpSocket; +use tracing::{info, error}; mod state; mod handlers; - +mod udp; use state::AppState; #[tokio::main] @@ -19,7 +20,13 @@ async fn main() -> anyhow::Result<()> { let state = AppState::new(); - + // Spawn UDP Server + let udp_state = state.clone(); + tokio::spawn(async move { + if let Err(e) = udp::run_udp_server(udp_state).await { + error!("UDP server error: {}", e); + } + }); // HTTP/WS Server let app = Router::new() diff --git a/src/state.rs b/src/state.rs index 02792d1..205a453 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,52 +1,48 @@ use dashmap::DashMap; -use shared::UserId; +use shared::{ControlMsg, UserId}; +use std::net::SocketAddr; use std::sync::Arc; use tokio::sync::broadcast; #[derive(Debug, Clone)] pub struct AppState { pub rooms: Arc>, + pub peers_by_addr: Arc>, } - +#[derive(Debug, Clone)] +pub struct PeerLocation { + pub room_id: String, + pub user_id: UserId, +} #[derive(Debug, Clone)] pub struct Room { pub id: String, pub peers: DashMap, // Channel for broadcasting control messages within the room - pub tx: broadcast::Sender, -} - -#[derive(Clone, Debug)] -pub struct RoomMessage { - pub from_user_id: UserId, - pub content: RoomMessageContent, -} - -#[derive(Clone, Debug)] -pub enum RoomMessageContent { - Control(shared::ControlMsg), - Media(Vec), + pub tx: broadcast::Sender, } #[derive(Debug, Clone)] pub struct Peer { pub id: UserId, pub display_name: String, + pub addr: Option, // UDP address } impl AppState { pub fn new() -> Self { Self { rooms: Arc::new(DashMap::new()), + peers_by_addr: Arc::new(DashMap::new()), } } } impl Room { pub fn new(id: String) -> Self { - let (tx, _) = broadcast::channel(1024); + let (tx, _) = broadcast::channel(100); Self { id, peers: DashMap::new(), diff --git a/src/udp.rs b/src/udp.rs new file mode 100644 index 0000000..2c65b63 --- /dev/null +++ b/src/udp.rs @@ -0,0 +1,120 @@ +use std::sync::Arc; +use tokio::net::UdpSocket; +use tracing::{info, warn, error}; +use shared::{PacketHeader, MediaType, UserId}; +use bincode; +use crate::state::{AppState, PeerLocation}; + +#[derive(serde::Serialize, serde::Deserialize)] +struct Handshake { + user_id: UserId, + room_code: String, +} + +pub async fn run_udp_server(state: AppState) -> anyhow::Result<()> { + let socket = Arc::new(UdpSocket::bind("0.0.0.0:4000").await?); + info!("UDP Server listening on 0.0.0.0:4000"); + + let mut buf = [0u8; 65535]; + + loop { + match socket.recv_from(&mut buf).await { + Ok((len, addr)) => { + let data = &buf[..len]; + info!("UDP RECV from {}: {} bytes", addr, len); + + // Manually parse header (24 bytes) to match client's raw byte layout: + // Byte 0: version (u8) + // Byte 1: media_type (u8) + // Bytes 2-5: user_id (u32 LE) + // Bytes 6-9: sequence (u32 LE) + // Bytes 10-17: timestamp (u64 LE) + // Bytes 18-19: fragment_index (u16 LE) + // Bytes 20-21: fragment_count (u16 LE) + // Bytes 22-23: flags (u16 LE) + + if data.len() < 24 { + warn!("UDP packet too small: {} bytes from {}", data.len(), addr); + continue; + } + + let version = data[0]; + let media_type_byte = data[1]; + let user_id = u32::from_le_bytes([data[2], data[3], data[4], data[5]]); + let _sequence = u32::from_le_bytes([data[6], data[7], data[8], data[9]]); + let _timestamp = u64::from_le_bytes([data[10], data[11], data[12], data[13], data[14], data[15], data[16], data[17]]); + let _fragment_index = u16::from_le_bytes([data[18], data[19]]); + let _fragment_count = u16::from_le_bytes([data[20], data[21]]); + let _flags = u16::from_le_bytes([data[22], data[23]]); + + let media_type = match media_type_byte { + 0 => MediaType::Audio, + 1 => MediaType::Video, + 2 => MediaType::Screen, + 3 => MediaType::Command, + _ => MediaType::Unknown, + }; + + let payload = &data[24..]; + + + match media_type { + MediaType::Command => { + // Handshake + match bincode::deserialize::(payload) { + Ok(handshake) => { + // Validate User in Room + if let Some(room) = state.rooms.get(&handshake.room_code) { + if room.peers.contains_key(&handshake.user_id) { + // Update Address + state.peers_by_addr.insert(addr, PeerLocation { + room_id: handshake.room_code.clone(), + user_id: handshake.user_id, + }); + + // Update Peer in Room + if let Some(mut peer) = room.peers.get_mut(&handshake.user_id) { + peer.addr = Some(addr); + info!( + "UDP Handshake: User {} at {}, Room {}", + handshake.user_id, addr, handshake.room_code + ); + } + } + } + } + Err(e) => { + warn!("Failed to deserialize Handshake from {}: {}", addr, e); + } + } + } + _ => { + // Media Packet: Relay + if let Some(loc) = state.peers_by_addr.get(&addr) { + let room_id = &loc.room_id; + let sender_id: UserId = loc.user_id; + + // Forward to all valid peers in room + if let Some(room) = state.rooms.get(room_id) { + for peer in room.peers.iter() { + // Don't echo back to sender + if *peer.key() != sender_id { + if let Some(target_addr) = peer.value().addr { + // Send + let _ = socket.send_to(data, target_addr).await; + } + } + } + } + } else { + warn!("Dropping Relay Packet from unknown sender: {}", addr); + } + } + } + } + Err(e) => { + error!("UDP recv error: {}", e); + } + } + } +}