diff --git a/napi-inl.h b/napi-inl.h index d5fbb1bf7..8212c0729 100644 --- a/napi-inl.h +++ b/napi-inl.h @@ -3618,6 +3618,159 @@ inline void AsyncWorker::OnWorkComplete( delete self; } +//////////////////////////////////////////////////////////////////////////////// +// ThreadSafeFunction class +//////////////////////////////////////////////////////////////////////////////// + +// static +template +inline ThreadSafeFunction ThreadSafeFunction::New(napi_env env, + const Function& callback, + const Object& resource, + ResourceString resourceName, + size_t maxQueueSize, + size_t initialThreadCount, + DataType* data, + Finalizer finalizeCallback, + Context* context) { + static_assert(details::can_make_string::value + || std::is_convertible::value, + "Resource name should be string convertible type"); + + napi_threadsafe_function tsFunctionValue; + auto* finalizeData = new details::FinalizeData({ + finalizeCallback, context }); + napi_status status = napi_create_threadsafe_function(env, callback, resource, + Value::From(env, resourceName), maxQueueSize, initialThreadCount, data, + details::FinalizeData::WrapperWithHint, + finalizeData, CallJS, &tsFunctionValue); + if (status != napi_ok) { + delete finalizeData; + NAPI_THROW_IF_FAILED(env, status, ThreadSafeFunction()); + } + + return ThreadSafeFunction(env, tsFunctionValue); +} + +inline ThreadSafeFunction::Status ThreadSafeFunction::BlockingCall() const { + return CallInternal(nullptr, napi_tsfn_blocking); +} + +template +inline ThreadSafeFunction::Status ThreadSafeFunction::BlockingCall( + Callback callback) const { + return CallInternal(new CallbackWrapper(callback), napi_tsfn_blocking); +} + +template +inline ThreadSafeFunction::Status ThreadSafeFunction::BlockingCall( + DataType* data, Callback callback) const { + auto wrapper = [data, callback](Env env, Function jsCallback) { + callback(env, jsCallback, data); + }; + return CallInternal(new CallbackWrapper(wrapper), napi_tsfn_blocking); +} + +inline ThreadSafeFunction::Status ThreadSafeFunction::NonBlockingCall() const { + return CallInternal(nullptr, napi_tsfn_nonblocking); +} + +template +inline ThreadSafeFunction::Status ThreadSafeFunction::NonBlockingCall( + Callback callback) const { + return CallInternal(new CallbackWrapper(callback), napi_tsfn_nonblocking); +} + +template +inline ThreadSafeFunction::Status ThreadSafeFunction::NonBlockingCall( + DataType* data, Callback callback) const { + auto wrapper = [data, callback](Env env, Function jsCallback) { + callback(env, jsCallback, data); + }; + return CallInternal(new CallbackWrapper(wrapper), napi_tsfn_nonblocking); +} + +inline bool ThreadSafeFunction::Acquire() const { + return !IsAborted() && napi_acquire_threadsafe_function( + _tsFunctionValue) == napi_ok; +} + +inline bool ThreadSafeFunction::Release() { + return !IsAborted() && napi_release_threadsafe_function( + _tsFunctionValue, napi_tsfn_release) == napi_ok; +} + +inline bool ThreadSafeFunction::Abort() { + if (IsAborted()) { + return false; + } + + napi_status status = napi_release_threadsafe_function( + _tsFunctionValue, napi_tsfn_abort); + + _tsFunctionValue = nullptr; + _env = nullptr; + + return status == napi_ok; +} + +inline bool ThreadSafeFunction::IsAborted() const { + return _env == nullptr || _tsFunctionValue == nullptr; +} + +inline ThreadSafeFunction::ThreadSafeFunction() + : _env(nullptr), + _tsFunctionValue(nullptr) { +} + +inline ThreadSafeFunction::ThreadSafeFunction( + napi_env env, napi_threadsafe_function tsFunctionValue) + : _env(env), + _tsFunctionValue(tsFunctionValue) { +} + +inline ThreadSafeFunction::Status ThreadSafeFunction::CallInternal( + CallbackWrapper* callbackWrapper, + napi_threadsafe_function_call_mode mode) const { + if (IsAborted()) { + return CLOSE; + } + napi_status status = napi_call_threadsafe_function( + _tsFunctionValue, callbackWrapper, mode); + if (status != napi_ok && callbackWrapper != nullptr) { + delete callbackWrapper; + } + + switch (status) { + case napi_ok: + return OK; + case napi_closing: + return CLOSE; + case napi_queue_full: + return FULL; + default: + return ERROR; + } +} + +// static +inline void ThreadSafeFunction::CallJS(napi_env env, + napi_value jsCallback, + void* /* context */, + void* data) { + if (env == nullptr && jsCallback == nullptr) + return; + + if (data != nullptr) { + auto* callbackWrapper = static_cast(data); + (*callbackWrapper)(env, Function(env, jsCallback)); + delete callbackWrapper; + } else { + Function(env, jsCallback).Call({}); + } +} + //////////////////////////////////////////////////////////////////////////////// // Memory Management class //////////////////////////////////////////////////////////////////////////////// diff --git a/napi.h b/napi.h index 13ac4853f..818e2c193 100644 --- a/napi.h +++ b/napi.h @@ -1818,6 +1818,68 @@ namespace Napi { std::string _error; }; + class ThreadSafeFunction { + public: + enum Status { + CLOSE, + FULL, + ERROR, + OK + }; + + template + static ThreadSafeFunction New(napi_env env, + const Function& callback, + const Object& resource, + ResourceString resourceName, + size_t maxQueueSize, + size_t initialThreadCount, + DataType* data, + Finalizer finalizeCallback, + Context* context); + + ThreadSafeFunction(); + + Status BlockingCall() const; + + template + Status BlockingCall(Callback callback) const; + + template + Status BlockingCall(DataType* data, Callback callback) const; + + Status NonBlockingCall() const; + + template + Status NonBlockingCall(Callback callback) const; + + template + Status NonBlockingCall(DataType* data, Callback callback) const; + + bool Acquire() const; + bool Release(); + bool Abort(); + + bool IsAborted() const; + + private: + using CallbackWrapper = std::function; + + ThreadSafeFunction(napi_env env, napi_threadsafe_function tsFunctionValue); + + Status CallInternal(CallbackWrapper* callbackWrapper, + napi_threadsafe_function_call_mode mode) const; + + static void CallJS(napi_env env, + napi_value jsCallback, + void* context, + void* data); + + napi_env _env; + napi_threadsafe_function _tsFunctionValue; + }; + // Memory management. class MemoryManagement { public: diff --git a/test/binding.cc b/test/binding.cc index 0068e7480..122be27f8 100644 --- a/test/binding.cc +++ b/test/binding.cc @@ -32,6 +32,9 @@ Object InitObject(Env env); Object InitObjectDeprecated(Env env); #endif // !NODE_ADDON_API_DISABLE_DEPRECATED Object InitPromise(Env env); +#if (NAPI_VERSION > 3) +Object InitThreadSafeFunction(Env env); +#endif Object InitTypedArray(Env env); Object InitObjectWrap(Env env); Object InitObjectReference(Env env); @@ -69,6 +72,9 @@ Object Init(Env env, Object exports) { exports.Set("object_deprecated", InitObjectDeprecated(env)); #endif // !NODE_ADDON_API_DISABLE_DEPRECATED exports.Set("promise", InitPromise(env)); +#if (NAPI_VERSION > 3) + exports.Set("threadsafe_function", InitThreadSafeFunction(env)); +#endif exports.Set("typedarray", InitTypedArray(env)); exports.Set("objectwrap", InitObjectWrap(env)); exports.Set("objectreference", InitObjectReference(env)); diff --git a/test/binding.gyp b/test/binding.gyp index 77c388074..312d8d1ca 100644 --- a/test/binding.gyp +++ b/test/binding.gyp @@ -31,6 +31,7 @@ 'object/object.cc', 'object/set_property.cc', 'promise.cc', + 'threadsafe_function/threadsafe_function.cc', 'typedarray.cc', 'objectwrap.cc', 'objectreference.cc', diff --git a/test/index.js b/test/index.js index 51fce464c..10553e611 100644 --- a/test/index.js +++ b/test/index.js @@ -34,6 +34,7 @@ let testModules = [ 'object/object_deprecated', 'object/set_property', 'promise', + 'threadsafe_function/threadsafe_function', 'typedarray', 'typedarray-bigint', 'objectwrap', diff --git a/test/threadsafe_function/threadsafe_function.cc b/test/threadsafe_function/threadsafe_function.cc new file mode 100644 index 000000000..cbf70908f --- /dev/null +++ b/test/threadsafe_function/threadsafe_function.cc @@ -0,0 +1,186 @@ +#include +#include "napi.h" + +using namespace Napi; + +constexpr size_t ARRAY_LENGTH = 10; +constexpr size_t MAX_QUEUE_SIZE = 2; + +static uv_thread_t uvThreads[2]; +static ThreadSafeFunction tsfn; + +struct ThreadSafeFunctionInfo { + enum CallType { + DEFAULT, + BLOCKING, + NON_BLOCKING + } type; + bool abort; + bool startSecondary; + FunctionReference jsFinalizeCallback; + uint32_t maxQueueSize; +} tsfnInfo; + +// Thread data to transmit to JS +static int ints[ARRAY_LENGTH]; + +static void SecondaryThread(void* data) { + ThreadSafeFunction* tsFunction = static_cast(data); + + if (!tsFunction->Release()) { + Error::Fatal("SecondaryThread", "ThreadSafeFunction.Release() failed"); + } +} + +// Source thread producing the data +static void DataSourceThread(void* data) { + ThreadSafeFunction* tsFunction = static_cast(data); + + // FIXME: The `info` should be come from "GetContext()" method + ThreadSafeFunctionInfo* info = &tsfnInfo; + + if (info->startSecondary) { + if (!tsFunction->Acquire()) { + Error::Fatal("DataSourceThread", "ThreadSafeFunction.Acquire() failed"); + } + + if (uv_thread_create(&uvThreads[1], SecondaryThread, tsFunction) != 0) { + Error::Fatal("DataSourceThread", "Failed to start secondary thread"); + } + } + + bool queueWasFull = false; + bool queueWasClosing = false; + for (int index = ARRAY_LENGTH - 1; index > -1 && !queueWasClosing; index--) { + ThreadSafeFunction::Status status = ThreadSafeFunction::ERROR; + auto callback = [](Env env, Function jsCallback, int* data) { + jsCallback.Call({ Number::New(env, *data) }); + }; + + switch (info->type) { + case ThreadSafeFunctionInfo::DEFAULT: + status = tsFunction->BlockingCall(); + break; + case ThreadSafeFunctionInfo::BLOCKING: + status = tsFunction->BlockingCall(&ints[index], callback); + break; + case ThreadSafeFunctionInfo::NON_BLOCKING: + status = tsFunction->NonBlockingCall(&ints[index], callback); + break; + } + + if (info->maxQueueSize == 0) { + // Let's make this thread really busy for 200 ms to give the main thread a + // chance to abort. + uint64_t start = uv_hrtime(); + for (; uv_hrtime() - start < 200000000;); + } + + switch (status) { + case ThreadSafeFunction::FULL: + queueWasFull = true; + index++; + // fall through + + case ThreadSafeFunction::OK: + continue; + + case ThreadSafeFunction::CLOSE: + queueWasClosing = true; + break; + + default: + Error::Fatal("DataSourceThread", "ThreadSafeFunction.*Call() failed"); + } + } + + if (info->type == ThreadSafeFunctionInfo::NON_BLOCKING && !queueWasFull) { + Error::Fatal("DataSourceThread", "Queue was never full"); + } + + if (info->abort && !queueWasClosing) { + Error::Fatal("DataSourceThread", "Queue was never closing"); + } + + if (!queueWasClosing && !tsFunction->Release()) { + Error::Fatal("DataSourceThread", "ThreadSafeFunction.Release() failed"); + } +} + +static Value StopThread(const CallbackInfo& info) { + tsfnInfo.jsFinalizeCallback = Napi::Persistent(info[0].As()); + bool abort = info[1].As(); + if (abort) { + tsfn.Abort(); + } else { + tsfn.Release(); + } + return Value(); +} + +// Join the thread and inform JS that we're done. +static void JoinTheThreads(Env /* env */, + uv_thread_t* theThreads, + ThreadSafeFunctionInfo* info) { + uv_thread_join(&theThreads[0]); + if (info->startSecondary) { + uv_thread_join(&theThreads[1]); + } + + info->jsFinalizeCallback.Call({}); + info->jsFinalizeCallback.Reset(); +} + +static Value StartThreadInternal(const CallbackInfo& info, + ThreadSafeFunctionInfo::CallType type) { + tsfnInfo.type = type; + tsfnInfo.abort = info[1].As(); + tsfnInfo.startSecondary = info[2].As(); + tsfnInfo.maxQueueSize = info[3].As().Uint32Value(); + + tsfn = ThreadSafeFunction::New(info.Env(), info[0].As(), Object(), + "Test", tsfnInfo.maxQueueSize, 2, uvThreads, JoinTheThreads, &tsfnInfo); + + if (uv_thread_create(&uvThreads[0], DataSourceThread, &tsfn) != 0) { + Error::Fatal("StartThreadInternal", "Failed to start data source thread"); + } + + return Value(); +} + +static Value Release(const CallbackInfo& /* info */) { + if (!tsfn.Release()) { + Error::Fatal("Release", "ThreadSafeFunction.Release() failed"); + } + return Value(); +} + +static Value StartThread(const CallbackInfo& info) { + return StartThreadInternal(info, ThreadSafeFunctionInfo::BLOCKING); +} + +static Value StartThreadNonblocking(const CallbackInfo& info) { + return StartThreadInternal(info, ThreadSafeFunctionInfo::NON_BLOCKING); +} + +static Value StartThreadNoNative(const CallbackInfo& info) { + return StartThreadInternal(info, ThreadSafeFunctionInfo::DEFAULT); +} + +Object InitThreadSafeFunction(Env env) { + for (size_t index = 0; index < ARRAY_LENGTH; index++) { + ints[index] = index; + } + + Object exports = Object::New(env); + exports["ARRAY_LENGTH"] = Number::New(env, ARRAY_LENGTH); + exports["MAX_QUEUE_SIZE"] = Number::New(env, MAX_QUEUE_SIZE); + exports["startThread"] = Function::New(env, StartThread); + exports["startThreadNoNative"] = Function::New(env, StartThreadNoNative); + exports["startThreadNonblocking"] = + Function::New(env, StartThreadNonblocking); + exports["stopThread"] = Function::New(env, StopThread); + exports["release"] = Function::New(env, Release); + + return exports; +} diff --git a/test/threadsafe_function/threadsafe_function.js b/test/threadsafe_function/threadsafe_function.js new file mode 100644 index 000000000..710c21212 --- /dev/null +++ b/test/threadsafe_function/threadsafe_function.js @@ -0,0 +1,170 @@ +'use strict'; + +const buildType = process.config.target_defaults.default_configuration; +const assert = require('assert'); +const common = require('../common'); + +test(require(`../build/${buildType}/binding.node`)); +test(require(`../build/${buildType}/binding_noexcept.node`)); + +function test(binding) { + const expectedArray = (function(arrayLength) { + const result = []; + for (let index = 0; index < arrayLength; index++) { + result.push(arrayLength - 1 - index); + } + return result; + })(binding.threadsafe_function.ARRAY_LENGTH); + + function testWithJSMarshaller({ + threadStarter, + quitAfter, + abort, + maxQueueSize, + launchSecondary }) { + return new Promise((resolve) => { + const array = []; + binding.threadsafe_function[threadStarter](function testCallback(value) { + array.push(value); + if (array.length === quitAfter) { + setImmediate(() => { + binding.threadsafe_function.stopThread(common.mustCall(() => { + resolve(array); + }), !!abort); + }); + } + }, !!abort, !!launchSecondary, maxQueueSize); + if (threadStarter === 'startThreadNonblocking') { + // Let's make this thread really busy for a short while to ensure that + // the queue fills and the thread receives a napi_queue_full. + const start = Date.now(); + while (Date.now() - start < 200); + } + }); + } + + new Promise(function testWithoutJSMarshaller(resolve) { + let callCount = 0; + binding.threadsafe_function.startThreadNoNative(function testCallback() { + callCount++; + + // The default call-into-JS implementation passes no arguments. + assert.strictEqual(arguments.length, 0); + if (callCount === binding.threadsafe_function.ARRAY_LENGTH) { + setImmediate(() => { + binding.threadsafe_function.stopThread(common.mustCall(() => { + resolve(); + }), false); + }); + } + }, false /* abort */, false /* launchSecondary */, + binding.threadsafe_function.MAX_QUEUE_SIZE); + }) + + // Start the thread in blocking mode, and assert that all values are passed. + // Quit after it's done. + .then(() => testWithJSMarshaller({ + threadStarter: 'startThread', + maxQueueSize: binding.threadsafe_function.MAX_QUEUE_SIZE, + quitAfter: binding.threadsafe_function.ARRAY_LENGTH + })) + .then((result) => assert.deepStrictEqual(result, expectedArray)) + + // Start the thread in blocking mode with an infinite queue, and assert that + // all values are passed. Quit after it's done. + .then(() => testWithJSMarshaller({ + threadStarter: 'startThread', + maxQueueSize: 0, + quitAfter: binding.threadsafe_function.ARRAY_LENGTH + })) + .then((result) => assert.deepStrictEqual(result, expectedArray)) + + // Start the thread in non-blocking mode, and assert that all values are + // passed. Quit after it's done. + .then(() => testWithJSMarshaller({ + threadStarter: 'startThreadNonblocking', + maxQueueSize: binding.threadsafe_function.MAX_QUEUE_SIZE, + quitAfter: binding.threadsafe_function.ARRAY_LENGTH + })) + .then((result) => assert.deepStrictEqual(result, expectedArray)) + + // Start the thread in blocking mode, and assert that all values are passed. + // Quit early, but let the thread finish. + .then(() => testWithJSMarshaller({ + threadStarter: 'startThread', + maxQueueSize: binding.threadsafe_function.MAX_QUEUE_SIZE, + quitAfter: 1 + })) + .then((result) => assert.deepStrictEqual(result, expectedArray)) + + // Start the thread in blocking mode with an infinite queue, and assert that + // all values are passed. Quit early, but let the thread finish. + .then(() => testWithJSMarshaller({ + threadStarter: 'startThread', + maxQueueSize: 0, + quitAfter: 1 + })) + .then((result) => assert.deepStrictEqual(result, expectedArray)) + + + // Start the thread in non-blocking mode, and assert that all values are + // passed. Quit early, but let the thread finish. + .then(() => testWithJSMarshaller({ + threadStarter: 'startThreadNonblocking', + maxQueueSize: binding.threadsafe_function.MAX_QUEUE_SIZE, + quitAfter: 1 + })) + .then((result) => assert.deepStrictEqual(result, expectedArray)) + + // Start the thread in blocking mode, and assert that all values are passed. + // Quit early, but let the thread finish. Launch a secondary thread to test + // the reference counter incrementing functionality. + .then(() => testWithJSMarshaller({ + threadStarter: 'startThread', + quitAfter: 1, + maxQueueSize: binding.threadsafe_function.MAX_QUEUE_SIZE, + launchSecondary: true + })) + .then((result) => assert.deepStrictEqual(result, expectedArray)) + + // Start the thread in non-blocking mode, and assert that all values are + // passed. Quit early, but let the thread finish. Launch a secondary thread + // to test the reference counter incrementing functionality. + .then(() => testWithJSMarshaller({ + threadStarter: 'startThreadNonblocking', + quitAfter: 1, + maxQueueSize: binding.threadsafe_function.MAX_QUEUE_SIZE, + launchSecondary: true + })) + .then((result) => assert.deepStrictEqual(result, expectedArray)) + + // Start the thread in blocking mode, and assert that it could not finish. + // Quit early by aborting. + .then(() => testWithJSMarshaller({ + threadStarter: 'startThread', + quitAfter: 1, + maxQueueSize: binding.threadsafe_function.MAX_QUEUE_SIZE, + abort: true + })) + .then((result) => assert.strictEqual(result.indexOf(0), -1)) + + // Start the thread in blocking mode with an infinite queue, and assert that + // it could not finish. Quit early by aborting. + .then(() => testWithJSMarshaller({ + threadStarter: 'startThread', + quitAfter: 1, + maxQueueSize: 0, + abort: true + })) + .then((result) => assert.strictEqual(result.indexOf(0), -1)) + + // Start the thread in non-blocking mode, and assert that it could not finish. + // Quit early and aborting. + .then(() => testWithJSMarshaller({ + threadStarter: 'startThreadNonblocking', + quitAfter: 1, + maxQueueSize: binding.threadsafe_function.MAX_QUEUE_SIZE, + abort: true + })) + .then((result) => assert.strictEqual(result.indexOf(0), -1)) +}