From ab7090f2c5537d927c83135dd281cf8fefe05aca Mon Sep 17 00:00:00 2001 From: Tobias Reisinger Date: Thu, 25 Apr 2024 19:45:22 +0200 Subject: [PATCH] Add relays websocket --- emgauwa-core/src/app_state.rs | 78 ++++++++++++- .../handlers/v1/ws/controllers/handlers.rs | 6 +- .../src/handlers/v1/ws/controllers/mod.rs | 12 +- emgauwa-core/src/handlers/v1/ws/mod.rs | 24 +++- emgauwa-core/src/handlers/v1/ws/relays/mod.rs | 107 ++++++++++++++++++ emgauwa-core/src/main.rs | 3 +- 6 files changed, 216 insertions(+), 14 deletions(-) create mode 100644 emgauwa-core/src/handlers/v1/ws/relays/mod.rs diff --git a/emgauwa-core/src/app_state.rs b/emgauwa-core/src/app_state.rs index feacab9..4edbd9d 100644 --- a/emgauwa-core/src/app_state.rs +++ b/emgauwa-core/src/app_state.rs @@ -1,12 +1,15 @@ use std::collections::HashMap; -use actix::{Actor, Context, Handler, Message, Recipient}; +use actix::{Actor, Addr, Context, Handler, Message, Recipient}; +use emgauwa_lib::db::DbController; use emgauwa_lib::errors::EmgauwaError; -use emgauwa_lib::models::Controller; +use emgauwa_lib::models::{convert_db_list, Controller, Relay}; use emgauwa_lib::types::{ControllerUid, ControllerWsAction, RelayStates}; use futures::executor::block_on; use sqlx::{Pool, Sqlite}; +use crate::handlers::v1::ws::relays::{RelaysWs, SendRelays}; + #[derive(Message)] #[rtype(result = "Result<(), EmgauwaError>")] pub struct DisconnectController { @@ -27,6 +30,10 @@ pub struct UpdateRelayStates { pub relay_states: RelayStates, } +#[derive(Message)] +#[rtype(result = "Result, EmgauwaError>")] +pub struct GetRelays {} + #[derive(Message)] #[rtype(result = "Result<(), EmgauwaError>")] pub struct Action { @@ -34,9 +41,16 @@ pub struct Action { pub action: ControllerWsAction, } +#[derive(Message)] +#[rtype(result = "()")] +pub struct ConnectRelayClient { + pub addr: Addr, +} + pub struct AppState { pub pool: Pool, pub connected_controllers: HashMap)>, + pub connected_relay_clients: Vec>, } impl AppState { @@ -44,8 +58,51 @@ impl AppState { AppState { pool, connected_controllers: HashMap::new(), + connected_relay_clients: Vec::new(), } } + + fn get_relays(&self) -> Result, EmgauwaError> { + let mut pool_conn = block_on(self.pool.acquire())?; + let db_controllers = block_on(DbController::get_all(&mut pool_conn))?; + let mut controllers: Vec = convert_db_list(&mut pool_conn, db_controllers)?; + + self.connected_controllers + .iter() + .for_each(|(uid, (connected_controller, _))| { + if let Some(c) = controllers.iter_mut().find(|c| c.c.uid == *uid) { + c.apply_relay_states(&connected_controller.get_relay_states()); + } + }); + + let mut relays: Vec = Vec::new(); + controllers.iter().for_each(|c| { + relays.extend(c.relays.clone()); + }); + + Ok(relays) + } + + fn notify_relay_clients(&mut self) { + self.connected_relay_clients.retain(|addr| addr.connected()); + + match self.get_relays() { + Ok(relays) => match serde_json::to_string(&relays) { + Ok(json) => { + self.connected_relay_clients.iter_mut().for_each(|addr| { + let relays_json = json.clone(); + addr.do_send(SendRelays { relays_json }); + }); + } + Err(err) => { + log::error!("Failed to serialize relays: {:?}", err); + } + }, + Err(err) => { + log::error!("Failed to get relays: {:?}", err); + } + }; + } } impl Actor for AppState { @@ -94,6 +151,15 @@ impl Handler for AppState { if let Some((controller, _)) = self.connected_controllers.get_mut(&msg.controller_uid) { controller.apply_relay_states(&msg.relay_states); } + self.notify_relay_clients(); + } +} + +impl Handler for AppState { + type Result = Result, EmgauwaError>; + + fn handle(&mut self, _msg: GetRelays, _ctx: &mut Self::Context) -> Self::Result { + self.get_relays() } } @@ -112,3 +178,11 @@ impl Handler for AppState { } } } + +impl Handler for AppState { + type Result = (); + + fn handle(&mut self, msg: ConnectRelayClient, _ctx: &mut Self::Context) -> Self::Result { + self.connected_relay_clients.push(msg.addr); + } +} diff --git a/emgauwa-core/src/handlers/v1/ws/controllers/handlers.rs b/emgauwa-core/src/handlers/v1/ws/controllers/handlers.rs index 335f7b1..5f763e5 100644 --- a/emgauwa-core/src/handlers/v1/ws/controllers/handlers.rs +++ b/emgauwa-core/src/handlers/v1/ws/controllers/handlers.rs @@ -9,13 +9,13 @@ use sqlx::pool::PoolConnection; use sqlx::Sqlite; use crate::app_state::{ConnectController, UpdateRelayStates}; -use crate::handlers::v1::ws::controllers::ControllerWs; +use crate::handlers::v1::ws::controllers::ControllersWs; -impl ControllerWs { +impl ControllersWs { pub fn handle_register( &mut self, conn: &mut PoolConnection, - ctx: &mut ::Context, + ctx: &mut ::Context, controller: Controller, ) -> Result<(), EmgauwaError> { log::info!( diff --git a/emgauwa-core/src/handlers/v1/ws/controllers/mod.rs b/emgauwa-core/src/handlers/v1/ws/controllers/mod.rs index fc95439..88aaf24 100644 --- a/emgauwa-core/src/handlers/v1/ws/controllers/mod.rs +++ b/emgauwa-core/src/handlers/v1/ws/controllers/mod.rs @@ -16,14 +16,14 @@ use ws::Message; use crate::app_state::{AppState, DisconnectController}; use crate::utils::flatten_result; -pub struct ControllerWs { +pub struct ControllersWs { pub pool: Pool, pub controller_uid: Option, pub app_state: Addr, pub hb: Instant, } -impl Actor for ControllerWs { +impl Actor for ControllersWs { type Context = ws::WebsocketContext; fn started(&mut self, ctx: &mut Self::Context) { @@ -45,11 +45,11 @@ impl Actor for ControllerWs { } } -impl ControllerWs { +impl ControllersWs { pub fn handle_action( &mut self, conn: &mut PoolConnection, - ctx: &mut ::Context, + ctx: &mut ::Context, action: ControllerWsAction, ) { let action_res = match action { @@ -83,7 +83,7 @@ impl ControllerWs { } } -impl Handler for ControllerWs { +impl Handler for ControllersWs { type Result = Result<(), EmgauwaError>; fn handle(&mut self, action: ControllerWsAction, ctx: &mut Self::Context) -> Self::Result { @@ -101,7 +101,7 @@ impl Handler for ControllerWs { } } -impl StreamHandler> for ControllerWs { +impl StreamHandler> for ControllersWs { fn handle(&mut self, msg: Result, ctx: &mut Self::Context) { let mut pool_conn = match block_on(self.pool.acquire()) { Ok(conn) => conn, diff --git a/emgauwa-core/src/handlers/v1/ws/mod.rs b/emgauwa-core/src/handlers/v1/ws/mod.rs index cda5a47..4fb3bf7 100644 --- a/emgauwa-core/src/handlers/v1/ws/mod.rs +++ b/emgauwa-core/src/handlers/v1/ws/mod.rs @@ -7,9 +7,11 @@ use emgauwa_lib::errors::EmgauwaError; use sqlx::{Pool, Sqlite}; use crate::app_state::AppState; -use crate::handlers::v1::ws::controllers::ControllerWs; +use crate::handlers::v1::ws::controllers::ControllersWs; +use crate::handlers::v1::ws::relays::RelaysWs; pub mod controllers; +pub mod relays; #[get("/ws/controllers")] pub async fn ws_controllers( @@ -19,7 +21,7 @@ pub async fn ws_controllers( stream: web::Payload, ) -> Result { let resp = ws::start( - ControllerWs { + ControllersWs { pool: pool.get_ref().clone(), controller_uid: None, app_state: app_state.get_ref().clone(), @@ -31,3 +33,21 @@ pub async fn ws_controllers( .map_err(|_| EmgauwaError::Internal(String::from("error starting websocket"))); resp } + +#[get("/ws/relays")] +pub async fn ws_relays( + app_state: web::Data>, + req: HttpRequest, + stream: web::Payload, +) -> Result { + let resp = ws::start( + RelaysWs { + app_state: app_state.get_ref().clone(), + hb: Instant::now(), + }, + &req, + stream, + ) + .map_err(|_| EmgauwaError::Internal(String::from("error starting websocket"))); + resp +} diff --git a/emgauwa-core/src/handlers/v1/ws/relays/mod.rs b/emgauwa-core/src/handlers/v1/ws/relays/mod.rs new file mode 100644 index 0000000..e1208e7 --- /dev/null +++ b/emgauwa-core/src/handlers/v1/ws/relays/mod.rs @@ -0,0 +1,107 @@ +use std::time::Instant; + +use actix::{Actor, ActorContext, Addr, AsyncContext, Handler, Message, StreamHandler}; +use actix_web_actors::ws; +use actix_web_actors::ws::ProtocolError; +use emgauwa_lib::constants::{HEARTBEAT_INTERVAL, HEARTBEAT_TIMEOUT}; +use emgauwa_lib::errors::EmgauwaError; +use futures::executor::block_on; + +use crate::app_state::{AppState, ConnectRelayClient}; + +pub struct RelaysWs { + pub app_state: Addr, + pub hb: Instant, +} + +#[derive(Message)] +#[rtype(result = "()")] +pub struct SendRelays { + pub relays_json: String, +} + +impl Actor for RelaysWs { + type Context = ws::WebsocketContext; + + fn started(&mut self, ctx: &mut Self::Context) { + // get unique id for ctx + match self.get_relays_json() { + Ok(relays_json) => { + ctx.text(relays_json); + self.hb(ctx); + + block_on(self.app_state.send(ConnectRelayClient { + addr: ctx.address(), + })) + .unwrap(); + } + Err(err) => { + log::error!("Error getting relays: {:?}", err); + ctx.stop(); + return; + } + } + } +} + +impl RelaysWs { + fn get_relays_json(&self) -> Result { + let relays = block_on(self.app_state.send(crate::app_state::GetRelays {}))??; + serde_json::to_string(&relays).map_err(EmgauwaError::from) + } + + // helper method that sends ping to client every 5 seconds (HEARTBEAT_INTERVAL). + fn hb(&self, ctx: &mut ws::WebsocketContext) { + ctx.run_interval(HEARTBEAT_INTERVAL, |act, ctx| { + // check client heartbeats + if Instant::now().duration_since(act.hb) > HEARTBEAT_TIMEOUT { + log::debug!("Websocket Relay heartbeat failed, disconnecting!"); + ctx.stop(); + // don't try to send a ping + return; + } + + ctx.ping(&[]); + }); + } +} + +impl StreamHandler> for RelaysWs { + fn handle(&mut self, msg: Result, ctx: &mut Self::Context) { + let msg = match msg { + Err(_) => { + ctx.stop(); + return; + } + Ok(msg) => msg, + }; + + match msg { + ws::Message::Ping(msg) => { + self.hb = Instant::now(); + ctx.pong(&msg) + } + ws::Message::Pong(_) => { + self.hb = Instant::now(); + } + ws::Message::Text(_) => log::debug!("Received unexpected text in relays ws"), + ws::Message::Binary(_) => log::debug!("Received unexpected binary in relays ws"), + ws::Message::Close(reason) => { + ctx.close(reason); + ctx.stop(); + } + ws::Message::Continuation(_) => { + ctx.stop(); + } + ws::Message::Nop => (), + } + } +} + +impl Handler for RelaysWs { + type Result = (); + + fn handle(&mut self, msg: SendRelays, ctx: &mut Self::Context) -> Self::Result { + ctx.text(msg.relays_json); + } +} diff --git a/emgauwa-core/src/main.rs b/emgauwa-core/src/main.rs index d1b0d31..5fca1e8 100644 --- a/emgauwa-core/src/main.rs +++ b/emgauwa-core/src/main.rs @@ -107,7 +107,8 @@ async fn main() -> Result<(), std::io::Error> { .service(handlers::v1::tags::show) .service(handlers::v1::tags::delete) .service(handlers::v1::tags::add) - .service(handlers::v1::ws::ws_controllers), + .service(handlers::v1::ws::ws_controllers) + .service(handlers::v1::ws::ws_relays), ) }) .listen(listener)?