diff --git a/Makefile b/Makefile index 80d21ba..20b80a2 100644 --- a/Makefile +++ b/Makefile @@ -19,5 +19,8 @@ clean-db: rm ./emgauwa-controller.sqlite || true $(MAKE) sqlx-prepare -fmt: +format: cargo +nightly fmt + +lint: + cargo clippy --all-targets --all-features -- -D warnings diff --git a/emgauwa-controller/src/main.rs b/emgauwa-controller/src/main.rs index 93a08f4..546822b 100644 --- a/emgauwa-controller/src/main.rs +++ b/emgauwa-controller/src/main.rs @@ -42,7 +42,7 @@ async fn create_this_relay( let relay = DbRelay::create( conn, &settings_relay.name, - settings_relay.number.unwrap(), + settings_relay.number.expect("Relay number is missing"), this_controller, ) .await?; @@ -55,6 +55,34 @@ async fn create_this_relay( Ok(relay) } +async fn run_websocket(this: Controller, url: &str) { + match connect_async(url).await { + Ok(connection) => { + let (ws_stream, _) = connection; + + let (mut write, read) = ws_stream.split(); + + let ws_action = ControllerWsAction::Register(this.clone()); + + let ws_action_json = + serde_json::to_string(&ws_action).expect("Failed to serialize action"); + if let Err(err) = write.send(Message::text(ws_action_json)).await { + log::error!("Failed to register at websocket: {}", err); + return; + } + + let read_handler = read.for_each(handle_message); + + read_handler.await; + + log::warn!("Lost connection to websocket"); + } + Err(err) => { + log::warn!("Failed to connect to websocket: {}", err,); + } + } +} + #[tokio::main] async fn main() { let settings = settings::init(); @@ -67,19 +95,24 @@ async fn main() { .await .expect("Failed to get database connection"); - let db_controller = DbController::get_all(&mut conn) + let db_controller = match DbController::get_all(&mut conn) .await .expect("Failed to get controller from database") .pop() - .unwrap_or_else(|| { - futures::executor::block_on(create_this_controller(&mut conn, &settings)) - }); + { + None => futures::executor::block_on(create_this_controller(&mut conn, &settings)), + Some(c) => c, + }; for relay in &settings.relays { - if DbRelay::get_by_controller_and_num(&mut conn, &db_controller, relay.number.unwrap()) - .await - .expect("Failed to get relay from database") - .is_none() + if DbRelay::get_by_controller_and_num( + &mut conn, + &db_controller, + relay.number.expect("Relay number is missing"), + ) + .await + .expect("Failed to get relay from database") + .is_none() { create_this_relay(&mut conn, &db_controller, relay) .await @@ -90,7 +123,7 @@ async fn main() { let db_controller = db_controller .update(&mut conn, &db_controller.name, settings.relays.len() as i64) .await - .unwrap(); + .expect("Failed to update controller"); let this = Controller::from_db_model(&mut conn, db_controller) .expect("Failed to convert database models"); @@ -103,27 +136,7 @@ async fn main() { tokio::spawn(run_relay_loop(settings)); loop { - match connect_async(&url).await { - Ok(connection) => { - let (ws_stream, _) = connection; - - let (mut write, read) = ws_stream.split(); - - let ws_action = ControllerWsAction::Register(this.clone()); - - let ws_action_json = serde_json::to_string(&ws_action).unwrap(); - write.send(Message::text(ws_action_json)).await.unwrap(); - - let read_handler = read.for_each(handle_message); - - read_handler.await; - - log::warn!("Lost connection to websocket"); - } - Err(err) => { - log::warn!("Failed to connect to websocket: {}", err,); - } - } + run_websocket(this.clone(), &url).await; log::info!( "Retrying to connect in {} seconds...", diff --git a/emgauwa-core/src/app_state.rs b/emgauwa-core/src/app_state.rs index bcbd2c6..f309420 100644 --- a/emgauwa-core/src/app_state.rs +++ b/emgauwa-core/src/app_state.rs @@ -1,36 +1,42 @@ use std::collections::HashMap; -use std::sync::{Arc, Mutex}; use actix::{Actor, Context, Handler, Message, Recipient}; -use emgauwa_lib::errors::DatabaseError; +use emgauwa_lib::errors::EmgauwaError; use emgauwa_lib::models::Controller; use emgauwa_lib::types::{ControllerUid, ControllerWsAction}; use futures::executor::block_on; use sqlx::{Pool, Sqlite}; #[derive(Message)] -#[rtype(result = "Result<(), DatabaseError>")] +#[rtype(result = "Result<(), EmgauwaError>")] pub struct DisconnectController { pub controller_uid: ControllerUid, } #[derive(Message)] -#[rtype(result = "Result<(), DatabaseError>")] +#[rtype(result = "Result<(), EmgauwaError>")] pub struct ConnectController { pub address: Recipient, pub controller: Controller, } +#[derive(Message)] +#[rtype(result = "Result<(), EmgauwaError>")] +pub struct Action { + pub controller_uid: ControllerUid, + pub action: ControllerWsAction, +} + pub struct AppServer { pub pool: Pool, - pub connected_controllers: Arc>>, + pub connected_controllers: HashMap)>, } impl AppServer { pub fn new(pool: Pool) -> AppServer { AppServer { pool, - connected_controllers: Arc::new(Mutex::new(HashMap::new())), + connected_controllers: HashMap::new(), } } } @@ -40,13 +46,12 @@ impl Actor for AppServer { } impl Handler for AppServer { - type Result = Result<(), DatabaseError>; + type Result = Result<(), EmgauwaError>; fn handle(&mut self, msg: DisconnectController, _ctx: &mut Self::Context) -> Self::Result { - let mut pool_conn = block_on(self.pool.acquire()).unwrap(); - let mut data = self.connected_controllers.lock().unwrap(); + let mut pool_conn = block_on(self.pool.acquire())?; - if let Some(controller) = data.remove(&msg.controller_uid) { + if let Some((controller, _)) = self.connected_controllers.remove(&msg.controller_uid) { if let Err(err) = block_on(controller.c.update_active(&mut pool_conn, false)) { log::error!( "Failed to mark controller {} as inactive: {:?}", @@ -60,12 +65,24 @@ impl Handler for AppServer { } impl Handler for AppServer { - type Result = Result<(), DatabaseError>; + type Result = Result<(), EmgauwaError>; fn handle(&mut self, msg: ConnectController, _ctx: &mut Self::Context) -> Self::Result { - let mut data = self.connected_controllers.lock().unwrap(); - data.insert(msg.controller.c.uid.clone(), msg.controller); + self.connected_controllers + .insert(msg.controller.c.uid.clone(), (msg.controller, msg.address)); Ok(()) } } + +impl Handler for AppServer { + type Result = Result<(), EmgauwaError>; + + fn handle(&mut self, msg: Action, _ctx: &mut Self::Context) -> Self::Result { + if let Some((_, address)) = self.connected_controllers.get(&msg.controller_uid) { + block_on(address.send(msg.action))? + } else { + Err(EmgauwaError::Connection(msg.controller_uid)) + } + } +} diff --git a/emgauwa-core/src/handlers/v1/relays.rs b/emgauwa-core/src/handlers/v1/relays.rs index cd3ade7..0f4ce50 100644 --- a/emgauwa-core/src/handlers/v1/relays.rs +++ b/emgauwa-core/src/handlers/v1/relays.rs @@ -1,11 +1,15 @@ +use actix::Addr; use actix_web::{get, put, web, HttpResponse}; use emgauwa_lib::db::{DbController, DbRelay, DbTag}; use emgauwa_lib::errors::{DatabaseError, EmgauwaError}; use emgauwa_lib::models::{convert_db_list, FromDbModel, Relay}; -use emgauwa_lib::types::ControllerUid; +use emgauwa_lib::types::{ControllerUid, ControllerWsAction}; use serde::{Deserialize, Serialize}; use sqlx::{Pool, Sqlite}; +use crate::app_state; +use crate::app_state::AppServer; + #[derive(Debug, Serialize, Deserialize)] pub struct RequestRelay { name: String, @@ -64,6 +68,7 @@ pub async fn index_for_controller( #[get("/api/v1/controllers/{controller_id}/relays/{relay_num}")] pub async fn show_for_controller( pool: web::Data>, + app_server: web::Data>, path: web::Path<(String, i64)>, ) -> Result { let mut pool_conn = pool.acquire().await?; diff --git a/emgauwa-core/src/handlers/v1/ws/controllers/handlers.rs b/emgauwa-core/src/handlers/v1/ws/controllers/handlers.rs index 68c5ad0..6408727 100644 --- a/emgauwa-core/src/handlers/v1/ws/controllers/handlers.rs +++ b/emgauwa-core/src/handlers/v1/ws/controllers/handlers.rs @@ -1,6 +1,6 @@ use actix::{Actor, AsyncContext}; use emgauwa_lib::db::{DbController, DbJunctionRelaySchedule, DbRelay, DbSchedule}; -use emgauwa_lib::errors::DatabaseError; +use emgauwa_lib::errors::{DatabaseError, EmgauwaError}; use emgauwa_lib::models::{Controller, FromDbModel}; use futures::executor::block_on; use sqlx::pool::PoolConnection; @@ -15,7 +15,7 @@ impl ControllerWs { conn: &mut PoolConnection, ctx: &mut ::Context, controller: Controller, - ) -> Result<(), DatabaseError> { + ) -> Result<(), EmgauwaError> { log::info!("Registering controller: {:?}", controller); let c = &controller.c; let controller_db = block_on(DbController::get_by_uid_or_create( @@ -60,10 +60,10 @@ impl ControllerWs { let addr = ctx.address(); self.controller_uid = Some(controller_uid.clone()); - self.app_server.do_send(ConnectController { + block_on(self.app_server.send(ConnectController { address: addr.recipient(), controller, - }); + }))??; Ok(()) } diff --git a/emgauwa-core/src/handlers/v1/ws/controllers/mod.rs b/emgauwa-core/src/handlers/v1/ws/controllers/mod.rs index bf58deb..c4a6425 100644 --- a/emgauwa-core/src/handlers/v1/ws/controllers/mod.rs +++ b/emgauwa-core/src/handlers/v1/ws/controllers/mod.rs @@ -6,7 +6,7 @@ use actix::{Actor, ActorContext, Addr, AsyncContext, Handler, StreamHandler}; use actix_web_actors::ws; use actix_web_actors::ws::ProtocolError; use emgauwa_lib::constants::{HEARTBEAT_INTERVAL, HEARTBEAT_TIMEOUT}; -use emgauwa_lib::errors::{DatabaseError, EmgauwaError}; +use emgauwa_lib::errors::EmgauwaError; use emgauwa_lib::types::{ControllerUid, ControllerWsAction}; use futures::executor::block_on; use sqlx::pool::PoolConnection; @@ -14,6 +14,7 @@ use sqlx::{Pool, Sqlite}; use ws::Message; use crate::app_state::{AppServer, DisconnectController}; +use crate::utils::flatten_result; pub struct ControllerWs { pub pool: Pool, @@ -31,9 +32,15 @@ impl Actor for ControllerWs { fn stopped(&mut self, _ctx: &mut Self::Context) { if let Some(controller_uid) = &self.controller_uid { - self.app_server.do_send(DisconnectController { - controller_uid: controller_uid.clone(), - }) + let flat_res = flatten_result( + block_on(self.app_server.send(DisconnectController { + controller_uid: controller_uid.clone(), + })) + .map_err(EmgauwaError::from), + ); + if let Err(err) = flat_res { + log::error!("Error disconnecting controller: {:?}", err); + } } } } @@ -44,9 +51,10 @@ impl ControllerWs { conn: &mut PoolConnection, ctx: &mut ::Context, action: ControllerWsAction, - ) -> Result<(), DatabaseError> { + ) -> Result<(), EmgauwaError> { match action { ControllerWsAction::Register(controller) => self.handle_register(conn, ctx, controller), + _ => Ok(()), } } @@ -78,7 +86,14 @@ impl Handler for ControllerWs { impl StreamHandler> for ControllerWs { fn handle(&mut self, msg: Result, ctx: &mut Self::Context) { - let mut pool_conn = block_on(self.pool.acquire()).unwrap(); + let mut pool_conn = match block_on(self.pool.acquire()) { + Ok(conn) => conn, + Err(err) => { + log::error!("Failed to acquire database connection: {:?}", err); + ctx.stop(); + return; + } + }; let msg = match msg { Err(_) => { @@ -96,14 +111,22 @@ impl StreamHandler> for ControllerWs { Message::Pong(_) => { self.hb = Instant::now(); } - Message::Text(text) => { - let action: ControllerWsAction = serde_json::from_str(&text).unwrap(); - let action_res = self.handle_action(&mut pool_conn, ctx, action); - if let Err(e) = action_res { - log::error!("Error handling action: {:?}", e); - ctx.text(serde_json::to_string(&e).unwrap()); + Message::Text(text) => match serde_json::from_str(&text) { + Ok(action) => { + let action_res = self.handle_action(&mut pool_conn, ctx, action); + if let Err(e) = action_res { + log::error!("Error handling action: {:?}", e); + ctx.text(serde_json::to_string(&e).expect("Failed to serialize error")); + } } - } + Err(e) => { + log::error!("Error deserializing action: {:?}", e); + ctx.text( + serde_json::to_string(&EmgauwaError::Serialization(e)) + .expect("Failed to serialize error"), + ); + } + }, Message::Binary(_) => log::warn!("Received unexpected binary in controller ws"), Message::Close(reason) => { ctx.close(reason); diff --git a/emgauwa-core/src/handlers/v1/ws/mod.rs b/emgauwa-core/src/handlers/v1/ws/mod.rs index 8a45ed0..80c3b58 100644 --- a/emgauwa-core/src/handlers/v1/ws/mod.rs +++ b/emgauwa-core/src/handlers/v1/ws/mod.rs @@ -3,7 +3,7 @@ use std::time::Instant; use actix::Addr; use actix_web::{get, web, HttpRequest, HttpResponse}; use actix_web_actors::ws; -use emgauwa_lib::errors::{ApiError, EmgauwaError}; +use emgauwa_lib::errors::EmgauwaError; use sqlx::{Pool, Sqlite}; use crate::app_state::AppServer; @@ -28,10 +28,6 @@ pub async fn ws_controllers( &req, stream, ) - .map_err(|_| { - EmgauwaError::from(ApiError::InternalError(String::from( - "error starting websocket", - ))) - }); + .map_err(|_| EmgauwaError::Internal(String::from("error starting websocket"))); resp } diff --git a/emgauwa-core/src/utils.rs b/emgauwa-core/src/utils.rs index dea14c5..02f1d43 100644 --- a/emgauwa-core/src/utils.rs +++ b/emgauwa-core/src/utils.rs @@ -3,6 +3,14 @@ use std::io::{Error, ErrorKind}; use crate::settings::Settings; +pub fn flatten_result(res: Result, E>) -> Result { + match res { + Ok(Ok(t)) => Ok(t), + Ok(Err(e)) => Err(e), + Err(e) => Err(e), + } +} + // https://blog.lxsang.me/post/id/28.0 pub fn drop_privileges(settings: &Settings) -> Result<(), Error> { log::info!( diff --git a/emgauwa-lib/src/db/mod.rs b/emgauwa-lib/src/db/mod.rs index 145a83f..fd42606 100644 --- a/emgauwa-lib/src/db/mod.rs +++ b/emgauwa-lib/src/db/mod.rs @@ -40,7 +40,10 @@ pub async fn init(db: &str) -> Pool { run_migrations(&pool).await; - let mut pool_conn = pool.acquire().await.unwrap(); + let mut pool_conn = pool + .acquire() + .await + .expect("Failed to acquire pool connection"); DbSchedule::get_on(&mut pool_conn) .await diff --git a/emgauwa-lib/src/db/model_utils.rs b/emgauwa-lib/src/db/model_utils.rs index e1b5835..c3639fe 100644 --- a/emgauwa-lib/src/db/model_utils.rs +++ b/emgauwa-lib/src/db/model_utils.rs @@ -46,8 +46,8 @@ impl Period { pub fn new_on() -> Self { Period { - start: NaiveTime::from_hms_opt(0, 0, 0).unwrap(), - end: NaiveTime::from_hms_opt(0, 0, 0).unwrap(), + start: NaiveTime::MIN, + end: NaiveTime::MIN, } } } @@ -103,8 +103,10 @@ impl From> for DbPeriods { let end_val_h: u32 = value[i - 1] as u32; let end_val_m: u32 = value[i] as u32; vec.push(Period { - start: NaiveTime::from_hms_opt(start_val_h, start_val_m, 0).unwrap(), - end: NaiveTime::from_hms_opt(end_val_h, end_val_m, 0).unwrap(), + start: NaiveTime::from_hms_opt(start_val_h, start_val_m, 0) + .expect("Failed to parse period start time from database"), + end: NaiveTime::from_hms_opt(end_val_h, end_val_m, 0) + .expect("Failed to parse period end time from database"), }); } DbPeriods(vec) diff --git a/emgauwa-lib/src/errors/api_error.rs b/emgauwa-lib/src/errors/api_error.rs index 26f2e4f..bba2248 100644 --- a/emgauwa-lib/src/errors/api_error.rs +++ b/emgauwa-lib/src/errors/api_error.rs @@ -3,14 +3,12 @@ use actix_web::http::StatusCode; #[derive(Debug)] pub enum ApiError { ProtectedSchedule, - InternalError(String), } impl ApiError { pub fn get_code(&self) -> StatusCode { match self { ApiError::ProtectedSchedule => StatusCode::FORBIDDEN, - ApiError::InternalError(_) => StatusCode::INTERNAL_SERVER_ERROR, } } } @@ -19,7 +17,6 @@ impl From<&ApiError> for String { fn from(err: &ApiError) -> Self { match err { ApiError::ProtectedSchedule => String::from("the targeted schedule is protected"), - ApiError::InternalError(msg) => msg.clone(), } } } diff --git a/emgauwa-lib/src/errors/emgauwa_error.rs b/emgauwa-lib/src/errors/emgauwa_error.rs index bbdf84b..ee33ed8 100644 --- a/emgauwa-lib/src/errors/emgauwa_error.rs +++ b/emgauwa-lib/src/errors/emgauwa_error.rs @@ -1,17 +1,22 @@ use std::fmt::{Debug, Display, Formatter}; +use actix::MailboxError; use actix_web::http::StatusCode; use actix_web::HttpResponse; use serde::ser::SerializeStruct; use serde::{Serialize, Serializer}; use crate::errors::{ApiError, DatabaseError}; +use crate::types::ControllerUid; +#[derive(Debug)] pub enum EmgauwaError { Api(ApiError), Uid(uuid::Error), Serialization(serde_json::Error), Database(DatabaseError), + Internal(String), + Connection(ControllerUid), } impl EmgauwaError { @@ -21,6 +26,8 @@ impl EmgauwaError { EmgauwaError::Serialization(_) => StatusCode::INTERNAL_SERVER_ERROR, EmgauwaError::Database(err) => err.get_code(), EmgauwaError::Uid(_) => StatusCode::BAD_REQUEST, + EmgauwaError::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR, + EmgauwaError::Connection(_) => StatusCode::GATEWAY_TIMEOUT, } } } @@ -32,6 +39,8 @@ impl From<&EmgauwaError> for String { EmgauwaError::Serialization(_) => String::from("error during (de-)serialization"), EmgauwaError::Database(err) => String::from(err), EmgauwaError::Uid(_) => String::from("the uid is in a bad format"), + EmgauwaError::Internal(_) => String::from("general error"), + EmgauwaError::Connection(_) => String::from("the target controller is not connected"), } } } @@ -66,6 +75,12 @@ impl From for EmgauwaError { } } +impl From for EmgauwaError { + fn from(value: MailboxError) -> Self { + EmgauwaError::Internal(value.to_string()) + } +} + impl From<&EmgauwaError> for HttpResponse { fn from(err: &EmgauwaError) -> Self { HttpResponse::build(err.get_code()).json(err) @@ -90,12 +105,6 @@ impl Display for EmgauwaError { } } -impl Debug for EmgauwaError { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", String::from(self)) - } -} - impl actix_web::error::ResponseError for EmgauwaError { fn status_code(&self) -> StatusCode { self.get_code() diff --git a/emgauwa-lib/src/types/controller_uid.rs b/emgauwa-lib/src/types/controller_uid.rs index 7ab51bc..0d45973 100644 --- a/emgauwa-lib/src/types/controller_uid.rs +++ b/emgauwa-lib/src/types/controller_uid.rs @@ -85,7 +85,7 @@ impl TryFrom<&str> for ControllerUid { impl From<&[u8]> for ControllerUid { fn from(value: &[u8]) -> Self { - Self(Uuid::from_slice(value).unwrap()) + Self(Uuid::from_slice(value).expect("Failed to parse controller uid from database")) } } diff --git a/emgauwa-lib/src/types/mod.rs b/emgauwa-lib/src/types/mod.rs index d03678e..3a7046f 100644 --- a/emgauwa-lib/src/types/mod.rs +++ b/emgauwa-lib/src/types/mod.rs @@ -6,8 +6,9 @@ pub use controller_uid::ControllerUid; pub use schedule_uid::ScheduleUid; use serde_derive::{Deserialize, Serialize}; +use crate::db::DbSchedule; use crate::errors::EmgauwaError; -use crate::models::Controller; +use crate::models::{Controller, Relay}; pub type Weekday = i64; @@ -15,4 +16,6 @@ pub type Weekday = i64; #[rtype(result = "Result<(), EmgauwaError>")] pub enum ControllerWsAction { Register(Controller), + Schedules(Vec), + Relays(Vec), } diff --git a/emgauwa-lib/src/types/schedule_uid.rs b/emgauwa-lib/src/types/schedule_uid.rs index d04df24..ebd696d 100644 --- a/emgauwa-lib/src/types/schedule_uid.rs +++ b/emgauwa-lib/src/types/schedule_uid.rs @@ -145,7 +145,9 @@ impl From<&[u8]> for ScheduleUid { match value { [Self::OFF_U8] => Self::Off, [Self::ON_U8] => Self::On, - value_bytes => Self::Any(Uuid::from_slice(value_bytes).unwrap()), + value_bytes => Self::Any( + Uuid::from_slice(value_bytes).expect("Failed to parse schedule uid from database"), + ), } } }