diff --git a/common/src/jni/main/cpp/org_conscrypt_NativeCrypto.cpp b/common/src/jni/main/cpp/org_conscrypt_NativeCrypto.cpp index e0e78a4f4..3d668c508 100644 --- a/common/src/jni/main/cpp/org_conscrypt_NativeCrypto.cpp +++ b/common/src/jni/main/cpp/org_conscrypt_NativeCrypto.cpp @@ -33,10 +33,10 @@ #include #else #include + #include #endif #include -#include #include #include #include @@ -6434,6 +6434,123 @@ class AppData { } }; +#ifdef _WIN32 + +/** + * Dark magic helper function that checks, for a given SSL session, whether it + * can SSL_read() or SSL_write() without blocking. Takes into account any + * concurrent attempts to close the SSLSocket from the Java side. This is + * needed to get rid of the hangs that occur when thread #1 closes the SSLSocket + * while thread #2 is sitting in a blocking read or write. The type argument + * specifies whether we are waiting for readability or writability. It expects + * to be passed either SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE, since we + * only need to wait in case one of these problems occurs. + * + * @param env + * @param type Either SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE + * @param fdObject The FileDescriptor, since appData->fileDescriptor should be NULL + * @param appData The application data structure with mutex info etc. + * @param timeout_millis The timeout value for select call, with the special value + * 0 meaning no timeout at all (wait indefinitely). Note: This is + * the Java semantics of the timeout value, not the usual + * select() semantics. + * @return The result of the inner select() call, + * THROW_SOCKETEXCEPTION if a SocketException was thrown, -1 on + * additional errors + */ +static int sslSelect(JNIEnv* env, int type, jobject fdObject, AppData* appData, int timeout_millis) { + // This loop is an expanded version of the NET_FAILURE_RETRY + // macro. It cannot simply be used in this case because select + // cannot be restarted without recreating the fd_sets and timeout + // structure. + int result; + fd_set rfds; + fd_set wfds; + do { + NetFd fd(env, fdObject); + if (fd.isClosed()) { + result = THROWN_EXCEPTION; + break; + } + int intFd = fd.get(); + JNI_TRACE("sslSelect type=%s fd=%d appData=%p timeout_millis=%d", + (type == SSL_ERROR_WANT_READ) ? "READ" : "WRITE", intFd, appData, timeout_millis); + + FD_ZERO(&rfds); + FD_ZERO(&wfds); + + if (type == SSL_ERROR_WANT_READ) { + FD_SET(intFd, &rfds); + } else { + FD_SET(intFd, &wfds); + } + + FD_SET(appData->fdsEmergency[0], &rfds); + + int maxFd = (intFd > appData->fdsEmergency[0]) ? intFd : appData->fdsEmergency[0]; + + // Build a struct for the timeout data if we actually want a timeout. + timeval tv; + timeval* ptv; + if (timeout_millis > 0) { + tv.tv_sec = timeout_millis / 1000; + tv.tv_usec = (timeout_millis % 1000) * 1000; + ptv = &tv; + } else { + ptv = NULL; + } + +#ifndef CONSCRYPT_UNBUNDLED + AsynchronousCloseMonitor monitor(intFd); +#else + CompatibilityCloseMonitor monitor(intFd); +#endif + result = select(maxFd + 1, &rfds, &wfds, NULL, ptv); + JNI_TRACE("sslSelect %s fd=%d appData=%p timeout_millis=%d => %d", + (type == SSL_ERROR_WANT_READ) ? "READ" : "WRITE", + fd.get(), appData, timeout_millis, result); + if (result == -1) { + if (fd.isClosed()) { + result = THROWN_EXCEPTION; + break; + } + if (errno != EINTR) { + break; + } + } + } while (result == -1); + + if (MUTEX_LOCK(appData->mutex) == -1) { + return -1; + } + + if (result > 0) { + // We have been woken up by a token in the emergency pipe. We + // can't be sure the token is still in the pipe at this point + // because it could have already been read by the thread that + // originally wrote it if it entered sslSelect and acquired + // the mutex before we did. Thus we cannot safely read from + // the pipe in a blocking way (so we make the pipe + // non-blocking at creation). + if (FD_ISSET(appData->fdsEmergency[0], &rfds)) { + char token; + do { + (void) read(appData->fdsEmergency[0], &token, 1); + } while (errno == EINTR); + } + } + + // Tell the world that there is now one thread less waiting for the + // underlying network. + appData->waitingThreads--; + + MUTEX_UNLOCK(appData->mutex); + + return result; +} + +#else // !defined(_WIN32) + /** * Dark magic helper function that checks, for a given SSL session, whether it * can SSL_read() or SSL_write() without blocking. Takes into account any @@ -6531,6 +6648,7 @@ static int sslSelect(JNIEnv* env, int type, jobject fdObject, AppData* appData, return result; } +#endif // !defined(_WIN32) /** * Helper function that wakes up a thread blocked in select(), in case there is diff --git a/common/src/jni/unbundled/cpp/JNIHelp.cpp b/common/src/jni/unbundled/cpp/JNIHelp.cpp index 7314833d6..1d83f6d1d 100644 --- a/common/src/jni/unbundled/cpp/JNIHelp.cpp +++ b/common/src/jni/unbundled/cpp/JNIHelp.cpp @@ -46,6 +46,43 @@ #include #include +#ifdef _WIN32 + // Windows uses strerror_s instead of strerror_r. + #define strerror_r(errno,buf,len) strerror_s(buf,len,errno) + + // Windows doesn't define this either *sigh*... + int vasprintf(char **ret, const char *format, va_list args) + { + va_list copy; + va_copy(copy, args); + + *ret = NULL; + + int count = vsnprintf(NULL, 0, format, args); + if (count >= 0) + { + char* buffer = (char*) malloc(count + 1); + if (buffer == NULL) + count = -1; + else if ((count = vsnprintf(buffer, count + 1, format, copy)) < 0) + free(buffer); + else + *ret = buffer; + } + va_end(copy); // Each va_start() or va_copy() needs a va_end() + + return count; + } + + int asprintf(char **strp, const char *fmt, ...) { + va_list ap; + va_start(ap, fmt); + int r = vasprintf(strp, fmt, ap); + va_end(ap); + return r; + } +#endif + /** * Equivalent to ScopedLocalRef, but slightly more powerful. */ diff --git a/common/src/jni/unbundled/include/JNIHelp.h b/common/src/jni/unbundled/include/JNIHelp.h index 1fca1db0d..23b4e38b5 100644 --- a/common/src/jni/unbundled/include/JNIHelp.h +++ b/common/src/jni/unbundled/include/JNIHelp.h @@ -28,7 +28,12 @@ #include "jni.h" #include -#include + +#ifdef _WIN32 + #include +#else + #include +#endif #ifndef NELEM # define NELEM(x) ((int) (sizeof(x) / sizeof((x)[0])))