Skip to content

Commit

Permalink
start, add ByteStr type
Browse files Browse the repository at this point in the history
  • Loading branch information
ParkMyCar committed Feb 2, 2024
1 parent c893677 commit 080397e
Show file tree
Hide file tree
Showing 14 changed files with 552 additions and 31 deletions.
10 changes: 5 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
[package]
name = "prost"
version = "0.11.9"
version = "0.11.10"
authors = [
"Dan Burkert <dan@danburkert.com>",
"Lucio Franco <luciofranco14@gmail.com",
"Tokio Contributors <team@tokio.rs>",
"Dan Burkert <dan@danburkert.com>",
"Lucio Franco <luciofranco14@gmail.com",
"Tokio Contributors <team@tokio.rs>",
]
license = "Apache-2.0"
repository = "https://github.com/tokio-rs/prost"
Expand Down Expand Up @@ -48,7 +48,7 @@ std = []

[dependencies]
bytes = { version = "1", default-features = false }
prost-derive = { version = "0.11.9", path = "prost-derive", optional = true }
prost-derive = { version = "0.11.10", path = "prost-derive", optional = true }

[dev-dependencies]
criterion = "0.3"
Expand Down
6 changes: 3 additions & 3 deletions prost-build/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "prost-build"
version = "0.11.9"
version = "0.11.10"
authors = [
"Dan Burkert <dan@danburkert.com>",
"Lucio Franco <luciofranco14@gmail.com>",
Expand All @@ -27,8 +27,8 @@ itertools = { version = "0.10", default-features = false, features = ["use_alloc
log = "0.4"
multimap = { version = "0.8", default-features = false }
petgraph = { version = "0.6", default-features = false }
prost = { version = "0.11.9", path = "..", default-features = false }
prost-types = { version = "0.11.9", path = "../prost-types", default-features = false }
prost = { version = "0.11.10", path = "..", default-features = false }
prost-types = { version = "0.11.10", path = "../prost-types", default-features = false }
tempfile = "3"
lazy_static = "1.4.0"
regex = { version = "1.5.5", default-features = false, features = ["std", "unicode-bool"] }
Expand Down
44 changes: 40 additions & 4 deletions prost-build/src/code_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use crate::ast::{Comments, Method, Service};
use crate::extern_paths::ExternPaths;
use crate::ident::{to_snake, to_upper_camel};
use crate::message_graph::MessageGraph;
use crate::{BytesType, Config, MapType};
use crate::{BytesType, Config, MapType, StringType};

#[derive(PartialEq)]
enum Syntax {
Expand Down Expand Up @@ -350,6 +350,17 @@ impl<'a> CodeGenerator<'a> {
.push_str(&format!("={:?}", bytes_type.annotation()));
}

if type_ == Type::String {
let string_type = self
.config
.string_type
.get_first_field(fq_message_name, field.name())
.copied()
.unwrap_or_default();
self.buf
.push_str(&format!("={:?}", string_type.annotation()));
}

match field.label() {
Label::Optional => {
if optional {
Expand Down Expand Up @@ -862,8 +873,6 @@ impl<'a> CodeGenerator<'a> {
}

fn resolve_type(&self, field: &FieldDescriptorProto, fq_message_name: &str) -> String {
let prost_path = self.config.prost_path.as_deref().unwrap_or("::prost");

match field.r#type() {
Type::Float => String::from("f32"),
Type::Double => String::from("f64"),
Expand All @@ -872,7 +881,14 @@ impl<'a> CodeGenerator<'a> {
Type::Int32 | Type::Sfixed32 | Type::Sint32 | Type::Enum => String::from("i32"),
Type::Int64 | Type::Sfixed64 | Type::Sint64 => String::from("i64"),
Type::Bool => String::from("bool"),
Type::String => format!("{}::alloc::string::String", prost_path),
Type::String => self
.config
.string_type
.get_first_field(fq_message_name, field.name())
.copied()
.unwrap_or_default()
.rust_type()
.to_owned(),
Type::Bytes => self
.config
.bytes_type
Expand Down Expand Up @@ -1212,6 +1228,26 @@ impl BytesType {
}
}

impl StringType {
/// The `prost-derive` annotation type corresponding to the bytes type.
fn annotation(&self) -> &'static str {
match self {
StringType::String => "string",
StringType::ByteStr => "byte_str",
StringType::ByteStrUnchecked => "byte_str_unchecked",
}
}

/// The fully-qualified Rust type corresponding to the bytes type.
fn rust_type(&self) -> &'static str {
match self {
StringType::String => "::prost::alloc::string::String",
StringType::ByteStr => "::prost::str::ByteStr",
StringType::ByteStrUnchecked => "::prost::str::ByteStrUnchecked",
}
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
35 changes: 35 additions & 0 deletions prost-build/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,14 @@ impl Default for BytesType {
}
}

#[derive(Clone, Copy, Debug, Default, PartialEq)]
enum StringType {
#[default]
String,
ByteStr,
ByteStrUnchecked,
}

/// Configuration options for Protobuf code generation.
///
/// This configuration builder can be used to set non-default code generation options.
Expand All @@ -243,6 +251,7 @@ pub struct Config {
service_generator: Option<Box<dyn ServiceGenerator>>,
map_type: PathMap<MapType>,
bytes_type: PathMap<BytesType>,
string_type: PathMap<StringType>,
type_attributes: PathMap<String>,
message_attributes: PathMap<String>,
enum_attributes: PathMap<String>,
Expand Down Expand Up @@ -389,6 +398,30 @@ impl Config {
self
}

pub fn bytes_str<I, S>(&mut self, paths: I) -> &mut Self
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
for matcher in paths {
self.string_type
.insert(matcher.as_ref().to_string(), StringType::ByteStr)
}
self
}

pub fn bytes_str_unchecked<I, S>(&mut self, paths: I) -> &mut Self
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
for matcher in paths {
self.string_type
.insert(matcher.as_ref().to_string(), StringType::ByteStrUnchecked)
}
self
}

/// Add additional attribute to matched fields.
///
/// # Arguments
Expand Down Expand Up @@ -1232,6 +1265,7 @@ impl default::Default for Config {
service_generator: None,
map_type: PathMap::default(),
bytes_type: PathMap::default(),
string_type: PathMap::default(),
type_attributes: PathMap::default(),
message_attributes: PathMap::default(),
enum_attributes: PathMap::default(),
Expand Down Expand Up @@ -1259,6 +1293,7 @@ impl fmt::Debug for Config {
.field("service_generator", &self.service_generator.is_some())
.field("map_type", &self.map_type)
.field("bytes_type", &self.bytes_type)
.field("string_type", &self.string_type)
.field("type_attributes", &self.type_attributes)
.field("field_attributes", &self.field_attributes)
.field("prost_types", &self.prost_types)
Expand Down
2 changes: 1 addition & 1 deletion prost-derive/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "prost-derive"
version = "0.11.9"
version = "0.11.10"
authors = [
"Dan Burkert <dan@danburkert.com>",
"Lucio Franco <luciofranco14@gmail.com>",
Expand Down
2 changes: 1 addition & 1 deletion prost-derive/src/field/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ fn key_ty_from_str(s: &str) -> Result<scalar::Ty, Error> {
| scalar::Ty::Sfixed32
| scalar::Ty::Sfixed64
| scalar::Ty::Bool
| scalar::Ty::String => Ok(ty),
| scalar::Ty::String(_) => Ok(ty),
_ => bail!("invalid map key type: {}", s),
}
}
Expand Down
76 changes: 62 additions & 14 deletions prost-derive/src/field/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ impl Field {
Kind::Plain(ref default) | Kind::Required(ref default) => {
let default = default.typed();
match self.ty {
Ty::String | Ty::Bytes(..) => quote!(#ident.clear()),
Ty::String(..) | Ty::Bytes(..) => quote!(#ident.clear()),
_ => quote!(#ident = #default),
}
}
Expand Down Expand Up @@ -391,13 +391,14 @@ pub enum Ty {
Sfixed32,
Sfixed64,
Bool,
String,
String(StringTy),
Bytes(BytesTy),
Enumeration(Path),
}

#[derive(Clone, Debug, PartialEq, Eq)]
#[derive(Default, Clone, Debug, PartialEq, Eq)]
pub enum BytesTy {
#[default]
Vec,
Bytes,
}
Expand All @@ -419,6 +420,41 @@ impl BytesTy {
}
}

#[derive(Default, Clone, Debug, PartialEq, Eq)]
pub enum StringTy {
#[default]
String,
ByteStr,
ByteStrUnchecked,
}

impl StringTy {
fn try_from_str(s: &str) -> Result<Self, Error> {
match s {
"string" => Ok(StringTy::String),
"byte_str" => Ok(StringTy::ByteStr),
"byte_str_unchecked" => Ok(StringTy::ByteStrUnchecked),
_ => bail!("Invalid bytes type: {}", s),
}
}

fn rust_type(&self) -> TokenStream {
match self {
StringTy::String => quote! { ::prost::alloc::string::String },
StringTy::ByteStr => quote! { ::prost::str::ByteStr },
StringTy::ByteStrUnchecked => quote! { ::prost::str::ByteStrUnchecked },
}
}

fn module(&self) -> &'static str {
match self {
StringTy::String => "string",
StringTy::ByteStr => "byte_str",
StringTy::ByteStrUnchecked => "byte_str_unchecked",
}
}
}

impl Ty {
pub fn from_attr(attr: &Meta) -> Result<Option<Ty>, Error> {
let ty = match *attr {
Expand All @@ -435,13 +471,18 @@ impl Ty {
Meta::Path(ref name) if name.is_ident("sfixed32") => Ty::Sfixed32,
Meta::Path(ref name) if name.is_ident("sfixed64") => Ty::Sfixed64,
Meta::Path(ref name) if name.is_ident("bool") => Ty::Bool,
Meta::Path(ref name) if name.is_ident("string") => Ty::String,
Meta::Path(ref name) if name.is_ident("bytes") => Ty::Bytes(BytesTy::Vec),
Meta::Path(ref name) if name.is_ident("string") => Ty::String(StringTy::default()),
Meta::Path(ref name) if name.is_ident("bytes") => Ty::Bytes(BytesTy::default()),
Meta::NameValue(MetaNameValue {
ref path,
lit: Lit::Str(ref l),
..
}) if path.is_ident("bytes") => Ty::Bytes(BytesTy::try_from_str(&l.value())?),
Meta::NameValue(MetaNameValue {
ref path,
lit: Lit::Str(ref l),
..
}) if path.is_ident("string") => Ty::String(StringTy::try_from_str(&l.value())?),
Meta::NameValue(MetaNameValue {
ref path,
lit: Lit::Str(ref l),
Expand Down Expand Up @@ -485,8 +526,8 @@ impl Ty {
"sfixed32" => Ty::Sfixed32,
"sfixed64" => Ty::Sfixed64,
"bool" => Ty::Bool,
"string" => Ty::String,
"bytes" => Ty::Bytes(BytesTy::Vec),
"string" => Ty::String(StringTy::default()),
"bytes" => Ty::Bytes(BytesTy::default()),
s if s.len() > enumeration_len && &s[..enumeration_len] == "enumeration" => {
let s = &s[enumeration_len..].trim();
match s.chars().next() {
Expand Down Expand Up @@ -521,7 +562,7 @@ impl Ty {
Ty::Sfixed32 => "sfixed32",
Ty::Sfixed64 => "sfixed64",
Ty::Bool => "bool",
Ty::String => "string",
Ty::String(..) => "string",
Ty::Bytes(..) => "bytes",
Ty::Enumeration(..) => "enum",
}
Expand All @@ -530,7 +571,7 @@ impl Ty {
// TODO: rename to 'owned_type'.
pub fn rust_type(&self) -> TokenStream {
match self {
Ty::String => quote!(::prost::alloc::string::String),
Ty::String(ty) => ty.rust_type(),
Ty::Bytes(ty) => ty.rust_type(),
_ => self.rust_ref_type(),
}
Expand All @@ -552,7 +593,7 @@ impl Ty {
Ty::Sfixed32 => quote!(i32),
Ty::Sfixed64 => quote!(i64),
Ty::Bool => quote!(bool),
Ty::String => quote!(&str),
Ty::String(..) => quote!(&str),
Ty::Bytes(..) => quote!(&[u8]),
Ty::Enumeration(..) => quote!(i32),
}
Expand All @@ -561,13 +602,14 @@ impl Ty {
pub fn module(&self) -> Ident {
match *self {
Ty::Enumeration(..) => Ident::new("int32", Span::call_site()),
Ty::String(ref sty) => Ident::new(sty.module(), Span::call_site()),
_ => Ident::new(self.as_str(), Span::call_site()),
}
}

/// Returns false if the scalar type is length delimited (i.e., `string` or `bytes`).
pub fn is_numeric(&self) -> bool {
!matches!(self, Ty::String | Ty::Bytes(..))
!matches!(self, Ty::String(..) | Ty::Bytes(..))
}
}

Expand Down Expand Up @@ -659,7 +701,13 @@ impl DefaultValue {
Lit::Int(ref lit) if *ty == Ty::Double => DefaultValue::F64(lit.base10_parse()?),

Lit::Bool(ref lit) if *ty == Ty::Bool => DefaultValue::Bool(lit.value),
Lit::Str(ref lit) if *ty == Ty::String => DefaultValue::String(lit.value()),
Lit::Str(ref lit)
if *ty == Ty::String(StringTy::String)
|| *ty == Ty::String(StringTy::ByteStr)
|| *ty == Ty::String(StringTy::ByteStrUnchecked) =>
{
DefaultValue::String(lit.value())
}
Lit::ByteStr(ref lit)
if *ty == Ty::Bytes(BytesTy::Bytes) || *ty == Ty::Bytes(BytesTy::Vec) =>
{
Expand Down Expand Up @@ -768,7 +816,7 @@ impl DefaultValue {
Ty::Uint64 | Ty::Fixed64 => DefaultValue::U64(0),

Ty::Bool => DefaultValue::Bool(false),
Ty::String => DefaultValue::String(String::new()),
Ty::String(..) => DefaultValue::String(String::new()),
Ty::Bytes(..) => DefaultValue::Bytes(Vec::new()),
Ty::Enumeration(ref path) => DefaultValue::Enumeration(quote!(#path::default())),
}
Expand All @@ -777,7 +825,7 @@ impl DefaultValue {
pub fn owned(&self) -> TokenStream {
match *self {
DefaultValue::String(ref value) if value.is_empty() => {
quote!(::prost::alloc::string::String::new())
quote!(::core::default::Default::default())
}
DefaultValue::String(ref value) => quote!(#value.into()),
DefaultValue::Bytes(ref value) if value.is_empty() => {
Expand Down
Loading

0 comments on commit 080397e

Please sign in to comment.