From 82615de00e877d3a5538508506090cfae721d375 Mon Sep 17 00:00:00 2001 From: Elliot Saba Date: Fri, 10 Jun 2022 18:33:33 +0000 Subject: [PATCH 1/2] Throw `ArgumentError` if `unsafe_SecretBuffer!()` is passed NULL Previously, if given a NULL `Cstring` we would blithely call `strlen()` on it, which resulted in a segfault. It is better if we throw an exception instead. --- base/secretbuffer.jl | 11 ++++++++++- test/secretbuffer.jl | 7 +++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/base/secretbuffer.jl b/base/secretbuffer.jl index 02a133be088f0..935c50fb80fd6 100644 --- a/base/secretbuffer.jl +++ b/base/secretbuffer.jl @@ -79,8 +79,17 @@ function SecretBuffer!(d::Vector{UInt8}) s end -unsafe_SecretBuffer!(s::Cstring) = unsafe_SecretBuffer!(convert(Ptr{UInt8}, s), Int(ccall(:strlen, Csize_t, (Cstring,), s))) +function unsafe_SecretBuffer!(s::Cstring) + if s == C_NULL + throw(ArgumentError("cannot convert NULL to SecretBuffer")) + end + len = Int(ccall(:strlen, Csize_t, (Cstring,), s)) + unsafe_SecretBuffer!(convert(Ptr{UInt8}, s), len) +end function unsafe_SecretBuffer!(p::Ptr{UInt8}, len=1) + if p == C_NULL + throw(ArgumentError("cannot convert NULL to SecretBuffer")) + end s = SecretBuffer(sizehint=len) for i in 1:len write(s, unsafe_load(p, i)) diff --git a/test/secretbuffer.jl b/test/secretbuffer.jl index df67204dd63ba..976c757deea57 100644 --- a/test/secretbuffer.jl +++ b/test/secretbuffer.jl @@ -122,4 +122,11 @@ using Test @test hash(sb1, UInt(5)) === hash(sb2, UInt(5)) shred!(sb1); shred!(sb2) end + @testset "NULL initialization" begin + null_ptr = Cstring(C_NULL) + @test_throws ArgumentError Base.unsafe_SecretBuffer!(null_ptr) + null_ptr = Ptr{UInt8}(C_NULL) + @test_throws ArgumentError Base.unsafe_SecretBuffer!(null_ptr) + @test_throws ArgumentError Base.unsafe_SecretBuffer!(null_ptr, 0) + end end From 22042bee4b26f4eb05ea31ac35017467288e4941 Mon Sep 17 00:00:00 2001 From: Elliot Saba Date: Fri, 10 Jun 2022 19:45:18 +0000 Subject: [PATCH 2/2] Add `jl_getch()`, use it from `getpass()` This works around the lack of `getch()` on non-windows platforms, such that we can use the windows-specific `getpass()` on all platforms. This was necessary to prevent breakage from `musl` due to a bad interaction between `with_fake_pty()` and `getpass()`. --- base/util.jl | 15 ++++----------- src/jl_exported_funcs.inc | 1 + src/julia.h | 1 + src/sys.c | 28 ++++++++++++++++++++++++++++ 4 files changed, 34 insertions(+), 11 deletions(-) diff --git a/base/util.jl b/base/util.jl index df9e29790deb6..46e7f36475b98 100644 --- a/base/util.jl +++ b/base/util.jl @@ -257,7 +257,7 @@ graphical interface. """ function getpass end -if Sys.iswindows() +_getch() = UInt8(ccall(:jl_getch, Cint, ())) function getpass(input::TTY, output::IO, prompt::AbstractString) input === stdin || throw(ArgumentError("getpass only works for stdin")) print(output, prompt, ": ") @@ -265,11 +265,11 @@ function getpass(input::TTY, output::IO, prompt::AbstractString) s = SecretBuffer() plen = 0 while true - c = UInt8(ccall(:_getch, Cint, ())) - if c == 0xff || c == UInt8('\n') || c == UInt8('\r') + c = _getch() + if c == 0xff || c == UInt8('\n') || c == UInt8('\r') || c == 0x04 break # EOF or return elseif c == 0x00 || c == 0xe0 - ccall(:_getch, Cint, ()) # ignore function/arrow keys + _getch() # ignore function/arrow keys elseif c == UInt8('\b') && plen > 0 plen -= 1 # delete last character on backspace elseif !iscntrl(Char(c)) && plen < 128 @@ -278,13 +278,6 @@ function getpass(input::TTY, output::IO, prompt::AbstractString) end return seekstart(s) end -else -function getpass(input::TTY, output::IO, prompt::AbstractString) - (input === stdin && output === stdout) || throw(ArgumentError("getpass only works for stdin")) - msg = string(prompt, ": ") - unsafe_SecretBuffer!(ccall(:getpass, Cstring, (Cstring,), msg)) -end -end # allow new getpass methods to be defined if stdin has been # redirected to some custom stream, e.g. in IJulia. diff --git a/src/jl_exported_funcs.inc b/src/jl_exported_funcs.inc index ffa12d0d5f040..72d385329ce49 100644 --- a/src/jl_exported_funcs.inc +++ b/src/jl_exported_funcs.inc @@ -199,6 +199,7 @@ XX(jl_generic_function_def) \ XX(jl_gensym) \ XX(jl_getallocationgranularity) \ + XX(jl_getch) \ XX(jl_getnameinfo) \ XX(jl_getpagesize) \ XX(jl_get_ARCH) \ diff --git a/src/julia.h b/src/julia.h index b4c159a93d70a..1dfd6ea239d77 100644 --- a/src/julia.h +++ b/src/julia.h @@ -2057,6 +2057,7 @@ extern JL_DLLEXPORT JL_STREAM *JL_STDERR; JL_DLLEXPORT JL_STREAM *jl_stdout_stream(void); JL_DLLEXPORT JL_STREAM *jl_stdin_stream(void); JL_DLLEXPORT JL_STREAM *jl_stderr_stream(void); +JL_DLLEXPORT int jl_getch(void); // showing and std streams JL_DLLEXPORT void jl_flush_cstdio(void) JL_NOTSAFEPOINT; diff --git a/src/sys.c b/src/sys.c index bc21d065f55a3..2f512888c1873 100644 --- a/src/sys.c +++ b/src/sys.c @@ -27,6 +27,9 @@ #include #include #include + +// For `struct termios` +#include #endif #ifndef _OS_WINDOWS_ @@ -514,6 +517,31 @@ JL_DLLEXPORT JL_STREAM *jl_stdin_stream(void) { return JL_STDIN; } JL_DLLEXPORT JL_STREAM *jl_stdout_stream(void) { return JL_STDOUT; } JL_DLLEXPORT JL_STREAM *jl_stderr_stream(void) { return JL_STDERR; } +// terminal workarounds +JL_DLLEXPORT int jl_getch(void) JL_NOTSAFEPOINT +{ +#if defined(_OS_WINDOWS_) + // Windows has an actual `_getch()`, use that: + return _getch(); +#else + // On all other platforms, we do the POSIX terminal manipulation dance + char c; + int r; + struct termios old_termios = {0}; + struct termios new_termios = {0}; + if (tcgetattr(0, &old_termios) != 0) + return -1; + new_termios = old_termios; + cfmakeraw(&new_termios); + if (tcsetattr(0, TCSADRAIN, &new_termios) != 0) + return -1; + r = read(0, &c, 1); + if (tcsetattr(0, TCSADRAIN, &old_termios) != 0) + return -1; + return r == 1 ? c : -1; +#endif +} + // -- processor native alignment information -- JL_DLLEXPORT void jl_native_alignment(uint_t *int8align, uint_t *int16align, uint_t *int32align,