mod handlers; use std::time::Instant; use actix::{Actor, ActorContext, Addr, AsyncContext, Handler, StreamHandler}; use actix_web_actors::ws; use actix_web_actors::ws::ProtocolError; use emgauwa_common::constants::{HEARTBEAT_INTERVAL, HEARTBEAT_TIMEOUT}; use emgauwa_common::errors::EmgauwaError; use emgauwa_common::types::{ControllerWsAction, EmgauwaUid}; use futures::executor::block_on; use sqlx::pool::PoolConnection; use sqlx::{Pool, Sqlite}; use ws::Message; use crate::app_state::{AppState, DisconnectController}; use crate::utils::flatten_result; pub struct ControllersWs { pub pool: Pool<Sqlite>, pub controller_uid: Option<EmgauwaUid>, pub app_state: Addr<AppState>, pub hb: Instant, } impl Actor for ControllersWs { type Context = ws::WebsocketContext<Self>; fn started(&mut self, ctx: &mut Self::Context) { self.hb(ctx); } fn stopped(&mut self, _ctx: &mut Self::Context) { if let Some(controller_uid) = &self.controller_uid { let flat_res = flatten_result( block_on(self.app_state.send(DisconnectController { controller_uid: controller_uid.clone(), })) .map_err(EmgauwaError::from), ); if let Err(err) = flat_res { log::error!("Error disconnecting controller: {:?}", err); } } } } impl ControllersWs { pub fn handle_action( &mut self, conn: &mut PoolConnection<Sqlite>, ctx: &mut <ControllersWs as Actor>::Context, action: ControllerWsAction, ) { let action_res = match action { ControllerWsAction::Register(controller) => self.handle_register(conn, ctx, controller), ControllerWsAction::RelayStates((controller_uid, relay_states)) => { self.handle_relay_states(controller_uid, relay_states) } _ => Ok(()), }; if let Err(e) = action_res { log::error!("Error handling action: {:?}", e); ctx.text( serde_json::to_string(&e).unwrap_or(format!("Error in handling action: {:?}", e)), ); } } // helper method that sends ping to client every 5 seconds (HEARTBEAT_INTERVAL). fn hb(&self, ctx: &mut ws::WebsocketContext<Self>) { ctx.run_interval(HEARTBEAT_INTERVAL, |act, ctx| { // check client heartbeats if Instant::now().duration_since(act.hb) > HEARTBEAT_TIMEOUT { log::warn!("Websocket Controller heartbeat failed, disconnecting!"); ctx.stop(); // don't try to send a ping return; } ctx.ping(&[]); }); } } impl Handler<ControllerWsAction> for ControllersWs { type Result = Result<(), EmgauwaError>; fn handle(&mut self, action: ControllerWsAction, ctx: &mut Self::Context) -> Self::Result { match action { ControllerWsAction::Disconnect => { ctx.close(None); ctx.stop(); } _ => { let action_json = serde_json::to_string(&action)?; ctx.text(action_json); } } Ok(()) } } impl StreamHandler<Result<Message, ProtocolError>> for ControllersWs { fn handle(&mut self, msg: Result<Message, ProtocolError>, ctx: &mut Self::Context) { let mut pool_conn = match block_on(self.pool.acquire()) { Ok(conn) => conn, Err(err) => { log::error!("Failed to acquire database connection: {:?}", err); ctx.stop(); return; } }; let msg = match msg { Err(_) => { ctx.stop(); return; } Ok(msg) => msg, }; match msg { Message::Ping(msg) => { self.hb = Instant::now(); ctx.pong(&msg) } Message::Pong(_) => { self.hb = Instant::now(); } Message::Text(text) => match serde_json::from_str(&text) { Ok(action) => { self.handle_action(&mut pool_conn, ctx, action); } Err(e) => { log::error!("Error deserializing action: {:?}", e); ctx.text( serde_json::to_string(&EmgauwaError::Serialization(e)) .unwrap_or(String::from("Error in deserializing action")), ); } }, Message::Binary(_) => log::warn!("Received unexpected binary in controller ws"), Message::Close(reason) => { ctx.close(reason); ctx.stop(); } Message::Continuation(_) => { ctx.stop(); } Message::Nop => (), } } }