Skip to content

Commit

Permalink
fix: read stderr in download (#1486)
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene authored Jan 25, 2024
1 parent 7e2a743 commit 9c320e2
Showing 1 changed file with 30 additions and 18 deletions.
48 changes: 30 additions & 18 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use nix::unistd::Pid;
use serde::Deserialize;
use std::env;
use std::ffi::OsString;
use std::io::{BufRead, BufReader, Lines, Read};
use std::io::{BufRead, BufReader, Lines};
use std::os::unix::process::{CommandExt, ExitStatusExt};
use std::path::Path;
use std::process::{Child, Command, ExitStatus, Stdio};
Expand Down Expand Up @@ -489,6 +489,9 @@ fn shard_manager(
// Safetensors load fast
envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));

// Disable progress bar
envs.push(("HF_HUB_DISABLE_PROGRESS_BARS".into(), "1".into()));

// Enable hf transfer for insane download speeds
let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
envs.push((
Expand Down Expand Up @@ -573,20 +576,20 @@ fn shard_manager(
thread::spawn(move || {
log_lines(shard_stdout_reader.lines());
});
// We read stderr in another thread as it seems that lines() can block in some cases
let (err_sender, err_receiver) = mpsc::channel();
thread::spawn(move || {
for line in shard_stderr_reader.lines().flatten() {
err_sender.send(line).unwrap_or(());
}
});

let mut ready = false;
let start_time = Instant::now();
let mut wait_time = Instant::now();
loop {
// Process exited
if let Some(exit_status) = p.try_wait().unwrap() {
// We read stderr in another thread as it seems that lines() can block in some cases
let (err_sender, err_receiver) = mpsc::channel();
thread::spawn(move || {
for line in shard_stderr_reader.lines().flatten() {
err_sender.send(line).unwrap_or(());
}
});
let mut err = String::new();
while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) {
err = err + "\n" + &line;
Expand Down Expand Up @@ -782,6 +785,9 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
// Copy current process env
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();

// Disable progress bar
envs.push(("HF_HUB_DISABLE_PROGRESS_BARS".into(), "1".into()));

// If huggingface_hub_cache is set, pass it to the download process
// Useful when running inside a docker container
if let Some(ref huggingface_hub_cache) = args.huggingface_hub_cache {
Expand Down Expand Up @@ -832,12 +838,20 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
}
};

// Redirect STDOUT to the console
let download_stdout = download_process.stdout.take().unwrap();
let stdout = BufReader::new(download_stdout);
let download_stdout = BufReader::new(download_process.stdout.take().unwrap());

thread::spawn(move || {
log_lines(download_stdout.lines());
});

let download_stderr = BufReader::new(download_process.stderr.take().unwrap());

// We read stderr in another thread as it seems that lines() can block in some cases
let (err_sender, err_receiver) = mpsc::channel();
thread::spawn(move || {
log_lines(stdout.lines());
for line in download_stderr.lines().flatten() {
err_sender.send(line).unwrap_or(());
}
});

loop {
Expand All @@ -848,12 +862,10 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
}

let mut err = String::new();
download_process
.stderr
.take()
.unwrap()
.read_to_string(&mut err)
.unwrap();
while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) {
err = err + "\n" + &line;
}

if let Some(signal) = status.signal() {
tracing::error!(
"Download process was signaled to shutdown with signal {signal}: {err}"
Expand Down

0 comments on commit 9c320e2

Please sign in to comment.