diff --git a/.env b/.env index 6075b6a..77e3b5d 100644 --- a/.env +++ b/.env @@ -1 +1 @@ -DATABASE_URL=emgauwa-core.sqlite +DATABASE_URL=sqlite://emgauwa-core.sqlite diff --git a/.gitignore b/.gitignore index 746d5d5..13b0cab 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ emgauwa-core.conf.d emgauwa-core.sqlite +emgauwa-core.sqlite-* # Added by cargo diff --git a/Cargo.lock b/Cargo.lock index 5229765..769c3b5 100644 Binary files a/Cargo.lock and b/Cargo.lock differ diff --git a/Cargo.toml b/Cargo.toml index bc91bd0..fa180c0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,8 +12,7 @@ authors = ["Tobias Reisinger "] [dependencies] actix-web = "4.4" -diesel = { version = "2.1", features = ["uuid", "sqlite"] } -diesel_migrations = "2.1" +sqlx = { version = "0.7", features = ["sqlite", "runtime-async-std", "macros", "chrono"] } dotenv = "0.15" config = "0.13" @@ -30,3 +29,5 @@ serde_json = "1.0" serde_derive = "1.0" libsqlite3-sys = { version = "*", features = ["bundled"] } + +futures = "0.3.29" diff --git a/diesel.toml b/diesel.toml deleted file mode 100644 index 71215db..0000000 --- a/diesel.toml +++ /dev/null @@ -1,5 +0,0 @@ -# For documentation on how to configure this file, -# see diesel.rs/guides/configuring-diesel-cli - -[print_schema] -file = "src/db/schema.rs" diff --git a/migrations/2021-10-13-000000_init/down.sql b/migrations/20231120000000_init.down.sql similarity index 100% rename from migrations/2021-10-13-000000_init/down.sql rename to migrations/20231120000000_init.down.sql diff --git a/migrations/2021-10-13-000000_init/up.sql b/migrations/20231120000000_init.up.sql similarity index 100% rename from migrations/2021-10-13-000000_init/up.sql rename to migrations/20231120000000_init.up.sql diff --git a/src/db.rs b/src/db.rs index 970a207..7960693 100644 --- a/src/db.rs +++ b/src/db.rs @@ -1,73 +1,80 @@ -use std::env; +use log::{info, trace}; +use sqlx::migrate::Migrator; +use sqlx::{Pool, Sqlite}; +use sqlx::sqlite::SqlitePoolOptions; use crate::db::errors::DatabaseError; use crate::db::model_utils::Period; -use crate::db::models::{NewSchedule, Periods}; +use crate::db::models::{Schedule, Periods}; use crate::types::EmgauwaUid; -use diesel::prelude::*; -use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness}; -use dotenv::dotenv; -use log::{info, trace}; pub mod errors; pub mod models; pub mod schedules; -pub mod schema; pub mod tag; mod model_utils; -pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations"); +static MIGRATOR: Migrator = sqlx::migrate!(); // defaults to "./migrations" -fn get_connection() -> SqliteConnection { - dotenv().ok(); - - let database_url = env::var("DATABASE_URL").expect("DATABASE_URL must be set"); - SqliteConnection::establish(&database_url) - .unwrap_or_else(|_| panic!("Error connecting to {}", database_url)) -} - -pub fn run_migrations() { +pub async fn run_migrations(pool: &Pool) { info!("Running migrations"); - let mut connection = get_connection(); - connection - .run_pending_migrations(MIGRATIONS) + MIGRATOR + .run(pool) + .await .expect("Failed to run migrations."); } -fn init_schedule(schedule: &NewSchedule) -> Result<(), DatabaseError> { - trace!("Initializing schedule {:?}", schedule.name); - match schedules::get_schedule_by_uid(schedule.uid.clone()) { +async fn init_schedule(pool: &Pool, uid: &EmgauwaUid, name: &str, periods: Periods) -> Result<(), DatabaseError> { + trace!("Initializing schedule {:?}", name); + match schedules::get_schedule_by_uid(pool, uid).await { Ok(_) => Ok(()), Err(err) => match err { DatabaseError::NotFound => { - trace!("Schedule {:?} not found, inserting", schedule.name); - let mut connection = get_connection(); - diesel::insert_into(schema::schedules::table) - .values(schedule) - .execute(&mut connection) + trace!("Schedule {:?} not found, inserting", name); + sqlx::query_as!(Schedule, "INSERT INTO schedules (uid, name, periods) VALUES (?, ?, ?) RETURNING *", + uid, + name, + periods, + ) + .fetch_optional(pool) + .await? + .ok_or(DatabaseError::InsertGetError) .map(|_| ()) - .map_err(DatabaseError::InsertError) } _ => Err(err), }, } } -pub fn init(db: &str) { - run_migrations(); - init_schedule(&NewSchedule { - uid: &EmgauwaUid::Off, - name: "Off", - periods: &Periods(vec![]), - }) - .expect("Error initializing schedule Off"); +pub async fn init(db: &str) -> Pool { + let pool: Pool = SqlitePoolOptions::new() + .acquire_timeout(std::time::Duration::from_secs(1)) + .max_connections(5) + .connect(db) + .await + .expect("Error connecting to database."); - init_schedule(&NewSchedule { - uid: &EmgauwaUid::On, - name: "On", - periods: &Periods(vec![Period::new_on()]), - }) - .expect("Error initializing schedule On"); + run_migrations(&pool).await; + + init_schedule( + &pool, + &EmgauwaUid::Off, + "Off", + Periods(vec![]) + ) + .await + .expect("Error initializing schedule Off"); + + init_schedule( + &pool, + &EmgauwaUid::On, + "On", + Periods(vec![Period::new_on()]) + ) + .await + .expect("Error initializing schedule On"); + + pool } diff --git a/src/db/errors.rs b/src/db/errors.rs index af36c6d..8ba2fc7 100644 --- a/src/db/errors.rs +++ b/src/db/errors.rs @@ -2,15 +2,16 @@ use actix_web::http::StatusCode; use actix_web::HttpResponse; use serde::ser::SerializeStruct; use serde::{Serialize, Serializer}; +use sqlx::Error; #[derive(Debug)] pub enum DatabaseError { DeleteError, - InsertError(diesel::result::Error), + InsertError, InsertGetError, NotFound, Protected, - UpdateError(diesel::result::Error), + UpdateError, Unknown, } @@ -40,14 +41,14 @@ impl Serialize for DatabaseError { impl From<&DatabaseError> for String { fn from(err: &DatabaseError) -> Self { match err { - DatabaseError::InsertError(_) => String::from("error on inserting into database"), + DatabaseError::InsertError => String::from("error on inserting into database"), DatabaseError::InsertGetError => { String::from("error on retrieving new entry from database (your entry was saved)") } DatabaseError::NotFound => String::from("model was not found in database"), DatabaseError::DeleteError => String::from("error on deleting from database"), DatabaseError::Protected => String::from("model is protected"), - DatabaseError::UpdateError(_) => String::from("error on updating the model"), + DatabaseError::UpdateError => String::from("error on updating the model"), DatabaseError::Unknown => String::from("unknown error"), } } @@ -58,3 +59,12 @@ impl From for HttpResponse { HttpResponse::build(err.get_code()).json(err) } } + +impl From for DatabaseError { + fn from(value: Error) -> Self { + match value { + Error::RowNotFound => DatabaseError::NotFound, + _ => DatabaseError::Unknown, + } + } +} \ No newline at end of file diff --git a/src/db/model_utils.rs b/src/db/model_utils.rs index 0d15068..28679d1 100644 --- a/src/db/model_utils.rs +++ b/src/db/model_utils.rs @@ -1,14 +1,13 @@ use crate::db::models::Periods; use chrono::{NaiveTime, Timelike}; -use diesel::deserialize::FromSql; -use diesel::serialize::{IsNull, Output, ToSql}; -use diesel::sql_types::Binary; -use diesel::sqlite::Sqlite; -use diesel::{deserialize, serialize}; use serde::{Deserialize, Serialize}; +use sqlx::{Decode, Encode, Sqlite, Type}; +use sqlx::database::HasArguments; +use sqlx::encode::IsNull; +use sqlx::error::BoxDynError; +use sqlx::sqlite::{SqliteTypeInfo, SqliteValueRef}; -#[derive(Debug, Serialize, Deserialize, AsExpression, FromSqlRow, PartialEq, Clone)] -#[diesel(sql_type = Binary)] +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] pub struct Period { #[serde(with = "period_format")] pub start: NaiveTime, @@ -52,13 +51,81 @@ impl Period { } } -impl ToSql for Periods -where - Vec: ToSql, -{ - fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Sqlite>) -> serialize::Result { - let periods_u8: Vec = self - .0 +//impl ToSql for Periods +//where +// Vec: ToSql, +//{ +// fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Sqlite>) -> serialize::Result { +// let periods_u8: Vec = self +// .0 +// .iter() +// .flat_map(|period| { +// let vec = vec![ +// period.start.hour() as u8, +// period.start.minute() as u8, +// period.end.hour() as u8, +// period.end.minute() as u8, +// ]; +// vec +// }) +// .collect(); +// +// out.set_value(periods_u8); +// +// Ok(IsNull::No) +// } +//} +// +//impl FromSql for Periods +//where +// DB: diesel::backend::Backend, +// Vec: FromSql, +//{ +// fn from_sql(bytes: DB::RawValue<'_>) -> deserialize::Result { +// let blob: Vec = Vec::from_sql(bytes).unwrap(); +// +// let mut vec = Vec::new(); +// for i in (3..blob.len()).step_by(4) { +// let start_val_h: u32 = blob[i - 3] as u32; +// let start_val_m: u32 = blob[i - 2] as u32; +// let end_val_h: u32 = blob[i - 1] as u32; +// let end_val_m: u32 = blob[i] as u32; +// vec.push(Period { +// start: NaiveTime::from_hms_opt(start_val_h, start_val_m, 0).unwrap(), +// end: NaiveTime::from_hms_opt(end_val_h, end_val_m, 0).unwrap(), +// }); +// } +// Ok(Periods(vec)) +// } +//} + +impl Type for Periods { + fn type_info() -> SqliteTypeInfo { + <&[u8] as Type>::type_info() + } + + fn compatible(ty: &SqliteTypeInfo) -> bool { + <&[u8] as Type>::compatible(ty) + } +} + +impl<'q> Encode<'q, Sqlite> for Periods { + //noinspection DuplicatedCode + fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> IsNull { + <&Vec as Encode>::encode(&Vec::from(self), buf) + } +} + +impl<'r> Decode<'r, Sqlite> for Periods { + fn decode(value: SqliteValueRef<'r>) -> Result { + let blob = <&[u8] as Decode>::decode(value)?; + Ok(Periods::from(Vec::from(blob))) + } +} + +impl From<&Periods> for Vec { + fn from(periods: &Periods) -> Vec { + periods.0 .iter() .flat_map(|period| { let vec = vec![ @@ -69,33 +136,23 @@ where ]; vec }) - .collect(); - - out.set_value(periods_u8); - - Ok(IsNull::No) + .collect() } } -impl FromSql for Periods -where - DB: diesel::backend::Backend, - Vec: FromSql, -{ - fn from_sql(bytes: DB::RawValue<'_>) -> deserialize::Result { - let blob: Vec = Vec::from_sql(bytes).unwrap(); - +impl From> for Periods { + fn from(value: Vec) -> Self { let mut vec = Vec::new(); - for i in (3..blob.len()).step_by(4) { - let start_val_h: u32 = blob[i - 3] as u32; - let start_val_m: u32 = blob[i - 2] as u32; - let end_val_h: u32 = blob[i - 1] as u32; - let end_val_m: u32 = blob[i] as u32; + for i in (3..value.len()).step_by(4) { + let start_val_h: u32 = value[i - 3] as u32; + let start_val_m: u32 = value[i - 2] as u32; + let end_val_h: u32 = value[i - 1] as u32; + let end_val_m: u32 = value[i] as u32; vec.push(Period { start: NaiveTime::from_hms_opt(start_val_h, start_val_m, 0).unwrap(), end: NaiveTime::from_hms_opt(end_val_h, end_val_m, 0).unwrap(), }); } - Ok(Periods(vec)) + Periods(vec) } } diff --git a/src/db/models.rs b/src/db/models.rs index 3c80774..4b1c99d 100644 --- a/src/db/models.rs +++ b/src/db/models.rs @@ -1,68 +1,37 @@ use crate::db::model_utils::Period; -use diesel::sql_types::Binary; use serde::{Deserialize, Serialize}; -use super::schema::*; use crate::types::EmgauwaUid; -#[derive(Debug, Serialize, Identifiable, Queryable)] +#[derive(Debug, Serialize)] pub struct Relay { #[serde(skip)] - pub id: i32, + pub id: i64, // TODO } -#[derive(Debug, Serialize, Identifiable, Queryable, Clone)] +#[derive(Debug, Serialize, Clone)] pub struct Schedule { #[serde(skip)] - pub id: i32, + pub id: i64, #[serde(rename(serialize = "id"))] pub uid: EmgauwaUid, pub name: String, pub periods: Periods, } -#[derive(Insertable)] -#[diesel(table_name = crate::db::schema::schedules)] -pub struct NewSchedule<'a> { - pub uid: &'a EmgauwaUid, - pub name: &'a str, - pub periods: &'a Periods, -} - -#[derive(Debug, Serialize, Deserialize, AsExpression, FromSqlRow, PartialEq, Clone)] -#[diesel(sql_type = Binary)] +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] pub struct Periods(pub Vec); -#[derive(Debug, Serialize, Identifiable, Queryable, Clone)] -#[diesel(table_name = crate::db::schema::tags)] +#[derive(Debug, Serialize, Clone)] pub struct Tag { - pub id: i32, + pub id: i64, pub tag: String, } -#[derive(Insertable)] -#[diesel(table_name = crate::db::schema::tags)] -pub struct NewTag<'a> { - pub tag: &'a str, -} - -#[derive(Queryable, Associations, Identifiable)] -#[diesel(belongs_to(Relay))] -#[diesel(belongs_to(Schedule))] -#[diesel(belongs_to(Tag))] -#[diesel(table_name = crate::db::schema::junction_tag)] pub struct JunctionTag { - pub id: i32, - pub tag_id: i32, - pub relay_id: Option, - pub schedule_id: Option, -} - -#[derive(Insertable)] -#[diesel(table_name = crate::db::schema::junction_tag)] -pub struct NewJunctionTag { - pub tag_id: i32, - pub relay_id: Option, - pub schedule_id: Option, + pub id: i64, + pub tag_id: i64, + pub relay_id: Option, + pub schedule_id: Option, } diff --git a/src/db/schedules.rs b/src/db/schedules.rs index 01f13f5..89431fe 100644 --- a/src/db/schedules.rs +++ b/src/db/schedules.rs @@ -1,141 +1,106 @@ -use diesel::dsl::sql; -use diesel::prelude::*; use std::borrow::Borrow; +use sqlx::{Pool, Sqlite}; use crate::types::EmgauwaUid; use crate::db::errors::DatabaseError; use crate::db::models::*; -use crate::db::schema::junction_tag::dsl::junction_tag; -use crate::db::schema::schedules::dsl::schedules; -use crate::db::schema::tags::dsl::tags; -use crate::db::tag::{create_junction_tag, create_tag}; -use crate::db::{get_connection, schema}; +use crate::db::tag::{create_junction_tag_schedule, create_tag}; -pub fn get_schedule_tags(schedule: &Schedule) -> Vec { - let mut connection = get_connection(); - JunctionTag::belonging_to(schedule) - .inner_join(schema::tags::dsl::tags) - .select(schema::tags::tag) - .load::(&mut connection) - .expect("Error loading tags") +pub async fn get_schedule_tags(pool: &Pool, schedule: &Schedule) -> Result, DatabaseError> { + Ok(sqlx::query_scalar!("SELECT tag FROM tags INNER JOIN junction_tag ON junction_tag.tag_id = tags.id WHERE junction_tag.schedule_id = ?", schedule.id) + .fetch_all(pool) + .await?) } -pub fn get_schedules() -> Vec { - let mut connection = get_connection(); - schedules - .load::(&mut connection) - .expect("Error loading schedules") +pub async fn get_schedules(pool: &Pool) -> Result, DatabaseError> { + Ok(sqlx::query_as!(Schedule, "SELECT * FROM schedules") + .fetch_all(pool) + .await?) } -pub fn get_schedule_by_uid(filter_uid: EmgauwaUid) -> Result { - let mut connection = get_connection(); - let result = schedules - .filter(schema::schedules::uid.eq(filter_uid)) - .first::(&mut connection) - .or(Err(DatabaseError::NotFound))?; - - Ok(result) +pub async fn get_schedule_by_uid(pool: &Pool, filter_uid: &EmgauwaUid) -> Result { + sqlx::query_as!(Schedule, "SELECT * FROM schedules WHERE uid = ?", filter_uid) + .fetch_optional(pool) + .await + .map(|s| s.ok_or(DatabaseError::NotFound))? } -pub fn get_schedules_by_tag(tag: &Tag) -> Vec { - let mut connection = get_connection(); - JunctionTag::belonging_to(tag) - .inner_join(schedules) - .select(schema::schedules::all_columns) - .load::(&mut connection) - .expect("Error loading tags") +pub async fn get_schedules_by_tag(pool: &Pool, tag: &Tag) -> Result, DatabaseError> { + Ok(sqlx::query_as!(Schedule, "SELECT schedule.* FROM schedules AS schedule INNER JOIN junction_tag ON junction_tag.schedule_id = schedule.id WHERE junction_tag.tag_id = ?", tag.id) + .fetch_all(pool) + .await?) } -pub fn delete_schedule_by_uid(filter_uid: EmgauwaUid) -> Result<(), DatabaseError> { +pub async fn delete_schedule_by_uid(pool: &Pool, filter_uid: EmgauwaUid) -> Result<(), DatabaseError> { let filter_uid = match filter_uid { EmgauwaUid::Off => Err(DatabaseError::Protected), EmgauwaUid::On => Err(DatabaseError::Protected), EmgauwaUid::Any(_) => Ok(filter_uid), }?; - let mut connection = get_connection(); - match diesel::delete(schedules.filter(schema::schedules::uid.eq(filter_uid))) - .execute(&mut connection) - { - Ok(rows) => { - if rows != 0 { - Ok(()) - } else { - Err(DatabaseError::DeleteError) - } - } - Err(_) => Err(DatabaseError::DeleteError), - } + sqlx::query!("DELETE FROM schedules WHERE uid = ?", filter_uid) + .execute(pool) + .await + .map(|res| match res.rows_affected() { + 0 => Err(DatabaseError::DeleteError), + _ => Ok(()), + })? } -pub fn create_schedule(new_name: &str, new_periods: &Periods) -> Result { - let mut connection = get_connection(); - - let new_schedule = NewSchedule { - uid: &EmgauwaUid::default(), - name: new_name, - periods: new_periods, - }; - - diesel::insert_into(schedules) - .values(&new_schedule) - .execute(&mut connection) - .map_err(DatabaseError::InsertError)?; - - let result = schedules - .find(sql("last_insert_rowid()")) - .get_result::(&mut connection) - .or(Err(DatabaseError::InsertGetError))?; - - Ok(result) +pub async fn create_schedule(pool: &Pool, new_name: &str, new_periods: &Periods) -> Result { + let uid = EmgauwaUid::default(); + sqlx::query_as!(Schedule, "INSERT INTO schedules (uid, name, periods) VALUES (?, ?, ?) RETURNING *", + uid, + new_name, + new_periods, + ) + .fetch_optional(pool) + .await? + .ok_or(DatabaseError::InsertGetError) } -pub fn update_schedule( +pub async fn update_schedule( + pool: &Pool, schedule: &Schedule, new_name: &str, new_periods: &Periods, ) -> Result { - let mut connection = get_connection(); - + // overwrite periods on protected schedules let new_periods = match schedule.uid { EmgauwaUid::Off | EmgauwaUid::On => schedule.periods.borrow(), EmgauwaUid::Any(_) => new_periods, }; - diesel::update(schedule) - .set(( - schema::schedules::name.eq(new_name), - schema::schedules::periods.eq(new_periods), - )) - .execute(&mut connection) - .map_err(DatabaseError::UpdateError)?; + sqlx::query!("UPDATE schedules SET name = ?, periods = ? WHERE id = ?", + new_name, + new_periods, + schedule.id, + ) + .execute(pool) + .await?; - get_schedule_by_uid(schedule.uid.clone()) + get_schedule_by_uid(pool, &schedule.uid).await } -pub fn set_schedule_tags(schedule: &Schedule, new_tags: &[String]) -> Result<(), DatabaseError> { - let mut connection = get_connection(); - diesel::delete(junction_tag.filter(schema::junction_tag::schedule_id.eq(schedule.id))) - .execute(&mut connection) - .or(Err(DatabaseError::DeleteError))?; +pub async fn set_schedule_tags(pool: &Pool, schedule: &Schedule, new_tags: &[String]) -> Result<(), DatabaseError> { + sqlx::query!("DELETE FROM junction_tag WHERE schedule_id = ?", schedule.id) + .execute(pool) + .await?; - let mut database_tags: Vec = tags - .filter(schema::tags::tag.eq_any(new_tags)) - .load::(&mut connection) - .expect("Error loading tags"); - - // create missing tags for new_tag in new_tags { - if !database_tags.iter().any(|tab_db| tab_db.tag.eq(new_tag)) { - database_tags.push(create_tag(new_tag).expect("Error inserting tag")); - } - } + let tag: Option = sqlx::query_as!(Tag, "SELECT * FROM tags WHERE tag = ?", new_tag) + .fetch_optional(pool) + .await?; - for database_tag in database_tags { - create_junction_tag(database_tag, None, Some(schedule)) - .expect("Error saving junction between tag and schedule"); - } + let tag = match tag { + Some(id) => id, + None => { + create_tag(pool, new_tag).await? + } + }; + create_junction_tag_schedule(pool, tag, schedule).await?; + } Ok(()) } diff --git a/src/db/schema.rs b/src/db/schema.rs deleted file mode 100644 index 565cbee..0000000 --- a/src/db/schema.rs +++ /dev/null @@ -1,93 +0,0 @@ -table! { - controllers (id) { - id -> Integer, - uid -> Text, - name -> Text, - ip -> Nullable, - port -> Nullable, - relay_count -> Nullable, - active -> Bool, - } -} - -table! { - junction_relay_schedule (id) { - id -> Integer, - weekday -> SmallInt, - relay_id -> Nullable, - schedule_id -> Nullable, - } -} - -table! { - junction_tag (id) { - id -> Integer, - tag_id -> Integer, - relay_id -> Nullable, - schedule_id -> Nullable, - } -} - -table! { - macro_actions (id) { - id -> Integer, - macro_id -> Integer, - relay_id -> Integer, - schedule_id -> Integer, - weekday -> SmallInt, - } -} - -table! { - macros (id) { - id -> Integer, - uid -> Text, - name -> Text, - } -} - -table! { - relays (id) { - id -> Integer, - name -> Text, - number -> Integer, - controller_id -> Integer, - } -} - -table! { - schedules (id) { - id -> Integer, - uid -> Binary, - name -> Text, - periods -> Binary, - } -} - -table! { - tags (id) { - id -> Integer, - tag -> Text, - } -} - -joinable!(junction_relay_schedule -> relays (relay_id)); -joinable!(junction_relay_schedule -> schedules (schedule_id)); -joinable!(junction_tag -> relays (relay_id)); -joinable!(junction_tag -> schedules (schedule_id)); -joinable!(junction_tag -> tags (tag_id)); -joinable!(macro_actions -> macros (macro_id)); -joinable!(macro_actions -> relays (relay_id)); -joinable!(macro_actions -> schedules (schedule_id)); -joinable!(relays -> controllers (controller_id)); - -allow_tables_to_appear_in_same_query!( - controllers, - junction_relay_schedule, - junction_tag, - macro_actions, - macros, - relays, - schedules, - tags, -); diff --git a/src/db/tag.rs b/src/db/tag.rs index c31df9a..881b2c0 100644 --- a/src/db/tag.rs +++ b/src/db/tag.rs @@ -1,63 +1,41 @@ -use diesel::dsl::sql; -use diesel::prelude::*; - +use sqlx::{Pool, Sqlite}; use crate::db::errors::DatabaseError; use crate::db::models::*; -use crate::db::schema::junction_tag::dsl::junction_tag; -use crate::db::schema::tags::dsl::tags; -use crate::db::{get_connection, schema}; -pub fn create_tag(new_tag: &str) -> Result { - let mut connection = get_connection(); - - let new_tag = NewTag { tag: new_tag }; - - diesel::insert_into(tags) - .values(&new_tag) - .execute(&mut connection) - .map_err(DatabaseError::InsertError)?; - - let result = tags - .find(sql("last_insert_rowid()")) - .get_result::(&mut connection) - .or(Err(DatabaseError::InsertGetError))?; - - Ok(result) +pub async fn create_tag(pool: &Pool, new_tag: &str) -> Result { + sqlx::query_as!(Tag, "INSERT INTO tags (tag) VALUES (?) RETURNING *", new_tag) + .fetch_optional(pool) + .await? + .ok_or(DatabaseError::InsertGetError) } -pub fn get_tag(target_tag: &str) -> Result { - let mut connection = get_connection(); - - let result = tags - .filter(schema::tags::tag.eq(target_tag)) - .first::(&mut connection) - .or(Err(DatabaseError::NotFound))?; - - Ok(result) +pub async fn get_tag(pool: &Pool, target_tag: &str) -> Result { + sqlx::query_as!(Tag, "SELECT * FROM tags WHERE tag = ?", target_tag) + .fetch_optional(pool) + .await + .map(|t| t.ok_or(DatabaseError::NotFound))? } -pub fn create_junction_tag( +#[allow(dead_code)] +pub async fn create_junction_tag_relay( + pool: &Pool, target_tag: Tag, - target_relay: Option<&Relay>, - target_schedule: Option<&Schedule>, + target_relay: &Relay, ) -> Result { - let mut connection = get_connection(); - let new_junction_tag = NewJunctionTag { - relay_id: target_relay.map(|r| r.id), - schedule_id: target_schedule.map(|s| s.id), - tag_id: target_tag.id, - }; - - diesel::insert_into(junction_tag) - .values(&new_junction_tag) - .execute(&mut connection) - .map_err(DatabaseError::InsertError)?; - - let result = junction_tag - .find(sql("last_insert_rowid()")) - .get_result::(&mut connection) - .or(Err(DatabaseError::InsertGetError))?; - - Ok(result) + sqlx::query_as!(JunctionTag, "INSERT INTO junction_tag (tag_id, relay_id) VALUES (?, ?) RETURNING *", target_tag.id, target_relay.id) + .fetch_optional(pool) + .await? + .ok_or(DatabaseError::InsertGetError) +} + +pub async fn create_junction_tag_schedule( + pool: &Pool, + target_tag: Tag, + target_schedule: &Schedule, +) -> Result { + sqlx::query_as!(JunctionTag, "INSERT INTO junction_tag (tag_id, schedule_id) VALUES (?, ?) RETURNING *", target_tag.id, target_schedule.id) + .fetch_optional(pool) + .await? + .ok_or(DatabaseError::InsertGetError) } diff --git a/src/handlers/v1/schedules.rs b/src/handlers/v1/schedules.rs index c6ee513..21ce7cc 100644 --- a/src/handlers/v1/schedules.rs +++ b/src/handlers/v1/schedules.rs @@ -3,6 +3,8 @@ use actix_web::{delete, get, post, put, web, HttpResponse, Responder}; use serde::{Deserialize, Serialize}; use std::borrow::Borrow; use std::convert::TryFrom; +use futures::future; +use sqlx::{Pool, Sqlite}; use crate::db::models::{Periods, Schedule}; use crate::db::schedules::*; @@ -20,38 +22,59 @@ pub struct RequestSchedule { } #[get("/api/v1/schedules")] -pub async fn index() -> impl Responder { - let schedules = get_schedules(); - let return_schedules: Vec = - schedules.iter().map(ReturnSchedule::from).collect(); +pub async fn index(pool: web::Data>) -> impl Responder { + let schedules = get_schedules(&pool).await; + + if let Err(err) = schedules { + return HttpResponse::from(err); + } + let schedules = schedules.unwrap(); + + let mut return_schedules: Vec = schedules.iter().map(ReturnSchedule::from).collect(); + for schedule in return_schedules.iter_mut() { + schedule.load_tags(&pool); + } + HttpResponse::Ok().json(return_schedules) } #[get("/api/v1/schedules/tag/{tag}")] -pub async fn tagged(path: web::Path<(String,)>) -> impl Responder { +pub async fn tagged(pool: web::Data>, path: web::Path<(String,)>) -> impl Responder { let (tag,) = path.into_inner(); - let tag_db = get_tag(&tag); - if tag_db.is_err() { - return HttpResponse::from(tag_db.unwrap_err()); + let tag_db = get_tag(&pool, &tag).await; + if let Err(err) = tag_db { + return HttpResponse::from(err); } let tag_db = tag_db.unwrap(); - let schedules = get_schedules_by_tag(&tag_db); - let return_schedules: Vec = + let schedules = get_schedules_by_tag(&pool, &tag_db).await; + if let Err(err) = schedules { + return HttpResponse::from(err); + } + let schedules = schedules.unwrap(); + + let mut return_schedules: Vec = schedules.iter().map(ReturnSchedule::from).collect(); + for schedule in return_schedules.iter_mut() { + schedule.load_tags(&pool); + } HttpResponse::Ok().json(return_schedules) } #[get("/api/v1/schedules/{schedule_id}")] -pub async fn show(path: web::Path<(String,)>) -> impl Responder { +pub async fn show(pool: web::Data>, path: web::Path<(String,)>) -> impl Responder { let (schedule_uid,) = path.into_inner(); let emgauwa_uid = EmgauwaUid::try_from(schedule_uid.as_str()).or(Err(HandlerError::BadUid)); match emgauwa_uid { Ok(uid) => { - let schedule = get_schedule_by_uid(uid); + let schedule = get_schedule_by_uid(&pool, &uid).await; match schedule { - Ok(ok) => HttpResponse::Ok().json(ReturnSchedule::from(ok)), + Ok(ok) => { + let mut return_schedule = ReturnSchedule::from(ok); + return_schedule.load_tags(&pool); + HttpResponse::Ok().json(return_schedule) + }, Err(err) => HttpResponse::from(err), } } @@ -60,35 +83,40 @@ pub async fn show(path: web::Path<(String,)>) -> impl Responder { } #[post("/api/v1/schedules")] -pub async fn add(data: web::Json) -> impl Responder { - let new_schedule = create_schedule(&data.name, &data.periods); +pub async fn add(pool: web::Data>, data: web::Json) -> impl Responder { + let new_schedule = create_schedule(&pool, &data.name, &data.periods).await; - if new_schedule.is_err() { - return HttpResponse::from(new_schedule.unwrap_err()); + if let Err(err) = new_schedule { + return HttpResponse::from(err); } let new_schedule = new_schedule.unwrap(); - let result = set_schedule_tags(&new_schedule, data.tags.as_slice()); - if result.is_err() { - return HttpResponse::from(result.unwrap_err()); + let result = set_schedule_tags(&pool, &new_schedule, data.tags.as_slice()).await; + if let Err(err) = result { + return HttpResponse::from(err); } - HttpResponse::Created().json(ReturnSchedule::from(new_schedule)) + let mut return_schedule = ReturnSchedule::from(new_schedule); + return_schedule.load_tags(&pool); + HttpResponse::Created().json(return_schedule) +} + +async fn add_list_single(pool: &Pool, request_schedule: &RequestSchedule) -> Result { + let new_schedule = create_schedule(pool, &request_schedule.name, &request_schedule.periods).await?; + + set_schedule_tags(pool, &new_schedule, request_schedule.tags.as_slice()).await?; + + Ok(new_schedule) } #[post("/api/v1/schedules/list")] -pub async fn add_list(data: web::Json>) -> impl Responder { - let result: Vec> = data +pub async fn add_list(pool: web::Data>, data: web::Json>) -> impl Responder { + let result: Vec> = future::join_all( + data .as_slice() .iter() - .map(|request_schedule| { - let new_schedule = create_schedule(&request_schedule.name, &request_schedule.periods)?; - - set_schedule_tags(&new_schedule, request_schedule.tags.as_slice())?; - - Ok(new_schedule) - }) - .collect(); + .map(|request_schedule| add_list_single(&pool, request_schedule)) + ).await; match vec_has_error(&result) { true => HttpResponse::from( @@ -99,10 +127,14 @@ pub async fn add_list(data: web::Json>) -> impl Responder { .unwrap_err(), ), false => { - let return_schedules: Vec = result + let mut return_schedules: Vec = result .iter() .map(|s| ReturnSchedule::from(s.as_ref().unwrap())) .collect(); + + for schedule in return_schedules.iter_mut() { + schedule.load_tags(&pool); + } HttpResponse::Created().json(return_schedules) } } @@ -110,38 +142,41 @@ pub async fn add_list(data: web::Json>) -> impl Responder { #[put("/api/v1/schedules/{schedule_id}")] pub async fn update( + pool: web::Data>, path: web::Path<(String,)>, data: web::Json, ) -> impl Responder { let (schedule_uid,) = path.into_inner(); let emgauwa_uid = EmgauwaUid::try_from(schedule_uid.as_str()).or(Err(HandlerError::BadUid)); - if emgauwa_uid.is_err() { - return HttpResponse::from(emgauwa_uid.unwrap_err()); + if let Err(err) = emgauwa_uid { + return HttpResponse::from(err); } let emgauwa_uid = emgauwa_uid.unwrap(); - let schedule = get_schedule_by_uid(emgauwa_uid); - if schedule.is_err() { - return HttpResponse::from(schedule.unwrap_err()); + let schedule = get_schedule_by_uid(&pool, &emgauwa_uid, ).await; + if let Err(err) = schedule { + return HttpResponse::from(err); } let schedule = schedule.unwrap(); - let schedule = update_schedule(&schedule, data.name.as_str(), data.periods.borrow()); - if schedule.is_err() { - return HttpResponse::from(schedule.unwrap_err()); + let schedule = update_schedule(&pool, &schedule, data.name.as_str(), data.periods.borrow()).await; + if let Err(err) = schedule { + return HttpResponse::from(err); } let schedule = schedule.unwrap(); - let result = set_schedule_tags(&schedule, data.tags.as_slice()); - if result.is_err() { - return HttpResponse::from(result.unwrap_err()); + let result = set_schedule_tags(&pool, &schedule, data.tags.as_slice()).await; + if let Err(err) = result { + return HttpResponse::from(err); } - HttpResponse::Ok().json(ReturnSchedule::from(schedule)) + let mut return_schedule = ReturnSchedule::from(schedule); + return_schedule.load_tags(&pool); + HttpResponse::Ok().json(return_schedule) } #[delete("/api/v1/schedules/{schedule_id}")] -pub async fn delete(path: web::Path<(String,)>) -> impl Responder { +pub async fn delete(pool: web::Data>, path: web::Path<(String,)>) -> impl Responder { let (schedule_uid,) = path.into_inner(); let emgauwa_uid = EmgauwaUid::try_from(schedule_uid.as_str()).or(Err(HandlerError::BadUid)); @@ -149,7 +184,7 @@ pub async fn delete(path: web::Path<(String,)>) -> impl Responder { Ok(uid) => match uid { EmgauwaUid::Off => HttpResponse::from(HandlerError::ProtectedSchedule), EmgauwaUid::On => HttpResponse::from(HandlerError::ProtectedSchedule), - EmgauwaUid::Any(_) => match delete_schedule_by_uid(uid) { + EmgauwaUid::Any(_) => match delete_schedule_by_uid(&pool, uid).await { Ok(_) => HttpResponse::Ok().json("schedule got deleted"), Err(err) => HttpResponse::from(err), }, diff --git a/src/main.rs b/src/main.rs index 0aa72eb..8d3f292 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,13 +1,9 @@ -#[macro_use] -extern crate diesel; -extern crate diesel_migrations; extern crate dotenv; use actix_web::middleware::TrailingSlash; use actix_web::{middleware, web, App, HttpServer}; use log::{trace, LevelFilter}; use simple_logger::SimpleLogger; -use std::fmt::format; use std::str::FromStr; mod db; @@ -31,9 +27,9 @@ async fn main() -> std::io::Result<()> { .init() .unwrap_or_else(|_| panic!("Error initializing logger.")); - db::init(&settings.database); + let pool = db::init(&settings.database).await; - HttpServer::new(|| { + HttpServer::new(move || { App::new() .wrap( middleware::DefaultHeaders::new() @@ -44,6 +40,7 @@ async fn main() -> std::io::Result<()> { .wrap(middleware::Logger::default()) .wrap(middleware::NormalizePath::new(TrailingSlash::Trim)) .app_data(web::JsonConfig::default().error_handler(handlers::json_error_handler)) + .app_data(web::Data::new(pool.clone())) .service(handlers::v1::schedules::index) .service(handlers::v1::schedules::tagged) .service(handlers::v1::schedules::show) diff --git a/src/return_models.rs b/src/return_models.rs index c000435..49e7212 100644 --- a/src/return_models.rs +++ b/src/return_models.rs @@ -1,3 +1,4 @@ +use futures::executor; use serde::Serialize; use crate::db::models::Schedule; @@ -10,10 +11,15 @@ pub struct ReturnSchedule { pub tags: Vec, } +impl ReturnSchedule { + pub fn load_tags(&mut self, pool: &sqlx::Pool) { + self.tags = executor::block_on(get_schedule_tags(pool, &self.schedule)).unwrap(); + } +} + impl From for ReturnSchedule { fn from(schedule: Schedule) -> Self { - let tags: Vec = get_schedule_tags(&schedule); - ReturnSchedule { schedule, tags } + ReturnSchedule { schedule, tags: vec![]} } } diff --git a/src/types.rs b/src/types.rs index 9fb50b5..f75f885 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,10 +1,8 @@ -use diesel::sql_types::Binary; use uuid::Uuid; pub mod emgauwa_uid; -#[derive(AsExpression, FromSqlRow, PartialEq, Clone)] -#[diesel(sql_type = Binary)] +#[derive(PartialEq, Clone)] pub enum EmgauwaUid { Off, On, diff --git a/src/types/emgauwa_uid.rs b/src/types/emgauwa_uid.rs index 7183f0d..84f0503 100644 --- a/src/types/emgauwa_uid.rs +++ b/src/types/emgauwa_uid.rs @@ -3,12 +3,12 @@ use std::fmt::{Debug, Formatter}; use std::str::FromStr; use crate::types::EmgauwaUid; -use diesel::backend::Backend; -use diesel::deserialize::FromSql; -use diesel::serialize::{IsNull, Output, ToSql}; -use diesel::sql_types::Binary; -use diesel::{deserialize, serialize}; use serde::{Serialize, Serializer}; +use sqlx::{Decode, Encode, Sqlite, Type}; +use sqlx::database::HasArguments; +use sqlx::encode::IsNull; +use sqlx::error::BoxDynError; +use sqlx::sqlite::{SqliteTypeInfo, SqliteValueRef}; use uuid::Uuid; impl EmgauwaUid { @@ -36,34 +36,26 @@ impl Debug for EmgauwaUid { } } -impl ToSql for EmgauwaUid -where - DB: Backend, - [u8]: ToSql, -{ - fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, DB>) -> serialize::Result { - match self { - EmgauwaUid::Off => [EmgauwaUid::OFF_U8].to_sql(out)?, - EmgauwaUid::On => [EmgauwaUid::ON_U8].to_sql(out)?, - EmgauwaUid::Any(value) => value.as_bytes().to_sql(out)?, - }; - Ok(IsNull::No) +impl Type for EmgauwaUid { + fn type_info() -> SqliteTypeInfo { + <&[u8] as Type>::type_info() + } + + fn compatible(ty: &SqliteTypeInfo) -> bool { + <&[u8] as Type>::compatible(ty) } } -impl FromSql for EmgauwaUid -where - DB: Backend, - Vec: FromSql, -{ - fn from_sql(bytes: DB::RawValue<'_>) -> deserialize::Result { - let blob: Vec = FromSql::::from_sql(bytes)?; +impl<'q> Encode<'q, Sqlite> for EmgauwaUid { + //noinspection DuplicatedCode + fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> IsNull { + <&Vec as Encode>::encode(&Vec::from(self), buf) + } +} - match blob.as_slice() { - [EmgauwaUid::OFF_U8] => Ok(EmgauwaUid::Off), - [EmgauwaUid::ON_U8] => Ok(EmgauwaUid::On), - value_bytes => Ok(EmgauwaUid::Any(Uuid::from_slice(value_bytes).unwrap())), - } +impl<'r> Decode<'r, Sqlite> for EmgauwaUid { + fn decode(value: SqliteValueRef<'r>) -> Result { + Ok(EmgauwaUid::from(<&[u8] as Decode>::decode(value)?)) } } @@ -104,8 +96,8 @@ impl TryFrom<&str> for EmgauwaUid { impl From<&EmgauwaUid> for Uuid { fn from(emgauwa_uid: &EmgauwaUid) -> Uuid { match emgauwa_uid { - EmgauwaUid::Off => uuid::Uuid::from_u128(EmgauwaUid::OFF_U128), - EmgauwaUid::On => uuid::Uuid::from_u128(EmgauwaUid::ON_U128), + EmgauwaUid::Off => Uuid::from_u128(EmgauwaUid::OFF_U128), + EmgauwaUid::On => Uuid::from_u128(EmgauwaUid::ON_U128), EmgauwaUid::Any(value) => *value, } } @@ -120,3 +112,33 @@ impl From<&EmgauwaUid> for String { } } } + +impl From<&EmgauwaUid> for Vec { + fn from(emgauwa_uid: &EmgauwaUid) -> Vec { + match emgauwa_uid { + EmgauwaUid::Off => vec![EmgauwaUid::OFF_U8], + EmgauwaUid::On => vec![EmgauwaUid::ON_U8], + EmgauwaUid::Any(value) => value.as_bytes().to_vec(), + } + } +} + +impl From<&[u8]> for EmgauwaUid { + fn from(value: &[u8]) -> Self { + match value { + [EmgauwaUid::OFF_U8] => EmgauwaUid::Off, + [EmgauwaUid::ON_U8] => EmgauwaUid::On, + value_bytes => EmgauwaUid::Any(Uuid::from_slice(value_bytes).unwrap()), + } + } +} + +impl From> for EmgauwaUid { + fn from(value: Vec) -> Self { + match value.as_slice() { + [EmgauwaUid::OFF_U8] => EmgauwaUid::Off, + [EmgauwaUid::ON_U8] => EmgauwaUid::On, + value_bytes => EmgauwaUid::Any(Uuid::from_slice(value_bytes).unwrap()), + } + } +}