Skip to content

Commit

Permalink
Merge branch 'master' into path-params-deserializer
Browse files Browse the repository at this point in the history
  • Loading branch information
m4tx authored Jan 28, 2025
2 parents 5b73ca8 + 68f79b7 commit 621ed98
Show file tree
Hide file tree
Showing 5 changed files with 232 additions and 12 deletions.
5 changes: 3 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,7 @@ thiserror = "2"
time = { version = "0.3.35", default-features = false }
tokio = { version = "1.41", default-features = false }
tower = "0.5.2"
# TODO switch back to the published version when https://github.com/leotaku/tower-livereload/pull/24 is released
tower-livereload = { git = "https://github.com/leotaku/tower-livereload.git", rev = "106cc96f91b11a1eca6d3dfc86be4e766a90a415" }
tower-livereload = "0.9.6"
tower-sessions = { version = "0.13", default-features = false }
tracing = { version = "0.1", default-features = false }
tracing-subscriber = "0.3"
Expand Down
2 changes: 1 addition & 1 deletion cot/src/auth/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ impl DatabaseUser {
})?;

let mut user = Self::new(Auto::auto(), username, &password.into());
user.save(db).await.map_err(AuthError::backend_error)?;
user.insert(db).await.map_err(AuthError::backend_error)?;

Ok(user)
}
Expand Down
170 changes: 163 additions & 7 deletions cot/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,13 @@ use derive_more::{Debug, Deref, Display};
use mockall::automock;
use query::Query;
pub use relations::{ForeignKey, ForeignKeyOnDeletePolicy, ForeignKeyOnUpdatePolicy};
use sea_query::{Iden, IntoColumnRef, ReturningClause, SchemaStatementBuilder, SimpleExpr};
use sea_query::{
Iden, IntoColumnRef, OnConflict, ReturningClause, SchemaStatementBuilder, SimpleExpr,
};
use sea_query_binder::{SqlxBinder, SqlxValues};
use sqlx::{Type, TypeInfo};
use thiserror::Error;
use tracing::debug;
use tracing::{span, trace, Instrument, Level};

#[cfg(feature = "mysql")]
use crate::db::impl_mysql::{DatabaseMySql, MySqlRow, MySqlValueRef};
Expand Down Expand Up @@ -59,6 +61,9 @@ pub enum DatabaseError {
/// Error when applying migrations.
#[error("Error when applying migrations: {0}")]
MigrationError(#[from] migrations::MigrationEngineError),
/// An object could not be found in the database.
#[error("Record with primary key `{primary_key}` not found in the database")]
RecordNotFound { primary_key: DbValue },
/// Foreign Key could not be retrieved from the database because the record
/// was not found.
#[error("Error retrieving a Foreign Key from the database: record not found")]
Expand Down Expand Up @@ -150,16 +155,51 @@ pub trait Model: Sized + Send + 'static {
pk: Self::PrimaryKey,
) -> Result<Option<Self>>;

/// Saves the model to the database.
/// Inserts the model instance to the database, or updates an instance
/// with the same primary key if it already exists.
///
/// To force insert or force update, use the [`Self::insert`] or
/// [`Self::update`] methods instead.
///
/// # Errors
///
/// This method can return an error if the model could not be saved to the
/// database.
/// This method can return an error if the model instance could not be
/// inserted into the database, for instance because the migrations
/// haven't been applied, or there was a problem with the database
/// connection.
async fn save<DB: DatabaseBackend>(&mut self, db: &DB) -> Result<()> {
db.insert_or_update(self).await?;
Ok(())
}

/// Insert the model instance to the database.
///
/// # Errors
///
/// This method can return an error if the model instance could not be
/// inserted into the database, for instance because the migrations
/// haven't been applied, or there was a problem with the database
/// connection.
async fn insert<DB: DatabaseBackend>(&mut self, db: &DB) -> Result<()> {
db.insert(self).await?;
Ok(())
}

/// Update the model instance in the database.
///
/// # Errors
///
/// This method can return an error if the model instance could not be
/// inserted into the database, for instance because the migrations
/// haven't been applied, or there was a problem with the database
/// connection.
///
/// This method can return an error if the model with the given primary key
/// could not be found in the database.
async fn update<DB: DatabaseBackend>(&mut self, db: &DB) -> Result<()> {
db.update(self).await?;
Ok(())
}
}

/// An identifier structure that holds table or column name as a string.
Expand Down Expand Up @@ -539,6 +579,34 @@ impl Database {
/// the database, for instance because the migrations haven't been
/// applied, or there was a problem with the database connection.
pub async fn insert<T: Model>(&self, data: &mut T) -> Result<()> {
let span = span!(Level::TRACE, "insert", table = %T::TABLE_NAME);

Self::insert_or_update_impl(self, data, false)
.instrument(span)
.await
}

/// Inserts a new row into the database, or updates it if a row with the
/// same primary key already exists.
///
/// # Errors
///
/// This method can return an error if the row could not be inserted into
/// the database, for instance because the migrations haven't been
/// applied, or there was a problem with the database connection.
pub async fn insert_or_update<T: Model>(&self, data: &mut T) -> Result<()> {
let span = span!(
Level::TRACE,
"insert_or_update",
table = %T::TABLE_NAME
);

Self::insert_or_update_impl(self, data, true)
.instrument(span)
.await
}

async fn insert_or_update_impl<T: Model>(&self, data: &mut T, update: bool) -> Result<()> {
let column_identifiers = T::COLUMNS
.iter()
.map(|column| Identifier::from(column.name.as_str()));
Expand Down Expand Up @@ -571,7 +639,7 @@ impl Database {

let mut insert_statement = sea_query::Query::insert()
.into_table(T::TABLE_NAME)
.columns(value_identifiers)
.columns(value_identifiers.iter().cloned())
.values(
filtered_values
.into_iter()
Expand All @@ -580,6 +648,13 @@ impl Database {
)?
.or_default_values()
.to_owned();
if update && !value_identifiers.is_empty() {
insert_statement.on_conflict(
OnConflict::column(T::PRIMARY_KEY_NAME)
.update_columns(value_identifiers)
.to_owned(),
);
}

if auto_col_ids.is_empty() {
self.execute_statement(&insert_statement).await?;
Expand Down Expand Up @@ -607,7 +682,76 @@ impl Database {
data.update_from_db(row, &auto_col_ids)?;
}

debug!("Inserted row");
if update {
trace!(primary_key = ?data.primary_key().to_db_field_value(), "Inserted or updated row");
} else {
trace!(primary_key = ?data.primary_key().to_db_field_value(), "Inserted row");
}

Ok(())
}

/// Updates an existing row in a database.
///
/// # Errors
///
/// This method can return an error if the row could not be updated in
/// the database, for instance because the migrations haven't been
/// applied, or there was a problem with the database connection.
///
/// This method can return an error if the row with the given primary key
/// could not be found in the database.
pub async fn update<T: Model>(&self, data: &mut T) -> Result<()> {
let span = span!(
Level::TRACE,
"update",
table = %T::TABLE_NAME,
primary_key = ?data.primary_key().to_db_field_value(),
);

Self::update_impl(self, data).instrument(span).await
}

async fn update_impl<T: Model>(&self, data: &mut T) -> Result<()> {
let column_identifiers = T::COLUMNS
.iter()
.map(|column| Identifier::from(column.name.as_str()));
let value_indices: Vec<_> = T::COLUMNS
.iter()
.enumerate()
.map(|(i, _column)| i)
.collect();
let values = data
.get_values(&value_indices)
.into_iter()
.map(ToDbFieldValue::to_db_field_value);

let mut statement_values = Vec::new();
std::iter::zip(column_identifiers, values).for_each(|(identifier, value)| match value {
DbFieldValue::Auto => {
panic!("Auto values are not supported in update queries");
}
DbFieldValue::Value(value) => {
statement_values.push((identifier, SimpleExpr::Value(value)));
}
});

let primary_key = data
.primary_key()
.to_db_field_value()
.expect_value("primary key cannot be auto when updating");
let update_statement = sea_query::Query::update()
.table(T::TABLE_NAME)
.values(statement_values)
.and_where(sea_query::Expr::col(T::PRIMARY_KEY_NAME).eq(primary_key.clone()))
.to_owned();

let result = self.execute_statement(&update_statement).await?;
if result.rows_affected == RowsNum(0) {
return Err(DatabaseError::RecordNotFound { primary_key });
}

trace!("Updated row");

Ok(())
}
Expand Down Expand Up @@ -841,8 +985,12 @@ impl ColumnTypeMapper for Database {
#[cfg_attr(test, automock)]
#[async_trait]
pub trait DatabaseBackend: Send + Sync {
async fn insert_or_update<T: Model>(&self, data: &mut T) -> Result<()>;

async fn insert<T: Model>(&self, data: &mut T) -> Result<()>;

async fn update<T: Model>(&self, data: &mut T) -> Result<()>;

async fn query<T: Model>(&self, query: &Query<T>) -> Result<Vec<T>>;

async fn get<T: Model>(&self, query: &Query<T>) -> Result<Option<T>>;
Expand All @@ -854,10 +1002,18 @@ pub trait DatabaseBackend: Send + Sync {

#[async_trait]
impl DatabaseBackend for Database {
async fn insert_or_update<T: Model>(&self, data: &mut T) -> Result<()> {
Database::insert_or_update(self, data).await
}

async fn insert<T: Model>(&self, data: &mut T) -> Result<()> {
Database::insert(self, data).await
}

async fn update<T: Model>(&self, data: &mut T) -> Result<()> {
Database::update(self, data).await
}

async fn query<T: Model>(&self, query: &Query<T>) -> Result<Vec<T>> {
Database::query(self, query).await
}
Expand Down
64 changes: 64 additions & 0 deletions cot/tests/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,26 @@ async fn model_crud(test_db: &mut TestDatabase) {

assert_eq!(TestModel::objects().all(&**test_db).await.unwrap(), vec![]);

// Create
let mut model = TestModel {
id: Auto::fixed(1),
name: "test".to_owned(),
};
model.save(&**test_db).await.unwrap();

// Read
let objects = TestModel::objects().all(&**test_db).await.unwrap();
assert_eq!(objects.len(), 1);
assert_eq!(objects[0].name, "test");

// Update (& read again)
model.name = "test2".to_owned();
model.save(&**test_db).await.unwrap();
let objects = TestModel::objects().all(&**test_db).await.unwrap();
assert_eq!(objects.len(), 1);
assert_eq!(objects[0].name, "test2");

// Delete
TestModel::objects()
.filter(<TestModel as Model>::Fields::id.eq(1))
.delete(&**test_db)
Expand All @@ -36,6 +47,59 @@ async fn model_crud(test_db: &mut TestDatabase) {
assert_eq!(TestModel::objects().all(&**test_db).await.unwrap(), vec![]);
}

#[cot_macros::dbtest]
async fn model_insert(test_db: &mut TestDatabase) {
migrate_test_model(&*test_db).await;

// Insert
let mut model = TestModel {
id: Auto::fixed(1),
name: "test".to_owned(),
};
let result = model.insert(&**test_db).await;
assert!(result.is_ok());

// Can't insert the same model instance again
let result = model.insert(&**test_db).await;
assert!(result.is_err());

// Read the model from the database
let objects = TestModel::objects().all(&**test_db).await.unwrap();
assert_eq!(objects.len(), 1);
assert_eq!(objects[0].name, "test");
}

#[cot_macros::dbtest]
async fn model_update(test_db: &mut TestDatabase) {
migrate_test_model(&*test_db).await;

// Insert
let mut model = TestModel {
id: Auto::fixed(1),
name: "test".to_owned(),
};
let result = model.insert(&**test_db).await;
assert!(result.is_ok());

// Update
model.name = "test2".to_owned();
let result = model.update(&**test_db).await;
assert!(result.is_ok());

// Can't update non-existing object
let mut model = TestModel {
id: Auto::fixed(2),
name: "test3".to_owned(),
};
let result = model.update(&**test_db).await;
assert!(result.is_err());

// Read the model from the database
let objects = TestModel::objects().all(&**test_db).await.unwrap();
assert_eq!(objects.len(), 1);
assert_eq!(objects[0].name, "test2");
}

#[cot_macros::dbtest]
async fn model_macro_filtering(test_db: &mut TestDatabase) {
migrate_test_model(&*test_db).await;
Expand Down

0 comments on commit 621ed98

Please sign in to comment.