diff --git a/lib/cublas/CUBLAS.jl b/lib/cublas/CUBLAS.jl index 30ef9e4df5..4093bac89b 100644 --- a/lib/cublas/CUBLAS.jl +++ b/lib/cublas/CUBLAS.jl @@ -210,56 +210,45 @@ end ## logging -const MAX_LOG_BUFLEN = UInt(1024*1024) -const log_buffer = Vector{UInt8}(undef, MAX_LOG_BUFLEN) -const log_cursor = Threads.Atomic{UInt}(0) -const log_cond = Ref{Base.AsyncCondition}() # root +# CUBLAS calls the log callback multiple times for each message, so we need to buffer them +const log_buffer = IOBuffer() function log_message(ptr) - # NOTE: this function may be called from unmanaged threads (by cublasXt), - # so we can't even allocate, let alone perform I/O. - len = @ccall strlen(ptr::Cstring)::Csize_t - old_cursor = log_cursor[] - new_cursor = old_cursor + len+1 - if new_cursor >= MAX_LOG_BUFLEN - # overrun - return - end + global log_buffer + str = unsafe_string(ptr) - @ccall memmove((pointer(log_buffer)+old_cursor)::Ptr{Nothing}, - pointer(ptr)::Ptr{Nothing}, (len+1)::Csize_t)::Nothing - log_cursor[] = new_cursor # the consumer handles CAS'ing this value + # flush if we've started a new log message + if startswith(str, r"[A-Z]!") + flush_log_messages() + end - # avoid code that depends on the runtime (even the unsafe_convert from ccall does?!) - assume(isassigned(log_cond)) - @ccall uv_async_send(log_cond[].handle::Ptr{Nothing})::Cint + # append the lines to the buffer + println(log_buffer, str) return end -function _log_message(blob) +function flush_log_messages() + global log_buffer + message = String(take!(log_buffer)) + isempty(message) && return + # the message format isn't documented, but it looks like a message starts with a capital # and the severity (e.g. `I!`), and subsequent lines start with a lowercase mark (`!i`) - # - # lines are separated by a \0 if they came in separately, but there may also be multiple - # actual lines separated by \n in each message. - for message in split(blob, r"[\0\n]+(?=[A-Z]!)") - code = message[1] - lines = split(message[3:end], r"[\0\n]+[a-z]!") - submessage = join(lines, '\n') - if code == 'I' - @debug submessage - elseif code == 'W' - @warn submessage - elseif code == 'E' - @error submessage - elseif code == 'F' - error(submessage) - else - @info "Unknown log message, please file an issue.\n$message" - end + code = message[1] + lines = split(message[3:end], r"\n+[a-z]!") + message = join(strip.(lines), '\n') + if code == 'I' + @debug message + elseif code == 'W' + @warn message + elseif code == 'E' + @error message + elseif code == 'F' + error(message) + else + @info "Unknown log message, please file an issue.\n$message" end - return end function __init__() @@ -273,21 +262,9 @@ function __init__() # register a log callback if !Sys.iswindows() && # NVIDIA bug #3321130 && !precompiling && (isdebug(:init, CUBLAS) || Base.JLOptions().debug_level >= 2) - log_cond[] = Base.AsyncCondition() do async_cond - blob = "" - while true - message_length = log_cursor[] - blob = unsafe_string(pointer(log_buffer), message_length) - if Threads.atomic_cas!(log_cursor, message_length, UInt(0)) == message_length - break - end - end - _log_message(blob) - return - end - callback = @cfunction(log_message, Nothing, (Cstring,)) cublasSetLoggerCallback(callback) + atexit(flush_log_messages) end end diff --git a/lib/cudnn/src/cuDNN.jl b/lib/cudnn/src/cuDNN.jl index 966775a6c4..5b4f16623d 100644 --- a/lib/cudnn/src/cuDNN.jl +++ b/lib/cudnn/src/cuDNN.jl @@ -116,10 +116,6 @@ end ## logging -const log_messages = [] -const log_lock = ReentrantLock() -const log_cond = Ref{Any}() # root - function log_message(sev, udata, dbg_ptr, ptr) dbg = unsafe_load(dbg_ptr) @@ -131,20 +127,11 @@ function log_message(sev, udata, dbg_ptr, ptr) end len += 1 end - str = unsafe_string(ptr, len) # XXX: can this yield? - - # print asynchronously - Base.@lock log_lock begin - push!(log_messages, (; sev, dbg, str)) - end - ccall(:uv_async_send, Cint, (Ptr{Cvoid},), udata) + str = unsafe_string(ptr, len) - return -end - -function _log_message(sev, dbg, str) + # split into lines and report lines = split(str, '\0') - msg = join(lines, '\n') + msg = join(strip.(lines), '\n') if sev == CUDNN_SEV_INFO @debug msg elseif sev == CUDNN_SEV_WARNING @@ -154,6 +141,7 @@ function _log_message(sev, dbg, str) elseif sev == CUDNN_SEV_FATAL error(msg) end + return end @@ -182,18 +170,9 @@ function __init__() # register a log callback if !precompiling && (isdebug(:init, cuDNN) || Base.JLOptions().debug_level >= 2) - log_cond[] = Base.AsyncCondition() do async_cond - Base.@lock log_lock begin - while length(log_messages) > 0 - message = popfirst!(log_messages) - _log_message(message...) - end - end - end - callback = @cfunction(log_message, Nothing, (cudnnSeverity_t, Ptr{Cvoid}, Ptr{cudnnDebug_t}, Ptr{UInt8})) - cudnnSetCallback(typemax(UInt32), log_cond[], callback) + cudnnSetCallback(typemax(UInt32), C_NULL, callback) end _initialized[] = true