Skip to content

Commit

Permalink
Merge pull request #237 from elbart/master
Browse files Browse the repository at this point in the history
Introduce optional serde support for model code generation
  • Loading branch information
tyt2y3 authored Oct 14, 2021
2 parents 75882b3 + d930612 commit fad881b
Show file tree
Hide file tree
Showing 11 changed files with 688 additions and 23 deletions.
7 changes: 7 additions & 0 deletions sea-orm-cli/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ pub fn build_cli() -> App<'static, 'static> {
.help("Generate entity file of compact format")
.takes_value(false)
.conflicts_with("EXPANDED_FORMAT"),
)
.arg(
Arg::with_name("WITH_SERDE")
.long("with-serde")
.help("Automatically derive serde Serialize / Deserialize traits for the entity (none, serialize, deserialize, both)")
.takes_value(true)
.default_value("none")
),
)
.setting(AppSettings::SubcommandRequiredElseHelp);
Expand Down
15 changes: 10 additions & 5 deletions sea-orm-cli/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use clap::ArgMatches;
use dotenv::dotenv;
use log::LevelFilter;
use sea_orm_codegen::{EntityTransformer, OutputFile};
use std::{error::Error, fmt::Display, fs, io::Write, path::Path, process::Command};
use sea_orm_codegen::{EntityTransformer, OutputFile, WithSerde};
use std::{error::Error, fmt::Display, fs, io::Write, path::Path, process::Command, str::FromStr};

mod cli;

Expand All @@ -26,13 +26,17 @@ async fn run_generate_command(matches: &ArgMatches<'_>) -> Result<(), Box<dyn Er
let url = args.value_of("DATABASE_URL").unwrap();
let output_dir = args.value_of("OUTPUT_DIR").unwrap();
let include_hidden_tables = args.is_present("INCLUDE_HIDDEN_TABLES");
let tables = args.values_of("TABLES").unwrap_or_default().collect::<Vec<_>>();
let tables = args
.values_of("TABLES")
.unwrap_or_default()
.collect::<Vec<_>>();
let expanded_format = args.is_present("EXPANDED_FORMAT");
let with_serde = args.value_of("WITH_SERDE").unwrap();
let filter_tables = |table: &str| -> bool {
if tables.len() > 0 {
return tables.contains(&table);
}

true
};
let filter_hidden_tables = |table: &str| -> bool {
Expand Down Expand Up @@ -84,7 +88,8 @@ async fn run_generate_command(matches: &ArgMatches<'_>) -> Result<(), Box<dyn Er
panic!("This database is not supported ({})", url)
};

let output = EntityTransformer::transform(table_stmts)?.generate(expanded_format);
let output = EntityTransformer::transform(table_stmts)?
.generate(expanded_format, WithSerde::from_str(with_serde).unwrap());

let dir = Path::new(output_dir);
fs::create_dir_all(dir)?;
Expand Down
235 changes: 217 additions & 18 deletions sea-orm-codegen/src/entity/writer.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::str::FromStr;

use crate::Entity;
use proc_macro2::TokenStream;
use quote::quote;
Expand All @@ -17,25 +19,83 @@ pub struct OutputFile {
pub content: String,
}

#[derive(PartialEq, Debug)]
pub enum WithSerde {
None,
Serialize,
Deserialize,
Both,
}

impl WithSerde {
pub fn extra_derive(&self) -> TokenStream {
let mut extra_derive = match self {
Self::None => {
quote! {}
}
Self::Serialize => {
quote! {
Serialize
}
}
Self::Deserialize => {
quote! {
Deserialize
}
}
Self::Both => {
quote! {
Serialize, Deserialize
}
}
};

if !extra_derive.is_empty() {
extra_derive = quote! { , #extra_derive }
}

extra_derive
}
}

impl FromStr for WithSerde {
type Err = crate::Error;

fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(match s {
"none" => Self::None,
"serialize" => Self::Serialize,
"deserialize" => Self::Deserialize,
"both" => Self::Both,
v => {
return Err(crate::Error::TransformError(format!(
"Unsupported enum variant '{}'",
v
)))
}
})
}
}

impl EntityWriter {
pub fn generate(self, expanded_format: bool) -> WriterOutput {
pub fn generate(self, expanded_format: bool, with_serde: WithSerde) -> WriterOutput {
let mut files = Vec::new();
files.extend(self.write_entities(expanded_format));
files.extend(self.write_entities(expanded_format, with_serde));
files.push(self.write_mod());
files.push(self.write_prelude());
WriterOutput { files }
}

pub fn write_entities(&self, expanded_format: bool) -> Vec<OutputFile> {
pub fn write_entities(&self, expanded_format: bool, with_serde: WithSerde) -> Vec<OutputFile> {
self.entities
.iter()
.map(|entity| {
let mut lines = Vec::new();
Self::write_doc_comment(&mut lines);
let code_blocks = if expanded_format {
Self::gen_expanded_code_blocks(entity)
Self::gen_expanded_code_blocks(entity, &with_serde)
} else {
Self::gen_compact_code_blocks(entity)
Self::gen_compact_code_blocks(entity, &with_serde)
};
Self::write(&mut lines, code_blocks);
OutputFile {
Expand Down Expand Up @@ -102,12 +162,12 @@ impl EntityWriter {
lines.push("".to_owned());
}

pub fn gen_expanded_code_blocks(entity: &Entity) -> Vec<TokenStream> {
pub fn gen_expanded_code_blocks(entity: &Entity, with_serde: &WithSerde) -> Vec<TokenStream> {
let mut code_blocks = vec![
Self::gen_import(),
Self::gen_import(with_serde),
Self::gen_entity_struct(),
Self::gen_impl_entity_name(entity),
Self::gen_model_struct(entity),
Self::gen_model_struct(entity, with_serde),
Self::gen_column_enum(entity),
Self::gen_primary_key_enum(entity),
Self::gen_impl_primary_key(entity),
Expand All @@ -121,8 +181,11 @@ impl EntityWriter {
code_blocks
}

pub fn gen_compact_code_blocks(entity: &Entity) -> Vec<TokenStream> {
let mut code_blocks = vec![Self::gen_import(), Self::gen_compact_model_struct(entity)];
pub fn gen_compact_code_blocks(entity: &Entity, with_serde: &WithSerde) -> Vec<TokenStream> {
let mut code_blocks = vec![
Self::gen_import(with_serde),
Self::gen_compact_model_struct(entity, with_serde),
];
let relation_defs = if entity.get_relation_ref_tables_camel_case().is_empty() {
vec![
Self::gen_relation_enum(entity),
Expand All @@ -138,9 +201,33 @@ impl EntityWriter {
code_blocks
}

pub fn gen_import() -> TokenStream {
quote! {
pub fn gen_import(with_serde: &WithSerde) -> TokenStream {
let prelude_import = quote!(
use sea_orm::entity::prelude::*;
);

match with_serde {
WithSerde::None => prelude_import,
WithSerde::Serialize => {
quote! {
#prelude_import
use serde::Serialize;
}
}

WithSerde::Deserialize => {
quote! {
#prelude_import
use serde::Deserialize;
}
}

WithSerde::Both => {
quote! {
#prelude_import
use serde::{Deserialize,Serialize};
}
}
}
}

Expand All @@ -162,11 +249,14 @@ impl EntityWriter {
}
}

pub fn gen_model_struct(entity: &Entity) -> TokenStream {
pub fn gen_model_struct(entity: &Entity, with_serde: &WithSerde) -> TokenStream {
let column_names_snake_case = entity.get_column_names_snake_case();
let column_rs_types = entity.get_column_rs_types();

let extra_derive = with_serde.extra_derive();

quote! {
#[derive(Clone, Debug, PartialEq, DeriveModel, DeriveActiveModel)]
#[derive(Clone, Debug, PartialEq, DeriveModel, DeriveActiveModel #extra_derive)]
pub struct Model {
#(pub #column_names_snake_case: #column_rs_types,)*
}
Expand Down Expand Up @@ -320,7 +410,7 @@ impl EntityWriter {
}
}

pub fn gen_compact_model_struct(entity: &Entity) -> TokenStream {
pub fn gen_compact_model_struct(entity: &Entity, with_serde: &WithSerde) -> TokenStream {
let table_name = entity.table_name.as_str();
let column_names_snake_case = entity.get_column_names_snake_case();
let column_rs_types = entity.get_column_rs_types();
Expand Down Expand Up @@ -365,8 +455,11 @@ impl EntityWriter {
}
})
.collect();

let extra_derive = with_serde.extra_derive();

quote! {
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
#[derive(Clone, Debug, PartialEq, DeriveEntityModel #extra_derive)]
#[sea_orm(table_name = #table_name)]
pub struct Model {
#(
Expand Down Expand Up @@ -396,6 +489,7 @@ impl EntityWriter {
mod tests {
use crate::{
Column, ConjunctRelation, Entity, EntityWriter, PrimaryKey, Relation, RelationType,
WithSerde,
};
use pretty_assertions::assert_eq;
use proc_macro2::TokenStream;
Expand Down Expand Up @@ -693,7 +787,7 @@ mod tests {
}
let content = lines.join("");
let expected: TokenStream = content.parse().unwrap();
let generated = EntityWriter::gen_expanded_code_blocks(entity)
let generated = EntityWriter::gen_expanded_code_blocks(entity, &crate::WithSerde::None)
.into_iter()
.skip(1)
.fold(TokenStream::new(), |mut acc, tok| {
Expand Down Expand Up @@ -733,7 +827,7 @@ mod tests {
}
let content = lines.join("");
let expected: TokenStream = content.parse().unwrap();
let generated = EntityWriter::gen_compact_code_blocks(entity)
let generated = EntityWriter::gen_compact_code_blocks(entity, &crate::WithSerde::None)
.into_iter()
.skip(1)
.fold(TokenStream::new(), |mut acc, tok| {
Expand All @@ -745,4 +839,109 @@ mod tests {

Ok(())
}

#[test]
fn test_gen_with_serde() -> io::Result<()> {
let cake_entity = setup().get(0).unwrap().clone();

assert_eq!(cake_entity.get_table_name_snake_case(), "cake");

// Compact code blocks
assert_serde_variant_results(
&cake_entity,
&(
include_str!("../../tests/compact_with_serde/cake_none.rs").into(),
WithSerde::None,
),
Box::new(EntityWriter::gen_compact_code_blocks),
)?;
assert_serde_variant_results(
&cake_entity,
&(
include_str!("../../tests/compact_with_serde/cake_serialize.rs").into(),
WithSerde::Serialize,
),
Box::new(EntityWriter::gen_compact_code_blocks),
)?;
assert_serde_variant_results(
&cake_entity,
&(
include_str!("../../tests/compact_with_serde/cake_deserialize.rs").into(),
WithSerde::Deserialize,
),
Box::new(EntityWriter::gen_compact_code_blocks),
)?;
assert_serde_variant_results(
&cake_entity,
&(
include_str!("../../tests/compact_with_serde/cake_both.rs").into(),
WithSerde::Both,
),
Box::new(EntityWriter::gen_compact_code_blocks),
)?;

// Expanded code blocks
assert_serde_variant_results(
&cake_entity,
&(
include_str!("../../tests/expanded_with_serde/cake_none.rs").into(),
WithSerde::None,
),
Box::new(EntityWriter::gen_expanded_code_blocks),
)?;
assert_serde_variant_results(
&cake_entity,
&(
include_str!("../../tests/expanded_with_serde/cake_serialize.rs").into(),
WithSerde::Serialize,
),
Box::new(EntityWriter::gen_expanded_code_blocks),
)?;
assert_serde_variant_results(
&cake_entity,
&(
include_str!("../../tests/expanded_with_serde/cake_deserialize.rs").into(),
WithSerde::Deserialize,
),
Box::new(EntityWriter::gen_expanded_code_blocks),
)?;
assert_serde_variant_results(
&cake_entity,
&(
include_str!("../../tests/expanded_with_serde/cake_both.rs").into(),
WithSerde::Both,
),
Box::new(EntityWriter::gen_expanded_code_blocks),
)?;

Ok(())
}

fn assert_serde_variant_results(
cake_entity: &Entity,
entity_serde_variant: &(String, WithSerde),
generator: Box<dyn Fn(&Entity, &WithSerde) -> Vec<TokenStream>>,
) -> io::Result<()> {
let mut reader = BufReader::new(entity_serde_variant.0.as_bytes());
let mut lines: Vec<String> = Vec::new();

reader.read_until(b'\n', &mut Vec::new())?;

let mut line = String::new();
while reader.read_line(&mut line)? > 0 {
lines.push(line.to_owned());
line.clear();
}
let content = lines.join("");
let expected: TokenStream = content.parse().unwrap();
let generated = generator(&cake_entity, &entity_serde_variant.1)
.into_iter()
.fold(TokenStream::new(), |mut acc, tok| {
acc.extend(tok);
acc
});

assert_eq!(expected.to_string(), generated.to_string());
Ok(())
}
}
Loading

0 comments on commit fad881b

Please sign in to comment.