diff --git a/Cargo.lock b/Cargo.lock index 3f48599..d56fed7 100644 Binary files a/Cargo.lock and b/Cargo.lock differ diff --git a/src/app_state.rs b/src/app_state.rs index 743270f..ec9e262 100644 --- a/src/app_state.rs +++ b/src/app_state.rs @@ -86,10 +86,9 @@ impl Handler for AppState { fn handle(&mut self, _msg: Reload, _ctx: &mut Self::Context) -> Self::Result { log::debug!("Reloading controller"); - let mut tx = block_on(self.pool.begin())?; + let mut pool_conn = block_on(self.pool.acquire())?; - self.this.reload(&mut tx)?; - block_on(tx.commit())?; + self.this.reload(&mut pool_conn)?; self.notify_controller_change(); diff --git a/src/main.rs b/src/main.rs index d18fec0..1ba2d42 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,7 +6,7 @@ use emgauwa_common::models::{Controller, FromDbModel}; use emgauwa_common::types::EmgauwaUid; use emgauwa_common::utils::{drop_privileges, init_logging}; use rppal_pfd::PiFaceDigital; -use sqlx::Transaction; +use sqlx::pool::PoolConnection; use sqlx::Sqlite; use crate::relay_loop::run_relays_loop; @@ -21,11 +21,11 @@ mod utils; mod ws; async fn create_this_controller( - tx: &mut Transaction<'_, Sqlite>, + conn: &mut PoolConnection, settings: &Settings, ) -> Result { DbController::create( - tx, + conn, &EmgauwaUid::default(), &settings.name, settings.relays.len() as i64, @@ -35,12 +35,12 @@ async fn create_this_controller( } async fn create_this_relay( - tx: &mut Transaction<'_, Sqlite>, + conn: &mut PoolConnection, this_controller: &DbController, settings_relay: &settings::Relay, ) -> Result { let relay = DbRelay::create( - tx, + conn, &settings_relay.name, settings_relay.number.ok_or(EmgauwaError::Internal( "Relay number is missing".to_string(), @@ -49,9 +49,9 @@ async fn create_this_relay( ) .await?; - let off = DbSchedule::get_off(tx).await?; + let off = DbSchedule::get_off(conn).await?; for weekday in 0..7 { - DbJunctionRelaySchedule::set_schedule(tx, &relay, &off, weekday).await?; + DbJunctionRelaySchedule::set_schedule(conn, &relay, &off, weekday).await?; } Ok(relay) @@ -72,20 +72,20 @@ async fn main() -> Result<(), std::io::Error> { .await .map_err(EmgauwaError::from)?; - let mut tx = pool.begin().await.map_err(EmgauwaError::from)?; + let mut conn = pool.acquire().await.map_err(EmgauwaError::from)?; - let db_controller = match DbController::get_all(&mut tx) + let db_controller = match DbController::get_all(&mut conn) .await .map_err(EmgauwaError::from)? .pop() { - None => futures::executor::block_on(create_this_controller(&mut tx, &settings))?, + None => futures::executor::block_on(create_this_controller(&mut conn, &settings))?, Some(c) => c, }; for relay in &settings.relays { if DbRelay::get_by_controller_and_num( - &mut tx, + &mut conn, &db_controller, relay.number.ok_or(EmgauwaError::Internal( "Relay number is missing".to_string(), @@ -95,19 +95,18 @@ async fn main() -> Result<(), std::io::Error> { .map_err(EmgauwaError::from)? .is_none() { - create_this_relay(&mut tx, &db_controller, relay) + create_this_relay(&mut conn, &db_controller, relay) .await .map_err(EmgauwaError::from)?; } } let db_controller = db_controller - .update(&mut tx, &db_controller.name, settings.relays.len() as i64) + .update(&mut conn, &db_controller.name, settings.relays.len() as i64) .await .map_err(EmgauwaError::from)?; - let this = Controller::from_db_model(&mut tx, db_controller).map_err(EmgauwaError::from)?; - tx.commit().await.map_err(EmgauwaError::from)?; + let this = Controller::from_db_model(&mut conn, db_controller).map_err(EmgauwaError::from)?; let url = format!( "ws://{}:{}/api/v1/ws/controllers", diff --git a/src/ws/mod.rs b/src/ws/mod.rs index cf496c1..994cec6 100644 --- a/src/ws/mod.rs +++ b/src/ws/mod.rs @@ -5,7 +5,7 @@ use emgauwa_common::errors::{DatabaseError, EmgauwaError}; use emgauwa_common::models::{Controller, Relay}; use emgauwa_common::types::{ControllerWsAction, ScheduleUid}; use futures::{future, pin_mut, SinkExt, StreamExt}; -use sqlx::Transaction; +use sqlx::pool::PoolConnection; use sqlx::{Pool, Sqlite}; use tokio::time; use tokio_tungstenite::tungstenite::Message; @@ -108,7 +108,14 @@ async fn handle_message( match serde_json::from_str(&text) { Ok(action) => { log::debug!("Received action: {:?}", action); - let action_res = handle_action(pool, app_state, action).await; + let mut pool_conn = match pool.acquire().await { + Ok(conn) => conn, + Err(err) => { + log::error!("Failed to acquire database connection: {:?}", err); + return; + } + }; + let action_res = handle_action(&mut pool_conn, app_state, action).await; if let Err(e) = action_res { log::error!("Error handling action: {:?}", e); } @@ -121,31 +128,29 @@ async fn handle_message( } pub async fn handle_action( - pool: Pool, + conn: &mut PoolConnection, app_state: &Addr, action: ControllerWsAction, ) -> Result<(), EmgauwaError> { let this = app_state_get_this(app_state).await?; - let mut tx = pool.begin().await?; match action { ControllerWsAction::Controller(controller) => { - handle_controller(&mut tx, &this, controller).await? + handle_controller(conn, &this, controller).await? } - ControllerWsAction::Relays(relays) => handle_relays(&mut tx, &this, relays).await?, - ControllerWsAction::Schedules(schedules) => handle_schedules(&mut tx, schedules).await?, + ControllerWsAction::Relays(relays) => handle_relays(conn, &this, relays).await?, + ControllerWsAction::Schedules(schedules) => handle_schedules(conn, schedules).await?, ControllerWsAction::RelayPulse((relay_num, duration)) => { handle_relay_pulse(app_state, relay_num, duration).await? } _ => return Ok(()), }; - tx.commit().await?; utils::app_state_reload(app_state).await } async fn handle_controller( - tx: &mut Transaction<'_, Sqlite>, + conn: &mut PoolConnection, this: &Controller, controller: Controller, ) -> Result<(), EmgauwaError> { @@ -154,17 +159,17 @@ async fn handle_controller( "Controller UID mismatch during update", ))); } - DbController::get_by_uid(tx, &controller.c.uid) + DbController::get_by_uid(conn, &controller.c.uid) .await? .ok_or(DatabaseError::NotFound)? - .update(tx, controller.c.name.as_str(), this.c.relay_count) + .update(conn, controller.c.name.as_str(), this.c.relay_count) .await?; Ok(()) } async fn handle_schedules( - tx: &mut Transaction<'_, Sqlite>, + conn: &mut PoolConnection, schedules: Vec, ) -> Result<(), EmgauwaError> { let mut handled_uids = vec![ @@ -179,15 +184,15 @@ async fn handle_schedules( handled_uids.push(schedule.uid.clone()); log::debug!("Handling schedule: {:?}", schedule); - let schedule_db = DbSchedule::get_by_uid(tx, &schedule.uid).await?; + let schedule_db = DbSchedule::get_by_uid(conn, &schedule.uid).await?; if let Some(schedule_db) = schedule_db { schedule_db - .update(tx, schedule.name.as_str(), &schedule.periods) + .update(conn, schedule.name.as_str(), &schedule.periods) .await?; } else { DbSchedule::create( - tx, + conn, schedule.uid.clone(), schedule.name.as_str(), &schedule.periods, @@ -200,7 +205,7 @@ async fn handle_schedules( } async fn handle_relays( - tx: &mut Transaction<'_, Sqlite>, + conn: &mut PoolConnection, this: &Controller, relays: Vec, ) -> Result<(), EmgauwaError> { @@ -210,24 +215,24 @@ async fn handle_relays( "Controller UID mismatch during relay update", ))); } - let db_relay = DbRelay::get_by_controller_and_num(tx, &this.c, relay.r.number) + let db_relay = DbRelay::get_by_controller_and_num(conn, &this.c, relay.r.number) .await? .ok_or(DatabaseError::NotFound)?; - db_relay.update(tx, relay.r.name.as_str()).await?; + db_relay.update(conn, relay.r.name.as_str()).await?; - handle_schedules(tx, relay.schedules.clone()).await?; + handle_schedules(conn, relay.schedules.clone()).await?; let mut schedules = Vec::new(); // We need to get the schedules from the database to have the right IDs for schedule in relay.schedules { schedules.push( - DbSchedule::get_by_uid(tx, &schedule.uid) + DbSchedule::get_by_uid(conn, &schedule.uid) .await? .ok_or(DatabaseError::NotFound)?, ); } - DbJunctionRelaySchedule::set_schedules(tx, &db_relay, schedules.iter().collect()).await?; + DbJunctionRelaySchedule::set_schedules(conn, &db_relay, schedules.iter().collect()).await?; } Ok(())