diff --git a/cove-db/src/query/user.rs b/cove-db/src/query/user.rs index 53d9962..1cdbf45 100644 --- a/cove-db/src/query/user.rs +++ b/cove-db/src/query/user.rs @@ -1,3 +1,4 @@ +use anyhow::anyhow; use async_trait::async_trait; use sqlx::Executor; use sqlx::postgres::PgQueryResult; @@ -5,7 +6,7 @@ use cove_net_common::id::SnowflakeID; use crate::{CoveDB, CoveDBImpl, CoveDbError}; use crate::part::{BindQueryBuilder, SqlPart}; use crate::query::{PartialQueryResult, QueryResult}; -use crate::rows::{DeletableRow, InsertableRow, SelectableRow, TableRow, WhereRow}; +use crate::rows::{DeletableRow, InsertableRow, PartialTableRow, SelectableRow, TableRow, WhereRow}; use crate::rows::user::{PartialUserRow, UserRow}; #[async_trait] @@ -39,7 +40,23 @@ pub trait UserQueries: CoveDBImpl { let out_partial_row: ::PartialRow = self.get_pool().fetch_one(query.sql_query).await?.into(); - Ok(PartialQueryResult::new(out_partial_row, None)) + if let Some(username) = &out_partial_row.username { + if username == user_request_row.username.as_ref().unwrap() { + if let Some(discriminator) = &out_partial_row.discriminator { + if discriminator == user_request_row.discriminator.as_ref().unwrap() { + Ok(PartialQueryResult::new(out_partial_row, None)) + } else { + Err(anyhow!("Discriminator mismatch for user query. Expected: '{}', got: '{}'", user_request_row.discriminator.as_ref().unwrap(), discriminator).into()) + } + } else { + Err(anyhow!("Bad discriminator '{}' for username '{}'", user_request_row.discriminator.as_ref().unwrap(), username).into()) + } + } else { + Err(anyhow!("Username mismatch for user query. Expected: '{}', got: '{}'", user_request_row.username.as_ref().unwrap(), username).into()) + } + } else { + Err(anyhow!("Bad username '{}' for user query", user_request_row.username.as_ref().unwrap()).into()) + } } async fn get_user_by_id(&self, user_id: SnowflakeID) -> Result, CoveDbError> { @@ -57,7 +74,15 @@ pub trait UserQueries: CoveDBImpl { let out_partial_row: ::PartialRow = self.get_pool().fetch_one(query.sql_query).await?.into(); - Ok(PartialQueryResult::new(out_partial_row, None)) + if let Some(id) = &out_partial_row.id { + if id == user_request_row.id.as_ref().unwrap() { + Ok(PartialQueryResult::new(out_partial_row, None)) + } else { + Err(anyhow!("User ID mismatch for user query. Expected '{}', got: '{}'", user_request_row.id.as_ref().unwrap(), id).into()) + } + } else { + Err(anyhow!("Bad user ID {} for user query", out_partial_row.id.unwrap()).into()) + } } async fn delete_user_by_id(&self, user_id: SnowflakeID) -> Result, CoveDbError> { @@ -66,13 +91,7 @@ pub trait UserQueries: CoveDBImpl { ..Default::default() }; - let mut query_builder = BindQueryBuilder::new(); - user_delete_row.select(vec!["*"]).encode(&mut query_builder)?; - let sql_where = user_delete_row.where_all(); - sql_where.encode(&mut query_builder)?; - let query = user_delete_row.bind(sql_where, query_builder.to_query())?; - - let out_val: UserRow = self.get_pool().fetch_one(query.sql_query).await?.try_into()?; + let out_val: UserRow = user_delete_row.get_full(self).await?.into(); let mut query_builder = BindQueryBuilder::new(); @@ -84,7 +103,11 @@ pub trait UserQueries: CoveDBImpl { let res = self.run_query::(query.sql_query).await?; - Ok(QueryResult::new(out_val, Some(res))) + if res.rows_affected() > 0 { + Ok(QueryResult::new(out_val, Some(res))) + } else { + Err(anyhow!("User with ID '{}' not found for deletion", user_delete_row.id.as_ref().unwrap()).into()) + } } } diff --git a/cove-db/src/rows/mod.rs b/cove-db/src/rows/mod.rs index e995d36..a0a35cd 100644 --- a/cove-db/src/rows/mod.rs +++ b/cove-db/src/rows/mod.rs @@ -1,7 +1,7 @@ use async_trait::async_trait; use sqlx::Executor; use sqlx::postgres::PgRow; -use crate::{CoveDB, CoveDBImpl}; +use crate::CoveDBImpl; use crate::part::{BindQuery, BindQueryBuilder, SqlPart}; use crate::part::delete::SqlDelete; use crate::part::insert::SqlInsert; @@ -30,7 +30,7 @@ pub trait PartialTableRow { Self::FullTableRow::get_table_name() } - async fn get_full(&self, db: &CoveDB) -> Result where + async fn get_full(&self, db: &(impl CoveDBImpl + Sync + ?Sized)) -> Result where Self: SelectableRow + WhereRow, ::Error: From, Self::FullTableRow: TryFrom, @@ -67,6 +67,9 @@ pub trait SelectableRow: PartialTableRow { } pub trait DeletableRow: PartialTableRow { + /// Equivalent to ` + /// DELETE FROM table_name + /// ` fn delete(&'_ self) -> SqlDelete { SqlDelete::with_table(Self::get_table_name()) } diff --git a/cove-net/common/src/id/message_type.rs b/cove-net/common/src/id/message_type.rs index c2a48ed..1c1c081 100644 --- a/cove-net/common/src/id/message_type.rs +++ b/cove-net/common/src/id/message_type.rs @@ -5,7 +5,7 @@ use crate::id::types::channel::ChannelMessageType; use crate::id::types::text::TextMessageType; #[repr(u8)] -#[derive(Debug, Serialize, Deserialize, Copy, Clone)] +#[derive(Debug, Serialize, Deserialize, Copy, Clone, Eq, PartialEq)] pub enum MessageType { Text(TextMessageType) = 0, Channel(ChannelMessageType) = 1, diff --git a/cove-net/common/src/id/mod.rs b/cove-net/common/src/id/mod.rs index f862f83..73c2cd5 100644 --- a/cove-net/common/src/id/mod.rs +++ b/cove-net/common/src/id/mod.rs @@ -32,7 +32,7 @@ use sqlx::postgres::{PgHasArrayType, PgTypeInfo}; use crate::id::message_type::MessageType; #[repr(C, packed(1))] -#[derive(Debug, DeserializeFromStr, SerializeDisplay, Clone)] +#[derive(Debug, DeserializeFromStr, SerializeDisplay, Clone, Eq, PartialEq)] pub struct SnowflakeID { pub message_type: MessageType, pub location: [u8;4], diff --git a/cove-net/common/src/id/types/channel.rs b/cove-net/common/src/id/types/channel.rs index 008f7c2..ac705cb 100644 --- a/cove-net/common/src/id/types/channel.rs +++ b/cove-net/common/src/id/types/channel.rs @@ -7,7 +7,7 @@ use sqlx::error::BoxDynError; use sqlx::postgres::PgValueRef; #[repr(u8)] -#[derive(Debug, Serialize, Deserialize, Copy, Clone)] +#[derive(Debug, Serialize, Deserialize, Copy, Clone, Eq, PartialEq)] pub enum ChannelMessageType { Text = 0, Voice = 1, diff --git a/cove-net/common/src/id/types/text.rs b/cove-net/common/src/id/types/text.rs index bd2b634..0ce079a 100644 --- a/cove-net/common/src/id/types/text.rs +++ b/cove-net/common/src/id/types/text.rs @@ -7,7 +7,7 @@ use sqlx::error::BoxDynError; use sqlx::postgres::PgValueRef; #[repr(u8)] -#[derive(Debug, Serialize, Deserialize, Copy, Clone)] +#[derive(Debug, Serialize, Deserialize, Copy, Clone, Eq, PartialEq)] pub enum TextMessageType { Text = 0, Reaction = 1,