Skip to content

Add support for TLS in postgres/tokio-postgres using native-tls #353

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Jul 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions refinery/tests/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use barrel::backend::Pg as Sql;
mod postgres {
use assert_cmd::prelude::*;
use predicates::str::contains;
use refinery::config::ConfigDbType;
use refinery::{
config::Config, embed_migrations, error::Kind, Migrate, Migration, Runner, Target,
};
Expand Down Expand Up @@ -728,4 +729,38 @@ mod postgres {
.stdout(contains("applying migration: V3__add_brand_to_cars_table"));
})
}

#[test]
fn migrates_with_tls_enabled() {
run_test(|| {
let mut config = Config::new(ConfigDbType::Postgres)
.set_db_name("postgres")
.set_db_user("postgres")
.set_db_host("localhost")
.set_db_port("5432")
.set_use_tls(true);

let migrations = get_migrations();
let runner = Runner::new(&migrations)
.set_grouped(false)
.set_abort_divergent(true)
.set_abort_missing(true);

let report = runner.run(&mut config).unwrap();

let applied_migrations = report.applied_migrations();
assert_eq!(5, applied_migrations.len());

let last_migration = runner
.get_last_applied_migration(&mut config)
.unwrap()
.unwrap();

assert_eq!(5, last_migration.version());
assert_eq!(migrations[4].name(), last_migration.name());
assert_eq!(migrations[4].checksum(), last_migration.checksum());

assert!(config.use_tls());
});
}
}
34 changes: 34 additions & 0 deletions refinery/tests/tokio_postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ mod tokio_postgres {
use refinery_core::tokio_postgres;
use refinery_core::tokio_postgres::NoTls;
use std::panic::AssertUnwindSafe;
use std::str::FromStr;
use time::OffsetDateTime;

const DEFAULT_TABLE_NAME: &str = "refinery_schema_history";
Expand Down Expand Up @@ -953,4 +954,37 @@ mod tokio_postgres {
})
.await;
}

#[tokio::test]
async fn migrates_with_tls_enabled() {
run_test(async {
let mut config =
Config::from_str("postgres://postgres@localhost:5432/postgres?sslmode=require")
.unwrap();

let migrations = get_migrations();
let runner = Runner::new(&migrations)
.set_grouped(false)
.set_abort_divergent(true)
.set_abort_missing(true);

let report = runner.run_async(&mut config).await.unwrap();

let applied_migrations = report.applied_migrations();
assert_eq!(5, applied_migrations.len());

let last_migration = runner
.get_last_applied_migration_async(&mut config)
.await
.unwrap()
.unwrap();

assert_eq!(5, last_migration.version());
assert_eq!(migrations[4].name(), last_migration.name());
assert_eq!(migrations[4].checksum(), last_migration.checksum());

assert!(config.use_tls());
})
.await;
}
}
9 changes: 6 additions & 3 deletions refinery_core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@ edition = "2021"

[features]
default = []
mysql_async = ["dep:mysql_async"]
postgres = ["dep:postgres", "dep:postgres-native-tls", "dep:native-tls"]
rusqlite-bundled = ["rusqlite", "rusqlite/bundled"]
serde = ["dep:serde"]
tiberius = ["dep:tiberius", "futures", "tokio", "tokio/net"]
tiberius-config = ["tiberius", "tokio", "tokio-util", "serde"]
tokio-postgres = ["dep:tokio-postgres", "tokio", "tokio/rt"]
mysql_async = ["dep:mysql_async"]
serde = ["dep:serde"]
tokio-postgres = ["dep:postgres-native-tls", "dep:native-tls", "dep:tokio-postgres", "tokio", "tokio/rt"]
toml = ["serde", "dep:toml"]

[dependencies]
Expand All @@ -31,6 +32,8 @@ walkdir = "2.3.1"
# allow multiple versions of the same dependency if API is similar
rusqlite = { version = ">= 0.23, <= 0.37", optional = true }
postgres = { version = ">=0.17, <= 0.19", optional = true }
native-tls = { version = "0.2", optional = true }
postgres-native-tls = { version = "0.5", optional = true}
tokio-postgres = { version = ">= 0.5, <= 0.7", optional = true }
mysql = { version = ">= 21.0.0, <= 26", optional = true, default-features = false, features = ["minimal"] }
mysql_async = { version = ">= 0.28, <= 0.35", optional = true, default-features = false, features = ["minimal"] }
Expand Down
95 changes: 89 additions & 6 deletions refinery_core/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@ use crate::Error;
use std::convert::TryFrom;
use std::path::PathBuf;
use std::str::FromStr;
#[cfg(any(
feature = "postgres",
feature = "tokio-postgres",
feature = "tiberius-config"
))]
use std::{borrow::Cow, collections::HashMap};
use url::Url;

// refinery config file used by migrate_from_config if migration from a Config struct is preferred instead of using the macros
Expand Down Expand Up @@ -34,6 +40,8 @@ impl Config {
db_user: None,
db_pass: None,
db_name: None,
#[cfg(any(feature = "postgres", feature = "tokio-postgres"))]
use_tls: false,
#[cfg(feature = "tiberius-config")]
trust_cert: false,
},
Expand Down Expand Up @@ -139,6 +147,11 @@ impl Config {
self.main.db_port.as_deref()
}

#[cfg(any(feature = "postgres", feature = "tokio-postgres"))]
pub fn use_tls(&self) -> bool {
self.main.use_tls
}

pub fn set_db_user(self, db_user: &str) -> Config {
Config {
main: Main {
Expand Down Expand Up @@ -183,6 +196,16 @@ impl Config {
},
}
}

#[cfg(any(feature = "postgres", feature = "tokio-postgres"))]
pub fn set_use_tls(self, use_tls: bool) -> Config {
Config {
main: Main {
use_tls,
..self.main
},
}
}
}

impl TryFrom<Url> for Config {
Expand All @@ -203,13 +226,17 @@ impl TryFrom<Url> for Config {
}
};

#[cfg(any(
feature = "postgres",
feature = "tokio-postgres",
feature = "tiberius-config"
))]
let query_params = url
.query_pairs()
.collect::<HashMap<Cow<'_, str>, Cow<'_, str>>>();

cfg_if::cfg_if! {
if #[cfg(feature = "tiberius-config")] {
use std::{borrow::Cow, collections::HashMap};
let query_params = url
.query_pairs()
.collect::<HashMap< Cow<'_, str>, Cow<'_, str>>>();

let trust_cert = query_params.
get("trust_cert")
.unwrap_or(&Cow::Borrowed("false"))
Expand All @@ -223,6 +250,18 @@ impl TryFrom<Url> for Config {
}
}

#[cfg(any(feature = "postgres", feature = "tokio-postgres"))]
let use_tls = match query_params.get("sslmode") {
Some(Cow::Borrowed("require")) => true,
Some(Cow::Borrowed("disable")) | None => false,
_ => {
return Err(Error::new(
Kind::ConfigError("Invalid sslmode value, please use disable/require".into()),
None,
))
}
};

Ok(Self {
main: Main {
db_type,
Expand All @@ -238,6 +277,8 @@ impl TryFrom<Url> for Config {
db_user: Some(url.username().to_string()),
db_pass: url.password().map(|r| r.to_string()),
db_name: Some(url.path().trim_start_matches('/').to_string()),
#[cfg(any(feature = "postgres", feature = "tokio-postgres"))]
use_tls,
#[cfg(feature = "tiberius-config")]
trust_cert,
},
Expand Down Expand Up @@ -270,8 +311,11 @@ struct Main {
db_user: Option<String>,
db_pass: Option<String>,
db_name: Option<String>,
#[cfg(any(feature = "postgres", feature = "tokio-postgres"))]
#[cfg_attr(feature = "serde", serde(default))]
use_tls: bool,
#[cfg(feature = "tiberius-config")]
#[serde(default)]
#[cfg_attr(feature = "serde", serde(default))]
trust_cert: bool,
}

Expand Down Expand Up @@ -453,6 +497,45 @@ mod tests {
);
}

#[cfg(any(feature = "postgres", feature = "tokio-postgres"))]
#[test]
fn builds_from_sslmode_str() {
use crate::config::ConfigDbType;

let config_disable =
Config::from_str("postgres://root:1234@localhost:5432/refinery?sslmode=disable")
.unwrap();
assert!(!config_disable.use_tls());

let config_require =
Config::from_str("postgres://root:1234@localhost:5432/refinery?sslmode=require")
.unwrap();
assert!(config_require.use_tls());

// Verify that manually created config matches parsed URL config
let manual_config_disable = Config::new(ConfigDbType::Postgres)
.set_db_user("root")
.set_db_pass("1234")
.set_db_host("localhost")
.set_db_port("5432")
.set_db_name("refinery")
.set_use_tls(false);
assert_eq!(config_disable.use_tls(), manual_config_disable.use_tls());

let manual_config_require = Config::new(ConfigDbType::Postgres)
.set_db_user("root")
.set_db_pass("1234")
.set_db_host("localhost")
.set_db_port("5432")
.set_db_name("refinery")
.set_use_tls(true);
assert_eq!(config_require.use_tls(), manual_config_require.use_tls());

let config =
Config::from_str("postgres://root:1234@localhost:5432/refinery?sslmode=invalidvalue");
assert!(config.is_err());
}

#[test]
fn builds_db_env_var_failure() {
std::env::set_var("DATABASE_URL", "this_is_not_a_url");
Expand Down
37 changes: 29 additions & 8 deletions refinery_core/src/drivers/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,16 @@ macro_rules! with_connection {
cfg_if::cfg_if! {
if #[cfg(feature = "postgres")] {
let path = build_db_url("postgresql", &$config);
let conn = postgres::Client::connect(path.as_str(), postgres::NoTls).migration_err("could not connect to database", None)?;

let conn;
if $config.use_tls() {
let connector = native_tls::TlsConnector::new().unwrap();
let connector = postgres_native_tls::MakeTlsConnector::new(connector);
conn = postgres::Client::connect(path.as_str(), connector).migration_err("could not connect to database", None)?;
} else {
conn = postgres::Client::connect(path.as_str(), postgres::NoTls).migration_err("could not connect to database", None)?;
}

$op(conn)
} else {
panic!("tried to migrate from config for a postgresql database, but feature postgres not enabled!");
Expand Down Expand Up @@ -123,13 +132,25 @@ macro_rules! with_connection_async {
cfg_if::cfg_if! {
if #[cfg(feature = "tokio-postgres")] {
let path = build_db_url("postgresql", $config);
let (client, connection ) = tokio_postgres::connect(path.as_str(), tokio_postgres::NoTls).await.migration_err("could not connect to database", None)?;
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("connection error: {}", e);
}
});
$op(client).await
if $config.use_tls() {
let connector = native_tls::TlsConnector::new().unwrap();
let connector = postgres_native_tls::MakeTlsConnector::new(connector);
let (client, connection) = tokio_postgres::connect(path.as_str(), connector).await.migration_err("could not connect to database", None)?;
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("connection error: {}", e);
}
});
$op(client).await
} else {
let (client, connection) = tokio_postgres::connect(path.as_str(), tokio_postgres::NoTls).await.migration_err("could not connect to database", None)?;
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("connection error: {}", e);
}
});
$op(client).await
}
} else {
panic!("tried to migrate async from config for a postgresql database, but tokio-postgres was not enabled!");
}
Expand Down
Loading