Skip to content

Commit 516e1d3

Browse files
authored
der_derive: fix derive(BitString): always encode max length (#1733)
Encoded length is now data-independent
1 parent 5d903ea commit 516e1d3

File tree

5 files changed

+138
-35
lines changed

5 files changed

+138
-35
lines changed

der/src/asn1/bit_string.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//! ASN.1 `BIT STRING` support.
22
3-
pub mod fixed_len_bit_string;
3+
pub mod allowed_len_bit_string;
44

55
use crate::{
66
BytesRef, DecodeValue, DerOrd, EncodeValue, Error, ErrorKind, FixedTag, Header, Length, Reader,

der/src/asn1/bit_string/fixed_len_bit_string.rs der/src/asn1/bit_string/allowed_len_bit_string.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,15 @@ use crate::{Error, ErrorKind, Tag};
2727
/// flag4: bool,
2828
/// }
2929
/// ```
30-
pub trait FixedLenBitString {
30+
pub trait AllowedLenBitString {
3131
/// Implementer must specify how many bits are allowed
3232
const ALLOWED_LEN_RANGE: RangeInclusive<u16>;
3333

3434
/// Returns an error if the bitstring is not in expected length range
3535
fn check_bit_len(bit_len: u16) -> Result<(), Error> {
3636
let allowed_len_range = Self::ALLOWED_LEN_RANGE;
3737

38-
// forces allowed range to eg. 3..=4
38+
// forces allowed range to e.g. 3..=4
3939
if !allowed_len_range.contains(&bit_len) {
4040
Err(ErrorKind::Length {
4141
tag: Tag::BitString,

der/src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ mod document;
365365
mod str_owned;
366366

367367
pub use crate::{
368-
asn1::bit_string::fixed_len_bit_string::FixedLenBitString,
368+
asn1::bit_string::allowed_len_bit_string::AllowedLenBitString,
369369
asn1::{AnyRef, Choice, Sequence},
370370
datetime::DateTime,
371371
decode::{Decode, DecodeOwned, DecodeValue},

der/tests/derive.rs

+108-5
Original file line numberDiff line numberDiff line change
@@ -778,7 +778,9 @@ mod bitstring {
778778
assert_eq!(reencoded, BITSTRING_EXAMPLE);
779779
}
780780

781-
/// this BitString will allow only 3..=4 bits
781+
/// this BitString will allow only 3..=4 bits in Decode
782+
///
783+
/// but will always Encode 4 bits
782784
#[derive(BitString)]
783785
pub struct MyBitString3or4 {
784786
pub bit_0: bool,
@@ -844,8 +846,8 @@ mod bitstring {
844846
.to_der()
845847
.unwrap();
846848

847-
// 3 bits used, 5 unused
848-
assert_eq!(encoded_3_zeros, hex!("03 02 05 00"));
849+
// 4 bits used, 4 unused
850+
assert_eq!(encoded_3_zeros, hex!("03 02 04 00"));
849851
}
850852

851853
#[test]
@@ -859,8 +861,8 @@ mod bitstring {
859861
.to_der()
860862
.unwrap();
861863

862-
// 3 bits used, 5 unused
863-
assert_eq!(encoded_3_zeros, hex!("03 02 05 E0"));
864+
// 4 bits used, 4 unused
865+
assert_eq!(encoded_3_zeros, hex!("03 02 04 E0"));
864866
}
865867

866868
#[test]
@@ -892,6 +894,107 @@ mod bitstring {
892894
// 4 bits used, 4 unused
893895
assert_eq!(encoded_4_zeros, hex!("03 02 04 10"));
894896
}
897+
898+
/// ```asn1
899+
/// PasswordFlags ::= BIT STRING {
900+
/// case-sensitive (0),
901+
/// local (1),
902+
/// change-disabled (2),
903+
/// unblock-disabled (3),
904+
/// initialized (4),
905+
/// needs-padding (5),
906+
/// unblockingPassword (6),
907+
/// soPassword (7),
908+
/// disable-allowed (8),
909+
/// integrity-protected (9),
910+
/// confidentiality-protected (10),
911+
/// exchangeRefData (11),
912+
/// resetRetryCounter1 (12),
913+
/// resetRetryCounter2 (13),
914+
/// context-dependent (14),
915+
/// multiStepProtocol (15)
916+
/// }
917+
/// ```
918+
#[derive(Clone, Debug, Eq, PartialEq, BitString)]
919+
pub struct PasswordFlags {
920+
/// case-sensitive (0)
921+
pub case_sensitive: bool,
922+
923+
/// local (1)
924+
pub local: bool,
925+
926+
/// change-disabled (2)
927+
pub change_disabled: bool,
928+
929+
/// unblock-disabled (3)
930+
pub unblock_disabled: bool,
931+
932+
/// initialized (4)
933+
pub initialized: bool,
934+
935+
/// needs-padding (5)
936+
pub needs_padding: bool,
937+
938+
/// unblockingPassword (6)
939+
pub unblocking_password: bool,
940+
941+
/// soPassword (7)
942+
pub so_password: bool,
943+
944+
/// disable-allowed (8)
945+
pub disable_allowed: bool,
946+
947+
/// integrity-protected (9)
948+
pub integrity_protected: bool,
949+
950+
/// confidentiality-protected (10)
951+
pub confidentiality_protected: bool,
952+
953+
/// exchangeRefData (11)
954+
pub exchange_ref_data: bool,
955+
956+
/// Second edition 2016-05-15
957+
/// resetRetryCounter1 (12)
958+
#[asn1(optional = "true")]
959+
pub reset_retry_counter1: bool,
960+
961+
/// resetRetryCounter2 (13)
962+
#[asn1(optional = "true")]
963+
pub reset_retry_counter2: bool,
964+
965+
/// context-dependent (14)
966+
#[asn1(optional = "true")]
967+
pub context_dependent: bool,
968+
969+
/// multiStepProtocol (15)
970+
#[asn1(optional = "true")]
971+
pub multi_step_protocol: bool,
972+
973+
/// fake_bit_for_testing (16)
974+
#[asn1(optional = "true")]
975+
pub fake_bit_for_testing: bool,
976+
}
977+
978+
const PASS_FLAGS_EXAMPLE_IN: &[u8] = &hex!("03 03 04 FF FF");
979+
const PASS_FLAGS_EXAMPLE_OUT: &[u8] = &hex!("03 04 07 FF F0 00");
980+
981+
#[test]
982+
fn decode_short_bitstring_2_bytes() {
983+
let pass_flags = PasswordFlags::from_der(PASS_FLAGS_EXAMPLE_IN).unwrap();
984+
985+
// case-sensitive (0)
986+
assert!(pass_flags.case_sensitive);
987+
988+
// exchangeRefData (11)
989+
assert!(pass_flags.exchange_ref_data);
990+
991+
// resetRetryCounter1 (12)
992+
assert!(!pass_flags.reset_retry_counter1);
993+
994+
let reencoded = pass_flags.to_der().unwrap();
995+
996+
assert_eq!(reencoded, PASS_FLAGS_EXAMPLE_OUT);
997+
}
895998
}
896999
mod infer_default {
8971000
//! When another crate might define a PartialEq for another type, the use of

der_derive/src/bitstring.rs

+26-26
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,26 @@ impl DeriveBitString {
3535

3636
let type_attrs = TypeAttrs::parse(&input.attrs)?;
3737

38-
let fields = data
38+
let fields: Vec<_> = data
3939
.fields
4040
.iter()
4141
.map(|field| BitStringField::new(field, &type_attrs))
4242
.collect::<syn::Result<_>>()?;
4343

44+
let mut started_optionals = false;
45+
for field in &fields {
46+
if !field.attrs.optional {
47+
if started_optionals {
48+
abort!(
49+
input.ident,
50+
"derive `BitString` only supports trailing optional fields one after another",
51+
)
52+
}
53+
} else {
54+
started_optionals = true;
55+
}
56+
}
57+
4458
Ok(Self {
4559
ident: input.ident,
4660
generics: input.generics.clone(),
@@ -75,14 +89,18 @@ impl DeriveBitString {
7589

7690
let mut min_expected_fields: u16 = 0;
7791
let mut max_expected_fields: u16 = 0;
92+
let mut started_optionals = false;
7893
for field in &self.fields {
7994
max_expected_fields += 1;
8095

81-
if !field.attrs.optional {
96+
if field.attrs.optional {
97+
started_optionals = true;
98+
}
99+
if !started_optionals {
82100
min_expected_fields += 1;
83101
}
84102
}
85-
let min_expected_bytes = (min_expected_fields + 7) / 8;
103+
let max_expected_bytes = (max_expected_fields + 7) / 8;
86104

87105
for (i, field) in self.fields.iter().enumerate().rev() {
88106
let field_name = &field.ident;
@@ -115,7 +133,7 @@ impl DeriveBitString {
115133
impl ::der::FixedTag for #ident #ty_generics #where_clause {
116134
const TAG: der::Tag = ::der::Tag::BitString;
117135
}
118-
impl ::der::FixedLenBitString for #ident #ty_generics #where_clause {
136+
impl ::der::AllowedLenBitString for #ident #ty_generics #where_clause {
119137
const ALLOWED_LEN_RANGE: ::core::ops::RangeInclusive<u16> = #min_expected_fields..=#max_expected_fields;
120138
}
121139

@@ -127,7 +145,7 @@ impl DeriveBitString {
127145
header: ::der::Header,
128146
) -> ::core::result::Result<Self, ::der::Error> {
129147
use ::der::{Decode as _, DecodeValue as _, Reader as _};
130-
use ::der::FixedLenBitString as _;
148+
use ::der::AllowedLenBitString as _;
131149

132150

133151
let bs = ::der::asn1::BitStringRef::decode_value(reader, header)?;
@@ -147,33 +165,15 @@ impl DeriveBitString {
147165

148166
impl #impl_generics ::der::EncodeValue for #ident #ty_generics #where_clause {
149167
fn value_len(&self) -> der::Result<der::Length> {
150-
Ok(der::Length::new(#min_expected_bytes + 1))
168+
Ok(der::Length::new(#max_expected_bytes + 1))
151169
}
152170

153171
fn encode_value(&self, writer: &mut impl ::der::Writer) -> ::der::Result<()> {
154172
use ::der::Encode as _;
155-
use der::FixedLenBitString as _;
173+
use ::der::AllowedLenBitString as _;
156174

157175
let arr = [#(#encode_bytes),*];
158-
159-
let min_bits = {
160-
let max_bits = *Self::ALLOWED_LEN_RANGE.end();
161-
let last_byte_bits = (max_bits % 8) as u8;
162-
let bs = ::der::asn1::BitStringRef::new(8 - last_byte_bits, &arr)?;
163-
164-
let mut min_bits = *Self::ALLOWED_LEN_RANGE.start();
165-
166-
// find last lit bit
167-
for bit_index in Self::ALLOWED_LEN_RANGE.rev() {
168-
if bs.get(bit_index as usize).unwrap_or_default() {
169-
min_bits = bit_index + 1;
170-
break;
171-
}
172-
}
173-
min_bits
174-
};
175-
176-
let last_byte_bits = (min_bits % 8) as u8;
176+
let last_byte_bits = (#max_expected_fields % 8) as u8;
177177
let bs = ::der::asn1::BitStringRef::new(8 - last_byte_bits, &arr)?;
178178
bs.encode_value(writer)
179179
}

0 commit comments

Comments
 (0)