diff --git a/Cargo.toml b/Cargo.toml index 01e7bbc..75db9e7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,11 @@ rust-version = "1.47.0" name = "benchmarks" harness = false +[[bench]] +name = "avx2benchmarks" +harness = false +required-features = ["avx2"] + [dev-dependencies] criterion = "0.3.4" rand = "0.6.1" @@ -23,6 +28,7 @@ structopt = "0.3.21" # test fixtures for engine tests rstest = "0.11.0" rstest_reuse = "0.1.3" +lazy_static = "1.4.0" [features] default = ["std"] diff --git a/benches/avx2benchmarks.rs b/benches/avx2benchmarks.rs new file mode 100644 index 0000000..699199f --- /dev/null +++ b/benches/avx2benchmarks.rs @@ -0,0 +1,617 @@ +extern crate base64; +#[macro_use] +extern crate criterion; +extern crate rand; + +#[macro_use] +extern crate lazy_static; + +use std::ops::Deref; + +use base64::display; +use base64::{ + decode_engine, decode_engine_slice, decode_engine_vec, encode_engine, encode_engine_slice, + encode_engine_string, write, engine::DEFAULT_ENGINE, +}; + +use base64::engine::avx2::{AVX2Encoder, AVX2Config}; +use criterion::{black_box, Bencher, BenchmarkId, Criterion, Throughput}; +use rand::{FromEntropy, Rng}; +use std::io::{self, Read, Write}; + +lazy_static! { + static ref AVX2_ENGINE: AVX2Encoder = AVX2Encoder::from_standard(AVX2Config::new()); +} + +fn do_decode_bench(b: &mut Bencher, &size: &usize) { + let mut v: Vec = Vec::with_capacity(size * 3 / 4); + fill(&mut v); + let encoded = encode_engine(&v, &DEFAULT_ENGINE); + + b.iter(|| { + let orig = decode_engine(&encoded, &DEFAULT_ENGINE); + black_box(&orig); + }); +} +fn do_decode_bench_avx(b: &mut Bencher, &size: &usize) { + let mut v: Vec = Vec::with_capacity(size * 3 / 4); + fill(&mut v); + let encoded = encode_engine(&v, AVX2_ENGINE.deref()); + + b.iter(|| { + let orig = decode_engine(&encoded, AVX2_ENGINE.deref()); + black_box(&orig); + }); +} + +fn do_decode_bench_reuse_buf(b: &mut Bencher, &size: &usize) { + let mut v: Vec = Vec::with_capacity(size * 3 / 4); + fill(&mut v); + let encoded = encode_engine(&v, &DEFAULT_ENGINE); + + let mut buf = Vec::new(); + b.iter(|| { + decode_engine_vec(&encoded, &mut buf, &DEFAULT_ENGINE).unwrap(); + black_box(&buf); + buf.clear(); + }); +} + +fn do_decode_bench_reuse_buf_avx(b: &mut Bencher, &size: &usize) { + let mut v: Vec = Vec::with_capacity(size * 3 / 4); + fill(&mut v); + let encoded = encode_engine(&v, AVX2_ENGINE.deref()); + + let mut buf = Vec::new(); + b.iter(|| { + decode_engine_vec(&encoded, &mut buf, AVX2_ENGINE.deref()).unwrap(); + black_box(&buf); + buf.clear(); + }); +} + +fn do_decode_bench_slice(b: &mut Bencher, &size: &usize) { + let mut v: Vec = Vec::with_capacity(size * 3 / 4); + fill(&mut v); + let encoded = encode_engine(&v, &DEFAULT_ENGINE); + + let mut buf = Vec::new(); + buf.resize(size, 0); + b.iter(|| { + decode_engine_slice(&encoded, &mut buf, &DEFAULT_ENGINE).unwrap(); + black_box(&buf); + }); +} + +fn do_decode_bench_slice_avx(b: &mut Bencher, &size: &usize) { + let mut v: Vec = Vec::with_capacity(size * 3 / 4); + fill(&mut v); + let encoded = encode_engine(&v, AVX2_ENGINE.deref()); + + let mut buf = Vec::new(); + buf.resize(size, 0); + b.iter(|| { + decode_engine_slice(&encoded, &mut buf, AVX2_ENGINE.deref()).unwrap(); + black_box(&buf); + }); +} + +fn do_decode_bench_stream(b: &mut Bencher, &size: &usize) { + let mut v: Vec = Vec::with_capacity(size * 3 / 4); + fill(&mut v); + let encoded = encode_engine(&v, &DEFAULT_ENGINE); + + let mut buf = Vec::new(); + buf.resize(size, 0); + buf.truncate(0); + + b.iter(|| { + let mut cursor = io::Cursor::new(&encoded[..]); + let mut decoder = base64::read::DecoderReader::from(&mut cursor, &DEFAULT_ENGINE); + decoder.read_to_end(&mut buf).unwrap(); + buf.clear(); + black_box(&buf); + }); +} + +fn do_decode_bench_stream_avx(b: &mut Bencher, &size: &usize) { + let mut v: Vec = Vec::with_capacity(size * 3 / 4); + fill(&mut v); + let encoded = encode_engine(&v, AVX2_ENGINE.deref()); + + let mut buf = Vec::new(); + buf.resize(size, 0); + buf.truncate(0); + + b.iter(|| { + let mut cursor = io::Cursor::new(&encoded[..]); + let mut decoder = base64::read::DecoderReader::from(&mut cursor, AVX2_ENGINE.deref()); + decoder.read_to_end(&mut buf).unwrap(); + buf.clear(); + black_box(&buf); + }); +} + +fn do_encode_bench(b: &mut Bencher, &size: &usize) { + let mut v: Vec = Vec::with_capacity(size); + fill(&mut v); + b.iter(|| { + let e = encode_engine(&v, &DEFAULT_ENGINE); + black_box(&e); + }); +} + +fn do_encode_bench_avx(b: &mut Bencher, &size: &usize) { + let mut v: Vec = Vec::with_capacity(size); + fill(&mut v); + b.iter(|| { + let e = encode_engine(&v, AVX2_ENGINE.deref()); + black_box(&e); + }); +} + +fn do_encode_bench_display(b: &mut Bencher, &size: &usize) { + let mut v: Vec = Vec::with_capacity(size); + fill(&mut v); + b.iter(|| { + let e = format!("{}", display::Base64Display::from(&v, &DEFAULT_ENGINE)); + black_box(&e); + }); +} + +fn do_encode_bench_display_avx(b: &mut Bencher, &size: &usize) { + let mut v: Vec = Vec::with_capacity(size); + fill(&mut v); + b.iter(|| { + let e = format!("{}", display::Base64Display::from(&v, AVX2_ENGINE.deref())); + black_box(&e); + }); +} + +fn do_encode_bench_reuse_buf(b: &mut Bencher, &size: &usize) { + let mut v: Vec = Vec::with_capacity(size); + fill(&mut v); + let mut buf = String::new(); + b.iter(|| { + encode_engine_string(&v, &mut buf, &DEFAULT_ENGINE); + buf.clear(); + }); +} + +fn do_encode_bench_reuse_buf_avx(b: &mut Bencher, &size: &usize) { + let mut v: Vec = Vec::with_capacity(size); + fill(&mut v); + let mut buf = String::new(); + b.iter(|| { + encode_engine_string(&v, &mut buf, AVX2_ENGINE.deref()); + buf.clear(); + }); +} + +fn do_encode_bench_slice(b: &mut Bencher, &size: &usize) { + let mut v: Vec = Vec::with_capacity(size); + fill(&mut v); + let mut buf = Vec::new(); + // conservative estimate of encoded size + buf.resize(v.len() * 2, 0); + b.iter(|| { + encode_engine_slice(&v, &mut buf, &DEFAULT_ENGINE); + }); +} + +fn do_encode_bench_slice_avx(b: &mut Bencher, &size: &usize) { + let mut v: Vec = Vec::with_capacity(size); + fill(&mut v); + let mut buf = Vec::new(); + // conservative estimate of encoded size + buf.resize(v.len() * 2, 0); + b.iter(|| { + encode_engine_slice(&v, &mut buf, AVX2_ENGINE.deref()); + }); +} + +fn do_encode_bench_stream(b: &mut Bencher, &size: &usize) { + let mut v: Vec = Vec::with_capacity(size); + fill(&mut v); + let mut buf = Vec::new(); + + buf.reserve(size * 2); + b.iter(|| { + buf.clear(); + let mut stream_enc = write::EncoderWriter::from(&mut buf, &DEFAULT_ENGINE); + stream_enc.write_all(&v).unwrap(); + stream_enc.flush().unwrap(); + }); +} + +fn do_encode_bench_stream_avx(b: &mut Bencher, &size: &usize) { + let mut v: Vec = Vec::with_capacity(size); + fill(&mut v); + let mut buf = Vec::new(); + + buf.reserve(size * 2); + b.iter(|| { + buf.clear(); + let mut stream_enc = write::EncoderWriter::from(&mut buf, AVX2_ENGINE.deref()); + stream_enc.write_all(&v).unwrap(); + stream_enc.flush().unwrap(); + }); +} + +fn do_encode_bench_string_stream(b: &mut Bencher, &size: &usize) { + let mut v: Vec = Vec::with_capacity(size); + fill(&mut v); + + b.iter(|| { + let mut stream_enc = write::EncoderStringWriter::from(&DEFAULT_ENGINE); + stream_enc.write_all(&v).unwrap(); + stream_enc.flush().unwrap(); + let _ = stream_enc.into_inner(); + }); +} + +fn do_encode_bench_string_stream_avx(b: &mut Bencher, &size: &usize) { + let mut v: Vec = Vec::with_capacity(size); + fill(&mut v); + + b.iter(|| { + let mut stream_enc = write::EncoderStringWriter::from(AVX2_ENGINE.deref()); + stream_enc.write_all(&v).unwrap(); + stream_enc.flush().unwrap(); + let _ = stream_enc.into_inner(); + }); +} + +fn do_encode_bench_string_reuse_buf_stream(b: &mut Bencher, &size: &usize) { + let mut v: Vec = Vec::with_capacity(size); + fill(&mut v); + + let mut buf = String::new(); + b.iter(|| { + buf.clear(); + let mut stream_enc = write::EncoderStringWriter::from_consumer(&mut buf, &DEFAULT_ENGINE); + stream_enc.write_all(&v).unwrap(); + stream_enc.flush().unwrap(); + let _ = stream_enc.into_inner(); + }); +} + +fn do_encode_bench_string_reuse_buf_stream_avx(b: &mut Bencher, &size: &usize) { + let mut v: Vec = Vec::with_capacity(size); + fill(&mut v); + + let mut buf = String::new(); + b.iter(|| { + buf.clear(); + let mut stream_enc = write::EncoderStringWriter::from_consumer(&mut buf, AVX2_ENGINE.deref()); + stream_enc.write_all(&v).unwrap(); + stream_enc.flush().unwrap(); + let _ = stream_enc.into_inner(); + }); +} + +fn fill(v: &mut Vec) { + let cap = v.capacity(); + // weak randomness is plenty; we just want to not be completely friendly to the branch predictor + let mut r = rand::rngs::SmallRng::from_entropy(); + while v.len() < cap { + v.push(r.gen::()); + } +} + +const BYTE_SIZES: [usize; 5] = [3, 50, 100, 500, 3 * 1024]; + +// Benchmarks over these byte sizes take longer so we will run fewer samples to +// keep the benchmark runtime reasonable. +const LARGE_BYTE_SIZES: [usize; 3] = [3 * 1024 * 1024, 10 * 1024 * 1024, 30 * 1024 * 1024]; + +fn encode_benchmarks(c: &mut Criterion, label: &str, byte_sizes: &[usize]) { + { + + let mut group_dec = c.benchmark_group(label); + group_dec + .warm_up_time(std::time::Duration::from_millis(500)) + .measurement_time(std::time::Duration::from_secs(3)); + for size in byte_sizes { + group_dec + .throughput(Throughput::Bytes(*size as u64)) + .bench_with_input(BenchmarkId::new("encode", size), size, do_encode_bench) + .bench_with_input(BenchmarkId::new("encode_avx", size), size, do_encode_bench_avx); + } + group_dec.finish(); + } + + { + + let mut dis = String::from(label); + dis.push_str("_display"); + let mut group_dis = c.benchmark_group(dis); + group_dis + .warm_up_time(std::time::Duration::from_millis(500)) + .measurement_time(std::time::Duration::from_secs(3)); + for size in byte_sizes { + group_dis + .throughput(Throughput::Bytes(*size as u64)) + .bench_with_input( + BenchmarkId::new("encode_display", size), + size, + do_encode_bench_display, + ) + .bench_with_input( + BenchmarkId::new("encode_display_avx", size), + size, + do_encode_bench_display_avx, + ); + } + group_dis.finish(); + } + + { + let mut reu = String::from(label); + reu.push_str("_reuse"); + let mut group_reu = c.benchmark_group(reu); + group_reu + .warm_up_time(std::time::Duration::from_millis(500)) + .measurement_time(std::time::Duration::from_secs(3)); + for size in byte_sizes { + group_reu + .throughput(Throughput::Bytes(*size as u64)) + .bench_with_input( + BenchmarkId::new("encode_reuse_buf", size), + size, + do_encode_bench_reuse_buf, + ) + .bench_with_input( + BenchmarkId::new("encode_reuse_buf_avx", size), + size, + do_encode_bench_reuse_buf_avx, + ); + } + group_reu.finish(); + } + + { + + let mut sli = String::from(label); + sli.push_str("_slice"); + let mut group_sli = c.benchmark_group(sli); + group_sli + .warm_up_time(std::time::Duration::from_millis(500)) + .measurement_time(std::time::Duration::from_secs(3)); + for size in byte_sizes { + group_sli + .throughput(Throughput::Bytes(*size as u64)) + .bench_with_input( + BenchmarkId::new("encode_slice", size), + size, + do_encode_bench_slice, + ) + .bench_with_input( + BenchmarkId::new("encode_slice_avx", size), + size, + do_encode_bench_slice_avx, + ); + } + group_sli.finish(); + } + + { + + let mut str_ = String::from(label); + str_.push_str("_stream"); + let mut group_str = c.benchmark_group(str_); + group_str + .warm_up_time(std::time::Duration::from_millis(500)) + .measurement_time(std::time::Duration::from_secs(3)); + for size in byte_sizes { + group_str + .throughput(Throughput::Bytes(*size as u64)) + .bench_with_input( + BenchmarkId::new("encode_string_stream", size), + size, + do_encode_bench_string_stream, + ) + .bench_with_input( + BenchmarkId::new("encode_string_stream_avx", size), + size, + do_encode_bench_string_stream_avx, + ); + } + group_str.finish(); + } + + { + + let mut buf = String::from(label); + buf.push_str("_buf"); + let mut group_buf = c.benchmark_group(buf); + group_buf + .warm_up_time(std::time::Duration::from_millis(500)) + .measurement_time(std::time::Duration::from_secs(3)); + for size in byte_sizes { + group_buf + .bench_with_input( + BenchmarkId::new("encode_reuse_buf_stream", size), + size, + do_encode_bench_stream, + ) + .bench_with_input( + BenchmarkId::new("encode_reuse_buf_stream_avx", size), + size, + do_encode_bench_stream_avx, + ); + } + group_buf.finish(); + } + + let mut bufstr = String::from(label); + bufstr.push_str("_bufstream"); + let mut group_bufstr = c.benchmark_group(bufstr); + group_bufstr + .warm_up_time(std::time::Duration::from_millis(500)) + .measurement_time(std::time::Duration::from_secs(3)); + for size in byte_sizes { + group_bufstr + .throughput(Throughput::Bytes(*size as u64)) + .bench_with_input( + BenchmarkId::new("encode_string_reuse_buf_stream", size), + size, + do_encode_bench_string_reuse_buf_stream, + ) + .bench_with_input( + BenchmarkId::new("encode_string_reuse_buf_stream_avx", size), + size, + do_encode_bench_string_reuse_buf_stream_avx, + ); + } + group_bufstr.finish(); + + +} + +fn decode_benchmarks(c: &mut Criterion, label: &str, byte_sizes: &[usize]) { + { + let mut group_dec = c.benchmark_group(label); + for size in byte_sizes { + group_dec + .warm_up_time(std::time::Duration::from_millis(500)) + .measurement_time(std::time::Duration::from_secs(3)) + .throughput(Throughput::Bytes(*size as u64)) + .bench_with_input(BenchmarkId::new("decode", size), size, do_decode_bench) + .bench_with_input(BenchmarkId::new("decode_avx", size), size, do_decode_bench_avx); + } + group_dec.finish(); + } + { + + let mut reu = String::from(label); + reu.push_str("_reuse"); + let mut group_reu = c.benchmark_group(reu); + + for size in byte_sizes { + group_reu + .warm_up_time(std::time::Duration::from_millis(500)) + .measurement_time(std::time::Duration::from_secs(3)) + .throughput(Throughput::Bytes(*size as u64)) + .bench_with_input( + BenchmarkId::new("decode_reuse_buf", size), + size, + do_decode_bench_reuse_buf, + ) + .bench_with_input( + BenchmarkId::new("decode_reuse_buf_avx", size), + size, + do_decode_bench_reuse_buf_avx, + ); + } + + group_reu.finish() + } + { + let mut sli = String::from(label); + sli.push_str("_slice"); + let mut group_sli = c.benchmark_group(sli); + for size in byte_sizes { + group_sli + .warm_up_time(std::time::Duration::from_millis(500)) + .measurement_time(std::time::Duration::from_secs(3)) + .throughput(Throughput::Bytes(*size as u64)) + .bench_with_input( + BenchmarkId::new("decode_slice", size), + size, + do_decode_bench_slice, + ) + .bench_with_input( + BenchmarkId::new("decode_slice_avx", size), + size, + do_decode_bench_slice_avx, + ); + } + group_sli.finish(); + } + + let mut str_ = String::from(label); + str_.push_str("_stream"); + let mut group_str = c.benchmark_group(str_); + for size in byte_sizes { + group_str + .warm_up_time(std::time::Duration::from_millis(500)) + .measurement_time(std::time::Duration::from_secs(3)) + .throughput(Throughput::Bytes(*size as u64)) + .bench_with_input( + BenchmarkId::new("decode_stream", size), + size, + do_decode_bench_stream, + ) + .bench_with_input( + BenchmarkId::new("decode_stream_avx", size), + size, + do_decode_bench_stream_avx, + ); + } + group_str.finish(); + + +} + +fn do_align_bench(b: &mut Bencher, &size: &usize) { + let mut v: Vec = Vec::with_capacity(size * 3 / 4 + 32); + fill(&mut v); + + let (pre, aligned_u32, post) = unsafe { + v.align_to_mut::() + }; + let aligned: &[u8] = unsafe { + core::mem::transmute(aligned_u32) + }; + assert!(pre.len() == 0); + assert!(post.len() == 0); + + + let encoded = encode_engine(&v, AVX2_ENGINE.deref()); + + let mut buf = Vec::new(); + buf.resize(size, 0); + b.iter(|| { + decode_engine_slice(&encoded, &mut buf, AVX2_ENGINE.deref()).unwrap(); + black_box(&buf); + }); +} +fn do_unalign_bench(b: &mut Bencher, &size: &usize) { + let mut v: Vec = Vec::with_capacity(size * 3 / 4 + 32); + fill(&mut v); + + let encoded = encode_engine(&v[5..], AVX2_ENGINE.deref()); + + let mut buf = Vec::new(); + buf.resize(size, 0); + b.iter(|| { + decode_engine_slice(&encoded, &mut buf, AVX2_ENGINE.deref()).unwrap(); + black_box(&buf); + }); +} + +fn align_benchmarks(c: &mut Criterion, label: &str, byte_sizes: &[usize]) { + let mut group = c.benchmark_group(label); + for size in byte_sizes { + group + .warm_up_time(std::time::Duration::from_millis(500)) + .measurement_time(std::time::Duration::from_secs(3)) + .throughput(Throughput::Bytes(*size as u64)) + .bench_with_input(BenchmarkId::new("aligned", size), size, do_align_bench) + .bench_with_input(BenchmarkId::new("unaligned", size), size, do_unalign_bench); + } + group.finish(); +} + +fn bench(c: &mut Criterion) { + encode_benchmarks(c, "encode_small_input", &BYTE_SIZES[..]); + encode_benchmarks(c, "encode_large_input", &LARGE_BYTE_SIZES[..]); + decode_benchmarks(c, "decode_small_input", &BYTE_SIZES[..]); + decode_benchmarks(c, "decode_large_input", &LARGE_BYTE_SIZES[..]); + + align_benchmarks(c, "align_benchmark", &LARGE_BYTE_SIZES[..]); +} + +criterion_group!(benches, bench); +criterion_main!(benches); diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml index 8b74839..e31eed6 100644 --- a/fuzz/Cargo.toml +++ b/fuzz/Cargo.toml @@ -5,6 +5,9 @@ version = "0.0.1" authors = ["Automatically generated"] publish = false +[features] +fuzz-simd = ["base64/avx2"] + [package.metadata] cargo-fuzz = true @@ -12,8 +15,10 @@ cargo-fuzz = true rand = "0.6.1" rand_pcg = "0.1.1" ring = "0.13.5" + [dependencies.base64] path = ".." + [dependencies.libfuzzer-sys] git = "https://github.com/rust-fuzz/libfuzzer-sys.git" @@ -36,3 +41,18 @@ path = "fuzzers/roundtrip_random_config.rs" [[bin]] name = "decode_random" path = "fuzzers/decode_random.rs" + +[[bin]] +name = "roundtrip_avx2" +path = "fuzzers/roundtrip_avx.rs" +required-features = ["fuzz-simd"] + +[[bin]] +name = "decode_random_avx2" +path = "fuzzers/decode_random_avx.rs" +required-features = ["fuzz-simd"] + +[[bin]] +name = "roundtrip_avx_equivalent" +path = "fuzzers/roundtrip_avx_equivalent.rs" +required-features = ["fuzz-simd"] diff --git a/fuzz/fuzzers/decode_random_avx.rs b/fuzz/fuzzers/decode_random_avx.rs new file mode 100644 index 0000000..362943a --- /dev/null +++ b/fuzz/fuzzers/decode_random_avx.rs @@ -0,0 +1,14 @@ +#![no_main] +#[macro_use] extern crate libfuzzer_sys; +extern crate base64; + +use base64::decode_engine; +use base64::engine::avx2::{AVX2Encoder, AVX2Config}; + +fuzz_target!(|data: &[u8]| { + let engine = AVX2Encoder::from_standard(AVX2Config::new()); + + // The data probably isn't valid base64 input, but as long as it returns an error instead + // of crashing, that's correct behavior. + let _ = decode_engine(&data, &engine); +}); diff --git a/fuzz/fuzzers/roundtrip.rs b/fuzz/fuzzers/roundtrip.rs index 2097f2a..cabaca8 100644 --- a/fuzz/fuzzers/roundtrip.rs +++ b/fuzz/fuzzers/roundtrip.rs @@ -8,4 +8,4 @@ fuzz_target!(|data: &[u8]| { let encoded = base64::encode_engine(&data, &DEFAULT_ENGINE); let decoded = base64::decode_engine(&encoded, &DEFAULT_ENGINE).unwrap(); assert_eq!(data, decoded.as_slice()); -}); +}); \ No newline at end of file diff --git a/fuzz/fuzzers/roundtrip_avx.rs b/fuzz/fuzzers/roundtrip_avx.rs new file mode 100644 index 0000000..0f7c54e --- /dev/null +++ b/fuzz/fuzzers/roundtrip_avx.rs @@ -0,0 +1,12 @@ +#![no_main] +#[macro_use] extern crate libfuzzer_sys; +extern crate base64; + +use base64::engine::avx2::{AVX2Encoder, AVX2Config}; +fuzz_target!(|data: &[u8]| { + let engine = AVX2Encoder::from_standard(AVX2Config::new()); + + let encoded = base64::encode_engine(&data, &engine); + let decoded = base64::decode_engine(&encoded, &engine).unwrap(); + assert_eq!(data, decoded.as_slice()); +}); \ No newline at end of file diff --git a/fuzz/fuzzers/roundtrip_avx_equivalent.rs b/fuzz/fuzzers/roundtrip_avx_equivalent.rs new file mode 100644 index 0000000..d842a38 --- /dev/null +++ b/fuzz/fuzzers/roundtrip_avx_equivalent.rs @@ -0,0 +1,18 @@ +#![no_main] +#[macro_use] extern crate libfuzzer_sys; +extern crate base64; + +use base64::engine::DEFAULT_ENGINE; + +use base64::engine::avx2::{AVX2Encoder, AVX2Config}; +fuzz_target!(|data: &[u8]| { + let avx_engine = AVX2Encoder::from_standard(AVX2Config::new()); + + let avx_encoded = base64::encode_engine(&data, &avx_engine); + let def_decoded = base64::decode_engine(&avx_encoded, &DEFAULT_ENGINE).unwrap(); + let def_encoded = base64::encode_engine(&data, &DEFAULT_ENGINE); + let avx_decoded = base64::decode_engine(&def_encoded, &avx_engine).unwrap(); + + assert_eq!(data, def_decoded.as_slice()); + assert_eq!(data, avx_decoded.as_slice()); +}); \ No newline at end of file diff --git a/src/engine/tests.rs b/src/engine/tests.rs index 9ec70f8..accba04 100644 --- a/src/engine/tests.rs +++ b/src/engine/tests.rs @@ -16,12 +16,26 @@ use crate::{ }; // the case::foo syntax includes the "foo" in the generated test method names +#[template] +#[cfg_attr(feature = "avx2", rstest(engine_wrapper, +case::avx2(avx2_tests::AVX2Wrapper {}), +case::fast_portable(FastPortableWrapper {}), +case::naive(NaiveWrapper {}), +))] +#[cfg_attr(not(feature = "avx2"), rstest(engine_wrapper, +case::fast_portable(FastPortableWrapper {}), +case::naive(NaiveWrapper {}), +))] +// Absolutely all engines +fn all_engines(engine_wrapper: E) {} + #[template] #[rstest(engine_wrapper, case::fast_portable(FastPortableWrapper {}), case::naive(NaiveWrapper {}), )] -fn all_engines(engine_wrapper: E) {} +// Engines that can handle a custom alphabet +fn literate_engines(engine_wrapper: E) {} #[apply(all_engines)] fn rfc_test_vectors_std_alphabet(engine_wrapper: E) { @@ -290,7 +304,7 @@ fn decode_detect_invalid_last_symbol_two_bytes(engine_wrapper: } } -#[apply(all_engines)] +#[apply(literate_engines)] fn decode_detect_invalid_last_symbol_when_length_is_also_invalid( engine_wrapper: E, ) { @@ -489,7 +503,7 @@ fn decode_invalid_trailing_bits_ignored_when_configured(engine } } -#[apply(all_engines)] +#[apply(literate_engines)] fn decode_invalid_byte_error(engine_wrapper: E) { let mut rng = rand::rngs::SmallRng::from_entropy(); @@ -939,6 +953,37 @@ impl EngineWrapper for NaiveWrapper { } } + +#[cfg(feature = "avx2")] +mod avx2_tests { + use super::*; + use crate::engine::avx2; + + pub(super) struct AVX2Wrapper {} + + impl EngineWrapper for AVX2Wrapper { + type Engine = avx2::AVX2Encoder; + + fn standard() -> Self::Engine { + avx2::AVX2Encoder::from_standard(avx2::AVX2Config::default()) + } + + fn standard_forgiving() -> Self::Engine { + avx2::AVX2Encoder::from_standard(avx2::AVX2Config::default() + .with_decode_allow_trailing_bits(true)) + } + + fn random(_rng: &mut R) -> Self::Engine { + // The avx alg can't handle custom alphabets yet + avx2::AVX2Encoder::from_standard(avx2::AVX2Config::default()) + } + + fn random_alphabet(rng: &mut R, alphabet: &Alphabet) -> Self::Engine { + unimplemented!() + } + } +} + trait EngineExtensions: Engine { // a convenience wrapper to avoid the separate estimate call in tests fn decode_ez(&self, input: &[u8], output: &mut [u8]) -> Result { diff --git a/tests/decode.rs b/tests/decode.rs index d7e29a7..3f0df20 100644 --- a/tests/decode.rs +++ b/tests/decode.rs @@ -84,3 +84,83 @@ fn decode_imap() { decode_engine(b"+//+", &DEFAULT_ENGINE) ); } + +#[test] +fn decode_urlsafe() { + let engine = FastPortable::from(&alphabet::URL_SAFE, NO_PAD); + let out = decode_engine( + "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJygpKissLS4vMDEyMzQ1Njc4OTo7PD0\ + -P0BBQkNERUZHSElKS0xNTk9QUVJTVFVWV1hZWltcXV5fYGFiY2RlZmdoaWprbG1ub3BxcnN0dXZ3eHl6e3x9fn\ + -AgYKDhIWGh4iJiouMjY6PkJGSk5SVlpeYmZqbnJ2en6ChoqOkpaanqKmqq6ytrq\ + -wsbKztLW2t7i5uru8vb6_wMHCw8TFxsfIycrLzM3Oz9DR0tPU1dbX2Nna29zd3t_g4eLj5OXm5-jp6uvs7e7v8PHy\ + 8_T19vf4-fr7_P3-_w==", + &engine + ).unwrap(); + let mut bytes: Vec = (0..255).collect(); + bytes.push(255); + + assert_eq!(out, bytes); +} + +#[cfg(feature = "avx2")] +mod avx2test { + use super::*; + + use base64::engine::avx2::{AVX2Encoder, AVX2Config}; + + #[test] + fn decode_long() { + let engine = AVX2Encoder::from_standard(AVX2Config::new()); + let out = decode_engine( + "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJygpKissLS4vMDEyMzQ1Njc4OTo7PD0+P0\ + BBQkNERUZHSElKS0xNTk9QUVJTVFVWV1hZWltcXV5fYGFiY2RlZmdoaWprbG1ub3BxcnN0dXZ3eHl6e3x9fn+Ag\ + YKDhIWGh4iJiouMjY6PkJGSk5SVlpeYmZqbnJ2en6ChoqOkpaanqKmqq6ytrq+wsbKztLW2t7i5uru8vb6/wMHC\ + w8TFxsfIycrLzM3Oz9DR0tPU1dbX2Nna29zd3t/g4eLj5OXm5+jp6uvs7e7v8PHy8/T19vf4+fr7/P3+/w==", + &engine + ).unwrap(); + println!("{:?}", out); + for (a,b) in out.iter().enumerate() { + assert_eq!(a as u8, *b); + } + } + + #[test] + fn decode_long_err() { + let engine = AVX2Encoder::from_standard(AVX2Config::new()); + let out = decode_engine( + "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJygpKissLS4vMDEyMzQ1Njc4OTo7PD0.P0\ + BBQkNERUZHSElKS0xNTk9QUVJTVFVWV1hZWltcXV5fYGFiY2RlZmdoaWprbG1ub3BxcnN0dXZ3eHl6e3x9fn+Ag\ + YKDhIWGh4iJiouMjY6PkJGSk5SVlpeYmZqbnJ2en6ChoqOkpaanqKmqq6ytrq+wsbKztLW2t7i5uru8vb6/wMHC\ + w8TFxsfIycrLzM3Oz9DR0tPU1dbX2Nna29zd3t/g4eLj5OXm5+jp6uvs7e7v8PHy8/T19vf4+fr7/P3+/w==", + &engine + ).unwrap_err(); + + assert_eq!(DecodeError::InvalidByte(83, '.' as u8), out); + } + + #[test] + fn decode_reject_null() { + let engine = AVX2Encoder::from_standard(AVX2Config::new()); + assert_eq!( + DecodeError::InvalidByte(3, 0x0), + decode_engine("YWx\0pY2U==", &engine).unwrap_err() + ); + } + + #[test] + fn decode_urlsafe() { + let engine = AVX2Encoder::from_url_safe(AVX2Config::new()); + let out = decode_engine( + "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJygpKissLS4vMDEyMzQ1Njc4OTo7PD0\ + -P0BBQkNERUZHSElKS0xNTk9QUVJTVFVWV1hZWltcXV5fYGFiY2RlZmdoaWprbG1ub3BxcnN0dXZ3eHl6e3x9fn\ + -AgYKDhIWGh4iJiouMjY6PkJGSk5SVlpeYmZqbnJ2en6ChoqOkpaanqKmqq6ytrq\ + -wsbKztLW2t7i5uru8vb6_wMHCw8TFxsfIycrLzM3Oz9DR0tPU1dbX2Nna29zd3t_g4eLj5OXm5-jp6uvs7e7v8PHy\ + 8_T19vf4-fr7_P3-_w==", + &engine + ).unwrap(); + let mut bytes: Vec = (0..255).collect(); + bytes.push(255); + + assert_eq!(out, bytes); + } +} diff --git a/tests/encode.rs b/tests/encode.rs index 7b3561e..76a3e54 100644 --- a/tests/encode.rs +++ b/tests/encode.rs @@ -4,10 +4,21 @@ use base64::alphabet::URL_SAFE; use base64::engine::fast_portable::{NO_PAD, PAD}; use base64::*; +#[cfg(not(feature = "avx2"))] fn compare_encode(expected: &str, target: &[u8]) { assert_eq!(expected, encode(target)); } +#[cfg(feature = "avx2")] +fn compare_encode(expected: &str, target: &[u8]) { + assert_eq!(expected, encode(target)); + + use base64::engine::avx2::{AVX2Encoder, AVX2Config}; + let engine: AVX2Encoder = AVX2Encoder::from_standard(AVX2Config::new()); + + assert_eq!(expected, encode_engine(target, &engine)); +} + #[test] fn encode_rfc4648_0() { compare_encode("", b""); @@ -111,3 +122,33 @@ fn encode_url_safe_without_padding() { "alice" ); } + +#[cfg(feature = "avx2")] +mod avx2tests { + use super::*; + + use base64::engine::avx2::{AVX2Encoder, AVX2Config}; + + #[test] + fn encode_all_bytes_url() { + let engine: AVX2Encoder = AVX2Encoder::from_url_safe(AVX2Config::new()); + let mut bytes = Vec::::with_capacity(256); + + for i in 0..255 { + bytes.push(i); + } + bytes.push(255); //bug with "overflowing" ranges? + + assert_eq!( + "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJygpKissLS4vMDEyMzQ1Njc4OTo7PD0\ + -P0BBQkNERUZHSElKS0xNTk9QUVJTVFVWV1hZWltcXV5fYGFiY2RlZmdoaWprbG1ub3BxcnN0dXZ3eHl6e3x9fn\ + -AgYKDhIWGh4iJiouMjY6PkJGSk5SVlpeYmZqbnJ2en6ChoqOkpaanqKmqq6ytrq\ + -wsbKztLW2t7i5uru8vb6_wMHCw8TFxsfIycrLzM3Oz9DR0tPU1dbX2Nna29zd3t_g4eLj5OXm5-jp6uvs7e7v8PHy\ + 8_T19vf4-fr7_P3-_w==", + encode_engine( + &bytes, + &engine + ) + ); + } +} diff --git a/tests/helpers.rs b/tests/helpers.rs index 5144988..bde4e9c 100644 --- a/tests/helpers.rs +++ b/tests/helpers.rs @@ -2,6 +2,7 @@ extern crate base64; use base64::*; +#[cfg(not(feature = "avx2"))] pub fn compare_decode(expected: &str, target: &str) { assert_eq!( expected, @@ -12,3 +13,28 @@ pub fn compare_decode(expected: &str, target: &str) { String::from_utf8(decode(target.as_bytes()).unwrap()).unwrap() ); } + +#[cfg(feature = "avx2")] +pub fn compare_decode(expected: &str, target: &str) { + let engine = &engine::DEFAULT_ENGINE; + assert_eq!( + expected, + String::from_utf8(decode_engine(target, engine).unwrap()).unwrap() + ); + assert_eq!( + expected, + String::from_utf8(decode_engine(target.as_bytes(), engine).unwrap()).unwrap() + ); + + use base64::engine::avx2::{AVX2Encoder, AVX2Config}; + let engine = AVX2Encoder::from_standard(AVX2Config::new()); + + assert_eq!( + expected, + String::from_utf8(decode_engine(target, &engine).unwrap()).unwrap() + ); + assert_eq!( + expected, + String::from_utf8(decode_engine(target.as_bytes(), &engine).unwrap()).unwrap() + ); +}