From 5a7b2de0eaccd71c8c59d51d72a6595f507005c8 Mon Sep 17 00:00:00 2001 From: Tobias Reisinger Date: Fri, 1 Dec 2023 18:27:04 +0100 Subject: [PATCH] Refactor errors and some other stuff/fixes --- api.v1.yaml | 18 +- emgauwa-controller/src/main.rs | 51 +++--- emgauwa-core/src/app_state.rs | 71 ++++++++ emgauwa-core/src/handlers/errors.rs | 83 --------- emgauwa-core/src/handlers/mod.rs | 1 - emgauwa-core/src/handlers/v1/controllers.rs | 18 +- emgauwa-core/src/handlers/v1/relays.rs | 20 +-- emgauwa-core/src/handlers/v1/schedules.rs | 28 ++- emgauwa-core/src/handlers/v1/tags.rs | 5 +- .../src/handlers/v1/ws/controllers.rs | 168 ------------------ .../handlers/v1/ws/controllers/handlers.rs | 70 ++++++++ .../src/handlers/v1/ws/controllers/mod.rs | 118 ++++++++++++ emgauwa-core/src/handlers/v1/ws/mod.rs | 17 +- emgauwa-core/src/main.rs | 16 +- emgauwa-lib/src/db/controllers.rs | 2 +- emgauwa-lib/src/db/junction_relay_schedule.rs | 2 +- emgauwa-lib/src/db/junction_tag.rs | 2 +- emgauwa-lib/src/db/mod.rs | 1 - emgauwa-lib/src/db/relays.rs | 2 +- emgauwa-lib/src/db/schedules.rs | 2 +- emgauwa-lib/src/db/tag.rs | 2 +- emgauwa-lib/src/errors/api_error.rs | 25 +++ .../errors.rs => errors/database_error.rs} | 0 emgauwa-lib/src/errors/emgauwa_error.rs | 107 +++++++++++ emgauwa-lib/src/errors/mod.rs | 7 + emgauwa-lib/src/lib.rs | 1 + emgauwa-lib/src/models/mod.rs | 2 +- emgauwa-lib/src/types/mod.rs | 9 +- 28 files changed, 507 insertions(+), 341 deletions(-) create mode 100644 emgauwa-core/src/app_state.rs delete mode 100644 emgauwa-core/src/handlers/errors.rs delete mode 100644 emgauwa-core/src/handlers/v1/ws/controllers.rs create mode 100644 emgauwa-core/src/handlers/v1/ws/controllers/handlers.rs create mode 100644 emgauwa-core/src/handlers/v1/ws/controllers/mod.rs create mode 100644 emgauwa-lib/src/errors/api_error.rs rename emgauwa-lib/src/{db/errors.rs => errors/database_error.rs} (100%) create mode 100644 emgauwa-lib/src/errors/emgauwa_error.rs create mode 100644 emgauwa-lib/src/errors/mod.rs diff --git a/api.v1.yaml b/api.v1.yaml index e4a8bba..186c319 100644 --- a/api.v1.yaml +++ b/api.v1.yaml @@ -748,13 +748,13 @@ components: controller_id: $ref: '#/components/schemas/controller_id' active_schedule: - $ref: '#/components/schemas/schedule' + $ref: '#/components/schemas/schedule-untagged' schedules: type: array maxItems: 7 minItems: 7 items: - $ref: '#/components/schemas/schedule' + $ref: '#/components/schemas/schedule-untagged' tags: type: array items: @@ -762,6 +762,20 @@ components: is_on: type: boolean description: NULL when unknown + schedule-untagged: + title: schedule + type: object + description: '' + properties: + id: + $ref: '#/components/schemas/schedule_id' + name: + type: string + example: Sprinkler Sunny Day + periods: + type: array + items: + $ref: '#/components/schemas/period' schedule: title: schedule type: object diff --git a/emgauwa-controller/src/main.rs b/emgauwa-controller/src/main.rs index 59db1ac..93a08f4 100644 --- a/emgauwa-controller/src/main.rs +++ b/emgauwa-controller/src/main.rs @@ -1,6 +1,6 @@ use emgauwa_lib::constants::WEBSOCKET_RETRY_TIMEOUT; -use emgauwa_lib::db::errors::DatabaseError; use emgauwa_lib::db::{DbController, DbJunctionRelaySchedule, DbRelay, DbSchedule}; +use emgauwa_lib::errors::DatabaseError; use emgauwa_lib::models::{Controller, FromDbModel}; use emgauwa_lib::types::{ControllerUid, ControllerWsAction}; use emgauwa_lib::{db, utils}; @@ -62,7 +62,10 @@ async fn main() { let pool = db::init(&settings.database).await; - let mut conn = pool.acquire().await.unwrap(); + let mut conn = pool + .acquire() + .await + .expect("Failed to get database connection"); let db_controller = DbController::get_all(&mut conn) .await @@ -100,31 +103,33 @@ async fn main() { tokio::spawn(run_relay_loop(settings)); loop { + match connect_async(&url).await { + Ok(connection) => { + let (ws_stream, _) = connection; + + let (mut write, read) = ws_stream.split(); + + let ws_action = ControllerWsAction::Register(this.clone()); + + let ws_action_json = serde_json::to_string(&ws_action).unwrap(); + write.send(Message::text(ws_action_json)).await.unwrap(); + + let read_handler = read.for_each(handle_message); + + read_handler.await; + + log::warn!("Lost connection to websocket"); + } + Err(err) => { + log::warn!("Failed to connect to websocket: {}", err,); + } + } + log::info!( - "Trying to connect in {} seconds...", + "Retrying to connect in {} seconds...", WEBSOCKET_RETRY_TIMEOUT.as_secs() ); time::sleep(WEBSOCKET_RETRY_TIMEOUT).await; - - let connect_result = connect_async(&url).await; - if let Err(err) = connect_result { - log::warn!("Failed to connect to websocket: {}", err,); - continue; - } - let (ws_stream, _) = connect_result.unwrap(); - - let (mut write, read) = ws_stream.split(); - - let ws_action = ControllerWsAction::Register(this.clone()); - - let ws_action_json = serde_json::to_string(&ws_action).unwrap(); - write.send(Message::text(ws_action_json)).await.unwrap(); - - let read_handler = read.for_each(handle_message); - - read_handler.await; - - log::warn!("Lost connection to websocket"); } } diff --git a/emgauwa-core/src/app_state.rs b/emgauwa-core/src/app_state.rs new file mode 100644 index 0000000..bcbd2c6 --- /dev/null +++ b/emgauwa-core/src/app_state.rs @@ -0,0 +1,71 @@ +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; + +use actix::{Actor, Context, Handler, Message, Recipient}; +use emgauwa_lib::errors::DatabaseError; +use emgauwa_lib::models::Controller; +use emgauwa_lib::types::{ControllerUid, ControllerWsAction}; +use futures::executor::block_on; +use sqlx::{Pool, Sqlite}; + +#[derive(Message)] +#[rtype(result = "Result<(), DatabaseError>")] +pub struct DisconnectController { + pub controller_uid: ControllerUid, +} + +#[derive(Message)] +#[rtype(result = "Result<(), DatabaseError>")] +pub struct ConnectController { + pub address: Recipient, + pub controller: Controller, +} + +pub struct AppServer { + pub pool: Pool, + pub connected_controllers: Arc>>, +} + +impl AppServer { + pub fn new(pool: Pool) -> AppServer { + AppServer { + pool, + connected_controllers: Arc::new(Mutex::new(HashMap::new())), + } + } +} + +impl Actor for AppServer { + type Context = Context; +} + +impl Handler for AppServer { + type Result = Result<(), DatabaseError>; + + fn handle(&mut self, msg: DisconnectController, _ctx: &mut Self::Context) -> Self::Result { + let mut pool_conn = block_on(self.pool.acquire()).unwrap(); + let mut data = self.connected_controllers.lock().unwrap(); + + if let Some(controller) = data.remove(&msg.controller_uid) { + if let Err(err) = block_on(controller.c.update_active(&mut pool_conn, false)) { + log::error!( + "Failed to mark controller {} as inactive: {:?}", + controller.c.uid, + err + ); + } + } + Ok(()) + } +} + +impl Handler for AppServer { + type Result = Result<(), DatabaseError>; + + fn handle(&mut self, msg: ConnectController, _ctx: &mut Self::Context) -> Self::Result { + let mut data = self.connected_controllers.lock().unwrap(); + data.insert(msg.controller.c.uid.clone(), msg.controller); + + Ok(()) + } +} diff --git a/emgauwa-core/src/handlers/errors.rs b/emgauwa-core/src/handlers/errors.rs deleted file mode 100644 index 367ac42..0000000 --- a/emgauwa-core/src/handlers/errors.rs +++ /dev/null @@ -1,83 +0,0 @@ -use std::fmt::{Display, Formatter}; - -use actix_web::http::StatusCode; -use actix_web::HttpResponse; -use emgauwa_lib::db::errors::DatabaseError; -use serde::ser::SerializeStruct; -use serde::{Serialize, Serializer}; - -#[derive(Debug)] -pub enum ApiError { - BadUid, - ProtectedSchedule, - DatabaseError(DatabaseError), - InternalError(String), -} - -impl ApiError { - fn get_code(&self) -> StatusCode { - match self { - ApiError::BadUid => StatusCode::BAD_REQUEST, - ApiError::ProtectedSchedule => StatusCode::FORBIDDEN, - ApiError::DatabaseError(db_error) => db_error.get_code(), - ApiError::InternalError(_) => StatusCode::INTERNAL_SERVER_ERROR, - } - } -} - -impl Serialize for ApiError { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - let mut s = serializer.serialize_struct("error", 2)?; - s.serialize_field("code", &self.get_code().as_u16())?; - s.serialize_field("description", &String::from(self))?; - s.end() - } -} - -impl From<&ApiError> for String { - fn from(err: &ApiError) -> Self { - match err { - ApiError::BadUid => String::from("the uid is in a bad format"), - ApiError::ProtectedSchedule => String::from("the targeted schedule is protected"), - ApiError::DatabaseError(db_err) => String::from(db_err), - ApiError::InternalError(msg) => msg.clone(), - } - } -} - -impl From<&ApiError> for HttpResponse { - fn from(err: &ApiError) -> Self { - HttpResponse::build(err.get_code()).json(err) - } -} - -impl Display for ApiError { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}: {}", self.get_code(), String::from(self)) - } -} - -impl actix_web::error::ResponseError for ApiError { - fn status_code(&self) -> StatusCode { - self.get_code() - } - - fn error_response(&self) -> HttpResponse { - HttpResponse::from(self) - } -} - -impl From for ApiError { - fn from(err: sqlx::Error) -> Self { - ApiError::DatabaseError(DatabaseError::from(err)) - } -} - -impl From for ApiError { - fn from(err: DatabaseError) -> Self { - ApiError::DatabaseError(err) - } -} diff --git a/emgauwa-core/src/handlers/mod.rs b/emgauwa-core/src/handlers/mod.rs index 507edc4..3c13f6c 100644 --- a/emgauwa-core/src/handlers/mod.rs +++ b/emgauwa-core/src/handlers/mod.rs @@ -2,7 +2,6 @@ use actix_web::{error, Error, HttpRequest, HttpResponse}; use serde::ser::SerializeStruct; use serde::{Serialize, Serializer}; -pub(crate) mod errors; pub mod v1; enum EmgauwaJsonPayLoadError { diff --git a/emgauwa-core/src/handlers/v1/controllers.rs b/emgauwa-core/src/handlers/v1/controllers.rs index 90a0994..10fb2dc 100644 --- a/emgauwa-core/src/handlers/v1/controllers.rs +++ b/emgauwa-core/src/handlers/v1/controllers.rs @@ -1,20 +1,18 @@ use actix_web::{delete, get, put, web, HttpResponse}; -use emgauwa_lib::db::errors::DatabaseError; use emgauwa_lib::db::DbController; +use emgauwa_lib::errors::{DatabaseError, EmgauwaError}; use emgauwa_lib::models::{convert_db_list, Controller, FromDbModel}; use emgauwa_lib::types::ControllerUid; use serde_derive::{Deserialize, Serialize}; use sqlx::{Pool, Sqlite}; -use crate::handlers::errors::ApiError; - #[derive(Debug, Serialize, Deserialize)] pub struct RequestController { name: String, } #[get("/api/v1/controllers")] -pub async fn index(pool: web::Data>) -> Result { +pub async fn index(pool: web::Data>) -> Result { let mut pool_conn = pool.acquire().await?; let db_controllers = DbController::get_all(&mut pool_conn).await?; @@ -28,11 +26,11 @@ pub async fn index(pool: web::Data>) -> Result>, path: web::Path<(String,)>, -) -> Result { +) -> Result { let mut pool_conn = pool.acquire().await?; let (controller_uid,) = path.into_inner(); - let uid = ControllerUid::try_from(controller_uid.as_str()).or(Err(ApiError::BadUid))?; + let uid = ControllerUid::try_from(controller_uid.as_str())?; let controller = DbController::get_by_uid(&mut pool_conn, &uid) .await? @@ -47,11 +45,11 @@ pub async fn update( pool: web::Data>, path: web::Path<(String,)>, data: web::Json, -) -> Result { +) -> Result { let mut pool_conn = pool.acquire().await?; let (controller_uid,) = path.into_inner(); - let uid = ControllerUid::try_from(controller_uid.as_str()).or(Err(ApiError::BadUid))?; + let uid = ControllerUid::try_from(controller_uid.as_str())?; let controller = DbController::get_by_uid(&mut pool_conn, &uid) .await? @@ -69,11 +67,11 @@ pub async fn update( pub async fn delete( pool: web::Data>, path: web::Path<(String,)>, -) -> Result { +) -> Result { let mut pool_conn = pool.acquire().await?; let (controller_uid,) = path.into_inner(); - let uid = ControllerUid::try_from(controller_uid.as_str()).or(Err(ApiError::BadUid))?; + let uid = ControllerUid::try_from(controller_uid.as_str())?; DbController::delete_by_uid(&mut pool_conn, uid).await?; Ok(HttpResponse::Ok().json("controller got deleted")) diff --git a/emgauwa-core/src/handlers/v1/relays.rs b/emgauwa-core/src/handlers/v1/relays.rs index 5e03d59..cd3ade7 100644 --- a/emgauwa-core/src/handlers/v1/relays.rs +++ b/emgauwa-core/src/handlers/v1/relays.rs @@ -1,13 +1,11 @@ use actix_web::{get, put, web, HttpResponse}; -use emgauwa_lib::db::errors::DatabaseError; use emgauwa_lib::db::{DbController, DbRelay, DbTag}; +use emgauwa_lib::errors::{DatabaseError, EmgauwaError}; use emgauwa_lib::models::{convert_db_list, FromDbModel, Relay}; use emgauwa_lib::types::ControllerUid; use serde::{Deserialize, Serialize}; use sqlx::{Pool, Sqlite}; -use crate::handlers::errors::ApiError; - #[derive(Debug, Serialize, Deserialize)] pub struct RequestRelay { name: String, @@ -15,7 +13,7 @@ pub struct RequestRelay { } #[get("/api/v1/relays")] -pub async fn index(pool: web::Data>) -> Result { +pub async fn index(pool: web::Data>) -> Result { let mut pool_conn = pool.acquire().await?; let db_relays = DbRelay::get_all(&mut pool_conn).await?; @@ -29,7 +27,7 @@ pub async fn index(pool: web::Data>) -> Result>, path: web::Path<(String,)>, -) -> Result { +) -> Result { let mut pool_conn = pool.acquire().await?; let (tag,) = path.into_inner(); @@ -47,11 +45,11 @@ pub async fn tagged( pub async fn index_for_controller( pool: web::Data>, path: web::Path<(String,)>, -) -> Result { +) -> Result { let mut pool_conn = pool.acquire().await?; let (controller_uid,) = path.into_inner(); - let uid = ControllerUid::try_from(controller_uid.as_str()).or(Err(ApiError::BadUid))?; + let uid = ControllerUid::try_from(controller_uid.as_str())?; let controller = DbController::get_by_uid(&mut pool_conn, &uid) .await? @@ -67,11 +65,11 @@ pub async fn index_for_controller( pub async fn show_for_controller( pool: web::Data>, path: web::Path<(String, i64)>, -) -> Result { +) -> Result { let mut pool_conn = pool.acquire().await?; let (controller_uid, relay_num) = path.into_inner(); - let uid = ControllerUid::try_from(controller_uid.as_str()).or(Err(ApiError::BadUid))?; + let uid = ControllerUid::try_from(controller_uid.as_str())?; let controller = DbController::get_by_uid(&mut pool_conn, &uid) .await? @@ -90,11 +88,11 @@ pub async fn update_for_controller( pool: web::Data>, path: web::Path<(String, i64)>, data: web::Json, -) -> Result { +) -> Result { let mut pool_conn = pool.acquire().await?; let (controller_uid, relay_num) = path.into_inner(); - let uid = ControllerUid::try_from(controller_uid.as_str()).or(Err(ApiError::BadUid))?; + let uid = ControllerUid::try_from(controller_uid.as_str())?; let controller = DbController::get_by_uid(&mut pool_conn, &uid) .await? diff --git a/emgauwa-core/src/handlers/v1/schedules.rs b/emgauwa-core/src/handlers/v1/schedules.rs index 81f908b..7e19ffa 100644 --- a/emgauwa-core/src/handlers/v1/schedules.rs +++ b/emgauwa-core/src/handlers/v1/schedules.rs @@ -1,14 +1,12 @@ use actix_web::{delete, get, post, put, web, HttpResponse}; -use emgauwa_lib::db::errors::DatabaseError; use emgauwa_lib::db::{DbPeriods, DbSchedule, DbTag}; +use emgauwa_lib::errors::{ApiError, DatabaseError, EmgauwaError}; use emgauwa_lib::models::{convert_db_list, FromDbModel, Schedule}; use emgauwa_lib::types::ScheduleUid; use serde::{Deserialize, Serialize}; use sqlx::pool::PoolConnection; use sqlx::{Pool, Sqlite}; -use crate::handlers::errors::ApiError; - #[derive(Debug, Serialize, Deserialize)] pub struct RequestSchedule { name: String, @@ -17,7 +15,7 @@ pub struct RequestSchedule { } #[get("/api/v1/schedules")] -pub async fn index(pool: web::Data>) -> Result { +pub async fn index(pool: web::Data>) -> Result { let mut pool_conn = pool.acquire().await?; let db_schedules = DbSchedule::get_all(&mut pool_conn).await?; @@ -30,7 +28,7 @@ pub async fn index(pool: web::Data>) -> Result>, path: web::Path<(String,)>, -) -> Result { +) -> Result { let mut pool_conn = pool.acquire().await?; let (tag,) = path.into_inner(); @@ -48,11 +46,11 @@ pub async fn tagged( pub async fn show( pool: web::Data>, path: web::Path<(String,)>, -) -> Result { +) -> Result { let mut pool_conn = pool.acquire().await?; let (schedule_uid,) = path.into_inner(); - let uid = ScheduleUid::try_from(schedule_uid.as_str()).or(Err(ApiError::BadUid))?; + let uid = ScheduleUid::try_from(schedule_uid.as_str())?; let schedule = DbSchedule::get_by_uid(&mut pool_conn, &uid) .await? @@ -66,7 +64,7 @@ pub async fn show( pub async fn add( pool: web::Data>, data: web::Json, -) -> Result { +) -> Result { let mut pool_conn = pool.acquire().await?; let new_schedule = DbSchedule::create( @@ -108,7 +106,7 @@ async fn add_list_single( pub async fn add_list( pool: web::Data>, data: web::Json>, -) -> Result { +) -> Result { let mut pool_conn = pool.acquire().await?; let mut db_schedules: Vec = Vec::new(); @@ -126,11 +124,11 @@ pub async fn update( pool: web::Data>, path: web::Path<(String,)>, data: web::Json, -) -> Result { +) -> Result { let mut pool_conn = pool.acquire().await?; let (schedule_uid,) = path.into_inner(); - let uid = ScheduleUid::try_from(schedule_uid.as_str()).or(Err(ApiError::BadUid))?; + let uid = ScheduleUid::try_from(schedule_uid.as_str())?; let schedule = DbSchedule::get_by_uid(&mut pool_conn, &uid) .await? @@ -152,15 +150,15 @@ pub async fn update( pub async fn delete( pool: web::Data>, path: web::Path<(String,)>, -) -> Result { +) -> Result { let mut pool_conn = pool.acquire().await?; let (schedule_uid,) = path.into_inner(); - let uid = ScheduleUid::try_from(schedule_uid.as_str()).or(Err(ApiError::BadUid))?; + let uid = ScheduleUid::try_from(schedule_uid.as_str())?; match uid { - ScheduleUid::Off => Err(ApiError::ProtectedSchedule), - ScheduleUid::On => Err(ApiError::ProtectedSchedule), + ScheduleUid::Off => Err(EmgauwaError::from(ApiError::ProtectedSchedule)), + ScheduleUid::On => Err(EmgauwaError::from(ApiError::ProtectedSchedule)), ScheduleUid::Any(_) => { DbSchedule::delete_by_uid(&mut pool_conn, uid).await?; Ok(HttpResponse::Ok().json("schedule got deleted")) diff --git a/emgauwa-core/src/handlers/v1/tags.rs b/emgauwa-core/src/handlers/v1/tags.rs index e5eb0e9..5b03038 100644 --- a/emgauwa-core/src/handlers/v1/tags.rs +++ b/emgauwa-core/src/handlers/v1/tags.rs @@ -1,11 +1,10 @@ use actix_web::{get, web, HttpResponse}; use emgauwa_lib::db::DbTag; +use emgauwa_lib::errors::EmgauwaError; use sqlx::{Pool, Sqlite}; -use crate::handlers::errors::ApiError; - #[get("/api/v1/tags")] -pub async fn index(pool: web::Data>) -> Result { +pub async fn index(pool: web::Data>) -> Result { let mut pool_conn = pool.acquire().await?; let db_tags = DbTag::get_all(&mut pool_conn).await?; diff --git a/emgauwa-core/src/handlers/v1/ws/controllers.rs b/emgauwa-core/src/handlers/v1/ws/controllers.rs deleted file mode 100644 index 1ec5b12..0000000 --- a/emgauwa-core/src/handlers/v1/ws/controllers.rs +++ /dev/null @@ -1,168 +0,0 @@ -use std::time::Instant; - -use actix::{Actor, ActorContext, AsyncContext, StreamHandler}; -use actix_web_actors::ws; -use actix_web_actors::ws::ProtocolError; -use emgauwa_lib::constants::{HEARTBEAT_INTERVAL, HEARTBEAT_TIMEOUT}; -use emgauwa_lib::db::errors::DatabaseError; -use emgauwa_lib::db::{DbController, DbJunctionRelaySchedule, DbRelay, DbSchedule}; -use emgauwa_lib::models::{Controller, FromDbModel}; -use emgauwa_lib::types::{ConnectedControllersType, ControllerUid, ControllerWsAction}; -use futures::executor::block_on; -use sqlx::pool::PoolConnection; -use sqlx::{Pool, Sqlite}; -use ws::Message; - -pub struct ControllerWs { - pub pool: Pool, - pub controller_uid: Option, - pub connected_controllers: ConnectedControllersType, - pub hb: Instant, -} - -impl Actor for ControllerWs { - type Context = ws::WebsocketContext; - - fn started(&mut self, ctx: &mut Self::Context) { - self.hb(ctx); - } - - fn stopped(&mut self, _ctx: &mut Self::Context) { - if let Some(controller_uid) = &self.controller_uid { - let mut pool_conn = block_on(self.pool.acquire()).unwrap(); - - let mut data = self.connected_controllers.lock().unwrap(); - if let Some(controller) = data.remove(controller_uid) { - if let Err(err) = block_on(controller.c.update_active(&mut pool_conn, false)) { - log::error!( - "Failed to mark controller {} as inactive: {:?}", - controller.c.uid, - err - ) - } - } - } - } -} - -impl ControllerWs { - pub fn handle_action( - &mut self, - conn: &mut PoolConnection, - action: ControllerWsAction, - ) -> Result<(), DatabaseError> { - match action { - ControllerWsAction::Register(controller) => { - log::info!("Registering controller: {:?}", controller); - let c = &controller.c; - let controller_db = block_on(DbController::get_by_uid_or_create( - conn, - &c.uid, - &c.name, - c.relay_count, - ))?; - block_on(controller_db.update_active(conn, true))?; - - for relay in &controller.relays { - let (new_relay, created) = - block_on(DbRelay::get_by_controller_and_num_or_create( - conn, - &controller_db, - relay.r.number, - &relay.r.name, - ))?; - if created { - let mut relay_schedules = Vec::new(); - for schedule in &relay.schedules { - let (new_schedule, _) = block_on(DbSchedule::get_by_uid_or_create( - conn, - schedule.uid.clone(), - &schedule.name, - &schedule.periods, - ))?; - relay_schedules.push(new_schedule); - } - - block_on(DbJunctionRelaySchedule::set_schedules( - conn, - &new_relay, - relay_schedules.iter().collect(), - ))?; - } - } - - let controller_uid = &controller.c.uid; - let controller_db = block_on(DbController::get_by_uid(conn, controller_uid))? - .ok_or(DatabaseError::InsertGetError)?; - let controller = Controller::from_db_model(conn, controller_db)?; - - self.controller_uid = Some(controller_uid.clone()); - - let mut data = self.connected_controllers.lock().unwrap(); - data.insert(controller_uid.clone(), controller); - - Ok(()) - } - } - } - - // helper method that sends ping to client every 5 seconds (HEARTBEAT_INTERVAL). - fn hb(&self, ctx: &mut ws::WebsocketContext) { - ctx.run_interval(HEARTBEAT_INTERVAL, |act, ctx| { - // check client heartbeats - if Instant::now().duration_since(act.hb) > HEARTBEAT_TIMEOUT { - log::warn!("Websocket Controller heartbeat failed, disconnecting!"); - ctx.stop(); - // don't try to send a ping - return; - } - - ctx.ping(&[]); - }); - } -} - -impl StreamHandler> for ControllerWs { - fn handle(&mut self, msg: Result, ctx: &mut Self::Context) { - let mut pool_conn = block_on(self.pool.acquire()).unwrap(); - - let msg = match msg { - Err(_) => { - ctx.stop(); - return; - } - Ok(msg) => msg, - }; - - match msg { - Message::Ping(msg) => { - self.hb = Instant::now(); - ctx.pong(&msg) - } - Message::Pong(_) => { - self.hb = Instant::now(); - } - Message::Text(text) => { - let action: ControllerWsAction = serde_json::from_str(&text).unwrap(); - let action_res = self.handle_action(&mut pool_conn, action); - if let Err(e) = action_res { - log::error!("Error handling action: {:?}", e); - ctx.text(serde_json::to_string(&e).unwrap()); - } - } - Message::Binary(_) => log::warn!("Received unexpected binary in controller ws"), - Message::Close(reason) => { - ctx.close(reason); - ctx.stop(); - } - Message::Continuation(_) => { - ctx.stop(); - } - Message::Nop => (), - } - - //let schedules = futures::executor::block_on(DbSchedule::get_all(&mut pool_conn)).unwrap(); - //let schedules_json = serde_json::to_string(&schedules).unwrap(); - //ctx.text(schedules_json); - } -} diff --git a/emgauwa-core/src/handlers/v1/ws/controllers/handlers.rs b/emgauwa-core/src/handlers/v1/ws/controllers/handlers.rs new file mode 100644 index 0000000..68c5ad0 --- /dev/null +++ b/emgauwa-core/src/handlers/v1/ws/controllers/handlers.rs @@ -0,0 +1,70 @@ +use actix::{Actor, AsyncContext}; +use emgauwa_lib::db::{DbController, DbJunctionRelaySchedule, DbRelay, DbSchedule}; +use emgauwa_lib::errors::DatabaseError; +use emgauwa_lib::models::{Controller, FromDbModel}; +use futures::executor::block_on; +use sqlx::pool::PoolConnection; +use sqlx::Sqlite; + +use crate::app_state::ConnectController; +use crate::handlers::v1::ws::controllers::ControllerWs; + +impl ControllerWs { + pub fn handle_register( + &mut self, + conn: &mut PoolConnection, + ctx: &mut ::Context, + controller: Controller, + ) -> Result<(), DatabaseError> { + log::info!("Registering controller: {:?}", controller); + let c = &controller.c; + let controller_db = block_on(DbController::get_by_uid_or_create( + conn, + &c.uid, + &c.name, + c.relay_count, + ))?; + block_on(controller_db.update_active(conn, true))?; + + for relay in &controller.relays { + let (new_relay, created) = block_on(DbRelay::get_by_controller_and_num_or_create( + conn, + &controller_db, + relay.r.number, + &relay.r.name, + ))?; + if created { + let mut relay_schedules = Vec::new(); + for schedule in &relay.schedules { + let (new_schedule, _) = block_on(DbSchedule::get_by_uid_or_create( + conn, + schedule.uid.clone(), + &schedule.name, + &schedule.periods, + ))?; + relay_schedules.push(new_schedule); + } + + block_on(DbJunctionRelaySchedule::set_schedules( + conn, + &new_relay, + relay_schedules.iter().collect(), + ))?; + } + } + + let controller_uid = &controller.c.uid; + let controller_db = block_on(DbController::get_by_uid(conn, controller_uid))? + .ok_or(DatabaseError::InsertGetError)?; + let controller = Controller::from_db_model(conn, controller_db)?; + + let addr = ctx.address(); + self.controller_uid = Some(controller_uid.clone()); + self.app_server.do_send(ConnectController { + address: addr.recipient(), + controller, + }); + + Ok(()) + } +} diff --git a/emgauwa-core/src/handlers/v1/ws/controllers/mod.rs b/emgauwa-core/src/handlers/v1/ws/controllers/mod.rs new file mode 100644 index 0000000..bf58deb --- /dev/null +++ b/emgauwa-core/src/handlers/v1/ws/controllers/mod.rs @@ -0,0 +1,118 @@ +mod handlers; + +use std::time::Instant; + +use actix::{Actor, ActorContext, Addr, AsyncContext, Handler, StreamHandler}; +use actix_web_actors::ws; +use actix_web_actors::ws::ProtocolError; +use emgauwa_lib::constants::{HEARTBEAT_INTERVAL, HEARTBEAT_TIMEOUT}; +use emgauwa_lib::errors::{DatabaseError, EmgauwaError}; +use emgauwa_lib::types::{ControllerUid, ControllerWsAction}; +use futures::executor::block_on; +use sqlx::pool::PoolConnection; +use sqlx::{Pool, Sqlite}; +use ws::Message; + +use crate::app_state::{AppServer, DisconnectController}; + +pub struct ControllerWs { + pub pool: Pool, + pub controller_uid: Option, + pub app_server: Addr, + pub hb: Instant, +} + +impl Actor for ControllerWs { + type Context = ws::WebsocketContext; + + fn started(&mut self, ctx: &mut Self::Context) { + self.hb(ctx); + } + + fn stopped(&mut self, _ctx: &mut Self::Context) { + if let Some(controller_uid) = &self.controller_uid { + self.app_server.do_send(DisconnectController { + controller_uid: controller_uid.clone(), + }) + } + } +} + +impl ControllerWs { + pub fn handle_action( + &mut self, + conn: &mut PoolConnection, + ctx: &mut ::Context, + action: ControllerWsAction, + ) -> Result<(), DatabaseError> { + match action { + ControllerWsAction::Register(controller) => self.handle_register(conn, ctx, controller), + } + } + + // helper method that sends ping to client every 5 seconds (HEARTBEAT_INTERVAL). + fn hb(&self, ctx: &mut ws::WebsocketContext) { + ctx.run_interval(HEARTBEAT_INTERVAL, |act, ctx| { + // check client heartbeats + if Instant::now().duration_since(act.hb) > HEARTBEAT_TIMEOUT { + log::warn!("Websocket Controller heartbeat failed, disconnecting!"); + ctx.stop(); + // don't try to send a ping + return; + } + + ctx.ping(&[]); + }); + } +} + +impl Handler for ControllerWs { + type Result = Result<(), EmgauwaError>; + + fn handle(&mut self, action: ControllerWsAction, ctx: &mut Self::Context) -> Self::Result { + let action_json = serde_json::to_string(&action)?; + ctx.text(action_json); + Ok(()) + } +} + +impl StreamHandler> for ControllerWs { + fn handle(&mut self, msg: Result, ctx: &mut Self::Context) { + let mut pool_conn = block_on(self.pool.acquire()).unwrap(); + + let msg = match msg { + Err(_) => { + ctx.stop(); + return; + } + Ok(msg) => msg, + }; + + match msg { + Message::Ping(msg) => { + self.hb = Instant::now(); + ctx.pong(&msg) + } + Message::Pong(_) => { + self.hb = Instant::now(); + } + Message::Text(text) => { + let action: ControllerWsAction = serde_json::from_str(&text).unwrap(); + let action_res = self.handle_action(&mut pool_conn, ctx, action); + if let Err(e) = action_res { + log::error!("Error handling action: {:?}", e); + ctx.text(serde_json::to_string(&e).unwrap()); + } + } + Message::Binary(_) => log::warn!("Received unexpected binary in controller ws"), + Message::Close(reason) => { + ctx.close(reason); + ctx.stop(); + } + Message::Continuation(_) => { + ctx.stop(); + } + Message::Nop => (), + } + } +} diff --git a/emgauwa-core/src/handlers/v1/ws/mod.rs b/emgauwa-core/src/handlers/v1/ws/mod.rs index e3ebdd0..8a45ed0 100644 --- a/emgauwa-core/src/handlers/v1/ws/mod.rs +++ b/emgauwa-core/src/handlers/v1/ws/mod.rs @@ -1,11 +1,12 @@ use std::time::Instant; +use actix::Addr; use actix_web::{get, web, HttpRequest, HttpResponse}; use actix_web_actors::ws; -use emgauwa_lib::types::ConnectedControllersType; +use emgauwa_lib::errors::{ApiError, EmgauwaError}; use sqlx::{Pool, Sqlite}; -use crate::handlers::errors::ApiError; +use crate::app_state::AppServer; use crate::handlers::v1::ws::controllers::ControllerWs; pub mod controllers; @@ -13,20 +14,24 @@ pub mod controllers; #[get("/api/v1/ws/controllers")] pub async fn ws_controllers( pool: web::Data>, - connected_controllers: web::Data, + app_server: web::Data>, req: HttpRequest, stream: web::Payload, -) -> Result { +) -> Result { let resp = ws::start( ControllerWs { pool: pool.get_ref().clone(), controller_uid: None, - connected_controllers: connected_controllers.get_ref().clone(), + app_server: app_server.get_ref().clone(), hb: Instant::now(), }, &req, stream, ) - .map_err(|_| ApiError::InternalError(String::from("error starting websocket"))); + .map_err(|_| { + EmgauwaError::from(ApiError::InternalError(String::from( + "error starting websocket", + ))) + }); resp } diff --git a/emgauwa-core/src/main.rs b/emgauwa-core/src/main.rs index fc4740c..8e83d16 100644 --- a/emgauwa-core/src/main.rs +++ b/emgauwa-core/src/main.rs @@ -1,16 +1,16 @@ -use std::collections::HashMap; use std::net::TcpListener; -use std::sync::{Arc, Mutex}; +use actix::Actor; use actix_cors::Cors; use actix_web::middleware::TrailingSlash; use actix_web::{middleware, web, App, HttpServer}; use emgauwa_lib::db::DbController; -use emgauwa_lib::types::ConnectedControllersType; use emgauwa_lib::utils::init_logging; +use crate::app_state::AppServer; use crate::utils::drop_privileges; +mod app_state; mod handlers; mod settings; mod utils; @@ -29,15 +29,19 @@ async fn main() -> std::io::Result<()> { // This block is to ensure that the connection is dropped after use. { - let mut conn = pool.acquire().await.unwrap(); + let mut conn = pool + .acquire() + .await + .expect("Failed to get database connection"); DbController::all_inactive(&mut conn) .await .expect("Error setting all controllers inactive"); } - let connected_controllers: ConnectedControllersType = Arc::new(Mutex::new(HashMap::new())); + let app_server = AppServer::new(pool.clone()).start(); log::info!("Starting server on {}:{}", settings.host, settings.port); + HttpServer::new(move || { let cors = Cors::default().allow_any_method().allow_any_header(); @@ -55,7 +59,7 @@ async fn main() -> std::io::Result<()> { .wrap(middleware::NormalizePath::new(TrailingSlash::Trim)) .app_data(web::JsonConfig::default().error_handler(handlers::json_error_handler)) .app_data(web::Data::new(pool.clone())) - .app_data(web::Data::new(connected_controllers.clone())) + .app_data(web::Data::new(app_server.clone())) .service(handlers::v1::controllers::index) .service(handlers::v1::controllers::show) .service(handlers::v1::controllers::update) diff --git a/emgauwa-lib/src/db/controllers.rs b/emgauwa-lib/src/db/controllers.rs index 9b83768..1b316d7 100644 --- a/emgauwa-lib/src/db/controllers.rs +++ b/emgauwa-lib/src/db/controllers.rs @@ -4,8 +4,8 @@ use serde_derive::{Deserialize, Serialize}; use sqlx::pool::PoolConnection; use sqlx::Sqlite; -use crate::db::errors::DatabaseError; use crate::db::{DbRelay, DbTag}; +use crate::errors::DatabaseError; use crate::types::ControllerUid; #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/emgauwa-lib/src/db/junction_relay_schedule.rs b/emgauwa-lib/src/db/junction_relay_schedule.rs index b331610..a090604 100644 --- a/emgauwa-lib/src/db/junction_relay_schedule.rs +++ b/emgauwa-lib/src/db/junction_relay_schedule.rs @@ -3,8 +3,8 @@ use std::ops::DerefMut; use sqlx::pool::PoolConnection; use sqlx::Sqlite; -use crate::db::errors::DatabaseError; use crate::db::{DbRelay, DbSchedule}; +use crate::errors::DatabaseError; use crate::types::Weekday; pub struct DbJunctionRelaySchedule { diff --git a/emgauwa-lib/src/db/junction_tag.rs b/emgauwa-lib/src/db/junction_tag.rs index 8229636..f1b8816 100644 --- a/emgauwa-lib/src/db/junction_tag.rs +++ b/emgauwa-lib/src/db/junction_tag.rs @@ -3,8 +3,8 @@ use std::ops::DerefMut; use sqlx::pool::PoolConnection; use sqlx::Sqlite; -use crate::db::errors::DatabaseError; use crate::db::{DbRelay, DbSchedule, DbTag}; +use crate::errors::DatabaseError; pub struct DbJunctionTag { pub id: i64, diff --git a/emgauwa-lib/src/db/mod.rs b/emgauwa-lib/src/db/mod.rs index e0ccaeb..145a83f 100644 --- a/emgauwa-lib/src/db/mod.rs +++ b/emgauwa-lib/src/db/mod.rs @@ -5,7 +5,6 @@ use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions}; use sqlx::{Pool, Sqlite}; mod controllers; -pub mod errors; mod junction_relay_schedule; mod junction_tag; mod model_utils; diff --git a/emgauwa-lib/src/db/relays.rs b/emgauwa-lib/src/db/relays.rs index a5033eb..35468cc 100644 --- a/emgauwa-lib/src/db/relays.rs +++ b/emgauwa-lib/src/db/relays.rs @@ -4,8 +4,8 @@ use serde_derive::{Deserialize, Serialize}; use sqlx::pool::PoolConnection; use sqlx::Sqlite; -use crate::db::errors::DatabaseError; use crate::db::{DbController, DbJunctionTag, DbTag}; +use crate::errors::DatabaseError; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DbRelay { diff --git a/emgauwa-lib/src/db/schedules.rs b/emgauwa-lib/src/db/schedules.rs index 2e72061..39bdb8b 100644 --- a/emgauwa-lib/src/db/schedules.rs +++ b/emgauwa-lib/src/db/schedules.rs @@ -5,9 +5,9 @@ use serde_derive::{Deserialize, Serialize}; use sqlx::pool::PoolConnection; use sqlx::Sqlite; -use crate::db::errors::DatabaseError; use crate::db::model_utils::Period; use crate::db::{DbJunctionTag, DbTag}; +use crate::errors::DatabaseError; use crate::types::ScheduleUid; #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/emgauwa-lib/src/db/tag.rs b/emgauwa-lib/src/db/tag.rs index bd3e1cd..abd964d 100644 --- a/emgauwa-lib/src/db/tag.rs +++ b/emgauwa-lib/src/db/tag.rs @@ -4,7 +4,7 @@ use serde_derive::Serialize; use sqlx::pool::PoolConnection; use sqlx::Sqlite; -use crate::db::errors::DatabaseError; +use crate::errors::DatabaseError; #[derive(Debug, Serialize, Clone)] pub struct DbTag { diff --git a/emgauwa-lib/src/errors/api_error.rs b/emgauwa-lib/src/errors/api_error.rs new file mode 100644 index 0000000..26f2e4f --- /dev/null +++ b/emgauwa-lib/src/errors/api_error.rs @@ -0,0 +1,25 @@ +use actix_web::http::StatusCode; + +#[derive(Debug)] +pub enum ApiError { + ProtectedSchedule, + InternalError(String), +} + +impl ApiError { + pub fn get_code(&self) -> StatusCode { + match self { + ApiError::ProtectedSchedule => StatusCode::FORBIDDEN, + ApiError::InternalError(_) => StatusCode::INTERNAL_SERVER_ERROR, + } + } +} + +impl From<&ApiError> for String { + fn from(err: &ApiError) -> Self { + match err { + ApiError::ProtectedSchedule => String::from("the targeted schedule is protected"), + ApiError::InternalError(msg) => msg.clone(), + } + } +} diff --git a/emgauwa-lib/src/db/errors.rs b/emgauwa-lib/src/errors/database_error.rs similarity index 100% rename from emgauwa-lib/src/db/errors.rs rename to emgauwa-lib/src/errors/database_error.rs diff --git a/emgauwa-lib/src/errors/emgauwa_error.rs b/emgauwa-lib/src/errors/emgauwa_error.rs new file mode 100644 index 0000000..bbdf84b --- /dev/null +++ b/emgauwa-lib/src/errors/emgauwa_error.rs @@ -0,0 +1,107 @@ +use std::fmt::{Debug, Display, Formatter}; + +use actix_web::http::StatusCode; +use actix_web::HttpResponse; +use serde::ser::SerializeStruct; +use serde::{Serialize, Serializer}; + +use crate::errors::{ApiError, DatabaseError}; + +pub enum EmgauwaError { + Api(ApiError), + Uid(uuid::Error), + Serialization(serde_json::Error), + Database(DatabaseError), +} + +impl EmgauwaError { + fn get_code(&self) -> StatusCode { + match self { + EmgauwaError::Api(err) => err.get_code(), + EmgauwaError::Serialization(_) => StatusCode::INTERNAL_SERVER_ERROR, + EmgauwaError::Database(err) => err.get_code(), + EmgauwaError::Uid(_) => StatusCode::BAD_REQUEST, + } + } +} + +impl From<&EmgauwaError> for String { + fn from(err: &EmgauwaError) -> Self { + match err { + EmgauwaError::Api(err) => String::from(err), + EmgauwaError::Serialization(_) => String::from("error during (de-)serialization"), + EmgauwaError::Database(err) => String::from(err), + EmgauwaError::Uid(_) => String::from("the uid is in a bad format"), + } + } +} + +impl From for EmgauwaError { + fn from(value: ApiError) -> Self { + EmgauwaError::Api(value) + } +} + +impl From for EmgauwaError { + fn from(value: DatabaseError) -> Self { + EmgauwaError::Database(value) + } +} + +impl From for EmgauwaError { + fn from(value: serde_json::Error) -> Self { + EmgauwaError::Serialization(value) + } +} + +impl From for EmgauwaError { + fn from(value: sqlx::Error) -> Self { + EmgauwaError::Database(DatabaseError::from(value)) + } +} + +impl From for EmgauwaError { + fn from(value: uuid::Error) -> Self { + EmgauwaError::Uid(value) + } +} + +impl From<&EmgauwaError> for HttpResponse { + fn from(err: &EmgauwaError) -> Self { + HttpResponse::build(err.get_code()).json(err) + } +} + +impl Serialize for EmgauwaError { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut s = serializer.serialize_struct("error", 2)?; + s.serialize_field("code", &self.get_code().as_u16())?; + s.serialize_field("description", &String::from(self))?; + s.end() + } +} + +impl Display for EmgauwaError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}: {}", self.get_code(), String::from(self)) + } +} + +impl Debug for EmgauwaError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", String::from(self)) + } +} + +impl actix_web::error::ResponseError for EmgauwaError { + fn status_code(&self) -> StatusCode { + self.get_code() + } + + fn error_response(&self) -> HttpResponse { + HttpResponse::from(self) + } +} diff --git a/emgauwa-lib/src/errors/mod.rs b/emgauwa-lib/src/errors/mod.rs new file mode 100644 index 0000000..0209bcc --- /dev/null +++ b/emgauwa-lib/src/errors/mod.rs @@ -0,0 +1,7 @@ +mod api_error; +mod database_error; +mod emgauwa_error; + +pub use api_error::ApiError; +pub use database_error::DatabaseError; +pub use emgauwa_error::EmgauwaError; diff --git a/emgauwa-lib/src/lib.rs b/emgauwa-lib/src/lib.rs index dc3750e..ddaf738 100644 --- a/emgauwa-lib/src/lib.rs +++ b/emgauwa-lib/src/lib.rs @@ -1,5 +1,6 @@ pub mod constants; pub mod db; +pub mod errors; pub mod models; pub mod types; pub mod utils; diff --git a/emgauwa-lib/src/models/mod.rs b/emgauwa-lib/src/models/mod.rs index 187d640..031e44f 100644 --- a/emgauwa-lib/src/models/mod.rs +++ b/emgauwa-lib/src/models/mod.rs @@ -3,8 +3,8 @@ use serde_derive::{Deserialize, Serialize}; use sqlx::pool::PoolConnection; use sqlx::Sqlite; -use crate::db::errors::DatabaseError; use crate::db::{DbController, DbJunctionRelaySchedule, DbRelay, DbSchedule}; +use crate::errors::DatabaseError; use crate::types::{ControllerUid, Weekday}; use crate::utils; diff --git a/emgauwa-lib/src/types/mod.rs b/emgauwa-lib/src/types/mod.rs index 838ddec..d03678e 100644 --- a/emgauwa-lib/src/types/mod.rs +++ b/emgauwa-lib/src/types/mod.rs @@ -1,19 +1,18 @@ mod controller_uid; mod schedule_uid; -use std::collections::HashMap; -use std::sync::{Arc, Mutex}; - +use actix::Message; pub use controller_uid::ControllerUid; pub use schedule_uid::ScheduleUid; use serde_derive::{Deserialize, Serialize}; +use crate::errors::EmgauwaError; use crate::models::Controller; -pub type ConnectedControllersType = Arc>>; pub type Weekday = i64; -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Message)] +#[rtype(result = "Result<(), EmgauwaError>")] pub enum ControllerWsAction { Register(Controller), }