Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve UEFI stdio #117174

Merged
merged 3 commits into from
Feb 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 94 additions & 37 deletions library/std/src/sys/pal/uefi/stdio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,51 +4,100 @@ use crate::mem::MaybeUninit;
use crate::os::uefi;
use crate::ptr::NonNull;

const MAX_BUFFER_SIZE: usize = 8192;
pub struct Stdin {
surrogate: Option<u16>,
incomplete_utf8: IncompleteUtf8,
}

struct IncompleteUtf8 {
bytes: [u8; 4],
len: u8,
}

impl IncompleteUtf8 {
pub const fn new() -> IncompleteUtf8 {
IncompleteUtf8 { bytes: [0; 4], len: 0 }
}

// Implemented for use in Stdin::read.
fn read(&mut self, buf: &mut [u8]) -> usize {
// Write to buffer until the buffer is full or we run out of bytes.
let to_write = crate::cmp::min(buf.len(), self.len as usize);
buf[..to_write].copy_from_slice(&self.bytes[..to_write]);

// Rotate the remaining bytes if not enough remaining space in buffer.
if usize::from(self.len) > buf.len() {
self.bytes.copy_within(to_write.., 0);
self.len -= to_write as u8;
} else {
self.len = 0;
}

to_write
}
}

pub struct Stdin;
pub struct Stdout;
pub struct Stderr;

impl Stdin {
pub const fn new() -> Stdin {
Stdin
Stdin { surrogate: None, incomplete_utf8: IncompleteUtf8::new() }
}
}

impl io::Read for Stdin {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let st: NonNull<r_efi::efi::SystemTable> = uefi::env::system_table().cast();
let stdin = unsafe { (*st.as_ptr()).con_in };

// Try reading any pending data
let inp = match read_key_stroke(stdin) {
Ok(x) => x,
Err(e) if e == r_efi::efi::Status::NOT_READY => {
// Wait for keypress for new data
wait_stdin(stdin)?;
read_key_stroke(stdin).map_err(|x| io::Error::from_raw_os_error(x.as_usize()))?
}
Err(e) => {
return Err(io::Error::from_raw_os_error(e.as_usize()));
}
// If there are bytes in the incomplete utf-8, start with those.
// (No-op if there is nothing in the buffer.)
let mut bytes_copied = self.incomplete_utf8.read(buf);

let stdin: *mut r_efi::protocols::simple_text_input::Protocol = unsafe {
let st: NonNull<r_efi::efi::SystemTable> = uefi::env::system_table().cast();
(*st.as_ptr()).con_in
};

// Check if the key is printiable character
if inp.scan_code != 0x00 {
return Err(io::const_io_error!(io::ErrorKind::Interrupted, "Special Key Press"));
if bytes_copied == buf.len() {
return Ok(bytes_copied);
}

// SAFETY: Iterator will have only 1 character since we are reading only 1 Key
// SAFETY: This character will always be UCS-2 and thus no surrogates.
let ch: char = char::decode_utf16([inp.unicode_char]).next().unwrap().unwrap();
if ch.len_utf8() > buf.len() {
return Ok(0);
let ch = simple_text_input_read(stdin)?;
// Only 1 character should be returned.
let mut ch: Vec<Result<char, crate::char::DecodeUtf16Error>> =
if let Some(x) = self.surrogate.take() {
char::decode_utf16([x, ch]).collect()
} else {
char::decode_utf16([ch]).collect()
};

if ch.len() > 1 {
return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid utf-16 sequence"));
}

ch.encode_utf8(buf);
match ch.pop().unwrap() {
Err(e) => {
self.surrogate = Some(e.unpaired_surrogate());
}
Ok(x) => {
// This will always be > 0
let buf_free_count = buf.len() - bytes_copied;
assert!(buf_free_count > 0);

if buf_free_count >= x.len_utf8() {
// There is enough space in the buffer for the character.
bytes_copied += x.encode_utf8(&mut buf[bytes_copied..]).len();
} else {
// There is not enough space in the buffer for the character.
// Store the character in the incomplete buffer.
self.incomplete_utf8.len =
x.encode_utf8(&mut self.incomplete_utf8.bytes).len() as u8;
// write partial character to buffer.
bytes_copied += self.incomplete_utf8.read(buf);
}
}
}

Ok(ch.len_utf8())
Ok(bytes_copied)
}
}

Expand Down Expand Up @@ -90,11 +139,11 @@ impl io::Write for Stderr {
}
}

// UCS-2 character should occupy 3 bytes at most in UTF-8
pub const STDIN_BUF_SIZE: usize = 3;
// UTF-16 character should occupy 4 bytes at most in UTF-8
pub const STDIN_BUF_SIZE: usize = 4;

pub fn is_ebadf(_err: &io::Error) -> bool {
true
false
}

pub fn panic_output() -> Option<impl io::Write> {
Expand All @@ -105,19 +154,15 @@ fn write(
protocol: *mut r_efi::protocols::simple_text_output::Protocol,
buf: &[u8],
) -> io::Result<usize> {
let mut utf16 = [0; MAX_BUFFER_SIZE / 2];

// Get valid UTF-8 buffer
let utf8 = match crate::str::from_utf8(buf) {
Ok(x) => x,
Err(e) => unsafe { crate::str::from_utf8_unchecked(&buf[..e.valid_up_to()]) },
};
// Clip UTF-8 buffer to max UTF-16 buffer we support
let utf8 = &utf8[..utf8.floor_char_boundary(utf16.len() - 1)];

for (i, ch) in utf8.encode_utf16().enumerate() {
utf16[i] = ch;
}
let mut utf16: Vec<u16> = utf8.encode_utf16().collect();
// NULL terminate the string
utf16.push(0);

unsafe { simple_text_output(protocol, &mut utf16) }?;

Expand All @@ -132,6 +177,18 @@ unsafe fn simple_text_output(
if res.is_error() { Err(io::Error::from_raw_os_error(res.as_usize())) } else { Ok(()) }
}

fn simple_text_input_read(
stdin: *mut r_efi::protocols::simple_text_input::Protocol,
) -> io::Result<u16> {
loop {
match read_key_stroke(stdin) {
Ok(x) => return Ok(x.unicode_char),
Err(e) if e == r_efi::efi::Status::NOT_READY => wait_stdin(stdin)?,
Err(e) => return Err(io::Error::from_raw_os_error(e.as_usize())),
}
}
}

fn wait_stdin(stdin: *mut r_efi::protocols::simple_text_input::Protocol) -> io::Result<()> {
let boot_services: NonNull<r_efi::efi::BootServices> =
uefi::env::boot_services().unwrap().cast();
Expand Down
Loading