Skip to content

Commit

Permalink
Generate SQL type definitions for unknown types
Browse files Browse the repository at this point in the history
This commit adds support for generating SQL type definitions for unknown
types to `diesel print-schema`. The basic idea is to generate the
corresponding marker type to always end up with an existing schema.
Especially this does not generate any code that is required for
serializing/deserializing rust values.

Additionally this commit tweaks the `table!` macro to import types from
the parent scope. This allows us to just reference those newly generated
types easily + simplifies the handling of other custom type imports in
my opinion. The old behaviour of having as part of the `table!`
definition remains supported.
  • Loading branch information
weiznich committed May 19, 2021
1 parent c0b6130 commit ff2c751
Show file tree
Hide file tree
Showing 26 changed files with 291 additions and 43 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ for Rust libraries in [RFC #1105](https://github.com/rust-lang/rfcs/blob/master/

* Add support for HAVING clauses.

* Diesel CLI will now generate SQL type definitions for SQL types that are not supported by diesel out of the box. It's possible to disable this behavior via the `generate_missing_sql_type_definitions` config option.

### Removed

* All previously deprecated items have been removed.
Expand Down
3 changes: 3 additions & 0 deletions diesel/src/macros/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,7 @@ macro_rules! __diesel_table_impl {
#![allow(dead_code)]
$($imports)*
pub use self::columns::*;
use super::*;

/// Re-exports all of the columns of this table, as well as the
/// table struct renamed to the module name. This is meant to be
Expand Down Expand Up @@ -798,6 +799,8 @@ macro_rules! __diesel_table_impl {
pub mod columns {
use super::table;
$($imports)*
#[allow(unused_imports)]
use super::super::*;

#[allow(non_camel_case_types, dead_code)]
#[derive(Debug, Clone, Copy, $crate::query_builder::QueryId)]
Expand Down
1 change: 0 additions & 1 deletion diesel/src/pg/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,6 @@ pub mod sql_types {
pub type BigSerial = crate::sql_types::BigInt;

/// The `UUID` SQL type. This type can only be used with `feature = "uuid"`
/// (uuid <=0.6) or `feature = "uuidv07"` (uuid = 0.7)
///
/// ### [`ToSql`] impls
///
Expand Down
7 changes: 7 additions & 0 deletions diesel_cli/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,13 @@ pub fn build_cli() -> App<'static, 'static> {
.multiple(true)
.number_of_values(1)
.help("A list of types to import for every table, separated by commas"),
)
.arg(
Arg::with_name("generate-custom-type-definitions")
.long("generate-custom-type-definitions")
.takes_value(true)
.possible_values(&["true", "false"])
.help("Generate SQL type definitions for types not provided by diesel"),
);

let config_arg = Arg::with_name("CONFIG_FILE")
Expand Down
6 changes: 6 additions & 0 deletions diesel_cli/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,15 @@ pub struct PrintSchema {
pub patch_file: Option<PathBuf>,
#[serde(default)]
pub import_types: Option<Vec<String>>,
#[serde(default)]
pub generate_missing_sql_type_definitions: Option<bool>,
}

impl PrintSchema {
pub fn generate_missing_sql_type_definitions(&self) -> bool {
self.generate_missing_sql_type_definitions.unwrap_or(true)
}

pub fn schema_name(&self) -> Option<&str> {
self.schema.as_deref()
}
Expand Down
4 changes: 2 additions & 2 deletions diesel_cli/src/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use std::fs::{self, File};
use std::io::Write;
use std::path::Path;

enum Backend {
pub enum Backend {
#[cfg(feature = "postgres")]
Pg,
#[cfg(feature = "sqlite")]
Expand All @@ -26,7 +26,7 @@ enum Backend {
}

impl Backend {
fn for_url(database_url: &str) -> Self {
pub fn for_url(database_url: &str) -> Self {
match database_url {
_ if database_url.starts_with("postgres://")
|| database_url.starts_with("postgresql://") =>
Expand Down
3 changes: 2 additions & 1 deletion diesel_cli/src/infer_schema_internals/data_structures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ pub struct ColumnInformation {
pub nullable: bool,
}

#[derive(Debug)]
#[derive(Debug, PartialEq)]
pub struct ColumnType {
pub rust_name: String,
pub sql_name: String,
pub is_array: bool,
pub is_nullable: bool,
pub is_unsigned: bool,
Expand Down
1 change: 1 addition & 0 deletions diesel_cli/src/infer_schema_internals/mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ pub fn determine_column_type(
let unsigned = determine_unsigned(&attr.type_name);

Ok(ColumnType {
sql_name: tpe.trim().to_lowercase(),
rust_name: tpe.trim().to_camel_case(),
is_array: false,
is_nullable: attr.nullable,
Expand Down
1 change: 1 addition & 0 deletions diesel_cli/src/infer_schema_internals/pg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub fn determine_column_type(
}

Ok(ColumnType {
sql_name: tpe.to_lowercase(),
rust_name: tpe.to_camel_case(),
is_array,
is_nullable: attr.nullable,
Expand Down
3 changes: 2 additions & 1 deletion diesel_cli/src/infer_schema_internals/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ pub fn determine_column_type(
};

Ok(ColumnType {
rust_name: path,
rust_name: path.clone(),
sql_name: path,
is_array: false,
is_nullable: attr.nullable,
is_unsigned: false,
Expand Down
4 changes: 4 additions & 0 deletions diesel_cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,10 @@ fn run_infer_schema(matches: &ArgMatches) -> Result<(), Box<dyn Error + Send + S
config.import_types = Some(types);
}

if let Some(generate_types) = matches.value_of("generate-custom-type-definitions") {
config.generate_missing_sql_type_definitions = Some(generate_types.parse()?);
}

run_print_schema(&database_url, &config, &mut stdout())?;
Ok(())
}
Expand Down
170 changes: 150 additions & 20 deletions diesel_cli/src/print_schema.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
use crate::config;

use crate::database::Backend;
use crate::infer_schema_internals::*;

use serde::de::{self, MapAccess, Visitor};
use serde::{Deserialize, Deserializer, Serialize};
use serde_regex::Serde as RegexWrapper;
use std::collections::HashSet;
use std::error::Error;
use std::fmt::{self, Display, Formatter, Write};
use std::io::Write as IoWrite;

const SCHEMA_HEADER: &str = "// @generated automatically by Diesel CLI.\n";
const SCHEMA_HEADER: &str = "// @generated automatically by Diesel CLI.";

type Regex = RegexWrapper<::regex::Regex>;

Expand Down Expand Up @@ -64,6 +66,93 @@ pub fn run_print_schema<W: IoWrite>(
Ok(())
}

fn common_diesel_types(types: &mut HashSet<&str>) {
types.insert("Bool");
types.insert("Integer");
types.insert("SmallInt");
types.insert("BigInt");
types.insert("Binary");
types.insert("Text");
types.insert("Double");
types.insert("Float");
types.insert("Numeric");

// hidden type defs
types.insert("Float4");
types.insert("Smallint");
types.insert("Int2");
types.insert("Int4");
types.insert("Int8");
types.insert("Bigint");
types.insert("Float8");
types.insert("Decimal");
types.insert("VarChar");
types.insert("Varchar");
types.insert("Char");
types.insert("Tinytext");
types.insert("Mediumtext");
types.insert("Longtext");
types.insert("Tinyblob");
types.insert("Blob");
types.insert("Mediumblob");
types.insert("Longblob");
types.insert("Varbinary");
types.insert("Bit");
}

#[cfg(feature = "postgres")]
fn pg_diesel_types() -> HashSet<&'static str> {
let mut types = HashSet::new();
types.insert("Cidr");
types.insert("Date");
types.insert("Inet");
types.insert("Jsonb");
types.insert("MacAddr");
types.insert("Money");
types.insert("Oid");
types.insert("Range");
types.insert("Timestamptz");
types.insert("Uuid");
types.insert("Json");
types.insert("Timestamp");
types.insert("Record");
types.insert("Interval");

// hidden type defs
types.insert("Int4range");
types.insert("Int8range");
types.insert("Daterange");
types.insert("Numrange");
types.insert("Tsrange");
types.insert("Tstzrange");
types.insert("SmallSerial");
types.insert("BigSerial");
types.insert("Serial");
types.insert("Bytea");
types.insert("Bpchar");
types.insert("Macaddr");

common_diesel_types(&mut types);
types
}

#[cfg(feature = "mysql")]
fn mysql_diesel_types() -> HashSet<&'static str> {
let mut types = HashSet::new();
common_diesel_types(&mut types);

types.insert("TinyInt");
types.insert("Tinyint");
types
}

#[cfg(feature = "sqlite")]
fn sqlite_diesel_types() -> HashSet<&'static str> {
let mut types = HashSet::new();
common_diesel_types(&mut types);
types
}

pub fn output_schema(
database_url: &str,
config: &config::PrintSchema,
Expand All @@ -78,17 +167,68 @@ pub fn output_schema(
let table_data = table_names
.into_iter()
.map(|t| load_table_data(database_url, t, &config.column_sorting))
.collect::<Result<_, Box<dyn Error + Send + Sync + 'static>>>()?;
.collect::<Result<Vec<_>, Box<dyn Error + Send + Sync + 'static>>>()?;

let mut out = String::new();
writeln!(out, "{}", SCHEMA_HEADER)?;

if let Some(import_types) = config.import_types() {
for import_type in import_types {
writeln!(out, "use {};", import_type)?;
}
}
writeln!(out)?;

if config.generate_missing_sql_type_definitions() {
let backend = Backend::for_url(database_url);
let diesel_provided_types = match backend {
#[cfg(feature = "postgres")]
Backend::Pg => pg_diesel_types(),
#[cfg(feature = "sqlite")]
Backend::Sqlite => sqlite_diesel_types(),
#[cfg(feature = "mysql")]
Backend::Mysql => mysql_diesel_types(),
};

let mut all_types = table_data
.iter()
.flat_map(|t| t.column_data.iter().map(|c| &c.ty))
.filter(|t| !diesel_provided_types.contains(&t.rust_name as &str))
.collect::<Vec<_>>();

all_types.sort_unstable_by_key(|ty| &ty.rust_name);
all_types.dedup_by_key(|ty| &ty.rust_name);

for t in all_types {
match backend {
#[cfg(feature = "postgres")]
Backend::Pg => {
if config.with_docs {
writeln!(out, "/// The `{}` SQL type", t.rust_name)?;
writeln!(out, "///")?;
writeln!(out, "/// (Automatically generated by Diesel.)")?;
}
writeln!(out, "#[derive(diesel::SqlType)]")?;
writeln!(out, "#[postgres(type_name = \"{}\")]", t.sql_name)?;
writeln!(out, "pub struct {};", t.rust_name)?;
writeln!(out)?;
}
#[cfg(feature = "sqlite")]
Backend::Sqlite => {
unreachable!("We only generate a closed set of types for sqlite")
}
#[cfg(feature = "mysql")]
Backend::Mysql => todo!(),
}
}
}

let definitions = TableDefinitions {
tables: table_data,
fk_constraints: foreign_keys,
include_docs: config.with_docs,
import_types: config.import_types(),
};

let mut out = String::new();
writeln!(out, "{}", SCHEMA_HEADER)?;

if let Some(schema_name) = config.schema_name() {
write!(out, "{}", ModuleDefinition(schema_name, definitions))?;
} else {
Expand All @@ -105,7 +245,7 @@ pub fn output_schema(
Ok(out)
}

struct ModuleDefinition<'a>(&'a str, TableDefinitions<'a>);
struct ModuleDefinition<'a>(&'a str, TableDefinitions);

impl<'a> Display for ModuleDefinition<'a> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
Expand All @@ -119,14 +259,13 @@ impl<'a> Display for ModuleDefinition<'a> {
}
}

struct TableDefinitions<'a> {
struct TableDefinitions {
tables: Vec<TableData>,
fk_constraints: Vec<ForeignKeyConstraint>,
include_docs: bool,
import_types: Option<&'a [String]>,
}

impl<'a> Display for TableDefinitions<'a> {
impl Display for TableDefinitions {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
let mut is_first = true;
for table in &self.tables {
Expand All @@ -141,7 +280,6 @@ impl<'a> Display for TableDefinitions<'a> {
TableDefinition {
table,
include_docs: self.include_docs,
import_types: self.import_types,
}
)?;
}
Expand Down Expand Up @@ -176,7 +314,6 @@ impl<'a> Display for TableDefinitions<'a> {

struct TableDefinition<'a> {
table: &'a TableData,
import_types: Option<&'a [String]>,
include_docs: bool,
}

Expand All @@ -187,13 +324,6 @@ impl<'a> Display for TableDefinition<'a> {
let mut out = PadAdapter::new(f);
writeln!(out)?;

if let Some(types) = self.import_types {
for import in types {
writeln!(out, "use {};", import)?;
}
writeln!(out)?;
}

if self.include_docs {
for d in self.table.docs.lines() {
writeln!(out, "///{}{}", if d.is_empty() { "" } else { " " }, d)?;
Expand Down
Loading

0 comments on commit ff2c751

Please sign in to comment.