Skip to content

Commit

Permalink
Replace fastrand with getrandom and base64
Browse files Browse the repository at this point in the history
- Instead of setting up a predictable userspace RNG, we get
  unpredictable random bytes directly from the OS. This fixes Stebalien#178.
- To obtain a uniformly distributed alphanumeric string, we convert the
  the random bytes to base64 and throw away any letters we don't want
  (`+` and `/`). With a low probability, this may result in obtaining
  too few alphanumeric letters, in which case we request more randomness
  from the OS until we have enough.
- Because we cannot control the seed anymore, a test manufacturing
  collisions by setting the same seed for several threads had to be removed.
  • Loading branch information
vks committed Aug 16, 2022
1 parent 4e4e323 commit 35aabfc
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 57 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ repository = "https://github.com/Stebalien/tempfile"
description = "A library for managing temporary files and directories."

[dependencies]
base64 = "0.13.0"
cfg-if = "1"
fastrand = "1.6.0"
getrandom = "0.2.7"
remove_dir_all = "0.5"

[target.'cfg(any(unix, target_os = "wasi"))'.dependencies]
Expand Down
35 changes: 31 additions & 4 deletions src/util.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,43 @@
use std::ffi::{OsStr, OsString};
use std::path::{Path, PathBuf};
use std::{io, iter::repeat_with};
use std::io;

use crate::error::IoResultExt;

fn calculate_rand_buf_len(alphanumeric_len: usize) -> usize {
let expected_non_alphanumeric_chars = alphanumeric_len / 32;
(alphanumeric_len + expected_non_alphanumeric_chars) * 3 / 4 + 3
}

fn calculate_base64_len(binary_len: usize) -> usize {
binary_len * 4 / 3 + 4
}

fn fill_with_random_base64(rand_buf: &mut [u8], char_buf: &mut Vec<u8>) {
getrandom::getrandom(rand_buf).expect("calling getrandom failed");
char_buf.resize(calculate_base64_len(rand_buf.len()), 0);
base64::encode_config_slice(rand_buf, base64::STANDARD_NO_PAD, char_buf);
}

fn tmpname(prefix: &OsStr, suffix: &OsStr, rand_len: usize) -> OsString {
let mut buf = OsString::with_capacity(prefix.len() + suffix.len() + rand_len);
buf.push(prefix);
let mut char_buf = [0u8; 4];
for c in repeat_with(fastrand::alphanumeric).take(rand_len) {
buf.push(c.encode_utf8(&mut char_buf));

let mut rand_buf = vec![0; calculate_rand_buf_len(rand_len)];
let mut char_buf = vec![0; calculate_base64_len(rand_buf.len())];
let mut remaining_chars = rand_len;
loop {
fill_with_random_base64(&mut rand_buf, &mut char_buf);
char_buf.retain(|&c| (c != b'+') & (c != b'/') & (c != 0));
if char_buf.len() >= remaining_chars {
buf.push(std::str::from_utf8(&char_buf[..remaining_chars]).unwrap());
break;
} else {
buf.push(std::str::from_utf8(&char_buf).unwrap());
remaining_chars -= char_buf.len();
}
}

buf.push(suffix);
buf
}
Expand Down
52 changes: 0 additions & 52 deletions tests/namedtempfile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,55 +387,3 @@ fn test_make_uds() {

assert!(temp_sock.path().exists());
}

#[cfg(unix)]
#[test]
fn test_make_uds_conflict() {
use std::os::unix::net::UnixListener;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;

// Check that retries happen correctly by racing N different threads.

const NTHREADS: usize = 20;

// The number of times our callback was called.
let tries = Arc::new(AtomicUsize::new(0));

let mut threads = Vec::with_capacity(NTHREADS);

for _ in 0..NTHREADS {
let tries = tries.clone();
threads.push(std::thread::spawn(move || {
// Ensure that every thread uses the same seed so we are guaranteed
// to retry. Note that fastrand seeds are thread-local.
fastrand::seed(42);

Builder::new()
.prefix("tmp")
.suffix(".sock")
.rand_bytes(12)
.make(|path| {
tries.fetch_add(1, Ordering::Relaxed);
UnixListener::bind(path)
})
}));
}

// Join all threads, but don't drop the temp file yet. Otherwise, we won't
// get a deterministic number of `tries`.
let sockets: Vec<_> = threads
.into_iter()
.map(|thread| thread.join().unwrap().unwrap())
.collect();

// Number of tries is exactly equal to (n*(n+1))/2.
assert_eq!(
tries.load(Ordering::Relaxed),
(NTHREADS * (NTHREADS + 1)) / 2
);

for socket in sockets {
assert!(socket.path().exists());
}
}

0 comments on commit 35aabfc

Please sign in to comment.