Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stdio][windows] Use MBTWC and WCTMB #107110

Merged
merged 1 commit into from
Feb 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions library/std/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@
all(target_vendor = "fortanix", target_env = "sgx"),
feature(slice_index_methods, coerce_unsized, sgx_platform)
)]
#![cfg_attr(windows, feature(round_char_boundary))]
//
// Language features:
#![feature(alloc_error_handler)]
Expand Down
32 changes: 30 additions & 2 deletions library/std/src/sys/windows/c.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@

use crate::ffi::CStr;
use crate::mem;
use crate::os::raw::{c_char, c_int, c_long, c_longlong, c_uint, c_ulong, c_ushort};
use crate::os::raw::{c_char, c_long, c_longlong, c_uint, c_ulong, c_ushort};
use crate::os::windows::io::{BorrowedHandle, HandleOrInvalid, HandleOrNull};
use crate::ptr;
use core::ffi::NonZero_c_ulong;

use libc::{c_void, size_t, wchar_t};

pub use crate::os::raw::c_int;

#[path = "c/errors.rs"] // c.rs is included from two places so we need to specify this
mod errors;
pub use errors::*;
Expand Down Expand Up @@ -47,16 +49,19 @@ pub type ACCESS_MASK = DWORD;

pub type LPBOOL = *mut BOOL;
pub type LPBYTE = *mut BYTE;
pub type LPCCH = *const CHAR;
pub type LPCSTR = *const CHAR;
pub type LPCWCH = *const WCHAR;
pub type LPCWSTR = *const WCHAR;
pub type LPCVOID = *const c_void;
pub type LPDWORD = *mut DWORD;
pub type LPHANDLE = *mut HANDLE;
pub type LPOVERLAPPED = *mut OVERLAPPED;
pub type LPPROCESS_INFORMATION = *mut PROCESS_INFORMATION;
pub type LPSECURITY_ATTRIBUTES = *mut SECURITY_ATTRIBUTES;
pub type LPSTARTUPINFO = *mut STARTUPINFO;
pub type LPSTR = *mut CHAR;
pub type LPVOID = *mut c_void;
pub type LPCVOID = *const c_void;
pub type LPWCH = *mut WCHAR;
pub type LPWIN32_FIND_DATAW = *mut WIN32_FIND_DATAW;
pub type LPWSADATA = *mut WSADATA;
Expand Down Expand Up @@ -132,6 +137,10 @@ pub const MAX_PATH: usize = 260;

pub const FILE_TYPE_PIPE: u32 = 3;

pub const CP_UTF8: DWORD = 65001;
pub const MB_ERR_INVALID_CHARS: DWORD = 0x08;
pub const WC_ERR_INVALID_CHARS: DWORD = 0x80;

#[repr(C)]
#[derive(Copy)]
pub struct WIN32_FIND_DATAW {
Expand Down Expand Up @@ -1155,6 +1164,25 @@ extern "system" {
lpFilePart: *mut LPWSTR,
) -> DWORD;
pub fn GetFileAttributesW(lpFileName: LPCWSTR) -> DWORD;

pub fn MultiByteToWideChar(
CodePage: UINT,
dwFlags: DWORD,
lpMultiByteStr: LPCCH,
cbMultiByte: c_int,
lpWideCharStr: LPWSTR,
cchWideChar: c_int,
) -> c_int;
pub fn WideCharToMultiByte(
CodePage: UINT,
dwFlags: DWORD,
lpWideCharStr: LPCWCH,
cchWideChar: c_int,
lpMultiByteStr: LPSTR,
cbMultiByte: c_int,
lpDefaultChar: LPCCH,
lpUsedDefaultChar: LPBOOL,
) -> c_int;
}

#[link(name = "ws2_32")]
Expand Down
74 changes: 47 additions & 27 deletions library/std/src/sys/windows/stdio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,14 +169,27 @@ fn write(
}

fn write_valid_utf8_to_console(handle: c::HANDLE, utf8: &str) -> io::Result<usize> {
debug_assert!(!utf8.is_empty());
ChrisDenton marked this conversation as resolved.
Show resolved Hide resolved

let mut utf16 = [MaybeUninit::<u16>::uninit(); MAX_BUFFER_SIZE / 2];
let mut len_utf16 = 0;
for (chr, dest) in utf8.encode_utf16().zip(utf16.iter_mut()) {
*dest = MaybeUninit::new(chr);
len_utf16 += 1;
}
// Safety: We've initialized `len_utf16` values.
let utf16: &[u16] = unsafe { MaybeUninit::slice_assume_init_ref(&utf16[..len_utf16]) };
let utf8 = &utf8[..utf8.floor_char_boundary(utf16.len())];

let utf16: &[u16] = unsafe {
// Note that this theoretically checks validity twice in the (most common) case
// where the underlying byte sequence is valid utf-8 (given the check in `write()`).
let result = c::MultiByteToWideChar(
c::CP_UTF8, // CodePage
c::MB_ERR_INVALID_CHARS, // dwFlags
utf8.as_ptr() as c::LPCCH, // lpMultiByteStr
utf8.len() as c::c_int, // cbMultiByte
utf16.as_mut_ptr() as c::LPWSTR, // lpWideCharStr
utf16.len() as c::c_int, // cchWideChar
);
assert!(result != 0, "Unexpected error in MultiByteToWideChar");

// Safety: MultiByteToWideChar initializes `result` values.
MaybeUninit::slice_assume_init_ref(&utf16[..result as usize])
};

let mut written = write_u16s(handle, &utf16)?;

Expand All @@ -189,8 +202,8 @@ fn write_valid_utf8_to_console(handle: c::HANDLE, utf8: &str) -> io::Result<usiz
// a missing surrogate can be produced (and also because of the UTF-8 validation above),
// write the missing surrogate out now.
// Buffering it would mean we have to lie about the number of bytes written.
let first_char_remaining = utf16[written];
if first_char_remaining >= 0xDCEE && first_char_remaining <= 0xDFFF {
let first_code_unit_remaining = utf16[written];
if first_code_unit_remaining >= 0xDCEE && first_code_unit_remaining <= 0xDFFF {
// low surrogate
// We just hope this works, and give up otherwise
let _ = write_u16s(handle, &utf16[written..written + 1]);
Expand All @@ -212,6 +225,7 @@ fn write_valid_utf8_to_console(handle: c::HANDLE, utf8: &str) -> io::Result<usiz
}

fn write_u16s(handle: c::HANDLE, data: &[u16]) -> io::Result<usize> {
debug_assert!(data.len() < u32::MAX as usize);
let mut written = 0;
cvt(unsafe {
c::WriteConsoleW(
Expand Down Expand Up @@ -365,26 +379,32 @@ fn read_u16s(handle: c::HANDLE, buf: &mut [MaybeUninit<u16>]) -> io::Result<usiz
Ok(amount as usize)
}

#[allow(unused)]
fn utf16_to_utf8(utf16: &[u16], utf8: &mut [u8]) -> io::Result<usize> {
let mut written = 0;
for chr in char::decode_utf16(utf16.iter().cloned()) {
match chr {
Ok(chr) => {
chr.encode_utf8(&mut utf8[written..]);
written += chr.len_utf8();
}
Err(_) => {
// We can't really do any better than forget all data and return an error.
return Err(io::const_io_error!(
io::ErrorKind::InvalidData,
"Windows stdin in console mode does not support non-UTF-16 input; \
encountered unpaired surrogate",
));
}
}
debug_assert!(utf16.len() <= c::c_int::MAX as usize);
debug_assert!(utf8.len() <= c::c_int::MAX as usize);

let result = unsafe {
c::WideCharToMultiByte(
c::CP_UTF8, // CodePage
c::WC_ERR_INVALID_CHARS, // dwFlags
utf16.as_ptr(), // lpWideCharStr
utf16.len() as c::c_int, // cchWideChar
utf8.as_mut_ptr() as c::LPSTR, // lpMultiByteStr
utf8.len() as c::c_int, // cbMultiByte
ptr::null(), // lpDefaultChar
ptr::null_mut(), // lpUsedDefaultChar
)
};
if result == 0 {
// We can't really do any better than forget all data and return an error.
Err(io::const_io_error!(
io::ErrorKind::InvalidData,
"Windows stdin in console mode does not support non-UTF-16 input; \
encountered unpaired surrogate",
))
} else {
Ok(result as usize)
}
Ok(written)
}

impl IncompleteUtf8 {
Expand Down