Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multiple types of DBs at once #3278

Closed
wants to merge 10 commits into from
2 changes: 2 additions & 0 deletions sqlx-core/src/any/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ impl Database for Any {
const NAME: &'static str = "Any";

const URL_SCHEMES: &'static [&'static str] = &[];

const TYPE_IMPORT_PATH: &'static str = "sqlx::any::database::Any";
}

// This _may_ be true, depending on the selected database
Expand Down
4 changes: 4 additions & 0 deletions sqlx-core/src/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ pub trait Database: 'static + Sized + Send + Debug {

/// The schemes for database URLs that should match this driver.
const URL_SCHEMES: &'static [&'static str];

// This can be removed once https://github.com/rust-lang/rust/issues/63084 is resolved and type_name is available in const fns.
/// The path to the database-specific type system.
const TYPE_IMPORT_PATH: &'static str;
}

/// A [`Database`] that maintains a client-side cache of prepared statements.
Expand Down
10 changes: 10 additions & 0 deletions sqlx-macros-core/src/query/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use syn::{Expr, LitBool, LitStr, Token};
use syn::{ExprArray, Type};

/// Macro input shared by `query!()` and `query_file!()`
#[derive(Clone)]
pub struct QueryMacroInput {
pub(super) sql: String,

Expand All @@ -19,13 +20,17 @@ pub struct QueryMacroInput {
pub(super) checked: bool,

pub(super) file_path: Option<String>,

// TODO: This should be some type and not a string
pub(super) driver: Option<Type>,
}

enum QuerySrc {
String(String),
File(String),
}

#[derive(Clone)]
pub enum RecordType {
Given(Type),
Scalar,
Expand All @@ -38,6 +43,7 @@ impl Parse for QueryMacroInput {
let mut args: Option<Vec<Expr>> = None;
let mut record_type = RecordType::Generated;
let mut checked = true;
let mut driver = None;

let mut expect_comma = false;

Expand Down Expand Up @@ -82,6 +88,9 @@ impl Parse for QueryMacroInput {
} else if key == "checked" {
let lit_bool = input.parse::<LitBool>()?;
checked = lit_bool.value;
} else if key == "driver" {
// TODO: This should be some actual type and not a string
driver = Some(input.parse::<Type>()?);
} else {
let message = format!("unexpected input key: {key}");
return Err(syn::Error::new_spanned(key, message));
Expand All @@ -104,6 +113,7 @@ impl Parse for QueryMacroInput {
arg_exprs,
checked,
file_path,
driver,
})
}
}
Expand Down
144 changes: 124 additions & 20 deletions sqlx-macros-core/src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pub struct QueryDriver {
db_name: &'static str,
url_schemes: &'static [&'static str],
expand: fn(QueryMacroInput, QueryDataSource) -> crate::Result<TokenStream>,
db_type_name: &'static str,
}

impl QueryDriver {
Expand All @@ -38,41 +39,94 @@ impl QueryDriver {
db_name: DB::NAME,
url_schemes: DB::URL_SCHEMES,
expand: expand_with::<DB>,
db_type_name: DB::TYPE_IMPORT_PATH,
}
}
}

#[derive(Clone)]
pub struct QueryDataSourceUrl<'a> {
database_url: &'a str,
database_url_parsed: Url,
}

impl<'a> From<&'a String> for QueryDataSourceUrl<'a> {
fn from(database_url: &'a String) -> Self {
let database_url_parsed = Url::parse(database_url).expect("invalid URL");

QueryDataSourceUrl {
database_url,
database_url_parsed,
}
}
}

#[derive(Clone)]
pub enum QueryDataSource<'a> {
Live {
database_url: &'a str,
database_url_parsed: Url,
database_urls: Vec<QueryDataSourceUrl<'a>>,
},
Cached(DynQueryData),
}

impl<'a> QueryDataSource<'a> {
pub fn live(database_url: &'a str) -> crate::Result<Self> {
pub fn live(database_urls: Vec<QueryDataSourceUrl<'a>>) -> crate::Result<Self> {
Ok(QueryDataSource::Live {
database_url,
database_url_parsed: database_url.parse()?,
database_urls,
})
}

pub fn matches_driver(&self, driver: &QueryDriver) -> bool {
match self {
Self::Live {
database_url_parsed,
database_urls,
..
} => driver.url_schemes.contains(&database_url_parsed.scheme()),
} => driver.url_schemes.iter().any(|scheme| {
database_urls.iter().any(|url| url.database_url_parsed.scheme() == *scheme)
}),
Self::Cached(dyn_data) => dyn_data.db_name == driver.db_name,
}
}

pub fn get_url_for_schemes(&self, schemes: &[&str]) -> Option<&QueryDataSourceUrl> {
match self {
Self::Live {
database_urls,
..
} => {
for scheme in schemes {
if let Some(url) = database_urls.iter().find(|url| url.database_url_parsed.scheme() == *scheme) {
return Some(url);
}
}
None
}
Self::Cached(_) => {
None
}
}
}

pub fn supported_schemes(&self) -> Vec<&str> {
match self {
Self::Live {
database_urls,
..
} => {
let mut schemes = vec![];
schemes.extend(database_urls.iter().map(|url| url.database_url_parsed.scheme()));
schemes
}
Self::Cached(..) => vec![],
}
}
}

struct Metadata {
#[allow(unused)]
manifest_dir: PathBuf,
offline: bool,
database_url: Option<String>,
database_urls: Vec<String>,
workspace_root: Arc<Mutex<Option<PathBuf>>>,
}

Expand Down Expand Up @@ -139,12 +193,10 @@ static METADATA: Lazy<Metadata> = Lazy::new(|| {
.map(|s| s.eq_ignore_ascii_case("true") || s == "1")
.unwrap_or(false);

let database_url = env("DATABASE_URL").ok();

Metadata {
manifest_dir,
offline,
database_url,
database_urls: env_db_urls(),
workspace_root: Arc::new(Mutex::new(None)),
}
});
Expand All @@ -156,9 +208,11 @@ pub fn expand_input<'a>(
let data_source = match &*METADATA {
Metadata {
offline: false,
database_url: Some(db_url),
database_urls: db_urls,
..
} => QueryDataSource::live(db_url)?,
} => {
QueryDataSource::live(db_urls.iter().map(QueryDataSourceUrl::from).collect())?
},

Metadata { offline, .. } => {
// Try load the cached query metadata file.
Expand Down Expand Up @@ -189,19 +243,64 @@ pub fn expand_input<'a>(
}
};

let mut working_drivers = vec![];

// If the driver was explicitly set, use it directly.
if let Some(input_driver) = input.driver.clone() {
for driver in drivers {
if (driver.expand)(input.clone(), data_source.clone()).is_ok() {
working_drivers.push(driver);
}
}

return match working_drivers.len() {
0 => {
Err(format!(
"no database driver found matching for query; the corresponding Cargo feature may need to be enabled"
).into())
}
1 => {
let driver = working_drivers.pop().unwrap();
(driver.expand)(input, data_source)
}
_ => {
let expansions = working_drivers.iter().map(|driver| {
let driver_name = driver.db_type_name;
let driver_type: Type = syn::parse_str(driver_name).unwrap();
let expanded = (driver.expand)(input.clone(), data_source.clone()).unwrap();
quote! {
impl ProvideQuery<#driver_type> for #driver_type {
fn provide_query<'a>() -> Query<'a, #driver_type, <#driver_type as sqlx::Database>::Arguments<'a>> {
#expanded
}
}
}
});
Ok(quote! {
{
use sqlx::query::Query;
trait ProvideQuery<DB: sqlx::Database> {
fn provide_query<'a>() -> Query<'a, DB, DB::Arguments<'a>>;
}
#(#expansions)*
#input_driver::provide_query()
}
})
}
}
}

// If no driver was set, try to find a matching driver for the data source.
for driver in drivers {
if data_source.matches_driver(driver) {
return (driver.expand)(input, data_source);
}
}

match data_source {
QueryDataSource::Live {
database_url_parsed,
..
} => Err(format!(
QueryDataSource::Live{..} => Err(format!(
"no database driver found matching URL scheme {:?}; the corresponding Cargo feature may need to be enabled",
database_url_parsed.scheme()
data_source.supported_schemes()
).into()),
QueryDataSource::Cached(data) => {
Err(format!(
Expand All @@ -221,8 +320,9 @@ where
{
let (query_data, offline): (QueryData<DB>, bool) = match data_source {
QueryDataSource::Cached(dyn_data) => (QueryData::from_dyn_data(dyn_data)?, true),
QueryDataSource::Live { database_url, .. } => {
let describe = DB::describe_blocking(&input.sql, database_url)?;
QueryDataSource::Live { .. } => {
let data_source_url = data_source.get_url_for_schemes(DB::URL_SCHEMES).unwrap();
let describe = DB::describe_blocking(&input.sql, data_source_url.database_url)?;
(QueryData::from_describe(&input.sql, describe), false)
}
};
Expand Down Expand Up @@ -386,3 +486,7 @@ fn env(name: &str) -> Result<String, std::env::VarError> {
std::env::var(name)
}
}

fn env_db_urls() -> Vec<String> {
std::env::vars().filter(|(k, _)| k.starts_with("DATABASE_URL")).map(|(_, v)| v).collect()
}
2 changes: 2 additions & 0 deletions sqlx-mysql/src/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ impl Database for MySql {
const NAME: &'static str = "MySQL";

const URL_SCHEMES: &'static [&'static str] = &["mysql", "mariadb"];

const TYPE_IMPORT_PATH: &'static str = "sqlx::mysql::MySql";
}

impl HasStatementCache for MySql {}
2 changes: 2 additions & 0 deletions sqlx-postgres/src/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ impl Database for Postgres {
const NAME: &'static str = "PostgreSQL";

const URL_SCHEMES: &'static [&'static str] = &["postgres", "postgresql"];

const TYPE_IMPORT_PATH: &'static str = "sqlx::postgres::Postgres";
}

impl HasStatementCache for Postgres {}
2 changes: 2 additions & 0 deletions sqlx-sqlite/src/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ impl Database for Sqlite {
const NAME: &'static str = "SQLite";

const URL_SCHEMES: &'static [&'static str] = &["sqlite"];

const TYPE_IMPORT_PATH: &'static str = "sqlx::sqlite::Sqlite";
}

impl HasStatementCache for Sqlite {}
6 changes: 6 additions & 0 deletions src/macros/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,9 @@ macro_rules! query (
($query:expr) => ({
$crate::sqlx_macros::expand_query!(source = $query)
});
($driver:ty, $query:expr) => ({
$crate::sqlx_macros::expand_query!(source = $query, driver = $driver)
});
// RFC: this semantically should be `$($args:expr),*` (with `$(,)?` to allow trailing comma)
// but that doesn't work in 1.45 because `expr` fragments get wrapped in a way that changes
// their hygiene, which is fixed in 1.46 so this is technically just a temp. workaround.
Expand All @@ -326,6 +329,9 @@ macro_rules! query (
// not like it makes them magically understandable at-a-glance.
($query:expr, $($args:tt)*) => ({
$crate::sqlx_macros::expand_query!(source = $query, args = [$($args)*])
});
($driver:ty, $query:expr, $($args:tt)*) => ({
$crate::sqlx_macros::expand_query!(source = $query, args = [$($args)*], driver = $driver)
})
);

Expand Down