use crate::constants::{HEARTBEAT_INTERVAL, HEARTBEAT_TIMEOUT}; use crate::db::errors::DatabaseError; use crate::db::{DbController, DbRelay}; use crate::models::{Controller, FromDbModel}; use crate::types::{ConnectedControllersType, ControllerUid}; use actix::{Actor, ActorContext, AsyncContext, StreamHandler}; use actix_web_actors::ws; use actix_web_actors::ws::ProtocolError; use serde_derive::{Deserialize, Serialize}; use sqlx::pool::PoolConnection; use sqlx::{Pool, Sqlite}; use std::time::Instant; use ws::Message; #[derive(Debug, Serialize, Deserialize)] pub enum ControllerWsAction { Register(Controller), } pub struct ControllerWs { pub pool: Pool<Sqlite>, pub controller_uid: Option<ControllerUid>, pub connected_controllers: ConnectedControllersType, pub hb: Instant, } impl Actor for ControllerWs { type Context = ws::WebsocketContext<Self>; fn started(&mut self, ctx: &mut Self::Context) { self.hb(ctx); } fn stopped(&mut self, _ctx: &mut Self::Context) { if let Some(controller_uid) = &self.controller_uid { let mut pool_conn = futures::executor::block_on(self.pool.acquire()).unwrap(); let mut data = self.connected_controllers.lock().unwrap(); if let Some(controller) = data.remove(controller_uid) { futures::executor::block_on(controller.c.update_active(&mut pool_conn, false)) .unwrap(); } } } } impl ControllerWs { pub fn handle_action( &mut self, conn: &mut PoolConnection<Sqlite>, action: ControllerWsAction, ) -> Result<(), DatabaseError> { match action { ControllerWsAction::Register(controller) => { log::info!("Registering controller: {:?}", controller); let c = &controller.c; let controller_db = futures::executor::block_on( DbController::get_by_uid_or_create(conn, &c.uid, &c.name, c.relay_count), )?; futures::executor::block_on(controller_db.update_active(conn, true))?; for relay in &controller.relays { let r = &relay.r; futures::executor::block_on(DbRelay::get_by_controller_and_num_or_create( conn, &controller_db, r.number, &r.name, ))?; } let controller = Controller::from_db_model(conn, controller_db)?; let controller_uid = &controller.c.uid; self.controller_uid = Some(controller_uid.clone()); let mut data = self.connected_controllers.lock().unwrap(); data.insert(controller_uid.clone(), controller); Ok(()) } } } /// helper method that sends ping to client every 5 seconds (HEARTBEAT_INTERVAL). /// /// also this method checks heartbeats from client fn hb(&self, ctx: &mut ws::WebsocketContext<Self>) { ctx.run_interval(HEARTBEAT_INTERVAL, |act, ctx| { // check client heartbeats if Instant::now().duration_since(act.hb) > HEARTBEAT_TIMEOUT { log::warn!("Websocket Controller heartbeat failed, disconnecting!"); ctx.stop(); // don't try to send a ping return; } ctx.ping(&[]); }); } } impl StreamHandler<Result<Message, ProtocolError>> for ControllerWs { fn handle(&mut self, msg: Result<Message, ProtocolError>, ctx: &mut Self::Context) { let mut pool_conn = futures::executor::block_on(self.pool.acquire()).unwrap(); let msg = match msg { Err(_) => { ctx.stop(); return; } Ok(msg) => msg, }; match msg { Message::Ping(msg) => { self.hb = Instant::now(); ctx.pong(&msg) } Message::Pong(_) => { self.hb = Instant::now(); } Message::Text(text) => { let action: ControllerWsAction = serde_json::from_str(&text).unwrap(); let action_res = self.handle_action(&mut pool_conn, action); if let Err(e) = action_res { log::error!("Error handling action: {:?}", e); ctx.text(serde_json::to_string(&e).unwrap()); } } Message::Binary(_) => log::warn!("Received unexpected binary in controller ws"), Message::Close(reason) => { ctx.close(reason); ctx.stop(); } Message::Continuation(_) => { ctx.stop(); } Message::Nop => (), } //let schedules = futures::executor::block_on(DbSchedule::get_all(&mut pool_conn)).unwrap(); //let schedules_json = serde_json::to_string(&schedules).unwrap(); //ctx.text(schedules_json); } }