[cove-db] Improve user error handling

This commit is contained in:
CanadianBaconBoi 2026-02-27 18:46:23 +01:00
parent d59ff11799
commit b2c20e9099
6 changed files with 43 additions and 17 deletions

View File

@ -1,3 +1,4 @@
use anyhow::anyhow;
use async_trait::async_trait; use async_trait::async_trait;
use sqlx::Executor; use sqlx::Executor;
use sqlx::postgres::PgQueryResult; use sqlx::postgres::PgQueryResult;
@ -5,7 +6,7 @@ use cove_net_common::id::SnowflakeID;
use crate::{CoveDB, CoveDBImpl, CoveDbError}; use crate::{CoveDB, CoveDBImpl, CoveDbError};
use crate::part::{BindQueryBuilder, SqlPart}; use crate::part::{BindQueryBuilder, SqlPart};
use crate::query::{PartialQueryResult, QueryResult}; use crate::query::{PartialQueryResult, QueryResult};
use crate::rows::{DeletableRow, InsertableRow, SelectableRow, TableRow, WhereRow}; use crate::rows::{DeletableRow, InsertableRow, PartialTableRow, SelectableRow, TableRow, WhereRow};
use crate::rows::user::{PartialUserRow, UserRow}; use crate::rows::user::{PartialUserRow, UserRow};
#[async_trait] #[async_trait]
@ -39,7 +40,23 @@ pub trait UserQueries: CoveDBImpl {
let out_partial_row: <UserRow as TableRow>::PartialRow = self.get_pool().fetch_one(query.sql_query).await?.into(); let out_partial_row: <UserRow as TableRow>::PartialRow = self.get_pool().fetch_one(query.sql_query).await?.into();
Ok(PartialQueryResult::new(out_partial_row, None)) if let Some(username) = &out_partial_row.username {
if username == user_request_row.username.as_ref().unwrap() {
if let Some(discriminator) = &out_partial_row.discriminator {
if discriminator == user_request_row.discriminator.as_ref().unwrap() {
Ok(PartialQueryResult::new(out_partial_row, None))
} else {
Err(anyhow!("Discriminator mismatch for user query. Expected: '{}', got: '{}'", user_request_row.discriminator.as_ref().unwrap(), discriminator).into())
}
} else {
Err(anyhow!("Bad discriminator '{}' for username '{}'", user_request_row.discriminator.as_ref().unwrap(), username).into())
}
} else {
Err(anyhow!("Username mismatch for user query. Expected: '{}', got: '{}'", user_request_row.username.as_ref().unwrap(), username).into())
}
} else {
Err(anyhow!("Bad username '{}' for user query", user_request_row.username.as_ref().unwrap()).into())
}
} }
async fn get_user_by_id(&self, user_id: SnowflakeID) -> Result<PartialQueryResult<PartialUserRow>, CoveDbError> { async fn get_user_by_id(&self, user_id: SnowflakeID) -> Result<PartialQueryResult<PartialUserRow>, CoveDbError> {
@ -57,7 +74,15 @@ pub trait UserQueries: CoveDBImpl {
let out_partial_row: <UserRow as TableRow>::PartialRow = self.get_pool().fetch_one(query.sql_query).await?.into(); let out_partial_row: <UserRow as TableRow>::PartialRow = self.get_pool().fetch_one(query.sql_query).await?.into();
Ok(PartialQueryResult::new(out_partial_row, None)) if let Some(id) = &out_partial_row.id {
if id == user_request_row.id.as_ref().unwrap() {
Ok(PartialQueryResult::new(out_partial_row, None))
} else {
Err(anyhow!("User ID mismatch for user query. Expected '{}', got: '{}'", user_request_row.id.as_ref().unwrap(), id).into())
}
} else {
Err(anyhow!("Bad user ID {} for user query", out_partial_row.id.unwrap()).into())
}
} }
async fn delete_user_by_id(&self, user_id: SnowflakeID) -> Result<QueryResult<UserRow>, CoveDbError> { async fn delete_user_by_id(&self, user_id: SnowflakeID) -> Result<QueryResult<UserRow>, CoveDbError> {
@ -66,13 +91,7 @@ pub trait UserQueries: CoveDBImpl {
..Default::default() ..Default::default()
}; };
let mut query_builder = BindQueryBuilder::new(); let out_val: UserRow = user_delete_row.get_full(self).await?.into();
user_delete_row.select(vec!["*"]).encode(&mut query_builder)?;
let sql_where = user_delete_row.where_all();
sql_where.encode(&mut query_builder)?;
let query = user_delete_row.bind(sql_where, query_builder.to_query())?;
let out_val: UserRow = self.get_pool().fetch_one(query.sql_query).await?.try_into()?;
let mut query_builder = BindQueryBuilder::new(); let mut query_builder = BindQueryBuilder::new();
@ -84,7 +103,11 @@ pub trait UserQueries: CoveDBImpl {
let res = self.run_query::<PgQueryResult>(query.sql_query).await?; let res = self.run_query::<PgQueryResult>(query.sql_query).await?;
Ok(QueryResult::new(out_val, Some(res))) if res.rows_affected() > 0 {
Ok(QueryResult::new(out_val, Some(res)))
} else {
Err(anyhow!("User with ID '{}' not found for deletion", user_delete_row.id.as_ref().unwrap()).into())
}
} }
} }

View File

@ -1,7 +1,7 @@
use async_trait::async_trait; use async_trait::async_trait;
use sqlx::Executor; use sqlx::Executor;
use sqlx::postgres::PgRow; use sqlx::postgres::PgRow;
use crate::{CoveDB, CoveDBImpl}; use crate::CoveDBImpl;
use crate::part::{BindQuery, BindQueryBuilder, SqlPart}; use crate::part::{BindQuery, BindQueryBuilder, SqlPart};
use crate::part::delete::SqlDelete; use crate::part::delete::SqlDelete;
use crate::part::insert::SqlInsert; use crate::part::insert::SqlInsert;
@ -30,7 +30,7 @@ pub trait PartialTableRow {
Self::FullTableRow::get_table_name() Self::FullTableRow::get_table_name()
} }
async fn get_full(&self, db: &CoveDB) -> Result<Self::FullTableRow, Self::Error> where async fn get_full(&self, db: &(impl CoveDBImpl + Sync + ?Sized)) -> Result<Self::FullTableRow, Self::Error> where
Self: SelectableRow + WhereRow, Self: SelectableRow + WhereRow,
<Self as PartialTableRow>::Error: From<anyhow::Error>, <Self as PartialTableRow>::Error: From<anyhow::Error>,
Self::FullTableRow: TryFrom<PgRow>, Self::FullTableRow: TryFrom<PgRow>,
@ -67,6 +67,9 @@ pub trait SelectableRow: PartialTableRow {
} }
pub trait DeletableRow: PartialTableRow { pub trait DeletableRow: PartialTableRow {
/// Equivalent to `
/// DELETE FROM table_name
/// `
fn delete(&'_ self) -> SqlDelete { fn delete(&'_ self) -> SqlDelete {
SqlDelete::with_table(Self::get_table_name()) SqlDelete::with_table(Self::get_table_name())
} }

View File

@ -5,7 +5,7 @@ use crate::id::types::channel::ChannelMessageType;
use crate::id::types::text::TextMessageType; use crate::id::types::text::TextMessageType;
#[repr(u8)] #[repr(u8)]
#[derive(Debug, Serialize, Deserialize, Copy, Clone)] #[derive(Debug, Serialize, Deserialize, Copy, Clone, Eq, PartialEq)]
pub enum MessageType { pub enum MessageType {
Text(TextMessageType) = 0, Text(TextMessageType) = 0,
Channel(ChannelMessageType) = 1, Channel(ChannelMessageType) = 1,

View File

@ -32,7 +32,7 @@ use sqlx::postgres::{PgHasArrayType, PgTypeInfo};
use crate::id::message_type::MessageType; use crate::id::message_type::MessageType;
#[repr(C, packed(1))] #[repr(C, packed(1))]
#[derive(Debug, DeserializeFromStr, SerializeDisplay, Clone)] #[derive(Debug, DeserializeFromStr, SerializeDisplay, Clone, Eq, PartialEq)]
pub struct SnowflakeID { pub struct SnowflakeID {
pub message_type: MessageType, pub message_type: MessageType,
pub location: [u8;4], pub location: [u8;4],

View File

@ -7,7 +7,7 @@ use sqlx::error::BoxDynError;
use sqlx::postgres::PgValueRef; use sqlx::postgres::PgValueRef;
#[repr(u8)] #[repr(u8)]
#[derive(Debug, Serialize, Deserialize, Copy, Clone)] #[derive(Debug, Serialize, Deserialize, Copy, Clone, Eq, PartialEq)]
pub enum ChannelMessageType { pub enum ChannelMessageType {
Text = 0, Text = 0,
Voice = 1, Voice = 1,

View File

@ -7,7 +7,7 @@ use sqlx::error::BoxDynError;
use sqlx::postgres::PgValueRef; use sqlx::postgres::PgValueRef;
#[repr(u8)] #[repr(u8)]
#[derive(Debug, Serialize, Deserialize, Copy, Clone)] #[derive(Debug, Serialize, Deserialize, Copy, Clone, Eq, PartialEq)]
pub enum TextMessageType { pub enum TextMessageType {
Text = 0, Text = 0,
Reaction = 1, Reaction = 1,