From 9823511b6225e80a1da687762582ba18cfdc2b3c Mon Sep 17 00:00:00 2001 From: Tobias Reisinger Date: Thu, 2 May 2024 13:30:14 +0200 Subject: [PATCH] Add sql transactions --- src/app_state.rs | 12 ++-- src/handlers/v1/controllers.rs | 30 +++++---- src/handlers/v1/macros.rs | 58 +++++++++-------- src/handlers/v1/relays.rs | 66 +++++++++++--------- src/handlers/v1/schedules.rs | 72 ++++++++++++---------- src/handlers/v1/tags.rs | 27 +++++--- src/handlers/v1/ws/controllers/handlers.rs | 23 +++---- src/handlers/v1/ws/controllers/mod.rs | 15 +---- src/main.rs | 6 +- 9 files changed, 171 insertions(+), 138 deletions(-) diff --git a/src/app_state.rs b/src/app_state.rs index 04747a2..f3eac5d 100644 --- a/src/app_state.rs +++ b/src/app_state.rs @@ -63,9 +63,10 @@ impl AppState { } fn get_relays(&self) -> Result, EmgauwaError> { - let mut pool_conn = block_on(self.pool.acquire())?; - let db_controllers = block_on(DbController::get_all(&mut pool_conn))?; - let mut controllers: Vec = convert_db_list(&mut pool_conn, db_controllers)?; + let mut tx = block_on(self.pool.begin())?; + let db_controllers = block_on(DbController::get_all(&mut tx))?; + let mut controllers: Vec = convert_db_list(&mut tx, db_controllers)?; + block_on(tx.commit())?; self.connected_controllers .iter() @@ -113,11 +114,11 @@ impl Handler for AppState { type Result = Result<(), EmgauwaError>; fn handle(&mut self, msg: DisconnectController, _ctx: &mut Self::Context) -> Self::Result { - let mut pool_conn = block_on(self.pool.acquire())?; + let mut tx = block_on(self.pool.begin())?; if let Some((controller, address)) = self.connected_controllers.remove(&msg.controller_uid) { - if let Err(err) = block_on(controller.c.update_active(&mut pool_conn, false)) { + if let Err(err) = block_on(controller.c.update_active(&mut tx, false)) { log::error!( "Failed to mark controller {} as inactive: {:?}", controller.c.uid, @@ -128,6 +129,7 @@ impl Handler for AppState { //block_on(address.send(ControllerWsAction::Disconnect))??; address.do_send(ControllerWsAction::Disconnect); } + block_on(tx.commit())?; self.notify_relay_clients(); Ok(()) } diff --git a/src/handlers/v1/controllers.rs b/src/handlers/v1/controllers.rs index 3fdb931..90c39e6 100644 --- a/src/handlers/v1/controllers.rs +++ b/src/handlers/v1/controllers.rs @@ -11,12 +11,13 @@ use crate::app_state::AppState; #[get("/controllers")] pub async fn index(pool: web::Data>) -> Result { - let mut pool_conn = pool.acquire().await?; + let mut tx = pool.begin().await?; - let db_controllers = DbController::get_all(&mut pool_conn).await?; + let db_controllers = DbController::get_all(&mut tx).await?; - let controllers: Vec = convert_db_list(&mut pool_conn, db_controllers)?; + let controllers: Vec = convert_db_list(&mut tx, db_controllers)?; + tx.commit().await?; Ok(HttpResponse::Ok().json(controllers)) } @@ -25,16 +26,18 @@ pub async fn show( pool: web::Data>, path: web::Path<(String,)>, ) -> Result { - let mut pool_conn = pool.acquire().await?; + let mut tx = pool.begin().await?; let (controller_uid,) = path.into_inner(); let uid = EmgauwaUid::try_from(controller_uid.as_str())?; - let controller = DbController::get_by_uid(&mut pool_conn, &uid) + let controller = DbController::get_by_uid(&mut tx, &uid) .await? .ok_or(DatabaseError::NotFound)?; - let return_controller = Controller::from_db_model(&mut pool_conn, controller)?; + let return_controller = Controller::from_db_model(&mut tx, controller)?; + + tx.commit().await?; Ok(HttpResponse::Ok().json(return_controller)) } @@ -45,20 +48,20 @@ pub async fn update( path: web::Path<(String,)>, data: web::Json, ) -> Result { - let mut pool_conn = pool.acquire().await?; + let mut tx = pool.begin().await?; let (controller_uid,) = path.into_inner(); let uid = EmgauwaUid::try_from(controller_uid.as_str())?; - let controller = DbController::get_by_uid(&mut pool_conn, &uid) + let controller = DbController::get_by_uid(&mut tx, &uid) .await? .ok_or(DatabaseError::NotFound)?; let controller = controller - .update(&mut pool_conn, data.name.as_str(), controller.relay_count) + .update(&mut tx, data.name.as_str(), controller.relay_count) .await?; - let return_controller = Controller::from_db_model(&mut pool_conn, controller)?; + let return_controller = Controller::from_db_model(&mut tx, controller)?; app_state .send(app_state::Action { @@ -67,6 +70,7 @@ pub async fn update( }) .await??; + tx.commit().await?; Ok(HttpResponse::Ok().json(return_controller)) } @@ -76,7 +80,7 @@ pub async fn delete( app_state: web::Data>, path: web::Path<(String,)>, ) -> Result { - let mut pool_conn = pool.acquire().await?; + let mut tx = pool.begin().await?; let (controller_uid,) = path.into_inner(); let uid = EmgauwaUid::try_from(controller_uid.as_str())?; @@ -87,6 +91,8 @@ pub async fn delete( }) .await??; - DbController::delete_by_uid(&mut pool_conn, uid).await?; + DbController::delete_by_uid(&mut tx, uid).await?; + + tx.commit().await?; Ok(HttpResponse::Ok().json("controller got deleted")) } diff --git a/src/handlers/v1/macros.rs b/src/handlers/v1/macros.rs index 0e3fc60..9703c7f 100644 --- a/src/handlers/v1/macros.rs +++ b/src/handlers/v1/macros.rs @@ -14,11 +14,12 @@ use crate::app_state::AppState; #[get("/macros")] pub async fn index(pool: web::Data>) -> Result { - let mut pool_conn = pool.acquire().await?; + let mut tx = pool.begin().await?; - let db_macros = DbMacro::get_all(&mut pool_conn).await?; - let macros: Vec = convert_db_list(&mut pool_conn, db_macros)?; + let db_macros = DbMacro::get_all(&mut tx).await?; + let macros: Vec = convert_db_list(&mut tx, db_macros)?; + tx.commit().await?; Ok(HttpResponse::Ok().json(macros)) } @@ -27,16 +28,18 @@ pub async fn show( pool: web::Data>, path: web::Path<(String,)>, ) -> Result { - let mut pool_conn = pool.acquire().await?; + let mut tx = pool.begin().await?; let (macro_uid,) = path.into_inner(); let uid = EmgauwaUid::try_from(macro_uid.as_str())?; - let db_macro = DbMacro::get_by_uid(&mut pool_conn, &uid) + let db_macro = DbMacro::get_by_uid(&mut tx, &uid) .await? .ok_or(DatabaseError::NotFound)?; - let return_macro = Macro::from_db_model(&mut pool_conn, db_macro)?; + let return_macro = Macro::from_db_model(&mut tx, db_macro)?; + + tx.commit().await?; Ok(HttpResponse::Ok().json(return_macro)) } @@ -45,15 +48,17 @@ pub async fn add( pool: web::Data>, data: web::Json, ) -> Result { - let mut pool_conn = pool.acquire().await?; + let mut tx = pool.begin().await?; - let new_macro = DbMacro::create(&mut pool_conn, EmgauwaUid::default(), &data.name).await?; + let new_macro = DbMacro::create(&mut tx, EmgauwaUid::default(), &data.name).await?; new_macro - .set_actions(&mut pool_conn, data.actions.as_slice()) + .set_actions(&mut tx, data.actions.as_slice()) .await?; - let return_macro = Macro::from_db_model(&mut pool_conn, new_macro)?; + let return_macro = Macro::from_db_model(&mut tx, new_macro)?; + + tx.commit().await?; Ok(HttpResponse::Created().json(return_macro)) } @@ -63,26 +68,28 @@ pub async fn update( path: web::Path<(String,)>, data: web::Json, ) -> Result { - let mut pool_conn = pool.acquire().await?; + let mut tx = pool.begin().await?; let (macro_uid,) = path.into_inner(); let uid = EmgauwaUid::try_from(macro_uid.as_str())?; - let db_macro = DbMacro::get_by_uid(&mut pool_conn, &uid) + let db_macro = DbMacro::get_by_uid(&mut tx, &uid) .await? .ok_or(DatabaseError::NotFound)?; if let Some(name) = &data.name { - db_macro.update(&mut pool_conn, name).await?; + db_macro.update(&mut tx, name).await?; } if let Some(actions) = &data.actions { db_macro - .set_actions(&mut pool_conn, actions.as_slice()) + .set_actions(&mut tx, actions.as_slice()) .await?; } - let return_macro = Macro::from_db_model(&mut pool_conn, db_macro)?; + let return_macro = Macro::from_db_model(&mut tx, db_macro)?; + + tx.commit().await?; Ok(HttpResponse::Ok().json(return_macro)) } @@ -91,12 +98,14 @@ pub async fn delete( pool: web::Data>, path: web::Path<(String,)>, ) -> Result { - let mut pool_conn = pool.acquire().await?; + let mut tx = pool.begin().await?; let (macro_uid,) = path.into_inner(); let uid = EmgauwaUid::try_from(macro_uid.as_str())?; - DbMacro::delete_by_uid(&mut pool_conn, uid).await?; + DbMacro::delete_by_uid(&mut tx, uid).await?; + + tx.commit().await?; Ok(HttpResponse::Ok().json("macro got deleted")) } @@ -107,27 +116,27 @@ pub async fn execute( path: web::Path<(String,)>, query: web::Query, ) -> Result { - let mut pool_conn = pool.acquire().await?; + let mut tx = pool.begin().await?; let (macro_uid,) = path.into_inner(); let uid = EmgauwaUid::try_from(macro_uid.as_str())?; - let db_macro = DbMacro::get_by_uid(&mut pool_conn, &uid) + let db_macro = DbMacro::get_by_uid(&mut tx, &uid) .await? .ok_or(DatabaseError::NotFound)?; let actions_db = match query.weekday { - None => db_macro.get_actions(&mut pool_conn).await?, + None => db_macro.get_actions(&mut tx).await?, Some(weekday) => { db_macro - .get_actions_weekday(&mut pool_conn, weekday) + .get_actions_weekday(&mut tx, weekday) .await? } }; - let mut actions: Vec = convert_db_list(&mut pool_conn, actions_db)?; + let mut actions: Vec = convert_db_list(&mut tx, actions_db)?; for action in &actions { - action.execute(&mut pool_conn).await?; + action.execute(&mut tx).await?; } let affected_controller_uids: Vec = actions @@ -144,7 +153,7 @@ pub async fn execute( if affected_relay_ids.contains(&action.relay.r.id) { continue; } - action.relay.reload(&mut pool_conn)?; + action.relay.reload(&mut tx)?; affected_relays.push(action.relay.clone()); affected_relay_ids.push(action.relay.r.id); } @@ -157,5 +166,6 @@ pub async fn execute( .await??; } + tx.commit().await?; Ok(HttpResponse::Ok().finish()) // TODO add a message? } diff --git a/src/handlers/v1/relays.rs b/src/handlers/v1/relays.rs index f07c081..96c1e15 100644 --- a/src/handlers/v1/relays.rs +++ b/src/handlers/v1/relays.rs @@ -14,12 +14,13 @@ use crate::app_state::AppState; #[get("/relays")] pub async fn index(pool: web::Data>) -> Result { - let mut pool_conn = pool.acquire().await?; + let mut tx = pool.begin().await?; - let db_relays = DbRelay::get_all(&mut pool_conn).await?; + let db_relays = DbRelay::get_all(&mut tx).await?; - let relays: Vec = convert_db_list(&mut pool_conn, db_relays)?; + let relays: Vec = convert_db_list(&mut tx, db_relays)?; + tx.commit().await?; Ok(HttpResponse::Ok().json(relays)) } @@ -28,16 +29,17 @@ pub async fn tagged( pool: web::Data>, path: web::Path<(String,)>, ) -> Result { - let mut pool_conn = pool.acquire().await?; + let mut tx = pool.begin().await?; let (tag,) = path.into_inner(); - let tag_db = DbTag::get_by_tag(&mut pool_conn, &tag) + let tag_db = DbTag::get_by_tag(&mut tx, &tag) .await? .ok_or(DatabaseError::NotFound)?; - let db_relays = DbRelay::get_by_tag(&mut pool_conn, &tag_db).await?; - let relays: Vec = convert_db_list(&mut pool_conn, db_relays)?; + let db_relays = DbRelay::get_by_tag(&mut tx, &tag_db).await?; + let relays: Vec = convert_db_list(&mut tx, db_relays)?; + tx.commit().await?; Ok(HttpResponse::Ok().json(relays)) } @@ -46,18 +48,20 @@ pub async fn index_for_controller( pool: web::Data>, path: web::Path<(String,)>, ) -> Result { - let mut pool_conn = pool.acquire().await?; + let mut tx = pool.begin().await?; let (controller_uid,) = path.into_inner(); let uid = EmgauwaUid::try_from(controller_uid.as_str())?; - let controller = DbController::get_by_uid(&mut pool_conn, &uid) + let controller = DbController::get_by_uid(&mut tx, &uid) .await? .ok_or(DatabaseError::NotFound)?; - let db_relays = controller.get_relays(&mut pool_conn).await?; + let db_relays = controller.get_relays(&mut tx).await?; - let relays: Vec = convert_db_list(&mut pool_conn, db_relays)?; + let relays: Vec = convert_db_list(&mut tx, db_relays)?; + + tx.commit().await?; Ok(HttpResponse::Ok().json(relays)) } @@ -66,20 +70,22 @@ pub async fn show_for_controller( pool: web::Data>, path: web::Path<(String, i64)>, ) -> Result { - let mut pool_conn = pool.acquire().await?; + let mut tx = pool.begin().await?; let (controller_uid, relay_num) = path.into_inner(); let uid = EmgauwaUid::try_from(controller_uid.as_str())?; - let controller = DbController::get_by_uid(&mut pool_conn, &uid) + let controller = DbController::get_by_uid(&mut tx, &uid) .await? .ok_or(DatabaseError::NotFound)?; - let relay = DbRelay::get_by_controller_and_num(&mut pool_conn, &controller, relay_num) + let relay = DbRelay::get_by_controller_and_num(&mut tx, &controller, relay_num) .await? .ok_or(DatabaseError::NotFound)?; - let return_relay = Relay::from_db_model(&mut pool_conn, relay)?; + let return_relay = Relay::from_db_model(&mut tx, relay)?; + + tx.commit().await?; Ok(HttpResponse::Ok().json(return_relay)) } @@ -90,32 +96,32 @@ pub async fn update_for_controller( path: web::Path<(String, i64)>, data: web::Json, ) -> Result { - let mut pool_conn = pool.acquire().await?; + let mut tx = pool.begin().await?; let (controller_uid, relay_num) = path.into_inner(); let uid = EmgauwaUid::try_from(controller_uid.as_str())?; - let controller = DbController::get_by_uid(&mut pool_conn, &uid) + let controller = DbController::get_by_uid(&mut tx, &uid) .await? .ok_or(DatabaseError::NotFound)?; - let mut relay = DbRelay::get_by_controller_and_num(&mut pool_conn, &controller, relay_num) + let mut relay = DbRelay::get_by_controller_and_num(&mut tx, &controller, relay_num) .await? .ok_or(DatabaseError::NotFound)?; if let Some(name) = &data.name { - relay = relay.update(&mut pool_conn, name.as_str()).await?; + relay = relay.update(&mut tx, name.as_str()).await?; } if let Some(schedule_uids) = &data.schedules { if schedule_uids.len() == 7 { let mut schedules = Vec::new(); for s_uid in schedule_uids { - schedules.push(s_uid.get_schedule(&mut pool_conn).await?); + schedules.push(s_uid.get_schedule(&mut tx).await?); } DbJunctionRelaySchedule::set_schedules( - &mut pool_conn, + &mut tx, &relay, schedules.iter().collect(), ) @@ -124,9 +130,9 @@ pub async fn update_for_controller( } if let Some(s_uid) = &data.active_schedule { - let schedule = s_uid.get_schedule(&mut pool_conn).await?; + let schedule = s_uid.get_schedule(&mut tx).await?; DbJunctionRelaySchedule::set_schedule( - &mut pool_conn, + &mut tx, &relay, &schedule, utils::get_weekday(), @@ -135,12 +141,12 @@ pub async fn update_for_controller( } if let Some(tags) = &data.tags { - relay.set_tags(&mut pool_conn, tags.as_slice()).await?; + relay.set_tags(&mut tx, tags.as_slice()).await?; } - let relay = relay.reload(&mut pool_conn).await?; + let relay = relay.reload(&mut tx).await?; - let return_relay = Relay::from_db_model(&mut pool_conn, relay)?; + let return_relay = Relay::from_db_model(&mut tx, relay)?; app_state .send(app_state::Action { @@ -149,6 +155,7 @@ pub async fn update_for_controller( }) .await??; + tx.commit().await?; Ok(HttpResponse::Ok().json(return_relay)) } @@ -159,16 +166,16 @@ pub async fn pulse( path: web::Path<(String, i64)>, data: web::Json, ) -> Result { - let mut pool_conn = pool.acquire().await?; + let mut tx = pool.begin().await?; let (controller_uid, relay_num) = path.into_inner(); let uid = EmgauwaUid::try_from(controller_uid.as_str())?; - let controller = DbController::get_by_uid(&mut pool_conn, &uid) + let controller = DbController::get_by_uid(&mut tx, &uid) .await? .ok_or(DatabaseError::NotFound)?; - let relay = DbRelay::get_by_controller_and_num(&mut pool_conn, &controller, relay_num) + let relay = DbRelay::get_by_controller_and_num(&mut tx, &controller, relay_num) .await? .ok_or(DatabaseError::NotFound)?; @@ -181,5 +188,6 @@ pub async fn pulse( }) .await??; + tx.commit().await?; Ok(HttpResponse::Ok().finish()) // TODO add a message? } diff --git a/src/handlers/v1/schedules.rs b/src/handlers/v1/schedules.rs index 8e329a1..548610e 100644 --- a/src/handlers/v1/schedules.rs +++ b/src/handlers/v1/schedules.rs @@ -7,7 +7,7 @@ use emgauwa_common::types::{ ControllerWsAction, RequestScheduleCreate, RequestScheduleUpdate, ScheduleUid, }; use itertools::Itertools; -use sqlx::pool::PoolConnection; +use sqlx::Transaction; use sqlx::{Pool, Sqlite}; use crate::app_state; @@ -15,11 +15,12 @@ use crate::app_state::AppState; #[get("/schedules")] pub async fn index(pool: web::Data>) -> Result { - let mut pool_conn = pool.acquire().await?; + let mut tx = pool.begin().await?; - let db_schedules = DbSchedule::get_all(&mut pool_conn).await?; - let schedules: Vec = convert_db_list(&mut pool_conn, db_schedules)?; + let db_schedules = DbSchedule::get_all(&mut tx).await?; + let schedules: Vec = convert_db_list(&mut tx, db_schedules)?; + tx.commit().await?; Ok(HttpResponse::Ok().json(schedules)) } @@ -28,16 +29,17 @@ pub async fn tagged( pool: web::Data>, path: web::Path<(String,)>, ) -> Result { - let mut pool_conn = pool.acquire().await?; + let mut tx = pool.begin().await?; let (tag,) = path.into_inner(); - let tag_db = DbTag::get_by_tag(&mut pool_conn, &tag) + let tag_db = DbTag::get_by_tag(&mut tx, &tag) .await? .ok_or(DatabaseError::NotFound)?; - let db_schedules = DbSchedule::get_by_tag(&mut pool_conn, &tag_db).await?; - let schedules: Vec = convert_db_list(&mut pool_conn, db_schedules)?; + let db_schedules = DbSchedule::get_by_tag(&mut tx, &tag_db).await?; + let schedules: Vec = convert_db_list(&mut tx, db_schedules)?; + tx.commit().await?; Ok(HttpResponse::Ok().json(schedules)) } @@ -46,16 +48,18 @@ pub async fn show( pool: web::Data>, path: web::Path<(String,)>, ) -> Result { - let mut pool_conn = pool.acquire().await?; + let mut tx = pool.begin().await?; let (schedule_uid,) = path.into_inner(); let uid = ScheduleUid::try_from(schedule_uid.as_str())?; - let schedule = DbSchedule::get_by_uid(&mut pool_conn, &uid) + let schedule = DbSchedule::get_by_uid(&mut tx, &uid) .await? .ok_or(DatabaseError::NotFound)?; - let return_schedule = Schedule::from_db_model(&mut pool_conn, schedule)?; + let return_schedule = Schedule::from_db_model(&mut tx, schedule)?; + + tx.commit().await?; Ok(HttpResponse::Ok().json(return_schedule)) } @@ -64,10 +68,10 @@ pub async fn add( pool: web::Data>, data: web::Json, ) -> Result { - let mut pool_conn = pool.acquire().await?; + let mut tx = pool.begin().await?; let new_schedule = DbSchedule::create( - &mut pool_conn, + &mut tx, ScheduleUid::default(), &data.name, &data.periods, @@ -76,20 +80,22 @@ pub async fn add( if let Some(tags) = &data.tags { new_schedule - .set_tags(&mut pool_conn, tags.as_slice()) + .set_tags(&mut tx, tags.as_slice()) .await?; } - let return_schedule = Schedule::from_db_model(&mut pool_conn, new_schedule)?; + let return_schedule = Schedule::from_db_model(&mut tx, new_schedule)?; + + tx.commit().await?; Ok(HttpResponse::Created().json(return_schedule)) } async fn add_list_single( - conn: &mut PoolConnection, + tx: &mut Transaction<'_, Sqlite>, request_schedule: &RequestScheduleCreate, ) -> Result { let new_schedule = DbSchedule::create( - conn, + tx, ScheduleUid::default(), &request_schedule.name, &request_schedule.periods, @@ -97,7 +103,7 @@ async fn add_list_single( .await?; if let Some(tags) = &request_schedule.tags { - new_schedule.set_tags(conn, tags.as_slice()).await?; + new_schedule.set_tags(tx, tags.as_slice()).await?; } Ok(new_schedule) @@ -108,15 +114,17 @@ pub async fn add_list( pool: web::Data>, data: web::Json>, ) -> Result { - let mut pool_conn = pool.acquire().await?; + let mut tx = pool.begin().await?; let mut db_schedules: Vec = Vec::new(); for s in data.iter() { - let new_s = futures::executor::block_on(add_list_single(&mut pool_conn, s))?; + let new_s = add_list_single(&mut tx, s).await?; db_schedules.push(new_s); } - let schedules: Vec = convert_db_list(&mut pool_conn, db_schedules)?; + let schedules: Vec = convert_db_list(&mut tx, db_schedules)?; + + tx.commit().await?; Ok(HttpResponse::Created().json(schedules)) } @@ -127,12 +135,12 @@ pub async fn update( path: web::Path<(String,)>, data: web::Json, ) -> Result { - let mut pool_conn = pool.acquire().await?; + let mut tx = pool.begin().await?; let (schedule_uid,) = path.into_inner(); let uid = ScheduleUid::try_from(schedule_uid.as_str())?; - let schedule = DbSchedule::get_by_uid(&mut pool_conn, &uid) + let schedule = DbSchedule::get_by_uid(&mut tx, &uid) .await? .ok_or(DatabaseError::NotFound)?; @@ -146,13 +154,13 @@ pub async fn update( Some(period) => period, }; - let schedule = schedule.update(&mut pool_conn, name, periods).await?; + let schedule = schedule.update(&mut tx, name, periods).await?; if let Some(tags) = &data.tags { - schedule.set_tags(&mut pool_conn, tags.as_slice()).await?; + schedule.set_tags(&mut tx, tags.as_slice()).await?; } - let controller_ids: Vec = DbJunctionRelaySchedule::get_relays(&mut pool_conn, &schedule) + let controller_ids: Vec = DbJunctionRelaySchedule::get_relays(&mut tx, &schedule) .await? .into_iter() .map(|r| r.controller_id) @@ -160,7 +168,7 @@ pub async fn update( .collect(); for controller_id in controller_ids { - let controller = DbController::get(&mut pool_conn, controller_id) + let controller = DbController::get(&mut tx, controller_id) .await? .ok_or(DatabaseError::NotFound)?; app_state @@ -171,7 +179,9 @@ pub async fn update( .await??; } - let return_schedule = Schedule::from_db_model(&mut pool_conn, schedule)?; + let return_schedule = Schedule::from_db_model(&mut tx, schedule)?; + + tx.commit().await?; Ok(HttpResponse::Ok().json(return_schedule)) } @@ -180,8 +190,6 @@ pub async fn delete( pool: web::Data>, path: web::Path<(String,)>, ) -> Result { - let mut pool_conn = pool.acquire().await?; - let (schedule_uid,) = path.into_inner(); let uid = ScheduleUid::try_from(schedule_uid.as_str())?; @@ -189,7 +197,9 @@ pub async fn delete( ScheduleUid::Off => Err(EmgauwaError::from(ApiError::ProtectedSchedule)), ScheduleUid::On => Err(EmgauwaError::from(ApiError::ProtectedSchedule)), ScheduleUid::Any(_) => { - DbSchedule::delete_by_uid(&mut pool_conn, uid).await?; + let mut tx = pool.begin().await?; + DbSchedule::delete_by_uid(&mut tx, uid).await?; + tx.commit().await?; Ok(HttpResponse::Ok().json("schedule got deleted")) } } diff --git a/src/handlers/v1/tags.rs b/src/handlers/v1/tags.rs index cf4ef36..6dfa221 100644 --- a/src/handlers/v1/tags.rs +++ b/src/handlers/v1/tags.rs @@ -7,12 +7,13 @@ use sqlx::{Pool, Sqlite}; #[get("/tags")] pub async fn index(pool: web::Data>) -> Result { - let mut pool_conn = pool.acquire().await?; + let mut tx = pool.begin().await?; - let db_tags = DbTag::get_all(&mut pool_conn).await?; + let db_tags = DbTag::get_all(&mut tx).await?; let tags: Vec = db_tags.iter().map(|t| t.tag.clone()).collect(); + tx.commit().await?; Ok(HttpResponse::Ok().json(tags)) } @@ -21,15 +22,17 @@ pub async fn show( pool: web::Data>, path: web::Path<(String,)>, ) -> Result { - let mut pool_conn = pool.acquire().await?; + let mut tx = pool.begin().await?; let (tag_name,) = path.into_inner(); - let tag = DbTag::get_by_tag(&mut pool_conn, &tag_name) + let tag = DbTag::get_by_tag(&mut tx, &tag_name) .await? .ok_or(DatabaseError::NotFound)?; - let return_tag = Tag::from_db_model(&mut pool_conn, tag)?; + let return_tag = Tag::from_db_model(&mut tx, tag)?; + + tx.commit().await?; Ok(HttpResponse::Ok().json(return_tag)) } @@ -38,11 +41,13 @@ pub async fn delete( pool: web::Data>, path: web::Path<(String,)>, ) -> Result { - let mut pool_conn = pool.acquire().await?; + let mut tx = pool.begin().await?; let (tag_name,) = path.into_inner(); - DbTag::delete_by_tag(&mut pool_conn, &tag_name).await?; + DbTag::delete_by_tag(&mut tx, &tag_name).await?; + + tx.commit().await?; Ok(HttpResponse::Ok().json("tag got deleted")) } @@ -51,11 +56,13 @@ pub async fn add( pool: web::Data>, data: web::Json, ) -> Result { - let mut pool_conn = pool.acquire().await?; + let mut tx = pool.begin().await?; - let new_tag = DbTag::create(&mut pool_conn, &data.tag).await?; + let new_tag = DbTag::create(&mut tx, &data.tag).await?; let cache = (Vec::new(), Vec::new()); // a new tag can't have any relays or schedules - let return_tag = Tag::from_db_model_cache(&mut pool_conn, new_tag, cache)?; + let return_tag = Tag::from_db_model_cache(&mut tx, new_tag, cache)?; + + tx.commit().await?; Ok(HttpResponse::Created().json(return_tag)) } diff --git a/src/handlers/v1/ws/controllers/handlers.rs b/src/handlers/v1/ws/controllers/handlers.rs index 9a01a7b..bb7ff50 100644 --- a/src/handlers/v1/ws/controllers/handlers.rs +++ b/src/handlers/v1/ws/controllers/handlers.rs @@ -5,8 +5,6 @@ use emgauwa_common::models::{Controller, FromDbModel}; use emgauwa_common::types::{ControllerWsAction, EmgauwaUid, RelayStates}; use emgauwa_common::utils; use futures::executor::block_on; -use sqlx::pool::PoolConnection; -use sqlx::Sqlite; use crate::app_state::{Action, ConnectController, UpdateRelayStates}; use crate::handlers::v1::ws::controllers::ControllersWs; @@ -14,7 +12,6 @@ use crate::handlers::v1::ws::controllers::ControllersWs; impl ControllersWs { pub fn handle_register( &mut self, - conn: &mut PoolConnection, ctx: &mut ::Context, controller: Controller, ) -> Result<(), EmgauwaError> { @@ -23,16 +20,19 @@ impl ControllersWs { controller.c.name, controller.c.uid ); + + let mut tx = block_on(self.pool.begin())?; + let c = &controller.c; let controller_db = block_on(DbController::get_by_uid_or_create( - conn, + &mut tx, &c.uid, &c.name, c.relay_count, ))?; - block_on(controller_db.update_active(conn, true))?; + block_on(controller_db.update_active(&mut tx, true))?; // update only the relay count - block_on(controller_db.update(conn, &controller_db.name, c.relay_count))?; + block_on(controller_db.update(&mut tx, &controller_db.name, c.relay_count))?; for relay in &controller.relays { log::debug!( @@ -45,7 +45,7 @@ impl ControllersWs { } ); let (new_relay, created) = block_on(DbRelay::get_by_controller_and_num_or_create( - conn, + &mut tx, &controller_db, relay.r.number, &relay.r.name, @@ -54,7 +54,7 @@ impl ControllersWs { let mut relay_schedules = Vec::new(); for schedule in &relay.schedules { let (new_schedule, _) = block_on(DbSchedule::get_by_uid_or_create( - conn, + &mut tx, schedule.uid.clone(), &schedule.name, &schedule.periods, @@ -63,7 +63,7 @@ impl ControllersWs { } block_on(DbJunctionRelaySchedule::set_schedules( - conn, + &mut tx, &new_relay, relay_schedules.iter().collect(), ))?; @@ -71,9 +71,9 @@ impl ControllersWs { } let controller_uid = &controller.c.uid; - let controller_db = block_on(DbController::get_by_uid(conn, controller_uid))? + let controller_db = block_on(DbController::get_by_uid(&mut tx, controller_uid))? .ok_or(DatabaseError::InsertGetError)?; - let controller = Controller::from_db_model(conn, controller_db)?; + let controller = Controller::from_db_model(&mut tx, controller_db)?; let addr = ctx.address(); self.controller_uid = Some(controller_uid.clone()); @@ -91,6 +91,7 @@ impl ControllersWs { action: ControllerWsAction::Relays(controller.relays), }))??; + block_on(tx.commit())?; log::debug!("Done registering controller"); Ok(()) } diff --git a/src/handlers/v1/ws/controllers/mod.rs b/src/handlers/v1/ws/controllers/mod.rs index 2b39e5a..ac72561 100644 --- a/src/handlers/v1/ws/controllers/mod.rs +++ b/src/handlers/v1/ws/controllers/mod.rs @@ -9,7 +9,6 @@ use emgauwa_common::constants::{HEARTBEAT_INTERVAL, HEARTBEAT_TIMEOUT}; use emgauwa_common::errors::EmgauwaError; use emgauwa_common::types::{ControllerWsAction, EmgauwaUid}; use futures::executor::block_on; -use sqlx::pool::PoolConnection; use sqlx::{Pool, Sqlite}; use ws::Message; @@ -48,12 +47,11 @@ impl Actor for ControllersWs { impl ControllersWs { pub fn handle_action( &mut self, - conn: &mut PoolConnection, ctx: &mut ::Context, action: ControllerWsAction, ) { let action_res = match action { - ControllerWsAction::Register(controller) => self.handle_register(conn, ctx, controller), + ControllerWsAction::Register(controller) => self.handle_register(ctx, controller), ControllerWsAction::RelayStates((controller_uid, relay_states)) => { self.handle_relay_states(controller_uid, relay_states) } @@ -103,15 +101,6 @@ impl Handler for ControllersWs { impl StreamHandler> for ControllersWs { fn handle(&mut self, msg: Result, ctx: &mut Self::Context) { - let mut pool_conn = match block_on(self.pool.acquire()) { - Ok(conn) => conn, - Err(err) => { - log::error!("Failed to acquire database connection: {:?}", err); - ctx.stop(); - return; - } - }; - let msg = match msg { Err(_) => { ctx.stop(); @@ -130,7 +119,7 @@ impl StreamHandler> for ControllersWs { } Message::Text(text) => match serde_json::from_str(&text) { Ok(action) => { - self.handle_action(&mut pool_conn, ctx, action); + self.handle_action(ctx, action); } Err(e) => { log::error!("Error deserializing action: {:?}", e); diff --git a/src/main.rs b/src/main.rs index 0ed9f92..7748bb3 100644 --- a/src/main.rs +++ b/src/main.rs @@ -28,11 +28,11 @@ async fn main() -> Result<(), std::io::Error> { let pool = emgauwa_common::db::init(&settings.database).await?; - let mut conn = pool.acquire().await.map_err(EmgauwaError::from)?; - DbController::all_inactive(&mut conn) + let mut tx = pool.begin().await.map_err(EmgauwaError::from)?; + DbController::all_inactive(&mut tx) .await .map_err(EmgauwaError::from)?; - conn.close().await.map_err(EmgauwaError::from)?; + tx.commit().await.map_err(EmgauwaError::from)?; let app_state_arbiter = Arbiter::with_tokio_rt(|| { tokio::runtime::Builder::new_multi_thread()