diff --git a/Cargo.lock b/Cargo.lock index 61d2e7d7b6..a0da6aa04d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3578,6 +3578,7 @@ dependencies = [ "smallvec", "smol", "sqlx", + "tempfile", "thiserror 2.0.12", "time", "tokio", diff --git a/sqlx-cli/README.md b/sqlx-cli/README.md index b20461b8fd..177680d435 100644 --- a/sqlx-cli/README.md +++ b/sqlx-cli/README.md @@ -60,6 +60,12 @@ this new file. sqlx migrate run ``` +Recursively resolve migrations in sub-directories of your migrations folder with `--recursive`: + +```bash +sqlx migrate run --recursive +``` + Compares the migration history of the running database against the `migrations/` folder and runs any scripts that are still pending. diff --git a/sqlx-cli/src/opt.rs b/sqlx-cli/src/opt.rs index cb09bc2ff5..aa80746ab7 100644 --- a/sqlx-cli/src/opt.rs +++ b/sqlx-cli/src/opt.rs @@ -334,11 +334,16 @@ pub struct AddMigrationOpts { /// Argument for the migration scripts source. #[derive(Args, Debug)] pub struct MigrationSourceOpt { - /// Path to folder containing migrations. + /// Path to the folder containing migrations. /// /// Defaults to `migrations/` if not specified, but a different default may be set by `sqlx.toml`. #[clap(long)] pub source: Option, + + /// Recursively resolve migrations in subdirectories of the source directory + /// when set. By default, only files in the top-level source directory are considered. + #[clap(long)] + pub recursive: bool, } impl MigrationSourceOpt { @@ -351,10 +356,11 @@ impl MigrationSourceOpt { } pub async fn resolve(&self, config: &Config) -> Result { - Migrator::new(ResolveWith( - self.resolve_path(config), - config.migrate.to_resolve_config(), - )) + { + let mut rc = config.migrate.to_resolve_config(); + rc.set_recursive(self.recursive); + Migrator::new(ResolveWith(self.resolve_path(config), rc)) + } .await } } diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index 58c5b67e05..b8ce6626c3 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -107,6 +107,7 @@ hashbrown = "0.15.0" [dev-dependencies] sqlx = { workspace = true, features = ["postgres", "sqlite", "mysql", "migrate", "macros", "time", "uuid"] } tokio = { version = "1", features = ["rt"] } +tempfile = "3.10.1" [lints] workspace = true diff --git a/sqlx-core/src/migrate/source.rs b/sqlx-core/src/migrate/source.rs index 4648e53f1e..4f4d8cc4e6 100644 --- a/sqlx-core/src/migrate/source.rs +++ b/sqlx-core/src/migrate/source.rs @@ -87,6 +87,9 @@ pub struct ResolveError { #[derive(Debug, Default)] pub struct ResolveConfig { ignored_chars: BTreeSet, + // When true, traverse subdirectories of the migrations directory and include + // any files that match the migration filename pattern. + recursive: bool, } impl ResolveConfig { @@ -94,6 +97,7 @@ impl ResolveConfig { pub fn new() -> Self { ResolveConfig { ignored_chars: BTreeSet::new(), + recursive: false, } } @@ -143,6 +147,12 @@ impl ResolveConfig { pub fn ignored_chars(&self) -> impl Iterator + '_ { self.ignored_chars.iter().copied() } + + /// Enable or disable recursive directory traversal when resolving migrations. + pub fn set_recursive(&mut self, recursive: bool) -> &mut Self { + self.recursive = recursive; + self + } } // FIXME: paths should just be part of `Migration` but we can't add a field backwards compatibly @@ -162,92 +172,109 @@ pub fn resolve_blocking_with_config( source: Some(e), })?; - let s = fs::read_dir(&path).map_err(|e| ResolveError { - message: format!("error reading migration directory {}", path.display()), - source: Some(e), - })?; - let mut migrations = Vec::new(); - for res in s { - let entry = res.map_err(|e| ResolveError { - message: format!( - "error reading contents of migration directory {}", - path.display() - ), + fn collect_dir( + dir: &Path, + config: &ResolveConfig, + out: &mut Vec<(Migration, PathBuf)>, + ) -> Result<(), ResolveError> { + let s = fs::read_dir(dir).map_err(|e| ResolveError { + message: format!("error reading migration directory {}", dir.display()), source: Some(e), })?; - let entry_path = entry.path(); - - let metadata = fs::metadata(&entry_path).map_err(|e| ResolveError { - message: format!( - "error getting metadata of migration path {}", - entry_path.display() - ), - source: Some(e), - })?; - - if !metadata.is_file() { - // not a file; ignore - continue; - } - - let file_name = entry.file_name(); - // This is arguably the wrong choice, - // but it really only matters for parsing the version and description. - // - // Using `.to_str()` and returning an error if the filename is not UTF-8 - // would be a breaking change. - let file_name = file_name.to_string_lossy(); + for res in s { + let entry = res.map_err(|e| ResolveError { + message: format!( + "error reading contents of migration directory {}", + dir.display() + ), + source: Some(e), + })?; - let parts = file_name.splitn(2, '_').collect::>(); + let entry_path = entry.path(); - if parts.len() != 2 || !parts[1].ends_with(".sql") { - // not of the format: _..sql; ignore - continue; - } + let metadata = fs::metadata(&entry_path).map_err(|e| ResolveError { + message: format!( + "error getting metadata of migration path {}", + entry_path.display() + ), + source: Some(e), + })?; - let version: i64 = parts[0].parse() - .map_err(|_e| ResolveError { - message: format!("error parsing migration filename {file_name:?}; expected integer version prefix (e.g. `01_foo.sql`)"), + if metadata.is_dir() { + if config.recursive { + collect_dir(&entry_path, config, out)?; + } + continue; + } + + if !metadata.is_file() { + continue; + } + + let file_name = entry.file_name(); + // This is arguably the wrong choice, + // but it really only matters for parsing the version and description. + // + // Using `.to_str()` and returning an error if the filename is not UTF-8 + // would be a breaking change. + let file_name = file_name.to_string_lossy(); + + let parts = file_name.splitn(2, '_').collect::>(); + + if parts.len() != 2 || !parts[1].ends_with(".sql") { + // not of the format: _..sql; ignore + continue; + } + + let version: i64 = parts[0].parse().map_err(|_e| ResolveError { + message: format!( + "error parsing migration filename {file_name:?}; expected integer version prefix (e.g. `01_foo.sql`)" + ), source: None, })?; - let migration_type = MigrationType::from_filename(parts[1]); + let migration_type = MigrationType::from_filename(parts[1]); - // remove the `.sql` and replace `_` with ` ` - let description = parts[1] - .trim_end_matches(migration_type.suffix()) - .replace('_', " ") - .to_owned(); + // remove the `.sql` and replace `_` with ` ` + let description = parts[1] + .trim_end_matches(migration_type.suffix()) + .replace('_', " ") + .to_owned(); - let sql = fs::read_to_string(&entry_path).map_err(|e| ResolveError { - message: format!( - "error reading contents of migration {}: {e}", - entry_path.display() - ), - source: Some(e), - })?; + let sql = fs::read_to_string(&entry_path).map_err(|e| ResolveError { + message: format!( + "error reading contents of migration {}: {e}", + entry_path.display() + ), + source: Some(e), + })?; + + // opt-out of migration transaction + let no_tx = sql.starts_with("-- no-transaction"); + + let checksum = checksum_with(&sql, &config.ignored_chars); + + out.push(( + Migration::with_checksum( + version, + Cow::Owned(description), + migration_type, + AssertSqlSafe(sql).into_sql_str(), + checksum.into(), + no_tx, + ), + entry_path, + )); + } - // opt-out of migration transaction - let no_tx = sql.starts_with("-- no-transaction"); - - let checksum = checksum_with(&sql, &config.ignored_chars); - - migrations.push(( - Migration::with_checksum( - version, - Cow::Owned(description), - migration_type, - AssertSqlSafe(sql).into_sql_str(), - checksum.into(), - no_tx, - ), - entry_path, - )); + Ok(()) } + collect_dir(&path, config, &mut migrations)?; + // Ensure that we are sorted by version in ascending order. migrations.sort_by_key(|(m, _)| m.version); @@ -297,3 +324,45 @@ fn checksum_with_ignored_chars() { assert_eq!(digest_ignored, digest_stripped); } + +#[cfg(test)] +mod recursive_tests { + use super::*; + use std::fs; + + #[test] + fn non_recursive_ignores_subdirs() { + let tmp = tempfile::tempdir().expect("tempdir"); + let root = tmp.path(); + // top-level migration + fs::write(root.join("1_top.sql"), "-- top\nSELECT 1;\n").expect("write top"); + // subdir migration + let sub = root.join("nested"); + fs::create_dir(&sub).expect("create nested"); + fs::write(sub.join("2_sub.sql"), "-- sub\nSELECT 2;\n").expect("write sub"); + + let cfg = ResolveConfig::new(); + let got = resolve_blocking_with_config(root, &cfg).expect("resolve ok"); + // should only see the top-level one + assert_eq!(got.len(), 1); + assert_eq!(got[0].0.version, 1); + } + + #[test] + fn recursive_finds_subdirs() { + let tmp = tempfile::tempdir().expect("tempdir"); + let root = tmp.path(); + fs::write(root.join("1_top.sql"), "-- top\nSELECT 1;\n").expect("write top"); + let sub = root.join("nested"); + fs::create_dir(&sub).expect("create nested"); + fs::write(sub.join("2_sub.sql"), "-- sub\nSELECT 2;\n").expect("write sub"); + + let mut cfg = ResolveConfig::new(); + cfg.set_recursive(true); + let got = resolve_blocking_with_config(root, &cfg).expect("resolve ok"); + // should see both, sorted by version + assert_eq!(got.len(), 2); + assert_eq!(got[0].0.version, 1); + assert_eq!(got[1].0.version, 2); + } +}