Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions sqlx-cli/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ this new file.
sqlx migrate run
```

Recursively resolve migrations in sub-directories of your migrations folder with `--recursive`:

```bash
sqlx migrate run --recursive
```

Compares the migration history of the running database against the `migrations/` folder and runs
any scripts that are still pending.

Expand Down
16 changes: 11 additions & 5 deletions sqlx-cli/src/opt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,11 +334,16 @@ pub struct AddMigrationOpts {
/// Argument for the migration scripts source.
#[derive(Args, Debug)]
pub struct MigrationSourceOpt {
/// Path to folder containing migrations.
/// Path to the folder containing migrations.
///
/// Defaults to `migrations/` if not specified, but a different default may be set by `sqlx.toml`.
#[clap(long)]
pub source: Option<String>,

/// Recursively resolve migrations in subdirectories of the source directory
/// when set. By default, only files in the top-level source directory are considered.
#[clap(long)]
pub recursive: bool,
}

impl MigrationSourceOpt {
Expand All @@ -351,10 +356,11 @@ impl MigrationSourceOpt {
}

pub async fn resolve(&self, config: &Config) -> Result<Migrator, MigrateError> {
Migrator::new(ResolveWith(
self.resolve_path(config),
config.migrate.to_resolve_config(),
))
{
let mut rc = config.migrate.to_resolve_config();
rc.set_recursive(self.recursive);
Migrator::new(ResolveWith(self.resolve_path(config), rc))
}
.await
}
}
Expand Down
1 change: 1 addition & 0 deletions sqlx-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ hashbrown = "0.15.0"
[dev-dependencies]
sqlx = { workspace = true, features = ["postgres", "sqlite", "mysql", "migrate", "macros", "time", "uuid"] }
tokio = { version = "1", features = ["rt"] }
tempfile = "3.10.1"

[lints]
workspace = true
209 changes: 139 additions & 70 deletions sqlx-core/src/migrate/source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,17 @@ pub struct ResolveError {
#[derive(Debug, Default)]
pub struct ResolveConfig {
ignored_chars: BTreeSet<char>,
// When true, traverse subdirectories of the migrations directory and include
// any files that match the migration filename pattern.
recursive: bool,
}

impl ResolveConfig {
/// Return a default, empty configuration.
pub fn new() -> Self {
ResolveConfig {
ignored_chars: BTreeSet::new(),
recursive: false,
}
}

Expand Down Expand Up @@ -143,6 +147,12 @@ impl ResolveConfig {
pub fn ignored_chars(&self) -> impl Iterator<Item = char> + '_ {
self.ignored_chars.iter().copied()
}

/// Enable or disable recursive directory traversal when resolving migrations.
pub fn set_recursive(&mut self, recursive: bool) -> &mut Self {
self.recursive = recursive;
self
}
}

// FIXME: paths should just be part of `Migration` but we can't add a field backwards compatibly
Expand All @@ -162,92 +172,109 @@ pub fn resolve_blocking_with_config(
source: Some(e),
})?;

let s = fs::read_dir(&path).map_err(|e| ResolveError {
message: format!("error reading migration directory {}", path.display()),
source: Some(e),
})?;

let mut migrations = Vec::new();

for res in s {
let entry = res.map_err(|e| ResolveError {
message: format!(
"error reading contents of migration directory {}",
path.display()
),
fn collect_dir(
dir: &Path,
config: &ResolveConfig,
out: &mut Vec<(Migration, PathBuf)>,
) -> Result<(), ResolveError> {
let s = fs::read_dir(dir).map_err(|e| ResolveError {
message: format!("error reading migration directory {}", dir.display()),
source: Some(e),
})?;

let entry_path = entry.path();

let metadata = fs::metadata(&entry_path).map_err(|e| ResolveError {
message: format!(
"error getting metadata of migration path {}",
entry_path.display()
),
source: Some(e),
})?;

if !metadata.is_file() {
// not a file; ignore
continue;
}

let file_name = entry.file_name();
// This is arguably the wrong choice,
// but it really only matters for parsing the version and description.
//
// Using `.to_str()` and returning an error if the filename is not UTF-8
// would be a breaking change.
let file_name = file_name.to_string_lossy();
for res in s {
let entry = res.map_err(|e| ResolveError {
message: format!(
"error reading contents of migration directory {}",
dir.display()
),
source: Some(e),
})?;

let parts = file_name.splitn(2, '_').collect::<Vec<_>>();
let entry_path = entry.path();

if parts.len() != 2 || !parts[1].ends_with(".sql") {
// not of the format: <VERSION>_<DESCRIPTION>.<REVERSIBLE_DIRECTION>.sql; ignore
continue;
}
let metadata = fs::metadata(&entry_path).map_err(|e| ResolveError {
message: format!(
"error getting metadata of migration path {}",
entry_path.display()
),
source: Some(e),
})?;

let version: i64 = parts[0].parse()
.map_err(|_e| ResolveError {
message: format!("error parsing migration filename {file_name:?}; expected integer version prefix (e.g. `01_foo.sql`)"),
if metadata.is_dir() {
if config.recursive {
collect_dir(&entry_path, config, out)?;
}
continue;
}

if !metadata.is_file() {
continue;
}

let file_name = entry.file_name();
// This is arguably the wrong choice,
// but it really only matters for parsing the version and description.
//
// Using `.to_str()` and returning an error if the filename is not UTF-8
// would be a breaking change.
let file_name = file_name.to_string_lossy();

let parts = file_name.splitn(2, '_').collect::<Vec<_>>();

if parts.len() != 2 || !parts[1].ends_with(".sql") {
// not of the format: <VERSION>_<DESCRIPTION>.<REVERSIBLE_DIRECTION>.sql; ignore
continue;
}

let version: i64 = parts[0].parse().map_err(|_e| ResolveError {
message: format!(
"error parsing migration filename {file_name:?}; expected integer version prefix (e.g. `01_foo.sql`)"
),
source: None,
})?;

let migration_type = MigrationType::from_filename(parts[1]);
let migration_type = MigrationType::from_filename(parts[1]);

// remove the `.sql` and replace `_` with ` `
let description = parts[1]
.trim_end_matches(migration_type.suffix())
.replace('_', " ")
.to_owned();
// remove the `.sql` and replace `_` with ` `
let description = parts[1]
.trim_end_matches(migration_type.suffix())
.replace('_', " ")
.to_owned();

let sql = fs::read_to_string(&entry_path).map_err(|e| ResolveError {
message: format!(
"error reading contents of migration {}: {e}",
entry_path.display()
),
source: Some(e),
})?;
let sql = fs::read_to_string(&entry_path).map_err(|e| ResolveError {
message: format!(
"error reading contents of migration {}: {e}",
entry_path.display()
),
source: Some(e),
})?;

// opt-out of migration transaction
let no_tx = sql.starts_with("-- no-transaction");

let checksum = checksum_with(&sql, &config.ignored_chars);

out.push((
Migration::with_checksum(
version,
Cow::Owned(description),
migration_type,
AssertSqlSafe(sql).into_sql_str(),
checksum.into(),
no_tx,
),
entry_path,
));
}

// opt-out of migration transaction
let no_tx = sql.starts_with("-- no-transaction");

let checksum = checksum_with(&sql, &config.ignored_chars);

migrations.push((
Migration::with_checksum(
version,
Cow::Owned(description),
migration_type,
AssertSqlSafe(sql).into_sql_str(),
checksum.into(),
no_tx,
),
entry_path,
));
Ok(())
}

collect_dir(&path, config, &mut migrations)?;

// Ensure that we are sorted by version in ascending order.
migrations.sort_by_key(|(m, _)| m.version);

Expand Down Expand Up @@ -297,3 +324,45 @@ fn checksum_with_ignored_chars() {

assert_eq!(digest_ignored, digest_stripped);
}

#[cfg(test)]
mod recursive_tests {
use super::*;
use std::fs;

#[test]
fn non_recursive_ignores_subdirs() {
let tmp = tempfile::tempdir().expect("tempdir");
let root = tmp.path();
// top-level migration
fs::write(root.join("1_top.sql"), "-- top\nSELECT 1;\n").expect("write top");
// subdir migration
let sub = root.join("nested");
fs::create_dir(&sub).expect("create nested");
fs::write(sub.join("2_sub.sql"), "-- sub\nSELECT 2;\n").expect("write sub");

let cfg = ResolveConfig::new();
let got = resolve_blocking_with_config(root, &cfg).expect("resolve ok");
// should only see the top-level one
assert_eq!(got.len(), 1);
assert_eq!(got[0].0.version, 1);
}

#[test]
fn recursive_finds_subdirs() {
let tmp = tempfile::tempdir().expect("tempdir");
let root = tmp.path();
fs::write(root.join("1_top.sql"), "-- top\nSELECT 1;\n").expect("write top");
let sub = root.join("nested");
fs::create_dir(&sub).expect("create nested");
fs::write(sub.join("2_sub.sql"), "-- sub\nSELECT 2;\n").expect("write sub");

let mut cfg = ResolveConfig::new();
cfg.set_recursive(true);
let got = resolve_blocking_with_config(root, &cfg).expect("resolve ok");
// should see both, sorted by version
assert_eq!(got.len(), 2);
assert_eq!(got[0].0.version, 1);
assert_eq!(got[1].0.version, 2);
}
}
Loading