Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce optional serde support for model code generation #237

Merged
merged 2 commits into from
Oct 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions sea-orm-cli/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ pub fn build_cli() -> App<'static, 'static> {
.help("Generate entity file of compact format")
.takes_value(false)
.conflicts_with("EXPANDED_FORMAT"),
)
.arg(
Arg::with_name("WITH_SERDE")
.long("with-serde")
.help("Automatically derive serde Serialize / Deserialize traits for the entity (none, serialize, deserialize, both)")
.takes_value(true)
.default_value("none")
),
)
.setting(AppSettings::SubcommandRequiredElseHelp);
Expand Down
15 changes: 10 additions & 5 deletions sea-orm-cli/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use clap::ArgMatches;
use dotenv::dotenv;
use log::LevelFilter;
use sea_orm_codegen::{EntityTransformer, OutputFile};
use std::{error::Error, fmt::Display, fs, io::Write, path::Path, process::Command};
use sea_orm_codegen::{EntityTransformer, OutputFile, WithSerde};
use std::{error::Error, fmt::Display, fs, io::Write, path::Path, process::Command, str::FromStr};

mod cli;

Expand All @@ -26,13 +26,17 @@ async fn run_generate_command(matches: &ArgMatches<'_>) -> Result<(), Box<dyn Er
let url = args.value_of("DATABASE_URL").unwrap();
let output_dir = args.value_of("OUTPUT_DIR").unwrap();
let include_hidden_tables = args.is_present("INCLUDE_HIDDEN_TABLES");
let tables = args.values_of("TABLES").unwrap_or_default().collect::<Vec<_>>();
let tables = args
.values_of("TABLES")
.unwrap_or_default()
.collect::<Vec<_>>();
let expanded_format = args.is_present("EXPANDED_FORMAT");
let with_serde = args.value_of("WITH_SERDE").unwrap();
let filter_tables = |table: &str| -> bool {
if tables.len() > 0 {
return tables.contains(&table);
}

true
};
let filter_hidden_tables = |table: &str| -> bool {
Expand Down Expand Up @@ -84,7 +88,8 @@ async fn run_generate_command(matches: &ArgMatches<'_>) -> Result<(), Box<dyn Er
panic!("This database is not supported ({})", url)
};

let output = EntityTransformer::transform(table_stmts)?.generate(expanded_format);
let output = EntityTransformer::transform(table_stmts)?
.generate(expanded_format, WithSerde::from_str(with_serde).unwrap());

let dir = Path::new(output_dir);
fs::create_dir_all(dir)?;
Expand Down
235 changes: 217 additions & 18 deletions sea-orm-codegen/src/entity/writer.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::str::FromStr;

use crate::Entity;
use proc_macro2::TokenStream;
use quote::quote;
Expand All @@ -17,25 +19,83 @@ pub struct OutputFile {
pub content: String,
}

#[derive(PartialEq, Debug)]
pub enum WithSerde {
None,
Serialize,
Deserialize,
Both,
}

impl WithSerde {
pub fn extra_derive(&self) -> TokenStream {
let mut extra_derive = match self {
Self::None => {
quote! {}
}
Self::Serialize => {
quote! {
Serialize
}
}
Self::Deserialize => {
quote! {
Deserialize
}
}
Self::Both => {
quote! {
Serialize, Deserialize
}
}
};

if !extra_derive.is_empty() {
extra_derive = quote! { , #extra_derive }
}

extra_derive
}
}

impl FromStr for WithSerde {
type Err = crate::Error;

fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(match s {
"none" => Self::None,
"serialize" => Self::Serialize,
"deserialize" => Self::Deserialize,
"both" => Self::Both,
v => {
return Err(crate::Error::TransformError(format!(
"Unsupported enum variant '{}'",
v
)))
}
})
}
}

impl EntityWriter {
pub fn generate(self, expanded_format: bool) -> WriterOutput {
pub fn generate(self, expanded_format: bool, with_serde: WithSerde) -> WriterOutput {
elbart marked this conversation as resolved.
Show resolved Hide resolved
let mut files = Vec::new();
files.extend(self.write_entities(expanded_format));
files.extend(self.write_entities(expanded_format, with_serde));
files.push(self.write_mod());
files.push(self.write_prelude());
WriterOutput { files }
}

pub fn write_entities(&self, expanded_format: bool) -> Vec<OutputFile> {
pub fn write_entities(&self, expanded_format: bool, with_serde: WithSerde) -> Vec<OutputFile> {
self.entities
.iter()
.map(|entity| {
let mut lines = Vec::new();
Self::write_doc_comment(&mut lines);
let code_blocks = if expanded_format {
Self::gen_expanded_code_blocks(entity)
Self::gen_expanded_code_blocks(entity, &with_serde)
} else {
Self::gen_compact_code_blocks(entity)
Self::gen_compact_code_blocks(entity, &with_serde)
};
Self::write(&mut lines, code_blocks);
OutputFile {
Expand Down Expand Up @@ -102,12 +162,12 @@ impl EntityWriter {
lines.push("".to_owned());
}

pub fn gen_expanded_code_blocks(entity: &Entity) -> Vec<TokenStream> {
pub fn gen_expanded_code_blocks(entity: &Entity, with_serde: &WithSerde) -> Vec<TokenStream> {
let mut code_blocks = vec![
Self::gen_import(),
Self::gen_import(with_serde),
Self::gen_entity_struct(),
Self::gen_impl_entity_name(entity),
Self::gen_model_struct(entity),
Self::gen_model_struct(entity, with_serde),
Self::gen_column_enum(entity),
Self::gen_primary_key_enum(entity),
Self::gen_impl_primary_key(entity),
Expand All @@ -121,8 +181,11 @@ impl EntityWriter {
code_blocks
}

pub fn gen_compact_code_blocks(entity: &Entity) -> Vec<TokenStream> {
let mut code_blocks = vec![Self::gen_import(), Self::gen_compact_model_struct(entity)];
pub fn gen_compact_code_blocks(entity: &Entity, with_serde: &WithSerde) -> Vec<TokenStream> {
let mut code_blocks = vec![
Self::gen_import(with_serde),
Self::gen_compact_model_struct(entity, with_serde),
];
let relation_defs = if entity.get_relation_ref_tables_camel_case().is_empty() {
vec![
Self::gen_relation_enum(entity),
Expand All @@ -138,9 +201,33 @@ impl EntityWriter {
code_blocks
}

pub fn gen_import() -> TokenStream {
quote! {
pub fn gen_import(with_serde: &WithSerde) -> TokenStream {
let prelude_import = quote!(
use sea_orm::entity::prelude::*;
);

match with_serde {
WithSerde::None => prelude_import,
WithSerde::Serialize => {
quote! {
#prelude_import
use serde::Serialize;
}
}

WithSerde::Deserialize => {
quote! {
#prelude_import
use serde::Deserialize;
}
}

WithSerde::Both => {
quote! {
#prelude_import
use serde::{Deserialize,Serialize};
}
}
}
}

Expand All @@ -162,11 +249,14 @@ impl EntityWriter {
}
}

pub fn gen_model_struct(entity: &Entity) -> TokenStream {
pub fn gen_model_struct(entity: &Entity, with_serde: &WithSerde) -> TokenStream {
let column_names_snake_case = entity.get_column_names_snake_case();
let column_rs_types = entity.get_column_rs_types();

let extra_derive = with_serde.extra_derive();

quote! {
#[derive(Clone, Debug, PartialEq, DeriveModel, DeriveActiveModel)]
#[derive(Clone, Debug, PartialEq, DeriveModel, DeriveActiveModel #extra_derive)]
pub struct Model {
#(pub #column_names_snake_case: #column_rs_types,)*
}
Expand Down Expand Up @@ -320,7 +410,7 @@ impl EntityWriter {
}
}

pub fn gen_compact_model_struct(entity: &Entity) -> TokenStream {
pub fn gen_compact_model_struct(entity: &Entity, with_serde: &WithSerde) -> TokenStream {
let table_name = entity.table_name.as_str();
let column_names_snake_case = entity.get_column_names_snake_case();
let column_rs_types = entity.get_column_rs_types();
Expand Down Expand Up @@ -365,8 +455,11 @@ impl EntityWriter {
}
})
.collect();

let extra_derive = with_serde.extra_derive();

quote! {
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
#[derive(Clone, Debug, PartialEq, DeriveEntityModel #extra_derive)]
#[sea_orm(table_name = #table_name)]
pub struct Model {
#(
Expand Down Expand Up @@ -396,6 +489,7 @@ impl EntityWriter {
mod tests {
use crate::{
Column, ConjunctRelation, Entity, EntityWriter, PrimaryKey, Relation, RelationType,
WithSerde,
};
use pretty_assertions::assert_eq;
use proc_macro2::TokenStream;
Expand Down Expand Up @@ -693,7 +787,7 @@ mod tests {
}
let content = lines.join("");
let expected: TokenStream = content.parse().unwrap();
let generated = EntityWriter::gen_expanded_code_blocks(entity)
let generated = EntityWriter::gen_expanded_code_blocks(entity, &crate::WithSerde::None)
.into_iter()
.skip(1)
.fold(TokenStream::new(), |mut acc, tok| {
Expand Down Expand Up @@ -733,7 +827,7 @@ mod tests {
}
let content = lines.join("");
let expected: TokenStream = content.parse().unwrap();
let generated = EntityWriter::gen_compact_code_blocks(entity)
let generated = EntityWriter::gen_compact_code_blocks(entity, &crate::WithSerde::None)
.into_iter()
.skip(1)
.fold(TokenStream::new(), |mut acc, tok| {
Expand All @@ -745,4 +839,109 @@ mod tests {

Ok(())
}

#[test]
fn test_gen_with_serde() -> io::Result<()> {
let cake_entity = setup().get(0).unwrap().clone();

assert_eq!(cake_entity.get_table_name_snake_case(), "cake");

// Compact code blocks
assert_serde_variant_results(
&cake_entity,
&(
include_str!("../../tests/compact_with_serde/cake_none.rs").into(),
WithSerde::None,
),
Box::new(EntityWriter::gen_compact_code_blocks),
)?;
assert_serde_variant_results(
&cake_entity,
&(
include_str!("../../tests/compact_with_serde/cake_serialize.rs").into(),
WithSerde::Serialize,
),
Box::new(EntityWriter::gen_compact_code_blocks),
)?;
assert_serde_variant_results(
&cake_entity,
&(
include_str!("../../tests/compact_with_serde/cake_deserialize.rs").into(),
WithSerde::Deserialize,
),
Box::new(EntityWriter::gen_compact_code_blocks),
)?;
assert_serde_variant_results(
&cake_entity,
&(
include_str!("../../tests/compact_with_serde/cake_both.rs").into(),
WithSerde::Both,
),
Box::new(EntityWriter::gen_compact_code_blocks),
)?;

// Expanded code blocks
assert_serde_variant_results(
&cake_entity,
&(
include_str!("../../tests/expanded_with_serde/cake_none.rs").into(),
WithSerde::None,
),
Box::new(EntityWriter::gen_expanded_code_blocks),
)?;
assert_serde_variant_results(
&cake_entity,
&(
include_str!("../../tests/expanded_with_serde/cake_serialize.rs").into(),
WithSerde::Serialize,
),
Box::new(EntityWriter::gen_expanded_code_blocks),
)?;
assert_serde_variant_results(
&cake_entity,
&(
include_str!("../../tests/expanded_with_serde/cake_deserialize.rs").into(),
WithSerde::Deserialize,
),
Box::new(EntityWriter::gen_expanded_code_blocks),
)?;
assert_serde_variant_results(
&cake_entity,
&(
include_str!("../../tests/expanded_with_serde/cake_both.rs").into(),
WithSerde::Both,
),
Box::new(EntityWriter::gen_expanded_code_blocks),
)?;

Ok(())
}

fn assert_serde_variant_results(
cake_entity: &Entity,
entity_serde_variant: &(String, WithSerde),
generator: Box<dyn Fn(&Entity, &WithSerde) -> Vec<TokenStream>>,
) -> io::Result<()> {
let mut reader = BufReader::new(entity_serde_variant.0.as_bytes());
let mut lines: Vec<String> = Vec::new();

reader.read_until(b'\n', &mut Vec::new())?;

let mut line = String::new();
while reader.read_line(&mut line)? > 0 {
lines.push(line.to_owned());
line.clear();
}
let content = lines.join("");
let expected: TokenStream = content.parse().unwrap();
let generated = generator(&cake_entity, &entity_serde_variant.1)
.into_iter()
.fold(TokenStream::new(), |mut acc, tok| {
acc.extend(tok);
acc
});

assert_eq!(expected.to_string(), generated.to_string());
Ok(())
}
}
Loading