Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate SQL type definitions for unknown types #2787

Merged
merged 9 commits into from
Jun 25, 2021
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ for Rust libraries in [RFC #1105](https://github.com/rust-lang/rfcs/blob/master/

* Added support for SQL functions without arguments for SQLite.

* Diesel CLI will now generate SQL type definitions for SQL types that are not supported by diesel out of the box. It's possible to disable this behavior via the `generate_missing_sql_type_definitions` config option.

### Removed

* All previously deprecated items have been removed.
Expand Down Expand Up @@ -171,8 +173,6 @@ for Rust libraries in [RFC #1105](https://github.com/rust-lang/rfcs/blob/master/
* The `#[table_name]` attribute for derive macros can now refer to any path and is no
longer limited to identifiers from the current scope.

* Interacting with a database requires a mutable connection.

### Fixed

* Many types were incorrectly considered non-aggregate when they should not
Expand Down
22 changes: 19 additions & 3 deletions diesel/src/macros/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,22 @@ pub(crate) mod prelude {
};
}

#[macro_export]
#[doc(hidden)]
macro_rules! __diesel_fix_sql_type_import {
($(use $($import:tt)::+;)*) => {
$(
$crate::__diesel_fix_sql_type_import!(@expand_import: $($import)::+);
)*
};
(@expand_import: super:: $($Type:tt)+) => {
use super::super::$($Type)+;
};
(@expand_import: $($Type:tt)+) => {
use $($Type)+;
}
}

#[macro_export]
#[doc(hidden)]
macro_rules! __diesel_column {
Expand Down Expand Up @@ -616,10 +632,10 @@ macro_rules! __diesel_table_impl {
},)+],
) => {
$($meta)*
#[allow(unused_imports, dead_code)]
pub mod $table_name {
#![allow(dead_code)]
$($imports)*
pub use self::columns::*;
$($imports)*
Ten0 marked this conversation as resolved.
Show resolved Hide resolved

/// Re-exports all of the columns of this table, as well as the
/// table struct renamed to the module name. This is meant to be
Expand Down Expand Up @@ -797,7 +813,7 @@ macro_rules! __diesel_table_impl {
/// Contains all of the columns of this table
pub mod columns {
use super::table;
$($imports)*
$crate::__diesel_fix_sql_type_import!($($imports)*);

#[allow(non_camel_case_types, dead_code)]
#[derive(Debug, Clone, Copy, $crate::query_builder::QueryId)]
Expand Down
1 change: 0 additions & 1 deletion diesel/src/pg/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,6 @@ pub mod sql_types {
pub type BigSerial = crate::sql_types::BigInt;

/// The `UUID` SQL type. This type can only be used with `feature = "uuid"`
/// (uuid <=0.6) or `feature = "uuidv07"` (uuid = 0.7)
///
/// ### [`ToSql`] impls
///
Expand Down
5 changes: 5 additions & 0 deletions diesel_cli/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,11 @@ pub fn build_cli() -> App<'static, 'static> {
.multiple(true)
.number_of_values(1)
.help("A list of types to import for every table, separated by commas."),
)
.arg(
Arg::with_name("generate-custom-type-definitions")
weiznich marked this conversation as resolved.
Show resolved Hide resolved
.long("no-generate-missing-sql-type-definitions")
.help("Generate SQL type definitions for types not provided by diesel"),
);

let config_arg = Arg::with_name("CONFIG_FILE")
Expand Down
6 changes: 6 additions & 0 deletions diesel_cli/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,15 @@ pub struct PrintSchema {
pub patch_file: Option<PathBuf>,
#[serde(default)]
pub import_types: Option<Vec<String>>,
#[serde(default)]
pub generate_missing_sql_type_definitions: Option<bool>,
}

impl PrintSchema {
pub fn generate_missing_sql_type_definitions(&self) -> bool {
self.generate_missing_sql_type_definitions.unwrap_or(true)
}

pub fn schema_name(&self) -> Option<&str> {
self.schema.as_deref()
}
Expand Down
4 changes: 2 additions & 2 deletions diesel_cli/src/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use std::fs::{self, File};
use std::io::Write;
use std::path::Path;

enum Backend {
pub enum Backend {
#[cfg(feature = "postgres")]
Pg,
#[cfg(feature = "sqlite")]
Expand All @@ -26,7 +26,7 @@ enum Backend {
}

impl Backend {
fn for_url(database_url: &str) -> Self {
pub fn for_url(database_url: &str) -> Self {
match database_url {
_ if database_url.starts_with("postgres://")
|| database_url.starts_with("postgresql://") =>
Expand Down
21 changes: 15 additions & 6 deletions diesel_cli/src/infer_schema_internals/data_structures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@ use super::table_data::TableName;
pub struct ColumnInformation {
pub column_name: String,
pub type_name: String,
pub type_schema: Option<String>,
pub nullable: bool,
}

#[derive(Debug)]
#[derive(Debug, PartialEq, Clone)]
pub struct ColumnType {
pub schema: Option<String>,
pub rust_name: String,
pub sql_name: String,
pub is_array: bool,
pub is_nullable: bool,
pub is_unsigned: bool,
Expand Down Expand Up @@ -59,14 +62,20 @@ pub struct ColumnDefinition {
}

impl ColumnInformation {
pub fn new<T, U>(column_name: T, type_name: U, nullable: bool) -> Self
pub fn new<T, U>(
column_name: T,
type_name: U,
type_schema: Option<String>,
nullable: bool,
) -> Self
where
T: Into<String>,
U: Into<String>,
{
ColumnInformation {
column_name: column_name.into(),
type_name: type_name.into(),
type_schema,
nullable,
}
}
Expand All @@ -76,12 +85,12 @@ impl ColumnInformation {
impl<ST, DB> Queryable<ST, DB> for ColumnInformation
where
DB: Backend + UsesInformationSchema,
(String, String, String): FromStaticSqlRow<ST, DB>,
(String, String, Option<String>, String): FromStaticSqlRow<ST, DB>,
{
type Row = (String, String, String);
type Row = (String, String, Option<String>, String);

fn build(row: Self::Row) -> deserialize::Result<Self> {
Ok(ColumnInformation::new(row.0, row.1, row.2 == "YES"))
Ok(ColumnInformation::new(row.0, row.1, row.2, row.3 == "YES"))
}
}

Expand All @@ -93,7 +102,7 @@ where
type Row = (i32, String, String, bool, Option<String>, bool);

fn build(row: Self::Row) -> deserialize::Result<Self> {
Ok(ColumnInformation::new(row.1, row.2, !row.3))
Ok(ColumnInformation::new(row.1, row.2, None, !row.3))
}
}

Expand Down
10 changes: 7 additions & 3 deletions diesel_cli/src/infer_schema_internals/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,17 @@ fn get_column_information(

fn determine_column_type(
attr: &ColumnInformation,
conn: &InferConnection,
conn: &mut InferConnection,
) -> Result<ColumnType, Box<dyn Error + Send + Sync + 'static>> {
match *conn {
#[cfg(feature = "sqlite")]
InferConnection::Sqlite(_) => super::sqlite::determine_column_type(attr),
#[cfg(feature = "postgres")]
InferConnection::Pg(_) => super::pg::determine_column_type(attr),
InferConnection::Pg(ref mut conn) => {
use crate::infer_schema_internals::information_schema::UsesInformationSchema;

super::pg::determine_column_type(attr, diesel::pg::Pg::default_schema(conn)?)
}
#[cfg(feature = "mysql")]
InferConnection::Mysql(_) => super::mysql::determine_column_type(attr),
}
Expand Down Expand Up @@ -206,7 +210,7 @@ pub fn load_table_data(
let column_data = get_column_information(&mut connection, &name, column_sorting)?
.into_iter()
.map(|c| {
let ty = determine_column_type(&c, &connection)?;
let ty = determine_column_type(&c, &mut connection)?;
let rust_name = rust_name_for_sql_name(&c.column_name);

Ok(ColumnDefinition {
Expand Down
49 changes: 42 additions & 7 deletions diesel_cli/src/infer_schema_internals/information_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::error::Error;
use diesel::backend::Backend;
use diesel::deserialize::{FromSql, FromSqlRow};
use diesel::dsl::*;
use diesel::expression::{is_aggregate, QueryMetadata, ValidGrouping};
use diesel::expression::{is_aggregate, MixedAggregates, QueryMetadata, ValidGrouping};
#[cfg(feature = "mysql")]
use diesel::mysql::Mysql;
#[cfg(feature = "postgres")]
Expand All @@ -24,7 +24,16 @@ pub trait UsesInformationSchema: Backend {
+ QueryId
+ QueryFragment<Self>;

type TypeSchema: SelectableExpression<
self::information_schema::columns::table,
SqlType = sql_types::Nullable<sql_types::Text>,
> + ValidGrouping<()>
+ QueryId
+ QueryFragment<Self>;

fn type_column() -> Self::TypeColumn;
fn type_schema() -> Self::TypeSchema;

fn default_schema<C>(conn: &mut C) -> QueryResult<String>
where
C: Connection<Backend = Self>,
Expand All @@ -34,11 +43,16 @@ pub trait UsesInformationSchema: Backend {
#[cfg(feature = "postgres")]
impl UsesInformationSchema for Pg {
type TypeColumn = self::information_schema::columns::udt_name;
type TypeSchema = diesel::dsl::Nullable<self::information_schema::columns::udt_schema>;

fn type_column() -> Self::TypeColumn {
self::information_schema::columns::udt_name
}

fn type_schema() -> Self::TypeSchema {
self::information_schema::columns::udt_schema.nullable()
}

fn default_schema<C>(_conn: &mut C) -> QueryResult<String> {
Ok("public".into())
}
Expand All @@ -50,11 +64,16 @@ sql_function!(fn database() -> VarChar);
#[cfg(feature = "mysql")]
impl UsesInformationSchema for Mysql {
type TypeColumn = self::information_schema::columns::column_type;
type TypeSchema = diesel::dsl::AsExprOf<Option<String>, sql_types::Nullable<sql_types::Text>>;

fn type_column() -> Self::TypeColumn {
self::information_schema::columns::column_type
}

fn type_schema() -> Self::TypeSchema {
None.into_sql()
}

fn default_schema<C>(conn: &mut C) -> QueryResult<String>
where
C: Connection<Backend = Self>,
Expand Down Expand Up @@ -85,6 +104,7 @@ mod information_schema {
__is_nullable -> VarChar,
ordinal_position -> BigInt,
udt_name -> VarChar,
udt_schema -> VarChar,
column_type -> VarChar,
}
}
Expand Down Expand Up @@ -135,11 +155,17 @@ where
SqlTypeOf<(
columns::column_name,
<Conn::Backend as UsesInformationSchema>::TypeColumn,
<Conn::Backend as UsesInformationSchema>::TypeSchema,
columns::__is_nullable,
)>,
Conn::Backend,
>,
is_aggregate::No: MixedAggregates<
<<Conn::Backend as UsesInformationSchema>::TypeSchema as ValidGrouping<()>>::IsAggregate,
Output = is_aggregate::No,
>,
String: FromSql<sql_types::Text, Conn::Backend>,
Option<String>: FromSql<sql_types::Nullable<sql_types::Text>, Conn::Backend>,
Order<
Filter<
Filter<
Expand All @@ -148,6 +174,7 @@ where
(
columns::column_name,
<Conn::Backend as UsesInformationSchema>::TypeColumn,
<Conn::Backend as UsesInformationSchema>::TypeSchema,
columns::__is_nullable,
),
>,
Expand All @@ -165,6 +192,7 @@ where
(
columns::column_name,
<Conn::Backend as UsesInformationSchema>::TypeColumn,
<Conn::Backend as UsesInformationSchema>::TypeSchema,
columns::__is_nullable,
),
>,
Expand All @@ -174,7 +202,12 @@ where
>,
columns::column_name,
>: QueryFragment<Conn::Backend>,
Conn::Backend: QueryMetadata<(sql_types::Text, sql_types::Text, sql_types::Text)>,
Conn::Backend: QueryMetadata<(
sql_types::Text,
sql_types::Text,
sql_types::Nullable<sql_types::Text>,
sql_types::Text,
)>,
{
use self::information_schema::columns::dsl::*;

Expand All @@ -184,8 +217,9 @@ where
};

let type_column = Conn::Backend::type_column();
let type_schema = Conn::Backend::type_schema();
let query = columns
.select((column_name, type_column, __is_nullable))
.select((column_name, type_column, type_schema, __is_nullable))
.filter(table_name.eq(&table.sql_name))
.filter(table_schema.eq(schema_name));
match column_sorting {
Expand Down Expand Up @@ -512,10 +546,11 @@ mod tests {

let table_1 = TableName::new("table_1", "test_schema");
let table_2 = TableName::new("table_2", "test_schema");
let id = ColumnInformation::new("id", "int4", false);
let text_col = ColumnInformation::new("text_col", "varchar", true);
let not_null = ColumnInformation::new("not_null", "text", false);
let array_col = ColumnInformation::new("array_col", "_varchar", false);
let pg_catalog = Some(String::from("pg_catalog"));
let id = ColumnInformation::new("id", "int4", pg_catalog.clone(), false);
let text_col = ColumnInformation::new("text_col", "varchar", pg_catalog.clone(), true);
let not_null = ColumnInformation::new("not_null", "text", pg_catalog.clone(), false);
let array_col = ColumnInformation::new("array_col", "_varchar", pg_catalog.clone(), false);
assert_eq!(
Ok(vec![id, text_col, not_null]),
get_table_data(&mut connection, &table_1, &ColumnSorting::OrdinalPosition)
Expand Down
2 changes: 2 additions & 0 deletions diesel_cli/src/infer_schema_internals/mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ pub fn determine_column_type(
let unsigned = determine_unsigned(&attr.type_name);

Ok(ColumnType {
schema: None,
sql_name: tpe.trim().to_lowercase(),
rust_name: tpe.trim().to_camel_case(),
is_array: false,
is_nullable: attr.nullable,
Expand Down
9 changes: 9 additions & 0 deletions diesel_cli/src/infer_schema_internals/pg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::io::{stderr, Write};

pub fn determine_column_type(
attr: &ColumnInformation,
default_schema: String,
) -> Result<ColumnType, Box<dyn Error + Send + Sync + 'static>> {
let is_array = attr.type_name.starts_with('_');
let tpe = if is_array {
Expand All @@ -30,6 +31,14 @@ pub fn determine_column_type(
}

Ok(ColumnType {
schema: attr.type_schema.as_ref().and_then(|s| {
if s == &default_schema {
None
} else {
Some(s.clone())
}
}),
sql_name: tpe.to_lowercase(),
rust_name: tpe.to_camel_case(),
is_array,
is_nullable: attr.nullable,
Expand Down
4 changes: 3 additions & 1 deletion diesel_cli/src/infer_schema_internals/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,9 @@ pub fn determine_column_type(
};

Ok(ColumnType {
rust_name: path,
schema: None,
rust_name: path.clone(),
sql_name: path,
is_array: false,
is_nullable: attr.nullable,
is_unsigned: false,
Expand Down
Loading