Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compatibility with diesel master after 72bfb356 #64

Merged
merged 9 commits into from
Dec 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/CI.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
matrix:
rust:
- stable
- 1.40.0
- 1.48.0

services:
postgres:
Expand Down
131 changes: 84 additions & 47 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,40 +13,41 @@ use syn::*;
/// # Attributes
///
/// ## Type attributes
///
/// * `#[PgType = "new_enum"]` specifies postgres name for the enum type. If ommitted, uses the enum's name in snake_case.
/// * `#[PgSchema = "schema"]` specifies the postgres schema containing the enum type. If omitted, diesel uses the default search path, but this can cause problems with caching.
/// * `#[DieselType = "NewEnumMapping"]` specifies the name for the diesel type. If omitted, uses the name + `Mapping`.
///
/// * `#[DieselExistingType = "crate::schema::sql_types::NewEnum"]` specifies the name for the corresponding diesel type that was already created by the diesel CLI. If omitted, uses `crate::schema::sql_types::EnumName`.
/// * `#[DieselType = "NewEnumMapping"]` specifies the name for the diesel type to create for Mysql or Sqlite. If omitted, uses the name + `Mapping`.
/// * `#[DbValueStyle = "snake_case"]` specifies a renaming style from each of the rust enum variants to each of the database variants. Either `camelCase`, `kebab-case`, `PascalCase`, `SCREAMING_SNAKE_CASE`, `snake_case`, `verbatim`. If omitted, uses `snake_case`.
///
/// ## Variant attributes
///
/// * `#[db_rename = "variant"]` specifies the db name for a specific variant.
#[proc_macro_derive(DbEnum, attributes(PgType, PgSchema, DieselType, DbValueStyle, db_rename))]
#[proc_macro_derive(
DbEnum,
attributes(DieselType, DieselExistingType, DbValueStyle, db_rename)
)]
pub fn derive(input: TokenStream) -> TokenStream {
let input: DeriveInput = parse_macro_input!(input as DeriveInput);
let db_type =
type_from_attrs(&input.attrs, "PgType").unwrap_or(input.ident.to_string().to_snake_case());
let db_schema =
type_from_attrs(&input.attrs, "PgSchema");
let diesel_mapping =
let diesel_existing_mapping = type_from_attrs(&input.attrs, "DieselExistingType")
.unwrap_or(format!("crate::schema::sql_types::{}", input.ident));
let new_diesel_mapping =
type_from_attrs(&input.attrs, "DieselType").unwrap_or(format!("{}Mapping", input.ident));

// Maintain backwards compatibility by defaulting to snake case.
let case_style =
type_from_attrs(&input.attrs, "DbValueStyle").unwrap_or("snake_case".to_string());
let case_style = CaseStyle::from_string(&case_style);

let diesel_mapping = Ident::new(diesel_mapping.as_ref(), Span::call_site());
let diesel_existing_mapping: proc_macro2::TokenStream =
diesel_existing_mapping.parse().unwrap();
let new_diesel_mapping = Ident::new(new_diesel_mapping.as_ref(), Span::call_site());
let quoted = if let Data::Enum(syn::DataEnum {
variants: data_variants,
..
}) = input.data
{
generate_derive_enum_impls(
&db_type,
db_schema.as_deref(),
&diesel_mapping,
&diesel_existing_mapping,
&new_diesel_mapping,
case_style,
&input.ident,
&data_variants,
Expand Down Expand Up @@ -103,9 +104,8 @@ impl CaseStyle {
}

fn generate_derive_enum_impls(
db_type: &str,
db_schema: Option<&str>,
diesel_mapping: &Ident,
diesel_existing_mapping: &proc_macro2::TokenStream,
new_diesel_mapping: &Ident,
case_style: CaseStyle,
enum_ty: &Ident,
variants: &syn::punctuated::Punctuated<Variant, syn::token::Comma>,
Expand Down Expand Up @@ -137,30 +137,72 @@ fn generate_derive_enum_impls(
let variants_rs: &[proc_macro2::TokenStream] = &variant_ids;
let variants_db: &[LitByteStr] = &variants_db;

let common_impl =
generate_common_impl(db_type, db_schema, diesel_mapping, enum_ty, variants_rs, variants_db);
let (common_diesel_mapping, common_diesel_mapping_use) =
if cfg!(feature = "mysql") || cfg!(feature = "sqlite") {
let new_diesel_mapping_impl = generate_common_diesel_mapping(new_diesel_mapping);
let common_impls_on_new_diesel_mapping = generate_common_impls(
&quote! { #new_diesel_mapping },
enum_ty,
variants_rs,
variants_db,
);
(
quote! {
#new_diesel_mapping_impl
#common_impls_on_new_diesel_mapping
},
quote! {
pub use self::#modname::#new_diesel_mapping;
},
)
} else {
(quote! {}, quote! {})
};

let pg_impl = if cfg!(feature = "postgres") {
generate_postgres_impl(diesel_mapping, enum_ty, variants_rs, variants_db)
let common_impls_on_existing_diesel_mapping =
generate_common_impls(diesel_existing_mapping, enum_ty, variants_rs, variants_db);
let postgres_impl =
generate_postgres_impl(diesel_existing_mapping, enum_ty, variants_rs, variants_db);
quote! {
#common_impls_on_existing_diesel_mapping
#postgres_impl
}
} else {
quote! {}
};
let mysql_impl = if cfg!(feature = "mysql") {
generate_mysql_impl(diesel_mapping, enum_ty, variants_rs, variants_db)
generate_mysql_impl(new_diesel_mapping, enum_ty, variants_rs, variants_db)
} else {
quote! {}
};
let sqlite_impl = if cfg!(feature = "sqlite") {
generate_sqlite_impl(diesel_mapping, enum_ty, variants_rs, variants_db)
generate_sqlite_impl(new_diesel_mapping, enum_ty, variants_rs, variants_db)
} else {
quote! {}
};

let imports = quote! {
use super::*;
use diesel::Queryable;
use diesel::backend::{self, Backend};
use diesel::expression::AsExpression;
use diesel::expression::bound::Bound;
use diesel::row::Row;
use diesel::sql_types::*;
use diesel::serialize::{self, ToSql, IsNull, Output};
use diesel::deserialize::{self, FromSql};
use diesel::query_builder::QueryId;
use std::io::Write;
};

let quoted = quote! {
pub use self::#modname::#diesel_mapping;
#common_diesel_mapping_use
#[allow(non_snake_case)]
mod #modname {
#common_impl
#imports

#common_diesel_mapping
#pg_impl
#mysql_impl
#sqlite_impl
Expand All @@ -181,33 +223,22 @@ fn stylize_value(value: &str, style: CaseStyle) -> String {
}
}

fn generate_common_impl(
db_type: &str,
db_schema: Option<&str>,
diesel_mapping: &Ident,
fn generate_common_diesel_mapping(new_diesel_mapping: &Ident) -> proc_macro2::TokenStream {
quote! {
#[derive(SqlType, Clone)]
#[mysql_type = "Enum"]
#[sqlite_type = "Text"]
pub struct #new_diesel_mapping;
}
}

fn generate_common_impls(
diesel_mapping: &proc_macro2::TokenStream,
enum_ty: &Ident,
variants_rs: &[proc_macro2::TokenStream],
variants_db: &[LitByteStr],
) -> proc_macro2::TokenStream {
let db_schema = db_schema.into_iter();
quote! {
use super::*;
use diesel::Queryable;
use diesel::backend::{self, Backend};
use diesel::expression::AsExpression;
use diesel::expression::bound::Bound;
use diesel::row::Row;
use diesel::sql_types::*;
use diesel::serialize::{self, ToSql, IsNull, Output};
use diesel::deserialize::{self, FromSql};
use diesel::query_builder::QueryId;
use std::io::Write;

#[derive(SqlType, Clone)]
#[postgres(type_name = #db_type, #(type_schema = #db_schema)*)]
#[mysql_type = "Enum"]
#[sqlite_type = "Text"]
pub struct #diesel_mapping;
impl QueryId for #diesel_mapping {
type QueryId = #diesel_mapping;
const HAS_STATIC_QUERY_ID: bool = true;
Expand Down Expand Up @@ -283,7 +314,7 @@ fn generate_common_impl(
}

fn generate_postgres_impl(
diesel_mapping: &Ident,
diesel_mapping: &proc_macro2::TokenStream,
enum_ty: &Ident,
variants_rs: &[proc_macro2::TokenStream],
variants_db: &[LitByteStr],
Expand All @@ -293,6 +324,12 @@ fn generate_postgres_impl(
use super::*;
use diesel::pg::{Pg, PgValue};

impl Clone for #diesel_mapping {
fn clone(&self) -> Self {
#diesel_mapping
}
}

impl FromSql<#diesel_mapping, Pg> for #enum_ty {
fn from_sql(raw: PgValue) -> deserialize::Result<Self> {
match raw.as_bytes() {
Expand Down
Loading