Skip to content

Commit

Permalink
fix bincode deserialization (#110)
Browse files Browse the repository at this point in the history
* implement and use a new PublicKeyBytesVisitor for deserializing bytes to a PublicKey
* add a test to cover bincode serialization / deserialization
* return the base64 decoding error instead of unwrapping
  • Loading branch information
steverusso authored Apr 20, 2022
1 parent 91aaa91 commit a95288e
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 5 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ secp256k1-test = { package = "secp256k1", version = "0.20.3", features = ["rand-
clear_on_drop = "0.2"
serde_json = "1.0"
hex-literal = "0.3.3"
bincode = "1.3.3"

[build-dependencies]
libsecp256k1-gen-ecmult = { version = "0.3.0", path = "gen/ecmult" }
Expand Down
31 changes: 26 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,10 +345,10 @@ impl Serialize for PublicKey {
}

#[cfg(feature = "std")]
struct PublicKeyVisitor;
struct PublicKeyStrVisitor;

#[cfg(feature = "std")]
impl<'de> de::Visitor<'de> for PublicKeyVisitor {
impl<'de> de::Visitor<'de> for PublicKeyStrVisitor {
type Value = PublicKey;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
Expand All @@ -360,7 +360,7 @@ impl<'de> de::Visitor<'de> for PublicKeyVisitor {
where
E: de::Error,
{
let value: &[u8] = &base64::decode(value).unwrap();
let value: &[u8] = &base64::decode(value).map_err(|e| E::custom(e))?;
let key_format = match value.len() {
33 => PublicKeyFormat::Compressed,
64 => PublicKeyFormat::Raw,
Expand All @@ -372,16 +372,37 @@ impl<'de> de::Visitor<'de> for PublicKeyVisitor {
}
}

#[cfg(feature = "std")]
struct PublicKeyBytesVisitor;

#[cfg(feature = "std")]
impl<'de> de::Visitor<'de> for PublicKeyBytesVisitor {
type Value = PublicKey;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str(
"a byte slice that is either 33 (compressed), 64 (raw), or 65 bytes in length",
)
}

fn visit_bytes<E>(self, value: &[u8]) -> Result<Self::Value, E>
where
E: de::Error,
{
PublicKey::parse_slice(value, None).map_err(|_e| E::custom(Error::InvalidPublicKey))
}
}

#[cfg(feature = "std")]
impl<'de> Deserialize<'de> for PublicKey {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: de::Deserializer<'de>,
{
if deserializer.is_human_readable() {
deserializer.deserialize_str(PublicKeyVisitor)
deserializer.deserialize_str(PublicKeyStrVisitor)
} else {
deserializer.deserialize_bytes(PublicKeyVisitor)
deserializer.deserialize_bytes(PublicKeyBytesVisitor)
}
}
}
Expand Down
8 changes: 8 additions & 0 deletions tests/serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,11 @@ fn test_deserialize_public_key() {
let pkey: PublicKey = serde_json::from_str(&SERIALIZED_DEBUG_PUBLIC_KEY).unwrap();
assert_eq!(pkey, debug_public_key());
}

#[test]
fn test_public_key_bincode_serde() {
let pkey = debug_public_key();
let serialized_pkey: Vec<u8> = bincode::serialize(&pkey).unwrap();
let pkey2 = bincode::deserialize(&serialized_pkey).unwrap();
assert_eq!(pkey, pkey2);
}

0 comments on commit a95288e

Please sign in to comment.