Add sql transactions

This commit is contained in:
Tobias Reisinger 2024-05-02 13:30:14 +02:00
parent 455ca50695
commit 9823511b62
Signed by: serguzim
GPG key ID: 13AD60C237A28DFE
9 changed files with 171 additions and 138 deletions

View file

@ -63,9 +63,10 @@ impl AppState {
} }
fn get_relays(&self) -> Result<Vec<Relay>, EmgauwaError> { fn get_relays(&self) -> Result<Vec<Relay>, EmgauwaError> {
let mut pool_conn = block_on(self.pool.acquire())?; let mut tx = block_on(self.pool.begin())?;
let db_controllers = block_on(DbController::get_all(&mut pool_conn))?; let db_controllers = block_on(DbController::get_all(&mut tx))?;
let mut controllers: Vec<Controller> = convert_db_list(&mut pool_conn, db_controllers)?; let mut controllers: Vec<Controller> = convert_db_list(&mut tx, db_controllers)?;
block_on(tx.commit())?;
self.connected_controllers self.connected_controllers
.iter() .iter()
@ -113,11 +114,11 @@ impl Handler<DisconnectController> for AppState {
type Result = Result<(), EmgauwaError>; type Result = Result<(), EmgauwaError>;
fn handle(&mut self, msg: DisconnectController, _ctx: &mut Self::Context) -> Self::Result { fn handle(&mut self, msg: DisconnectController, _ctx: &mut Self::Context) -> Self::Result {
let mut pool_conn = block_on(self.pool.acquire())?; let mut tx = block_on(self.pool.begin())?;
if let Some((controller, address)) = self.connected_controllers.remove(&msg.controller_uid) if let Some((controller, address)) = self.connected_controllers.remove(&msg.controller_uid)
{ {
if let Err(err) = block_on(controller.c.update_active(&mut pool_conn, false)) { if let Err(err) = block_on(controller.c.update_active(&mut tx, false)) {
log::error!( log::error!(
"Failed to mark controller {} as inactive: {:?}", "Failed to mark controller {} as inactive: {:?}",
controller.c.uid, controller.c.uid,
@ -128,6 +129,7 @@ impl Handler<DisconnectController> for AppState {
//block_on(address.send(ControllerWsAction::Disconnect))??; //block_on(address.send(ControllerWsAction::Disconnect))??;
address.do_send(ControllerWsAction::Disconnect); address.do_send(ControllerWsAction::Disconnect);
} }
block_on(tx.commit())?;
self.notify_relay_clients(); self.notify_relay_clients();
Ok(()) Ok(())
} }

View file

@ -11,12 +11,13 @@ use crate::app_state::AppState;
#[get("/controllers")] #[get("/controllers")]
pub async fn index(pool: web::Data<Pool<Sqlite>>) -> Result<HttpResponse, EmgauwaError> { pub async fn index(pool: web::Data<Pool<Sqlite>>) -> Result<HttpResponse, EmgauwaError> {
let mut pool_conn = pool.acquire().await?; let mut tx = pool.begin().await?;
let db_controllers = DbController::get_all(&mut pool_conn).await?; let db_controllers = DbController::get_all(&mut tx).await?;
let controllers: Vec<Controller> = convert_db_list(&mut pool_conn, db_controllers)?; let controllers: Vec<Controller> = convert_db_list(&mut tx, db_controllers)?;
tx.commit().await?;
Ok(HttpResponse::Ok().json(controllers)) Ok(HttpResponse::Ok().json(controllers))
} }
@ -25,16 +26,18 @@ pub async fn show(
pool: web::Data<Pool<Sqlite>>, pool: web::Data<Pool<Sqlite>>,
path: web::Path<(String,)>, path: web::Path<(String,)>,
) -> Result<HttpResponse, EmgauwaError> { ) -> Result<HttpResponse, EmgauwaError> {
let mut pool_conn = pool.acquire().await?; let mut tx = pool.begin().await?;
let (controller_uid,) = path.into_inner(); let (controller_uid,) = path.into_inner();
let uid = EmgauwaUid::try_from(controller_uid.as_str())?; let uid = EmgauwaUid::try_from(controller_uid.as_str())?;
let controller = DbController::get_by_uid(&mut pool_conn, &uid) let controller = DbController::get_by_uid(&mut tx, &uid)
.await? .await?
.ok_or(DatabaseError::NotFound)?; .ok_or(DatabaseError::NotFound)?;
let return_controller = Controller::from_db_model(&mut pool_conn, controller)?; let return_controller = Controller::from_db_model(&mut tx, controller)?;
tx.commit().await?;
Ok(HttpResponse::Ok().json(return_controller)) Ok(HttpResponse::Ok().json(return_controller))
} }
@ -45,20 +48,20 @@ pub async fn update(
path: web::Path<(String,)>, path: web::Path<(String,)>,
data: web::Json<RequestControllerUpdate>, data: web::Json<RequestControllerUpdate>,
) -> Result<HttpResponse, EmgauwaError> { ) -> Result<HttpResponse, EmgauwaError> {
let mut pool_conn = pool.acquire().await?; let mut tx = pool.begin().await?;
let (controller_uid,) = path.into_inner(); let (controller_uid,) = path.into_inner();
let uid = EmgauwaUid::try_from(controller_uid.as_str())?; let uid = EmgauwaUid::try_from(controller_uid.as_str())?;
let controller = DbController::get_by_uid(&mut pool_conn, &uid) let controller = DbController::get_by_uid(&mut tx, &uid)
.await? .await?
.ok_or(DatabaseError::NotFound)?; .ok_or(DatabaseError::NotFound)?;
let controller = controller let controller = controller
.update(&mut pool_conn, data.name.as_str(), controller.relay_count) .update(&mut tx, data.name.as_str(), controller.relay_count)
.await?; .await?;
let return_controller = Controller::from_db_model(&mut pool_conn, controller)?; let return_controller = Controller::from_db_model(&mut tx, controller)?;
app_state app_state
.send(app_state::Action { .send(app_state::Action {
@ -67,6 +70,7 @@ pub async fn update(
}) })
.await??; .await??;
tx.commit().await?;
Ok(HttpResponse::Ok().json(return_controller)) Ok(HttpResponse::Ok().json(return_controller))
} }
@ -76,7 +80,7 @@ pub async fn delete(
app_state: web::Data<Addr<AppState>>, app_state: web::Data<Addr<AppState>>,
path: web::Path<(String,)>, path: web::Path<(String,)>,
) -> Result<HttpResponse, EmgauwaError> { ) -> Result<HttpResponse, EmgauwaError> {
let mut pool_conn = pool.acquire().await?; let mut tx = pool.begin().await?;
let (controller_uid,) = path.into_inner(); let (controller_uid,) = path.into_inner();
let uid = EmgauwaUid::try_from(controller_uid.as_str())?; let uid = EmgauwaUid::try_from(controller_uid.as_str())?;
@ -87,6 +91,8 @@ pub async fn delete(
}) })
.await??; .await??;
DbController::delete_by_uid(&mut pool_conn, uid).await?; DbController::delete_by_uid(&mut tx, uid).await?;
tx.commit().await?;
Ok(HttpResponse::Ok().json("controller got deleted")) Ok(HttpResponse::Ok().json("controller got deleted"))
} }

View file

@ -14,11 +14,12 @@ use crate::app_state::AppState;
#[get("/macros")] #[get("/macros")]
pub async fn index(pool: web::Data<Pool<Sqlite>>) -> Result<HttpResponse, EmgauwaError> { pub async fn index(pool: web::Data<Pool<Sqlite>>) -> Result<HttpResponse, EmgauwaError> {
let mut pool_conn = pool.acquire().await?; let mut tx = pool.begin().await?;
let db_macros = DbMacro::get_all(&mut pool_conn).await?; let db_macros = DbMacro::get_all(&mut tx).await?;
let macros: Vec<Macro> = convert_db_list(&mut pool_conn, db_macros)?; let macros: Vec<Macro> = convert_db_list(&mut tx, db_macros)?;
tx.commit().await?;
Ok(HttpResponse::Ok().json(macros)) Ok(HttpResponse::Ok().json(macros))
} }
@ -27,16 +28,18 @@ pub async fn show(
pool: web::Data<Pool<Sqlite>>, pool: web::Data<Pool<Sqlite>>,
path: web::Path<(String,)>, path: web::Path<(String,)>,
) -> Result<HttpResponse, EmgauwaError> { ) -> Result<HttpResponse, EmgauwaError> {
let mut pool_conn = pool.acquire().await?; let mut tx = pool.begin().await?;
let (macro_uid,) = path.into_inner(); let (macro_uid,) = path.into_inner();
let uid = EmgauwaUid::try_from(macro_uid.as_str())?; let uid = EmgauwaUid::try_from(macro_uid.as_str())?;
let db_macro = DbMacro::get_by_uid(&mut pool_conn, &uid) let db_macro = DbMacro::get_by_uid(&mut tx, &uid)
.await? .await?
.ok_or(DatabaseError::NotFound)?; .ok_or(DatabaseError::NotFound)?;
let return_macro = Macro::from_db_model(&mut pool_conn, db_macro)?; let return_macro = Macro::from_db_model(&mut tx, db_macro)?;
tx.commit().await?;
Ok(HttpResponse::Ok().json(return_macro)) Ok(HttpResponse::Ok().json(return_macro))
} }
@ -45,15 +48,17 @@ pub async fn add(
pool: web::Data<Pool<Sqlite>>, pool: web::Data<Pool<Sqlite>>,
data: web::Json<RequestMacroCreate>, data: web::Json<RequestMacroCreate>,
) -> Result<HttpResponse, EmgauwaError> { ) -> Result<HttpResponse, EmgauwaError> {
let mut pool_conn = pool.acquire().await?; let mut tx = pool.begin().await?;
let new_macro = DbMacro::create(&mut pool_conn, EmgauwaUid::default(), &data.name).await?; let new_macro = DbMacro::create(&mut tx, EmgauwaUid::default(), &data.name).await?;
new_macro new_macro
.set_actions(&mut pool_conn, data.actions.as_slice()) .set_actions(&mut tx, data.actions.as_slice())
.await?; .await?;
let return_macro = Macro::from_db_model(&mut pool_conn, new_macro)?; let return_macro = Macro::from_db_model(&mut tx, new_macro)?;
tx.commit().await?;
Ok(HttpResponse::Created().json(return_macro)) Ok(HttpResponse::Created().json(return_macro))
} }
@ -63,26 +68,28 @@ pub async fn update(
path: web::Path<(String,)>, path: web::Path<(String,)>,
data: web::Json<RequestMacroUpdate>, data: web::Json<RequestMacroUpdate>,
) -> Result<HttpResponse, EmgauwaError> { ) -> Result<HttpResponse, EmgauwaError> {
let mut pool_conn = pool.acquire().await?; let mut tx = pool.begin().await?;
let (macro_uid,) = path.into_inner(); let (macro_uid,) = path.into_inner();
let uid = EmgauwaUid::try_from(macro_uid.as_str())?; let uid = EmgauwaUid::try_from(macro_uid.as_str())?;
let db_macro = DbMacro::get_by_uid(&mut pool_conn, &uid) let db_macro = DbMacro::get_by_uid(&mut tx, &uid)
.await? .await?
.ok_or(DatabaseError::NotFound)?; .ok_or(DatabaseError::NotFound)?;
if let Some(name) = &data.name { if let Some(name) = &data.name {
db_macro.update(&mut pool_conn, name).await?; db_macro.update(&mut tx, name).await?;
} }
if let Some(actions) = &data.actions { if let Some(actions) = &data.actions {
db_macro db_macro
.set_actions(&mut pool_conn, actions.as_slice()) .set_actions(&mut tx, actions.as_slice())
.await?; .await?;
} }
let return_macro = Macro::from_db_model(&mut pool_conn, db_macro)?; let return_macro = Macro::from_db_model(&mut tx, db_macro)?;
tx.commit().await?;
Ok(HttpResponse::Ok().json(return_macro)) Ok(HttpResponse::Ok().json(return_macro))
} }
@ -91,12 +98,14 @@ pub async fn delete(
pool: web::Data<Pool<Sqlite>>, pool: web::Data<Pool<Sqlite>>,
path: web::Path<(String,)>, path: web::Path<(String,)>,
) -> Result<HttpResponse, EmgauwaError> { ) -> Result<HttpResponse, EmgauwaError> {
let mut pool_conn = pool.acquire().await?; let mut tx = pool.begin().await?;
let (macro_uid,) = path.into_inner(); let (macro_uid,) = path.into_inner();
let uid = EmgauwaUid::try_from(macro_uid.as_str())?; let uid = EmgauwaUid::try_from(macro_uid.as_str())?;
DbMacro::delete_by_uid(&mut pool_conn, uid).await?; DbMacro::delete_by_uid(&mut tx, uid).await?;
tx.commit().await?;
Ok(HttpResponse::Ok().json("macro got deleted")) Ok(HttpResponse::Ok().json("macro got deleted"))
} }
@ -107,27 +116,27 @@ pub async fn execute(
path: web::Path<(String,)>, path: web::Path<(String,)>,
query: web::Query<RequestMacroExecute>, query: web::Query<RequestMacroExecute>,
) -> Result<HttpResponse, EmgauwaError> { ) -> Result<HttpResponse, EmgauwaError> {
let mut pool_conn = pool.acquire().await?; let mut tx = pool.begin().await?;
let (macro_uid,) = path.into_inner(); let (macro_uid,) = path.into_inner();
let uid = EmgauwaUid::try_from(macro_uid.as_str())?; let uid = EmgauwaUid::try_from(macro_uid.as_str())?;
let db_macro = DbMacro::get_by_uid(&mut pool_conn, &uid) let db_macro = DbMacro::get_by_uid(&mut tx, &uid)
.await? .await?
.ok_or(DatabaseError::NotFound)?; .ok_or(DatabaseError::NotFound)?;
let actions_db = match query.weekday { let actions_db = match query.weekday {
None => db_macro.get_actions(&mut pool_conn).await?, None => db_macro.get_actions(&mut tx).await?,
Some(weekday) => { Some(weekday) => {
db_macro db_macro
.get_actions_weekday(&mut pool_conn, weekday) .get_actions_weekday(&mut tx, weekday)
.await? .await?
} }
}; };
let mut actions: Vec<MacroAction> = convert_db_list(&mut pool_conn, actions_db)?; let mut actions: Vec<MacroAction> = convert_db_list(&mut tx, actions_db)?;
for action in &actions { for action in &actions {
action.execute(&mut pool_conn).await?; action.execute(&mut tx).await?;
} }
let affected_controller_uids: Vec<EmgauwaUid> = actions let affected_controller_uids: Vec<EmgauwaUid> = actions
@ -144,7 +153,7 @@ pub async fn execute(
if affected_relay_ids.contains(&action.relay.r.id) { if affected_relay_ids.contains(&action.relay.r.id) {
continue; continue;
} }
action.relay.reload(&mut pool_conn)?; action.relay.reload(&mut tx)?;
affected_relays.push(action.relay.clone()); affected_relays.push(action.relay.clone());
affected_relay_ids.push(action.relay.r.id); affected_relay_ids.push(action.relay.r.id);
} }
@ -157,5 +166,6 @@ pub async fn execute(
.await??; .await??;
} }
tx.commit().await?;
Ok(HttpResponse::Ok().finish()) // TODO add a message? Ok(HttpResponse::Ok().finish()) // TODO add a message?
} }

View file

@ -14,12 +14,13 @@ use crate::app_state::AppState;
#[get("/relays")] #[get("/relays")]
pub async fn index(pool: web::Data<Pool<Sqlite>>) -> Result<HttpResponse, EmgauwaError> { pub async fn index(pool: web::Data<Pool<Sqlite>>) -> Result<HttpResponse, EmgauwaError> {
let mut pool_conn = pool.acquire().await?; let mut tx = pool.begin().await?;
let db_relays = DbRelay::get_all(&mut pool_conn).await?; let db_relays = DbRelay::get_all(&mut tx).await?;
let relays: Vec<Relay> = convert_db_list(&mut pool_conn, db_relays)?; let relays: Vec<Relay> = convert_db_list(&mut tx, db_relays)?;
tx.commit().await?;
Ok(HttpResponse::Ok().json(relays)) Ok(HttpResponse::Ok().json(relays))
} }
@ -28,16 +29,17 @@ pub async fn tagged(
pool: web::Data<Pool<Sqlite>>, pool: web::Data<Pool<Sqlite>>,
path: web::Path<(String,)>, path: web::Path<(String,)>,
) -> Result<HttpResponse, EmgauwaError> { ) -> Result<HttpResponse, EmgauwaError> {
let mut pool_conn = pool.acquire().await?; let mut tx = pool.begin().await?;
let (tag,) = path.into_inner(); let (tag,) = path.into_inner();
let tag_db = DbTag::get_by_tag(&mut pool_conn, &tag) let tag_db = DbTag::get_by_tag(&mut tx, &tag)
.await? .await?
.ok_or(DatabaseError::NotFound)?; .ok_or(DatabaseError::NotFound)?;
let db_relays = DbRelay::get_by_tag(&mut pool_conn, &tag_db).await?; let db_relays = DbRelay::get_by_tag(&mut tx, &tag_db).await?;
let relays: Vec<Relay> = convert_db_list(&mut pool_conn, db_relays)?; let relays: Vec<Relay> = convert_db_list(&mut tx, db_relays)?;
tx.commit().await?;
Ok(HttpResponse::Ok().json(relays)) Ok(HttpResponse::Ok().json(relays))
} }
@ -46,18 +48,20 @@ pub async fn index_for_controller(
pool: web::Data<Pool<Sqlite>>, pool: web::Data<Pool<Sqlite>>,
path: web::Path<(String,)>, path: web::Path<(String,)>,
) -> Result<HttpResponse, EmgauwaError> { ) -> Result<HttpResponse, EmgauwaError> {
let mut pool_conn = pool.acquire().await?; let mut tx = pool.begin().await?;
let (controller_uid,) = path.into_inner(); let (controller_uid,) = path.into_inner();
let uid = EmgauwaUid::try_from(controller_uid.as_str())?; let uid = EmgauwaUid::try_from(controller_uid.as_str())?;
let controller = DbController::get_by_uid(&mut pool_conn, &uid) let controller = DbController::get_by_uid(&mut tx, &uid)
.await? .await?
.ok_or(DatabaseError::NotFound)?; .ok_or(DatabaseError::NotFound)?;
let db_relays = controller.get_relays(&mut pool_conn).await?; let db_relays = controller.get_relays(&mut tx).await?;
let relays: Vec<Relay> = convert_db_list(&mut pool_conn, db_relays)?; let relays: Vec<Relay> = convert_db_list(&mut tx, db_relays)?;
tx.commit().await?;
Ok(HttpResponse::Ok().json(relays)) Ok(HttpResponse::Ok().json(relays))
} }
@ -66,20 +70,22 @@ pub async fn show_for_controller(
pool: web::Data<Pool<Sqlite>>, pool: web::Data<Pool<Sqlite>>,
path: web::Path<(String, i64)>, path: web::Path<(String, i64)>,
) -> Result<HttpResponse, EmgauwaError> { ) -> Result<HttpResponse, EmgauwaError> {
let mut pool_conn = pool.acquire().await?; let mut tx = pool.begin().await?;
let (controller_uid, relay_num) = path.into_inner(); let (controller_uid, relay_num) = path.into_inner();
let uid = EmgauwaUid::try_from(controller_uid.as_str())?; let uid = EmgauwaUid::try_from(controller_uid.as_str())?;
let controller = DbController::get_by_uid(&mut pool_conn, &uid) let controller = DbController::get_by_uid(&mut tx, &uid)
.await? .await?
.ok_or(DatabaseError::NotFound)?; .ok_or(DatabaseError::NotFound)?;
let relay = DbRelay::get_by_controller_and_num(&mut pool_conn, &controller, relay_num) let relay = DbRelay::get_by_controller_and_num(&mut tx, &controller, relay_num)
.await? .await?
.ok_or(DatabaseError::NotFound)?; .ok_or(DatabaseError::NotFound)?;
let return_relay = Relay::from_db_model(&mut pool_conn, relay)?; let return_relay = Relay::from_db_model(&mut tx, relay)?;
tx.commit().await?;
Ok(HttpResponse::Ok().json(return_relay)) Ok(HttpResponse::Ok().json(return_relay))
} }
@ -90,32 +96,32 @@ pub async fn update_for_controller(
path: web::Path<(String, i64)>, path: web::Path<(String, i64)>,
data: web::Json<RequestRelayUpdate>, data: web::Json<RequestRelayUpdate>,
) -> Result<HttpResponse, EmgauwaError> { ) -> Result<HttpResponse, EmgauwaError> {
let mut pool_conn = pool.acquire().await?; let mut tx = pool.begin().await?;
let (controller_uid, relay_num) = path.into_inner(); let (controller_uid, relay_num) = path.into_inner();
let uid = EmgauwaUid::try_from(controller_uid.as_str())?; let uid = EmgauwaUid::try_from(controller_uid.as_str())?;
let controller = DbController::get_by_uid(&mut pool_conn, &uid) let controller = DbController::get_by_uid(&mut tx, &uid)
.await? .await?
.ok_or(DatabaseError::NotFound)?; .ok_or(DatabaseError::NotFound)?;
let mut relay = DbRelay::get_by_controller_and_num(&mut pool_conn, &controller, relay_num) let mut relay = DbRelay::get_by_controller_and_num(&mut tx, &controller, relay_num)
.await? .await?
.ok_or(DatabaseError::NotFound)?; .ok_or(DatabaseError::NotFound)?;
if let Some(name) = &data.name { if let Some(name) = &data.name {
relay = relay.update(&mut pool_conn, name.as_str()).await?; relay = relay.update(&mut tx, name.as_str()).await?;
} }
if let Some(schedule_uids) = &data.schedules { if let Some(schedule_uids) = &data.schedules {
if schedule_uids.len() == 7 { if schedule_uids.len() == 7 {
let mut schedules = Vec::new(); let mut schedules = Vec::new();
for s_uid in schedule_uids { for s_uid in schedule_uids {
schedules.push(s_uid.get_schedule(&mut pool_conn).await?); schedules.push(s_uid.get_schedule(&mut tx).await?);
} }
DbJunctionRelaySchedule::set_schedules( DbJunctionRelaySchedule::set_schedules(
&mut pool_conn, &mut tx,
&relay, &relay,
schedules.iter().collect(), schedules.iter().collect(),
) )
@ -124,9 +130,9 @@ pub async fn update_for_controller(
} }
if let Some(s_uid) = &data.active_schedule { if let Some(s_uid) = &data.active_schedule {
let schedule = s_uid.get_schedule(&mut pool_conn).await?; let schedule = s_uid.get_schedule(&mut tx).await?;
DbJunctionRelaySchedule::set_schedule( DbJunctionRelaySchedule::set_schedule(
&mut pool_conn, &mut tx,
&relay, &relay,
&schedule, &schedule,
utils::get_weekday(), utils::get_weekday(),
@ -135,12 +141,12 @@ pub async fn update_for_controller(
} }
if let Some(tags) = &data.tags { if let Some(tags) = &data.tags {
relay.set_tags(&mut pool_conn, tags.as_slice()).await?; relay.set_tags(&mut tx, tags.as_slice()).await?;
} }
let relay = relay.reload(&mut pool_conn).await?; let relay = relay.reload(&mut tx).await?;
let return_relay = Relay::from_db_model(&mut pool_conn, relay)?; let return_relay = Relay::from_db_model(&mut tx, relay)?;
app_state app_state
.send(app_state::Action { .send(app_state::Action {
@ -149,6 +155,7 @@ pub async fn update_for_controller(
}) })
.await??; .await??;
tx.commit().await?;
Ok(HttpResponse::Ok().json(return_relay)) Ok(HttpResponse::Ok().json(return_relay))
} }
@ -159,16 +166,16 @@ pub async fn pulse(
path: web::Path<(String, i64)>, path: web::Path<(String, i64)>,
data: web::Json<RequestRelayPulse>, data: web::Json<RequestRelayPulse>,
) -> Result<HttpResponse, EmgauwaError> { ) -> Result<HttpResponse, EmgauwaError> {
let mut pool_conn = pool.acquire().await?; let mut tx = pool.begin().await?;
let (controller_uid, relay_num) = path.into_inner(); let (controller_uid, relay_num) = path.into_inner();
let uid = EmgauwaUid::try_from(controller_uid.as_str())?; let uid = EmgauwaUid::try_from(controller_uid.as_str())?;
let controller = DbController::get_by_uid(&mut pool_conn, &uid) let controller = DbController::get_by_uid(&mut tx, &uid)
.await? .await?
.ok_or(DatabaseError::NotFound)?; .ok_or(DatabaseError::NotFound)?;
let relay = DbRelay::get_by_controller_and_num(&mut pool_conn, &controller, relay_num) let relay = DbRelay::get_by_controller_and_num(&mut tx, &controller, relay_num)
.await? .await?
.ok_or(DatabaseError::NotFound)?; .ok_or(DatabaseError::NotFound)?;
@ -181,5 +188,6 @@ pub async fn pulse(
}) })
.await??; .await??;
tx.commit().await?;
Ok(HttpResponse::Ok().finish()) // TODO add a message? Ok(HttpResponse::Ok().finish()) // TODO add a message?
} }

View file

@ -7,7 +7,7 @@ use emgauwa_common::types::{
ControllerWsAction, RequestScheduleCreate, RequestScheduleUpdate, ScheduleUid, ControllerWsAction, RequestScheduleCreate, RequestScheduleUpdate, ScheduleUid,
}; };
use itertools::Itertools; use itertools::Itertools;
use sqlx::pool::PoolConnection; use sqlx::Transaction;
use sqlx::{Pool, Sqlite}; use sqlx::{Pool, Sqlite};
use crate::app_state; use crate::app_state;
@ -15,11 +15,12 @@ use crate::app_state::AppState;
#[get("/schedules")] #[get("/schedules")]
pub async fn index(pool: web::Data<Pool<Sqlite>>) -> Result<HttpResponse, EmgauwaError> { pub async fn index(pool: web::Data<Pool<Sqlite>>) -> Result<HttpResponse, EmgauwaError> {
let mut pool_conn = pool.acquire().await?; let mut tx = pool.begin().await?;
let db_schedules = DbSchedule::get_all(&mut pool_conn).await?; let db_schedules = DbSchedule::get_all(&mut tx).await?;
let schedules: Vec<Schedule> = convert_db_list(&mut pool_conn, db_schedules)?; let schedules: Vec<Schedule> = convert_db_list(&mut tx, db_schedules)?;
tx.commit().await?;
Ok(HttpResponse::Ok().json(schedules)) Ok(HttpResponse::Ok().json(schedules))
} }
@ -28,16 +29,17 @@ pub async fn tagged(
pool: web::Data<Pool<Sqlite>>, pool: web::Data<Pool<Sqlite>>,
path: web::Path<(String,)>, path: web::Path<(String,)>,
) -> Result<HttpResponse, EmgauwaError> { ) -> Result<HttpResponse, EmgauwaError> {
let mut pool_conn = pool.acquire().await?; let mut tx = pool.begin().await?;
let (tag,) = path.into_inner(); let (tag,) = path.into_inner();
let tag_db = DbTag::get_by_tag(&mut pool_conn, &tag) let tag_db = DbTag::get_by_tag(&mut tx, &tag)
.await? .await?
.ok_or(DatabaseError::NotFound)?; .ok_or(DatabaseError::NotFound)?;
let db_schedules = DbSchedule::get_by_tag(&mut pool_conn, &tag_db).await?; let db_schedules = DbSchedule::get_by_tag(&mut tx, &tag_db).await?;
let schedules: Vec<Schedule> = convert_db_list(&mut pool_conn, db_schedules)?; let schedules: Vec<Schedule> = convert_db_list(&mut tx, db_schedules)?;
tx.commit().await?;
Ok(HttpResponse::Ok().json(schedules)) Ok(HttpResponse::Ok().json(schedules))
} }
@ -46,16 +48,18 @@ pub async fn show(
pool: web::Data<Pool<Sqlite>>, pool: web::Data<Pool<Sqlite>>,
path: web::Path<(String,)>, path: web::Path<(String,)>,
) -> Result<HttpResponse, EmgauwaError> { ) -> Result<HttpResponse, EmgauwaError> {
let mut pool_conn = pool.acquire().await?; let mut tx = pool.begin().await?;
let (schedule_uid,) = path.into_inner(); let (schedule_uid,) = path.into_inner();
let uid = ScheduleUid::try_from(schedule_uid.as_str())?; let uid = ScheduleUid::try_from(schedule_uid.as_str())?;
let schedule = DbSchedule::get_by_uid(&mut pool_conn, &uid) let schedule = DbSchedule::get_by_uid(&mut tx, &uid)
.await? .await?
.ok_or(DatabaseError::NotFound)?; .ok_or(DatabaseError::NotFound)?;
let return_schedule = Schedule::from_db_model(&mut pool_conn, schedule)?; let return_schedule = Schedule::from_db_model(&mut tx, schedule)?;
tx.commit().await?;
Ok(HttpResponse::Ok().json(return_schedule)) Ok(HttpResponse::Ok().json(return_schedule))
} }
@ -64,10 +68,10 @@ pub async fn add(
pool: web::Data<Pool<Sqlite>>, pool: web::Data<Pool<Sqlite>>,
data: web::Json<RequestScheduleCreate>, data: web::Json<RequestScheduleCreate>,
) -> Result<HttpResponse, EmgauwaError> { ) -> Result<HttpResponse, EmgauwaError> {
let mut pool_conn = pool.acquire().await?; let mut tx = pool.begin().await?;
let new_schedule = DbSchedule::create( let new_schedule = DbSchedule::create(
&mut pool_conn, &mut tx,
ScheduleUid::default(), ScheduleUid::default(),
&data.name, &data.name,
&data.periods, &data.periods,
@ -76,20 +80,22 @@ pub async fn add(
if let Some(tags) = &data.tags { if let Some(tags) = &data.tags {
new_schedule new_schedule
.set_tags(&mut pool_conn, tags.as_slice()) .set_tags(&mut tx, tags.as_slice())
.await?; .await?;
} }
let return_schedule = Schedule::from_db_model(&mut pool_conn, new_schedule)?; let return_schedule = Schedule::from_db_model(&mut tx, new_schedule)?;
tx.commit().await?;
Ok(HttpResponse::Created().json(return_schedule)) Ok(HttpResponse::Created().json(return_schedule))
} }
async fn add_list_single( async fn add_list_single(
conn: &mut PoolConnection<Sqlite>, tx: &mut Transaction<'_, Sqlite>,
request_schedule: &RequestScheduleCreate, request_schedule: &RequestScheduleCreate,
) -> Result<DbSchedule, DatabaseError> { ) -> Result<DbSchedule, DatabaseError> {
let new_schedule = DbSchedule::create( let new_schedule = DbSchedule::create(
conn, tx,
ScheduleUid::default(), ScheduleUid::default(),
&request_schedule.name, &request_schedule.name,
&request_schedule.periods, &request_schedule.periods,
@ -97,7 +103,7 @@ async fn add_list_single(
.await?; .await?;
if let Some(tags) = &request_schedule.tags { if let Some(tags) = &request_schedule.tags {
new_schedule.set_tags(conn, tags.as_slice()).await?; new_schedule.set_tags(tx, tags.as_slice()).await?;
} }
Ok(new_schedule) Ok(new_schedule)
@ -108,15 +114,17 @@ pub async fn add_list(
pool: web::Data<Pool<Sqlite>>, pool: web::Data<Pool<Sqlite>>,
data: web::Json<Vec<RequestScheduleCreate>>, data: web::Json<Vec<RequestScheduleCreate>>,
) -> Result<HttpResponse, EmgauwaError> { ) -> Result<HttpResponse, EmgauwaError> {
let mut pool_conn = pool.acquire().await?; let mut tx = pool.begin().await?;
let mut db_schedules: Vec<DbSchedule> = Vec::new(); let mut db_schedules: Vec<DbSchedule> = Vec::new();
for s in data.iter() { for s in data.iter() {
let new_s = futures::executor::block_on(add_list_single(&mut pool_conn, s))?; let new_s = add_list_single(&mut tx, s).await?;
db_schedules.push(new_s); db_schedules.push(new_s);
} }
let schedules: Vec<Schedule> = convert_db_list(&mut pool_conn, db_schedules)?; let schedules: Vec<Schedule> = convert_db_list(&mut tx, db_schedules)?;
tx.commit().await?;
Ok(HttpResponse::Created().json(schedules)) Ok(HttpResponse::Created().json(schedules))
} }
@ -127,12 +135,12 @@ pub async fn update(
path: web::Path<(String,)>, path: web::Path<(String,)>,
data: web::Json<RequestScheduleUpdate>, data: web::Json<RequestScheduleUpdate>,
) -> Result<HttpResponse, EmgauwaError> { ) -> Result<HttpResponse, EmgauwaError> {
let mut pool_conn = pool.acquire().await?; let mut tx = pool.begin().await?;
let (schedule_uid,) = path.into_inner(); let (schedule_uid,) = path.into_inner();
let uid = ScheduleUid::try_from(schedule_uid.as_str())?; let uid = ScheduleUid::try_from(schedule_uid.as_str())?;
let schedule = DbSchedule::get_by_uid(&mut pool_conn, &uid) let schedule = DbSchedule::get_by_uid(&mut tx, &uid)
.await? .await?
.ok_or(DatabaseError::NotFound)?; .ok_or(DatabaseError::NotFound)?;
@ -146,13 +154,13 @@ pub async fn update(
Some(period) => period, Some(period) => period,
}; };
let schedule = schedule.update(&mut pool_conn, name, periods).await?; let schedule = schedule.update(&mut tx, name, periods).await?;
if let Some(tags) = &data.tags { if let Some(tags) = &data.tags {
schedule.set_tags(&mut pool_conn, tags.as_slice()).await?; schedule.set_tags(&mut tx, tags.as_slice()).await?;
} }
let controller_ids: Vec<i64> = DbJunctionRelaySchedule::get_relays(&mut pool_conn, &schedule) let controller_ids: Vec<i64> = DbJunctionRelaySchedule::get_relays(&mut tx, &schedule)
.await? .await?
.into_iter() .into_iter()
.map(|r| r.controller_id) .map(|r| r.controller_id)
@ -160,7 +168,7 @@ pub async fn update(
.collect(); .collect();
for controller_id in controller_ids { for controller_id in controller_ids {
let controller = DbController::get(&mut pool_conn, controller_id) let controller = DbController::get(&mut tx, controller_id)
.await? .await?
.ok_or(DatabaseError::NotFound)?; .ok_or(DatabaseError::NotFound)?;
app_state app_state
@ -171,7 +179,9 @@ pub async fn update(
.await??; .await??;
} }
let return_schedule = Schedule::from_db_model(&mut pool_conn, schedule)?; let return_schedule = Schedule::from_db_model(&mut tx, schedule)?;
tx.commit().await?;
Ok(HttpResponse::Ok().json(return_schedule)) Ok(HttpResponse::Ok().json(return_schedule))
} }
@ -180,8 +190,6 @@ pub async fn delete(
pool: web::Data<Pool<Sqlite>>, pool: web::Data<Pool<Sqlite>>,
path: web::Path<(String,)>, path: web::Path<(String,)>,
) -> Result<HttpResponse, EmgauwaError> { ) -> Result<HttpResponse, EmgauwaError> {
let mut pool_conn = pool.acquire().await?;
let (schedule_uid,) = path.into_inner(); let (schedule_uid,) = path.into_inner();
let uid = ScheduleUid::try_from(schedule_uid.as_str())?; let uid = ScheduleUid::try_from(schedule_uid.as_str())?;
@ -189,7 +197,9 @@ pub async fn delete(
ScheduleUid::Off => Err(EmgauwaError::from(ApiError::ProtectedSchedule)), ScheduleUid::Off => Err(EmgauwaError::from(ApiError::ProtectedSchedule)),
ScheduleUid::On => Err(EmgauwaError::from(ApiError::ProtectedSchedule)), ScheduleUid::On => Err(EmgauwaError::from(ApiError::ProtectedSchedule)),
ScheduleUid::Any(_) => { ScheduleUid::Any(_) => {
DbSchedule::delete_by_uid(&mut pool_conn, uid).await?; let mut tx = pool.begin().await?;
DbSchedule::delete_by_uid(&mut tx, uid).await?;
tx.commit().await?;
Ok(HttpResponse::Ok().json("schedule got deleted")) Ok(HttpResponse::Ok().json("schedule got deleted"))
} }
} }

View file

@ -7,12 +7,13 @@ use sqlx::{Pool, Sqlite};
#[get("/tags")] #[get("/tags")]
pub async fn index(pool: web::Data<Pool<Sqlite>>) -> Result<HttpResponse, EmgauwaError> { pub async fn index(pool: web::Data<Pool<Sqlite>>) -> Result<HttpResponse, EmgauwaError> {
let mut pool_conn = pool.acquire().await?; let mut tx = pool.begin().await?;
let db_tags = DbTag::get_all(&mut pool_conn).await?; let db_tags = DbTag::get_all(&mut tx).await?;
let tags: Vec<String> = db_tags.iter().map(|t| t.tag.clone()).collect(); let tags: Vec<String> = db_tags.iter().map(|t| t.tag.clone()).collect();
tx.commit().await?;
Ok(HttpResponse::Ok().json(tags)) Ok(HttpResponse::Ok().json(tags))
} }
@ -21,15 +22,17 @@ pub async fn show(
pool: web::Data<Pool<Sqlite>>, pool: web::Data<Pool<Sqlite>>,
path: web::Path<(String,)>, path: web::Path<(String,)>,
) -> Result<HttpResponse, EmgauwaError> { ) -> Result<HttpResponse, EmgauwaError> {
let mut pool_conn = pool.acquire().await?; let mut tx = pool.begin().await?;
let (tag_name,) = path.into_inner(); let (tag_name,) = path.into_inner();
let tag = DbTag::get_by_tag(&mut pool_conn, &tag_name) let tag = DbTag::get_by_tag(&mut tx, &tag_name)
.await? .await?
.ok_or(DatabaseError::NotFound)?; .ok_or(DatabaseError::NotFound)?;
let return_tag = Tag::from_db_model(&mut pool_conn, tag)?; let return_tag = Tag::from_db_model(&mut tx, tag)?;
tx.commit().await?;
Ok(HttpResponse::Ok().json(return_tag)) Ok(HttpResponse::Ok().json(return_tag))
} }
@ -38,11 +41,13 @@ pub async fn delete(
pool: web::Data<Pool<Sqlite>>, pool: web::Data<Pool<Sqlite>>,
path: web::Path<(String,)>, path: web::Path<(String,)>,
) -> Result<HttpResponse, EmgauwaError> { ) -> Result<HttpResponse, EmgauwaError> {
let mut pool_conn = pool.acquire().await?; let mut tx = pool.begin().await?;
let (tag_name,) = path.into_inner(); let (tag_name,) = path.into_inner();
DbTag::delete_by_tag(&mut pool_conn, &tag_name).await?; DbTag::delete_by_tag(&mut tx, &tag_name).await?;
tx.commit().await?;
Ok(HttpResponse::Ok().json("tag got deleted")) Ok(HttpResponse::Ok().json("tag got deleted"))
} }
@ -51,11 +56,13 @@ pub async fn add(
pool: web::Data<Pool<Sqlite>>, pool: web::Data<Pool<Sqlite>>,
data: web::Json<RequestTagCreate>, data: web::Json<RequestTagCreate>,
) -> Result<HttpResponse, EmgauwaError> { ) -> Result<HttpResponse, EmgauwaError> {
let mut pool_conn = pool.acquire().await?; let mut tx = pool.begin().await?;
let new_tag = DbTag::create(&mut pool_conn, &data.tag).await?; let new_tag = DbTag::create(&mut tx, &data.tag).await?;
let cache = (Vec::new(), Vec::new()); // a new tag can't have any relays or schedules let cache = (Vec::new(), Vec::new()); // a new tag can't have any relays or schedules
let return_tag = Tag::from_db_model_cache(&mut pool_conn, new_tag, cache)?; let return_tag = Tag::from_db_model_cache(&mut tx, new_tag, cache)?;
tx.commit().await?;
Ok(HttpResponse::Created().json(return_tag)) Ok(HttpResponse::Created().json(return_tag))
} }

View file

@ -5,8 +5,6 @@ use emgauwa_common::models::{Controller, FromDbModel};
use emgauwa_common::types::{ControllerWsAction, EmgauwaUid, RelayStates}; use emgauwa_common::types::{ControllerWsAction, EmgauwaUid, RelayStates};
use emgauwa_common::utils; use emgauwa_common::utils;
use futures::executor::block_on; use futures::executor::block_on;
use sqlx::pool::PoolConnection;
use sqlx::Sqlite;
use crate::app_state::{Action, ConnectController, UpdateRelayStates}; use crate::app_state::{Action, ConnectController, UpdateRelayStates};
use crate::handlers::v1::ws::controllers::ControllersWs; use crate::handlers::v1::ws::controllers::ControllersWs;
@ -14,7 +12,6 @@ use crate::handlers::v1::ws::controllers::ControllersWs;
impl ControllersWs { impl ControllersWs {
pub fn handle_register( pub fn handle_register(
&mut self, &mut self,
conn: &mut PoolConnection<Sqlite>,
ctx: &mut <ControllersWs as Actor>::Context, ctx: &mut <ControllersWs as Actor>::Context,
controller: Controller, controller: Controller,
) -> Result<(), EmgauwaError> { ) -> Result<(), EmgauwaError> {
@ -23,16 +20,19 @@ impl ControllersWs {
controller.c.name, controller.c.name,
controller.c.uid controller.c.uid
); );
let mut tx = block_on(self.pool.begin())?;
let c = &controller.c; let c = &controller.c;
let controller_db = block_on(DbController::get_by_uid_or_create( let controller_db = block_on(DbController::get_by_uid_or_create(
conn, &mut tx,
&c.uid, &c.uid,
&c.name, &c.name,
c.relay_count, c.relay_count,
))?; ))?;
block_on(controller_db.update_active(conn, true))?; block_on(controller_db.update_active(&mut tx, true))?;
// update only the relay count // update only the relay count
block_on(controller_db.update(conn, &controller_db.name, c.relay_count))?; block_on(controller_db.update(&mut tx, &controller_db.name, c.relay_count))?;
for relay in &controller.relays { for relay in &controller.relays {
log::debug!( log::debug!(
@ -45,7 +45,7 @@ impl ControllersWs {
} }
); );
let (new_relay, created) = block_on(DbRelay::get_by_controller_and_num_or_create( let (new_relay, created) = block_on(DbRelay::get_by_controller_and_num_or_create(
conn, &mut tx,
&controller_db, &controller_db,
relay.r.number, relay.r.number,
&relay.r.name, &relay.r.name,
@ -54,7 +54,7 @@ impl ControllersWs {
let mut relay_schedules = Vec::new(); let mut relay_schedules = Vec::new();
for schedule in &relay.schedules { for schedule in &relay.schedules {
let (new_schedule, _) = block_on(DbSchedule::get_by_uid_or_create( let (new_schedule, _) = block_on(DbSchedule::get_by_uid_or_create(
conn, &mut tx,
schedule.uid.clone(), schedule.uid.clone(),
&schedule.name, &schedule.name,
&schedule.periods, &schedule.periods,
@ -63,7 +63,7 @@ impl ControllersWs {
} }
block_on(DbJunctionRelaySchedule::set_schedules( block_on(DbJunctionRelaySchedule::set_schedules(
conn, &mut tx,
&new_relay, &new_relay,
relay_schedules.iter().collect(), relay_schedules.iter().collect(),
))?; ))?;
@ -71,9 +71,9 @@ impl ControllersWs {
} }
let controller_uid = &controller.c.uid; let controller_uid = &controller.c.uid;
let controller_db = block_on(DbController::get_by_uid(conn, controller_uid))? let controller_db = block_on(DbController::get_by_uid(&mut tx, controller_uid))?
.ok_or(DatabaseError::InsertGetError)?; .ok_or(DatabaseError::InsertGetError)?;
let controller = Controller::from_db_model(conn, controller_db)?; let controller = Controller::from_db_model(&mut tx, controller_db)?;
let addr = ctx.address(); let addr = ctx.address();
self.controller_uid = Some(controller_uid.clone()); self.controller_uid = Some(controller_uid.clone());
@ -91,6 +91,7 @@ impl ControllersWs {
action: ControllerWsAction::Relays(controller.relays), action: ControllerWsAction::Relays(controller.relays),
}))??; }))??;
block_on(tx.commit())?;
log::debug!("Done registering controller"); log::debug!("Done registering controller");
Ok(()) Ok(())
} }

View file

@ -9,7 +9,6 @@ use emgauwa_common::constants::{HEARTBEAT_INTERVAL, HEARTBEAT_TIMEOUT};
use emgauwa_common::errors::EmgauwaError; use emgauwa_common::errors::EmgauwaError;
use emgauwa_common::types::{ControllerWsAction, EmgauwaUid}; use emgauwa_common::types::{ControllerWsAction, EmgauwaUid};
use futures::executor::block_on; use futures::executor::block_on;
use sqlx::pool::PoolConnection;
use sqlx::{Pool, Sqlite}; use sqlx::{Pool, Sqlite};
use ws::Message; use ws::Message;
@ -48,12 +47,11 @@ impl Actor for ControllersWs {
impl ControllersWs { impl ControllersWs {
pub fn handle_action( pub fn handle_action(
&mut self, &mut self,
conn: &mut PoolConnection<Sqlite>,
ctx: &mut <ControllersWs as Actor>::Context, ctx: &mut <ControllersWs as Actor>::Context,
action: ControllerWsAction, action: ControllerWsAction,
) { ) {
let action_res = match action { let action_res = match action {
ControllerWsAction::Register(controller) => self.handle_register(conn, ctx, controller), ControllerWsAction::Register(controller) => self.handle_register(ctx, controller),
ControllerWsAction::RelayStates((controller_uid, relay_states)) => { ControllerWsAction::RelayStates((controller_uid, relay_states)) => {
self.handle_relay_states(controller_uid, relay_states) self.handle_relay_states(controller_uid, relay_states)
} }
@ -103,15 +101,6 @@ impl Handler<ControllerWsAction> for ControllersWs {
impl StreamHandler<Result<Message, ProtocolError>> for ControllersWs { impl StreamHandler<Result<Message, ProtocolError>> for ControllersWs {
fn handle(&mut self, msg: Result<Message, ProtocolError>, ctx: &mut Self::Context) { fn handle(&mut self, msg: Result<Message, ProtocolError>, ctx: &mut Self::Context) {
let mut pool_conn = match block_on(self.pool.acquire()) {
Ok(conn) => conn,
Err(err) => {
log::error!("Failed to acquire database connection: {:?}", err);
ctx.stop();
return;
}
};
let msg = match msg { let msg = match msg {
Err(_) => { Err(_) => {
ctx.stop(); ctx.stop();
@ -130,7 +119,7 @@ impl StreamHandler<Result<Message, ProtocolError>> for ControllersWs {
} }
Message::Text(text) => match serde_json::from_str(&text) { Message::Text(text) => match serde_json::from_str(&text) {
Ok(action) => { Ok(action) => {
self.handle_action(&mut pool_conn, ctx, action); self.handle_action(ctx, action);
} }
Err(e) => { Err(e) => {
log::error!("Error deserializing action: {:?}", e); log::error!("Error deserializing action: {:?}", e);

View file

@ -28,11 +28,11 @@ async fn main() -> Result<(), std::io::Error> {
let pool = emgauwa_common::db::init(&settings.database).await?; let pool = emgauwa_common::db::init(&settings.database).await?;
let mut conn = pool.acquire().await.map_err(EmgauwaError::from)?; let mut tx = pool.begin().await.map_err(EmgauwaError::from)?;
DbController::all_inactive(&mut conn) DbController::all_inactive(&mut tx)
.await .await
.map_err(EmgauwaError::from)?; .map_err(EmgauwaError::from)?;
conn.close().await.map_err(EmgauwaError::from)?; tx.commit().await.map_err(EmgauwaError::from)?;
let app_state_arbiter = Arbiter::with_tokio_rt(|| { let app_state_arbiter = Arbiter::with_tokio_rt(|| {
tokio::runtime::Builder::new_multi_thread() tokio::runtime::Builder::new_multi_thread()