diff --git a/.gitignore b/.gitignore index 66fa6d0a2..f1880c096 100644 --- a/.gitignore +++ b/.gitignore @@ -193,12 +193,6 @@ packages/ *.lib nativebin/ /cs/**/launchSettings.json -/cs/benchmark/Properties/launchSettings.json - -# JetBrains IDE -.idea/ - - # JetBrains cs/.idea/ diff --git a/azure-pipelines-full.yml b/azure-pipelines-full.yml new file mode 100644 index 000000000..8877af88b --- /dev/null +++ b/azure-pipelines-full.yml @@ -0,0 +1,258 @@ +variables: + solution: 'cs/FASTER.sln' + solutionRemote: 'cs/remote/FASTER.remote.sln' + RunAzureTests: 'yes' + +jobs: +- job: 'csharpWindows' + pool: + vmImage: windows-latest + displayName: 'C# (Windows)' + timeoutInMinutes: 75 + + strategy: + maxParallel: 2 + matrix: + AnyCPU-Debug: + buildPlatform: 'Any CPU' + buildConfiguration: 'Debug' + AnyCPU-Release: + buildPlatform: 'Any CPU' + buildConfiguration: 'Release' + x64-Debug: + buildPlatform: 'x64' + buildConfiguration: 'Debug' + x64-Release: + buildPlatform: 'x64' + buildConfiguration: 'Release' + + steps: + - powershell: 'Invoke-WebRequest -OutFile azure-storage-emulator.msi -Uri "https://go.microsoft.com/fwlink/?LinkId=717179&clcid=0x409"' + displayName: 'Download Azure Storage Emulator' + + - powershell: 'msiexec /passive /lvx installation.log /a azure-storage-emulator.msi TARGETDIR="C:\storage-emulator"' + displayName: 'Install Azure Storage Emulator' + + - script: '"C:\Program Files\Microsoft SQL Server\130\Tools\Binn\SqlLocalDB.exe" create "v13.0" 13.0 -s' + displayName: 'Init Test Db' + + - script: '"C:\storage-emulator\root\Microsoft SDKs\Azure\Storage Emulator\AzureStorageEmulator.exe" start' + displayName: 'Start Storage Emulator' + + - task: DotNetCoreCLI@2 + displayName: 'dotnet build $(buildConfiguration)' + inputs: + command: 'build' + projects: '**/*.test.csproj' + arguments: '--configuration $(buildConfiguration)' + + - task: DotNetCoreCLI@2 + displayName: 'dotnet test $(buildConfiguration)' + inputs: + command: test + projects: '**/*.test.csproj' + arguments: '--configuration $(buildConfiguration) -l "console;verbosity=detailed"' + + - task: PublishTestResults@2 + displayName: 'Publish Test Results' + inputs: + testRunner: VSTest + testResultsFiles: '**/*.trx' + searchFolder: '$(Agent.TempDirectory)' + +- job: 'cppWindows' + pool: + vmImage: vs2017-win2016 + displayName: 'C++ (Windows)' + + strategy: + maxParallel: 2 + matrix: + x64-Debug: + buildPlatform: 'x64' + buildConfiguration: 'Debug' + x64-Release: + buildPlatform: 'x64' + buildConfiguration: 'Release' + + steps: + - task: CMake@1 + displayName: 'CMake .. -G"Visual Studio 15 2017 Win64"' + inputs: + workingDirectory: 'cc/build' + cmakeArgs: '.. -G"Visual Studio 15 2017 Win64"' + + - task: MSBuild@1 + displayName: 'Build solution cc/build/FASTER.sln' + inputs: + solution: 'cc/build/FASTER.sln' + msbuildArguments: '/m /p:Configuration=$(buildConfiguration) /p:Platform=$(buildPlatform)' + + - script: 'ctest -j 1 --interactive-debug-mode 0 --output-on-failure -C $(buildConfiguration) -R "in_memory"' + workingDirectory: 'cc/build' + displayName: 'Run Ctest' + +- job: 'cppLinux' + pool: + vmImage: ubuntu-18.04 + displayName: 'C++ (Linux)' + + steps: + - script: | + sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test + sudo apt update + sudo apt install -y g++-7 libaio-dev uuid-dev libtbb-dev + displayName: 'Install depdendencies' + - script: | + export CXX='g++-7' + cd cc + mkdir -p build/Debug build/Release + cd build/Debug + cmake -DCMAKE_BUILD_TYPE=Debug ../.. + make -j + cd ../../build/Release + cmake -DCMAKE_BUILD_TYPE=Release ../.. + make -j + displayName: 'Compile' + - script: | + CTEST_OUTPUT_ON_FAILURE=1 make test + workingDirectory: 'cc/build/Debug' + displayName: 'Run Tests (Debug)' +- job: 'csharpLinux' + pool: + vmImage: ubuntu-18.04 + displayName: 'C# (Linux)' + + strategy: + maxParallel: 2 + matrix: + AnyCPU-Debug: + buildPlatform: 'Any CPU' + buildConfiguration: 'Debug' + AnyCPU-Release: + buildPlatform: 'Any CPU' + buildConfiguration: 'Release' + + steps: + - task: DotNetCoreCLI@2 + displayName: 'dotnet build $(buildConfiguration)' + inputs: + command: 'build' + projects: '**/*.test.csproj' + arguments: '--configuration $(buildConfiguration)' + + - task: DotNetCoreCLI@2 + displayName: 'dotnet test $(buildConfiguration)' + inputs: + command: test + projects: '**/*.test.csproj' + arguments: '--configuration $(buildConfiguration) -l "console;verbosity=detailed" --filter "TestCategory=Smoke"' + + - task: PublishTestResults@2 + displayName: 'Publish Test Results' + inputs: + testResultsFormat: 'VSTest' + testResultsFiles: '*.trx' + searchFolder: '$(Agent.TempDirectory)' + +# - job: 'cppBlobsWindows' +# pool: +# vmImage: vs2017-win2016 +# displayName: 'C++ Blobs (Windows)' + +# strategy: +# maxParallel: 2 +# matrix: +# x64-Debug: +# buildPlatform: 'x64' +# buildConfiguration: 'Debug' +# x64-Release: +# buildPlatform: 'x64' +# buildConfiguration: 'Release' + +# steps: +# - task: CMake@1 +# displayName: 'CMake .. -G"Visual Studio 15 2017 Win64" -DUSE_BLOBS=ON' +# inputs: +# workingDirectory: 'cc/build' +# cmakeArgs: '.. -G"Visual Studio 15 2017 Win64" -DUSE_BLOBS=ON' + +# - script: 'git clone https://github.com/microsoft/vcpkg' +# workingDirectory: 'cc/build' +# displayName: 'Download Vcpkg' + +# - script: '.\vcpkg\bootstrap-vcpkg.bat' +# workingDirectory: 'cc/build' +# displayName: 'Install Vcpkg' + +# - script: '.\vcpkg\vcpkg.exe install azure-storage-cpp:x64-windows' +# workingDirectory: 'cc/build' +# displayName: 'Install Azure dependencies' + +# - script: '.\vcpkg\vcpkg.exe integrate install' +# workingDirectory: 'cc/build' +# displayName: 'Integrate vcpkg with msbuild' + +# - task: MSBuild@1 +# displayName: 'Build solution cc/build/FASTER.sln' +# inputs: +# solution: 'cc/build/FASTER.sln' +# msbuildArguments: '/m /p:Configuration=$(buildConfiguration) /p:Platform=$(buildPlatform)' + +# - powershell: 'Invoke-WebRequest -OutFile azure-storage-emulator.msi -Uri "https://go.microsoft.com/fwlink/?LinkId=717179&clcid=0x409"' +# displayName: 'Download Azure Storage Emulator' + +# - powershell: 'msiexec /passive /lvx installation.log /a azure-storage-emulator.msi TARGETDIR="C:\storage-emulator"' +# displayName: 'Install Azure Storage Emulator' + +# - script: '"C:\Program Files\Microsoft SQL Server\130\Tools\Binn\SqlLocalDB.exe" create "v13.0" 13.0 -s' +# displayName: 'Init Test Db' + +# - script: '"C:\storage-emulator\root\Microsoft SDKs\Azure\Storage Emulator\AzureStorageEmulator.exe" start' +# displayName: 'Start Storage Emulator' + +# - script: | +# ctest -j 1 --interactive-debug-mode 0 --output-on-failure -C $(buildConfiguration) -R "azure_test" +# ctest -j 1 --interactive-debug-mode 0 --output-on-failure -C $(buildConfiguration) -R "storage_test" +# ctest -j 1 --interactive-debug-mode 0 --output-on-failure -C $(buildConfiguration) -R "faster_blobs_example" +# workingDirectory: 'cc/build' +# displayName: 'Run Ctest' + +# - job: 'cppBlobsLinux' +# pool: +# vmImage: ubuntu-18.04 +# displayName: 'C++ Blobs (Linux)' + +# steps: +# - script: | +# sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test +# sudo apt update +# sudo apt install -y libaio-dev uuid-dev libtbb-dev npm +# displayName: 'Install dependencies' + +# - script: | +# sudo ./scripts/linux/azure/blob.sh +# workingDirectory: 'cc' +# displayName: 'Install Azure dependencies' + +# - script: | +# cd cc +# mkdir -p build/Debug build/Release +# cd build/Debug +# cmake -DCMAKE_BUILD_TYPE=Debug -DUSE_BLOBS=ON ../.. +# make -j +# cd ../../build/Release +# cmake -DCMAKE_BUILD_TYPE=Release -DUSE_BLOBS=ON ../.. +# make -j +# displayName: 'Compile' + +# - script: | +# sudo npm install -g azurite +# azurite -s & +# displayName: 'Install and launch azurite (linux storage emulator)' + +# - script: | +# export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib +# CTEST_OUTPUT_ON_FAILURE=1 make test +# workingDirectory: 'cc/build/Debug' +# displayName: 'Run Tests (Debug)' \ No newline at end of file diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 63d378ddf..73c3c3912 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -51,7 +51,7 @@ jobs: inputs: command: test projects: '**/*.test.csproj' - arguments: '--configuration $(buildConfiguration) -l "console;verbosity=detailed"' + arguments: '--configuration $(buildConfiguration) -l "console;verbosity=detailed" --filter:TestCategory=Smoke' - task: PublishTestResults@2 displayName: 'Publish Test Results' @@ -94,33 +94,36 @@ jobs: - job: 'cppLinux' pool: - vmImage: ubuntu-18.04 + vmImage: ubuntu-20.04 displayName: 'C++ (Linux)' steps: - script: | - sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test sudo apt update - sudo apt install -y g++-7 libaio-dev uuid-dev libtbb-dev + sudo apt install -y g++ libaio-dev uuid-dev libtbb-dev displayName: 'Install depdendencies' + - script: | + git clone https://git.kernel.dk/liburing + cd liburing + git checkout liburing-0.7 + ./configure + sudo make install + displayName: Install Liburing - script: | - export CXX='g++-7' cd cc mkdir -p build/Debug build/Release cd build/Debug - cmake -DCMAKE_BUILD_TYPE=Debug ../.. + cmake -DCMAKE_BUILD_TYPE=Debug -DUSE_URING=ON ../.. make -j cd ../../build/Release - cmake -DCMAKE_BUILD_TYPE=Release ../.. + cmake -DCMAKE_BUILD_TYPE=Release -DUSE_URING=ON ../.. make -j displayName: 'Compile' - - script: | CTEST_OUTPUT_ON_FAILURE=1 make test workingDirectory: 'cc/build/Debug' displayName: 'Run Tests (Debug)' - - job: 'csharpLinux' pool: vmImage: ubuntu-18.04 @@ -149,7 +152,7 @@ jobs: inputs: command: test projects: '**/*.test.csproj' - arguments: '--configuration $(buildConfiguration) -l "console;verbosity=detailed"' + arguments: '--configuration $(buildConfiguration) -l "console;verbosity=detailed" --filter "TestCategory=Smoke"' - task: PublishTestResults@2 displayName: 'Publish Test Results' diff --git a/cc/CMakeLists.txt b/cc/CMakeLists.txt index 6952b91b8..4f92318b6 100644 --- a/cc/CMakeLists.txt +++ b/cc/CMakeLists.txt @@ -11,6 +11,7 @@ project(FASTER) # a flag `USE_BLOBS` that will link in azure's blob store library so that FASTER # can be used with a blob device for the hybrid log. OPTION(USE_BLOBS "Extend FASTER's hybrid log to blob store" OFF) +OPTION(USE_URING "Enable io_uring based IO handler" OFF) if (MSVC) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /Zi /nologo /Gm- /W3 /WX /EHsc /GS /fp:precise /permissive- /Zc:wchar_t /Zc:forScope /Zc:inline /Gd /TP") @@ -26,6 +27,11 @@ else() set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -O0 -g -D_DEBUG") set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3 -g") + + if (USE_URING) + set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -DFASTER_URING") + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -DFASTER_URING") + endif() endif() #Always set _DEBUG compiler directive when compiling bits regardless of target OS @@ -37,7 +43,7 @@ set_directory_properties(PROPERTIES COMPILE_DEFINITIONS_DEBUG "_DEBUG") configure_file(CMakeLists.txt.in googletest-download/CMakeLists.txt) execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" . RESULT_VARIABLE result - WORKING_DIRECTORY ${CMAKE_BINARY_DIR}/googletest-download ) + WORKING_DIRECTORY ${CMAKE_BINARY_DIR}/googletest-download) if(result) message(FATAL_ERROR "CMake step for googletest failed: ${result}") endif() @@ -77,6 +83,9 @@ if(WIN32) set(FASTER_TEST_LINK_LIBS ${FASTER_TEST_LINK_LIBS} rpcrt4) else() set (FASTER_TEST_LINK_LIBS ${FASTER_TEST_LINK_LIBS} stdc++fs uuid tbb gcc aio m stdc++ pthread) + if(USE_URING) + set (FASTER_TEST_LINK_LIBS ${FASTER_TEST_LINK_LIBS} uring) + endif() # Using blob storage. Link in appropriate libraries. if(USE_BLOBS) set (FASTER_TEST_LINK_LIBS ${FASTER_TEST_LINK_LIBS} azurestorage cpprest boost_system crypto ssl) @@ -89,6 +98,9 @@ if(WIN32) set (FASTER_BENCHMARK_LINK_LIBS ${FASTER_LINK_LIBS} rpcrt4 wsock32 Ws2_32) else() set (FASTER_BENCHMARK_LINK_LIBS ${FASTER_BENCHMARK_LINK_LIBS} stdc++fs uuid tbb gcc aio m stdc++ pthread) + if(USE_URING) + set (FASTER_BENCHMARK_LINK_LIBS ${FASTER_BENCHMARK_LINK_LIBS} uring) + endif() endif() #Function to automate building test binaries diff --git a/cc/src/core/address.h b/cc/src/core/address.h index aeb52d3fb..df6655da0 100644 --- a/cc/src/core/address.h +++ b/cc/src/core/address.h @@ -66,6 +66,11 @@ class Address { control_ += delta; return *this; } + inline Address operator+(const uint64_t delta) { + Address addr (*this); + addr += delta; + return addr; + } inline Address operator-(const Address& other) { return control_ - other.control_; } diff --git a/cc/src/core/faster.h b/cc/src/core/faster.h index ad9bc88eb..c5dc99d98 100644 --- a/cc/src/core/faster.h +++ b/cc/src/core/faster.h @@ -1594,15 +1594,25 @@ OperationStatus FasterKv<K, V, D>::InternalContinuePendingRmw(ExecutionContext& Address head_address = hlog.head_address.load(); // Make sure that atomic_entry is OK to update. - if(address >= head_address) { + if(address >= head_address && address != pending_context->entry.address()) { record_t* record = reinterpret_cast<record_t*>(hlog.Get(address)); if(!pending_context->is_key_equal(record->key())) { - address = TraceBackForKeyMatchCtxt(*pending_context, record->header.previous_address(), head_address); + Address min_offset = std::max(pending_context->entry.address() + 1, head_address); + address = TraceBackForKeyMatchCtxt(*pending_context, record->header.previous_address(), min_offset); } } + assert(address >= pending_context->entry.address()); // part of the same hash chain if(address > pending_context->entry.address()) { - // We can't trace the current hash bucket entry back to the record we read. + // This handles two mutually exclusive cases. In both cases InternalRmw will be called immediately: + // 1) Found a newer record in the in-memory region (i.e. address >= head_address) + // Calling InternalRmw will result in taking into account the newer version, + // instead of the record we've just read from disk. + // 2) We can't trace the current hash bucket entry back to the record we've read (i.e. address < head_address) + // This is because part of the hash chain now extends to disk, thus we cannot check it right away + // Calling InternalRmw will result in launching a new search in the hash chain, by reading + // the newly introduced entries in the chain, so we won't miss any potential entries with same key. + // pending_context->continue_async(address, expected_entry); return OperationStatus::RETRY_NOW; } diff --git a/cc/src/environment/file_linux.cc b/cc/src/environment/file_linux.cc index 9eb21b9be..461a28d21 100644 --- a/cc/src/environment/file_linux.cc +++ b/cc/src/environment/file_linux.cc @@ -196,6 +196,121 @@ Status QueueFile::ScheduleOperation(FileOperationType operationType, uint8_t* bu return Status::Ok; } +#ifdef FASTER_URING + +bool UringIoHandler::TryComplete() { + struct io_uring_cqe* cqe = nullptr; + cq_lock_.Acquire(); + int res = io_uring_peek_cqe(ring_, &cqe); + if(res == 0 && cqe) { + int io_res = cqe->res; + auto *context = reinterpret_cast<UringIoHandler::IoCallbackContext*>(io_uring_cqe_get_data(cqe)); + io_uring_cqe_seen(ring_, cqe); + cq_lock_.Release(); + Status return_status; + size_t byte_transferred; + if (io_res < 0) { + // Retry if it is failed..... + sq_lock_.Acquire(); + struct io_uring_sqe *sqe = io_uring_get_sqe(ring_); + assert(sqe != 0); + if (context->is_read_) { + io_uring_prep_readv(sqe, context->fd_, &context->vec_, 1, context->offset_); + } else { + io_uring_prep_writev(sqe, context->fd_, &context->vec_, 1, context->offset_); + } + io_uring_sqe_set_data(sqe, context); + int retry_res = io_uring_submit(ring_); + assert(retry_res == 1); + sq_lock_.Release(); + return false; + } else { + return_status = Status::Ok; + byte_transferred = io_res; + } + context->callback(context->caller_context, return_status, byte_transferred); + lss_allocator.Free(context); + return true; + } else { + cq_lock_.Release(); + return false; + } +} + +Status UringFile::Open(FileCreateDisposition create_disposition, const FileOptions& options, + UringIoHandler* handler, bool* exists) { + int flags = 0; + if(options.unbuffered) { + flags |= O_DIRECT; + } + RETURN_NOT_OK(File::Open(flags, create_disposition, exists)); + if(exists && !*exists) { + return Status::Ok; + } + + ring_ = handler->io_uring(); + sq_lock_ = handler->sq_lock(); + return Status::Ok; +} + +Status UringFile::Read(size_t offset, uint32_t length, uint8_t* buffer, + IAsyncContext& context, AsyncIOCallback callback) const { + DCHECK_ALIGNMENT(offset, length, buffer); +#ifdef IO_STATISTICS + ++read_count_; + bytes_read_ += length; +#endif + return const_cast<UringFile*>(this)->ScheduleOperation(FileOperationType::Read, buffer, + offset, length, context, callback); +} + +Status UringFile::Write(size_t offset, uint32_t length, const uint8_t* buffer, + IAsyncContext& context, AsyncIOCallback callback) { + DCHECK_ALIGNMENT(offset, length, buffer); +#ifdef IO_STATISTICS + bytes_written_ += length; +#endif + return ScheduleOperation(FileOperationType::Write, const_cast<uint8_t*>(buffer), offset, length, + context, callback); +} + +Status UringFile::ScheduleOperation(FileOperationType operationType, uint8_t* buffer, + size_t offset, uint32_t length, IAsyncContext& context, + AsyncIOCallback callback) { + auto io_context = alloc_context<UringIoHandler::IoCallbackContext>(sizeof(UringIoHandler::IoCallbackContext)); + if (!io_context.get()) return Status::OutOfMemory; + + IAsyncContext* caller_context_copy; + RETURN_NOT_OK(context.DeepCopy(caller_context_copy)); + + bool is_read = operationType == FileOperationType::Read; + new(io_context.get()) UringIoHandler::IoCallbackContext(is_read, fd_, buffer, length, offset, caller_context_copy, callback); + + sq_lock_->Acquire(); + struct io_uring_sqe *sqe = io_uring_get_sqe(ring_); + assert(sqe != 0); + + if (is_read) { + io_uring_prep_readv(sqe, fd_, &io_context->vec_, 1, offset); + //io_uring_prep_read(sqe, fd_, buffer, length, offset); + } else { + io_uring_prep_writev(sqe, fd_, &io_context->vec_, 1, offset); + //io_uring_prep_write(sqe, fd_, buffer, length, offset); + } + io_uring_sqe_set_data(sqe, io_context.get()); + + int res = io_uring_submit(ring_); + sq_lock_->Release(); + if (res != 1) { + return Status::IOError; + } + + io_context.release(); + return Status::Ok; +} + +#endif + #undef DCHECK_ALIGNMENT } diff --git a/cc/src/environment/file_linux.h b/cc/src/environment/file_linux.h index 1b98d9ce7..5dd8d146c 100644 --- a/cc/src/environment/file_linux.h +++ b/cc/src/environment/file_linux.h @@ -11,6 +11,10 @@ #include <sys/stat.h> #include <unistd.h> +#ifdef FASTER_URING +#include <liburing.h> +#endif + #include "../core/async.h" #include "../core/status.h" #include "file_common.h" @@ -250,5 +254,154 @@ class QueueFile : public File { io_context_t io_object_; }; +#ifdef FASTER_URING + +class alignas(64) SpinLock { +public: + SpinLock(): locked_(false) {} + + void Acquire() noexcept { + for (;;) { + if (!locked_.exchange(true, std::memory_order_acquire)) { + return; + } + + while (locked_.load(std::memory_order_relaxed)) { + __builtin_ia32_pause(); + } + } + } + + void Release() noexcept { + locked_.store(false, std::memory_order_release); + } +private: + std::atomic_bool locked_; +}; + +class UringFile; + +/// The QueueIoHandler class encapsulates completions for async file I/O, where the completions +/// are put on the AIO completion queue. +class UringIoHandler { + public: + typedef UringFile async_file_t; + + private: + constexpr static int kMaxEvents = 128; + + public: + UringIoHandler() { + ring_ = new struct io_uring(); + int ret = io_uring_queue_init(kMaxEvents, ring_, 0); + assert(ret == 0); + } + + UringIoHandler(size_t max_threads) { + ring_ = new struct io_uring(); + int ret = io_uring_queue_init(kMaxEvents, ring_, 0); + assert(ret == 0); + } + + /// Move constructor + UringIoHandler(UringIoHandler&& other) { + ring_ = other.ring_; + other.ring_ = 0; + } + + ~UringIoHandler() { + if (ring_ != 0) { + io_uring_queue_exit(ring_); + delete ring_; + } + } + + /* + /// Invoked whenever a Linux AIO completes. + static void IoCompletionCallback(io_context_t ctx, struct iocb* iocb, long res, long res2); + */ + struct IoCallbackContext { + IoCallbackContext(bool is_read, int fd, uint8_t* buffer, size_t length, size_t offset, core::IAsyncContext* context_, core::AsyncIOCallback callback_) + : is_read_(is_read) + , fd_(fd) + , vec_{buffer, length} + , offset_(offset) + , caller_context{ context_ } + , callback{ callback_ } {} + + bool is_read_; + + int fd_; + struct iovec vec_; + size_t offset_; + + /// Caller callback context. + core::IAsyncContext* caller_context; + + /// The caller's asynchronous callback function + core::AsyncIOCallback callback; + }; + + inline struct io_uring* io_uring() const { + return ring_; + } + + inline SpinLock* sq_lock() { + return &sq_lock_; + } + + /// Try to execute the next IO completion on the queue, if any. + bool TryComplete(); + + private: + /// The io_uring for all the I/Os + struct io_uring* ring_; + SpinLock sq_lock_, cq_lock_; +}; + +/// The UringFile class encapsulates asynchronous reads and writes, using the specified +/// io_uring +class UringFile : public File { + public: + UringFile() + : File() + , ring_{ nullptr } { + } + UringFile(const std::string& filename) + : File(filename) + , ring_{ nullptr } { + } + /// Move constructor + UringFile(UringFile&& other) + : File(std::move(other)) + , ring_{ other.ring_ } + , sq_lock_{ other.sq_lock_ } { + } + /// Move assignment operator. + UringFile& operator=(UringFile&& other) { + File::operator=(std::move(other)); + ring_ = other.ring_; + sq_lock_ = other.sq_lock_; + return *this; + } + + core::Status Open(FileCreateDisposition create_disposition, const FileOptions& options, + UringIoHandler* handler, bool* exists = nullptr); + + core::Status Read(size_t offset, uint32_t length, uint8_t* buffer, + core::IAsyncContext& context, core::AsyncIOCallback callback) const; + core::Status Write(size_t offset, uint32_t length, const uint8_t* buffer, + core::IAsyncContext& context, core::AsyncIOCallback callback); + + private: + core::Status ScheduleOperation(FileOperationType operationType, uint8_t* buffer, size_t offset, + uint32_t length, core::IAsyncContext& context, core::AsyncIOCallback callback); + + struct io_uring* ring_; + SpinLock* sq_lock_; +}; + +#endif + } } // namespace FASTER::environment diff --git a/cc/test/CMakeLists.txt b/cc/test/CMakeLists.txt index 67df2443d..ac31da059 100644 --- a/cc/test/CMakeLists.txt +++ b/cc/test/CMakeLists.txt @@ -1,10 +1,16 @@ ADD_FASTER_TEST(in_memory_test "") ADD_FASTER_TEST(malloc_fixed_page_size_test "") ADD_FASTER_TEST(paging_queue_test "paging_test.h") +if((NOT MSVC) AND USE_URING) +ADD_FASTER_TEST(paging_uring_test "paging_test.h") +endif() if(MSVC) ADD_FASTER_TEST(paging_threadpool_test "paging_test.h") endif() ADD_FASTER_TEST(recovery_queue_test "recovery_test.h") +if((NOT MSVC) AND USE_URING) +ADD_FASTER_TEST(recovery_uring_test "recovery_test.h") +endif() if(MSVC) ADD_FASTER_TEST(recovery_threadpool_test "recovery_test.h") endif() diff --git a/cc/test/paging_test.h b/cc/test/paging_test.h index 63408badd..713523e92 100644 --- a/cc/test/paging_test.h +++ b/cc/test/paging_test.h @@ -779,6 +779,181 @@ TEST(CLASS, Rmw) { store.StopSession(); } +TEST(CLASS, Rmw_Large) { + class Key { + public: + Key(uint64_t key) + : key_{ key } { + } + + inline static constexpr uint32_t size() { + return static_cast<uint32_t>(sizeof(Key)); + } + inline KeyHash GetHash() const { + std::hash<uint64_t> hash_fn; + return KeyHash{ hash_fn(key_) }; + } + + /// Comparison operators. + inline bool operator==(const Key& other) const { + return key_ == other.key_; + } + inline bool operator!=(const Key& other) const { + return key_ != other.key_; + } + + private: + uint64_t key_; + }; + + class RmwContext; + + class Value { + public: + Value() + : counter_{ 0 } + , junk_{ 1 } { + } + + inline static constexpr uint32_t size() { + return static_cast<uint32_t>(sizeof(Value)); + } + + friend class RmwContext; + + private: + std::atomic<uint64_t> counter_; + uint8_t junk_[8016]; + }; + static_assert(sizeof(Value) == 8024, "sizeof(Value) != 8024"); + static_assert(alignof(Value) == 8, "alignof(Value) != 8"); + + class RmwContext : public IAsyncContext { + public: + typedef Key key_t; + typedef Value value_t; + + RmwContext(Key key, uint64_t incr) + : key_{ key } + , incr_{ incr } + , val_{ 0 } { + } + + /// Copy (and deep-copy) constructor. + RmwContext(const RmwContext& other) + : key_{ other.key_ } + , incr_{ other.incr_ } + , val_{ other.val_ } { + } + + inline const Key& key() const { + return key_; + } + inline static constexpr uint32_t value_size() { + return sizeof(value_t); + } + inline static constexpr uint32_t value_size(const Value& old_value) { + return sizeof(value_t); + } + inline void RmwInitial(Value& value) { + value.counter_ = incr_; + val_ = value.counter_; + } + inline void RmwCopy(const Value& old_value, Value& value) { + value.counter_ = old_value.counter_ + incr_; + val_ = value.counter_; + } + inline bool RmwAtomic(Value& value) { + val_ = value.counter_.fetch_add(incr_) + incr_; + return true; + } + + inline uint64_t val() const { + return val_; + } + + protected: + /// The explicit interface requires a DeepCopy_Internal() implementation. + Status DeepCopy_Internal(IAsyncContext*& context_copy) { + return IAsyncContext::DeepCopy_Internal(*this, context_copy); + } + + private: + Key key_; + uint64_t incr_; + + uint64_t val_; + }; + + std::experimental::filesystem::create_directories("logs"); + + typedef FASTER::device::FileSystemDisk<handler_t, (1 << 30)> disk_t; + FasterKv<Key, Value, disk_t> store { 2048, (1 << 20) * 192, "logs", 0.4 }; + + Guid session_id = store.StartSession(); + + constexpr size_t kNumRecords = 50000; + + // Initial RMW. + static std::atomic<uint64_t> records_touched{ 0 }; + for(size_t idx = 0; idx < kNumRecords; ++idx) { + auto callback = [](IAsyncContext* ctxt, Status result) { + CallbackContext<RmwContext> context{ ctxt }; + ASSERT_EQ(Status::Ok, result); + ASSERT_EQ(3, context->val()); + ++records_touched; + }; + + if(idx % 256 == 0) { + store.Refresh(); + } + + RmwContext context{ Key{ idx }, 3 }; + Status result = store.Rmw(context, callback, 1); + if(result == Status::Ok) { + ASSERT_EQ(3, context.val()); + ++records_touched; + } else { + ASSERT_EQ(Status::Pending, result); + } + } + + bool result = store.CompletePending(true); + ASSERT_TRUE(result); + ASSERT_EQ(kNumRecords, records_touched.load()); + + // Second RMW. + records_touched = 0; + for(size_t idx = kNumRecords; idx > 0; --idx) { + auto callback = [](IAsyncContext* ctxt, Status result) { + CallbackContext<RmwContext> context{ ctxt }; + ASSERT_EQ(Status::Ok, result); + ASSERT_EQ(8, context->val()); + ++records_touched; + }; + + if(idx % 256 == 0) { + store.Refresh(); + } + + RmwContext context{ Key{ idx - 1 }, 5 }; + Status result = store.Rmw(context, callback, 1); + if(result == Status::Ok) { + ASSERT_EQ(8, context.val()) << idx - 1; + ++records_touched; + } else { + ASSERT_EQ(Status::Pending, result); + } + } + + ASSERT_LT(records_touched.load(), kNumRecords); + result = store.CompletePending(true); + ASSERT_TRUE(result); + ASSERT_EQ(kNumRecords, records_touched.load()); + + store.StopSession(); +} + TEST(CLASS, Rmw_Concurrent) { class Key { public: @@ -1033,3 +1208,260 @@ TEST(CLASS, Rmw_Concurrent) { thread.join(); } } + +TEST(CLASS, Rmw_Concurrent_Large) { + class Key { + public: + Key(uint64_t key) + : key_{ key } { + } + + inline static constexpr uint32_t size() { + return static_cast<uint32_t>(sizeof(Key)); + } + inline KeyHash GetHash() const { + std::hash<uint64_t> hash_fn; + return KeyHash{ hash_fn(key_) }; + } + + /// Comparison operators. + inline bool operator==(const Key& other) const { + return key_ == other.key_; + } + inline bool operator!=(const Key& other) const { + return key_ != other.key_; + } + + private: + uint64_t key_; + }; + + class RmwContext; + class ReadContext; + + class Value { + public: + Value() + : counter_{ 0 } + , junk_{ 1 } { + } + + inline static constexpr uint32_t size() { + return static_cast<uint32_t>(sizeof(Value)); + } + + friend class RmwContext; + friend class ReadContext; + + private: + std::atomic<uint64_t> counter_; + uint8_t junk_[8016]; + }; + static_assert(sizeof(Value) == 8024, "sizeof(Value) != 8024"); + static_assert(alignof(Value) == 8, "alignof(Value) != 8"); + + class RmwContext : public IAsyncContext { + public: + typedef Key key_t; + typedef Value value_t; + + RmwContext(Key key, uint64_t incr) + : key_{ key } + , incr_{ incr } { + } + + /// Copy (and deep-copy) constructor. + RmwContext(const RmwContext& other) + : key_{ other.key_ } + , incr_{ other.incr_ } { + } + + inline const Key& key() const { + return key_; + } + inline static constexpr uint32_t value_size() { + return sizeof(value_t); + } + inline static constexpr uint32_t value_size(const Value& old_value) { + return sizeof(value_t); + } + inline void RmwInitial(Value& value) { + value.counter_ = incr_; + } + inline void RmwCopy(const Value& old_value, Value& value) { + value.counter_ = old_value.counter_ + incr_; + } + inline bool RmwAtomic(Value& value) { + value.counter_.fetch_add(incr_); + return true; + } + + protected: + /// The explicit interface requires a DeepCopy_Internal() implementation. + Status DeepCopy_Internal(IAsyncContext*& context_copy) { + return IAsyncContext::DeepCopy_Internal(*this, context_copy); + } + + private: + Key key_; + uint64_t incr_; + }; + + class ReadContext : public IAsyncContext { + public: + typedef Key key_t; + typedef Value value_t; + + ReadContext(Key key) + : key_{ key } { + } + + /// Copy (and deep-copy) constructor. + ReadContext(const ReadContext& other) + : key_{ other.key_ } { + } + + /// The implicit and explicit interfaces require a key() accessor. + inline const Key& key() const { + return key_; + } + + inline void Get(const Value& value) { + counter = value.counter_.load(std::memory_order_acquire); + } + inline void GetAtomic(const Value& value) { + counter = value.counter_.load(); + } + + protected: + /// The explicit interface requires a DeepCopy_Internal() implementation. + Status DeepCopy_Internal(IAsyncContext*& context_copy) { + return IAsyncContext::DeepCopy_Internal(*this, context_copy); + } + + private: + Key key_; + public: + uint64_t counter; + }; + + typedef FASTER::device::FileSystemDisk<handler_t, (1 << 30)> disk_t; + static constexpr size_t kNumRecords = 50000; + static constexpr size_t kNumThreads = 2; + + auto rmw_worker = [](FasterKv<Key, Value, disk_t>* store_, uint64_t incr) { + Guid session_id = store_->StartSession(); + for(size_t idx = 0; idx < kNumRecords; ++idx) { + auto callback = [](IAsyncContext* ctxt, Status result) { + CallbackContext<RmwContext> context{ ctxt }; + ASSERT_EQ(Status::Ok, result); + }; + + if(idx % 256 == 0) { + store_->Refresh(); + } + + RmwContext context{ Key{ idx }, incr }; + Status result = store_->Rmw(context, callback, 1); + if(result != Status::Ok) { + ASSERT_EQ(Status::Pending, result); + } + } + bool result = store_->CompletePending(true); + ASSERT_TRUE(result); + store_->StopSession(); + }; + + auto read_worker1 = [](FasterKv<Key, Value, disk_t>* store_, size_t thread_idx) { + Guid session_id = store_->StartSession(); + for(size_t idx = 0; idx < kNumRecords / kNumThreads; ++idx) { + auto callback = [](IAsyncContext* ctxt, Status result) { + CallbackContext<ReadContext> context{ ctxt }; + ASSERT_EQ(Status::Ok, result); + ASSERT_EQ(7 * kNumThreads, context->counter); + }; + + if(idx % 256 == 0) { + store_->Refresh(); + } + + ReadContext context{ Key{ thread_idx* (kNumRecords / kNumThreads) + idx } }; + Status result = store_->Read(context, callback, 1); + if(result == Status::Ok) { + ASSERT_EQ(7 * kNumThreads, context.counter); + } else { + ASSERT_EQ(Status::Pending, result); + } + } + bool result = store_->CompletePending(true); + ASSERT_TRUE(result); + store_->StopSession(); + }; + + auto read_worker2 = [](FasterKv<Key, Value, disk_t>* store_, size_t thread_idx) { + Guid session_id = store_->StartSession(); + for(size_t idx = 0; idx < kNumRecords / kNumThreads; ++idx) { + auto callback = [](IAsyncContext* ctxt, Status result) { + CallbackContext<ReadContext> context{ ctxt }; + ASSERT_EQ(Status::Ok, result); + ASSERT_EQ(13 * kNumThreads, context->counter); + }; + + if(idx % 256 == 0) { + store_->Refresh(); + } + + ReadContext context{ Key{ thread_idx* (kNumRecords / kNumThreads) + idx } }; + Status result = store_->Read(context, callback, 1); + if(result == Status::Ok) { + ASSERT_EQ(13 * kNumThreads, context.counter); + } else { + ASSERT_EQ(Status::Pending, result); + } + } + bool result = store_->CompletePending(true); + ASSERT_TRUE(result); + store_->StopSession(); + }; + + std::experimental::filesystem::create_directories("logs"); + + // 192 MB in memory -- rest on disk + FasterKv<Key, Value, disk_t> store { 2048, (1 << 20) * 192, "logs", 0.4 }; + + // Initial RMW. + std::deque<std::thread> threads{}; + for(int64_t idx = 0; idx < kNumThreads; ++idx) { + threads.emplace_back(rmw_worker, &store, 7); + } + for(auto& thread : threads) { + thread.join(); + } + + // Read. + threads.clear(); + for(int64_t idx = 0; idx < kNumThreads; ++idx) { + threads.emplace_back(read_worker1, &store, idx); + } + for(auto& thread : threads) { + thread.join(); + } + + // Second RMW. + threads.clear(); + for(int64_t idx = 0; idx < kNumThreads; ++idx) { + threads.emplace_back(rmw_worker, &store, 6); + } + for(auto& thread : threads) { + thread.join(); + } + + // Read again. + threads.clear(); + for(int64_t idx = 0; idx < kNumThreads; ++idx) { + threads.emplace_back(read_worker2, &store, idx); + } + for(auto& thread : threads) { + thread.join(); + } +} diff --git a/cc/test/paging_uring_test.cc b/cc/test/paging_uring_test.cc new file mode 100644 index 000000000..c66d2cd6b --- /dev/null +++ b/cc/test/paging_uring_test.cc @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include <atomic> +#include <cstdint> +#include <cstring> +#include <deque> +#include <functional> +#include <thread> +#include "gtest/gtest.h" +#include "core/faster.h" +#include "device/file_system_disk.h" + +using namespace FASTER::core; + +typedef FASTER::environment::UringIoHandler handler_t; + +#define CLASS PagingTest_Uring + +#include "paging_test.h" + +#undef CLASS + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/cc/test/recovery_uring_test.cc b/cc/test/recovery_uring_test.cc new file mode 100644 index 000000000..301a52f43 --- /dev/null +++ b/cc/test/recovery_uring_test.cc @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include <atomic> +#include <cstdint> +#include <cstring> +#include <deque> +#include <functional> +#include <memory> +#include <random> +#include <thread> +#include "gtest/gtest.h" +#include "core/faster.h" +#include "core/light_epoch.h" +#include "core/thread.h" +#include "device/file_system_disk.h" + +using namespace FASTER::core; + +typedef FASTER::environment::UringIoHandler handler_t; + +#define CLASS RecoveryTest_Uring + +#include "recovery_test.h" + +#undef CLASS + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/cs/libdpr/samples/DprCounters/DprCounters/CounterClientSession.cs b/cs/libdpr/samples/DprCounters/DprCounters/CounterClientSession.cs index f823fb75a..cf3027bad 100644 --- a/cs/libdpr/samples/DprCounters/DprCounters/CounterClientSession.cs +++ b/cs/libdpr/samples/DprCounters/DprCounters/CounterClientSession.cs @@ -69,7 +69,7 @@ public long Increment(Worker worker, long amount, out long result) receivedBytes += socket.Receive(serializationBuffer, receivedBytes, serializationBuffer.Length - receivedBytes, SocketFlags.None); // Forward the DPR response header after we are done - var success = session.ResolveBatch(new Memory<byte>(serializationBuffer, sizeof(int), size - sizeof(long)), out var vector); + var success = session.ResolveBatch(new Span<byte>(serializationBuffer, sizeof(int), size - sizeof(long)), out var vector); // Because we use one-off sockets, resolve batch should never fail. Debug.Assert(success); diff --git a/cs/libdpr/src/FASTER.libdpr/clientlib/ClientBatchTracker.cs b/cs/libdpr/src/FASTER.libdpr/clientlib/ClientBatchTracker.cs index 23a973e74..31919b0bc 100644 --- a/cs/libdpr/src/FASTER.libdpr/clientlib/ClientBatchTracker.cs +++ b/cs/libdpr/src/FASTER.libdpr/clientlib/ClientBatchTracker.cs @@ -10,13 +10,12 @@ internal class BatchInfo internal const int MaxHeaderSize = 4096; internal readonly byte[] header; internal bool allocated; - internal int batchId, batchSize; + internal int batchId; internal Worker workerId; internal BatchInfo(int batchId) { this.batchId = batchId; - batchSize = 0; allocated = false; workerId = default; header = new byte[MaxHeaderSize]; @@ -27,6 +26,7 @@ internal BatchInfo(int batchId) // a number of frames that hold batch information as specified in the constructor. internal class ClientBatchTracker : IEnumerable<BatchInfo> { + public const int INVALID_BATCH_ID = -1; private readonly BatchInfo[] buffers; private readonly ConcurrentQueue<int> freeBuffers; diff --git a/cs/libdpr/src/FASTER.libdpr/clientlib/DprClientSession.cs b/cs/libdpr/src/FASTER.libdpr/clientlib/DprClientSession.cs index 2cdb2e0f5..508fe1955 100644 --- a/cs/libdpr/src/FASTER.libdpr/clientlib/DprClientSession.cs +++ b/cs/libdpr/src/FASTER.libdpr/clientlib/DprClientSession.cs @@ -93,7 +93,6 @@ public void IssueBatch(int batchSize, Worker workerId, out Span<byte> header) ref var dprHeader = ref Unsafe.AsRef<DprBatchRequestHeader>(b); // Populate info with relevant request information info.workerId = workerId; - info.batchSize = batchSize; // Populate header with relevant request information dprHeader.batchId = info.batchId; @@ -122,12 +121,13 @@ public void IssueBatch(int batchSize, Worker workerId, out Span<byte> header) /// longer be safe to access if reply is moved, modified, or deallocated. /// </param> /// <returns>whether it is safe to proceed with consuming the operation result</returns> - public unsafe bool ResolveBatch(ReadOnlyMemory<byte> reply, out IReadOnlyList<long> versionVector) + public unsafe bool ResolveBatch(ReadOnlySpan<byte> reply, out DprBatchVersionVector versionVector) { - fixed (byte* h = reply.Span) + versionVector = new DprBatchVersionVector(Span<byte>.Empty); + + fixed (byte* h = reply) { ref var responseHeader = ref Unsafe.AsRef<DprBatchResponseHeader>(h); - versionVector = Array.Empty<long>(); var batchInfo = batchTracker.GetBatch(responseHeader.batchId); fixed (byte* b = batchInfo.header) @@ -157,8 +157,9 @@ public unsafe bool ResolveBatch(ReadOnlyMemory<byte> reply, out IReadOnlyList<lo // TODO(Tianyu): Not necessary to iterate through all of the vector, can probably do an optimization // to compute the max on the sender side to avoid iteration. long maxVersion = 0; - foreach (var v in new DprBatchVersionVector(reply)) - maxVersion = Math.Max(maxVersion, v); + versionVector = new DprBatchVersionVector(reply); + for (var i = 0; i < versionVector.Count; i++) + maxVersion = Math.Max(maxVersion, versionVector[i]); core.Utility.MonotonicUpdate(ref clientVersion, maxVersion, out _); versionVector = new DprBatchVersionVector(reply); @@ -188,5 +189,40 @@ public unsafe bool ResolveBatch(ReadOnlyMemory<byte> reply, out IReadOnlyList<lo return true; } + + /// <summary> + /// Consumes a DPR subscription batch reply and update tracking information. This method should be called + /// before exposing the batch; if the method returns false, the results are rolled back and should be + /// discarded. Returns a versionVector that allows callers to inspect the version each operation executed + /// at, in the same order operations appeared in the batch. + /// Thread-safe to invoke with other methods of the object. + /// </summary> + /// <param name="src"> Id of the work this batch is from </param> + /// <param name="header">The received subscription batch header</param> + /// <param name="versionVector"> + /// An IEnumerable holding the version each operation in the batch executed at. Operations are identified by + /// the offset they appeared in the original batch. This object shares scope with the reply Memory and will no + /// longer be safe to access if reply is moved, modified, or deallocated. + /// </param> + /// <returns>whether it is safe to proceed with consuming the operation result</returns> + public unsafe bool ResolveSubscriptionBatch(Worker src, ReadOnlySpan<byte> header, + out DprBatchVersionVector versionVector) + { + CheckWorldlineChange(); + // Wait for a batch slot to become available + BatchInfo info; + // TODO(Tianyu): Probably rewrite with async + while (!batchTracker.TryGetBatchInfo(out info)) + Thread.Yield(); + info.workerId = src; + + fixed (byte* h = header) + { + ref var responseHeader = ref Unsafe.AsRef<DprBatchResponseHeader>(h); + responseHeader.batchId = info.batchId; + } + + return ResolveBatch(header, out versionVector); + } } } \ No newline at end of file diff --git a/cs/libdpr/src/FASTER.libdpr/common/Defs.cs b/cs/libdpr/src/FASTER.libdpr/common/Defs.cs index db6f8494f..c164f935a 100644 --- a/cs/libdpr/src/FASTER.libdpr/common/Defs.cs +++ b/cs/libdpr/src/FASTER.libdpr/common/Defs.cs @@ -8,6 +8,7 @@ namespace FASTER.libdpr [Serializable] public struct Worker { + public static readonly Worker INVALID = new Worker(-1); /// <summary> /// Reserved Worker id for cluster management /// </summary> diff --git a/cs/libdpr/src/FASTER.libdpr/common/DprBatchVersionVector.cs b/cs/libdpr/src/FASTER.libdpr/common/DprBatchVersionVector.cs index b793d01dd..c005b3f42 100644 --- a/cs/libdpr/src/FASTER.libdpr/common/DprBatchVersionVector.cs +++ b/cs/libdpr/src/FASTER.libdpr/common/DprBatchVersionVector.cs @@ -8,84 +8,30 @@ namespace FASTER.libdpr /// <summary> /// on-wire format for a vector of version numbers. Does not own or allocate underlying memory. /// </summary> - public unsafe struct DprBatchVersionVector : IReadOnlyList<long> + public unsafe ref struct DprBatchVersionVector { - private readonly ReadOnlyMemory<byte> responseHead; + private readonly ReadOnlySpan<byte> responseHead; /// <summary> /// Construct a new VersionVector to be backed by the given byte* /// </summary> /// <param name="vectorHead"> Reference to the response bytes</param> - public DprBatchVersionVector(ReadOnlyMemory<byte> responseHead) + public DprBatchVersionVector(ReadOnlySpan<byte> responseHead) { this.responseHead = responseHead; - fixed (byte* h = responseHead.Span) + if (responseHead.IsEmpty) + Count = 0; + else { - fixed (byte* b = Unsafe.AsRef<DprBatchResponseHeader>(h).versions) - { - var l = (long*) b; - Count = (int) l[0]; - } - } - } - - private class DprBatchVersionVectorEnumerator : IEnumerator<long> - { - private readonly ReadOnlyMemory<byte> start; - private long index = -1; - - public DprBatchVersionVectorEnumerator(ReadOnlyMemory<byte> start) - { - this.start = start; - } - - public bool MoveNext() - { - fixed (byte* h = start.Span) + fixed (byte* h = responseHead) { fixed (byte* b = Unsafe.AsRef<DprBatchResponseHeader>(h).versions) { var l = (long*) b; - return ++index < l[0]; + Count = (int) l[0]; } } } - - public void Reset() - { - index = 0; - } - - public long Current - { - get - { - fixed (byte* h = start.Span) - { - fixed (byte* b = Unsafe.AsRef<DprBatchResponseHeader>(h).versions) - { - var l = (long*) b; - return l[index + 1]; - } - } - } - } - - object IEnumerator.Current => Current; - - public void Dispose() - { - } - } - - public IEnumerator<long> GetEnumerator() - { - return new DprBatchVersionVectorEnumerator(responseHead); - } - - IEnumerator IEnumerable.GetEnumerator() - { - return GetEnumerator(); } public int Count { get; } @@ -94,7 +40,7 @@ public long this[int index] { get { - fixed (byte* h = responseHead.Span) + fixed (byte* h = responseHead) { fixed (byte* b = Unsafe.AsRef<DprBatchResponseHeader>(h).versions) { diff --git a/cs/libdpr/src/FASTER.libdpr/serverlib/DprServer.cs b/cs/libdpr/src/FASTER.libdpr/serverlib/DprServer.cs index a625590b9..5aa88cbc6 100644 --- a/cs/libdpr/src/FASTER.libdpr/serverlib/DprServer.cs +++ b/cs/libdpr/src/FASTER.libdpr/serverlib/DprServer.cs @@ -59,6 +59,15 @@ public DprServer(IDprFinder dprFinder, Worker me, TStateObject stateObject) depSerializationArray = new byte[2 * LightDependencySet.MaxClusterSize * sizeof(long)]; } + /// <summary></summary> + /// <returns> Worker ID of this DprServer instance </returns> + public Worker Me() => state.me; + + /// <summary> + /// At the start (restart) of processing, connect to the rest of the DPR cluster. If the worker restarted from + /// an existing instance, the cluster will detect this and trigger rollback as appropriate across the cluster. + /// Must be invoked exactly once before any other roperations. + /// </summary> public void ConnectToCluster() { var v = state.dprFinder.NewWorker(state.me, stateObject); @@ -76,6 +85,8 @@ public void ConnectToCluster() }, state.dprFinder.SystemWorldLine()); } + /// <summary></summary> + /// <returns> The underlying state object </returns> public TStateObject StateObject() { return stateObject; @@ -96,6 +107,13 @@ private ReadOnlySpan<byte> ComputeDependency(long version) return new ReadOnlySpan<byte>(depSerializationArray, 0, head); } + /// <summary> + /// Check whether this DprServer is due to refresh its view of the cluster or need to perform a checkpoint --- + /// if so, perform the relevant operation(s). Roughly ensures that both are only actually invoked as frequently + /// as the given intervals for checkpoint and refresh. + /// </summary> + /// <param name="checkpointPeriodMilli">The rough interval, in milliseconds, expected between two checkpoints</param> + /// <param name="refreshPeriodMilli">The rough interval, in milliseconds, expected between two refreshes</param> public void TryRefreshAndCheckpoint(long checkpointPeriodMilli, long refreshPeriodMilli) { var currentTime = state.sw.ElapsedMilliseconds; @@ -138,8 +156,8 @@ private void TryAdvanceWorldLineTo(long targetWorldline) /// Invoke before beginning processing of a batch. If the function returns false, the batch must not be /// executed to preserve DPR consistency, and the caller should return the error response (written in the /// response field in case of failure). Otherwise, when the function returns, the batch is safe to execute. - /// In the true case, there must eventually be a matching SignalBatchFinish call for DPR to make - /// progress. + /// In the true case, there must eventually be a matching SignalBatchFinish call on the same thread for DPR + /// to make progress. /// </summary> /// <param name="request">Dpr request message from user</param> /// <param name="response">Dpr response message that will be returned to user in case of failure</param> @@ -205,11 +223,10 @@ public unsafe bool RequestBatchBegin(ref DprBatchRequestHeader request, ref DprB return true; } - /// <summary> - /// Invoke after processing of a batch is complete, and dprBatchHeader has been populated with a batch offset -> - /// executed version mapping. This function must be invoked once for every processed batch for DPR to make - /// progress. + /// Invoke after processing of a batch is complete, and DprBatchVersionTracker has been populated with a + /// batch offset -> executed version mapping. This function must be invoked once for every processed batch + /// on the same thread for DPR to make progress. /// </summary> /// <param name="request">Dpr request message from user</param> /// <param name="response">Dpr response message that will be returned to user</param> @@ -238,6 +255,52 @@ public int SignalBatchFinish(ref DprBatchRequestHeader request, Span<byte> respo return responseSize; } + /// <summary> + /// Obtain a version tracker for a subscription batch. Functionally, subscription messages are equivalent to + /// response messages, but do not necessarily have a corresponding DPR request for them. There must eventually + /// be a matching FinishSubscriptionBatch call on the same thread for DPR to make progress. + /// </summary> + /// <returns> Tracker to use for batch execution </returns> + public DprBatchVersionTracker StartSubscriptionBatch() + { + state.worldlineTracker.Enter(); + return trackers.Checkout(); + } + + /// <summary> + /// Invoke after processing of a subscription batch is complete, and DprBatchVersionTracker has been populated + /// with a batch offset -> executed version mapping. This function must be invoked once for every + /// started subscription batch on the same thread for DPR to make progress. + /// </summary> + /// <param name="sessionId"> id of the session this batch is going to </param> + /// <param name="response"> Dpr response message that will be returned to user </param> + /// <param name="tracker"> Tracker used in batch processing</param> + /// <returns> size of header if successful, negative size required to hold response if supplied byte span is too small </returns> + public int FinishSubscriptionBatch(Guid sessionId, Span<byte> response, DprBatchVersionTracker tracker) + { + ref var dprResponse = + ref MemoryMarshal.GetReference(MemoryMarshal.Cast<byte, DprBatchResponseHeader>(response)); + var responseSize = DprBatchResponseHeader.HeaderSize + tracker.EncodingSize(); + if (response.Length < responseSize) return -responseSize; + + // This is the world line we are leaving + var wl = state.worldlineTracker.Version(); + + // Signal batch finished so world-lines can advance. Need to make sure not double-invoked, therefore only + // called after size validation. + state.worldlineTracker.Leave(); + + // Populate response + dprResponse.sessionId = sessionId; + dprResponse.worldLine = wl; + dprResponse.batchId = ClientBatchTracker.INVALID_BATCH_ID; + dprResponse.batchSize = responseSize; + tracker.AppendOntoResponse(ref dprResponse); + + trackers.Return(tracker); + return responseSize; + } + /// <summary> /// Force the execution of a checkpoint ahead of the schedule specified at creation time. /// Resets the checkpoint schedule to happen checkpoint_milli after this invocation. diff --git a/cs/libdpr/src/FASTER.libdpr/serverlib/SimpleVersionScheme.cs b/cs/libdpr/src/FASTER.libdpr/serverlib/SimpleVersionScheme.cs index ce1f1640f..7beb2012c 100644 --- a/cs/libdpr/src/FASTER.libdpr/serverlib/SimpleVersionScheme.cs +++ b/cs/libdpr/src/FASTER.libdpr/serverlib/SimpleVersionScheme.cs @@ -1,5 +1,6 @@ using System; using System.Threading; +using FASTER.core; namespace FASTER.libdpr { @@ -9,11 +10,15 @@ namespace FASTER.libdpr /// </summary> public class SimpleVersionScheme { - // One count is reserved for the thread that actually advances the version, each batch also gets one - private readonly CountdownEvent count = new CountdownEvent(1); + private LightEpoch epoch; private long version = 1; private ManualResetEventSlim versionChanged; + public SimpleVersionScheme() + { + epoch = new LightEpoch(); + } + /// <summary> /// Returns the current version /// </summary> @@ -31,24 +36,19 @@ public long Version() /// <returns>current version number</returns> public long Enter() { + epoch.Resume(); + // Temporarily block if a version change is under way --- depending on whether the thread observes + // versionChanged, they are either in the current version or the next while (true) { var ev = versionChanged; - if (ev == null) - { - if (count.TryAddCount()) - // Because version is only changed after count == 0, which we know will never happen before we - // return at this point, it suffices to just read the field. - return version; - // If the count ever reaches 0, version change may have already occured, and we need to - // back away to retry - } - else - { - // Wait for version advance to complete and then try again. - versionChanged.Wait(); - } + if (ev == null) break; + // Allow version change to complete by leaving this epoch. + epoch.Suspend(); + ev.Wait(); + epoch.Resume(); } + return version; } /// <summary> @@ -56,7 +56,7 @@ public long Enter() /// </summary> public void Leave() { - count.Signal(); + epoch.Suspend(); } /// <summary> @@ -72,34 +72,28 @@ public void Leave() /// <returns> Whether the advance was successful </returns> public bool TryAdvanceVersion(Action<long, long> criticalSection, long targetVersion = -1) { - if (targetVersion != -1 && targetVersion <= version) return false; - var ev = new ManualResetEventSlim(); // Compare and exchange to install our advance while (Interlocked.CompareExchange(ref versionChanged, ev, null) != null) { } - // After success, we have exclusive access to update version - var original = version; if (targetVersion != -1 && targetVersion <= version) { - // In this case, advance request is not valid + versionChanged.Set(); versionChanged = null; return false; } - // One count is reserved for the thread that actually advances the version - count.Signal(); - // Wait until all batches in the previous version has been processed - count.Wait(); - - version = targetVersion == -1 ? version + 1 : targetVersion; - criticalSection(original, version); - // Complete the version change - count.Reset(); - ev.Set(); - versionChanged = null; + // Any thread that sees ev will be in v + 1, because the bump happens only after ev is set. + var original = version; + epoch.BumpCurrentEpoch(() => + { + version = targetVersion == -1 ? version + 1 : targetVersion; + criticalSection(original, version); + versionChanged.Set(); + versionChanged = null; + }); return true; } } diff --git a/cs/libdpr/test/FASTER.libdpr.test/TestClientObject.cs b/cs/libdpr/test/FASTER.libdpr.test/TestClientObject.cs index 9d7d414d4..8a854aa26 100644 --- a/cs/libdpr/test/FASTER.libdpr.test/TestClientObject.cs +++ b/cs/libdpr/test/FASTER.libdpr.test/TestClientObject.cs @@ -38,7 +38,7 @@ public long ResolveOp(int op) { responses.Remove(op, out var buf); Debug.Assert(buf != null); - session.ResolveBatch(new ReadOnlyMemory<byte>(buf), out var result); + session.ResolveBatch(new ReadOnlySpan<byte>(buf), out var result); bufs.Return(buf); return result[0]; } diff --git a/cs/playground/AsyncStress/SerializedFasterWrapper.cs b/cs/playground/AsyncStress/SerializedFasterWrapper.cs index 6ed253110..f86c72d9a 100644 --- a/cs/playground/AsyncStress/SerializedFasterWrapper.cs +++ b/cs/playground/AsyncStress/SerializedFasterWrapper.cs @@ -95,6 +95,7 @@ internal async ValueTask UpdateAsync<TUpdater, TAsyncResult>(TUpdater updater, K } } Interlocked.Add(ref pendingCount, await updater.CompleteAsync(await task.ConfigureAwait(false))); + _sessionPool.Return(session); } public void Update<TUpdater, TAsyncResult>(TUpdater updater, Key key, Value value) diff --git a/cs/remote/FASTER.remote.sln b/cs/remote/FASTER.remote.sln index 80550104b..290299e1b 100644 --- a/cs/remote/FASTER.remote.sln +++ b/cs/remote/FASTER.remote.sln @@ -59,6 +59,12 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "docs", "docs", "{48000B9F-1 ..\..\docs\_data\navigation.yml = ..\..\docs\_data\navigation.yml EndProjectSection EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "WebClient", "WebClient", "{8D1F793C-DA1E-4C75-B824-E510BB54534E}" + ProjectSection(SolutionItems) = preProject + samples\WebClient\FASTERFunctions.js = samples\WebClient\FASTERFunctions.js + samples\WebClient\WebClient.html = samples\WebClient\WebClient.html + EndProjectSection +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -162,6 +168,7 @@ Global {4053EC35-77A5-4728-B16F-F4FDD1104CAF} = {1065AE7E-DEA5-4E21-AE39-95B93C074B17} {6A49ADD2-DC25-47E1-9D29-5DC6380E880A} = {6B8D1038-C9D5-4111-B5CE-BF64E7D12AE1} {2238A430-8D61-40A3-A23B-B1163A4CCBC6} = {8CF11B91-A6B6-4B81-AD43-2B07CF60F8FF} + {8D1F793C-DA1E-4C75-B824-E510BB54534E} = {1065AE7E-DEA5-4E21-AE39-95B93C074B17} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {FB603D60-F72D-4DAD-9349-442A45E20276} diff --git a/cs/remote/benchmark/FASTER.benchmark/Types.cs b/cs/remote/benchmark/FASTER.benchmark/Types.cs index 372e24f1c..fa8462867 100644 --- a/cs/remote/benchmark/FASTER.benchmark/Types.cs +++ b/cs/remote/benchmark/FASTER.benchmark/Types.cs @@ -65,5 +65,16 @@ public void RMWCompletionCallback(ref Key key, ref Input input, ref Output outpu public void UpsertCompletionCallback(ref Key key, ref Value value, Empty ctx) { } + + public void SubscribeKVCallback(ref Key key, ref Input input, ref Output output, Empty ctx, Status status) + { + } + + public void PublishCompletionCallback(ref Key key, ref Value value, Empty ctx) + { + } + public void SubscribeCallback(ref Key key, ref Value value, Empty ctx) + { + } } } diff --git a/cs/remote/samples/FixedLenClient/Functions.cs b/cs/remote/samples/FixedLenClient/Functions.cs index 7bf04cf9d..3122cf0ed 100644 --- a/cs/remote/samples/FixedLenClient/Functions.cs +++ b/cs/remote/samples/FixedLenClient/Functions.cs @@ -31,6 +31,10 @@ public override void ReadCompletionCallback(ref long key, ref long input, ref lo } } + public override void SubscribeKVCallback(ref long key, ref long input, ref long output, byte ctx, Status status) + { + } + public override void RMWCompletionCallback(ref long key, ref long input, ref long output, byte ctx, Status status) { if (ctx == 1) diff --git a/cs/remote/samples/FixedLenClient/Program.cs b/cs/remote/samples/FixedLenClient/Program.cs index 0cdd46b9b..c24835e56 100644 --- a/cs/remote/samples/FixedLenClient/Program.cs +++ b/cs/remote/samples/FixedLenClient/Program.cs @@ -3,6 +3,7 @@ using System; using System.Diagnostics; +using System.Threading; using System.Threading.Tasks; using FASTER.client; @@ -16,6 +17,7 @@ class Program { static void Main(string[] args) { + Environment.SetEnvironmentVariable("DOTNET_SYSTEM_NET_SOCKETS_INLINE_COMPLETIONS", "1"); string ip = "127.0.0.1"; int port = 3278; @@ -31,7 +33,8 @@ static void Main(string[] args) // Create a client session to the FasterKV server. // Sessions are mono-threaded, similar to normal FasterKV sessions. - using var session = client.NewSession(new Functions()); // Uses protocol WireFormat.DefaultFixedLenKV by default + using var session = client.NewSession(new Functions()); + using var session2 = client.NewSession(new Functions()); // Explicit version of NewSession call, where you provide all types, callback functions, and serializer // using var session = client.NewSession<long, long, byte, Functions, FixedLenSerializer<long, long, long, long>>(new Functions(), new FixedLenSerializer<long, long, long, long>()); @@ -39,6 +42,9 @@ static void Main(string[] args) // Samples using sync client API SyncSamples(session); + // Samples using sync subscription client API + SyncSubscriptionSamples(session, session2); + // Samples using async client API AsyncSamples(session).Wait(); @@ -47,6 +53,9 @@ static void Main(string[] args) static void SyncSamples(ClientSession<long, long, long, long, byte, Functions, FixedLenSerializer<long, long, long, long>> session) { + session.Upsert(23, 23 + 10000); + session.CompletePending(true); + for (int i = 0; i < 1000; i++) session.Upsert(i, i + 10000); @@ -94,6 +103,28 @@ static void SyncSamples(ClientSession<long, long, long, long, byte, Functions, F session.CompletePending(true); } + static void SyncSubscriptionSamples(ClientSession<long, long, long, long, byte, Functions, FixedLenSerializer<long, long, long, long>> session, ClientSession<long, long, long, long, byte, Functions, FixedLenSerializer<long, long, long, long>> session2) + { + session2.SubscribeKV(23); + session2.CompletePending(true); + + for (int i = 0; i < 1000000; i++) + session.Upsert(23, i + 10); + + // Flushes partially filled batches, does not wait for responses + session.Flush(); + session.CompletePending(true); + + session.RMW(23, 25); + session.CompletePending(true); + + session.Flush(); + session.CompletePending(true); + + Thread.Sleep(1000); + } + + static async Task AsyncSamples(ClientSession<long, long, long, long, byte, Functions, FixedLenSerializer<long, long, long, long>> session) { for (int i = 0; i < 1000; i++) diff --git a/cs/remote/samples/FixedLenServer/FixedLenServer.csproj b/cs/remote/samples/FixedLenServer/FixedLenServer.csproj index c205d6c8e..2a2976167 100644 --- a/cs/remote/samples/FixedLenServer/FixedLenServer.csproj +++ b/cs/remote/samples/FixedLenServer/FixedLenServer.csproj @@ -6,6 +6,7 @@ <AllowUnsafeBlocks>true</AllowUnsafeBlocks> <Platforms>AnyCPU;x64</Platforms> <LangVersion>latest</LangVersion> + <ServerGarbageCollection>true</ServerGarbageCollection> </PropertyGroup> <PropertyGroup Condition="'$(Configuration)' == 'Debug'"> diff --git a/cs/remote/samples/FixedLenServer/Program.cs b/cs/remote/samples/FixedLenServer/Program.cs index 0b2c44586..42f5c7c6c 100644 --- a/cs/remote/samples/FixedLenServer/Program.cs +++ b/cs/remote/samples/FixedLenServer/Program.cs @@ -3,52 +3,31 @@ using System; using System.Threading; -using ServerOptions; using CommandLine; -using FASTER.core; +using FasterServerOptions; using FASTER.server; -using FASTER.common; +using System.Diagnostics; -namespace FixedLenServer +namespace FasterFixedLenServer { /// <summary> - /// This sample creates a FASTER server for fixed-length (struct) keys and values + /// Sample server for fixed-length (blittable) keys and values. /// Types are defined in Types.cs; they are 8-byte keys and values in the sample. - /// A binary wire protocol is used. /// </summary> class Program { static void Main(string[] args) { - FixedLenServer(args); - } + Environment.SetEnvironmentVariable("DOTNET_SYSTEM_NET_SOCKETS_INLINE_COMPLETIONS", "1"); + Trace.Listeners.Add(new ConsoleTraceListener()); - static void FixedLenServer(string[] args) - { Console.WriteLine("FASTER fixed-length (binary) KV server"); ParserResult<Options> result = Parser.Default.ParseArguments<Options>(args); if (result.Tag == ParserResultType.NotParsed) return; var opts = result.MapResult(o => o, xs => new Options()); - - opts.GetSettings(out var logSettings, out var checkpointSettings, out var indexSize); - - // We use blittable structs Key and Value to construct a customized server for fixed-length types - var store = new FasterKV<Key, Value>(indexSize, logSettings, checkpointSettings); - if (opts.Recover) store.Recover(); - - // This fixed-length session provider can be used with compatible clients such as FixedLenClient and FASTER.benchmark - // Uses FixedLenSerializer as our in-built serializer for blittable (fixed length) types - var provider = new FasterKVProvider<Key, Value, Input, Output, Functions, FixedLenSerializer<Key, Value, Input, Output>>(store, e => new Functions()); - - // Create server - var server = new FasterServer(opts.Address, opts.Port); - - // Register session provider for WireFormat.DefaultFixedLenKV - // You can register multiple session providers with the same server, with different wire protocol specifications - server.Register(WireFormat.DefaultFixedLenKV, provider); - - // Start server + + using var server = new FixedLenServer<Key, Value, Input, Output, Functions>(opts.GetServerOptions(), e => new Functions()); server.Start(); Console.WriteLine("Started server"); diff --git a/cs/remote/samples/FixedLenServer/Types.cs b/cs/remote/samples/FixedLenServer/Types.cs index 2945ef75c..42b58dae7 100644 --- a/cs/remote/samples/FixedLenServer/Types.cs +++ b/cs/remote/samples/FixedLenServer/Types.cs @@ -7,7 +7,7 @@ using System.Threading; using FASTER.core; -namespace FixedLenServer +namespace FasterFixedLenServer { [StructLayout(LayoutKind.Explicit, Size = 8)] public struct Key : IFasterEqualityComparer<Key> @@ -53,7 +53,7 @@ public struct Output } - public struct Functions : IFunctions<Key, Value, Input, Output, long> + public struct Functions : IAdvancedFunctions<Key, Value, Input, Output, long> { // No locking needed for atomic types such as Value public bool SupportsLocking => false; @@ -61,7 +61,7 @@ public struct Functions : IFunctions<Key, Value, Input, Output, long> // Callbacks public void RMWCompletionCallback(ref Key key, ref Input input, ref Output output, long ctx, Status status) { } - public void ReadCompletionCallback(ref Key key, ref Input input, ref Output output, long ctx, Status status) { } + public void ReadCompletionCallback(ref Key key, ref Input input, ref Output output, long ctx, Status status, RecordInfo recordInfo) { } public void UpsertCompletionCallback(ref Key key, ref Value value, long ctx) { } @@ -72,17 +72,23 @@ public void CheckpointCompletionCallback(string sessionId, CommitPoint commitPoi // Read functions [MethodImpl(MethodImplOptions.AggressiveInlining)] - public void SingleReader(ref Key key, ref Input input, ref Value value, ref Output dst) => dst.value = value; + public void SingleReader(ref Key key, ref Input input, ref Value value, ref Output dst, long address) + { + dst.value = value; + } [MethodImpl(MethodImplOptions.AggressiveInlining)] - public void ConcurrentReader(ref Key key, ref Input input, ref Value value, ref Output dst) => dst.value = value; + public void ConcurrentReader(ref Key key, ref Input input, ref Value value, ref Output dst, ref RecordInfo recordInfo, long address) + { + dst.value = value; + } // Upsert functions [MethodImpl(MethodImplOptions.AggressiveInlining)] public void SingleWriter(ref Key key, ref Value src, ref Value dst) => dst = src; [MethodImpl(MethodImplOptions.AggressiveInlining)] - public bool ConcurrentWriter(ref Key key, ref Value src, ref Value dst) + public bool ConcurrentWriter(ref Key key, ref Value src, ref Value dst, ref RecordInfo recordInfo, long address) { dst = src; return true; @@ -97,7 +103,7 @@ public void InitialUpdater(ref Key key, ref Input input, ref Value value, ref Ou } [MethodImpl(MethodImplOptions.AggressiveInlining)] - public bool InPlaceUpdater(ref Key key, ref Input input, ref Value value, ref Output output) + public bool InPlaceUpdater(ref Key key, ref Input input, ref Value value, ref Output output, ref RecordInfo recordInfo, long address) { Interlocked.Add(ref value.value, input.value); output.value = value; @@ -114,5 +120,7 @@ public void CopyUpdater(ref Key key, ref Input input, ref Value oldValue, ref Va public void Lock(ref RecordInfo recordInfo, ref Key key, ref Value value, LockType lockType, ref long lockContext) { } public bool Unlock(ref RecordInfo recordInfo, ref Key key, ref Value value, LockType lockType, long lockContext) => true; + + public void ConcurrentDeleter(ref Key key, ref Value value, ref RecordInfo recordInfo, long address) { } } } diff --git a/cs/remote/samples/VarLenClient/CustomTypeFunctions.cs b/cs/remote/samples/VarLenClient/CustomTypeFunctions.cs index c752824ab..2041f831b 100644 --- a/cs/remote/samples/VarLenClient/CustomTypeFunctions.cs +++ b/cs/remote/samples/VarLenClient/CustomTypeFunctions.cs @@ -28,5 +28,13 @@ public override void ReadCompletionCallback(ref CustomType key, ref CustomType i throw new Exception("Unexpected user context"); } } + + public override void SubscribeKVCallback(ref CustomType key, ref CustomType input, ref CustomType output, byte ctx, Status status) + { + } + + public override void SubscribeCallback(ref CustomType key, ref CustomType value, byte ctx) + { + } } } diff --git a/cs/remote/samples/VarLenClient/CustomTypeSamples.cs b/cs/remote/samples/VarLenClient/CustomTypeSamples.cs index c56835b95..fb6d431ef 100644 --- a/cs/remote/samples/VarLenClient/CustomTypeSamples.cs +++ b/cs/remote/samples/VarLenClient/CustomTypeSamples.cs @@ -2,6 +2,7 @@ // Licensed under the MIT license. using System; +using System.Threading; using System.Threading.Tasks; using FASTER.client; using FASTER.common; @@ -22,14 +23,18 @@ public void Run(string ip, int port) // Create a session to FasterKV server // Sessions are mono-threaded, similar to normal FasterKV sessions - using var session = client.NewSession(new CustomTypeFunctions(), WireFormat.DefaultVarLenKV); - - // Explicit version of NewSession call, where you provide all types, callback functions, and serializer - // using var session = client.NewSession<long, long, long, Functions, BlittableParameterSerializer<long, long, long, long>>(new Functions(), new BlittableParameterSerializer<long, long, long, long>()); + var session = client.NewSession(new CustomTypeFunctions(), WireFormat.DefaultVarLenKV); + var subSession = client.NewSession(new CustomTypeFunctions(), WireFormat.DefaultVarLenKV); // Samples using sync client API SyncVarLenSamples(session); + // Samples using sync client API + SyncVarLenSubscriptionKVSamples(session, subSession); + + // Samples using sync client API + SyncVarLenSubscriptionSamples(session, subSession); + // Samples using async client API AsyncVarLenSamples(session).Wait(); } @@ -53,6 +58,54 @@ void SyncVarLenSamples(ClientSession<CustomType, CustomType, CustomType, CustomT session.CompletePending(true); } + void SyncVarLenSubscriptionKVSamples(ClientSession<CustomType, CustomType, CustomType, CustomType, byte, CustomTypeFunctions, FixedLenSerializer<CustomType, CustomType, CustomType, CustomType>> session, + ClientSession<CustomType, CustomType, CustomType, CustomType, byte, CustomTypeFunctions, FixedLenSerializer<CustomType, CustomType, CustomType, CustomType>> session2) + { + session2.SubscribeKV(new CustomType(23)); + session2.CompletePending(true); + + session2.SubscribeKV(new CustomType(24)); + session2.CompletePending(true); + + session2.PSubscribeKV(new CustomType(25)); + session2.CompletePending(true); + + session.Upsert(new CustomType(23), new CustomType(2300)); + session.CompletePending(true); + + session.Upsert(new CustomType(24), new CustomType(2400)); + session.CompletePending(true); + + session.Upsert(new CustomType(25), new CustomType(2500)); + session.CompletePending(true); + + System.Threading.Thread.Sleep(1000); + } + + void SyncVarLenSubscriptionSamples(ClientSession<CustomType, CustomType, CustomType, CustomType, byte, CustomTypeFunctions, FixedLenSerializer<CustomType, CustomType, CustomType, CustomType>> session, + ClientSession<CustomType, CustomType, CustomType, CustomType, byte, CustomTypeFunctions, FixedLenSerializer<CustomType, CustomType, CustomType, CustomType>> session2) + { + session2.Subscribe(new CustomType(23)); + session2.CompletePending(true); + + session2.Subscribe(new CustomType(24)); + session2.CompletePending(true); + + session2.PSubscribe(new CustomType(25)); + session2.CompletePending(true); + + session.Publish(new CustomType(23), new CustomType(2300)); + session.CompletePending(true); + + session.Publish(new CustomType(24), new CustomType(2400)); + session.CompletePending(true); + + session.Publish(new CustomType(25), new CustomType(2500)); + session.CompletePending(true); + + System.Threading.Thread.Sleep(1000); + } + async Task AsyncVarLenSamples(ClientSession<CustomType, CustomType, CustomType, CustomType, byte, CustomTypeFunctions, FixedLenSerializer<CustomType, CustomType, CustomType, CustomType>> session) { // By default, we flush async operations as soon as they are issued diff --git a/cs/remote/samples/VarLenClient/Program.cs b/cs/remote/samples/VarLenClient/Program.cs index 203767c93..a48406f02 100644 --- a/cs/remote/samples/VarLenClient/Program.cs +++ b/cs/remote/samples/VarLenClient/Program.cs @@ -13,6 +13,7 @@ class Program { static void Main(string[] args) { + Environment.SetEnvironmentVariable("DOTNET_SYSTEM_NET_SOCKETS_INLINE_COMPLETIONS", "1"); string ip = "127.0.0.1"; int port = 3278; @@ -20,7 +21,7 @@ static void Main(string[] args) ip = args[0]; if (args.Length > 1 && args[1] != "-") port = int.Parse(args[1]); - + new MemoryBenchmark().Run(ip, port); new MemorySamples().Run(ip, port); new CustomTypeSamples().Run(ip, port); diff --git a/cs/remote/samples/VarLenServer/Options.cs b/cs/remote/samples/VarLenServer/Options.cs index 854c6e1e7..88f08e81c 100644 --- a/cs/remote/samples/VarLenServer/Options.cs +++ b/cs/remote/samples/VarLenServer/Options.cs @@ -2,10 +2,9 @@ // Licensed under the MIT license. using CommandLine; -using System; -using FASTER.core; +using FASTER.server; -namespace ServerOptions +namespace FasterServerOptions { class Options { @@ -36,122 +35,24 @@ class Options [Option('r', "recover", Required = false, Default = false, HelpText = "Recover from latest checkpoint.")] public bool Recover { get; set; } - public int MemorySizeBits() - { - long size = ParseSize(MemorySize); - int bits = (int)Math.Floor(Math.Log2(size)); - Console.WriteLine($"Using log memory size of {PrettySize((long)Math.Pow(2, bits))}"); - if (size != Math.Pow(2, bits)) - Console.WriteLine($"Warning: using lower log memory size than specified (power of 2)"); - return bits; - } - - public int PageSizeBits() - { - long size = ParseSize(PageSize); - int bits = (int)Math.Ceiling(Math.Log2(size)); - Console.WriteLine($"Using page size of {PrettySize((long)Math.Pow(2, bits))}"); - if (size != Math.Pow(2, bits)) - Console.WriteLine($"Warning: using lower page size than specified (power of 2)"); - return bits; - } - - public int SegmentSizeBits() - { - long size = ParseSize(SegmentSize); - int bits = (int)Math.Ceiling(Math.Log2(size)); - Console.WriteLine($"Using disk segment size of {PrettySize((long)Math.Pow(2, bits))}"); - if (size != Math.Pow(2, bits)) - Console.WriteLine($"Warning: using lower disk segment size than specified (power of 2)"); - return bits; - } - - public int IndexSizeCachelines() - { - long size = ParseSize(IndexSize); - int bits = (int)Math.Ceiling(Math.Log2(size)); - long adjustedSize = 1L << bits; - if (adjustedSize < 64) throw new Exception("Invalid index size"); - Console.WriteLine($"Using hash index size of {PrettySize(adjustedSize)} ({PrettySize(adjustedSize/64)} cache lines)"); - if (size != adjustedSize) - Console.WriteLine($"Warning: using lower hash index size than specified (power of 2)"); - return 1 << (bits - 6); - } + [Option("pubsub", Required = false, Default = true, HelpText = "Enable pub/sub feature on server.")] + public bool EnablePubSub { get; set; } - public void GetSettings(out LogSettings logSettings, out CheckpointSettings checkpointSettings, out int indexSize) + public ServerOptions GetServerOptions() { - logSettings = new LogSettings { PreallocateLog = false }; - - logSettings.PageSizeBits = PageSizeBits(); - logSettings.MemorySizeBits = MemorySizeBits(); - Console.WriteLine($"There are {PrettySize(1 << (logSettings.MemorySizeBits - logSettings.PageSizeBits))} log pages in memory"); - logSettings.SegmentSizeBits = SegmentSizeBits(); - indexSize = IndexSizeCachelines(); - - var device = LogDir == "" ? new NullDevice() : Devices.CreateLogDevice(LogDir + "/hlog", preallocateFile: false); - logSettings.LogDevice = device; - - if (CheckpointDir == null && LogDir == null) - checkpointSettings = null; - else - checkpointSettings = new CheckpointSettings { - CheckPointType = CheckpointType.FoldOver, - CheckpointDir = CheckpointDir ?? (LogDir + "/checkpoints") - }; - } - - static long ParseSize(string value) - { - char[] suffix = new char[] { 'k', 'm', 'g', 't', 'p' }; - long result = 0; - foreach (char c in value) - { - if (char.IsDigit(c)) - { - result = result * 10 + (byte)c - '0'; - } - else - { - for (int i = 0; i < suffix.Length; i++) - { - if (char.ToLower(c) == suffix[i]) - { - result *= (long)Math.Pow(1024, i + 1); - return result; - } - } - } - } - return result; - } - - public static string PrettySize(long value) - { - char[] suffix = new char[] { 'k', 'm', 'g', 't', 'p' }; - double v = value; - int exp = 0; - while (v - Math.Floor(v) > 0) - { - if (exp >= 18) - break; - exp += 3; - v *= 1024; - v = Math.Round(v, 12); - } - - while (Math.Floor(v).ToString().Length > 3) + return new ServerOptions { - if (exp <= -18) - break; - exp -= 3; - v /= 1024; - v = Math.Round(v, 12); - } - if (exp > 0) - return v.ToString() + suffix[exp / 3 - 1]; - else if (exp < 0) - return v.ToString() + suffix[-exp / 3 - 1]; - return v.ToString(); + Port = Port, + Address = Address, + MemorySize = MemorySize, + PageSize = PageSize, + IndexSize = IndexSize, + SegmentSize = SegmentSize, + LogDir = LogDir, + CheckpointDir = CheckpointDir, + Recover = Recover, + EnablePubSub = EnablePubSub, + }; } } } \ No newline at end of file diff --git a/cs/remote/samples/VarLenServer/Program.cs b/cs/remote/samples/VarLenServer/Program.cs index 1dade6ebe..311fe9dde 100644 --- a/cs/remote/samples/VarLenServer/Program.cs +++ b/cs/remote/samples/VarLenServer/Program.cs @@ -4,50 +4,28 @@ using System; using System.Threading; using CommandLine; -using ServerOptions; -using FASTER.core; +using FasterServerOptions; using FASTER.server; -using FASTER.common; +using System.Diagnostics; -namespace VarLenServer +namespace FasterVarLenServer { /// <summary> - /// Server for variable-length keys and values. + /// Sample server for variable-length keys and values. /// </summary> class Program { static void Main(string[] args) { - VarLenServer(args); - } - + Trace.Listeners.Add(new ConsoleTraceListener()); - static void VarLenServer(string[] args) - { Console.WriteLine("FASTER variable-length KV server"); ParserResult<Options> result = Parser.Default.ParseArguments<Options>(args); if (result.Tag == ParserResultType.NotParsed) return; var opts = result.MapResult(o => o, xs => new Options()); - - opts.GetSettings(out var logSettings, out var checkpointSettings, out var indexSize); - - // Create a new instance of the FasterKV, customized for variable-length blittable data (represented by SpanByte) - // With SpanByte, keys and values are stored inline in the FASTER log as [ 4 byte length | payload ] - var store = new FasterKV<SpanByte, SpanByte>(indexSize, logSettings, checkpointSettings); - if (opts.Recover) store.Recover(); - - // This variable-length session provider can be used with compatible clients such as VarLenClient - var provider = new SpanByteFasterKVProvider(store); - - // Create server - var server = new FasterServer(opts.Address, opts.Port); - - // Register provider as backend provider for WireFormat.DefaultFixedLenKV - // You can register multiple providers with the same server, with different wire protocol specifications - server.Register(WireFormat.DefaultVarLenKV, provider); - - // Start server + + using var server = new VarLenServer(opts.GetServerOptions()); server.Start(); Console.WriteLine("Started server"); diff --git a/cs/remote/samples/WebClient/FASTERFunctions.js b/cs/remote/samples/WebClient/FASTERFunctions.js new file mode 100644 index 000000000..6ee2501a3 --- /dev/null +++ b/cs/remote/samples/WebClient/FASTERFunctions.js @@ -0,0 +1,47 @@ +// JavaScript source code + +function writeToScreen(message) { + output.insertAdjacentHTML("afterbegin", "<p>" + message + "</p>"); +} + +class FASTERFunctions extends CallbackFunctionsBase { + + constructor(client) { + super(); + } + + ReadCompletionCallback(keyBytes, outputBytes, status) { + if (status == Status.OK) { + var output = deserialize(outputBytes, 0, outputBytes.length); + writeToScreen("<span> value: " + output + " </span>"); + } + } + + UpsertCompletionCallback(keyBytes, valueBytes, status) { + if (status == Status.OK) { + writeToScreen("<span> PUT OK </span>"); + } + } + + DeleteCompletionCallback(keyBytes, status) { } + + RMWCompletionCallback(keyBytes, outputBytes, status) { } + + SubscribeKVCompletionCallback(keyBytes, outputBytes, status) + { + if (status == Status.OK) { + var key = deserialize(keyBytes, 0, keyBytes.length); + var output = deserialize(outputBytes, 0, outputBytes.length); + writeToScreen("<span> subscribed key: " + key + " value: " + output + " </span>"); + } + } + + SubscribeCompletionCallback(keyBytes, valueBytes, status) + { + if (status == Status.OK) { + var key = deserialize(keyBytes, 0, keyBytes.length); + var value = deserialize(valueBytes, 0, valueBytes.length); + writeToScreen("<span> subscribed key: " + key + " value: " + value + " </span>"); + } + } +} \ No newline at end of file diff --git a/cs/remote/samples/WebClient/WebClient.html b/cs/remote/samples/WebClient/WebClient.html new file mode 100644 index 000000000..ab27f4adb --- /dev/null +++ b/cs/remote/samples/WebClient/WebClient.html @@ -0,0 +1,124 @@ +<!DOCTYPE html> +<html lang="en"> +<head> + <meta charset="utf-8"> + <meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no"> + <title>FASTER webclient example</title> + <link rel="stylesheet" href="https://cdn.bootcss.com/bootstrap/3.3.0/css/bootstrap.min.css"> +</head> + +<style> + textarea { + vertical-align: bottom; + } + + #output { + overflow: auto; + } + + #output > p { + overflow-wrap: break-word; + } + + #output span { + color: blue; + } + + #output span.error { + color: red; + } +</style> + +<body> + <h2>FASTER webclient test</h2> + + <label><b>Key :</b> </label> <input type="text" id="putKey" /> + <label><b>Value :</b> </label> <input type="text" id="putValue" /> + <p> Click <button class="btn btn-xs btn-primary" id="putButton">Put</button> button to Upsert key and value</p><br> + <br> + <label><b>Key :</b> </label> <input type="text" id="getKey" /> + <p> Click <button class="btn btn-xs btn-primary" id="getButton">Get</button> button to Get key and value</p><br> + <br> + <label><b>Key :</b> </label> <input type="text" id="subscribeKey" /> + <p> Click <button class="btn btn-xs btn-primary" id="subButton">Subscribe</button> button to Subscribe to a key</p><br> + <br> + + <div id=output></div> + + <script src="..\..\src\FASTER.client\JavascriptClient\Utils.js"></script> + <script src="..\..\src\FASTER.client\JavascriptClient\ParameterSerializer.js"></script> + <script src="..\..\src\FASTER.client\JavascriptClient\Queue.js"></script> + <script src="..\..\src\FASTER.client\JavascriptClient\CallbackFunctionsBase.js"></script> + <script src="..\..\src\FASTER.client\JavascriptClient\BatchHeader.js"></script> + <script src="..\..\src\FASTER.client\JavascriptClient\ClientSession.js"></script> + <script src="..\..\src\FASTER.client\JavascriptClient\ClientNetworkSession.js"></script> + + <script> + function serialize(str) { + var index = 0; + var byteArr = new ArrayBuffer(str.length); + const byteView = new Uint8Array(byteArr); + for (index = 0; index < str.length; index++) { + var code = str.charCodeAt(index); + byteView.set([code], index); + } + return byteView; + } + + function deserialize(byteArr, startIdx, lenString) { + var strByteArr = []; + const byteView = new Uint8Array(byteArr); + for (var index = 0; index < lenString; index++) { + strByteArr[index] = byteView[startIdx + index]; + } + + var result = ""; + for (var i = 0; i < lenString; i++) { + result += String.fromCharCode(parseInt(strByteArr[i], 10)); + } + + return result; + } + + </script> + + <script src="FASTERFunctions.js"></script> + + <script> + // http://www.websocket.org/echo.html + var buttonPut = document.querySelector("#putButton"), + buttonGet = document.querySelector("#getButton"), + buttonSub = document.querySelector("#subButton"); + + buttonPut.addEventListener("click", onClickButtonPut); + buttonGet.addEventListener("click", onClickButtonGet); + buttonSub.addEventListener("click", onClickButtonSub); + + var address = "127.0.0.1"; + var port = 3278; + var remoteSession = new ClientSession(address, port, new FASTERFunctions(this)); + + function onClickButtonPut() { + var keyBytes = serialize(putKey.value); + var valBytes = serialize(putValue.value); + + remoteSession.Upsert(keyBytes, valBytes); + remoteSession.CompletePending(true); + } + + function onClickButtonGet() { + var keyBytes = serialize(getKey.value); + + remoteSession.Read(keyBytes); + remoteSession.CompletePending(true); + } + + function onClickButtonSub() { + var keyBytes = serialize(subscribeKey.value); + + remoteSession.SubscribeKV(keyBytes); + remoteSession.CompletePending(true); + } + </script> +</body> +</html> diff --git a/cs/remote/src/FASTER.client/CallbackFunctionsBase.cs b/cs/remote/src/FASTER.client/CallbackFunctionsBase.cs index c873d7ef4..24db138ed 100644 --- a/cs/remote/src/FASTER.client/CallbackFunctionsBase.cs +++ b/cs/remote/src/FASTER.client/CallbackFunctionsBase.cs @@ -23,5 +23,11 @@ public virtual void ReadCompletionCallback(ref Key key, ref Input input, ref Out public virtual void RMWCompletionCallback(ref Key key, ref Input input, ref Output output, Context ctx, Status status) { } /// <inheritdoc/> public virtual void UpsertCompletionCallback(ref Key key, ref Value value, Context ctx) { } + /// <inheritdoc/> + public virtual void SubscribeKVCallback(ref Key key, ref Input input, ref Output output, Context ctx, Status status) { } + /// <inheritdoc/> + public virtual void PublishCompletionCallback(ref Key key, ref Value value, Context ctx) { } + /// <inheritdoc/> + public virtual void SubscribeCallback(ref Key key, ref Value value, Context ctx) { } } } \ No newline at end of file diff --git a/cs/remote/src/FASTER.client/ClientSession.cs b/cs/remote/src/FASTER.client/ClientSession.cs index ab2a029e0..bcca380d5 100644 --- a/cs/remote/src/FASTER.client/ClientSession.cs +++ b/cs/remote/src/FASTER.client/ClientSession.cs @@ -35,15 +35,17 @@ public unsafe sealed partial class ClientSession<Key, Value, Input, Output, Cont readonly int bufferSize; readonly WireFormat wireFormat; readonly MaxSizeSettings maxSizeSettings; + private bool subscriptionSession; bool disposed; - ReusableObject<SeaaBuffer> sendObject; + SeaaBuffer sendObject; byte* offset; int numMessages; int numPendingBatches; readonly ElasticCircularBuffer<(Key, Value, Context)> upsertQueue; readonly ElasticCircularBuffer<(Key, Input, Output, Context)> readrmwQueue; + readonly ElasticCircularBuffer<(Key, Value, Context)> pubsubQueue; readonly ElasticCircularBuffer<TaskCompletionSource<(Status, Output)>> tcsQueue; /// <summary> @@ -64,14 +66,16 @@ public ClientSession(string address, int port, Functions functions, WireFormat w this.bufferSize = BufferSizeUtils.ClientBufferSize(this.maxSizeSettings); this.messageManager = new NetworkSender(bufferSize); this.disposed = false; + this.subscriptionSession = false; upsertQueue = new ElasticCircularBuffer<(Key, Value, Context)>(); readrmwQueue = new ElasticCircularBuffer<(Key, Input, Output, Context)>(); + pubsubQueue = new ElasticCircularBuffer<(Key, Value, Context)>(); tcsQueue = new ElasticCircularBuffer<TaskCompletionSource<(Status, Output)>>(); numPendingBatches = 0; sendObject = messageManager.GetReusableSeaaBuffer(); - offset = sendObject.obj.bufferPtr + sizeof(int) + BatchHeader.Size; + offset = sendObject.bufferPtr + sizeof(int) + BatchHeader.Size; numMessages = 0; sendSocket = GetSendSocket(address, port); } @@ -214,20 +218,75 @@ public Status Delete(ref Key key, Context userContext = default, long serialNo = public Status Delete(Key key, Context userContext = default, long serialNo = 0) => InternalDelete(MessageType.Delete, ref key, userContext, serialNo); + /// <summary> + /// SubscribeKV operation + /// </summary> + /// <param name="key">Key</param> + /// <param name="input">Input</param> + /// <param name="userContext">User context</param> + /// <param name="serialNo">Serial number</param> + /// <returns>Status of operation</returns> + public void SubscribeKV(Key key, Input input = default, Context userContext = default, long serialNo = 0) + => InternalSubscribeKV(MessageType.SubscribeKV, ref key, ref input, userContext, serialNo); + + /// <summary> + /// PSubscribeKV operation + /// </summary> + /// <param name="prefix">Key</param> + /// <param name="input">Input</param> + /// <param name="userContext">User context</param> + /// <param name="serialNo">Serial number</param> + /// <returns>Status of operation</returns> + public void PSubscribeKV(Key prefix, Input input = default, Context userContext = default, long serialNo = 0) + => InternalSubscribeKV(MessageType.PSubscribeKV, ref prefix, ref input, userContext, serialNo); + + /// <summary> + /// Upsert operation + /// </summary> + /// <param name="key">Key</param> + /// <param name="desiredValue">Desired value</param> + /// <param name="userContext">User context</param> + /// <param name="serialNo">Serial number</param> + /// <returns>Status of operation</returns> + public Status Publish(Key key, Value desiredValue, Context userContext = default, long serialNo = 0) + => InternalPublish(MessageType.Publish, ref key, ref desiredValue, userContext, serialNo); + + /// <summary> + /// SubscribeKV operation + /// </summary> + /// <param name="key">Key</param> + /// <param name="input">Input</param> + /// <param name="userContext">User context</param> + /// <param name="serialNo">Serial number</param> + /// <returns>Status of operation</returns> + public void Subscribe(Key key, Context userContext = default, long serialNo = 0) + => InternalSubscribe(MessageType.Subscribe, ref key, userContext, serialNo); + + /// <summary> + /// PSubscribe operation + /// </summary> + /// <param name="prefix">Key</param> + /// <param name="input">Input</param> + /// <param name="userContext">User context</param> + /// <param name="serialNo">Serial number</param> + /// <returns>Status of operation</returns> + public void PSubscribe(Key prefix, Context userContext = default, long serialNo = 0) + => InternalSubscribe(MessageType.PSubscribe, ref prefix, userContext, serialNo); + /// <summary> /// Flush current buffer of outgoing messages. Does not wait for responses. /// </summary> public void Flush() { - if (offset > sendObject.obj.bufferPtr + sizeof(int) + BatchHeader.Size) + if (offset > sendObject.bufferPtr + sizeof(int) + BatchHeader.Size) { - int payloadSize = (int)(offset - sendObject.obj.bufferPtr); + int payloadSize = (int)(offset - sendObject.bufferPtr); - ((BatchHeader*)(sendObject.obj.bufferPtr + sizeof(int)))->SetNumMessagesProtocol(numMessages, wireFormat); + ((BatchHeader*)(sendObject.bufferPtr + sizeof(int)))->SetNumMessagesProtocol(numMessages, wireFormat); Interlocked.Increment(ref numPendingBatches); // Set packet size in header - *(int*)sendObject.obj.bufferPtr = -(payloadSize - sizeof(int)); + *(int*)sendObject.bufferPtr = -(payloadSize - sizeof(int)); try { @@ -239,7 +298,7 @@ public void Flush() throw; } sendObject = messageManager.GetReusableSeaaBuffer(); - offset = sendObject.obj.bufferPtr + sizeof(int) + BatchHeader.Size; + offset = sendObject.bufferPtr + sizeof(int) + BatchHeader.Size; numMessages = 0; } } @@ -263,7 +322,8 @@ public void CompletePending(bool wait = true) public void Dispose() { disposed = true; - sendObject.Dispose(); + if (sendObject != null) + messageManager.Return(sendObject); sendSocket.Dispose(); messageManager.Dispose(); } @@ -281,6 +341,7 @@ public void Dispose(bool completePending) int lastSeqNo = -1; + readonly Dictionary<int, (Key, Value, Context)> pubsubPendingContext = new(); readonly Dictionary<int, (Key, Input, Output, Context)> readRmwPendingContext = new(); readonly Dictionary<int, TaskCompletionSource<(Status, Output)>> readRmwPendingTcs = new(); @@ -292,7 +353,7 @@ internal void ProcessReplies(byte[] buf, int offset) var src = b; var seqNo = ((BatchHeader*)src)->SeqNo; var count = ((BatchHeader*)src)->NumMessages; - if (seqNo != lastSeqNo + 1) + if (seqNo != lastSeqNo + 1 && !subscriptionSession) throw new Exception("Out of order message within session"); lastSeqNo = seqNo; @@ -401,6 +462,110 @@ internal void ProcessReplies(byte[] buf, int offset) tcs.SetResult((status, default)); break; } + case MessageType.SubscribeKV: + { + var status = ReadStatus(ref src); + var p = hrw.ReadPendingSeqNo(ref src); + if (status == Status.OK) + { + readRmwPendingContext.TryGetValue(p, out var result); + result.Item3 = serializer.ReadOutput(ref src); + functions.SubscribeKVCallback(ref result.Item1, ref result.Item2, ref result.Item3, result.Item4, Status.OK); + } + else if (status == Status.NOTFOUND) + { + readRmwPendingContext.TryGetValue(p, out var result); + functions.SubscribeKVCallback(ref result.Item1, ref result.Item2, ref defaultOutput, result.Item4, Status.NOTFOUND); + } + else if (status == Status.PENDING) + { + var result = readrmwQueue.Dequeue(); + readRmwPendingContext.Add(p, result); + } + else + { + throw new Exception("Unexpected status of SubscribeKV"); + } + break; + } + case MessageType.PSubscribeKV: + { + var status = ReadStatus(ref src); + var p = hrw.ReadPendingSeqNo(ref src); + if (status == Status.OK) + { + readRmwPendingContext.TryGetValue(p, out var result); + result.Item1 = serializer.ReadKey(ref src); + result.Item3 = serializer.ReadOutput(ref src); + functions.SubscribeKVCallback(ref result.Item1, ref result.Item2, ref result.Item3, result.Item4, Status.OK); + } + else if (status == Status.NOTFOUND) + { + readRmwPendingContext.TryGetValue(p, out var result); + result.Item1 = serializer.ReadKey(ref src); + functions.SubscribeKVCallback(ref result.Item1, ref result.Item2, ref defaultOutput, result.Item4, Status.NOTFOUND); + } + else if (status == Status.PENDING) + { + var result = readrmwQueue.Dequeue(); + readRmwPendingContext.Add(p, result); + } + else + { + throw new Exception("Unexpected status of SubscribeKV"); + } + break; + } + case MessageType.Publish: + { + var status = ReadStatus(ref src); + (Key, Value, Context) result = upsertQueue.Dequeue(); + functions.PublishCompletionCallback(ref result.Item1, ref result.Item2, result.Item3); + break; + } + case MessageType.Subscribe: + { + var status = ReadStatus(ref src); + var p = hrw.ReadPendingSeqNo(ref src); + if (status == Status.OK) + { + pubsubPendingContext.TryGetValue(p, out var result); + result.Item2 = serializer.ReadValue(ref src); + functions.SubscribeCallback(ref result.Item1, ref result.Item2, result.Item3); + } + else if (status == Status.PENDING) + { + var result = pubsubQueue.Dequeue(); + pubsubPendingContext.Add(p, result); + } + else + { + throw new Exception("Unexpected status of SubscribeKV"); + } + break; + } + case MessageType.PSubscribe: + { + var status = ReadStatus(ref src); + var p = hrw.ReadPendingSeqNo(ref src); + if (status == Status.OK) + { + pubsubPendingContext.TryGetValue(p, out var result); + result.Item1 = serializer.ReadKey(ref src); + result.Item2 = serializer.ReadValue(ref src); + functions.SubscribeCallback(ref result.Item1, ref result.Item2, result.Item3); + } + else if (status == Status.PENDING) + { + var result = pubsubQueue.Dequeue(); + pubsubPendingContext.Add(p, result); + } + else + { + throw new Exception("Unexpected status of SubscribeKV"); + } + break; + } case MessageType.PendingResult: { HandlePending(ref src); @@ -491,6 +656,24 @@ private void HandlePending(ref byte* src) result.SetResult((status, default)); break; } + case MessageType.SubscribeKV: + { + var status = ReadStatus(ref src); + if (!readRmwPendingContext.TryGetValue(p, out var result)) + { + Debug.WriteLine("Received unexpected subsription key"); + break; + } + + if (status == Status.OK) + { + result.Item3 = serializer.ReadOutput(ref src); + functions.ReadCompletionCallback(ref result.Item1, ref result.Item2, ref result.Item3, result.Item4, status); + } + else + functions.ReadCompletionCallback(ref result.Item1, ref result.Item2, ref defaultOutput, result.Item4, status); + break; + } default: { throw new NotImplementedException(); @@ -539,17 +722,88 @@ private Socket GetSendSocket(string address, int port, int millisecondsTimeout = [MethodImpl(MethodImplOptions.AggressiveInlining)] private unsafe Status InternalRead(MessageType messageType, ref Key key, ref Input input, ref Output output, Context userContext = default, long serialNo = 0) { + Debug.Assert(!subscriptionSession); + while (true) { - byte* end = sendObject.obj.bufferPtr + bufferSize; + byte* end = sendObject.bufferPtr + bufferSize; byte* curr = offset; if (hrw.Write(messageType, ref curr, (int)(end - curr))) - if (serializer.Write(ref key, ref curr, (int)(end - curr))) - if (serializer.Write(ref input, ref curr, (int)(end - curr))) + if (hrw.Write(serialNo, ref curr, (int)(end - curr))) + if (serializer.Write(ref key, ref curr, (int)(end - curr))) + if (serializer.Write(ref input, ref curr, (int)(end - curr))) + { + numMessages++; + offset = curr; + readrmwQueue.Enqueue((key, input, output, userContext)); + return Status.PENDING; + } + Flush(); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private unsafe Status InternalSubscribeKV(MessageType messageType, ref Key key, ref Input input, Context userContext = default, long serialNo = 0) + { + subscriptionSession = true; + + while (true) + { + byte* end = sendObject.bufferPtr + bufferSize; + byte* curr = offset; + if (hrw.Write(messageType, ref curr, (int)(end - curr))) + if (hrw.Write(serialNo, ref curr, (int)(end - curr))) + if (serializer.Write(ref key, ref curr, (int)(end - curr))) + if (serializer.Write(ref input, ref curr, (int)(end - curr))) + { + numMessages++; + offset = curr; + readrmwQueue.Enqueue((key, input, default, userContext)); + return Status.PENDING; + } + Flush(); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private unsafe Status InternalPublish(MessageType messageType, ref Key key, ref Value desiredValue, Context userContext = default, long serialNo = 0) + { + Debug.Assert(!subscriptionSession); + + while (true) + { + byte* end = sendObject.bufferPtr + bufferSize; + byte* curr = offset; + if (hrw.Write(messageType, ref curr, (int)(end - curr))) + if (hrw.Write(serialNo, ref curr, (int)(end - curr))) + if (serializer.Write(ref key, ref curr, (int)(end - curr))) + if (serializer.Write(ref desiredValue, ref curr, (int)(end - curr))) + { + numMessages++; + offset = curr; + upsertQueue.Enqueue((key, desiredValue, userContext)); + return Status.PENDING; + } + Flush(); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private unsafe Status InternalSubscribe(MessageType messageType, ref Key key, Context userContext = default, long serialNo = 0) + { + subscriptionSession = true; + + while (true) + { + byte* end = sendObject.bufferPtr + bufferSize; + byte* curr = offset; + if (hrw.Write(messageType, ref curr, (int)(end - curr))) + if (hrw.Write(serialNo, ref curr, (int)(end - curr))) + if (serializer.Write(ref key, ref curr, (int)(end - curr))) { numMessages++; offset = curr; - readrmwQueue.Enqueue((key, input, output, userContext)); + pubsubQueue.Enqueue((key, default, userContext)); return Status.PENDING; } Flush(); @@ -559,19 +813,22 @@ private unsafe Status InternalRead(MessageType messageType, ref Key key, ref Inp [MethodImpl(MethodImplOptions.AggressiveInlining)] private unsafe Status InternalUpsert(MessageType messageType, ref Key key, ref Value desiredValue, Context userContext = default, long serialNo = 0) { + Debug.Assert(!subscriptionSession); + while (true) { - byte* end = sendObject.obj.bufferPtr + bufferSize; + byte* end = sendObject.bufferPtr + bufferSize; byte* curr = offset; if (hrw.Write(messageType, ref curr, (int)(end - curr))) - if (serializer.Write(ref key, ref curr, (int)(end - curr))) - if (serializer.Write(ref desiredValue, ref curr, (int)(end - curr))) - { - numMessages++; - offset = curr; - upsertQueue.Enqueue((key, desiredValue, userContext)); - return Status.PENDING; - } + if (hrw.Write(serialNo, ref curr, (int)(end - curr))) + if (serializer.Write(ref key, ref curr, (int)(end - curr))) + if (serializer.Write(ref desiredValue, ref curr, (int)(end - curr))) + { + numMessages++; + offset = curr; + upsertQueue.Enqueue((key, desiredValue, userContext)); + return Status.PENDING; + } Flush(); } } @@ -579,19 +836,22 @@ private unsafe Status InternalUpsert(MessageType messageType, ref Key key, ref V [MethodImpl(MethodImplOptions.AggressiveInlining)] private unsafe Status InternalRMW(MessageType messageType, ref Key key, ref Input input, ref Output output, Context userContext = default, long serialNo = 0) { + Debug.Assert(!subscriptionSession); + while (true) { - byte* end = sendObject.obj.bufferPtr + bufferSize; + byte* end = sendObject.bufferPtr + bufferSize; byte* curr = offset; if (hrw.Write(messageType, ref curr, (int)(end - curr))) - if (serializer.Write(ref key, ref curr, (int)(end - curr))) - if (serializer.Write(ref input, ref curr, (int)(end - curr))) - { - numMessages++; - offset = curr; - readrmwQueue.Enqueue((key, input, output, userContext)); - return Status.PENDING; - } + if (hrw.Write(serialNo, ref curr, (int)(end - curr))) + if (serializer.Write(ref key, ref curr, (int)(end - curr))) + if (serializer.Write(ref input, ref curr, (int)(end - curr))) + { + numMessages++; + offset = curr; + readrmwQueue.Enqueue((key, input, output, userContext)); + return Status.PENDING; + } Flush(); } } @@ -599,18 +859,21 @@ private unsafe Status InternalRMW(MessageType messageType, ref Key key, ref Inpu [MethodImpl(MethodImplOptions.AggressiveInlining)] private unsafe Status InternalDelete(MessageType messageType, ref Key key, Context userContext = default, long serialNo = 0) { + Debug.Assert(!subscriptionSession); + while (true) { - byte* end = sendObject.obj.bufferPtr + bufferSize; + byte* end = sendObject.bufferPtr + bufferSize; byte* curr = offset; if (hrw.Write(messageType, ref curr, (int)(end - curr))) - if (serializer.Write(ref key, ref curr, (int)(end - curr))) - { - numMessages++; - offset = curr; - upsertQueue.Enqueue((key, default, userContext)); - return Status.PENDING; - } + if (hrw.Write(serialNo, ref curr, (int)(end - curr))) + if (serializer.Write(ref key, ref curr, (int)(end - curr))) + { + numMessages++; + offset = curr; + upsertQueue.Enqueue((key, default, userContext)); + return Status.PENDING; + } Flush(); } } diff --git a/cs/remote/src/FASTER.client/FixedLenSerializer.cs b/cs/remote/src/FASTER.client/FixedLenSerializer.cs index 5abbe6be5..a83b69344 100644 --- a/cs/remote/src/FASTER.client/FixedLenSerializer.cs +++ b/cs/remote/src/FASTER.client/FixedLenSerializer.cs @@ -54,6 +54,22 @@ public bool Write(ref Input i, ref byte* dst, int length) return true; } + /// <inheritdoc /> + public Key ReadKey(ref byte* src) + { + var _src = src; + src += Unsafe.SizeOf<Key>(); + return Unsafe.AsRef<Key>(_src); + } + + /// <inheritdoc /> + public Value ReadValue(ref byte* src) + { + var _src = src; + src += Unsafe.SizeOf<Value>(); + return Unsafe.AsRef<Value>(_src); + } + /// <inheritdoc /> public Output ReadOutput(ref byte* src) { diff --git a/cs/remote/src/FASTER.client/ICallbackFunctions.cs b/cs/remote/src/FASTER.client/ICallbackFunctions.cs index 09136ba73..7bec17409 100644 --- a/cs/remote/src/FASTER.client/ICallbackFunctions.cs +++ b/cs/remote/src/FASTER.client/ICallbackFunctions.cs @@ -49,5 +49,31 @@ public interface ICallbackFunctions<Key, Value, Input, Output, Context> /// <param name="key"></param> /// <param name="ctx"></param> void DeleteCompletionCallback(ref Key key, Context ctx); + + /// <summary> + /// Subscribe KV callback + /// </summary> + /// <param name="key"></param> + /// /// <param name="input"></param> + /// <param name="output"></param> + /// <param name="ctx"></param> + /// <param name="status"></param> + void SubscribeKVCallback(ref Key key, ref Input input, ref Output output, Context ctx, Status status); + + /// <summary> + /// Publish completion + /// </summary> + /// <param name="key"></param> + /// <param name="value"></param> + /// <param name="ctx"></param> + void PublishCompletionCallback(ref Key key, ref Value value, Context ctx); + + /// <summary> + /// Subscribe callback + /// </summary> + /// <param name="key"></param> + /// <param name="value"></param> + /// <param name="ctx"></param> + void SubscribeCallback(ref Key key, ref Value value, Context ctx); } } \ No newline at end of file diff --git a/cs/remote/src/FASTER.client/JavascriptClient/BatchHeader.js b/cs/remote/src/FASTER.client/JavascriptClient/BatchHeader.js new file mode 100644 index 000000000..4e8971f9c --- /dev/null +++ b/cs/remote/src/FASTER.client/JavascriptClient/BatchHeader.js @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +/// <summary> +/// Header for message batch (Little Endian server) +/// [4 byte seqNo][1 byte protocol][3 byte numMessages] +/// </summary> +var BatchHeader = new Object(); +BatchHeader.Size = 8; +BatchHeader.NumMessages = 0; \ No newline at end of file diff --git a/cs/remote/src/FASTER.client/JavascriptClient/CallbackFunctionsBase.js b/cs/remote/src/FASTER.client/JavascriptClient/CallbackFunctionsBase.js new file mode 100644 index 000000000..1bb6dd076 --- /dev/null +++ b/cs/remote/src/FASTER.client/JavascriptClient/CallbackFunctionsBase.js @@ -0,0 +1,16 @@ +class CallbackFunctionsBase +{ + constructor() { } + + ReadCompletionCallback(key, output, status) { } + + UpsertCompletionCallback(key, value, status) { } + + DeleteCompletionCallback(key, status) { } + + RMWCompletionCallback(key, output, status) { } + + SubscribeKVCompletionCallback(key, output, status) { } + + SubscribeCompletionCallback(key, output, status) { } +} \ No newline at end of file diff --git a/cs/remote/src/FASTER.client/JavascriptClient/ClientNetworkSession.js b/cs/remote/src/FASTER.client/JavascriptClient/ClientNetworkSession.js new file mode 100644 index 000000000..6d5a7f4e4 --- /dev/null +++ b/cs/remote/src/FASTER.client/JavascriptClient/ClientNetworkSession.js @@ -0,0 +1,31 @@ +class ClientNetworkSession { + constructor(session, address, port) { + + this.address = address; + this.port = port; + this.clientSession = session; + var wsUri = "ws://" + this.address + ":" + this.port; + + this.websocket = new WebSocket(wsUri); + var self = this; + + this.websocket.binaryType = 'arraybuffer'; + + this.websocket.onopen = function (e) { + writeToScreen("CONNECTED"); + }; + + this.websocket.onclose = function (e) { + alert("Disconnected"); + }; + + this.websocket.onmessage = function (e) { + const view = new DataView(e.data); + self.clientSession.ProcessReplies(view, e.data); + }; + + this.websocket.onerror = function (e) { + writeToScreen("<span class=error>ERROR:</span> " + e.data); + }; + } +} \ No newline at end of file diff --git a/cs/remote/src/FASTER.client/JavascriptClient/ClientSession.js b/cs/remote/src/FASTER.client/JavascriptClient/ClientSession.js new file mode 100644 index 000000000..2e6270fb3 --- /dev/null +++ b/cs/remote/src/FASTER.client/JavascriptClient/ClientSession.js @@ -0,0 +1,282 @@ +class ClientSession { + + constructor(address, port, functions, maxSizeSettings) { + this.functions = functions; + this.readrmwPendingContext = {}; + this.pubsubPendingContext = {}; + this.maxSizeSettings = maxSizeSettings ?? new MaxSizeSettings(); + this.bufferSize = JSUtils.ClientBufferSize(this.maxSizeSettings); + this.serializer = new ParameterSerializer(); + this.readrmwQueue = new Queue(); + this.pubsubQueue = new Queue(); + this.upsertQueue = new Queue(); + this.sendSocket = new ClientNetworkSession(this, address, port); + this.intSerializer = new IntSerializer(); + this.offset = 4 + BatchHeader.Size; + this.numMessages = 0; + this.reusableBuffer = new ArrayBuffer(this.bufferSize); + this.numPendingBatches = 0; + } + + ProcessReplies(view, arrayBuf) { + var count = view.getUint8(8); + var arrIdx = 4 + BatchHeader.Size; + + for (var i = 0; i < count; i++) { + const op = view.getUint8(arrIdx); + arrIdx++; + const status = view.getUint8(arrIdx); + arrIdx++; + var output = []; + + switch (op) { + case MessageType.Read: + var key = this.readrmwQueue.dequeue(); + if (status == Status.OK) { + output = this.serializer.ReadOutput(arrayBuf, arrIdx); + arrIdx += output.length + 4; + this.functions.ReadCompletionCallback(key, output, status); + break; + } else if (status == Status.PENDING) { + var p = this.intSerializer.deserialize(arrayBuf, arrIdx); + arrIdx += 4; + readrmwPendingContext[p] = key; + } else { + this.functions.ReadCompletionCallback(key, output, status); + } + break; + + case MessageType.Upsert: + var keyValue = this.upsertQueue.dequeue(); + this.functions.UpsertCompletionCallback(keyValue[0], keyValue[1], status); + break; + + case MessageType.Delete: + var keyValue = this.upsertQueue.dequeue(); + this.functions.DeleteCompletionCallback(keyValue[0], status); + break; + + case MessageType.RMW: + var key = this.readrmwQueue.dequeue(); + if (status == Status.OK || status == Status.NOTFOUND) { + output = this.serializer.ReadOutput(arrayBuf, arrIdx); + arrIdx += output.length + 4; + this.functions.RMWCompletionCallback(key, output, status); + } else if (status == Status.PENDING) { + var p = this.intSerializer.deserialize(arrayBuf, arrIdx); + arrIdx += 4; + readrmwPendingContext[p] = key; + } else { + output = []; + this.functions.RMWCompletionCallback(key, output, status); + } + break; + + case MessageType.SubscribeKV: + var sid = this.intSerializer.deserialize(arrayBuf, arrIdx); + arrIdx += 4; + if (status == Status.OK || status == Status.NOTFOUND) { + var key = this.readrmwPendingContext[sid]; + output = this.serializer.ReadOutput(arrayBuf, arrIdx); + arrIdx += output.length + 4; + this.functions.SubscribeKVCompletionCallback(key, output, status); + } else if (status == Status.PENDING) { + var key = this.readrmwQueue.dequeue(); + this.readrmwPendingContext[sid] = key; + } else { + var key = this.readrmwPendingContext[sid]; + output = []; + this.functions.SubscribeKVCompletionCallback(key, output, status); + } + break; + + case MessageType.Subscribe: + var sid = this.intSerializer.deserialize(arrayBuf, arrIdx); + arrIdx += 4; + if (status == Status.OK || status == Status.NOTFOUND) { + var key = this.pubsubPendingContext[sid]; + output = this.serializer.ReadOutput(arrayBuf, arrIdx); + arrIdx += output.length + 4; + this.functions.SubscribeCompletionCallback(key, output, status); + } else if (status == Status.PENDING) { + var key = this.pubsubQueue.dequeue(); + this.pubsubPendingContext[sid] = key; + } else { + var key = this.pubsubPendingContext[sid]; + value = [] + this.functions.SubscribeCompletionCallback(key, value, status); + } + break; + + case MessageType.HandlePending: + HandlePending(view, arrayBuf, arrIdx - 1); + break; + + default: + alert("Wrong reply received"); + } + } + this.numPendingBatches--; + } + + HandlePending(view, arrayBuf, arrIdx) { + var output = []; + var origMessage = view.getUint8(arrIdx++); + var p = this.intSerializer.deserialize(arrayBuf, arrIdx); + arrIdx += 4; + var status = view.getUint8(arrIdx++); + + switch (origMessage) { + case MessageType.Read: + var key = this.readrmwPendingContext[p]; + delete this.readrmwPendingContext[p]; + if (status == Status.OK) { + output = this.serializer.ReadOutput(arrayBuf, arrIdx); + arrIdx += output.length + 4; + } + this.functions.ReadCompletionCallback(key, output, status); + break; + + case MessageType.RMW: + var key = this.readrmwPendingContext[p]; + delete this.readrmwPendingContext[p]; + if (status == Status.OK) { + output = this.serializer.ReadOutput(arrayBuf, arrIdx); + arrIdx += output.length + 4; + } + this.functions.RMWCompletionCallback(key, output, status); + break; + + case MessageType.SubscribeKV: + var key = this.readrmwPendingContext[p]; + if (status == Status.OK) { + output = this.serializer.ReadOutput(arrayBuf, arrIdx); + arrIdx += output.length + 4; + } + this.functions.SubscribeKVCompletionCallback(key, output, status); + break; + + default: + alert("Not implemented exception"); + } + } + + /// <summary> + /// Flush current buffer of outgoing messages. Does not wait for responses. + /// </summary> + Flush() { + if (this.offset > 4 + BatchHeader.Size) { + + var payloadSize = this.offset; + this.intSerializer.serialize(this.reusableBuffer, 4, 0); + this.intSerializer.serialize(this.reusableBuffer, 8, this.numMessages); + this.intSerializer.serialize(this.reusableBuffer, 0, (payloadSize - 4)); + this.numPendingBatches++; + + var sendingView = new Uint8Array(this.reusableBuffer, 0, payloadSize); + this.sendSocket.websocket.send(sendingView); + this.offset = 4 + BatchHeader.Size; + this.numMessages = 0; + } + } + + /// <summary> + /// Flush current buffer of outgoing messages. Spin-wait for all responses to be received and process them. + /// </summary> + CompletePending(wait = true) { + this.Flush(); + } + + Upsert(key, value) { + // OP SEQ NUMBER, OP, LEN(KEY), KEY, LEN(VAL), VAL + var view = new Uint8Array(this.reusableBuffer); + + var arrIdx = this.offset; + view[arrIdx++] = MessageType.Upsert; + + arrIdx = this.serializer.WriteKVI(this.reusableBuffer, arrIdx, key, key.length); + arrIdx = this.serializer.WriteKVI(this.reusableBuffer, arrIdx, value, value.length); + + this.upsertQueue.enqueue([key, value]); + + this.offset = arrIdx; + this.numMessages++; + } + + Read(key) { + // OP SEQ NUMBER, OP, LEN(KEY), KEY + var view = new Uint8Array(this.reusableBuffer); + + var arrIdx = this.offset; + view[arrIdx++] = MessageType.Read; + + arrIdx = this.serializer.WriteKVI(this.reusableBuffer, arrIdx, key, key.length); + + this.readrmwQueue.enqueue(key); + + this.offset = arrIdx; + this.numMessages++; + } + + Delete(key) { + // OP SEQ NUMBER, OP, LEN(KEY), KEY + var view = new Uint8Array(this.reusableBuffer); + + var arrIdx = this.offset; + view[arrIdx++] = MessageType.Delete; + + arrIdx = this.serializer.WriteKVI(this.reusableBuffer, arrIdx, key, key.length); + + var value = []; + this.upsertQueue.enqueue([key, value]); + + this.offset = arrIdx; + this.numMessages++; + } + + RMW(key, input) { + // OP SEQ NUMBER, OP, LEN(KEY), KEY, LEN(OP_ID + INPUT), OP_ID, INPUT + var view = new Uint8Array(this.reusableBuffer); + + var arrIdx = this.offset; + view[arrIdx++] = MessageType.RMW; + + arrIdx = this.serializer.WriteKVI(this.reusableBuffer, arrIdx, key, key.length); + arrIdx = this.serializer.WriteKVI(this.reusableBuffer, arrIdx, input, input.length); + + this.readrmwQueue.enqueue(key); + + this.offset = arrIdx; + this.numMessages++; + } + + SubscribeKV(key) { + // OP SEQ NUMBER, OP, LEN(KEY), KEY + var view = new Uint8Array(this.reusableBuffer); + + var arrIdx = this.offset; + view[arrIdx++] = MessageType.SubscribeKV; + + arrIdx = this.serializer.WriteKVI(this.reusableBuffer, arrIdx, key, key.length); + + this.readrmwQueue.enqueue(key); + + this.offset = arrIdx; + this.numMessages++; + } + + Subscribe(key) { + // OP SEQ NUMBER, OP, LEN(KEY), KEY + var view = new Uint8Array(this.reusableBuffer); + + var arrIdx = this.offset; + view[arrIdx++] = MessageType.Subscribe; + + arrIdx = this.serializer.WriteKVI(this.reusableBuffer, arrIdx, key, key.length); + + this.pubsubQueue.enqueue(key); + + this.offset = arrIdx; + this.numMessages++; + } +} \ No newline at end of file diff --git a/cs/remote/src/FASTER.client/JavascriptClient/ParameterSerializer.js b/cs/remote/src/FASTER.client/JavascriptClient/ParameterSerializer.js new file mode 100644 index 000000000..499850129 --- /dev/null +++ b/cs/remote/src/FASTER.client/JavascriptClient/ParameterSerializer.js @@ -0,0 +1,93 @@ +class IntSerializer { + serialize(byteArr, arrIdx, val) { + var index = arrIdx; + const byteView = new Uint8Array(byteArr); + for (index = arrIdx; index < arrIdx + 4; index++) { + var byte = val & 0xff; + byteView.set([byte], index); + val = (val - byte) / 256; + } + + return index; + } + + deserialize(byteArr, arrIdx) { + var index = arrIdx; + const byteView = new Uint8Array(byteArr); + var val = 0; + var tmpVal = 0; + for (index = arrIdx; index < arrIdx + 4; index++) { + var byte = byteView[index]; + tmpVal = byte * (256 ** (index - arrIdx)); + val = val + tmpVal; + } + + return val; + } +} + +class ASCIISerializer { + serialize(byteArr, arrIdx, str) { + var index = 0; + const byteView = new Uint8Array(byteArr); + for (index = 0; index < str.length; index++) { + var code = str.charCodeAt(index); + byteView.set([code], arrIdx + index); + } + return arrIdx + index; + } + + deserialize(byteArr, startIdx, lenString) { + var strByteArr = []; + const byteView = new Uint8Array(byteArr); + for (var index = 0; index < lenString; index++) { + strByteArr[index] = byteView[startIdx + index]; + } + + var result = ""; + for (var i = 0; i < lenString; i++) { + result += String.fromCharCode(parseInt(strByteArr[i], 10)); + // result += String.fromCharCode(parseInt(byteView.getUint8(i), 10)); + } + + return result; + } +} + + +class ParameterSerializer { + + constructor() { + this.intSerializer = new IntSerializer(); + } + + WriteOpSeqNum(byteArr, arrIdx, opSeqNum) { + return this.intSerializer.serialize(byteArr, arrIdx, opSeqNum); + } + + WriteKVI(byteArr, arrIdx, objBytes, lenObj) { + const byteView = new Uint8Array(byteArr); + arrIdx = this.intSerializer.serialize(byteArr, arrIdx, lenObj); + + for (var index = 0; index < lenObj; index++) { + byteView[arrIdx + index] = objBytes[index]; + } + + var endIdx = arrIdx + lenObj; + return endIdx; + } + + ReadOutput(byteArr, arrIdx) { + var lenOutput = this.intSerializer.deserialize(byteArr, arrIdx); + const byteView = new Uint8Array(byteArr); + + arrIdx += 4; + var output = []; + + for (var index = 0; index < lenOutput; index++) { + output[index] = byteView[arrIdx + index]; + } + + return output; + } +} diff --git a/cs/remote/src/FASTER.client/JavascriptClient/Queue.js b/cs/remote/src/FASTER.client/JavascriptClient/Queue.js new file mode 100644 index 000000000..ca799dcb2 --- /dev/null +++ b/cs/remote/src/FASTER.client/JavascriptClient/Queue.js @@ -0,0 +1,26 @@ +function Queue() { + this.elements = []; +} + +Queue.prototype.enqueue = function (e) { + this.elements.push(e); +}; + +// remove an element from the front of the queue +Queue.prototype.dequeue = function () { + return this.elements.shift(); +}; + +// check if the queue is empty +Queue.prototype.isEmpty = function () { + return this.elements.length == 0; +}; + +// get the element at the front of the queue +Queue.prototype.peek = function () { + return !this.isEmpty() ? this.elements[0] : undefined; +}; + +Queue.prototype.length = function () { + return this.elements.length; +} \ No newline at end of file diff --git a/cs/remote/src/FASTER.client/JavascriptClient/Utils.js b/cs/remote/src/FASTER.client/JavascriptClient/Utils.js new file mode 100644 index 000000000..af9622ea0 --- /dev/null +++ b/cs/remote/src/FASTER.client/JavascriptClient/Utils.js @@ -0,0 +1,52 @@ +var MaxBatchSize = 1 << 17; + +class MaxSizeSettings +{ + constructor() { + this.MaxKeySize = 4096; + this.MaxValueSize = 4096; + this.MaxInputSize = 4096; + this.MaxOutputSize = 4096; + } +} + +class JSUtils +{ + constructor() { } + + static ClientBufferSize(maxSizeSettings) { + var minSizeUpsert = maxSizeSettings.MaxKeySize + maxSizeSettings.MaxValueSize + 2; + var minSizeReadRmw = maxSizeSettings.MaxKeySize + maxSizeSettings.MaxInputSize + 2; + + // leave enough space for double buffering + var minSize = 2 * (minSizeUpsert < minSizeReadRmw ? minSizeReadRmw : minSizeUpsert) + 4; + + return MaxBatchSize < minSize ? minSize : MaxBatchSize; + } + +} + +const Status = { + OK: 0, + NOTFOUND: 1, + PENDING: 2, + ERROR: 3 +}; + +const MessageType = { + Read: 0, + Upsert: 1, + RMW: 2, + Delete: 3, + ReadAsync: 4, + UpsertAsync: 5, + RMWAsync: 6, + DeleteAsync: 7, + SubscribeKV: 8, + PSubscribeKV: 9, + Subscribe: 10, + Publish: 11, + PSubscribe: 12, + PendingResult: 13, +}; +Object.freeze(MessageType); diff --git a/cs/remote/src/FASTER.client/MemoryFunctionsBase.cs b/cs/remote/src/FASTER.client/MemoryFunctionsBase.cs index 4c13f652d..cf8528050 100644 --- a/cs/remote/src/FASTER.client/MemoryFunctionsBase.cs +++ b/cs/remote/src/FASTER.client/MemoryFunctionsBase.cs @@ -24,5 +24,11 @@ public virtual void RMWCompletionCallback(ref ReadOnlyMemory<T> key, ref ReadOnl /// <inheritdoc /> public virtual void UpsertCompletionCallback(ref ReadOnlyMemory<T> key, ref ReadOnlyMemory<T> value, byte ctx) { } + /// <inheritdoc /> + public virtual void SubscribeKVCallback(ref ReadOnlyMemory<T> key, ref ReadOnlyMemory<T> input, ref (IMemoryOwner<T>, int) output, byte ctx, Status status) { } + /// <inheritdoc/> + public virtual void PublishCompletionCallback(ref ReadOnlyMemory<T> key, ref ReadOnlyMemory<T> value, byte ctx) { } + /// <inheritdoc/> + public virtual void SubscribeCallback(ref ReadOnlyMemory<T> key, ref ReadOnlyMemory<T> value, byte ctx) { } } } \ No newline at end of file diff --git a/cs/remote/src/FASTER.client/MemoryParameterSerializer.cs b/cs/remote/src/FASTER.client/MemoryParameterSerializer.cs index 42e759ae9..02497c64f 100644 --- a/cs/remote/src/FASTER.client/MemoryParameterSerializer.cs +++ b/cs/remote/src/FASTER.client/MemoryParameterSerializer.cs @@ -26,6 +26,28 @@ public MemoryParameterSerializer(MemoryPool<T> memoryPool = default) this.memoryPool = memoryPool ?? MemoryPool<T>.Shared; } + /// <inheritdoc /> + public ReadOnlyMemory<T> ReadKey(ref byte* src) + { + var len = (*(int*)src) / sizeof(T); + var mem = memoryPool.Rent(len); + new ReadOnlySpan<byte>(src + sizeof(int), (*(int*)src)).CopyTo( + MemoryMarshal.Cast<T, byte>(mem.Memory.Span)); + src += sizeof(int) + (*(int*)src); + return (mem.Memory.Slice(0, len)); + } + + /// <inheritdoc /> + public ReadOnlyMemory<T> ReadValue(ref byte* src) + { + var len = (*(int*)src) / sizeof(T); + var mem = memoryPool.Rent(len); + new ReadOnlySpan<byte>(src + sizeof(int), (*(int*)src)).CopyTo( + MemoryMarshal.Cast<T, byte>(mem.Memory.Span)); + src += sizeof(int) + (*(int*)src); + return (mem.Memory.Slice(0, len)); + } + /// <inheritdoc /> public (IMemoryOwner<T>, int) ReadOutput(ref byte* src) { diff --git a/cs/remote/src/FASTER.common/HeaderReaderWriter.cs b/cs/remote/src/FASTER.common/HeaderReaderWriter.cs index 151dff300..0833fff4d 100644 --- a/cs/remote/src/FASTER.common/HeaderReaderWriter.cs +++ b/cs/remote/src/FASTER.common/HeaderReaderWriter.cs @@ -38,6 +38,15 @@ public unsafe bool Write(MessageType s, ref byte* dst, int length) return true; } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public unsafe bool Write(long seqNum, ref byte* dst, int length) + { + if (length < sizeof(long)) return false; + *(long*) dst = seqNum; + dst += sizeof(long); + return true; + } + /// <summary> /// Read message type /// </summary> @@ -48,5 +57,13 @@ public unsafe MessageType ReadMessageType(ref byte* dst) { return (MessageType)(*dst++); } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public unsafe long ReadSerialNum(ref byte* dst) + { + var result = *(long*) dst; + dst += sizeof(long); + return result; + } } } \ No newline at end of file diff --git a/cs/remote/src/FASTER.common/IClientSerializer.cs b/cs/remote/src/FASTER.common/IClientSerializer.cs index 0fc2aff66..c0982a725 100644 --- a/cs/remote/src/FASTER.common/IClientSerializer.cs +++ b/cs/remote/src/FASTER.common/IClientSerializer.cs @@ -38,12 +38,26 @@ public unsafe interface IClientSerializer<Key, Value, Input, Output> /// <param name="length">Space (bytes) available at destination</param> /// <returns>True if write succeeded, false if not (insufficient space)</returns> bool Write(ref Input i, ref byte* dst, int length); - + /// <summary> /// Read output from source memory location, increment pointer by amount read /// </summary> /// <param name="src">Source memory location</param> /// <returns>Output</returns> Output ReadOutput(ref byte* src); + + /// <summary> + /// Read key from source memory location, increment pointer by amount read + /// </summary> + /// <param name="src">Source memory location</param> + /// <returns>Key</returns> + Key ReadKey(ref byte* src); + + /// <summary> + /// Read key from source memory location, increment pointer by amount read + /// </summary> + /// <param name="src">Source memory location</param> + /// <returns>Key</returns> + Value ReadValue(ref byte* src); } } \ No newline at end of file diff --git a/cs/remote/src/FASTER.common/IKeyInputSerializer.cs b/cs/remote/src/FASTER.common/IKeyInputSerializer.cs new file mode 100644 index 000000000..110890b6b --- /dev/null +++ b/cs/remote/src/FASTER.common/IKeyInputSerializer.cs @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +namespace FASTER.common +{ + /// <summary> + /// Serializer interface for keys, needed for pub-sub + /// </summary> + /// <typeparam name="Key">Key</typeparam> + /// <typeparam name="Input">Input</typeparam> + public unsafe interface IKeyInputSerializer<Key, Input> : IKeySerializer<Key> + { + /// <summary> + /// Read input by reference, from given location + /// </summary> + /// <param name="src">Memory location</param> + /// <returns>Input</returns> + ref Input ReadInputByRef(ref byte* src); + } +} diff --git a/cs/remote/src/FASTER.common/IKeySerializer.cs b/cs/remote/src/FASTER.common/IKeySerializer.cs new file mode 100644 index 000000000..21570c69b --- /dev/null +++ b/cs/remote/src/FASTER.common/IKeySerializer.cs @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +namespace FASTER.common +{ + /// <summary> + /// Serializer interface for keys, needed for pub-sub + /// </summary> + /// <typeparam name="Key">Key</typeparam> + public unsafe interface IKeySerializer<Key> + { + /// <summary> + /// Read key by reference, from given location + /// </summary> + /// <param name="src">Memory location</param> + /// <returns>Key</returns> + ref Key ReadKeyByRef(ref byte* src); + + /// <summary> + /// Match pattern with key used for pub-sub + /// </summary> + /// <param name="k">key to be published</param> + /// <param name="asciiKey">whether key is ascii</param> + /// <param name="pattern">pattern to check</param> + /// <param name="asciiPattern">whether pattern is ascii</param> + bool Match(ref Key k, bool asciiKey, ref Key pattern, bool asciiPattern); + } +} diff --git a/cs/remote/src/FASTER.common/IServerSerializer.cs b/cs/remote/src/FASTER.common/IServerSerializer.cs index 6a456b2b3..89802f7ee 100644 --- a/cs/remote/src/FASTER.common/IServerSerializer.cs +++ b/cs/remote/src/FASTER.common/IServerSerializer.cs @@ -12,6 +12,24 @@ namespace FASTER.common /// <typeparam name="Output">Output</typeparam> public unsafe interface IServerSerializer<Key, Value, Input, Output> { + /// <summary> + /// Write element to given destination, with length bytes of space available + /// </summary> + /// <param name="k">Element to write</param> + /// <param name="dst">Destination memory</param> + /// <param name="length">Space (bytes) available at destination</param> + /// <returns>True if write succeeded, false if not (insufficient space)</returns> + bool Write(ref Key k, ref byte* dst, int length); + + /// <summary> + /// Write element to given destination, with length bytes of space available + /// </summary> + /// <param name="v">Element to write</param> + /// <param name="dst">Destination memory</param> + /// <param name="length">Space (bytes) available at destination</param> + /// <returns>True if write succeeded, false if not (insufficient space)</returns> + bool Write(ref Value v, ref byte* dst, int length); + /// <summary> /// Write element to given destination, with length bytes of space available /// </summary> diff --git a/cs/remote/src/FASTER.common/MessageType.cs b/cs/remote/src/FASTER.common/MessageType.cs index e73e6177f..a492c45b2 100644 --- a/cs/remote/src/FASTER.common/MessageType.cs +++ b/cs/remote/src/FASTER.common/MessageType.cs @@ -41,9 +41,35 @@ public enum MessageType : byte /// An async request to delete some value in a remote Faster instance /// </summary> DeleteAsync, + + /// <summary> + /// A request to subscribe to some key in a remote Faster instance + /// </summary> + SubscribeKV, + + /// <summary> + /// A request to subscribe to some key prefix in a remote Faster instance + /// </summary> + PSubscribeKV, + + /// <summary> + /// A request to subscribe to some key + /// </summary> + Subscribe, + + /// <summary> + /// A request to publish to some key, value pair + /// </summary> + Publish, + + /// <summary> + /// A request to subscribe to some key prefix + /// </summary> + PSubscribe, + /// <summary> /// Pending result /// </summary> - PendingResult + PendingResult, } } \ No newline at end of file diff --git a/cs/remote/src/FASTER.common/NetworkSender.cs b/cs/remote/src/FASTER.common/NetworkSender.cs index 801e7b41d..997ad324b 100644 --- a/cs/remote/src/FASTER.common/NetworkSender.cs +++ b/cs/remote/src/FASTER.common/NetworkSender.cs @@ -35,7 +35,7 @@ public void Dispose() /// Get reusable SocketAsyncEventArgs buffer /// </summary> /// <returns></returns> - public ReusableObject<SeaaBuffer> GetReusableSeaaBuffer() => reusableSeaaBuffer.Checkout(); + public SeaaBuffer GetReusableSeaaBuffer() => reusableSeaaBuffer.Checkout(); /// <summary> /// Send @@ -44,21 +44,27 @@ public void Dispose() /// <param name="sendObject">Reusable SocketAsyncEventArgs buffer</param> /// <param name="offset">Offset</param> /// <param name="size">Size in bytes</param> - public unsafe void Send(Socket socket, ReusableObject<SeaaBuffer> sendObject, int offset, int size) + public unsafe void Send(Socket socket, SeaaBuffer sendObject, int offset, int size) { // Reset send buffer - sendObject.obj.socketEventAsyncArgs.SetBuffer(offset, size); + sendObject.socketEventAsyncArgs.SetBuffer(offset, size); // Set user context to reusable object handle for disposal when send is done - sendObject.obj.socketEventAsyncArgs.UserToken = sendObject; - if (!socket.SendAsync(sendObject.obj.socketEventAsyncArgs)) - SeaaBuffer_Completed(null, sendObject.obj.socketEventAsyncArgs); + sendObject.socketEventAsyncArgs.UserToken = sendObject; + if (!socket.SendAsync(sendObject.socketEventAsyncArgs)) + SeaaBuffer_Completed(null, sendObject.socketEventAsyncArgs); } private void SeaaBuffer_Completed(object sender, SocketAsyncEventArgs e) { - ((ReusableObject<SeaaBuffer>)e.UserToken).Dispose(); + reusableSeaaBuffer.Return((SeaaBuffer)e.UserToken); } + /// <summary> + /// Return to pool + /// </summary> + /// <param name="obj"></param> + public void Return(SeaaBuffer obj) => reusableSeaaBuffer.Return(obj); + /// <summary> /// Receive /// </summary> diff --git a/cs/remote/src/FASTER.common/SimpleObjectPool.cs b/cs/remote/src/FASTER.common/SimpleObjectPool.cs index f2751fa17..a278e4d48 100644 --- a/cs/remote/src/FASTER.common/SimpleObjectPool.cs +++ b/cs/remote/src/FASTER.common/SimpleObjectPool.cs @@ -17,7 +17,6 @@ internal class SimpleObjectPool<T> : IDisposable where T : class, IDisposable private readonly Func<T> factory; private readonly LightConcurrentStack<T> stack; private int allocatedObjects; - private readonly int maxObjects; /// <summary> /// Constructor @@ -27,8 +26,7 @@ internal class SimpleObjectPool<T> : IDisposable where T : class, IDisposable public SimpleObjectPool(Func<T> factory, int maxObjects = 128) { this.factory = factory; - this.maxObjects = maxObjects; - stack = new LightConcurrentStack<T>(); + stack = new LightConcurrentStack<T>(maxObjects); allocatedObjects = 0; } @@ -46,19 +44,23 @@ public void Dispose() } [MethodImpl(MethodImplOptions.AggressiveInlining)] - public ReusableObject<T> Checkout() + public T Checkout() { if (!stack.TryPop(out var obj)) { - if (allocatedObjects < maxObjects) - { - Interlocked.Increment(ref allocatedObjects); - return new ReusableObject<T>(factory(), stack); - } - // Overflow objects are simply discarded after use - return new ReusableObject<T>(factory(), null); + Interlocked.Increment(ref allocatedObjects); + return factory(); + } + return obj; + } + + public void Return(T obj) + { + if (!stack.TryPush(obj)) + { + obj.Dispose(); + Interlocked.Decrement(ref allocatedObjects); } - return new ReusableObject<T>(obj, stack); } } } \ No newline at end of file diff --git a/cs/remote/src/FASTER.common/WireFormat.cs b/cs/remote/src/FASTER.common/WireFormat.cs index c5f5a5e8a..284f6c33d 100644 --- a/cs/remote/src/FASTER.common/WireFormat.cs +++ b/cs/remote/src/FASTER.common/WireFormat.cs @@ -20,6 +20,10 @@ public enum WireFormat : byte DefaultFixedLenKV = 1, /// <summary> + /// Similar to DefaultVarLenKV but with WebSocket headers (binary) + /// </summary> + WebSocket = 2, + /// ASCII wire format (non-binary protocol) /// </summary> ASCII = 255 diff --git a/cs/remote/src/FASTER.server/BinaryServerSession.cs b/cs/remote/src/FASTER.server/BinaryServerSession.cs index 27383313a..90d2e0581 100644 --- a/cs/remote/src/FASTER.server/BinaryServerSession.cs +++ b/cs/remote/src/FASTER.server/BinaryServerSession.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. using System; @@ -11,7 +11,7 @@ namespace FASTER.server { internal unsafe sealed class BinaryServerSession<Key, Value, Input, Output, Functions, ParameterSerializer> : FasterKVServerSessionBase<Key, Value, Input, Output, Functions, ParameterSerializer> - where Functions : IFunctions<Key, Value, Input, Output, long> + where Functions : IAdvancedFunctions<Key, Value, Input, Output, long> where ParameterSerializer : IServerSerializer<Key, Value, Input, Output> { readonly HeaderReaderWriter hrw; @@ -20,9 +20,16 @@ internal unsafe sealed class BinaryServerSession<Key, Value, Input, Output, Func int seqNo, pendingSeqNo, msgnum, start; byte* dcurr; - public BinaryServerSession(Socket socket, FasterKV<Key, Value> store, Functions functions, ParameterSerializer serializer, MaxSizeSettings maxSizeSettings) - : base(socket, store, functions, serializer, maxSizeSettings) + readonly SubscribeKVBroker<Key, Value, Input, IKeyInputSerializer<Key, Input>> subscribeKVBroker; + readonly SubscribeBroker<Key, Value, IKeySerializer<Key>> subscribeBroker; + + + public BinaryServerSession(Socket socket, FasterKV<Key, Value> store, Functions functions, ParameterSerializer serializer, MaxSizeSettings maxSizeSettings, SubscribeKVBroker<Key, Value, Input, IKeyInputSerializer<Key, Input>> subscribeKVBroker, SubscribeBroker<Key, Value, IKeySerializer<Key>> subscribeBroker) + : base(socket, store, functions, null, serializer, maxSizeSettings) { + this.subscribeKVBroker = subscribeKVBroker; + this.subscribeBroker = subscribeBroker; + readHead = 0; // Reserve minimum 4 bytes to send pending sequence number as output @@ -50,8 +57,8 @@ public override int TryConsumeMessages(byte[] buf) public override void CompleteRead(ref Output output, long ctx, Status status) { - byte* d = responseObject.obj.bufferPtr; - var dend = d + responseObject.obj.buffer.Length; + byte* d = responseObject.bufferPtr; + var dend = d + responseObject.buffer.Length; if ((int)(dend - dcurr) < 7 + maxSizeSettings.MaxOutputSize) SendAndReset(ref d, ref dend); @@ -67,8 +74,8 @@ public override void CompleteRead(ref Output output, long ctx, Status status) public override void CompleteRMW(ref Output output, long ctx, Status status) { - byte* d = responseObject.obj.bufferPtr; - var dend = d + responseObject.obj.buffer.Length; + byte* d = responseObject.bufferPtr; + var dend = d + responseObject.buffer.Length; if ((int)(dend - dcurr) < 7 + maxSizeSettings.MaxOutputSize) SendAndReset(ref d, ref dend); @@ -108,8 +115,8 @@ private unsafe void ProcessBatch(byte[] buf, int offset) fixed (byte* b = &buf[offset]) { - byte* d = responseObject.obj.bufferPtr; - var dend = d + responseObject.obj.buffer.Length; + byte* d = responseObject.bufferPtr; + var dend = d + responseObject.buffer.Length; dcurr = d + sizeof(int); // reserve space for size int origPendingSeqNo = pendingSeqNo; @@ -126,6 +133,7 @@ private unsafe void ProcessBatch(byte[] buf, int offset) for (msgnum = 0; msgnum < num; msgnum++) { var message = (MessageType)(*src++); + var serialNum = hrw.ReadSerialNum(ref src); switch (message) { case MessageType.Upsert: @@ -133,9 +141,13 @@ private unsafe void ProcessBatch(byte[] buf, int offset) if ((int)(dend - dcurr) < 2) SendAndReset(ref d, ref dend); - status = session.Upsert(ref serializer.ReadKeyByRef(ref src), ref serializer.ReadValueByRef(ref src)); + var keyPtr = src; + status = session.Upsert(ref serializer.ReadKeyByRef(ref src), ref serializer.ReadValueByRef(ref src), serialNo: serialNum); + hrw.Write(message, ref dcurr, (int)(dend - dcurr)); Write(ref status, ref dcurr, (int)(dend - dcurr)); + + subscribeKVBroker?.Publish(keyPtr); break; case MessageType.Read: @@ -145,7 +157,7 @@ private unsafe void ProcessBatch(byte[] buf, int offset) long ctx = ((long)message << 32) | (long)pendingSeqNo; status = session.Read(ref serializer.ReadKeyByRef(ref src), ref serializer.ReadInputByRef(ref src), - ref serializer.AsRefOutput(dcurr + 2, (int)(dend - dcurr)), ctx, 0); + ref serializer.AsRefOutput(dcurr + 2, (int)(dend - dcurr)), ctx, serialNum); hrw.Write(message, ref dcurr, (int)(dend - dcurr)); Write(ref status, ref dcurr, (int)(dend - dcurr)); @@ -161,9 +173,11 @@ private unsafe void ProcessBatch(byte[] buf, int offset) if ((int)(dend - dcurr) < 2 + maxSizeSettings.MaxOutputSize) SendAndReset(ref d, ref dend); + keyPtr = src; + ctx = ((long)message << 32) | (long)pendingSeqNo; status = session.RMW(ref serializer.ReadKeyByRef(ref src), ref serializer.ReadInputByRef(ref src), - ref serializer.AsRefOutput(dcurr + 2, (int)(dend - dcurr)), ctx); + ref serializer.AsRefOutput(dcurr + 2, (int)(dend - dcurr)), ctx, serialNum); hrw.Write(message, ref dcurr, (int)(dend - dcurr)); Write(ref status, ref dcurr, (int)(dend - dcurr)); @@ -171,6 +185,7 @@ private unsafe void ProcessBatch(byte[] buf, int offset) Write(pendingSeqNo++, ref dcurr, (int)(dend - dcurr)); else if (status == Status.OK || status == Status.NOTFOUND) serializer.SkipOutput(ref dcurr); + subscribeKVBroker?.Publish(keyPtr); break; case MessageType.Delete: @@ -178,13 +193,18 @@ private unsafe void ProcessBatch(byte[] buf, int offset) if ((int)(dend - dcurr) < 2) SendAndReset(ref d, ref dend); - status = session.Delete(ref serializer.ReadKeyByRef(ref src)); + keyPtr = src; + status = session.Delete(ref serializer.ReadKeyByRef(ref src), serialNo: serialNum); + hrw.Write(message, ref dcurr, (int)(dend - dcurr)); Write(ref status, ref dcurr, (int)(dend - dcurr)); + + subscribeKVBroker?.Publish(keyPtr); break; default: - throw new NotImplementedException(); + if (!HandlePubSub(message, ref src, ref d, ref dend)) throw new NotImplementedException(); + break; } } @@ -195,7 +215,95 @@ private unsafe void ProcessBatch(byte[] buf, int offset) if (msgnum - start > 0) Send(d); else - responseObject.Dispose(); + { + messageManager.Return(responseObject); + responseObject = null; + } + } + } + + /// <inheritdoc /> + public unsafe override void Publish(ref byte* keyPtr, int keyLength, ref byte* valPtr, int valLength, ref byte* inputPtr, int sid) + => Publish(ref keyPtr, keyLength, ref valPtr, ref inputPtr, sid, false); + + /// <inheritdoc /> + public unsafe override void PrefixPublish(byte* prefixPtr, int prefixLength, ref byte* keyPtr, int keyLength, ref byte* valPtr, int valLength, ref byte* inputPtr, int sid) + => Publish(ref keyPtr, keyLength, ref valPtr, ref inputPtr, sid, true); + + private unsafe void Publish(ref byte* keyPtr, int keyLength, ref byte* valPtr, ref byte* inputPtr, int sid, bool prefix) + { + MessageType message; + + if (valPtr == null) + { + message = MessageType.SubscribeKV; + if (prefix) + message = MessageType.PSubscribeKV; + } + else + { + message = MessageType.Subscribe; + if (prefix) + message = MessageType.PSubscribe; + } + + var respObj = messageManager.GetReusableSeaaBuffer(); + + ref Key key = ref serializer.ReadKeyByRef(ref keyPtr); + + byte* d = respObj.bufferPtr; + var dend = d + respObj.buffer.Length; + var dcurr = d + sizeof(int); // reserve space for size + byte* outputDcurr; + + dcurr += BatchHeader.Size; + + long ctx = ((long)message << 32) | (long)sid; + + if (prefix) + outputDcurr = dcurr + 6 + keyLength; + else + outputDcurr = dcurr + 6; + + var status = Status.OK; + if (valPtr == null) + status = session.Read(ref key, ref serializer.ReadInputByRef(ref inputPtr), ref serializer.AsRefOutput(outputDcurr, (int)(dend - dcurr)), ctx, 0); + + if (status != Status.PENDING) + { + // Write six bytes (message | status | sid) + hrw.Write(message, ref dcurr, (int)(dend - dcurr)); + Write(ref status, ref dcurr, (int)(dend - dcurr)); + Write(sid, ref dcurr, (int)(dend - dcurr)); + if (prefix) + serializer.Write(ref key, ref dcurr, (int)(dend - dcurr)); + if (valPtr != null) + { + ref Value value = ref serializer.ReadValueByRef(ref valPtr); + serializer.Write(ref value, ref dcurr, (int)(dend - dcurr)); + } + else if (status == Status.OK) + serializer.SkipOutput(ref dcurr); + } + else + { + throw new Exception("Pending reads not supported with pub/sub"); + } + + // Send replies + var dstart = d + sizeof(int); + Unsafe.AsRef<BatchHeader>(dstart).NumMessages = 1; + Unsafe.AsRef<BatchHeader>(dstart).SeqNo = 0; + int payloadSize = (int)(dcurr - d); + // Set packet size in header + *(int*)respObj.bufferPtr = -(payloadSize - sizeof(int)); + try + { + messageManager.Send(socket, respObj, 0, payloadSize); + } + catch + { + messageManager.Return(respObj); } } @@ -221,8 +329,8 @@ private void SendAndReset(ref byte* d, ref byte* dend) { Send(d); GetResponseObject(); - d = responseObject.obj.bufferPtr; - dend = d + responseObject.obj.buffer.Length; + d = responseObject.bufferPtr; + dend = d + responseObject.buffer.Length; dcurr = d + sizeof(int); start = msgnum; } @@ -235,9 +343,117 @@ private void Send(byte* d) Unsafe.AsRef<BatchHeader>(dstart).SeqNo = seqNo++; int payloadSize = (int)(dcurr - d); // Set packet size in header - *(int*)responseObject.obj.bufferPtr = -(payloadSize - sizeof(int)); + *(int*)responseObject.bufferPtr = -(payloadSize - sizeof(int)); SendResponse(payloadSize); - responseObject.obj = null; + } + + private bool HandlePubSub(MessageType message, ref byte* src, ref byte* d, ref byte* dend) + { + switch (message) + { + case MessageType.SubscribeKV: + if (subscribeKVBroker == null) return false; + + if ((int)(dend - dcurr) < 2 + maxSizeSettings.MaxOutputSize) + SendAndReset(ref d, ref dend); + + var keyStart = src; + serializer.ReadKeyByRef(ref src); + + var inputStart = src; + serializer.ReadInputByRef(ref src); + + int sid = subscribeKVBroker.Subscribe(ref keyStart, ref inputStart, this); + var status = Status.PENDING; + hrw.Write(message, ref dcurr, (int)(dend - dcurr)); + Write(ref status, ref dcurr, (int)(dend - dcurr)); + Write(sid, ref dcurr, (int)(dend - dcurr)); + break; + + case MessageType.PSubscribeKV: + if (subscribeKVBroker == null) return false; + + if ((int)(dend - dcurr) < 2 + maxSizeSettings.MaxOutputSize) + SendAndReset(ref d, ref dend); + + if (subscribeKVBroker == null) + break; + + keyStart = src; + serializer.ReadKeyByRef(ref src); + + inputStart = src; + serializer.ReadInputByRef(ref src); + + sid = subscribeKVBroker.PSubscribe(ref keyStart, ref inputStart, this); + status = Status.PENDING; + hrw.Write(message, ref dcurr, (int)(dend - dcurr)); + Write(ref status, ref dcurr, (int)(dend - dcurr)); + Write(sid, ref dcurr, (int)(dend - dcurr)); + break; + + case MessageType.Publish: + if (subscribeBroker == null) return false; + + if ((int)(dend - dcurr) < 2) + SendAndReset(ref d, ref dend); + + var keyPtr = src; + ref Key key = ref serializer.ReadKeyByRef(ref src); + byte* valPtr = src; + ref Value val = ref serializer.ReadValueByRef(ref src); + int valueLength = (int)(src - valPtr); + + status = Status.OK; + hrw.Write(message, ref dcurr, (int)(dend - dcurr)); + Write(ref status, ref dcurr, (int)(dend - dcurr)); + + if (subscribeBroker != null) + subscribeBroker.Publish(keyPtr, valPtr, valueLength); + break; + + case MessageType.Subscribe: + if (subscribeBroker == null) return false; + + if ((int)(dend - dcurr) < 2 + maxSizeSettings.MaxOutputSize) + SendAndReset(ref d, ref dend); + + keyStart = src; + serializer.ReadKeyByRef(ref src); + + sid = subscribeBroker.Subscribe(ref keyStart, this); + status = Status.PENDING; + hrw.Write(message, ref dcurr, (int)(dend - dcurr)); + Write(ref status, ref dcurr, (int)(dend - dcurr)); + Write(sid, ref dcurr, (int)(dend - dcurr)); + break; + + case MessageType.PSubscribe: + if (subscribeBroker == null) return false; + + if ((int)(dend - dcurr) < 2 + maxSizeSettings.MaxOutputSize) + SendAndReset(ref d, ref dend); + + keyStart = src; + serializer.ReadKeyByRef(ref src); + + sid = subscribeBroker.PSubscribe(ref keyStart, this); + status = Status.PENDING; + hrw.Write(message, ref dcurr, (int)(dend - dcurr)); + Write(ref status, ref dcurr, (int)(dend - dcurr)); + Write(sid, ref dcurr, (int)(dend - dcurr)); + break; + + default: + return false; + } + return true; + } + + public override void Dispose() + { + subscribeBroker?.RemoveSubscription(this); + subscribeKVBroker?.RemoveSubscription(this); } } } diff --git a/cs/remote/src/FASTER.server/ByteArrayComparer.cs b/cs/remote/src/FASTER.server/ByteArrayComparer.cs new file mode 100644 index 000000000..090273282 --- /dev/null +++ b/cs/remote/src/FASTER.server/ByteArrayComparer.cs @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using FASTER.core; + +namespace FASTER.server +{ + /// <summary> + /// Byte array equality comparer + /// </summary> + public class ByteArrayComparer : IEqualityComparer<byte[]> + { + /// <summary> + /// Equals + /// </summary> + /// <param name="left"></param> + /// <param name="right"></param> + /// <returns></returns> + public bool Equals(byte[] left, byte[] right) + => new ReadOnlySpan<byte>(left).SequenceEqual(new ReadOnlySpan<byte>(right)); + + /// <summary> + /// Get hash code + /// </summary> + /// <param name="key"></param> + /// <returns></returns> + public unsafe int GetHashCode(byte[] key) + { + fixed (byte* k = key) + { + return (int)Utility.HashBytes(k, key.Length); + } + } + } +} diff --git a/cs/remote/src/FASTER.server/ConnectionArgs.cs b/cs/remote/src/FASTER.server/ConnectionArgs.cs index 7a21a4666..e77706a5e 100644 --- a/cs/remote/src/FASTER.server/ConnectionArgs.cs +++ b/cs/remote/src/FASTER.server/ConnectionArgs.cs @@ -9,5 +9,6 @@ internal class ConnectionArgs { public Socket socket; public IServerSession session; + public int bytesRead; } } diff --git a/cs/remote/src/FASTER.server/FASTER.server.csproj b/cs/remote/src/FASTER.server/FASTER.server.csproj index b5088cf9b..3cc22c43b 100644 --- a/cs/remote/src/FASTER.server/FASTER.server.csproj +++ b/cs/remote/src/FASTER.server/FASTER.server.csproj @@ -27,7 +27,7 @@ <ItemGroup> <ProjectReference Include="..\FASTER.common\FASTER.common.csproj" /> <ProjectReference Include="..\..\..\src\core\FASTER.core.csproj" /> - <PackageReference Include="Microsoft.SourceLink.GitHub" Version="1.0.0" PrivateAssets="All"/> + <PackageReference Include="Microsoft.SourceLink.GitHub" Version="1.0.0" PrivateAssets="All" /> </ItemGroup> </Project> diff --git a/cs/remote/src/FASTER.server/FasterKVProvider.cs b/cs/remote/src/FASTER.server/FasterKVProvider.cs index d38d4d7f1..f1e056e7c 100644 --- a/cs/remote/src/FASTER.server/FasterKVProvider.cs +++ b/cs/remote/src/FASTER.server/FasterKVProvider.cs @@ -15,33 +15,47 @@ namespace FASTER.server /// <typeparam name="Functions"></typeparam> /// <typeparam name="ParameterSerializer"></typeparam> public sealed class FasterKVProvider<Key, Value, Input, Output, Functions, ParameterSerializer> : ISessionProvider - where Functions : IFunctions<Key, Value, Input, Output, long> + where Functions : IAdvancedFunctions<Key, Value, Input, Output, long> where ParameterSerializer : IServerSerializer<Key, Value, Input, Output> { readonly FasterKV<Key, Value> store; readonly Func<WireFormat, Functions> functionsGen; readonly ParameterSerializer serializer; readonly MaxSizeSettings maxSizeSettings; + readonly SubscribeKVBroker<Key, Value, Input, IKeyInputSerializer<Key, Input>> subscribeKVBroker; + readonly SubscribeBroker<Key, Value, IKeySerializer<Key>> subscribeBroker; /// <summary> /// Create FasterKV backend /// </summary> /// <param name="store"></param> /// <param name="functionsGen"></param> + /// <param name="subscribeKVBroker"></param> + /// <param name="subscribeBroker"></param> /// <param name="serializer"></param> /// <param name="maxSizeSettings"></param> - public FasterKVProvider(FasterKV<Key, Value> store, Func<WireFormat, Functions> functionsGen, ParameterSerializer serializer = default, MaxSizeSettings maxSizeSettings = default) + public FasterKVProvider(FasterKV<Key, Value> store, Func<WireFormat, Functions> functionsGen, SubscribeKVBroker<Key, Value, Input, IKeyInputSerializer<Key, Input>> subscribeKVBroker = null, SubscribeBroker<Key, Value, IKeySerializer<Key>> subscribeBroker = null, ParameterSerializer serializer = default, MaxSizeSettings maxSizeSettings = default) { this.store = store; this.functionsGen = functionsGen; this.serializer = serializer; this.maxSizeSettings = maxSizeSettings ?? new MaxSizeSettings(); + this.subscribeKVBroker = subscribeKVBroker; + this.subscribeBroker = subscribeBroker; } /// <inheritdoc /> public IServerSession GetSession(WireFormat wireFormat, Socket socket) { - return new BinaryServerSession<Key, Value, Input, Output, Functions, ParameterSerializer>(socket, store, functionsGen(wireFormat), serializer, maxSizeSettings); + switch (wireFormat) + { + case WireFormat.WebSocket: + return new WebsocketServerSession<Key, Value, Input, Output, Functions, ParameterSerializer> + (socket, store, functionsGen(wireFormat), serializer, maxSizeSettings, subscribeKVBroker, subscribeBroker); + default: + return new BinaryServerSession<Key, Value, Input, Output, Functions, ParameterSerializer> + (socket, store, functionsGen(wireFormat), serializer, maxSizeSettings, subscribeKVBroker, subscribeBroker); + } } } } \ No newline at end of file diff --git a/cs/remote/src/FASTER.server/FasterKVServerSessionBase.cs b/cs/remote/src/FASTER.server/FasterKVServerSessionBase.cs index 17f5bca72..8e6f83efe 100644 --- a/cs/remote/src/FASTER.server/FasterKVServerSessionBase.cs +++ b/cs/remote/src/FASTER.server/FasterKVServerSessionBase.cs @@ -4,25 +4,23 @@ namespace FASTER.server { - internal abstract class FasterKVServerSessionBase<Key, Value, Input, Output, Functions, ParameterSerializer> : ServerSessionBase - where Functions : IFunctions<Key, Value, Input, Output, long> + + internal abstract class FasterKVServerSessionBase<Key, Value, Input, Output, Functions, ParameterSerializer> : FasterKVServerSessionBase<Output> + where Functions : IAdvancedFunctions<Key, Value, Input, Output, long> where ParameterSerializer : IServerSerializer<Key, Value, Input, Output> { - protected readonly ClientSession<Key, Value, Input, Output, long, ServerKVFunctions<Key, Value, Input, Output, Functions, ParameterSerializer>> session; + protected readonly AdvancedClientSession<Key, Value, Input, Output, long, ServerKVFunctions<Key, Value, Input, Output, Functions>> session; protected readonly ParameterSerializer serializer; public FasterKVServerSessionBase(Socket socket, FasterKV<Key, Value> store, Functions functions, + SessionVariableLengthStructSettings<Value, Input> sessionVariableLengthStructSettings, ParameterSerializer serializer, MaxSizeSettings maxSizeSettings) : base(socket, maxSizeSettings) { - session = store.For(new ServerKVFunctions<Key, Value, Input, Output, Functions, ParameterSerializer>(functions, this)) - .NewSession<ServerKVFunctions<Key, Value, Input, Output, Functions, ParameterSerializer>>(); + session = store.For(new ServerKVFunctions<Key, Value, Input, Output, Functions>(functions, this)) + .NewSession<ServerKVFunctions<Key, Value, Input, Output, Functions>>(sessionVariableLengthStructSettings: sessionVariableLengthStructSettings); this.serializer = serializer; } - - public abstract void CompleteRead(ref Output output, long ctx, Status status); - public abstract void CompleteRMW(ref Output output, long ctx, Status status); - public override void Dispose() { @@ -30,4 +28,12 @@ public override void Dispose() base.Dispose(); } } + + internal abstract class FasterKVServerSessionBase<Output> : ServerSessionBase + { + public FasterKVServerSessionBase(Socket socket, MaxSizeSettings maxSizeSettings) : base(socket, maxSizeSettings) { } + + public abstract void CompleteRead(ref Output output, long ctx, Status status); + public abstract void CompleteRMW(ref Output output, long ctx, Status status); + } } \ No newline at end of file diff --git a/cs/remote/src/FASTER.server/FasterServer.cs b/cs/remote/src/FASTER.server/FasterServer.cs index 018836935..3886366b8 100644 --- a/cs/remote/src/FASTER.server/FasterServer.cs +++ b/cs/remote/src/FASTER.server/FasterServer.cs @@ -43,7 +43,7 @@ public FasterServer(string address, int port, int networkBufferSize = default) if (networkBufferSize == default) this.networkBufferSize = BufferSizeUtils.ClientBufferSize(new MaxSizeSettings()); - var ip = IPAddress.Parse(address); + var ip = address == null ? IPAddress.Any : IPAddress.Parse(address); var endPoint = new IPEndPoint(ip, port); servSocket = new Socket(ip.AddressFamily, SocketType.Stream, ProtocolType.Tcp); servSocket.Bind(endPoint); @@ -98,7 +98,7 @@ private bool HandleNewConnection(SocketAsyncEventArgs e) return false; } - // Ok to create new event args on accept because we assume a connection to be long-running + // Ok to create new event args on accept because we assume a connection to be long-running var receiveEventArgs = new SocketAsyncEventArgs(); var buffer = new byte[networkBufferSize]; receiveEventArgs.SetBuffer(buffer, 0, networkBufferSize); @@ -120,37 +120,47 @@ private bool HandleNewConnection(SocketAsyncEventArgs e) private void AcceptEventArg_Completed(object sender, SocketAsyncEventArgs e) { - do + try { - if (!HandleNewConnection(e)) break; - e.AcceptSocket = null; - } while (!servSocket.AcceptAsync(e)); + do + { + if (!HandleNewConnection(e)) break; + e.AcceptSocket = null; + } while (!servSocket.AcceptAsync(e)); + } + // socket disposed + catch (ObjectDisposedException) { } } private void DisposeActiveSessions() { - while (activeSessionCount > 0) + while (true) { - foreach (var kvp in activeSessions) + while (activeSessionCount > 0) { - var _session = kvp.Key; - if (_session != null) + foreach (var kvp in activeSessions) { - if (activeSessions.TryRemove(_session, out _)) + var _session = kvp.Key; + if (_session != null) { - _session.Dispose(); - Interlocked.Decrement(ref activeSessionCount); + if (activeSessions.TryRemove(_session, out _)) + { + _session.Dispose(); + Interlocked.Decrement(ref activeSessionCount); + } } } + Thread.Yield(); } - Thread.Yield(); + if (Interlocked.CompareExchange(ref activeSessionCount, int.MinValue, 0) == 0) + break; } } [MethodImpl(MethodImplOptions.AggressiveInlining)] private bool HandleReceiveCompletion(SocketAsyncEventArgs e) { - var connArgs = (ConnectionArgs) e.UserToken; + var connArgs = (ConnectionArgs)e.UserToken; if (e.BytesTransferred == 0 || e.SocketError != SocketError.Success || disposed) { DisposeConnectionSession(e); @@ -159,21 +169,35 @@ private bool HandleReceiveCompletion(SocketAsyncEventArgs e) if (connArgs.session == null) { - if (!CreateSession(e)) - return false; + return CreateSession(e); } connArgs.session.AddBytesRead(e.BytesTransferred); var newHead = connArgs.session.TryConsumeMessages(e.Buffer); - e.SetBuffer(newHead, e.Buffer.Length - newHead); + if (newHead == e.Buffer.Length) + { + // Need to grow input buffer + var newBuffer = new byte[e.Buffer.Length * 2]; + Array.Copy(e.Buffer, newBuffer, e.Buffer.Length); + e.SetBuffer(newBuffer, newHead, newBuffer.Length - newHead); + } + else + e.SetBuffer(newHead, e.Buffer.Length - newHead); return true; } private unsafe bool CreateSession(SocketAsyncEventArgs e) { - var connArgs = (ConnectionArgs) e.UserToken; + var connArgs = (ConnectionArgs)e.UserToken; - if (e.BytesTransferred < 4) return false; + connArgs.bytesRead += e.BytesTransferred; + + // We need at least 4 bytes to determine session + if (connArgs.bytesRead < 4) + { + e.SetBuffer(connArgs.bytesRead, e.Buffer.Length - connArgs.bytesRead); + return true; + } WireFormat protocol; @@ -181,10 +205,19 @@ private unsafe bool CreateSession(SocketAsyncEventArgs e) // This results in a fourth byte value (little endian) > 127, denoting a non-ASCII wire format. if (e.Buffer[3] > 127) { - if (e.BytesTransferred < 4 + BatchHeader.Size) return false; + + if (connArgs.bytesRead < 4 + BatchHeader.Size) + { + e.SetBuffer(connArgs.bytesRead, e.Buffer.Length - connArgs.bytesRead); + return true; + } fixed (void* bh = &e.Buffer[4]) protocol = ((BatchHeader*)bh)->Protocol; } + else if (e.Buffer[0] == 71 && e.Buffer[1] == 69 && e.Buffer[2] == 84) + { + protocol = WireFormat.WebSocket; + } else { protocol = WireFormat.ASCII; @@ -192,27 +225,44 @@ private unsafe bool CreateSession(SocketAsyncEventArgs e) if (!sessionProviders.TryGetValue(protocol, out var provider)) { - throw new FasterException($"Unsupported wire format {protocol}"); + Console.WriteLine($"Unsupported incoming wire format {protocol}"); + DisposeConnectionSession(e); + return false; } - connArgs.session = provider.GetSession(protocol, connArgs.socket); + if (Interlocked.Increment(ref activeSessionCount) <= 0) + { + DisposeConnectionSession(e); + return false; + } - if (activeSessions.TryAdd(connArgs.session, default)) - Interlocked.Increment(ref activeSessionCount); - else - throw new Exception("Unexpected: unable to add session to activeSessions"); + connArgs.session = provider.GetSession(protocol, connArgs.socket); + + activeSessions.TryAdd(connArgs.session, default); if (disposed) { DisposeConnectionSession(e); return false; } + + connArgs.session.AddBytesRead(connArgs.bytesRead); + var _newHead = connArgs.session.TryConsumeMessages(e.Buffer); + if (_newHead == e.Buffer.Length) + { + // Need to grow input buffer + var newBuffer = new byte[e.Buffer.Length * 2]; + Array.Copy(e.Buffer, newBuffer, e.Buffer.Length); + e.SetBuffer(newBuffer, _newHead, newBuffer.Length - _newHead); + } + else + e.SetBuffer(_newHead, e.Buffer.Length - _newHead); return true; } private void DisposeConnectionSession(SocketAsyncEventArgs e) { - var connArgs = (ConnectionArgs) e.UserToken; + var connArgs = (ConnectionArgs)e.UserToken; connArgs.socket.Dispose(); e.Dispose(); var _session = connArgs.session; @@ -237,8 +287,11 @@ private void RecvEventArg_Completed(object sender, SocketAsyncEventArgs e) if (!HandleReceiveCompletion(e)) break; } while (!connArgs.socket.ReceiveAsync(e)); } - // ignore session socket disposed due to server dispose - catch (ObjectDisposedException) { } + // socket disposed + catch (ObjectDisposedException) + { + DisposeConnectionSession(e); + } } } } diff --git a/cs/remote/src/FASTER.server/FixedLenSerializer.cs b/cs/remote/src/FASTER.server/FixedLenSerializer.cs index 6a1ea7104..8840316d0 100644 --- a/cs/remote/src/FASTER.server/FixedLenSerializer.cs +++ b/cs/remote/src/FASTER.server/FixedLenSerializer.cs @@ -54,6 +54,24 @@ public ref Input ReadInputByRef(ref byte* src) return ref Unsafe.AsRef<Input>(_src); } + /// <inheritdoc /> + public bool Write(ref Key k, ref byte* dst, int length) + { + if (length < Unsafe.SizeOf<Key>()) return false; + Unsafe.AsRef<Key>(dst) = k; + dst += Unsafe.SizeOf<Key>(); + return true; + } + + /// <inheritdoc /> + public bool Write(ref Value v, ref byte* dst, int length) + { + if (length < Unsafe.SizeOf<Value>()) return false; + Unsafe.AsRef<Value>(dst) = v; + dst += Unsafe.SizeOf<Value>(); + return true; + } + /// <inheritdoc /> public bool Write(ref Output o, ref byte* dst, int length) { @@ -100,5 +118,14 @@ private static bool IsBlittable<T>() } return true; } + + /// <inheritdoc /> + public bool Match(ref Key k, ref Key pattern) + { + if (k.Equals(pattern)) + return true; + + return false; + } } } \ No newline at end of file diff --git a/cs/remote/src/FASTER.server/Providers/SpanByteFasterKVProvider.cs b/cs/remote/src/FASTER.server/Providers/SpanByteFasterKVProvider.cs new file mode 100644 index 000000000..89cc5ca9e --- /dev/null +++ b/cs/remote/src/FASTER.server/Providers/SpanByteFasterKVProvider.cs @@ -0,0 +1,78 @@ +using System.Net.Sockets; +using FASTER.common; +using FASTER.core; + +namespace FASTER.server +{ + /// <summary> + /// Session provider for FasterKV store based on + /// [K, V, I, O, C] = [SpanByte, SpanByte, SpanByte, SpanByteAndMemory, long] + /// </summary> + public class SpanByteFasterKVProvider : ISessionProvider + { + /// <summary> + /// Store + /// </summary> + protected readonly FasterKV<SpanByte, SpanByte> store; + + /// <summary> + /// Serializer + /// </summary> + protected readonly SpanByteServerSerializer serializer; + + /// <summary> + /// KV broker + /// </summary> + protected readonly SubscribeKVBroker<SpanByte, SpanByte, SpanByte, IKeyInputSerializer<SpanByte, SpanByte>> kvBroker; + + /// <summary> + /// Broker + /// </summary> + protected readonly SubscribeBroker<SpanByte, SpanByte, IKeySerializer<SpanByte>> broker; + + /// <summary> + /// Size settings + /// </summary> + protected readonly MaxSizeSettings maxSizeSettings; + + /// <summary> + /// Create SpanByte FasterKV backend + /// </summary> + /// <param name="store"></param> + /// <param name="kvBroker"></param> + /// <param name="broker"></param> + /// <param name="recoverStore"></param> + /// <param name="maxSizeSettings"></param> + public SpanByteFasterKVProvider(FasterKV<SpanByte, SpanByte> store, SubscribeKVBroker<SpanByte, SpanByte, SpanByte, IKeyInputSerializer<SpanByte, SpanByte>> kvBroker = null, SubscribeBroker<SpanByte, SpanByte, IKeySerializer<SpanByte>> broker = null, bool recoverStore = false, MaxSizeSettings maxSizeSettings = default) + { + this.store = store; + if (recoverStore) + { + try + { + store.Recover(); + } + catch + { } + } + this.kvBroker = kvBroker; + this.broker = broker; + this.serializer = new SpanByteServerSerializer(); + this.maxSizeSettings = maxSizeSettings ?? new MaxSizeSettings(); + } + + /// <inheritdoc /> + public virtual IServerSession GetSession(WireFormat wireFormat, Socket socket) + { + switch (wireFormat) + { + case WireFormat.WebSocket: + return new WebsocketServerSession<SpanByte, SpanByte, SpanByte, SpanByteAndMemory, SpanByteFunctionsForServer<long>, SpanByteServerSerializer> + (socket, store, new SpanByteFunctionsForServer<long>(), serializer, maxSizeSettings, kvBroker, broker); + default: + return new BinaryServerSession<SpanByte, SpanByte, SpanByte, SpanByteAndMemory, SpanByteFunctionsForServer<long>, SpanByteServerSerializer> + (socket, store, new SpanByteFunctionsForServer<long>(), serializer, maxSizeSettings, kvBroker, broker); + } + } + } +} \ No newline at end of file diff --git a/cs/remote/src/FASTER.server/PubSub/FixedLenKeySerializer.cs b/cs/remote/src/FASTER.server/PubSub/FixedLenKeySerializer.cs new file mode 100644 index 000000000..040061ada --- /dev/null +++ b/cs/remote/src/FASTER.server/PubSub/FixedLenKeySerializer.cs @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +using System.Runtime.CompilerServices; +using FASTER.common; + +namespace FASTER.server +{ + /// <summary> + /// Serializer for SpanByte. Used only on server-side. + /// </summary> + public unsafe sealed class FixedLenKeySerializer<Key, Input> : IKeyInputSerializer<Key, Input> + { + /// <summary> + /// Constructor + /// </summary> + public FixedLenKeySerializer() + { + } + + /// <inheritdoc /> + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public ref Key ReadKeyByRef(ref byte* src) + { + var _src = (void*)src; + src += Unsafe.SizeOf<Key>(); + return ref Unsafe.AsRef<Key>(_src); + } + + /// <inheritdoc /> + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public ref Input ReadInputByRef(ref byte* src) + { + var _src = (void*)src; + src += Unsafe.SizeOf<Input>(); + return ref Unsafe.AsRef<Input>(_src); + } + + /// <inheritdoc /> + public bool Match(ref Key k, bool asciiKey, ref Key pattern, bool asciiPattern) + { + if (k.Equals(pattern)) + return true; + return false; + } + } +} \ No newline at end of file diff --git a/cs/remote/src/FASTER.server/PubSub/SpanByteKeySerializer.cs b/cs/remote/src/FASTER.server/PubSub/SpanByteKeySerializer.cs new file mode 100644 index 000000000..6dcf516cb --- /dev/null +++ b/cs/remote/src/FASTER.server/PubSub/SpanByteKeySerializer.cs @@ -0,0 +1,206 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +using System.Runtime.CompilerServices; +using FASTER.core; +using FASTER.common; +using System; + +namespace FASTER.server +{ + /// <summary> + /// Serializer for SpanByte. Used only on server-side. + /// </summary> + public unsafe sealed class SpanByteKeySerializer : IKeyInputSerializer<SpanByte, SpanByte> + { + readonly SpanByteVarLenStruct settings; + + /// <summary> + /// Constructor + /// </summary> + public SpanByteKeySerializer() + { + settings = new SpanByteVarLenStruct(); + } + + /// <inheritdoc /> + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public ref SpanByte ReadKeyByRef(ref byte* src) + { + ref var ret = ref Unsafe.AsRef<SpanByte>(src); + src += settings.GetLength(ref ret); + return ref ret; + } + + /// <inheritdoc /> + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public ref SpanByte ReadInputByRef(ref byte* src) + { + ref var ret = ref Unsafe.AsRef<SpanByte>(src); + src += settings.GetLength(ref ret); + return ref ret; + } + + /// <inheritdoc /> + public bool Match(ref SpanByte k, bool asciiKey, ref SpanByte pattern, bool asciiPattern) + { + if (asciiKey && asciiPattern) + { + return GlobMatching(pattern.ToPointer(), pattern.LengthWithoutMetadata, k.ToPointer(), k.LengthWithoutMetadata); + } + + if (pattern.LengthWithoutMetadata > k.LengthWithoutMetadata) + return false; + return pattern.AsReadOnlySpan().SequenceEqual(k.AsReadOnlySpan().Slice(0, pattern.LengthWithoutMetadata)); + } + + /* Glob-style pattern matching. */ + private static bool GlobMatching(byte* pattern, int patternLen, byte* key, int stringLen, bool nocase = false) + { + while (patternLen > 0 && stringLen > 0) + { + switch (pattern[0]) + { + case (byte)'*': + while (patternLen > 0 && pattern[1] == '*') + { + pattern++; + patternLen--; + } + if (patternLen == 1) + return true; /* match */ + while (stringLen > 0) + { + if (GlobMatching(pattern + 1, patternLen - 1, key, stringLen, nocase)) + return true; /* match */ + key++; + stringLen--; + } + return false; /* no match */ + + case (byte)'?': + key++; + stringLen--; + break; + + case (byte)'[': + { + bool not, match; + pattern++; + patternLen--; + not = (pattern[0] == '^'); + if (not) + { + pattern++; + patternLen--; + } + match = false; + while (true) + { + if (pattern[0] == '\\' && patternLen >= 2) + { + pattern++; + patternLen--; + if (pattern[0] == key[0]) + match = true; + } + else if (pattern[0] == ']') + { + break; + } + else if (patternLen == 0) + { + pattern--; + patternLen++; + break; + } + else if (patternLen >= 3 && pattern[1] == '-') + { + int start = pattern[0]; + int end = pattern[2]; + int c = key[0]; + if (start > end) + { + int t = start; + start = end; + end = t; + } + if (nocase) + { + start = char.ToLower((char)start); + end = char.ToLower((char)end); + c = char.ToLower((char)c); + } + pattern += 2; + patternLen -= 2; + if (c >= start && c <= end) + match = true; + } + else + { + if (!nocase) + { + if (pattern[0] == key[0]) + match = true; + } + else + { + if (char.ToLower((char)pattern[0]) == char.ToLower((char)key[0])) + match = true; + } + } + pattern++; + patternLen--; + } + + if (not) + match = !match; + if (!match) + return false; /* no match */ + key++; + stringLen--; + break; + } + + case (byte)'\\': + if (patternLen >= 2) + { + pattern++; + patternLen--; + } + goto default; + + /* fall through */ + default: + if (!nocase) + { + if (pattern[0] != key[0]) + return false; /* no match */ + } + else + { + if (char.ToLower((char)pattern[0]) != char.ToLower((char)key[0])) + return false; /* no match */ + } + key++; + stringLen--; + break; + } + pattern++; + patternLen--; + if (stringLen == 0) + { + while (*pattern == '*') + { + pattern++; + patternLen--; + } + break; + } + } + if (patternLen == 0 && stringLen == 0) + return true; + return false; + } + } +} \ No newline at end of file diff --git a/cs/remote/src/FASTER.server/PubSub/SubscribeBroker.cs b/cs/remote/src/FASTER.server/PubSub/SubscribeBroker.cs new file mode 100644 index 000000000..2b52beb48 --- /dev/null +++ b/cs/remote/src/FASTER.server/PubSub/SubscribeBroker.cs @@ -0,0 +1,417 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using FASTER.common; +using FASTER.core; + +namespace FASTER.server +{ + /// <summary> + /// Broker used for PUB-SUB to FASTER KV store. There is a broker per FasterKV instance. + /// A single broker can be used with multiple FasterKVProviders. + /// </summary> + /// <typeparam name="Key"></typeparam> + /// <typeparam name="Value"></typeparam> + /// <typeparam name="KeyValueSerializer"></typeparam> + public sealed class SubscribeBroker<Key, Value, KeyValueSerializer> : IDisposable + where KeyValueSerializer : IKeySerializer<Key> + { + private int sid = 0; + private ConcurrentDictionary<byte[], ConcurrentDictionary<int, ServerSessionBase>> subscriptions; + private ConcurrentDictionary<byte[], (bool, ConcurrentDictionary<int, ServerSessionBase>)> prefixSubscriptions; + private AsyncQueue<(byte[], byte[])> publishQueue; + readonly IKeySerializer<Key> keySerializer; + readonly FasterLog log; + readonly IDevice device; + readonly CancellationTokenSource cts = new(); + readonly ManualResetEvent done = new(true); + bool disposed = false; + + /// <summary> + /// Constructor + /// </summary> + /// <param name="keySerializer">Serializer for Prefix Match and serializing Key</param> + /// <param name="logDir">Directory where the log will be stored</param> + /// <param name="startFresh">start the log from scratch, do not continue</param> + public SubscribeBroker(IKeySerializer<Key> keySerializer, string logDir, bool startFresh = true) + { + this.keySerializer = keySerializer; + device = logDir == null ? new NullDevice() : Devices.CreateLogDevice(logDir + "/pubsubkv", preallocateFile: false); + device.Initialize((long)(1 << 30) * 64); + log = new FasterLog(new FasterLogSettings { LogDevice = device }); + if (startFresh) + log.TruncateUntil(log.CommittedUntilAddress); + } + + /// <summary> + /// Remove all subscriptions for a session, + /// called during dispose of server session + /// </summary> + /// <param name="session">server session</param> + public unsafe void RemoveSubscription(IServerSession session) + { + if (subscriptions != null) + { + foreach (var subscribedkey in subscriptions.Keys) + { + fixed (byte* keyPtr = &subscribedkey[0]) + this.Unsubscribe(keyPtr, (ServerSessionBase)session); + } + } + + if (prefixSubscriptions != null) + { + foreach (var subscribedkey in prefixSubscriptions.Keys) + { + fixed (byte* keyPtr = &subscribedkey[0]) + this.PUnsubscribe(keyPtr, (ServerSessionBase)session); + } + } + } + + private unsafe int Broadcast(byte[] key, byte* valPtr, int valLength, bool ascii) + { + int numSubscribers = 0; + + fixed (byte* ptr = &key[0]) + { + byte* keyPtr = ptr; + + if (subscriptions != null) + { + bool foundSubscription = subscriptions.TryGetValue(key, out var subscriptionServerSessionDict); + if (foundSubscription) + { + foreach (var sub in subscriptionServerSessionDict) + { + byte* keyBytePtr = ptr; + byte* nullBytePtr = null; + byte* valBytePtr = valPtr; + sub.Value.Publish(ref keyBytePtr, key.Length, ref valBytePtr, valLength, ref nullBytePtr, sub.Key); + numSubscribers++; + } + } + } + + if (prefixSubscriptions != null) + { + foreach (var kvp in prefixSubscriptions) + { + fixed (byte* subscribedPrefixPtr = &kvp.Key[0]) + { + byte* subPrefixPtr = subscribedPrefixPtr; + byte* reqKeyPtr = ptr; + + bool match = keySerializer.Match(ref keySerializer.ReadKeyByRef(ref reqKeyPtr), ascii, + ref keySerializer.ReadKeyByRef(ref subPrefixPtr), kvp.Value.Item1); + if (match) + { + foreach (var sub in kvp.Value.Item2) + { + byte* keyBytePtr = ptr; + byte* nullBytePtr = null; + sub.Value.PrefixPublish(subscribedPrefixPtr, kvp.Key.Length, ref keyBytePtr, key.Length, ref valPtr, valLength, ref nullBytePtr, sub.Key); + numSubscribers++; + } + } + } + } + } + } + return numSubscribers; + } + + private async Task Start(CancellationToken cancellationToken = default) + { + try + { + var uniqueKeys = new Dictionary<byte[], (byte[], byte[])>(new ByteArrayComparer()); + long truncateUntilAddress = log.BeginAddress; + + while (true) + { + if (disposed) + break; + + using var iter = log.Scan(log.BeginAddress, long.MaxValue, scanUncommitted: true); + await iter.WaitAsync(cancellationToken).ConfigureAwait(false); + while (iter.GetNext(out byte[] subscriptionKeyValueAscii, out _, out long currentAddress, out long nextAddress)) + { + if (currentAddress >= long.MaxValue) return; + + byte[] subscriptionKey; + byte[] subscriptionValue; + byte[] ascii; + + unsafe + { + fixed (byte* subscriptionKeyValueAsciiPtr = &subscriptionKeyValueAscii[0]) + { + var keyPtr = subscriptionKeyValueAsciiPtr; + keySerializer.ReadKeyByRef(ref keyPtr); + int subscriptionKeyLength = (int)(keyPtr - subscriptionKeyValueAsciiPtr); + int subscriptionValueLength = subscriptionKeyValueAscii.Length - (subscriptionKeyLength + sizeof(bool)); + subscriptionKey = new byte[subscriptionKeyLength]; + subscriptionValue = new byte[subscriptionValueLength]; + ascii = new byte[sizeof(bool)]; + + fixed (byte* subscriptionKeyPtr = &subscriptionKey[0], subscriptionValuePtr = &subscriptionValue[0], asciiPtr = &ascii[0]) + { + Buffer.MemoryCopy(subscriptionKeyValueAsciiPtr, subscriptionKeyPtr, subscriptionKeyLength, subscriptionKeyLength); + Buffer.MemoryCopy(subscriptionKeyValueAsciiPtr + subscriptionKeyLength, subscriptionValuePtr, subscriptionValueLength, subscriptionValueLength); + Buffer.MemoryCopy(subscriptionKeyValueAsciiPtr + subscriptionKeyLength + subscriptionValueLength, asciiPtr, sizeof(bool), sizeof(bool)); + } + } + } + truncateUntilAddress = nextAddress; + if (!uniqueKeys.ContainsKey(subscriptionKey)) + uniqueKeys.Add(subscriptionKey, (subscriptionValue, ascii)); + } + + if (truncateUntilAddress > log.BeginAddress) + log.TruncateUntil(truncateUntilAddress); + + unsafe + { + var enumerator = uniqueKeys.GetEnumerator(); + while (enumerator.MoveNext()) + { + byte[] keyBytes = enumerator.Current.Key; + byte[] valBytes = enumerator.Current.Value.Item1; + byte[] asciiBytes = enumerator.Current.Value.Item2; + bool ascii = asciiBytes[0] != 0; + + fixed (byte* valPtr = valBytes) + Broadcast(keyBytes, valPtr, valBytes.Length, ascii); + } + uniqueKeys.Clear(); + } + } + } + finally + { + done.Set(); + } + } + + /// <summary> + /// Subscribe to a particular Key + /// </summary> + /// <param name="key">Key to subscribe to</param> + /// <param name="session">Server session</param> + /// <returns></returns> + public unsafe int Subscribe(ref byte* key, ServerSessionBase session) + { + var start = key; + keySerializer.ReadKeyByRef(ref key); + var id = Interlocked.Increment(ref sid); + if (Interlocked.CompareExchange(ref publishQueue, new AsyncQueue<(byte[], byte[])>(), null) == null) + { + done.Reset(); + subscriptions = new ConcurrentDictionary<byte[], ConcurrentDictionary<int, ServerSessionBase>>(new ByteArrayComparer()); + prefixSubscriptions = new ConcurrentDictionary<byte[], (bool, ConcurrentDictionary<int, ServerSessionBase>)>(new ByteArrayComparer()); + Task.Run(() => Start(cts.Token)); + } + else + { + while (prefixSubscriptions == null) Thread.Yield(); + } + var subscriptionKey = new Span<byte>(start, (int)(key - start)).ToArray(); + subscriptions.TryAdd(subscriptionKey, new ConcurrentDictionary<int, ServerSessionBase>()); + if (subscriptions.TryGetValue(subscriptionKey, out var val)) + val.TryAdd(sid, session); + return id; + } + + /// <summary> + /// Subscribe to a particular prefix + /// </summary> + /// <param name="prefix">prefix to subscribe to</param> + /// <param name="session">Server session</param> + /// <param name="ascii">is key ascii?</param> + /// <returns></returns> + public unsafe int PSubscribe(ref byte* prefix, ServerSessionBase session, bool ascii = false) + { + var start = prefix; + keySerializer.ReadKeyByRef(ref prefix); + var id = Interlocked.Increment(ref sid); + if (Interlocked.CompareExchange(ref publishQueue, new AsyncQueue<(byte[], byte[])>(), null) == null) + { + done.Reset(); + subscriptions = new ConcurrentDictionary<byte[], ConcurrentDictionary<int, ServerSessionBase>>(new ByteArrayComparer()); + prefixSubscriptions = new ConcurrentDictionary<byte[], (bool, ConcurrentDictionary<int, ServerSessionBase>)>(new ByteArrayComparer()); + Task.Run(() => Start(cts.Token)); + } + else + { + while (prefixSubscriptions == null) Thread.Yield(); + } + var subscriptionPrefix = new Span<byte>(start, (int)(prefix - start)).ToArray(); + prefixSubscriptions.TryAdd(subscriptionPrefix, (ascii, new ConcurrentDictionary<int, ServerSessionBase>())); + if (prefixSubscriptions.TryGetValue(subscriptionPrefix, out var val)) + val.Item2.TryAdd(sid, session); + return id; + } + + /// <summary> + /// Unsubscribe to a particular Key + /// </summary> + /// <param name="key">Key to subscribe to</param> + /// <param name="session">Server session</param> + /// <returns></returns> + public unsafe bool Unsubscribe(byte* key, ServerSessionBase session) + { + bool ret = false; + var start = key; + var subscriptionKey = new Span<byte>(start, (int)(key - start)).ToArray(); + if (subscriptions != null) + { + if (subscriptions.TryGetValue(subscriptionKey, out var subscriptionDict)) + { + foreach (var sid in subscriptionDict.Keys) + { + if (subscriptionDict[sid] == session) + { + subscriptionDict.TryRemove(sid, out _); + ret = true; + break; + } + } + } + } + return ret; + } + + /// <summary> + /// Unsubscribe to a particular pattern + /// </summary> + /// <param name="key">Pattern to subscribe to</param> + /// <param name="session">Server session</param> + /// <returns></returns> + public unsafe void PUnsubscribe(byte* key, ServerSessionBase session) + { + var start = key; + var subscriptionKey = new Span<byte>(start, (int)(key - start)).ToArray(); + if (prefixSubscriptions != null) + { + if (prefixSubscriptions.ContainsKey(subscriptionKey)) + { + { + prefixSubscriptions.TryGetValue(subscriptionKey, out var subscriptionDict); + foreach (var sid in subscriptionDict.Item2.Keys) + { + if (subscriptionDict.Item2[sid] == session) + { + subscriptionDict.Item2.TryRemove(sid, out _); + break; + } + } + } + } + } + } + + + /// <summary> + /// List all subscriptions made by a session + /// </summary> + /// <param name="session"></param> + /// <returns></returns> + public unsafe List<byte[]> ListAllSubscriptions(ServerSessionBase session) + { + List<byte[]> sessionSubscriptions = new(); + if (subscriptions != null) + { + foreach (var subscription in subscriptions) + { + if (subscription.Value.Values.Contains(session)) + sessionSubscriptions.Add(subscription.Key); + } + } + return sessionSubscriptions; + } + + /// <summary> + /// List all pattern subscriptions made by a session + /// </summary> + /// <param name="session"></param> + /// <returns></returns> + public unsafe List<byte[]> ListAllPSubscriptions(ServerSessionBase session) + { + List<byte[]> sessionPSubscriptions = new(); + foreach (var psubscription in prefixSubscriptions) + { + if (psubscription.Value.Item2.Values.Contains(session)) + sessionPSubscriptions.Add(psubscription.Key); + } + + return sessionPSubscriptions; + } + + /// <summary> + /// Publish the update made to key to all the subscribers, synchronously + /// </summary> + /// <param name="key">key that has been updated</param> + /// <param name="value">value that has been updated</param> + /// <param name="valueLength">value length that has been updated</param> + /// <param name="ascii">whether ascii</param> + public unsafe int PublishNow(byte* key, byte* value, int valueLength, bool ascii) + { + if (subscriptions == null && prefixSubscriptions == null) return 0; + + var start = key; + ref Key k = ref keySerializer.ReadKeyByRef(ref key); + var keyBytes = new Span<byte>(start, (int)(key - start)).ToArray(); + int numSubscribedSessions = Broadcast(keyBytes, value, valueLength, ascii); + return numSubscribedSessions; + } + + /// <summary> + /// Publish the update made to key to all the subscribers, asynchronously + /// </summary> + /// <param name="key">key that has been updated</param> + /// <param name="value">value that has been updated</param> + /// <param name="valueLength">value length that has been updated</param> + /// <param name="ascii">is payload ascii</param> + public unsafe void Publish(byte* key, byte* value, int valueLength, bool ascii = false) + { + if (subscriptions == null && prefixSubscriptions == null) return; + + var start = key; + ref Key k = ref keySerializer.ReadKeyByRef(ref key); + // TODO: this needs to be a single atomic enqueue + byte[] logEntryBytes = new byte[(key - start) + valueLength + sizeof(bool)]; + fixed (byte* logEntryBytePtr = &logEntryBytes[0]) + { + byte* dst = logEntryBytePtr; + Buffer.MemoryCopy(start, dst, (key - start), (key - start)); + dst += (key - start); + Buffer.MemoryCopy(value, dst, valueLength, valueLength); + dst += valueLength; + byte* asciiPtr = (byte*)&ascii; + Buffer.MemoryCopy(asciiPtr, dst, sizeof(bool), sizeof(bool)); + } + + log.Enqueue(logEntryBytes); + log.RefreshUncommitted(); + } + + /// <inheritdoc /> + public void Dispose() + { + disposed = true; + cts.Cancel(); + done.WaitOne(); + subscriptions?.Clear(); + prefixSubscriptions?.Clear(); + log.Dispose(); + device.Dispose(); + } + } +} diff --git a/cs/remote/src/FASTER.server/PubSub/SubscribeKVBroker.cs b/cs/remote/src/FASTER.server/PubSub/SubscribeKVBroker.cs new file mode 100644 index 000000000..3e0a23ec9 --- /dev/null +++ b/cs/remote/src/FASTER.server/PubSub/SubscribeKVBroker.cs @@ -0,0 +1,272 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using FASTER.common; +using FASTER.core; + +namespace FASTER.server +{ + /// <summary> + /// Broker used for PUB-SUB to FASTER KV store. There is a broker per FasterKV instance. + /// A single broker can be used with multiple FasterKVProviders. + /// </summary> + /// <typeparam name="Key"></typeparam> + /// <typeparam name="Value"></typeparam> + /// <typeparam name="Input"></typeparam> + /// <typeparam name="KeyInputSerializer"></typeparam> + public sealed class SubscribeKVBroker<Key, Value, Input, KeyInputSerializer> : IDisposable + where KeyInputSerializer : IKeyInputSerializer<Key, Input> + { + private int sid = 0; + private ConcurrentDictionary<byte[], ConcurrentDictionary<int, (ServerSessionBase, byte[])>> subscriptions; + private ConcurrentDictionary<byte[], ConcurrentDictionary<int, (ServerSessionBase, byte[])>> prefixSubscriptions; + private AsyncQueue<byte[]> publishQueue; + readonly IKeyInputSerializer<Key, Input> keyInputSerializer; + readonly FasterLog log; + readonly IDevice device; + readonly CancellationTokenSource cts = new(); + readonly ManualResetEvent done = new(true); + bool disposed = false; + + /// <summary> + /// Constructor + /// </summary> + /// <param name="keyInputSerializer">Serializer for Prefix Match and serializing Key and Input</param> + /// <param name="logDir">Directory where the log will be stored</param> + /// <param name="startFresh">start the log from scratch, do not continue</param> + public SubscribeKVBroker(IKeyInputSerializer<Key, Input> keyInputSerializer, string logDir, bool startFresh = true) + { + this.keyInputSerializer = keyInputSerializer; + device = logDir == null ? new NullDevice() : Devices.CreateLogDevice(logDir + "/pubsubkv", preallocateFile: false); + device.Initialize((long)(1 << 30) * 64); + log = new FasterLog(new FasterLogSettings { LogDevice = device }); + if (startFresh) + log.TruncateUntil(log.CommittedUntilAddress); + } + + /// <summary> + /// Remove all subscriptions for a session, + /// called during dispose of server session + /// </summary> + /// <param name="session">server session</param> + public void RemoveSubscription(IServerSession session) + { + if (subscriptions != null) + { + foreach (var kvp in subscriptions) + { + foreach (var sub in kvp.Value) + { + if (sub.Value.Item1 == session) + { + kvp.Value.TryRemove(sub.Key, out _); + break; + } + } + } + } + + if (prefixSubscriptions != null) + { + foreach (var kvp in prefixSubscriptions) + { + foreach (var sub in kvp.Value) + { + if (sub.Value.Item1 == session) + { + kvp.Value.TryRemove(sub.Key, out _); + break; + } + } + } + } + } + + internal async Task Start(CancellationToken cancellationToken = default) + { + try + { + var uniqueKeys = new HashSet<byte[]>(new ByteArrayComparer()); + var uniqueKeySubscriptions = new List<(ServerSessionBase, int, bool)>(); + long truncateUntilAddress = log.BeginAddress; + + while (true) + { + if (disposed) + break; + + using var iter = log.Scan(log.BeginAddress, long.MaxValue, scanUncommitted: true); + await iter.WaitAsync(cancellationToken).ConfigureAwait(false); + while (iter.GetNext(out byte[] subscriptionKey, out int entryLength, out long currentAddress, out long nextAddress)) + { + if (currentAddress >= long.MaxValue) return; + uniqueKeys.Add(subscriptionKey); + truncateUntilAddress = nextAddress; + } + + if (truncateUntilAddress > log.BeginAddress) + log.TruncateUntil(truncateUntilAddress); + + unsafe + { + foreach (var keyBytes in uniqueKeys) + { + fixed (byte* ptr = &keyBytes[0]) + { + byte* keyPtr = ptr; + bool foundSubscription = subscriptions.TryGetValue(keyBytes, out var subscriptionServerSessionDict); + if (foundSubscription) + { + foreach (var sub in subscriptionServerSessionDict) + { + byte* keyBytePtr = ptr; + var serverSession = sub.Value.Item1; + byte* nullBytePtr = null; + + fixed (byte* inputPtr = &sub.Value.Item2[0]) + { + byte* inputBytePtr = inputPtr; + serverSession.Publish(ref keyBytePtr, keyBytes.Length, ref nullBytePtr, 0, ref inputBytePtr, sub.Key); + } + } + } + + foreach (var kvp in prefixSubscriptions) + { + var subscribedPrefixBytes = kvp.Key; + var prefixSubscriptionServerSessionDict = kvp.Value; + fixed (byte* subscribedPrefixPtr = &subscribedPrefixBytes[0]) + { + byte* subPrefixPtr = subscribedPrefixPtr; + byte* reqKeyPtr = ptr; + + bool match = keyInputSerializer.Match(ref keyInputSerializer.ReadKeyByRef(ref reqKeyPtr), false, + ref keyInputSerializer.ReadKeyByRef(ref subPrefixPtr), false); + if (match) + { + foreach (var sub in prefixSubscriptionServerSessionDict) + { + byte* keyBytePtr = ptr; + var serverSession = sub.Value.Item1; + byte* nullBytrPtr = null; + + fixed (byte* inputPtr = &sub.Value.Item2[0]) + { + byte* inputBytePtr = inputPtr; + serverSession.Publish(ref keyBytePtr, keyBytes.Length, ref nullBytrPtr, 0, ref inputBytePtr, sub.Key); + } + } + } + } + } + } + uniqueKeySubscriptions.Clear(); + } + uniqueKeys.Clear(); + } + } + } + finally + { + done.Set(); + } + } + + /// <summary> + /// Subscribe to a particular Key + /// </summary> + /// <param name="key">Key to subscribe to</param> + /// <param name="input">Input from subscriber</param> + /// <param name="session">Server session</param> + /// <returns></returns> + public unsafe int Subscribe(ref byte* key, ref byte* input, ServerSessionBase session) + { + var start = key; + var inputStart = input; + keyInputSerializer.ReadKeyByRef(ref key); + keyInputSerializer.ReadInputByRef(ref input); + var id = Interlocked.Increment(ref sid); + if (Interlocked.CompareExchange(ref publishQueue, new AsyncQueue<byte[]>(), null) == null) + { + done.Reset(); + subscriptions = new ConcurrentDictionary<byte[], ConcurrentDictionary<int, (ServerSessionBase, byte[])>>(new ByteArrayComparer()); + prefixSubscriptions = new ConcurrentDictionary<byte[], ConcurrentDictionary<int, (ServerSessionBase, byte[])>>(new ByteArrayComparer()); + Task.Run(() => Start(cts.Token)); + } + else + { + while (prefixSubscriptions == null) Thread.Yield(); + } + var subscriptionKey = new Span<byte>(start, (int)(key - start)).ToArray(); + var subscriptionInput = new Span<byte>(inputStart, (int)(input - inputStart)).ToArray(); + subscriptions.TryAdd(subscriptionKey, new ConcurrentDictionary<int, (ServerSessionBase, byte[])>()); + if (subscriptions.TryGetValue(subscriptionKey, out var val)) + val.TryAdd(sid, (session, subscriptionInput)); + return id; + } + + /// <summary> + /// Subscribe to a particular prefix + /// </summary> + /// <param name="prefix">prefix to subscribe to</param> + /// <param name="input">Input from subscriber</param> + /// <param name="session">Server session</param> + /// <returns></returns> + public unsafe int PSubscribe(ref byte* prefix, ref byte* input, ServerSessionBase session) + { + var start = prefix; + var inputStart = input; + keyInputSerializer.ReadKeyByRef(ref prefix); + keyInputSerializer.ReadInputByRef(ref input); + var id = Interlocked.Increment(ref sid); + if (Interlocked.CompareExchange(ref publishQueue, new AsyncQueue<byte[]>(), null) == null) + { + done.Reset(); + subscriptions = new ConcurrentDictionary<byte[], ConcurrentDictionary<int, (ServerSessionBase, byte[])>>(new ByteArrayComparer()); + prefixSubscriptions = new ConcurrentDictionary<byte[], ConcurrentDictionary<int, (ServerSessionBase, byte[])>>(new ByteArrayComparer()); + Task.Run(() => Start(cts.Token)); + } + else + { + while (prefixSubscriptions == null) Thread.Yield(); + } + var subscriptionPrefix = new Span<byte>(start, (int)(prefix - start)).ToArray(); + var subscriptionInput = new Span<byte>(inputStart, (int)(input - inputStart)).ToArray(); + prefixSubscriptions.TryAdd(subscriptionPrefix, new ConcurrentDictionary<int, (ServerSessionBase, byte[])>()); + if (prefixSubscriptions.TryGetValue(subscriptionPrefix, out var val)) + val.TryAdd(sid, (session, subscriptionInput)); + return id; + } + + /// <summary> + /// Publish the update made to key to all the subscribers + /// </summary> + /// <param name="key">key that has been updated</param> + public unsafe void Publish(byte* key) + { + if (subscriptions == null && prefixSubscriptions == null) return; + + var start = key; + ref Key k = ref keyInputSerializer.ReadKeyByRef(ref key); + log.Enqueue(new Span<byte>(start, (int)(key - start))); + log.RefreshUncommitted(); + } + + /// <inheritdoc /> + public void Dispose() + { + disposed = true; + cts.Cancel(); + done.WaitOne(); + subscriptions?.Clear(); + prefixSubscriptions?.Clear(); + log.Dispose(); + device.Dispose(); + } + } +} diff --git a/cs/remote/src/FASTER.server/ServerKVFunctions.cs b/cs/remote/src/FASTER.server/ServerKVFunctions.cs index 5e415a3f0..4b53e0276 100644 --- a/cs/remote/src/FASTER.server/ServerKVFunctions.cs +++ b/cs/remote/src/FASTER.server/ServerKVFunctions.cs @@ -2,20 +2,18 @@ // Licensed under the MIT license. using FASTER.core; -using FASTER.common; namespace FASTER.server { - internal struct ServerKVFunctions<Key, Value, Input, Output, Functions, ParameterSerializer> : IFunctions<Key, Value, Input, Output, long> - where Functions : IFunctions<Key, Value, Input, Output, long> - where ParameterSerializer : IServerSerializer<Key, Value, Input, Output> + internal struct ServerKVFunctions<Key, Value, Input, Output, Functions> : IAdvancedFunctions<Key, Value, Input, Output, long> + where Functions : IAdvancedFunctions<Key, Value, Input, Output, long> { private readonly Functions functions; - private readonly FasterKVServerSessionBase<Key, Value, Input, Output, Functions, ParameterSerializer> serverNetworkSession; + private readonly FasterKVServerSessionBase<Output> serverNetworkSession; public bool SupportsLocking => functions.SupportsLocking; - public ServerKVFunctions(Functions functions, FasterKVServerSessionBase<Key, Value, Input, Output, Functions, ParameterSerializer> serverNetworkSession) + public ServerKVFunctions(Functions functions, FasterKVServerSessionBase<Output> serverNetworkSession) { this.functions = functions; this.serverNetworkSession = serverNetworkSession; @@ -24,11 +22,14 @@ public ServerKVFunctions(Functions functions, FasterKVServerSessionBase<Key, Val public void CheckpointCompletionCallback(string sessionId, CommitPoint commitPoint) => functions.CheckpointCompletionCallback(sessionId, commitPoint); - public void ConcurrentReader(ref Key key, ref Input input, ref Value value, ref Output dst) - => functions.ConcurrentReader(ref key, ref input, ref value, ref dst); + public void ConcurrentDeleter(ref Key key, ref Value value, ref RecordInfo recordInfo, long address) + => functions.ConcurrentDeleter(ref key, ref value, ref recordInfo, address); - public bool ConcurrentWriter(ref Key key, ref Value src, ref Value dst) - => functions.ConcurrentWriter(ref key, ref src, ref dst); + public void ConcurrentReader(ref Key key, ref Input input, ref Value value, ref Output dst, ref RecordInfo recordInfo, long address) + => functions.ConcurrentReader(ref key, ref input, ref value, ref dst, ref recordInfo, address); + + public bool ConcurrentWriter(ref Key key, ref Value src, ref Value dst, ref RecordInfo recordInfo, long address) + => functions.ConcurrentWriter(ref key, ref src, ref dst, ref recordInfo, address); public bool NeedCopyUpdate(ref Key key, ref Input input, ref Value oldValue, ref Output output) => functions.NeedCopyUpdate(ref key, ref input, ref oldValue, ref output); @@ -41,14 +42,14 @@ public void DeleteCompletionCallback(ref Key key, long ctx) public void InitialUpdater(ref Key key, ref Input input, ref Value value, ref Output output) => functions.InitialUpdater(ref key, ref input, ref value, ref output); + + public bool InPlaceUpdater(ref Key key, ref Input input, ref Value value, ref Output output, ref RecordInfo recordInfo, long address) + => functions.InPlaceUpdater(ref key, ref input, ref value, ref output, ref recordInfo, address); - public bool InPlaceUpdater(ref Key key, ref Input input, ref Value value, ref Output output) - => functions.InPlaceUpdater(ref key, ref input, ref value, ref output); - - public void ReadCompletionCallback(ref Key key, ref Input input, ref Output output, long ctx, Status status) + public void ReadCompletionCallback(ref Key key, ref Input input, ref Output output, long ctx, Status status, RecordInfo recordInfo) { serverNetworkSession.CompleteRead(ref output, ctx, status); - functions.ReadCompletionCallback(ref key, ref input, ref output, ctx, status); + functions.ReadCompletionCallback(ref key, ref input, ref output, ctx, status, recordInfo); } public void RMWCompletionCallback(ref Key key, ref Input input, ref Output output, long ctx, Status status) @@ -57,8 +58,8 @@ public void RMWCompletionCallback(ref Key key, ref Input input, ref Output outpu functions.RMWCompletionCallback(ref key, ref input, ref output, ctx, status); } - public void SingleReader(ref Key key, ref Input input, ref Value value, ref Output dst) - => functions.SingleReader(ref key, ref input, ref value, ref dst); + public void SingleReader(ref Key key, ref Input input, ref Value value, ref Output dst, long address) + => functions.SingleReader(ref key, ref input, ref value, ref dst, address); public void SingleWriter(ref Key key, ref Value src, ref Value dst) => functions.SingleWriter(ref key, ref src, ref dst); diff --git a/cs/remote/src/FASTER.server/ServerSessionBase.cs b/cs/remote/src/FASTER.server/ServerSessionBase.cs index db02a74b8..68bbf97d9 100644 --- a/cs/remote/src/FASTER.server/ServerSessionBase.cs +++ b/cs/remote/src/FASTER.server/ServerSessionBase.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. using System; @@ -25,17 +25,20 @@ public abstract class ServerSessionBase : IServerSession /// <summary> /// Response object /// </summary> - protected ReusableObject<SeaaBuffer> responseObject; + protected SeaaBuffer responseObject; /// <summary> /// Bytes read /// </summary> protected int bytesRead; - private readonly NetworkSender messageManager; + /// <summary> + /// Message manager + /// </summary> + protected readonly NetworkSender messageManager; + private readonly int serverBufferSize; - /// <summary> /// Create new instance /// </summary> @@ -59,7 +62,7 @@ public ServerSessionBase(Socket socket, MaxSizeSettings maxSizeSettings) /// <summary> /// Get response object /// </summary> - protected void GetResponseObject() { if (responseObject.obj == null) responseObject = messageManager.GetReusableSeaaBuffer(); } + protected void GetResponseObject() { if (responseObject == null) responseObject = messageManager.GetReusableSeaaBuffer(); } /// <summary> /// Send response @@ -67,13 +70,15 @@ public ServerSessionBase(Socket socket, MaxSizeSettings maxSizeSettings) /// <param name="size"></param> protected void SendResponse(int size) { + var _r = responseObject; + responseObject = null; try { - messageManager.Send(socket, responseObject, 0, size); + messageManager.Send(socket, _r, 0, size); } catch { - responseObject.Dispose(); + messageManager.Return(_r); } } @@ -84,25 +89,63 @@ protected void SendResponse(int size) /// <param name="size"></param> protected void SendResponse(int offset, int size) { + var _r = responseObject; + responseObject = null; try { - messageManager.Send(socket, responseObject, offset, size); + messageManager.Send(socket, _r, offset, size); } catch { - responseObject.Dispose(); + messageManager.Return(_r); } } + /// <summary> + /// Publish an update to a key to all the subscribers of the key + /// </summary> + /// <param name="keyPtr"></param> + /// <param name="keyLength"></param> + /// <param name="valPtr"></param> + /// <param name="valLength"></param> + /// <param name="inputPtr"></param> + /// <param name="sid"></param> + public abstract unsafe void Publish(ref byte* keyPtr, int keyLength, ref byte* valPtr, int valLength, ref byte* inputPtr, int sid); + + /// <summary> + /// Publish an update to a key to all the (prefix) subscribers of the key + /// </summary> + /// <param name="prefixPtr"></param> + /// <param name="prefixLength"></param> + /// <param name="keyPtr"></param> + /// <param name="keyLength"></param> + /// <param name="valPtr"></param> + /// <param name="valLength"></param> + /// <param name="inputPtr"></param> + /// <param name="sid"></param> + public abstract unsafe void PrefixPublish(byte* prefixPtr, int prefixLength, ref byte* keyPtr, int keyLength, ref byte* valPtr, int valLength, ref byte* inputPtr, int sid); + /// <summary> /// Dispose /// </summary> public virtual void Dispose() { socket.Dispose(); - if (responseObject.obj != null) - responseObject.Dispose(); + var _r = responseObject; + if (_r != null) + messageManager.Return(_r); + messageManager.Dispose(); + } + + /// <summary> + /// Wait for ongoing outgoing calls to complete + /// </summary> + public virtual void CompleteSends() + { + var _r = responseObject; + if (_r != null) + messageManager.Return(_r); messageManager.Dispose(); } } diff --git a/cs/remote/src/FASTER.server/Servers/FixedLenServer.cs b/cs/remote/src/FASTER.server/Servers/FixedLenServer.cs new file mode 100644 index 000000000..ecae78eb7 --- /dev/null +++ b/cs/remote/src/FASTER.server/Servers/FixedLenServer.cs @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +using System; +using FASTER.core; +using FASTER.common; + +namespace FASTER.server +{ + /// <summary> + /// FASTER server for variable-length data + /// </summary> + public sealed class FixedLenServer<Key, Value, Input, Output, Functions> : GenericServer<Key, Value, Input, Output, Functions, FixedLenSerializer<Key, Value, Input, Output>> + where Key : unmanaged + where Value : unmanaged + where Input : unmanaged + where Output : unmanaged + where Functions : IAdvancedFunctions<Key, Value, Input, Output, long> + { + /// <summary> + /// Create server instance; use Start to start the server. + /// </summary> + /// <param name="opts"></param> + /// <param name="functionsGen"></param> + /// <param name="maxSizeSettings"></param> + public FixedLenServer(ServerOptions opts, Func<WireFormat, Functions> functionsGen, MaxSizeSettings maxSizeSettings = default) + : base(opts, functionsGen, new FixedLenSerializer<Key, Value, Input, Output>(), new FixedLenKeySerializer<Key, Input>(), maxSizeSettings) + { + } + } +} diff --git a/cs/remote/src/FASTER.server/Servers/GenericServer.cs b/cs/remote/src/FASTER.server/Servers/GenericServer.cs new file mode 100644 index 000000000..4230c9661 --- /dev/null +++ b/cs/remote/src/FASTER.server/Servers/GenericServer.cs @@ -0,0 +1,127 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +using System; +using System.IO; +using FASTER.core; +using FASTER.common; + +namespace FASTER.server +{ + /// <summary> + /// FASTER server for generic types + /// </summary> + public class GenericServer<Key, Value, Input, Output, Functions, ParameterSerializer> : IDisposable + where Functions : IAdvancedFunctions<Key, Value, Input, Output, long> + where ParameterSerializer : IServerSerializer<Key, Value, Input, Output> + { + readonly ServerOptions opts; + readonly FasterServer server; + readonly FasterKV<Key, Value> store; + readonly FasterKVProvider<Key, Value, Input, Output, Functions, ParameterSerializer> provider; + readonly SubscribeKVBroker<Key, Value, Input, IKeyInputSerializer<Key, Input>> kvBroker; + readonly SubscribeBroker<Key, Value, IKeySerializer<Key>> broker; + + /// <summary> + /// Create server instance; use Start to start the server. + /// </summary> + /// <param name="opts"></param> + /// <param name="functionsGen"></param> + /// <param name="serializer"></param> + /// <param name="keyInputSerializer"></param> + /// <param name="maxSizeSettings"></param> + public GenericServer(ServerOptions opts, Func<WireFormat, Functions> functionsGen, ParameterSerializer serializer, IKeyInputSerializer<Key, Input> keyInputSerializer, MaxSizeSettings maxSizeSettings = default) + { + this.opts = opts; + + if (opts.LogDir != null && !Directory.Exists(opts.LogDir)) + Directory.CreateDirectory(opts.LogDir); + + if (opts.CheckpointDir != null && !Directory.Exists(opts.CheckpointDir)) + Directory.CreateDirectory(opts.CheckpointDir); + + opts.GetSettings(out var logSettings, out var checkpointSettings, out var indexSize); + store = new FasterKV<Key, Value>(indexSize, logSettings, checkpointSettings); + + if (opts.Recover) + { + try + { + store.Recover(); + } + catch { } + } + + if (opts.EnablePubSub) + { + kvBroker = new SubscribeKVBroker<Key, Value, Input, IKeyInputSerializer<Key, Input>>(keyInputSerializer, null, true); + broker = new SubscribeBroker<Key, Value, IKeySerializer<Key>>(keyInputSerializer, null, true); + } + + // Create session provider for VarLen + provider = new FasterKVProvider<Key, Value, Input, Output, Functions, ParameterSerializer>(store, functionsGen, kvBroker, broker, serializer, maxSizeSettings); + + server = new FasterServer(opts.Address, opts.Port); + server.Register(WireFormat.DefaultFixedLenKV, provider); + } + + /// <summary> + /// Start server instance + /// </summary> + public void Start() => server.Start(); + + /// <summary> + /// Dispose store (including log and checkpoint directory) + /// </summary> + public void Dispose() + { + InternalDispose(); + DeleteDirectory(opts.LogDir); + DeleteDirectory(opts.CheckpointDir); + } + + /// <summary> + /// Dipose, optionally deleting logs and checkpoints + /// </summary> + /// <param name="deleteDir">Whether to delete logs and checkpoints</param> + public void Dispose(bool deleteDir = true) + { + InternalDispose(); + if (deleteDir) DeleteDirectory(opts.LogDir); + if (deleteDir) DeleteDirectory(opts.CheckpointDir); + } + + private void InternalDispose() + { + server.Dispose(); + broker?.Dispose(); + kvBroker?.Dispose(); + store.Dispose(); + } + + private static void DeleteDirectory(string path) + { + if (path == null) return; + + foreach (string directory in Directory.GetDirectories(path)) + { + DeleteDirectory(directory); + } + + // Exceptions may happen due to a handle briefly remaining held after Dispose(). + try + { + Directory.Delete(path, true); + } + catch (Exception ex) when (ex is IOException || + ex is UnauthorizedAccessException) + { + try + { + Directory.Delete(path, true); + } + catch { } + } + } + } +} diff --git a/cs/remote/src/FASTER.server/Servers/ServerOptions.cs b/cs/remote/src/FASTER.server/Servers/ServerOptions.cs new file mode 100644 index 000000000..0c5647ebe --- /dev/null +++ b/cs/remote/src/FASTER.server/Servers/ServerOptions.cs @@ -0,0 +1,241 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +using System; +using System.Diagnostics; +using System.IO; +using FASTER.core; + +namespace FASTER.server +{ + /// <summary> + /// Options when creating FASTER server + /// </summary> + public class ServerOptions + { + /// <summary> + /// Port to run server on + /// </summary> + public int Port = 3278; + + /// <summary> + /// IP address to bind server to + /// </summary> + public string Address = "127.0.0.1"; + + /// <summary> + /// Total log memory used in bytes (rounds down to power of 2) + /// </summary> + public string MemorySize = "16g"; + + /// <summary> + /// Size of each page in bytes (rounds down to power of 2) + /// </summary> + public string PageSize = "32m"; + + /// <summary> + /// Size of each log segment in bytes on disk (rounds down to power of 2) + /// </summary> + public string SegmentSize = "1g"; + + /// <summary> + /// Size of hash index in bytes (rounds down to power of 2) + /// </summary> + public string IndexSize = "8g"; + + /// <summary> + /// Storage directory for data (hybrid log). Runs memory-only if unspecified. + /// </summary> + public string LogDir = null; + + /// <summary> + /// Storage directory for checkpoints. Uses 'checkpoints' folder under logdir if unspecified. + /// </summary> + public string CheckpointDir = null; + + /// <summary> + /// Recover from latest checkpoint. + /// </summary> + public bool Recover = false; + + /// <summary> + /// Enable pub/sub feature on server. + /// </summary> + public bool EnablePubSub = true; + + /// <summary> + /// Constructor + /// </summary> + public ServerOptions() + { + } + + internal int MemorySizeBits() + { + long size = ParseSize(MemorySize); + long adjustedSize = PreviousPowerOf2(size); + if (size != adjustedSize) + Trace.WriteLine($"Warning: using lower log memory size than specified (power of 2)"); + return (int)Math.Log(adjustedSize, 2); + } + + internal int PageSizeBits() + { + long size = ParseSize(PageSize); + long adjustedSize = PreviousPowerOf2(size); + if (size != adjustedSize) + Trace.WriteLine($"Warning: using lower page size than specified (power of 2)"); + return (int)Math.Log(adjustedSize, 2); + } + + internal int SegmentSizeBits() + { + long size = ParseSize(SegmentSize); + long adjustedSize = PreviousPowerOf2(size); + if (size != adjustedSize) + Trace.WriteLine($"Warning: using lower disk segment size than specified (power of 2)"); + return (int)Math.Log(adjustedSize, 2); + } + + internal int IndexSizeCachelines() + { + long size = ParseSize(IndexSize); + long adjustedSize = PreviousPowerOf2(size); + if (adjustedSize < 64 || adjustedSize > (1L << 37)) throw new Exception("Invalid index size"); + if (size != adjustedSize) + Trace.WriteLine($"Warning: using lower hash index size than specified (power of 2)"); + return (int)(adjustedSize / 64); + } + + internal void GetSettings(out LogSettings logSettings, out CheckpointSettings checkpointSettings, out int indexSize) + { + logSettings = new LogSettings { PreallocateLog = false }; + + logSettings.PageSizeBits = PageSizeBits(); + Trace.WriteLine($"[Store] Using page size of {PrettySize((long)Math.Pow(2, logSettings.PageSizeBits))}"); + + logSettings.MemorySizeBits = MemorySizeBits(); + Trace.WriteLine($"[Store] Using log memory size of {PrettySize((long)Math.Pow(2, logSettings.MemorySizeBits))}"); + + Trace.WriteLine($"[Store] There are {PrettySize(1 << (logSettings.MemorySizeBits - logSettings.PageSizeBits))} log pages in memory"); + + logSettings.SegmentSizeBits = SegmentSizeBits(); + Trace.WriteLine($"[Store] Using disk segment size of {PrettySize((long)Math.Pow(2, logSettings.SegmentSizeBits))}"); + + indexSize = IndexSizeCachelines(); + Trace.WriteLine($"[Store] Using hash index size of {PrettySize(indexSize * 64L)} ({PrettySize(indexSize)} cache lines)"); + + if (LogDir == null) + LogDir = Directory.GetCurrentDirectory(); + + var device = LogDir == "" ? new NullDevice() : Devices.CreateLogDevice(LogDir + "/Store/hlog"); + logSettings.LogDevice = device; + + checkpointSettings = new CheckpointSettings + { + CheckPointType = CheckpointType.Snapshot, + CheckpointDir = CheckpointDir ?? (LogDir + "/Store/checkpoints"), + RemoveOutdated = true, + }; + } + + internal void GetObjectStoreSettings(out LogSettings objLogSettings, out CheckpointSettings objCheckpointSettings, out int objIndexSize) + { + objLogSettings = new LogSettings { PreallocateLog = false }; + + objLogSettings.PageSizeBits = PageSizeBits(); + Trace.WriteLine($"[Object Store] Using page size of {PrettySize((long)Math.Pow(2, objLogSettings.PageSizeBits))}"); + + objLogSettings.MemorySizeBits = MemorySizeBits(); + Trace.WriteLine($"[Object Store] Using log memory size of {PrettySize((long)Math.Pow(2, objLogSettings.MemorySizeBits))}"); + + Trace.WriteLine($"[Object Store] There are {PrettySize(1 << (objLogSettings.MemorySizeBits - objLogSettings.PageSizeBits))} log pages in memory"); + + objLogSettings.SegmentSizeBits = SegmentSizeBits(); + Trace.WriteLine($"[Object Store] Using disk segment size of {PrettySize((long)Math.Pow(2, objLogSettings.SegmentSizeBits))}"); + + objIndexSize = IndexSizeCachelines() / 64; + Trace.WriteLine($"[Object Store] Using hash index size of {PrettySize(objIndexSize * 64L)} ({PrettySize(objIndexSize)} cache lines)"); + + if (LogDir == null) + LogDir = Directory.GetCurrentDirectory(); + + var device = LogDir == "" ? new NullDevice() : Devices.CreateLogDevice(LogDir + "/ObjectStore/hlog"); + objLogSettings.LogDevice = device; + var objdevice = LogDir == "" ? new NullDevice() : Devices.CreateLogDevice(LogDir + "/ObjectStore/hlog.obj"); + objLogSettings.ObjectLogDevice = objdevice; + + objCheckpointSettings = new CheckpointSettings + { + CheckPointType = CheckpointType.Snapshot, + CheckpointDir = CheckpointDir ?? (LogDir + "/ObjectStore/checkpoints"), + RemoveOutdated = true, + }; + } + + private static long ParseSize(string value) + { + char[] suffix = new char[] { 'k', 'm', 'g', 't', 'p' }; + long result = 0; + foreach (char c in value) + { + if (char.IsDigit(c)) + { + result = result * 10 + (byte)c - '0'; + } + else + { + for (int i = 0; i < suffix.Length; i++) + { + if (char.ToLower(c) == suffix[i]) + { + result *= (long)Math.Pow(1024, i + 1); + return result; + } + } + } + } + return result; + } + + private static string PrettySize(long value) + { + char[] suffix = new char[] { 'k', 'm', 'g', 't', 'p' }; + double v = value; + int exp = 0; + while (v - Math.Floor(v) > 0) + { + if (exp >= 18) + break; + exp += 3; + v *= 1024; + v = Math.Round(v, 12); + } + + while (Math.Floor(v).ToString().Length > 3) + { + if (exp <= -18) + break; + exp -= 3; + v /= 1024; + v = Math.Round(v, 12); + } + if (exp > 0) + return v.ToString() + suffix[exp / 3 - 1]; + else if (exp < 0) + return v.ToString() + suffix[-exp / 3 - 1]; + return v.ToString(); + } + + private long PreviousPowerOf2(long v) + { + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; + v |= v >> 32; + return v - (v >> 1); + } + } +} \ No newline at end of file diff --git a/cs/remote/src/FASTER.server/Servers/VarLenServer.cs b/cs/remote/src/FASTER.server/Servers/VarLenServer.cs new file mode 100644 index 000000000..6ece966a4 --- /dev/null +++ b/cs/remote/src/FASTER.server/Servers/VarLenServer.cs @@ -0,0 +1,115 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +using System; +using System.IO; +using FASTER.core; +using FASTER.common; + +namespace FASTER.server +{ + /// <summary> + /// FASTER server for variable-length data + /// </summary> + public sealed class VarLenServer : IDisposable + { + readonly ServerOptions opts; + readonly FasterServer server; + readonly FasterKV<SpanByte, SpanByte> store; + readonly SpanByteFasterKVProvider provider; + readonly SubscribeKVBroker<SpanByte, SpanByte, SpanByte, IKeyInputSerializer<SpanByte, SpanByte>> kvBroker; + readonly SubscribeBroker<SpanByte, SpanByte, IKeySerializer<SpanByte>> broker; + readonly LogSettings logSettings; + + /// <summary> + /// Create server instance; use Start to start the server. + /// </summary> + /// <param name="opts"></param> + public VarLenServer(ServerOptions opts) + { + this.opts = opts; + + if (opts.LogDir != null && !Directory.Exists(opts.LogDir)) + Directory.CreateDirectory(opts.LogDir); + + if (opts.CheckpointDir != null && !Directory.Exists(opts.CheckpointDir)) + Directory.CreateDirectory(opts.CheckpointDir); + + opts.GetSettings(out logSettings, out var checkpointSettings, out var indexSize); + store = new FasterKV<SpanByte, SpanByte>(indexSize, logSettings, checkpointSettings); + + if (opts.EnablePubSub) + { + kvBroker = new SubscribeKVBroker<SpanByte, SpanByte, SpanByte, IKeyInputSerializer<SpanByte, SpanByte>>(new SpanByteKeySerializer(), null, true); + broker = new SubscribeBroker<SpanByte, SpanByte, IKeySerializer<SpanByte>>(new SpanByteKeySerializer(), null, true); + } + + // Create session provider for VarLen + provider = new SpanByteFasterKVProvider(store, kvBroker, broker, opts.Recover); + + server = new FasterServer(opts.Address, opts.Port); + server.Register(WireFormat.DefaultVarLenKV, provider); + server.Register(WireFormat.WebSocket, provider); + } + + /// <summary> + /// Start server instance + /// </summary> + public void Start() => server.Start(); + + /// <summary> + /// Dispose store (including log and checkpoint directory) + /// </summary> + public void Dispose() + { + InternalDispose(); + DeleteDirectory(opts.LogDir); + DeleteDirectory(opts.CheckpointDir); + } + + /// <summary> + /// Dipose, optionally deleting logs and checkpoints + /// </summary> + /// <param name="deleteDir">Whether to delete logs and checkpoints</param> + public void Dispose(bool deleteDir = true) + { + InternalDispose(); + if (deleteDir) DeleteDirectory(opts.LogDir); + if (deleteDir) DeleteDirectory(opts.CheckpointDir); + } + + private void InternalDispose() + { + server.Dispose(); + broker?.Dispose(); + kvBroker?.Dispose(); + store.Dispose(); + logSettings.LogDevice.Dispose(); + } + + private static void DeleteDirectory(string path) + { + if (path == null) return; + + foreach (string directory in Directory.GetDirectories(path)) + { + DeleteDirectory(directory); + } + + // Exceptions may happen due to a handle briefly remaining held after Dispose(). + try + { + Directory.Delete(path, true); + } + catch (Exception ex) when (ex is IOException || + ex is UnauthorizedAccessException) + { + try + { + Directory.Delete(path, true); + } + catch { } + } + } + } +} diff --git a/cs/remote/src/FASTER.server/SpanByteClientSerializer.cs b/cs/remote/src/FASTER.server/SpanByteClientSerializer.cs index e2768a276..6ea9c3856 100644 --- a/cs/remote/src/FASTER.server/SpanByteClientSerializer.cs +++ b/cs/remote/src/FASTER.server/SpanByteClientSerializer.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. using System; @@ -34,6 +34,20 @@ public SpanByteAndMemory ReadOutput(ref byte* src) return new SpanByteAndMemory(mem, length); } + /// <inheritdoc /> + public SpanByte ReadKey(ref byte* src) + { + int length = *(int*)src; + return SpanByte.FromPointer(src, length); + } + + /// <inheritdoc /> + public SpanByte ReadValue(ref byte* src) + { + int length = *(int*)src; + return SpanByte.FromPointer(src, length); + } + /// <inheritdoc /> public bool Write(ref SpanByte k, ref byte* dst, int length) { diff --git a/cs/remote/src/FASTER.server/SpanByteFasterKVProvider.cs b/cs/remote/src/FASTER.server/SpanByteFasterKVProvider.cs deleted file mode 100644 index 2623c8c3a..000000000 --- a/cs/remote/src/FASTER.server/SpanByteFasterKVProvider.cs +++ /dev/null @@ -1,43 +0,0 @@ -using System; -using System.Net.Sockets; -using FASTER.common; -using FASTER.core; - -namespace FASTER.server -{ - - /// <summary> - /// Session provider for FasterKV store based on - /// [K, V, I, O, C] = [SpanByte, SpanByte, SpanByte, SpanByteAndMemory, long] - /// </summary> - public sealed class SpanByteFasterKVProvider : ISessionProvider, IDisposable - { - readonly FasterKV<SpanByte, SpanByte> store; - readonly SpanByteServerSerializer serializer; - readonly MaxSizeSettings maxSizeSettings; - - /// <summary> - /// Create SpanByte FasterKV backend - /// </summary> - /// <param name="store"></param> - /// <param name="maxSizeSettings"></param> - public SpanByteFasterKVProvider(FasterKV<SpanByte, SpanByte> store, MaxSizeSettings maxSizeSettings = default) - { - this.store = store; - this.serializer = new SpanByteServerSerializer(); - this.maxSizeSettings = maxSizeSettings ?? new MaxSizeSettings(); - } - - /// <inheritdoc /> - public IServerSession GetSession(WireFormat wireFormat, Socket socket) - { - return new BinaryServerSession<SpanByte, SpanByte, SpanByte, SpanByteAndMemory, SpanByteFunctionsForServer<long>, SpanByteServerSerializer> - (socket, store, new SpanByteFunctionsForServer<long>(wireFormat), serializer, maxSizeSettings); - } - - /// <inheritdoc /> - public void Dispose() - { - } - } -} \ No newline at end of file diff --git a/cs/remote/src/FASTER.server/SpanByteFunctionsForServer.cs b/cs/remote/src/FASTER.server/SpanByteFunctionsForServer.cs index 03f51ccd2..79f3a906a 100644 --- a/cs/remote/src/FASTER.server/SpanByteFunctionsForServer.cs +++ b/cs/remote/src/FASTER.server/SpanByteFunctionsForServer.cs @@ -10,35 +10,27 @@ namespace FASTER.server /// <summary> /// Callback functions using SpanByteAndMemory output, for SpanByte key, value, input /// </summary> - public class SpanByteFunctionsForServer<Context> : SpanByteFunctions<SpanByte, SpanByteAndMemory, Context> + public class SpanByteFunctionsForServer<Context> : SpanByteAdvancedFunctions<SpanByte, SpanByteAndMemory, Context> { - readonly WireFormat wireFormat; - readonly MemoryPool<byte> memoryPool; + /// <summary> + /// Memory pool + /// </summary> + readonly protected MemoryPool<byte> memoryPool; /// <summary> /// Constructor /// </summary> - /// <param name="wireFormat"></param> /// <param name="memoryPool"></param> - public SpanByteFunctionsForServer(WireFormat wireFormat = default, MemoryPool<byte> memoryPool = default) : base(true) + public SpanByteFunctionsForServer(MemoryPool<byte> memoryPool = default) : base(true) { - this.wireFormat = wireFormat; this.memoryPool = memoryPool ?? MemoryPool<byte>.Shared; } /// <inheritdoc /> - public unsafe override void SingleReader(ref SpanByte key, ref SpanByte input, ref SpanByte value, ref SpanByteAndMemory dst) - { - if (wireFormat != WireFormat.ASCII) - CopyWithHeaderTo(ref value, ref dst, memoryPool); - } + public override void SingleReader(ref SpanByte key, ref SpanByte input, ref SpanByte value, ref SpanByteAndMemory dst, long address) => CopyWithHeaderTo(ref value, ref dst, memoryPool); /// <inheritdoc /> - public unsafe override void ConcurrentReader(ref SpanByte key, ref SpanByte input, ref SpanByte value, ref SpanByteAndMemory dst) - { - if (wireFormat != WireFormat.ASCII) - CopyWithHeaderTo(ref value, ref dst, memoryPool); - } + public override void ConcurrentReader(ref SpanByte key, ref SpanByte input, ref SpanByte value, ref SpanByteAndMemory dst, ref RecordInfo recordInfo, long address) => CopyWithHeaderTo(ref value, ref dst, memoryPool); /// <summary> /// Copy to given SpanByteAndMemory (header length and payload copied to actual span/memory) @@ -46,7 +38,7 @@ public unsafe override void ConcurrentReader(ref SpanByte key, ref SpanByte inpu /// <param name="src"></param> /// <param name="dst"></param> /// <param name="memoryPool"></param> - private unsafe void CopyWithHeaderTo(ref SpanByte src, ref SpanByteAndMemory dst, MemoryPool<byte> memoryPool) + private static unsafe bool CopyWithHeaderTo(ref SpanByte src, ref SpanByteAndMemory dst, MemoryPool<byte> memoryPool) { if (dst.IsSpanByte) { @@ -57,7 +49,7 @@ private unsafe void CopyWithHeaderTo(ref SpanByte src, ref SpanByteAndMemory dst fixed (byte* ptr = span) *(int*)ptr = src.Length; src.AsReadOnlySpan().CopyTo(span.Slice(sizeof(int))); - return; + return true; } dst.ConvertToHeap(); } @@ -68,6 +60,7 @@ private unsafe void CopyWithHeaderTo(ref SpanByte src, ref SpanByteAndMemory dst fixed (byte* ptr = dst.Memory.Memory.Span) *(int*)ptr = src.Length; src.AsReadOnlySpan().CopyTo(dst.Memory.Memory.Span.Slice(sizeof(int))); + return true; } } -} \ No newline at end of file +} diff --git a/cs/remote/src/FASTER.server/SpanByteServerSerializer.cs b/cs/remote/src/FASTER.server/SpanByteServerSerializer.cs index 064b0e3a1..3bc7217f2 100644 --- a/cs/remote/src/FASTER.server/SpanByteServerSerializer.cs +++ b/cs/remote/src/FASTER.server/SpanByteServerSerializer.cs @@ -57,6 +57,20 @@ public ref SpanByte ReadInputByRef(ref byte* src) return ref ret; } + /// <inheritdoc /> + public bool Write(ref SpanByte k, ref byte* dst, int length) + { + if (k.Length > length) return false; + + *(int*)dst = k.Length; + dst += sizeof(int); + var dest = new SpanByte(k.Length, (IntPtr)dst); + k.CopyTo(ref dest); + dst += k.Length; + return true; + } + + /// <inheritdoc /> public bool Write(ref SpanByteAndMemory k, ref byte* dst, int length) { @@ -73,7 +87,6 @@ public bool Write(ref SpanByteAndMemory k, ref byte* dst, int length) /// <inheritdoc /> public ref SpanByteAndMemory AsRefOutput(byte* src, int length) { - *(int*)src = length - sizeof(int); output = SpanByteAndMemory.FromFixedSpan(new Span<byte>(src, length)); return ref output; } diff --git a/cs/remote/src/FASTER.server/WebsocketServerSession.cs b/cs/remote/src/FASTER.server/WebsocketServerSession.cs new file mode 100644 index 000000000..e8edce724 --- /dev/null +++ b/cs/remote/src/FASTER.server/WebsocketServerSession.cs @@ -0,0 +1,691 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +using System; +using System.Diagnostics; +using System.Net.Sockets; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Text; +using System.Text.RegularExpressions; +using FASTER.common; +using FASTER.core; +using System.Collections.Concurrent; +using System.Collections.Generic; + +namespace FASTER.server +{ + internal struct Decoder + { + public int msgLen; + public int maskStart; + public int dataStart; + }; + + internal unsafe sealed class WebsocketServerSession<Key, Value, Input, Output, Functions, ParameterSerializer> + : FasterKVServerSessionBase<Key, Value, Input, Output, Functions, ParameterSerializer> + where Functions : IAdvancedFunctions<Key, Value, Input, Output, long> + where ParameterSerializer : IServerSerializer<Key, Value, Input, Output> + { + readonly HeaderReaderWriter hrw; + GCHandle recvHandle; + byte* recvBufferPtr; + int readHead; + + int pendingSeqNo, msgnum, start; + byte* dcurr; + + readonly SubscribeKVBroker<Key, Value, Input, IKeyInputSerializer<Key, Input>> subscribeKVBroker; + readonly SubscribeBroker<Key, Value, IKeySerializer<Key>> subscribeBroker; + + public WebsocketServerSession(Socket socket, FasterKV<Key, Value> store, Functions functions, ParameterSerializer serializer, MaxSizeSettings maxSizeSettings, SubscribeKVBroker<Key, Value, Input, IKeyInputSerializer<Key, Input>> subscribeKVBroker, SubscribeBroker<Key, Value, IKeySerializer<Key>> subscribeBroker) + : base(socket, store, functions, null, serializer, maxSizeSettings) + { + this.subscribeKVBroker = subscribeKVBroker; + this.subscribeBroker = subscribeBroker; + + readHead = 0; + + // Reserve minimum 4 bytes to send pending sequence number as output + if (this.maxSizeSettings.MaxOutputSize < sizeof(int)) + this.maxSizeSettings.MaxOutputSize = sizeof(int); + } + + public override int TryConsumeMessages(byte[] buf) + { + if (recvBufferPtr == null) + { + recvHandle = GCHandle.Alloc(buf, GCHandleType.Pinned); + recvBufferPtr = (byte*)recvHandle.AddrOfPinnedObject(); + } + + while (TryReadMessages(buf, out var offset)) + { + bool completeWSCommand = ProcessBatch(buf, offset); + if (!completeWSCommand) + return bytesRead; + } + + // The bytes left in the current buffer not consumed by previous operations + var bytesLeft = bytesRead - readHead; + if (bytesLeft != bytesRead) + { + // Shift them to the head of the array so we can reset the buffer to a consistent state + Array.Copy(buf, readHead, buf, 0, bytesLeft); + bytesRead = bytesLeft; + readHead = 0; + } + + return bytesRead; + } + + public override void CompleteRead(ref Output output, long ctx, core.Status status) + { + byte* d = responseObject.bufferPtr; + var dend = d + responseObject.buffer.Length; + + if ((int)(dend - dcurr) < 7 + maxSizeSettings.MaxOutputSize) + SendAndReset(ref d, ref dend); + + hrw.Write(MessageType.PendingResult, ref dcurr, (int)(dend - dcurr)); + hrw.Write((MessageType)(ctx >> 32), ref dcurr, (int)(dend - dcurr)); + Write((int)(ctx & 0xffffffff), ref dcurr, (int)(dend - dcurr)); + Write(ref status, ref dcurr, (int)(dend - dcurr)); + if (status != core.Status.NOTFOUND) + serializer.Write(ref output, ref dcurr, (int)(dend - dcurr)); + msgnum++; + } + + public override void CompleteRMW(ref Output output, long ctx, Status status) + { + byte* d = responseObject.bufferPtr; + var dend = d + responseObject.buffer.Length; + + if ((int)(dend - dcurr) < 7 + maxSizeSettings.MaxOutputSize) + SendAndReset(ref d, ref dend); + + hrw.Write(MessageType.PendingResult, ref dcurr, (int)(dend - dcurr)); + hrw.Write((MessageType)(ctx >> 32), ref dcurr, (int)(dend - dcurr)); + Write((int)(ctx & 0xffffffff), ref dcurr, (int)(dend - dcurr)); + Write(ref status, ref dcurr, (int)(dend - dcurr)); + if (status == Status.OK || status == Status.NOTFOUND) + serializer.Write(ref output, ref dcurr, (int)(dend - dcurr)); + msgnum++; + + int packetLen = (int)((dcurr - 10) - d); + CreateSendPacketHeader(ref d, packetLen); + } + + + private bool TryReadMessages(byte[] buf, out int offset) + { + offset = default; + + var bytesAvailable = bytesRead - readHead; + // Need to at least have read off of size field on the message + if (bytesAvailable < sizeof(int)) return false; + + offset = readHead; + return true; + } + + private unsafe void CreateSendPacketHeader(ref byte* d, int payloadLen) + { + if (payloadLen < 126) + { + d += 8; + } + else if (payloadLen < 65536) + { + d += 6; + } + byte* dcurr = d; + + *dcurr = 0b10000010; + dcurr++; + if (payloadLen < 126) + { + *dcurr = (byte)(payloadLen & 0b01111111); + dcurr++; + } + else if (payloadLen < 65536) + { + *dcurr = (byte)(0b01111110); + dcurr++; + byte[] payloadLenBytes = BitConverter.GetBytes((UInt16)payloadLen); + if (BitConverter.IsLittleEndian) + Array.Reverse(payloadLenBytes); + + *dcurr++ = payloadLenBytes[0]; + *dcurr++ = payloadLenBytes[1]; + } + else + { + *dcurr = (byte)(0b01111111); + dcurr++; + byte[] payloadLenBytes = BitConverter.GetBytes((UInt64)payloadLen); + if (BitConverter.IsLittleEndian) + Array.Reverse(payloadLenBytes); + + *dcurr++ = (byte)(payloadLenBytes[0] & 0b01111111); + *dcurr++ = payloadLenBytes[1]; + *dcurr++ = payloadLenBytes[2]; + *dcurr++ = payloadLenBytes[3]; + *dcurr++ = payloadLenBytes[4]; + *dcurr++ = payloadLenBytes[5]; + *dcurr++ = payloadLenBytes[6]; + *dcurr++ = payloadLenBytes[7]; + } + } + + private unsafe bool ProcessBatch(byte[] buf, int offset) + { + bool completeWSCommand = true; + GetResponseObject(); + + fixed (byte* b = &buf[offset]) + { + byte* d = responseObject.bufferPtr; + var dend = d + responseObject.buffer.Length; + dcurr = d; // reserve space for size + var bytesAvailable = bytesRead - readHead; + var _origReadHead = readHead; + int msglen = 0; + byte[] decoded = Array.Empty<byte>(); + var ptr = recvBufferPtr + readHead; + var totalMsgLen = 0; + List<Decoder> decoderInfoList = new(); + + if (buf[offset] == 71 && buf[offset + 1] == 69 && buf[offset + 2] == 84) + { + // 1. Obtain the value of the "Sec-WebSocket-Key" request header without any leading or trailing whitespace + // 2. Concatenate it with "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" (a special GUID specified by RFC 6455) + // 3. Compute SHA-1 and Base64 hash of the new value + // 4. Write the hash back as the value of "Sec-WebSocket-Accept" response header in an HTTP response + string s = Encoding.UTF8.GetString(buf, offset, buf.Length - offset); + string swk = Regex.Match(s, "Sec-WebSocket-Key: (.*)").Groups[1].Value.Trim(); + string swka = swk + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + byte[] swkaSha1 = System.Security.Cryptography.SHA1.Create().ComputeHash(Encoding.UTF8.GetBytes(swka)); + string swkaSha1Base64 = Convert.ToBase64String(swkaSha1); + + // HTTP/1.1 defines the sequence CR LF as the end-of-line marker + byte[] response = Encoding.UTF8.GetBytes( + "HTTP/1.1 101 Switching Protocols\r\n" + + "Connection: Upgrade\r\n" + + "Upgrade: websocket\r\n" + + "Sec-WebSocket-Accept: " + swkaSha1Base64 + "\r\n\r\n"); + + fixed (byte* responsePtr = &response[0]) + Buffer.MemoryCopy(responsePtr, dcurr, response.Length, response.Length); + + dcurr += response.Length; + + SendResponse((int)(d - responseObject.bufferPtr), (int)(dcurr - d)); + readHead = bytesRead; + return completeWSCommand; + + } + else + { + var decoderInfo = new Decoder(); + + bool fin = (buf[offset] & 0b10000000) != 0, + mask = (buf[offset + 1] & 0b10000000) != 0; // must be true, "All messages from the client to the server have this bit set" + + int opcode = buf[offset] & 0b00001111; // expecting 1 - text message + offset++; + + msglen = buf[offset] - 128; // & 0111 1111 + + if (msglen < 125) + { + offset++; + } + else if (msglen == 126) + { + msglen = BitConverter.ToUInt16(new byte[] { buf[offset + 2], buf[offset + 1] }, 0); + offset += 3; + } + else if (msglen == 127) + { + msglen = (int)BitConverter.ToUInt64(new byte[] { buf[offset + 8], buf[offset + 7], buf[offset + 6], buf[offset + 5], buf[offset + 4], buf[offset + 3], buf[offset + 2], buf[offset + 1] }, 0); + offset += 9; + } + + if (msglen == 0) + Console.WriteLine("msglen == 0"); + + + decoderInfo.maskStart = offset; + decoderInfo.msgLen = msglen; + decoderInfo.dataStart = offset + 4; + decoderInfoList.Add(decoderInfo); + totalMsgLen += msglen; + offset += 4; + + if (fin == false) + { + byte[] decodedClientMsgLen = new byte[sizeof(Int32)]; + byte[] clientMsgLenMask = new byte[4] { buf[decoderInfo.maskStart], buf[decoderInfo.maskStart + 1], buf[decoderInfo.maskStart + 2], buf[decoderInfo.maskStart + 3] }; + for (int i = 0; i < sizeof(Int32); ++i) + decodedClientMsgLen[i] = (byte)(buf[decoderInfo.dataStart + i] ^ clientMsgLenMask[i % 4]); + var clientMsgLen = (int)BitConverter.ToInt32(decodedClientMsgLen, 0); + if (clientMsgLen > bytesRead) + return false; + } + + var nextBufOffset = offset; + + while (fin == false) + { + nextBufOffset += msglen; + + fin = ((buf[nextBufOffset]) & 0b10000000) != 0; + + nextBufOffset++; + var nextMsgLen = buf[nextBufOffset] - 128; // & 0111 1111 + + offset++; + nextBufOffset++; + + if (nextMsgLen < 125) + { + nextBufOffset++; + offset++; + } + else if (nextMsgLen == 126) + { + offset += 3; + nextMsgLen = BitConverter.ToUInt16(new byte[] { buf[nextBufOffset + 1], buf[nextBufOffset] }, 0); + nextBufOffset += 2; + } + else if (nextMsgLen == 127) + { + offset += 9; + nextMsgLen = (int)BitConverter.ToUInt64(new byte[] { buf[nextBufOffset + 7], buf[nextBufOffset + 6], buf[nextBufOffset + 5], buf[nextBufOffset + 4], buf[nextBufOffset + 3], buf[nextBufOffset + 2], buf[nextBufOffset + 1], buf[nextBufOffset] }, 0); + nextBufOffset += 8; + } + + var nextDecoderInfo = new Decoder(); + nextDecoderInfo.msgLen = nextMsgLen; + nextDecoderInfo.maskStart = nextBufOffset; + nextDecoderInfo.dataStart = nextBufOffset + 4; + decoderInfoList.Add(nextDecoderInfo); + totalMsgLen += nextMsgLen; + offset += 4; + } + + completeWSCommand = true; + + var decodedIndex = 0; + decoded = new byte[totalMsgLen]; + for (int decoderListIdx = 0; decoderListIdx < decoderInfoList.Count; decoderListIdx++) + { + { + var decoderInfoElem = decoderInfoList[decoderListIdx]; + byte[] masks = new byte[4] { buf[decoderInfoElem.maskStart], buf[decoderInfoElem.maskStart + 1], buf[decoderInfoElem.maskStart + 2], buf[decoderInfoElem.maskStart + 3] }; + + for (int i = 0; i < decoderInfoElem.msgLen; ++i) + decoded[decodedIndex++] = (byte)(buf[decoderInfoElem.dataStart + i] ^ masks[i % 4]); + } + } + + offset += totalMsgLen; + readHead = offset; + } + + dcurr = d; + dcurr += 10; + dcurr += sizeof(int); // reserve space for size + int origPendingSeqNo = pendingSeqNo; + + dcurr += BatchHeader.Size; + start = 0; + msgnum = 0; + + fixed (byte* ptr1 = &decoded[4]) + { + var src = ptr1; + ref var header = ref Unsafe.AsRef<BatchHeader>(src); + int num = *(int*)(src + 4); + src += BatchHeader.Size; + Status status = default; + + for (msgnum = 0; msgnum < num; msgnum++) + { + var message = (MessageType)(*src++); + + switch (message) + { + case MessageType.Upsert: + case MessageType.UpsertAsync: + if ((int)(dend - dcurr) < 2) + SendAndReset(ref d, ref dend); + + var keyPtr = src; + status = session.Upsert(ref serializer.ReadKeyByRef(ref src), ref serializer.ReadValueByRef(ref src)); + + hrw.Write(message, ref dcurr, (int)(dend - dcurr)); + Write(ref status, ref dcurr, (int)(dend - dcurr)); + + if (subscribeKVBroker != null) + subscribeKVBroker.Publish(keyPtr); + break; + + case MessageType.Read: + case MessageType.ReadAsync: + if ((int)(dend - dcurr) < 2 + maxSizeSettings.MaxOutputSize) + SendAndReset(ref d, ref dend); + + long ctx = ((long)message << 32) | (long)pendingSeqNo; + status = session.Read(ref serializer.ReadKeyByRef(ref src), ref serializer.ReadInputByRef(ref src), + ref serializer.AsRefOutput(dcurr + 2, (int)(dend - dcurr)), ctx, 0); + + hrw.Write(message, ref dcurr, (int)(dend - dcurr)); + Write(ref status, ref dcurr, (int)(dend - dcurr)); + + if (status == core.Status.PENDING) + Write(pendingSeqNo++, ref dcurr, (int)(dend - dcurr)); + else if (status == core.Status.OK) + serializer.SkipOutput(ref dcurr); + + break; + + case MessageType.RMW: + case MessageType.RMWAsync: + if ((int)(dend - dcurr) < 2) + SendAndReset(ref d, ref dend); + + keyPtr = src; + + ctx = ((long)message << 32) | (long)pendingSeqNo; + status = session.RMW(ref serializer.ReadKeyByRef(ref src), ref serializer.ReadInputByRef(ref src), ctx); + + hrw.Write(message, ref dcurr, (int)(dend - dcurr)); + Write(ref status, ref dcurr, (int)(dend - dcurr)); + if (status == Status.PENDING) + Write(pendingSeqNo++, ref dcurr, (int)(dend - dcurr)); + + if (subscribeKVBroker != null) + subscribeKVBroker.Publish(keyPtr); + break; + + case MessageType.Delete: + case MessageType.DeleteAsync: + if ((int)(dend - dcurr) < 2) + SendAndReset(ref d, ref dend); + + keyPtr = src; + + status = session.Delete(ref serializer.ReadKeyByRef(ref src)); + + hrw.Write(message, ref dcurr, (int)(dend - dcurr)); + Write(ref status, ref dcurr, (int)(dend - dcurr)); + + if (subscribeKVBroker != null) + subscribeKVBroker.Publish(keyPtr); + break; + + case MessageType.SubscribeKV: + Debug.Assert(subscribeKVBroker != null); + + if ((int)(dend - dcurr) < 2 + maxSizeSettings.MaxOutputSize) + SendAndReset(ref d, ref dend); + + var keyStart = src; + ref Key key = ref serializer.ReadKeyByRef(ref src); + + var inputStart = src; + ref Input input = ref serializer.ReadInputByRef(ref src); + + int sid = subscribeKVBroker.Subscribe(ref keyStart, ref inputStart, this); + status = Status.PENDING; + + hrw.Write(message, ref dcurr, (int)(dend - dcurr)); + Write(ref status, ref dcurr, (int)(dend - dcurr)); + Write(sid, ref dcurr, (int)(dend - dcurr)); + serializer.Write(ref key, ref dcurr, (int)(dend - dcurr)); + + break; + + case MessageType.PSubscribeKV: + Debug.Assert(subscribeKVBroker != null); + + if ((int)(dend - dcurr) < 2 + maxSizeSettings.MaxOutputSize) + SendAndReset(ref d, ref dend); + + keyStart = src; + key = ref serializer.ReadKeyByRef(ref src); + + inputStart = src; + input = ref serializer.ReadInputByRef(ref src); + + sid = subscribeKVBroker.PSubscribe(ref keyStart, ref inputStart, this); + status = Status.PENDING; + + hrw.Write(message, ref dcurr, (int)(dend - dcurr)); + Write(ref status, ref dcurr, (int)(dend - dcurr)); + Write(sid, ref dcurr, (int)(dend - dcurr)); + serializer.Write(ref key, ref dcurr, (int)(dend - dcurr)); + + break; + + case MessageType.Publish: + Debug.Assert(subscribeBroker != null); + + if ((int)(dend - dcurr) < 2) + SendAndReset(ref d, ref dend); + + keyPtr = src; + key = ref serializer.ReadKeyByRef(ref src); + byte* valPtr = src; + ref Value val = ref serializer.ReadValueByRef(ref src); + int valueLength = (int)(src - valPtr); + + status = Status.OK; + hrw.Write(message, ref dcurr, (int)(dend - dcurr)); + Write(ref status, ref dcurr, (int)(dend - dcurr)); + + if (subscribeBroker != null) + subscribeBroker.Publish(keyPtr, valPtr, valueLength); + break; + + case MessageType.Subscribe: + Debug.Assert(subscribeBroker != null); + + if ((int)(dend - dcurr) < 2 + maxSizeSettings.MaxOutputSize) + SendAndReset(ref d, ref dend); + + keyStart = src; + serializer.ReadKeyByRef(ref src); + + sid = subscribeBroker.Subscribe(ref keyStart, this); + status = Status.PENDING; + hrw.Write(message, ref dcurr, (int)(dend - dcurr)); + Write(ref status, ref dcurr, (int)(dend - dcurr)); + Write(sid, ref dcurr, (int)(dend - dcurr)); + break; + + case MessageType.PSubscribe: + Debug.Assert(subscribeBroker != null); + + if ((int)(dend - dcurr) < 2 + maxSizeSettings.MaxOutputSize) + SendAndReset(ref d, ref dend); + + keyStart = src; + serializer.ReadKeyByRef(ref src); + + sid = subscribeBroker.PSubscribe(ref keyStart, this); + status = Status.PENDING; + hrw.Write(message, ref dcurr, (int)(dend - dcurr)); + Write(ref status, ref dcurr, (int)(dend - dcurr)); + Write(sid, ref dcurr, (int)(dend - dcurr)); + break; + + default: + throw new NotImplementedException(); + } + } + } + + if (origPendingSeqNo != pendingSeqNo) + session.CompletePending(true); + + // Send replies + if (msgnum - start > 0) + Send(d); + else + { + messageManager.Return(responseObject); + responseObject = null; + } + } + + return completeWSCommand; + } + + /// <inheritdoc /> + public unsafe override void Publish(ref byte* keyPtr, int keyLength, ref byte* valPtr, int valLength, ref byte* inputPtr, int sid) + => Publish(ref keyPtr, keyLength, ref valPtr, ref inputPtr, sid, false); + + /// <inheritdoc /> + public unsafe override void PrefixPublish(byte* prefixPtr, int prefixLength, ref byte* keyPtr, int keyLength, ref byte* valPtr, int valLength, ref byte* inputPtr, int sid) + => Publish(ref keyPtr, keyLength, ref valPtr, ref inputPtr, sid, true); + + private unsafe void Publish(ref byte* keyPtr, int keyLength, ref byte* valPtr, ref byte* inputPtr, int sid, bool prefix) + { + MessageType message; + + if (valPtr == null) + { + message = MessageType.SubscribeKV; + if (prefix) + message = MessageType.PSubscribeKV; + } + else + { + message = MessageType.Subscribe; + if (prefix) + message = MessageType.PSubscribe; + } + + var respObj = messageManager.GetReusableSeaaBuffer(); + + ref Key key = ref serializer.ReadKeyByRef(ref keyPtr); + + byte* d = respObj.bufferPtr; + var dend = d + respObj.buffer.Length; + var dcurr = d + sizeof(int); // reserve space for size + byte* outputDcurr; + + dcurr += BatchHeader.Size; + + long ctx = ((long)message << 32) | (long)sid; + + if (prefix) + outputDcurr = dcurr + 6 + keyLength; + else + outputDcurr = dcurr + 6; + + var status = Status.OK; + if (valPtr == null) + status = session.Read(ref key, ref serializer.ReadInputByRef(ref inputPtr), ref serializer.AsRefOutput(outputDcurr, (int)(dend - dcurr)), ctx, 0); + + if (status != Status.PENDING) + { + // Write six bytes (message | status | sid) + hrw.Write(message, ref dcurr, (int)(dend - dcurr)); + Write(ref status, ref dcurr, (int)(dend - dcurr)); + Write(sid, ref dcurr, (int)(dend - dcurr)); + if (prefix) + serializer.Write(ref key, ref dcurr, (int)(dend - dcurr)); + if (valPtr != null) + { + ref Value value = ref serializer.ReadValueByRef(ref valPtr); + serializer.Write(ref value, ref dcurr, (int)(dend - dcurr)); + } + else if (status == Status.OK) + serializer.SkipOutput(ref dcurr); + } + else + { + throw new Exception("Pending reads not supported with pub/sub"); + } + + // Send replies + var dstart = d + sizeof(int); + Unsafe.AsRef<BatchHeader>(dstart).NumMessages = 1; + Unsafe.AsRef<BatchHeader>(dstart).SeqNo = 0; + int payloadSize = (int)(dcurr - d); + // Set packet size in header + *(int*)respObj.bufferPtr = -(payloadSize - sizeof(int)); + try + { + messageManager.Send(socket, respObj, 0, payloadSize); + } + catch + { + messageManager.Return(respObj); + } + } + + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private unsafe bool WriteOpSeqId(ref int o, ref byte* dst, int length) + { + if (length < sizeof(int)) return false; + *(int*)dst = o; + dst += sizeof(int); + return true; + } + + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private unsafe bool Write(ref Status s, ref byte* dst, int length) + { + if (length < 1) return false; + *dst++ = (byte)s; + return true; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private unsafe bool Write(int seqNo, ref byte* dst, int length) + { + if (length < sizeof(int)) return false; + *(int*)dst = seqNo; + dst += sizeof(int); + return true; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void SendAndReset(ref byte* d, ref byte* dend) + { + Send(d); + GetResponseObject(); + d = responseObject.bufferPtr; + dend = d + responseObject.buffer.Length; + dcurr = d; + dcurr += 10; + dcurr += sizeof(int); // reserve space for size + start = msgnum; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void Send(byte* d) + { + if ((int)(dcurr - d) > 0) + { + int packetLen = (int)((dcurr - 10) - d); + var dtemp = d + 10; + var dstart = dtemp + sizeof(int); + + CreateSendPacketHeader(ref d, packetLen); + + *(int*)dtemp = (packetLen - sizeof(int)); + *(int*)dstart = 0; + *(int*)(dstart + sizeof(int)) = (msgnum - start); + SendResponse((int)(d - responseObject.bufferPtr), (int)(dcurr - d)); + } + } + } +} diff --git a/cs/remote/test/FASTER.remote.test/FixedLenBinaryPubSubTests.cs b/cs/remote/test/FASTER.remote.test/FixedLenBinaryPubSubTests.cs new file mode 100644 index 000000000..006c9a464 --- /dev/null +++ b/cs/remote/test/FASTER.remote.test/FixedLenBinaryPubSubTests.cs @@ -0,0 +1,92 @@ +using FASTER.core; +using FASTER.server; +using NUnit.Framework; + +namespace FASTER.remote.test +{ + [TestFixture] + public class FixedLenBinaryPubSubTests + { + FixedLenServer<long, long, long, long, AdvancedSimpleFunctions<long, long, long>> server; + FixedLenClient<long, long> client; + + [SetUp] + public void Setup() + { + server = TestUtils.CreateFixedLenServer(TestContext.CurrentContext.TestDirectory + "/FixedLenBinaryPubSubTests", (a, b) => a + b, enablePubSub: true); + server.Start(); + client = new FixedLenClient<long, long>(); + } + + [TearDown] + public void TearDown() + { + client.Dispose(); + server.Dispose(); + } + + [Test] + [Repeat(10)] + public void SubscribeKVTest() + { + var f = new FixedLenClientFunctions(); + using var session = client.GetSession(f); + using var subSession = client.GetSession(f); + + subSession.SubscribeKV(10); + subSession.CompletePending(true); + session.Upsert(10, 23); + session.CompletePending(true); + + f.WaitSubscribe(); + } + + [Test] + [Repeat(10)] + public void PrefixSubscribeKVTest() + { + var f = new FixedLenClientFunctions(); + using var session = client.GetSession(f); + using var subSession = client.GetSession(f); + + subSession.PSubscribeKV(10); + subSession.CompletePending(true); + session.Upsert(10, 23); + session.CompletePending(true); + + f.WaitSubscribe(); + } + + [Test] + [Repeat(10)] + public void SubscribeTest() + { + var f = new FixedLenClientFunctions(); + using var session = client.GetSession(f); + using var subSession = client.GetSession(f); + + subSession.Subscribe(10); + subSession.CompletePending(true); + session.Publish(10, 23); + session.CompletePending(true); + + f.WaitSubscribe(); + } + + [Test] + [Repeat(10)] + public void PrefixSubscribeTest() + { + var f = new FixedLenClientFunctions(); + using var session = client.GetSession(f); + using var subSession = client.GetSession(f); + + subSession.PSubscribe(10); + subSession.CompletePending(true); + session.Publish(10, 23); + session.CompletePending(true); + + f.WaitSubscribe(); + } + } +} \ No newline at end of file diff --git a/cs/remote/test/FASTER.remote.test/FixedLenBinaryTests.cs b/cs/remote/test/FASTER.remote.test/FixedLenBinaryTests.cs index c41ca00c6..cffa42f26 100644 --- a/cs/remote/test/FASTER.remote.test/FixedLenBinaryTests.cs +++ b/cs/remote/test/FASTER.remote.test/FixedLenBinaryTests.cs @@ -1,17 +1,20 @@ -using NUnit.Framework; +using FASTER.core; +using FASTER.server; +using NUnit.Framework; namespace FASTER.remote.test { [TestFixture] public class FixedLenBinaryTests { - FixedLenServer<long, long> server; + FixedLenServer<long, long, long, long, AdvancedSimpleFunctions<long, long, long>> server; FixedLenClient<long, long> client; [SetUp] public void Setup() { - server = new FixedLenServer<long, long>(TestContext.CurrentContext.TestDirectory + "/FixedLenBinaryTests", (a, b) => a + b); + server = TestUtils.CreateFixedLenServer(TestContext.CurrentContext.TestDirectory + "/FixedLenBinaryTests", (a, b) => a + b); + server.Start(); client = new FixedLenClient<long, long>(); } diff --git a/cs/remote/test/FASTER.remote.test/FixedLenClient.cs b/cs/remote/test/FASTER.remote.test/FixedLenClient.cs index 893c21e78..fd92802ba 100644 --- a/cs/remote/test/FASTER.remote.test/FixedLenClient.cs +++ b/cs/remote/test/FASTER.remote.test/FixedLenClient.cs @@ -2,6 +2,7 @@ using FASTER.common; using NUnit.Framework; using System; +using System.Threading; namespace FASTER.remote.test { @@ -24,6 +25,9 @@ public void Dispose() public ClientSession<long, long, long, long, long, FixedLenClientFunctions, FixedLenSerializer<long, long, long, long>> GetSession() => client.NewSession<long, long, long, FixedLenClientFunctions, FixedLenSerializer<long, long, long, long>>(new FixedLenClientFunctions(), WireFormat.DefaultFixedLenKV); + + public ClientSession<long, long, long, long, long, FixedLenClientFunctions, FixedLenSerializer<long, long, long, long>> GetSession(FixedLenClientFunctions f) + => client.NewSession<long, long, long, FixedLenClientFunctions, FixedLenSerializer<long, long, long, long>>(f, WireFormat.DefaultFixedLenKV); } /// <summary> @@ -31,10 +35,32 @@ public ClientSession<long, long, long, long, long, FixedLenClientFunctions, Fixe /// </summary> sealed class FixedLenClientFunctions : CallbackFunctionsBase<long, long, long, long, long> { + readonly ManualResetEvent evt = new ManualResetEvent(false); + public override void ReadCompletionCallback(ref long key, ref long input, ref long output, long ctx, Status status) { Assert.IsTrue(status == Status.OK); Assert.IsTrue(output == ctx); } + + /// <inheritdoc /> + public override void SubscribeKVCallback(ref long key, ref long input, ref long output, long ctx, Status status) + { + Assert.IsTrue(status == Status.OK); + Assert.IsTrue(output == 23); + evt.Set(); + } + + /// <inheritdoc /> + public override void SubscribeCallback(ref long key, ref long value, long ctx) + { + Assert.IsTrue(value == 23); + evt.Set(); + } + + public void WaitSubscribe() + { + evt.WaitOne(); + } } } diff --git a/cs/remote/test/FASTER.remote.test/FixedLenServer.cs b/cs/remote/test/FASTER.remote.test/FixedLenServer.cs deleted file mode 100644 index 71fac40df..000000000 --- a/cs/remote/test/FASTER.remote.test/FixedLenServer.cs +++ /dev/null @@ -1,70 +0,0 @@ -using FASTER.core; -using FASTER.server; -using FASTER.common; -using System; -using System.IO; - -namespace FASTER.remote.test -{ - class FixedLenServer<Key, Value> : IDisposable - where Key : unmanaged - where Value : unmanaged - { - readonly string folderName; - readonly FasterServer server; - readonly FasterKV<Key, Value> store; - - public FixedLenServer(string folderName, Func<Value, Value, Value> merger, string address = "127.0.0.1", int port = 33278) - { - this.folderName = folderName; - GetSettings(folderName, out var logSettings, out var checkpointSettings, out var indexSize); - - // We use blittable structs Key and Value to construct a costomized server for fixed-length types - store = new FasterKV<Key, Value>(indexSize, logSettings, checkpointSettings); - - // Create session provider for FixedLen - var provider = new FasterKVProvider<Key, Value, Value, Value, FixedLenServerFunctions<Key, Value>, FixedLenSerializer<Key, Value, Value, Value>>(store, e => new FixedLenServerFunctions<Key, Value>(merger)); - - server = new FasterServer(address, port); - server.Register(WireFormat.DefaultFixedLenKV, provider); - server.Start(); - } - - public void Dispose() - { - server.Dispose(); - store.Dispose(); - new DirectoryInfo(folderName).Delete(true); - } - - private static void GetSettings(string LogDir, out LogSettings logSettings, out CheckpointSettings checkpointSettings, out int indexSize) - { - logSettings = new LogSettings { PreallocateLog = false }; - - logSettings.PageSizeBits = 20; - logSettings.MemorySizeBits = 25; - logSettings.SegmentSizeBits = 30; - indexSize = 1 << 20; - - var device = LogDir == "" ? new NullDevice() : Devices.CreateLogDevice(LogDir + "/hlog", preallocateFile: false); - logSettings.LogDevice = device; - - string CheckpointDir = null; - if (CheckpointDir == null && LogDir == null) - checkpointSettings = null; - else - checkpointSettings = new CheckpointSettings - { - CheckPointType = CheckpointType.FoldOver, - CheckpointDir = CheckpointDir ?? (LogDir + "/checkpoints") - }; - } - } - - sealed class FixedLenServerFunctions<Key, Value> : SimpleFunctions<Key, Value, long> - { - public FixedLenServerFunctions(Func<Value, Value, Value> merger) : base(merger) - { - } - } -} diff --git a/cs/remote/test/FASTER.remote.test/TestUtils.cs b/cs/remote/test/FASTER.remote.test/TestUtils.cs new file mode 100644 index 000000000..b771c17dd --- /dev/null +++ b/cs/remote/test/FASTER.remote.test/TestUtils.cs @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +using System; +using FASTER.core; +using FASTER.server; + +namespace FASTER.remote.test +{ + internal static class TestUtils + { + /// <summary> + /// Address + /// </summary> + public static string Address = "127.0.0.1"; + + /// <summary> + /// Port + /// </summary> + public static int Port = 33278; + + /// <summary> + /// Create VarLenServer + /// </summary> + public static FixedLenServer<long, long, long, long, AdvancedSimpleFunctions<long, long, long>> CreateFixedLenServer(string logDir, Func<long, long, long> merger, bool enablePubSub = true, bool tryRecover = false) + { + ServerOptions opts = new() + { + LogDir = logDir, + Address = Address, + Port = Port, + EnablePubSub = enablePubSub, + Recover = tryRecover, + IndexSize = "1m", + }; + return new FixedLenServer<long, long, long, long, AdvancedSimpleFunctions<long, long, long>>(opts, e => new AdvancedSimpleFunctions<long, long, long>(merger)); + } + + /// <summary> + /// Create VarLenServer + /// </summary> + public static VarLenServer CreateVarLenServer(string logDir, bool enablePubSub = true, bool tryRecover = false) + { + ServerOptions opts = new() + { + LogDir = logDir, + Address = Address, + Port = Port, + EnablePubSub = enablePubSub, + Recover = tryRecover, + IndexSize = "1m", + }; + return new VarLenServer(opts); + } + } +} \ No newline at end of file diff --git a/cs/remote/test/FASTER.remote.test/VarLenBinaryPubSubTests.cs b/cs/remote/test/FASTER.remote.test/VarLenBinaryPubSubTests.cs new file mode 100644 index 000000000..db49151f4 --- /dev/null +++ b/cs/remote/test/FASTER.remote.test/VarLenBinaryPubSubTests.cs @@ -0,0 +1,127 @@ +using System; +using NUnit.Framework; +using FASTER.server; + +namespace FASTER.remote.test +{ + [TestFixture] + public class VarLenBinaryPubSubTests + { + VarLenServer server; + VarLenMemoryClient client; + + [SetUp] + public void Setup() + { + server = TestUtils.CreateVarLenServer(TestContext.CurrentContext.TestDirectory + "/VarLenBinaryTests", enablePubSub: true); + server.Start(); + client = new VarLenMemoryClient(); + } + + [TearDown] + public void TearDown() + { + client.Dispose(); + server.Dispose(); + } + + [Test] + [Repeat(10)] + public void SubscribeKVTest() + { + Random r = new Random(23); + + var f = new MemoryFunctions(); + using var session = client.GetSession(f); + using var subSession = client.GetSession(f); + var key = new Memory<int>(new int[2 + r.Next(50)]); + var value = new Memory<int>(new int[1 + r.Next(50)]); + key.Span[0] = r.Next(100); + key.Span[1] = value.Length; + value.Span.Fill(key.Span[0]); + + subSession.SubscribeKV(key); + subSession.CompletePending(true); + session.Upsert(key, value); + session.CompletePending(true); + + f.WaitSubscribe(); + } + + [Test] + [Repeat(10)] + public void PSubscribeKVTest() + { + Random r = new Random(23); + var f = new MemoryFunctions(); + using var session = client.GetSession(f); + using var subSession = client.GetSession(f); + var key = new Memory<int>(new int[2 + r.Next(50)]); + var value = new Memory<int>(new int[1 + r.Next(50)]); + int randomNum = r.Next(100); + key.Span[0] = randomNum; + key.Span[1] = value.Length; + value.Span.Fill(key.Span[0]); + + var upsertKey = new Memory<int>(new int[100]); + upsertKey.Span[0] = randomNum; + upsertKey.Span[1] = value.Length; + + subSession.PSubscribeKV(key); + subSession.CompletePending(true); + session.Upsert(upsertKey, value); + session.CompletePending(true); + + f.WaitSubscribe(); + } + + [Test] + [Repeat(10)] + public void SubscribeTest() + { + Random r = new Random(23); + var f = new MemoryFunctions(); + using var session = client.GetSession(f); + using var subSession = client.GetSession(f); + var key = new Memory<int>(new int[2 + r.Next(50)]); + var value = new Memory<int>(new int[1 + r.Next(50)]); + key.Span[0] = r.Next(100); + key.Span[1] = value.Length; + value.Span.Fill(key.Span[0]); + + subSession.Subscribe(key); + subSession.CompletePending(true); + session.Publish(key, value); + session.CompletePending(true); + + f.WaitSubscribe(); + } + + [Test] + [Repeat(10)] + public void PSubscribeTest() + { + Random r = new Random(23); + var f = new MemoryFunctions(); + using var session = client.GetSession(f); + using var subSession = client.GetSession(f); + var key = new Memory<int>(new int[2 + r.Next(50)]); + var value = new Memory<int>(new int[1 + r.Next(50)]); + int randomNum = r.Next(100); + key.Span[0] = randomNum; + key.Span[1] = value.Length; + value.Span.Fill(key.Span[0]); + + var upsertKey = new Memory<int>(new int[100]); + upsertKey.Span[0] = randomNum; + upsertKey.Span[1] = value.Length; + + subSession.PSubscribe(key); + subSession.CompletePending(true); + session.Publish(upsertKey, value); + session.CompletePending(true); + + f.WaitSubscribe(); + } + } +} \ No newline at end of file diff --git a/cs/remote/test/FASTER.remote.test/VarLenBinaryTests.cs b/cs/remote/test/FASTER.remote.test/VarLenBinaryTests.cs index 54632ec51..bcad9f689 100644 --- a/cs/remote/test/FASTER.remote.test/VarLenBinaryTests.cs +++ b/cs/remote/test/FASTER.remote.test/VarLenBinaryTests.cs @@ -1,4 +1,5 @@ using System; +using FASTER.server; using NUnit.Framework; namespace FASTER.remote.test @@ -12,7 +13,8 @@ public class VarLenBinaryTests [SetUp] public void Setup() { - server = new VarLenServer(TestContext.CurrentContext.TestDirectory + "/VarLenBinaryTests"); + server = TestUtils.CreateVarLenServer(TestContext.CurrentContext.TestDirectory + "/VarLenBinaryTests"); + server.Start(); client = new VarLenMemoryClient(); } diff --git a/cs/remote/test/FASTER.remote.test/VarLenMemoryClient.cs b/cs/remote/test/FASTER.remote.test/VarLenMemoryClient.cs index 2b161d1cf..33a2c3ff8 100644 --- a/cs/remote/test/FASTER.remote.test/VarLenMemoryClient.cs +++ b/cs/remote/test/FASTER.remote.test/VarLenMemoryClient.cs @@ -1,5 +1,6 @@ using System; using System.Buffers; +using System.Threading; using FASTER.client; using FASTER.common; using NUnit.Framework; @@ -22,6 +23,9 @@ public void Dispose() public ClientSession<ReadOnlyMemory<int>, ReadOnlyMemory<int>, ReadOnlyMemory<int>, (IMemoryOwner<int>, int), long, MemoryFunctions, MemoryParameterSerializer<int>> GetSession() => client.NewSession<ReadOnlyMemory<int>, (IMemoryOwner<int>, int), long, MemoryFunctions, MemoryParameterSerializer<int>>(new MemoryFunctions(), WireFormat.DefaultVarLenKV, new MemoryParameterSerializer<int>()); + + public ClientSession<ReadOnlyMemory<int>, ReadOnlyMemory<int>, ReadOnlyMemory<int>, (IMemoryOwner<int>, int), long, MemoryFunctions, MemoryParameterSerializer<int>> GetSession(MemoryFunctions f) + => client.NewSession<ReadOnlyMemory<int>, (IMemoryOwner<int>, int), long, MemoryFunctions, MemoryParameterSerializer<int>>(f, WireFormat.DefaultVarLenKV, new MemoryParameterSerializer<int>()); } /// <summary> @@ -29,6 +33,8 @@ public void Dispose() /// </summary> public class MemoryFunctions : ICallbackFunctions<ReadOnlyMemory<int>, ReadOnlyMemory<int>, ReadOnlyMemory<int>, (IMemoryOwner<int>, int), long> { + readonly ManualResetEvent evt = new ManualResetEvent(false); + /// <inheritdoc /> public virtual void DeleteCompletionCallback(ref ReadOnlyMemory<int> key, long ctx) { } @@ -57,5 +63,50 @@ public virtual void RMWCompletionCallback(ref ReadOnlyMemory<int> key, ref ReadO /// <inheritdoc /> public virtual void UpsertCompletionCallback(ref ReadOnlyMemory<int> key, ref ReadOnlyMemory<int> value, long ctx) { } + + /// <inheritdoc /> + public virtual void SubscribeKVCallback(ref ReadOnlyMemory<int> key, ref ReadOnlyMemory<int> input, ref (IMemoryOwner<int>, int) output, long ctx, Status status) + { + try + { + Assert.IsTrue(status == Status.OK); + int check = key.Span[0]; + int len = key.Span[1]; + Assert.IsTrue(output.Item2 == len); + Memory<int> expected = new Memory<int>(new int[len]); + expected.Span.Fill(check); + Assert.IsTrue(expected.Span.SequenceEqual(output.Item1.Memory.Span.Slice(0, output.Item2))); + } + finally + { + evt.Set(); + output.Item1.Dispose(); + } + } + + /// <inheritdoc /> + public void PublishCompletionCallback(ref ReadOnlyMemory<int> key, ref ReadOnlyMemory<int> value, long ctx) { } + + /// <inheritdoc /> + public void SubscribeCallback(ref ReadOnlyMemory<int> key, ref ReadOnlyMemory<int> value, long ctx) + { + try + { + int check = key.Span[0]; + int len = key.Span[1]; + Memory<int> expected = new Memory<int>(new int[len]); + expected.Span.Fill(check); + Assert.IsTrue(expected.Span.SequenceEqual(value.Span.Slice(0, value.Length))); + } + finally + { + evt.Set(); + } + } + + public void WaitSubscribe() + { + evt.WaitOne(); + } } } diff --git a/cs/remote/test/FASTER.remote.test/VarLenServer.cs b/cs/remote/test/FASTER.remote.test/VarLenServer.cs deleted file mode 100644 index bf25a9416..000000000 --- a/cs/remote/test/FASTER.remote.test/VarLenServer.cs +++ /dev/null @@ -1,62 +0,0 @@ -using System; -using System.IO; -using FASTER.core; -using FASTER.server; -using FASTER.common; - -namespace FASTER.remote.test -{ - class VarLenServer : IDisposable - { - readonly string folderName; - readonly FasterServer server; - readonly FasterKV<SpanByte, SpanByte> store; - readonly SpanByteFasterKVProvider provider; - - public VarLenServer(string folderName, string address = "127.0.0.1", int port = 33278) - { - this.folderName = folderName; - GetSettings(folderName, out var logSettings, out var checkpointSettings, out var indexSize); - - store = new FasterKV<SpanByte, SpanByte>(indexSize, logSettings, checkpointSettings); - - // Create session provider for VarLen - provider = new SpanByteFasterKVProvider(store); - - server = new FasterServer(address, port); - server.Register(WireFormat.DefaultVarLenKV, provider); - server.Start(); - } - - public void Dispose() - { - server.Dispose(); - provider.Dispose(); - store.Dispose(); - new DirectoryInfo(folderName).Delete(true); - } - - private static void GetSettings(string LogDir, out LogSettings logSettings, out CheckpointSettings checkpointSettings, out int indexSize) - { - logSettings = new LogSettings { PreallocateLog = false }; - - logSettings.PageSizeBits = 20; - logSettings.MemorySizeBits = 25; - logSettings.SegmentSizeBits = 30; - indexSize = 1 << 20; - - var device = LogDir == "" ? new NullDevice() : Devices.CreateLogDevice(LogDir + "/hlog", preallocateFile: false); - logSettings.LogDevice = device; - - string CheckpointDir = null; - if (CheckpointDir == null && LogDir == null) - checkpointSettings = null; - else - checkpointSettings = new CheckpointSettings - { - CheckPointType = CheckpointType.FoldOver, - CheckpointDir = CheckpointDir ?? (LogDir + "/checkpoints") - }; - } - } -} diff --git a/cs/src/core/Allocator/AllocatorBase.cs b/cs/src/core/Allocator/AllocatorBase.cs index 105178234..37f745f50 100644 --- a/cs/src/core/Allocator/AllocatorBase.cs +++ b/cs/src/core/Allocator/AllocatorBase.cs @@ -3,6 +3,7 @@ using System; using System.Diagnostics; +using System.IO; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Threading; @@ -398,6 +399,12 @@ public abstract (int, int) GetInitialRecordSize<Input, FasterSession>(ref Key ke /// <param name="localSegmentOffsets"></param> protected abstract void WriteAsyncToDevice<TContext>(long startPage, long flushPage, int pageSize, DeviceIOCompletionCallback callback, PageAsyncFlushResult<TContext> result, IDevice device, IDevice objectLogDevice, long[] localSegmentOffsets); + private protected void VerifyCompatibleSectorSize(IDevice device) + { + if (this.sectorSize % device.SectorSize != 0) + throw new FasterException($"Allocator with sector size {sectorSize} cannot flush to device with sector size {device.SectorSize}"); + } + /// <summary> /// Delta flush /// </summary> @@ -466,7 +473,7 @@ internal unsafe virtual void AsyncFlushDeltaToDevice(long startAddress, long end deltaLog.Seal(destOffset); } - internal unsafe void ApplyDelta(DeltaLog log, long startPage, long endPage) + internal unsafe void ApplyDelta(DeltaLog log, long startPage, long endPage, long recoverTo) { if (log == null) return; @@ -474,22 +481,49 @@ internal unsafe void ApplyDelta(DeltaLog log, long startPage, long endPage) long endLogicalAddress = GetStartLogicalAddress(endPage); log.Reset(); - while (log.GetNext(out long physicalAddress, out int entryLength, out int type)) + while (log.GetNext(out long physicalAddress, out int entryLength, out var type)) { - if (type != 0) continue; // consider only delta records - long endAddress = physicalAddress + entryLength; - while (physicalAddress < endAddress) + switch (type) { - long address = *(long*)physicalAddress; - physicalAddress += sizeof(long); - int size = *(int*)physicalAddress; - physicalAddress += sizeof(int); - if (address >= startLogicalAddress && address < endLogicalAddress) - { - var destination = GetPhysicalAddress(address); - Buffer.MemoryCopy((void*)physicalAddress, (void*)destination, size, size); - } - physicalAddress += size; + case DeltaLogEntryType.DELTA: + // Delta records + long endAddress = physicalAddress + entryLength; + while (physicalAddress < endAddress) + { + long address = *(long*)physicalAddress; + physicalAddress += sizeof(long); + int size = *(int*)physicalAddress; + physicalAddress += sizeof(int); + if (address >= startLogicalAddress && address < endLogicalAddress) + { + var destination = GetPhysicalAddress(address); + Buffer.MemoryCopy((void*)physicalAddress, (void*)destination, size, size); + } + physicalAddress += size; + } + break; + case DeltaLogEntryType.CHECKPOINT_METADATA: + if (recoverTo != -1) + { + // Only read metadata if we need to stop at a specific version + var metadata = new byte[entryLength]; + unsafe + { + fixed (byte* m = metadata) + Buffer.MemoryCopy((void*) physicalAddress, m, entryLength, entryLength); + } + + HybridLogRecoveryInfo recoveryInfo = new(); + using StreamReader s = new(new MemoryStream(metadata)); + recoveryInfo.Initialize(s); + // Finish recovery if only specific versions are requested + if (recoveryInfo.version == recoverTo) return; + } + + break; + default: + throw new FasterException("Unexpected entry type"); + } } } @@ -709,7 +743,7 @@ public AllocatorBase(LogSettings settings, IFasterEqualityComparer<Key> comparer SegmentBufferSize = 1 + (LogTotalSizeBytes / SegmentSize < 1 ? 1 : (int)(LogTotalSizeBytes / SegmentSize)); if (SegmentSize < PageSize) - throw new FasterException("Segment must be at least of page size"); + throw new FasterException($"Segment ({SegmentSize.ToString()}) must be at least of page size ({PageSize.ToString()})"); PageStatusIndicator = new FullPageStatus[BufferSize]; @@ -739,7 +773,7 @@ public AllocatorBase(LogSettings settings, IFasterEqualityComparer<Key> comparer /// <param name="firstValidAddress"></param> protected void Initialize(long firstValidAddress) { - Debug.Assert(firstValidAddress <= PageSize); + Debug.Assert(firstValidAddress <= PageSize, $"firstValidAddress {firstValidAddress} shoulld be <= PageSize {PageSize}"); bufferPool = new SectorAlignedBufferPool(1, sectorSize); @@ -1855,9 +1889,10 @@ private unsafe void AsyncGetFromDiskCallback(uint errorCode, uint numBytes, obje } else { - // Keys are not same. I/O is not complete + // Keys are not same. I/O is not complete. Follow the chain to the previous record and issue a request for it if + // it is in the range to resolve, else surface "not found". ctx.logicalAddress = GetInfoFromBytePointer(record).PreviousAddress; - if (ctx.logicalAddress >= BeginAddress) + if (ctx.logicalAddress >= BeginAddress && ctx.logicalAddress >= ctx.minAddress) { ctx.record.Return(); ctx.record = ctx.objBuffer = default; @@ -1910,7 +1945,7 @@ private void AsyncFlushPageCallback(uint errorCode, uint numBytes, object contex { if (errorCode != 0) { - errorList.Add(result.fromAddress); + errorList.Add(result.fromAddress, errorCode); } Utility.MonotonicUpdate(ref PageStatusIndicator[result.page % BufferSize].LastFlushedUntilAddress, result.untilAddress, out _); ShiftFlushedUntilAddress(); diff --git a/cs/src/core/Allocator/AsyncIOContext.cs b/cs/src/core/Allocator/AsyncIOContext.cs index 0342815e6..cc3c90509 100644 --- a/cs/src/core/Allocator/AsyncIOContext.cs +++ b/cs/src/core/Allocator/AsyncIOContext.cs @@ -37,6 +37,11 @@ public unsafe struct AsyncIOContext<Key, Value> /// </summary> public long logicalAddress; + /// <summary> + /// Minimum Logical address to resolve Key in + /// </summary> + public long minAddress; + /// <summary> /// Record buffer /// </summary> diff --git a/cs/src/core/Allocator/BlittableAllocator.cs b/cs/src/core/Allocator/BlittableAllocator.cs index 1bad1c9db..1189d3cc5 100644 --- a/cs/src/core/Allocator/BlittableAllocator.cs +++ b/cs/src/core/Allocator/BlittableAllocator.cs @@ -141,7 +141,7 @@ internal override void AllocatePage(int index) } var adjustedSize = PageSize + 2 * sectorSize; - byte[] tmp = new byte[adjustedSize]; + var tmp = new byte[adjustedSize]; Array.Clear(tmp, 0, adjustedSize); handles[index] = GCHandle.Alloc(tmp, GCHandleType.Pinned); @@ -181,6 +181,7 @@ protected override void WriteAsyncToDevice<TContext> (long startPage, long flushPage, int pageSize, DeviceIOCompletionCallback callback, PageAsyncFlushResult<TContext> asyncResult, IDevice device, IDevice objectLogDevice, long[] localSegmentOffsets) { + base.VerifyCompatibleSectorSize(device); var alignedPageSize = (pageSize + (sectorSize - 1)) & ~(sectorSize - 1); WriteAsync((IntPtr)pointers[flushPage % BufferSize], diff --git a/cs/src/core/Allocator/ErrorList.cs b/cs/src/core/Allocator/ErrorList.cs index 59d8b48ae..a776e9af3 100644 --- a/cs/src/core/Allocator/ErrorList.cs +++ b/cs/src/core/Allocator/ErrorList.cs @@ -8,14 +8,14 @@ namespace FASTER.core { class ErrorList { - private readonly List<long> errorList; + private readonly List<(long address, uint errorCode)> errorList; - public ErrorList() => errorList = new List<long>(); + public ErrorList() => errorList = new(); - public void Add(long address) + public void Add(long address, uint errorCode) { lock (errorList) - errorList.Add(address); + errorList.Add((address, errorCode)); } public uint CheckAndWait(long oldFlushedUntilAddress, long currentFlushedUntilAddress) @@ -29,11 +29,11 @@ public uint CheckAndWait(long oldFlushedUntilAddress, long currentFlushedUntilAd { for (int i = 0; i < errorList.Count; i++) { - if (errorList[i] >= oldFlushedUntilAddress && errorList[i] < currentFlushedUntilAddress) + if (errorList[i].address >= oldFlushedUntilAddress && errorList[i].address < currentFlushedUntilAddress) { - errorCode = 1; + errorCode = errorList[i].errorCode; } - else if (errorList[i] < oldFlushedUntilAddress) + else if (errorList[i].address < oldFlushedUntilAddress) { done = false; // spin barrier for other threads during exception Thread.Yield(); @@ -50,7 +50,7 @@ public void RemoveUntil(long currentFlushedUntilAddress) { for (int i = 0; i < errorList.Count; i++) { - if (errorList[i] < currentFlushedUntilAddress) + if (errorList[i].address < currentFlushedUntilAddress) { errorList.RemoveAt(i); } diff --git a/cs/src/core/Allocator/GenericAllocator.cs b/cs/src/core/Allocator/GenericAllocator.cs index ffa13e010..c00d2811d 100644 --- a/cs/src/core/Allocator/GenericAllocator.cs +++ b/cs/src/core/Allocator/GenericAllocator.cs @@ -21,7 +21,6 @@ public struct Record<Key, Value> public Value value; } - public unsafe sealed class GenericAllocator<Key, Value> : AllocatorBase<Key, Value> { // Circular buffer definition @@ -29,8 +28,8 @@ public unsafe sealed class GenericAllocator<Key, Value> : AllocatorBase<Key, Val // Object log related variables private readonly IDevice objectLogDevice; - // Size of object chunks beign written to storage - private const int ObjectBlockSize = 100 * (1 << 20); + // Size of object chunks being written to storage + private readonly int ObjectBlockSize = 100 * (1 << 20); // Tail offsets per segment, in object log public readonly long[] segmentOffsets; // Record sizes @@ -274,6 +273,9 @@ protected override void WriteAsyncToDevice<TContext> (long startPage, long flushPage, int pageSize, DeviceIOCompletionCallback callback, PageAsyncFlushResult<TContext> asyncResult, IDevice device, IDevice objectLogDevice, long[] localSegmentOffsets) { + base.VerifyCompatibleSectorSize(device); + base.VerifyCompatibleSectorSize(objectLogDevice); + bool epochTaken = false; if (!epoch.ThisInstanceProtected()) { @@ -303,8 +305,6 @@ protected override void WriteAsyncToDevice<TContext> } } - - internal override void ClearPage(long page, int offset) { Array.Clear(values[page % BufferSize], offset / recordSize, values[page % BufferSize].Length - offset / recordSize); @@ -741,7 +741,7 @@ internal void AsyncReadPagesFromDeviceToFrame<TContext>( } - #region Page handlers for objects +#region Page handlers for objects /// <summary> /// Deseialize part of page from stream /// </summary> @@ -981,7 +981,7 @@ public override bool ValueHasObjects() { return SerializerSettings.valueSerializer != null; } - #endregion +#endregion public override IHeapContainer<Key> GetKeyContainer(ref Key key) => new StandardHeapContainer<Key>(ref key); public override IHeapContainer<Value> GetValueContainer(ref Value value) => new StandardHeapContainer<Value>(ref value); diff --git a/cs/src/core/Allocator/VarLenBlittableAllocator.cs b/cs/src/core/Allocator/VarLenBlittableAllocator.cs index f33df6430..9c8db19f4 100644 --- a/cs/src/core/Allocator/VarLenBlittableAllocator.cs +++ b/cs/src/core/Allocator/VarLenBlittableAllocator.cs @@ -286,6 +286,7 @@ protected override void WriteAsyncToDevice<TContext> (long startPage, long flushPage, int pageSize, DeviceIOCompletionCallback callback, PageAsyncFlushResult<TContext> asyncResult, IDevice device, IDevice objectLogDevice, long[] localSegmentOffsets) { + base.VerifyCompatibleSectorSize(device); var alignedPageSize = (pageSize + (sectorSize - 1)) & ~(sectorSize - 1); WriteAsync((IntPtr)pointers[flushPage % BufferSize], @@ -304,7 +305,6 @@ public override long GetStartLogicalAddress(long page) return page << LogPageSizeBits; } - /// <summary> /// Get first valid logical address /// </summary> diff --git a/cs/src/core/Async/ReadAsync.cs b/cs/src/core/Async/ReadAsync.cs index 9df14986e..a6b472c82 100644 --- a/cs/src/core/Async/ReadAsync.cs +++ b/cs/src/core/Async/ReadAsync.cs @@ -156,7 +156,7 @@ internal ValueTask<ReadAsyncResult<Input, Output, Context>> ReadAsync<Input, Out ref Key key, ref Input input, long startAddress, Context context, long serialNo, CancellationToken token, byte operationFlags = 0) { var pcontext = default(PendingContext<Input, Output, Context>); - pcontext.operationFlags = operationFlags; + pcontext.SetOperationFlags(operationFlags, startAddress); var diskRequest = default(AsyncIOContext<Key, Value>); Output output = default; diff --git a/cs/src/core/ClientSession/AdvancedClientSession.cs b/cs/src/core/ClientSession/AdvancedClientSession.cs index e4862605d..b770aae93 100644 --- a/cs/src/core/ClientSession/AdvancedClientSession.cs +++ b/cs/src/core/ClientSession/AdvancedClientSession.cs @@ -339,7 +339,6 @@ public ValueTask<FasterKV<Key, Value>.ReadAsyncResult<Input, Output, Context>> R return fht.ReadAsync(this.FasterSession, this.ctx, ref key, ref input, Constants.kInvalidAddress, context, serialNo, token); } - /// <summary> /// Async read operation. May return uncommitted results; to ensure reading of committed results, complete the read and then call WaitForCommitAsync. /// </summary> @@ -703,7 +702,7 @@ public ValueTask<FasterKV<Key, Value>.DeleteAsyncResult<Input, Output, Context>> /// Experimental feature /// Checks whether specified record is present in memory /// (between HeadAddress and tail, or between fromAddress - /// and tail) + /// and tail), including tombstones. /// </summary> /// <param name="key">Key of the record.</param> /// <param name="logicalAddress">Logical address of record, if found</param> diff --git a/cs/src/core/ClientSession/ClientSession.cs b/cs/src/core/ClientSession/ClientSession.cs index aaacabe08..965b00e94 100644 --- a/cs/src/core/ClientSession/ClientSession.cs +++ b/cs/src/core/ClientSession/ClientSession.cs @@ -140,6 +140,11 @@ private void UpdateVarlen(ref IVariableLengthStruct<Value, Input> variableLength /// </summary> public long SerialNo => ctx.serialNum; + /// <summary> + /// Current version number of the session + /// </summary> + public long Version => ctx.version; + /// <summary> /// Dispose session /// </summary> @@ -239,14 +244,35 @@ public Status Read(Key key, out Output output, Context userContext = default, lo return (Read(ref key, ref input, ref output, userContext, serialNo), output); } -#if DEBUG - internal const string AdvancedOnlyMethodErr = "This method is not available on non-Advanced ClientSessions"; - - /// <summary>This method is not available for non-Advanced ClientSessions, because ReadCompletionCallback does not have RecordInfo.</summary> - [Obsolete(AdvancedOnlyMethodErr)] - public Status Read(ref Key key, ref Input input, ref Output output, ref RecordInfo recordInfo, ReadFlags readFlags = ReadFlags.None, Context userContext = default, long serialNo = 0) - => throw new FasterException(AdvancedOnlyMethodErr); -#endif // DEBUG; + /// <summary> + /// Read operation that accepts a <paramref name="recordInfo"/> ref argument to start the lookup at instead of starting at the hash table entry for <paramref name="key"/>, + /// and is updated with the record header for the found record (which contains previous address in the hash chain for this key; this can + /// be used as <paramref name="recordInfo"/> in a subsequent call to iterate all records for <paramref name="key"/>). + /// </summary> + /// <param name="key">The key to look up</param> + /// <param name="input">Input to help extract the retrieved value into <paramref name="output"/></param> + /// <param name="output">The location to place the retrieved value</param> + /// <param name="recordInfo">On input contains the address to start at in its <see cref="RecordInfo.PreviousAddress"/>; if this is Constants.kInvalidAddress, the + /// search starts with the key as in other forms of Read. On output, receives a copy of the record's header, which can be passed + /// in a subsequent call, thereby enumerating all records in a key's hash chain.</param> + /// <param name="readFlags">Flags for controlling operations within the read, such as ReadCache interaction</param> + /// <param name="userContext">User application context passed in case the read goes pending due to IO</param> + /// <param name="serialNo">The serial number of the operation (used in recovery)</param> + /// <returns><paramref name="output"/> is populated by the <see cref="IFunctions{Key, Value, Context}"/> implementation</returns> + /// <remarks>This method on non-Advanced ClientSessions is not suitable for read loops, because ReadCompletionCallback does not have RecordInfo.</remarks> + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Status Read(ref Key key, ref Input input, ref Output output, ref RecordInfo recordInfo, ReadFlags readFlags = ReadFlags.None, Context userContext = default, long serialNo = 0) + { + if (SupportAsync) UnsafeResumeThread(); + try + { + return fht.ContextRead(ref key, ref input, ref output, ref recordInfo, readFlags, userContext, FasterSession, serialNo, ctx); + } + finally + { + if (SupportAsync) UnsafeSuspendThread(); + } + } /// <summary> /// Read operation that accepts an <paramref name="address"/> argument to lookup at, instead of a key. @@ -366,13 +392,36 @@ public ValueTask<FasterKV<Key, Value>.ReadAsyncResult<Input, Output, Context>> R return fht.ReadAsync(this.FasterSession, this.ctx, ref key, ref input, Constants.kInvalidAddress, context, serialNo, token); } -#if DEBUG - /// <summary>For consistency with Read(.., ref RecordInfo, ...), this method is not available for non-Advanced ClientSessions.</summary> - [Obsolete(AdvancedOnlyMethodErr)] + /// <summary> + /// Async read operation that accepts a <paramref name="startAddress"/> to start the lookup at instead of starting at the hash table entry for <paramref name="key"/>, + /// and returns the <see cref="RecordInfo"/> for the found record (which contains previous address in the hash chain for this key; this can + /// be used as <paramref name="startAddress"/> in a subsequent call to iterate all records for <paramref name="key"/>). + /// </summary> + /// <param name="key">The key to look up</param> + /// <param name="input">Input to help extract the retrieved value into output</param> + /// <param name="startAddress">Start at this address rather than the address in the hash table for <paramref name="key"/>"/></param> + /// <param name="readFlags">Flags for controlling operations within the read, such as ReadCache interaction</param> + /// <param name="userContext">User application context passed in case the read goes pending due to IO</param> + /// <param name="serialNo">The serial number of the operation (used in recovery)</param> + /// <param name="cancellationToken">Token to cancel the operation</param> + /// <returns><see cref="ValueTask"/> wrapping <see cref="FasterKV{Key, Value}.ReadAsyncResult{Input, Output, Context}"/></returns> + /// <remarks>The caller must await the return value to obtain the result, then call one of + /// <list type="bullet"> + /// <item>result.<see cref="FasterKV{Key, Value}.ReadAsyncResult{Input, Output, Context}.Complete()"/></item> + /// <item>result.<see cref="FasterKV{Key, Value}.ReadAsyncResult{Input, Output, Context}.Complete(out RecordInfo)"/></item> + /// </list> + /// to complete the read operation and obtain the result status, the output that is populated by the + /// <see cref="IFunctions{Key, Value, Context}"/> implementation, and optionally a copy of the header for the retrieved record + /// <para>This method on non-Advanced ClientSessions is not suitable for read loops, because ReadCompletionCallback does not have RecordInfo.</para> + /// </remarks> + [MethodImpl(MethodImplOptions.AggressiveInlining)] public ValueTask<FasterKV<Key, Value>.ReadAsyncResult<Input, Output, Context>> ReadAsync(ref Key key, ref Input input, long startAddress, ReadFlags readFlags = ReadFlags.None, Context userContext = default, long serialNo = 0, CancellationToken cancellationToken = default) - => throw new FasterException(AdvancedOnlyMethodErr); -#endif + { + Debug.Assert(SupportAsync, NotAsyncSessionErr); + var operationFlags = FasterKV<Key, Value>.PendingContext<Input, Output, Context>.GetOperationFlags(readFlags); + return fht.ReadAsync(this.FasterSession, this.ctx, ref key, ref input, startAddress, userContext, serialNo, cancellationToken, operationFlags); + } /// <summary> /// Async Read operation that accepts an <paramref name="address"/> argument to lookup at, instead of a key. @@ -663,7 +712,7 @@ public ValueTask<FasterKV<Key, Value>.DeleteAsyncResult<Input, Output, Context>> /// Experimental feature /// Checks whether specified record is present in memory /// (between HeadAddress and tail, or between fromAddress - /// and tail) + /// and tail), including tombstones. /// </summary> /// <param name="key">Key of the record.</param> /// <param name="logicalAddress">Logical address of record, if found</param> @@ -890,6 +939,28 @@ public long Compact<CompactionFunctions>(long untilAddress, bool shiftBeginAddre return fht.Log.Compact<Input, Output, Context, Functions, CompactionFunctions>(functions, compactionFunctions, untilAddress, shiftBeginAddress); } + /// <summary> + /// Insert key and value with the record info preserved. + /// Succeed only if logical address of the key isn't greater than foundLogicalAddress; otherwise give up and return. + /// </summary> + /// <param name="key"></param> + /// <param name="desiredValue"></param> + /// <param name="recordInfo"></param> + /// <param name="foundLogicalAddress"></param> + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void CopyToTail(ref Key key, ref Value desiredValue, ref RecordInfo recordInfo, long foundLogicalAddress) + { + if (SupportAsync) UnsafeResumeThread(); + try + { + fht.InternalCopyToTail(ref key, ref desiredValue, ref recordInfo, foundLogicalAddress, FasterSession, ctx, noReadCache: true); + } + finally + { + if (SupportAsync) UnsafeSuspendThread(); + } + } + /// <summary> /// Iterator for all (distinct) live key-values stored in FASTER /// </summary> diff --git a/cs/src/core/Device/LocalMemoryDevice.cs b/cs/src/core/Device/LocalMemoryDevice.cs index 66c2ca0ef..fcb440705 100644 --- a/cs/src/core/Device/LocalMemoryDevice.cs +++ b/cs/src/core/Device/LocalMemoryDevice.cs @@ -48,7 +48,7 @@ public LocalMemoryDevice(long capacity, long sz_segment, int parallelism, int la : base(fileName, sector_size, capacity) { if (capacity == Devices.CAPACITY_UNSPECIFIED) throw new Exception("Local memory device must have a capacity!"); - Console.WriteLine("LocalMemoryDevice: Creating a " + capacity + " sized local memory device."); + Debug.WriteLine("LocalMemoryDevice: Creating a " + capacity + " sized local memory device."); num_segments = (int)(capacity / sz_segment); this.sz_segment = sz_segment; this.latencyTicks = latencyMs * TimeSpan.TicksPerMillisecond; @@ -75,7 +75,7 @@ public LocalMemoryDevice(long capacity, long sz_segment, int parallelism, int la ioProcessors[i].Start(); } - Console.WriteLine("LocalMemoryDevice: " + ram_segments.Length + " pinned in-memory segments created, each with " + sz_segment + " bytes"); + Debug.WriteLine("LocalMemoryDevice: " + ram_segments.Length + " pinned in-memory segments created, each with " + sz_segment + " bytes"); } private void ProcessIOQueue(ConcurrentQueue<IORequestLocalMemory> q) @@ -150,7 +150,7 @@ public override void WriteAsync(IntPtr sourceAddress, var req = new IORequestLocalMemory { srcAddress = (void*)sourceAddress, - dstAddress = ram_segments[segmentId] + destinationAddress, + dstAddress = ram_segments[segmentId % parallelism] + destinationAddress, bytes = numBytesToWrite, callback = callback, context = context diff --git a/cs/src/core/Index/CheckpointManagement/DeviceLogCommitCheckpointManager.cs b/cs/src/core/Index/CheckpointManagement/DeviceLogCommitCheckpointManager.cs index fd77e59dc..cd06ce56c 100644 --- a/cs/src/core/Index/CheckpointManagement/DeviceLogCommitCheckpointManager.cs +++ b/cs/src/core/Index/CheckpointManagement/DeviceLogCommitCheckpointManager.cs @@ -15,6 +15,9 @@ namespace FASTER.core /// </summary> public class DeviceLogCommitCheckpointManager : ILogCommitManager, ICheckpointManager { + const int indexTokenCount = 2; + const int logTokenCount = 1; + private readonly INamedDeviceFactory deviceFactory; private readonly ICheckpointNamingScheme checkpointNamingScheme; private readonly SemaphoreSlim semaphore; @@ -60,7 +63,7 @@ public DeviceLogCommitCheckpointManager(INamedDeviceFactory deviceFactory, IChec // later log checkpoint to work with indexTokenHistory = new Guid[2]; // We only keep the latest log checkpoint - logTokenHistory = new Guid[1]; + logTokenHistory = new Guid[2]; } this._disposed = false; @@ -73,6 +76,15 @@ public void PurgeAll() deviceFactory.Delete(new FileDescriptor { directoryName = "" }); } + /// <inheritdoc /> + public void Purge(Guid token) + { + // Try both because we do not know which type the guid denotes + deviceFactory.Delete(checkpointNamingScheme.LogCheckpointBase(token)); + deviceFactory.Delete(checkpointNamingScheme.IndexCheckpointBase(token)); + + } + /// <summary> /// Create new instance of log commit manager /// </summary> @@ -206,7 +218,7 @@ public unsafe void CommitIndexCheckpoint(Guid indexToken, byte[] commitMetadata) { var prior = indexTokenHistory[indexTokenHistoryOffset]; indexTokenHistory[indexTokenHistoryOffset] = indexToken; - indexTokenHistoryOffset = (indexTokenHistoryOffset + 1) % indexTokenHistory.Length; + indexTokenHistoryOffset = (indexTokenHistoryOffset + 1) % indexTokenCount; if (prior != default) deviceFactory.Delete(checkpointNamingScheme.IndexCheckpointBase(prior)); } @@ -253,7 +265,7 @@ public unsafe void CommitLogCheckpoint(Guid logToken, byte[] commitMetadata) { var prior = logTokenHistory[logTokenHistoryOffset]; logTokenHistory[logTokenHistoryOffset] = logToken; - logTokenHistoryOffset = (logTokenHistoryOffset + 1) % logTokenHistory.Length; + logTokenHistoryOffset = (logTokenHistoryOffset + 1) % logTokenCount; if (prior != default) deviceFactory.Delete(checkpointNamingScheme.LogCheckpointBase(prior)); } @@ -265,7 +277,7 @@ public unsafe void CommitLogIncrementalCheckpoint(Guid logToken, int version, by deltaLog.Allocate(out int length, out long physicalAddress); if (length < commitMetadata.Length) { - deltaLog.Seal(0, type: 1); + deltaLog.Seal(0, DeltaLogEntryType.CHECKPOINT_METADATA); deltaLog.Allocate(out length, out physicalAddress); if (length < commitMetadata.Length) { @@ -277,7 +289,7 @@ public unsafe void CommitLogIncrementalCheckpoint(Guid logToken, int version, by { Buffer.MemoryCopy(ptr, (void*)physicalAddress, commitMetadata.Length, commitMetadata.Length); } - deltaLog.Seal(commitMetadata.Length, type: 1); + deltaLog.Seal(commitMetadata.Length, DeltaLogEntryType.CHECKPOINT_METADATA); deltaLog.FlushAsync().Wait(); } @@ -288,25 +300,42 @@ public IEnumerable<Guid> GetLogCheckpointTokens() } /// <inheritdoc /> - public byte[] GetLogCheckpointMetadata(Guid logToken, DeltaLog deltaLog) + public byte[] GetLogCheckpointMetadata(Guid logToken, DeltaLog deltaLog, bool scanDelta, long recoverTo) { byte[] metadata = null; - if (deltaLog != null) + if (deltaLog != null && scanDelta) { // Try to get latest valid metadata from delta-log deltaLog.Reset(); - while (deltaLog.GetNext(out long physicalAddress, out int entryLength, out int type)) + while (deltaLog.GetNext(out long physicalAddress, out int entryLength, out var type)) { - if (type != 1) continue; // consider only metadata records - long endAddress = physicalAddress + entryLength; - metadata = new byte[entryLength]; - unsafe + switch (type) { - fixed (byte* m = metadata) - Buffer.MemoryCopy((void*)physicalAddress, m, entryLength, entryLength); + case DeltaLogEntryType.DELTA: + // consider only metadata records + continue; + case DeltaLogEntryType.CHECKPOINT_METADATA: + metadata = new byte[entryLength]; + unsafe + { + fixed (byte* m = metadata) + Buffer.MemoryCopy((void*)physicalAddress, m, entryLength, entryLength); + } + HybridLogRecoveryInfo recoveryInfo = new(); + using (StreamReader s = new(new MemoryStream(metadata))) { + recoveryInfo.Initialize(s); + // Finish recovery if only specific versions are requested + if (recoveryInfo.version == recoverTo || recoveryInfo.version < recoverTo && recoveryInfo.nextVersion > recoverTo) goto LoopEnd; + } + continue; + default: + throw new FasterException("Unexpected entry type"); } + LoopEnd: + break; } if (metadata != null) return metadata; + } var device = deviceFactory.Get(checkpointNamingScheme.LogCheckpointMetadata(logToken)); @@ -320,7 +349,7 @@ public byte[] GetLogCheckpointMetadata(Guid logToken, DeltaLog deltaLog) else ReadInto(device, 0, out body, size + sizeof(int)); device.Dispose(); - return new Span<byte>(body).Slice(sizeof(int)).ToArray(); + return body.AsSpan().Slice(sizeof(int), size).ToArray(); } /// <inheritdoc /> @@ -366,12 +395,12 @@ public void OnRecovery(Guid indexToken, Guid logToken) if (indexToken != default) { indexTokenHistory[indexTokenHistoryOffset] = indexToken; - indexTokenHistoryOffset = (indexTokenHistoryOffset + 1) % indexTokenHistory.Length; + indexTokenHistoryOffset = (indexTokenHistoryOffset + 1) % indexTokenCount; } if (logToken != default) { logTokenHistory[logTokenHistoryOffset] = logToken; - logTokenHistoryOffset = (logTokenHistoryOffset + 1) % logTokenHistory.Length; + logTokenHistoryOffset = (logTokenHistoryOffset + 1) % logTokenCount; } // Purge all log checkpoints that were not used for recovery diff --git a/cs/src/core/Index/Common/AddressInfo.cs b/cs/src/core/Index/Common/AddressInfo.cs index a33045ded..80d57d52c 100644 --- a/cs/src/core/Index/Common/AddressInfo.cs +++ b/cs/src/core/Index/Common/AddressInfo.cs @@ -80,6 +80,7 @@ readonly get } set { + var orig_word = word; var _word = (long)word; _word &= ~kAddressMask; _word |= (value & kAddressMask); diff --git a/cs/src/core/Index/Common/Contexts.cs b/cs/src/core/Index/Common/Contexts.cs index ab6f0cb10..80ba150bf 100644 --- a/cs/src/core/Index/Common/Contexts.cs +++ b/cs/src/core/Index/Common/Contexts.cs @@ -87,11 +87,15 @@ internal struct PendingContext<Input, Output, Context> internal byte operationFlags; internal RecordInfo recordInfo; + internal long minAddress; + // Note: Must be kept in sync with corresponding ReadFlags enum values internal const byte kSkipReadCache = 0x01; - internal const byte kNoKey = 0x02; - internal const byte kSkipCopyReadsToTail = 0x04; - internal const byte kIsAsync = 0x08; + internal const byte kMinAddress = 0x02; + + internal const byte kNoKey = 0x10; + internal const byte kSkipCopyReadsToTail = 0x20; + internal const byte kIsAsync = 0x40; [MethodImpl(MethodImplOptions.AggressiveInlining)] internal IHeapContainer<Key> DetachKey() @@ -113,7 +117,8 @@ internal IHeapContainer<Input> DetachInput() internal static byte GetOperationFlags(ReadFlags readFlags, bool noKey = false) { Debug.Assert((byte)ReadFlags.SkipReadCache == kSkipReadCache); - byte flags = (byte)(readFlags & ReadFlags.SkipReadCache); + Debug.Assert((byte)ReadFlags.MinAddress == kMinAddress); + byte flags = (byte)(readFlags & (ReadFlags.SkipReadCache | ReadFlags.MinAddress)); if (noKey) flags |= kNoKey; // This is always set true for the Read overloads (Reads by address) that call this method. @@ -121,6 +126,18 @@ internal static byte GetOperationFlags(ReadFlags readFlags, bool noKey = false) return flags; } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void SetOperationFlags(ReadFlags readFlags, long address, bool noKey = false) + => this.SetOperationFlags(GetOperationFlags(readFlags, noKey), address); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void SetOperationFlags(byte flags, long address) + { + this.operationFlags = flags; + if (this.HasMinAddress) + this.minAddress = address; + } + internal bool NoKey { get => (operationFlags & kNoKey) != 0; @@ -133,6 +150,12 @@ internal bool SkipReadCache set => operationFlags = value ? (byte)(operationFlags | kSkipReadCache) : (byte)(operationFlags & ~kSkipReadCache); } + internal bool HasMinAddress + { + get => (operationFlags & kMinAddress) != 0; + set => operationFlags = value ? (byte)(operationFlags | kMinAddress) : (byte)(operationFlags & ~kMinAddress); + } + internal bool SkipCopyReadsToTail { get => (operationFlags & kSkipCopyReadsToTail) != 0; @@ -405,15 +428,42 @@ public void Initialize(StreamReader reader) /// <param name="token"></param> /// <param name="checkpointManager"></param> /// <param name="deltaLog"></param> - /// <returns></returns> - internal void Recover(Guid token, ICheckpointManager checkpointManager, DeltaLog deltaLog = null) + /// <param name = "scanDelta"> + /// whether to scan the delta log to obtain the latest info contained in an incremental snapshot checkpoint. + /// If false, this will recover the base snapshot info but avoid potentially expensive scans. + /// </param> + /// <param name="recoverTo"> specific version to recover to, if using delta log</param> + internal void Recover(Guid token, ICheckpointManager checkpointManager, DeltaLog deltaLog = null, bool scanDelta = false, long recoverTo = -1) { - var metadata = checkpointManager.GetLogCheckpointMetadata(token, deltaLog); + var metadata = checkpointManager.GetLogCheckpointMetadata(token, deltaLog, scanDelta, recoverTo); if (metadata == null) throw new FasterException("Invalid log commit metadata for ID " + token.ToString()); - - using (StreamReader s = new(new MemoryStream(metadata))) - Initialize(s); + using StreamReader s = new(new MemoryStream(metadata)); + Initialize(s); + } + + /// <summary> + /// Recover info from token + /// </summary> + /// <param name="token"></param> + /// <param name="checkpointManager"></param> + /// <param name="deltaLog"></param> + /// <param name="commitCookie"> Any user-specified commit cookie written as part of the checkpoint </param> + /// <param name = "scanDelta"> + /// whether to scan the delta log to obtain the latest info contained in an incremental snapshot checkpoint. + /// If false, this will recover the base snapshot info but avoid potentially expensive scans. + /// </param> + /// <param name="recoverTo"> specific version to recover to, if using delta log</param> + + internal void Recover(Guid token, ICheckpointManager checkpointManager, out byte[] commitCookie, DeltaLog deltaLog = null, bool scanDelta = false, long recoverTo = -1) + { + var metadata = checkpointManager.GetLogCheckpointMetadata(token, deltaLog, scanDelta, recoverTo); + if (metadata == null) + throw new FasterException("Invalid log commit metadata for ID " + token.ToString()); + using StreamReader s = new(new MemoryStream(metadata)); + Initialize(s); + var cookie = s.ReadToEnd(); + commitCookie = cookie.Length == 0 ? null : Convert.FromBase64String(cookie); } /// <summary> @@ -501,7 +551,7 @@ public readonly void DebugPrint() } } - internal struct HybridLogCheckpointInfo + internal struct HybridLogCheckpointInfo : IDisposable { public HybridLogRecoveryInfo info; public IDevice snapshotFileDevice; @@ -517,7 +567,28 @@ public void Initialize(Guid token, int _version, ICheckpointManager checkpointMa checkpointManager.InitializeLogCheckpoint(token); } - public void Recover(Guid token, ICheckpointManager checkpointManager, int deltaLogPageSizeBits) + public void Dispose() + { + snapshotFileDevice?.Dispose(); + snapshotFileObjectLogDevice?.Dispose(); + deltaLog?.Dispose(); + deltaFileDevice?.Dispose(); + this = default; + } + + public HybridLogCheckpointInfo Transfer() + { + // Ownership transfer of handles across struct copies + var dest = this; + dest.snapshotFileDevice = default; + dest.snapshotFileObjectLogDevice = default; + this.deltaLog = default; + this.deltaFileDevice = default; + return dest; + } + + public void Recover(Guid token, ICheckpointManager checkpointManager, int deltaLogPageSizeBits, + bool scanDelta, long recoverTo) { deltaFileDevice = checkpointManager.GetDeltaLogDevice(token); deltaFileDevice.Initialize(-1); @@ -525,7 +596,7 @@ public void Recover(Guid token, ICheckpointManager checkpointManager, int deltaL { deltaLog = new DeltaLog(deltaFileDevice, deltaLogPageSizeBits, -1); deltaLog.InitializeForReads(); - info.Recover(token, checkpointManager, deltaLog); + info.Recover(token, checkpointManager, deltaLog, scanDelta, recoverTo); } else { @@ -533,12 +604,21 @@ public void Recover(Guid token, ICheckpointManager checkpointManager, int deltaL } } - public void Reset() + public void Recover(Guid token, ICheckpointManager checkpointManager, int deltaLogPageSizeBits, + out byte[] commitCookie, bool scanDelta = false, long recoverTo = -1) { - flushedSemaphore = null; - info = default; - snapshotFileDevice?.Dispose(); - snapshotFileObjectLogDevice?.Dispose(); + deltaFileDevice = checkpointManager.GetDeltaLogDevice(token); + deltaFileDevice.Initialize(-1); + if (deltaFileDevice.GetFileSize(0) > 0) + { + deltaLog = new DeltaLog(deltaFileDevice, deltaLogPageSizeBits, -1); + deltaLog.InitializeForReads(); + info.Recover(token, checkpointManager, out commitCookie, deltaLog, scanDelta, recoverTo); + } + else + { + info.Recover(token, checkpointManager, out commitCookie); + } } public bool IsDefault() @@ -607,6 +687,7 @@ public void Initialize(StreamReader reader) public void Recover(Guid guid, ICheckpointManager checkpointManager) { + this.token = guid; var metadata = checkpointManager.GetIndexCheckpointMetadata(guid); if (metadata == null) throw new FasterException("Invalid index commit metadata for ID " + guid.ToString()); @@ -687,7 +768,8 @@ public void Recover(Guid token, ICheckpointManager checkpointManager) public void Reset() { info = default; - main_ht_device.Dispose(); + main_ht_device?.Dispose(); + main_ht_device = null; } public bool IsDefault() diff --git a/cs/src/core/Index/FASTER/FASTER.cs b/cs/src/core/Index/FASTER/FASTER.cs index a97f971e0..e040ba7c3 100644 --- a/cs/src/core/Index/FASTER/FASTER.cs +++ b/cs/src/core/Index/FASTER/FASTER.cs @@ -16,6 +16,7 @@ namespace FASTER.core /// <summary> /// Flags for the Read-by-address methods /// </summary> + /// <remarks>Note: must be kept in sync with corresponding PendingContext k* values</remarks> [Flags] public enum ReadFlags { @@ -24,6 +25,9 @@ public enum ReadFlags /// <summary>Skip the ReadCache when reading, including not inserting to ReadCache when pending reads are complete</summary> SkipReadCache = 0x00000001, + + /// <summary>The minimum address at which to resolve the Key; return <see cref="Status.NOTFOUND"/> if the key is not found at this address or higher</summary> + MinAddress = 0x00000002, } public partial class FasterKV<Key, Value> : FasterBase, @@ -221,33 +225,43 @@ public FasterKV(long size, LogSettings logSettings, Initialize(size, sectorSize); systemState = default; - systemState.phase = Phase.REST; - systemState.version = 1; + systemState.Phase = Phase.REST; + systemState.Version = 1; } /// <summary> /// Initiate full checkpoint /// </summary> /// <param name="token">Checkpoint token</param> + /// <param name="targetVersion"> + /// intended version number of the next version. Checkpoint will not execute if supplied version is not larger + /// than current version. Actual new version may have version number greater than supplied number. If the supplied + /// number is -1, checkpoint will unconditionally create a new version. + /// </param> /// <returns> /// Whether we successfully initiated the checkpoint (initiation may /// fail if we are already taking a checkpoint or performing some other /// operation such as growing the index). Use CompleteCheckpointAsync to wait completion. /// </returns> - public bool TakeFullCheckpoint(out Guid token) - => TakeFullCheckpoint(out token, this.FoldOverSnapshot ? CheckpointType.FoldOver : CheckpointType.Snapshot); + public bool TakeFullCheckpoint(out Guid token, long targetVersion = -1) + => TakeFullCheckpoint(out token, this.FoldOverSnapshot ? CheckpointType.FoldOver : CheckpointType.Snapshot, targetVersion); /// <summary> /// Initiate full checkpoint /// </summary> /// <param name="token">Checkpoint token</param> /// <param name="checkpointType">Checkpoint type</param> + /// <param name="targetVersion"> + /// intended version number of the next version. Checkpoint will not execute if supplied version is not larger + /// than current version. Actual new version may have version number greater than supplied number. If the supplied + /// number is -1, checkpoint will unconditionally create a new version. + /// </param> /// <returns> /// Whether we successfully initiated the checkpoint (initiation may /// fail if we are already taking a checkpoint or performing some other /// operation such as growing the index). Use CompleteCheckpointAsync to wait completion. /// </returns> - public bool TakeFullCheckpoint(out Guid token, CheckpointType checkpointType) + public bool TakeFullCheckpoint(out Guid token, CheckpointType checkpointType, long targetVersion = -1) { ISynchronizationTask backend; if (checkpointType == CheckpointType.FoldOver) @@ -257,7 +271,7 @@ public bool TakeFullCheckpoint(out Guid token, CheckpointType checkpointType) else throw new FasterException("Unsupported full checkpoint type"); - var result = StartStateMachine(new FullCheckpointStateMachine(backend, -1)); + var result = StartStateMachine(new FullCheckpointStateMachine(backend, targetVersion)); if (result) token = _hybridLogCheckpointToken; else @@ -270,6 +284,11 @@ public bool TakeFullCheckpoint(out Guid token, CheckpointType checkpointType) /// </summary> /// <param name="checkpointType">Checkpoint type</param> /// <param name="cancellationToken">Cancellation token</param> + /// <param name="targetVersion"> + /// intended version number of the next version. Checkpoint will not execute if supplied version is not larger + /// than current version. Actual new version may have version number greater than supplied number. If the supplied + /// number is -1, checkpoint will unconditionally create a new version. + /// </param> /// <returns> /// (bool success, Guid token) /// success: Whether we successfully initiated the checkpoint (initiation may @@ -278,9 +297,10 @@ public bool TakeFullCheckpoint(out Guid token, CheckpointType checkpointType) /// token: Token for taken checkpoint /// Await task to complete checkpoint, if initiated successfully /// </returns> - public async ValueTask<(bool success, Guid token)> TakeFullCheckpointAsync(CheckpointType checkpointType, CancellationToken cancellationToken = default) + public async ValueTask<(bool success, Guid token)> TakeFullCheckpointAsync(CheckpointType checkpointType, + CancellationToken cancellationToken = default, long targetVersion = -1) { - var success = TakeFullCheckpoint(out Guid token, checkpointType); + var success = TakeFullCheckpoint(out Guid token, checkpointType, targetVersion); if (success) await CompleteCheckpointAsync(cancellationToken).ConfigureAwait(false); @@ -326,8 +346,13 @@ public bool TakeIndexCheckpoint(out Guid token) /// Initiate log-only checkpoint /// </summary> /// <param name="token">Checkpoint token</param> + /// <param name="targetVersion"> + /// intended version number of the next version. Checkpoint will not execute if supplied version is not larger + /// than current version. Actual new version may have version number greater than supplied number. If the supplied + /// number is -1, checkpoint will unconditionally create a new version. + /// </param> /// <returns>Whether we could initiate the checkpoint. Use CompleteCheckpointAsync to wait completion.</returns> - public bool TakeHybridLogCheckpoint(out Guid token) + public bool TakeHybridLogCheckpoint(out Guid token, long targetVersion = -1) { ISynchronizationTask backend; if (FoldOverSnapshot) @@ -335,7 +360,7 @@ public bool TakeHybridLogCheckpoint(out Guid token) else backend = new SnapshotCheckpointTask(); - var result = StartStateMachine(new HybridLogCheckpointStateMachine(backend, -1)); + var result = StartStateMachine(new HybridLogCheckpointStateMachine(backend, targetVersion)); token = _hybridLogCheckpointToken; return result; } @@ -346,8 +371,14 @@ public bool TakeHybridLogCheckpoint(out Guid token) /// <param name="token">Checkpoint token</param> /// <param name="checkpointType">Checkpoint type</param> /// <param name="tryIncremental">For snapshot, try to store as incremental delta over last snapshot</param> + /// <param name="targetVersion"> + /// intended version number of the next version. Checkpoint will not execute if supplied version is not larger + /// than current version. Actual new version may have version number greater than supplied number. If the supplied + /// number is -1, checkpoint will unconditionally create a new version. + /// </param> /// <returns>Whether we could initiate the checkpoint. Use CompleteCheckpointAsync to wait completion.</returns> - public bool TakeHybridLogCheckpoint(out Guid token, CheckpointType checkpointType, bool tryIncremental = false) + public bool TakeHybridLogCheckpoint(out Guid token, CheckpointType checkpointType, bool tryIncremental = false, + long targetVersion = -1) { ISynchronizationTask backend; if (checkpointType == CheckpointType.FoldOver) @@ -362,7 +393,7 @@ public bool TakeHybridLogCheckpoint(out Guid token, CheckpointType checkpointTyp else throw new FasterException("Unsupported checkpoint type"); - var result = StartStateMachine(new HybridLogCheckpointStateMachine(backend, -1)); + var result = StartStateMachine(new HybridLogCheckpointStateMachine(backend, targetVersion)); token = _hybridLogCheckpointToken; return result; } @@ -373,6 +404,11 @@ public bool TakeHybridLogCheckpoint(out Guid token, CheckpointType checkpointTyp /// <param name="checkpointType">Checkpoint type</param> /// <param name="tryIncremental">For snapshot, try to store as incremental delta over last snapshot</param> /// <param name="cancellationToken">Cancellation token</param> + /// <param name="targetVersion"> + /// intended version number of the next version. Checkpoint will not execute if supplied version is not larger + /// than current version. Actual new version may have version number greater than supplied number. If the supplied + /// number is -1, checkpoint will unconditionally create a new version. + /// </param> /// <returns> /// (bool success, Guid token) /// success: Whether we successfully initiated the checkpoint (initiation may @@ -381,9 +417,10 @@ public bool TakeHybridLogCheckpoint(out Guid token, CheckpointType checkpointTyp /// token: Token for taken checkpoint /// Await task to complete checkpoint, if initiated successfully /// </returns> - public async ValueTask<(bool success, Guid token)> TakeHybridLogCheckpointAsync(CheckpointType checkpointType, bool tryIncremental = false, CancellationToken cancellationToken = default) + public async ValueTask<(bool success, Guid token)> TakeHybridLogCheckpointAsync(CheckpointType checkpointType, + bool tryIncremental = false, CancellationToken cancellationToken = default, long targetVersion = -1) { - var success = TakeHybridLogCheckpoint(out Guid token, checkpointType, tryIncremental); + var success = TakeHybridLogCheckpoint(out Guid token, checkpointType, tryIncremental, targetVersion); if (success) await CompleteCheckpointAsync(cancellationToken).ConfigureAwait(false); @@ -396,9 +433,12 @@ public bool TakeHybridLogCheckpoint(out Guid token, CheckpointType checkpointTyp /// </summary> /// <param name="numPagesToPreload">Number of pages to preload into memory (beyond what needs to be read for recovery)</param> /// <param name="undoNextVersion">Whether records with versions beyond checkpoint version need to be undone (and invalidated on log)</param> - public void Recover(int numPagesToPreload = -1, bool undoNextVersion = true) + /// <param name="recoverTo"> specific version requested or -1 for latest version. FASTER will recover to the largest version number checkpointed that's smaller than the required version. </param> + + public void Recover(int numPagesToPreload = -1, bool undoNextVersion = true, long recoverTo = -1) { - InternalRecoverFromLatestCheckpoints(numPagesToPreload, undoNextVersion); + FindRecoveryInfo(recoverTo, out var recoveredHlcInfo, out var recoveredIcInfo); + InternalRecover(recoveredIcInfo, recoveredHlcInfo, numPagesToPreload, undoNextVersion, recoverTo); } /// <summary> @@ -406,9 +446,14 @@ public void Recover(int numPagesToPreload = -1, bool undoNextVersion = true) /// </summary> /// <param name="numPagesToPreload">Number of pages to preload into memory (beyond what needs to be read for recovery)</param> /// <param name="undoNextVersion">Whether records with versions beyond checkpoint version need to be undone (and invalidated on log)</param> + /// <param name="recoverTo"> specific version requested or -1 for latest version. FASTER will recover to the largest version number checkpointed that's smaller than the required version.</param> /// <param name="cancellationToken">Cancellation token</param> - public ValueTask RecoverAsync(int numPagesToPreload = -1, bool undoNextVersion = true, CancellationToken cancellationToken = default) - => InternalRecoverFromLatestCheckpointsAsync(numPagesToPreload, undoNextVersion, cancellationToken); + public ValueTask RecoverAsync(int numPagesToPreload = -1, bool undoNextVersion = true, long recoverTo = -1, + CancellationToken cancellationToken = default) + { + FindRecoveryInfo(recoverTo, out var recoveredHlcInfo, out var recoveredIcInfo); + return InternalRecoverAsync(recoveredIcInfo, recoveredHlcInfo, numPagesToPreload, undoNextVersion, recoverTo, cancellationToken); + } /// <summary> /// Recover from specific token (blocking operation) @@ -418,7 +463,7 @@ public ValueTask RecoverAsync(int numPagesToPreload = -1, bool undoNextVersion = /// <param name="undoNextVersion">Whether records with versions beyond checkpoint version need to be undone (and invalidated on log)</param> public void Recover(Guid fullCheckpointToken, int numPagesToPreload = -1, bool undoNextVersion = true) { - InternalRecover(fullCheckpointToken, fullCheckpointToken, numPagesToPreload, undoNextVersion); + InternalRecover(fullCheckpointToken, fullCheckpointToken, numPagesToPreload, undoNextVersion, -1); } /// <summary> @@ -429,7 +474,7 @@ public void Recover(Guid fullCheckpointToken, int numPagesToPreload = -1, bool u /// <param name="undoNextVersion">Whether records with versions beyond checkpoint version need to be undone (and invalidated on log)</param> /// <param name="cancellationToken">Cancellation token</param> public ValueTask RecoverAsync(Guid fullCheckpointToken, int numPagesToPreload = -1, bool undoNextVersion = true, CancellationToken cancellationToken = default) - => InternalRecoverAsync(fullCheckpointToken, fullCheckpointToken, numPagesToPreload, undoNextVersion, cancellationToken); + => InternalRecoverAsync(fullCheckpointToken, fullCheckpointToken, numPagesToPreload, undoNextVersion, -1, cancellationToken); /// <summary> /// Recover from specific index and log token (blocking operation) @@ -440,7 +485,7 @@ public ValueTask RecoverAsync(Guid fullCheckpointToken, int numPagesToPreload = /// <param name="undoNextVersion">Whether records with versions beyond checkpoint version need to be undone (and invalidated on log)</param> public void Recover(Guid indexCheckpointToken, Guid hybridLogCheckpointToken, int numPagesToPreload = -1, bool undoNextVersion = true) { - InternalRecover(indexCheckpointToken, hybridLogCheckpointToken, numPagesToPreload, undoNextVersion); + InternalRecover(indexCheckpointToken, hybridLogCheckpointToken, numPagesToPreload, undoNextVersion, -1); } /// <summary> @@ -452,7 +497,7 @@ public void Recover(Guid indexCheckpointToken, Guid hybridLogCheckpointToken, in /// <param name="undoNextVersion">Whether records with versions beyond checkpoint version need to be undone (and invalidated on log)</param> /// <param name="cancellationToken">Cancellation token</param> public ValueTask RecoverAsync(Guid indexCheckpointToken, Guid hybridLogCheckpointToken, int numPagesToPreload = -1, bool undoNextVersion = true, CancellationToken cancellationToken = default) - => InternalRecoverAsync(indexCheckpointToken, hybridLogCheckpointToken, numPagesToPreload, undoNextVersion, cancellationToken); + => InternalRecoverAsync(indexCheckpointToken, hybridLogCheckpointToken, numPagesToPreload, undoNextVersion, -1, cancellationToken); /// <summary> /// Wait for ongoing checkpoint to complete @@ -468,13 +513,21 @@ public async ValueTask CompleteCheckpointAsync(CancellationToken token = default while (true) { var systemState = this.systemState; - if (systemState.phase == Phase.REST || systemState.phase == Phase.PREPARE_GROW || - systemState.phase == Phase.IN_PROGRESS_GROW) + if (systemState.Phase == Phase.REST || systemState.Phase == Phase.PREPARE_GROW || + systemState.Phase == Phase.IN_PROGRESS_GROW) return; List<ValueTask> valueTasks = new(); - - ThreadStateMachineStep<Empty, Empty, Empty, NullFasterSession>(null, NullFasterSession.Instance, valueTasks, token); + try + { + ThreadStateMachineStep<Empty, Empty, Empty, NullFasterSession>(null, NullFasterSession.Instance, valueTasks, token); + } + catch (Exception) + { + this._indexCheckpoint.Reset(); + this._hybridLogCheckpoint.Dispose(); + throw; + } if (valueTasks.Count == 0) continue; // we need to re-check loop, so we return only when we are at REST @@ -517,7 +570,7 @@ internal Status ContextRead<Input, Output, Context, FasterSession>(ref Key key, where FasterSession : IFasterSession<Key, Value, Input, Output, Context> { var pcontext = default(PendingContext<Input, Output, Context>); - pcontext.operationFlags = PendingContext<Input, Output, Context>.GetOperationFlags(readFlags); + pcontext.SetOperationFlags(readFlags, recordInfo.PreviousAddress); var internalStatus = InternalRead(ref key, ref input, ref output, recordInfo.PreviousAddress, ref context, ref pcontext, fasterSession, sessionCtx, serialNo); Debug.Assert(internalStatus != OperationStatus.RETRY_NOW); @@ -544,7 +597,7 @@ internal Status ContextReadAtAddress<Input, Output, Context, FasterSession>(long where FasterSession : IFasterSession<Key, Value, Input, Output, Context> { var pcontext = default(PendingContext<Input, Output, Context>); - pcontext.operationFlags = PendingContext<Input, Output, Context>.GetOperationFlags(readFlags, noKey: true); + pcontext.SetOperationFlags(readFlags, address, noKey: true); Key key = default; var internalStatus = InternalRead(ref key, ref input, ref output, address, ref context, ref pcontext, fasterSession, sessionCtx, serialNo); Debug.Assert(internalStatus != OperationStatus.RETRY_NOW); @@ -669,7 +722,7 @@ public bool GrowIndex() while (true) { SystemState _systemState = SystemState.Copy(ref systemState); - if (_systemState.phase == Phase.IN_PROGRESS_GROW) + if (_systemState.Phase == Phase.IN_PROGRESS_GROW) { SplitBuckets(0); epoch.ProtectAndDrain(); @@ -677,7 +730,7 @@ public bool GrowIndex() else { SystemState.RemoveIntermediate(ref _systemState); - if (_systemState.phase != Phase.PREPARE_GROW && _systemState.phase != Phase.IN_PROGRESS_GROW) + if (_systemState.Phase != Phase.PREPARE_GROW && _systemState.Phase != Phase.IN_PROGRESS_GROW) { return true; } @@ -698,13 +751,12 @@ public void Dispose() Free(); hlog.Dispose(); readcache?.Dispose(); - _lastSnapshotCheckpoint.deltaLog?.Dispose(); - _lastSnapshotCheckpoint.deltaFileDevice?.Dispose(); + _lastSnapshotCheckpoint.Dispose(); if (disposeCheckpointManager) checkpointManager?.Dispose(); } - private void UpdateVarLen(ref VariableLengthStructSettings<Key, Value> variableLengthStructSettings) + private static void UpdateVarLen(ref VariableLengthStructSettings<Key, Value> variableLengthStructSettings) { if (typeof(Key) == typeof(SpanByte)) { diff --git a/cs/src/core/Index/FASTER/FASTERImpl.cs b/cs/src/core/Index/FASTER/FASTERImpl.cs index fe7e0271c..2752cd1b3 100644 --- a/cs/src/core/Index/FASTER/FASTERImpl.cs +++ b/cs/src/core/Index/FASTER/FASTERImpl.cs @@ -13,15 +13,44 @@ namespace FASTER.core { public unsafe partial class FasterKV<Key, Value> : FasterBase, IFasterKV<Key, Value> { + /// <summary> + /// This is a wrapper for checking the record's version instead of just peeking at the latest record at the tail of the bucket. + /// By calling with the address of the traced record, we can prevent a different key sharing the same bucket from deceiving + /// the operation to think that the version of the key has reached v+1 and thus to incorrectly update in place. + /// </summary> + /// <typeparam name="Input"></typeparam> + /// <typeparam name="Output"></typeparam> + /// <typeparam name="Context"></typeparam> + /// <param name="logicalAddress">The logical address of the traced record for the key</param> + /// <param name="sessionCtx"></param> + /// <returns></returns> + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private bool CheckEntryVersionNew<Input, Output, Context>(long logicalAddress, FasterExecutionContext<Input, Output, Context> sessionCtx) + { + HashBucketEntry entry = default; + entry.word = logicalAddress; + return CheckBucketVersionNew(ref entry, sessionCtx); + } + + /// <summary> + /// Check the version of the passed-in entry. + /// The semantics of this function are to check the tail of a bucket (indicated by entry), so we name it this way. + /// </summary> + /// <typeparam name="Input"></typeparam> + /// <typeparam name="Output"></typeparam> + /// <typeparam name="Context"></typeparam> + /// <param name="entry">the last entry of a bucket</param> + /// <param name="sessionCtx"></param> + /// <returns></returns> [MethodImpl(MethodImplOptions.AggressiveInlining)] - private bool InVersionNew<Input, Output, Context>(ref HashBucketEntry entry, FasterExecutionContext<Input, Output, Context> sessionCtx) + private bool CheckBucketVersionNew<Input, Output, Context>(ref HashBucketEntry entry, FasterExecutionContext<Input, Output, Context> sessionCtx) { // A version shift can only in an address after the checkpoint starts, as v_new threads RCU entries to the tail. if (entry.Address < _hybridLogCheckpoint.info.startLogicalAddress) return false; // Otherwise, check if the version suffix of the entry matches v_new. return GetLatestRecordVersion(ref entry, sessionCtx.version) == RecordInfo.GetShortVersion(currentSyncStateMachine.ToVersion()); } - + internal enum LatchOperation : byte { None, @@ -94,11 +123,11 @@ internal OperationStatus InternalRead<Input, Output, Context, Functions>( OperationStatus status; long logicalAddress; - var usePreviousAddress = startAddress != Constants.kInvalidAddress; + var useStartAddress = startAddress != Constants.kInvalidAddress && !pendingContext.HasMinAddress; bool tagExists; - if (!usePreviousAddress) + if (!useStartAddress) { - tagExists = FindTag(hash, tag, ref bucket, ref slot, ref entry); + tagExists = FindTag(hash, tag, ref bucket, ref slot, ref entry) && entry.Address >= pendingContext.minAddress; } else { @@ -112,13 +141,13 @@ internal OperationStatus InternalRead<Input, Output, Context, Functions>( if (UseReadCache) { - if (pendingContext.SkipReadCache) + if (pendingContext.SkipReadCache || pendingContext.NoKey) { SkipReadCache(ref logicalAddress); } else if (ReadFromCache(ref key, ref logicalAddress, ref physicalAddress)) { - if (sessionCtx.phase == Phase.PREPARE && InVersionNew(ref entry, sessionCtx)) + if (sessionCtx.phase == Phase.PREPARE && CheckBucketVersionNew(ref entry, sessionCtx)) { status = OperationStatus.CPR_SHIFT_DETECTED; goto CreatePendingContext; // Pivot thread @@ -160,7 +189,7 @@ internal OperationStatus InternalRead<Input, Output, Context, Functions>( } #endregion - if (sessionCtx.phase == Phase.PREPARE && InVersionNew(ref entry, sessionCtx)) + if (sessionCtx.phase == Phase.PREPARE && CheckBucketVersionNew(ref entry, sessionCtx)) { status = OperationStatus.CPR_SHIFT_DETECTED; goto CreatePendingContext; // Pivot thread @@ -188,10 +217,10 @@ internal OperationStatus InternalRead<Input, Output, Context, Functions>( if (!pendingContext.recordInfo.Tombstone) { fasterSession.SingleReader(ref key, ref input, ref hlog.GetValue(physicalAddress), ref output, logicalAddress); - if (CopyReadsToTail == CopyReadsToTail.FromReadOnly) + if (CopyReadsToTail == CopyReadsToTail.FromReadOnly && !pendingContext.SkipCopyReadsToTail) { var container = hlog.GetValueContainer(ref hlog.GetValue(physicalAddress)); - InternalUpsert(ref key, ref container.Get(), ref userContext, ref pendingContext, fasterSession, sessionCtx, lsn); + InternalTryCopyToTail(ref key, ref container.Get(), ref pendingContext.recordInfo, logicalAddress, fasterSession, sessionCtx); container.Dispose(); } return OperationStatus.SUCCESS; @@ -209,7 +238,7 @@ internal OperationStatus InternalRead<Input, Output, Context, Functions>( if (sessionCtx.phase == Phase.PREPARE) { Debug.Assert(heldOperation != LatchOperation.Exclusive); - if (usePreviousAddress) + if (useStartAddress) { Debug.Assert(heldOperation == LatchOperation.None); } @@ -370,7 +399,7 @@ internal OperationStatus InternalUpsert<Input, Output, Context, FasterSession>( #region Entry latch operation if (sessionCtx.phase != Phase.REST) { - latchDestination = AcquireLatchUpsert(sessionCtx, bucket, ref status, ref latchOperation, ref entry); + latchDestination = AcquireLatchUpsert(sessionCtx, bucket, ref status, ref latchOperation, ref entry, logicalAddress); } #endregion @@ -444,7 +473,7 @@ internal OperationStatus InternalUpsert<Input, Output, Context, FasterSession>( } private LatchDestination AcquireLatchUpsert<Input, Output, Context>(FasterExecutionContext<Input, Output, Context> sessionCtx, HashBucket* bucket, ref OperationStatus status, - ref LatchOperation latchOperation, ref HashBucketEntry entry) + ref LatchOperation latchOperation, ref HashBucketEntry entry, long logicalAddress) { switch (sessionCtx.phase) { @@ -454,7 +483,11 @@ private LatchDestination AcquireLatchUpsert<Input, Output, Context>(FasterExecut { // Set to release shared latch (default) latchOperation = LatchOperation.Shared; - if (InVersionNew(ref entry, sessionCtx)) + // Here (and in InternalRead, AcquireLatchRMW, and InternalDelete) we still check the tail record of the bucket (entry.Address) + // rather than the traced record (logicalAddress), because I'm worried that the implementation + // may not allow in-place updates for version v when the bucket arrives v+1. + // This is safer but potentially unnecessary. + if (CheckBucketVersionNew(ref entry, sessionCtx)) { status = OperationStatus.CPR_SHIFT_DETECTED; return LatchDestination.CreatePendingContext; // Pivot Thread @@ -469,7 +502,7 @@ private LatchDestination AcquireLatchUpsert<Input, Output, Context>(FasterExecut } case Phase.IN_PROGRESS: { - if (!InVersionNew(ref entry, sessionCtx)) + if (!CheckEntryVersionNew(logicalAddress, sessionCtx)) { if (HashBucket.TryAcquireExclusiveLatch(bucket)) { @@ -487,7 +520,7 @@ private LatchDestination AcquireLatchUpsert<Input, Output, Context>(FasterExecut } case Phase.WAIT_PENDING: { - if (!InVersionNew(ref entry, sessionCtx)) + if (!CheckEntryVersionNew(logicalAddress, sessionCtx)) { if (HashBucket.NoSharedLatches(bucket)) { @@ -503,7 +536,7 @@ private LatchDestination AcquireLatchUpsert<Input, Output, Context>(FasterExecut } case Phase.WAIT_FLUSH: { - if (!InVersionNew(ref entry, sessionCtx)) + if (!CheckEntryVersionNew(logicalAddress, sessionCtx)) { return LatchDestination.CreateNewRecord; // Create a (v+1) record } @@ -812,7 +845,7 @@ private LatchDestination AcquireLatchRMW<Input, Output, Context>(PendingContext< { // Set to release shared latch (default) latchOperation = LatchOperation.Shared; - if (InVersionNew(ref entry, sessionCtx)) + if (CheckBucketVersionNew(ref entry, sessionCtx)) { status = OperationStatus.CPR_SHIFT_DETECTED; return LatchDestination.CreatePendingContext; // Pivot Thread @@ -827,7 +860,7 @@ private LatchDestination AcquireLatchRMW<Input, Output, Context>(PendingContext< } case Phase.IN_PROGRESS: { - if (!InVersionNew(ref entry, sessionCtx)) + if (!CheckEntryVersionNew(logicalAddress, sessionCtx)) { Debug.Assert(pendingContext.heldLatch != LatchOperation.Shared); if (pendingContext.heldLatch == LatchOperation.Exclusive || HashBucket.TryAcquireExclusiveLatch(bucket)) @@ -847,7 +880,7 @@ private LatchDestination AcquireLatchRMW<Input, Output, Context>(PendingContext< } case Phase.WAIT_PENDING: { - if (!InVersionNew(ref entry, sessionCtx)) + if (!CheckEntryVersionNew(logicalAddress, sessionCtx)) { if (HashBucket.NoSharedLatches(bucket)) { @@ -864,7 +897,7 @@ private LatchDestination AcquireLatchRMW<Input, Output, Context>(PendingContext< } case Phase.WAIT_FLUSH: { - if (!InVersionNew(ref entry, sessionCtx)) + if (!CheckEntryVersionNew(logicalAddress, sessionCtx)) { if (logicalAddress >= hlog.HeadAddress) return LatchDestination.CreateNewRecord; // Create a (v+1) record @@ -1045,7 +1078,7 @@ internal OperationStatus InternalDelete<Input, Output, Context, FasterSession>( { // Set to release shared latch (default) latchOperation = LatchOperation.Shared; - if (InVersionNew(ref entry, sessionCtx)) + if (CheckBucketVersionNew(ref entry, sessionCtx)) { status = OperationStatus.CPR_SHIFT_DETECTED; goto CreatePendingContext; // Pivot Thread @@ -1060,7 +1093,7 @@ internal OperationStatus InternalDelete<Input, Output, Context, FasterSession>( } case Phase.IN_PROGRESS: { - if (!InVersionNew(ref entry, sessionCtx)) + if (!CheckEntryVersionNew(logicalAddress, sessionCtx)) { if (HashBucket.TryAcquireExclusiveLatch(bucket)) { @@ -1078,7 +1111,7 @@ internal OperationStatus InternalDelete<Input, Output, Context, FasterSession>( } case Phase.WAIT_PENDING: { - if (!InVersionNew(ref entry, sessionCtx)) + if (!CheckEntryVersionNew(logicalAddress, sessionCtx)) { if (HashBucket.NoSharedLatches(bucket)) { @@ -1094,7 +1127,7 @@ internal OperationStatus InternalDelete<Input, Output, Context, FasterSession>( } case Phase.WAIT_FLUSH: { - if (!InVersionNew(ref entry, sessionCtx)) + if (!CheckEntryVersionNew(logicalAddress, sessionCtx)) { goto CreateNewRecord; // Create a (v+1) record } @@ -1369,97 +1402,12 @@ internal void InternalContinuePendingReadCopyToTail<Input, Output, Context, Fast { Debug.Assert(RelaxedCPR || pendingContext.version == opCtx.version); - var bucket = default(HashBucket*); - var slot = default(int); - var logicalAddress = Constants.kInvalidAddress; - var physicalAddress = default(long); - - // If NoKey, we do not have the key in the initial call and must use the key from the satisfied request. ref Key key = ref pendingContext.NoKey ? ref hlog.GetContextRecordKey(ref request) : ref pendingContext.key.Get(); - - var hash = comparer.GetHashCode64(ref key); - - var tag = (ushort)((ulong)hash >> Constants.kHashTagShift); - -#region Trace back record in in-memory HybridLog - var entry = default(HashBucketEntry); - FindOrCreateTag(hash, tag, ref bucket, ref slot, ref entry, hlog.BeginAddress); - logicalAddress = entry.word & Constants.kAddressMask; - - if (UseReadCache) - SkipReadCache(ref logicalAddress); - var latestLogicalAddress = logicalAddress; - - if (logicalAddress >= hlog.HeadAddress) - { - physicalAddress = hlog.GetPhysicalAddress(logicalAddress); - if (!comparer.Equals(ref key, ref hlog.GetKey(physicalAddress))) - { - logicalAddress = hlog.GetInfo(physicalAddress).PreviousAddress; - TraceBackForKeyMatch(ref key, - logicalAddress, - hlog.HeadAddress, - out logicalAddress, - out physicalAddress); - } - } -#endregion - - if (logicalAddress > pendingContext.entry.Address) - { - // Give up early - return; - } - -#region Create new copy in mutable region - physicalAddress = (long)request.record.GetValidPointer(); - var (actualSize, allocatedSize) = hlog.GetRecordSize(physicalAddress); - - long newLogicalAddress, newPhysicalAddress; - if (UseReadCache) - { - BlockAllocateReadCache(allocatedSize, out newLogicalAddress, currentCtx, fasterSession); - newPhysicalAddress = readcache.GetPhysicalAddress(newLogicalAddress); - RecordInfo.WriteInfo(ref readcache.GetInfo(newPhysicalAddress), opCtx.version, - tombstone:false, invalidBit:false, - entry.Address); - readcache.Serialize(ref key, newPhysicalAddress); - fasterSession.SingleWriter(ref key, - ref hlog.GetContextRecordValue(ref request), - ref readcache.GetValue(newPhysicalAddress, newPhysicalAddress + actualSize)); - } - else - { - BlockAllocate(allocatedSize, out newLogicalAddress, currentCtx, fasterSession); - newPhysicalAddress = hlog.GetPhysicalAddress(newLogicalAddress); - RecordInfo.WriteInfo(ref hlog.GetInfo(newPhysicalAddress), opCtx.version, - tombstone:false, invalidBit:false, - latestLogicalAddress); - hlog.Serialize(ref key, newPhysicalAddress); - fasterSession.SingleWriter(ref key, - ref hlog.GetContextRecordValue(ref request), - ref hlog.GetValue(newPhysicalAddress, newPhysicalAddress + actualSize)); - } - - - var updatedEntry = default(HashBucketEntry); - updatedEntry.Tag = tag; - updatedEntry.Address = newLogicalAddress & Constants.kAddressMask; - updatedEntry.Pending = entry.Pending; - updatedEntry.Tentative = false; - updatedEntry.ReadCache = UseReadCache; - - var foundEntry = default(HashBucketEntry); - foundEntry.word = Interlocked.CompareExchange( - ref bucket->bucket_entries[slot], - updatedEntry.word, - entry.word); - if (foundEntry.word != entry.word) - { - if (!UseReadCache) hlog.GetInfo(newPhysicalAddress).Invalid = true; - // We don't retry, just give up - } -#endregion + byte* physicalAddress = request.record.GetValidPointer(); + long logicalAddress = pendingContext.entry.Address; + ref RecordInfo oldRecordInfo = ref hlog.GetInfoFromBytePointer(physicalAddress); + + InternalTryCopyToTail(opCtx, ref key, ref hlog.GetContextRecordValue(ref request), ref oldRecordInfo, logicalAddress, fasterSession, currentCtx); } /// <summary> @@ -1716,6 +1664,7 @@ ref pendingContext.input.Get(), request.id = pendingContext.id; request.request_key = pendingContext.key; request.logicalAddress = pendingContext.logicalAddress; + request.minAddress = pendingContext.minAddress; request.record = default; if (asyncOp) request.asyncOperation = new TaskCompletionSource<AsyncIOContext<Key, Value>>(TaskCreationOptions.RunContinuationsAsynchronously); @@ -1786,7 +1735,7 @@ private void HeavyEnter<Input, Output, Context, FasterSession>(long hash, Faster { // We spin-wait as a simplification // Could instead do a "heavy operation" here - while (systemState.phase != Phase.IN_PROGRESS_GROW) + while (systemState.Phase != Phase.IN_PROGRESS_GROW) Thread.SpinWait(100); InternalRefresh(ctx, session); } @@ -1890,6 +1839,172 @@ private bool TraceBackForKeyMatch( foundPhysicalAddress = Constants.kInvalidAddress; return false; } + + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal OperationStatus InternalCopyToTail<Input, Output, Context, FasterSession>( + ref Key key, ref Value value, + ref RecordInfo recordInfo, + long expectedLogicalAddress, + FasterSession fasterSession, + FasterExecutionContext<Input, Output, Context> currentCtx, + bool noReadCache = false) + where FasterSession : IFasterSession<Key, Value, Input, Output, Context> + { + OperationStatus internalStatus; + do + internalStatus = InternalTryCopyToTail(currentCtx, ref key, ref value, ref recordInfo, expectedLogicalAddress, fasterSession, currentCtx, noReadCache); + while (internalStatus == OperationStatus.RETRY_NOW); + return internalStatus; + } + + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal OperationStatus InternalTryCopyToTail<Input, Output, Context, FasterSession>( + ref Key key, ref Value value, + ref RecordInfo recordInfo, + long foundLogicalAddress, + FasterSession fasterSession, + FasterExecutionContext<Input, Output, Context> currentCtx, + bool noReadCache = false) + where FasterSession : IFasterSession<Key, Value, Input, Output, Context> + => InternalTryCopyToTail(currentCtx, ref key, ref value, ref recordInfo, foundLogicalAddress, fasterSession, currentCtx, noReadCache); + + /// <summary> + /// Helper function for trying to copy existing immutable records (at foundLogicalAddress) to the tail, + /// used in <see cref="InternalRead{Input, Output, Context, Functions}(ref Key, ref Input, ref Output, long, ref Context, ref PendingContext{Input, Output, Context}, Functions, FasterExecutionContext{Input, Output, Context}, long)"/> + /// <see cref="InternalContinuePendingReadCopyToTail{Input, Output, Context, FasterSession}(FasterExecutionContext{Input, Output, Context}, AsyncIOContext{Key, Value}, ref PendingContext{Input, Output, Context}, FasterSession, FasterExecutionContext{Input, Output, Context})"/>, + /// and <see cref="ClientSession{Key, Value, Input, Output, Context, Functions}.CopyToTail(ref Key, ref Value, ref RecordInfo, long)"/> + /// + /// Succeed only if the record for the same key hasn't changed. + /// </summary> + /// <typeparam name="Input"></typeparam> + /// <typeparam name="Output"></typeparam> + /// <typeparam name="Context"></typeparam> + /// <typeparam name="FasterSession"></typeparam> + /// <param name="opCtx"> + /// The thread(or session) context to execute operation in. + /// It's different from currentCtx only when the function is used in InternalContinuePendingReadCopyToTail + /// </param> + /// <param name="key"></param> + /// <param name="value"></param> + /// <param name="recordInfo"></param> + /// <param name="expectedLogicalAddress"> + /// The expected address of the record being copied. + /// </param> + /// <param name="fasterSession"></param> + /// <param name="currentCtx"></param> + /// <param name="noReadCache"> + /// If true, it won't clutter read cache. + /// Otherwise, it still checks UseReadCache to determine whether to buffer in read cache. + /// It is useful in Compact. + /// </param> + /// <returns> + /// NOTFOUND: didn't find the expected record of the same key that isn't greater than expectedLogicalAddress + /// RETRY_NOW: failed. + /// SUCCESS: + /// </returns> + internal OperationStatus InternalTryCopyToTail<Input, Output, Context, FasterSession>( + FasterExecutionContext<Input, Output, Context> opCtx, + ref Key key, ref Value value, + ref RecordInfo recordInfo, + long expectedLogicalAddress, + FasterSession fasterSession, + FasterExecutionContext<Input, Output, Context> currentCtx, + bool noReadCache = false) + where FasterSession : IFasterSession<Key, Value, Input, Output, Context> + { + Debug.Assert(expectedLogicalAddress >= hlog.BeginAddress); + var bucket = default(HashBucket*); + var slot = default(int); + + var hash = comparer.GetHashCode64(ref key); + var tag = (ushort)((ulong)hash >> Constants.kHashTagShift); + + var entry = default(HashBucketEntry); + FindOrCreateTag(hash, tag, ref bucket, ref slot, ref entry, hlog.BeginAddress); + var logicalAddress = entry.Address; + var physicalAddress = default(long); + if (UseReadCache) + SkipReadCache(ref logicalAddress); + var latestLogicalAddress = logicalAddress; + + if (logicalAddress >= hlog.HeadAddress) + { + physicalAddress = hlog.GetPhysicalAddress(logicalAddress); + if (!comparer.Equals(ref key, ref hlog.GetKey(physicalAddress))) + { + logicalAddress = hlog.GetInfo(physicalAddress).PreviousAddress; + TraceBackForKeyMatch(ref key, + logicalAddress, + hlog.HeadAddress, + out logicalAddress, + out physicalAddress); + } + } + + if (logicalAddress > expectedLogicalAddress || logicalAddress < hlog.BeginAddress) + { + // We give up early. + // Note: In Compact, expectedLogicalAddress may not exactly match the source of this copy operation, + // but instead only an upper bound. + return OperationStatus.NOTFOUND; + } + #region Create new copy in mutable region + var (actualSize, allocatedSize) = hlog.GetRecordSize(ref key, ref value); + + long newLogicalAddress, newPhysicalAddress; + bool copyToReadCache = !noReadCache && UseReadCache; + if (copyToReadCache) + { + BlockAllocateReadCache(allocatedSize, out newLogicalAddress, currentCtx, fasterSession); + newPhysicalAddress = readcache.GetPhysicalAddress(newLogicalAddress); + RecordInfo.WriteInfo(ref readcache.GetInfo(newPhysicalAddress), opCtx.version, + tombstone: false, invalidBit: false, + entry.Address); + readcache.Serialize(ref key, newPhysicalAddress); + fasterSession.SingleWriter(ref key, + ref value, + ref readcache.GetValue(newPhysicalAddress, newPhysicalAddress + actualSize)); + } + else + { + BlockAllocate(allocatedSize, out newLogicalAddress, currentCtx, fasterSession); + newPhysicalAddress = hlog.GetPhysicalAddress(newLogicalAddress); + RecordInfo.WriteInfo(ref hlog.GetInfo(newPhysicalAddress), opCtx.version, + tombstone: false, invalidBit: false, + latestLogicalAddress); + hlog.Serialize(ref key, newPhysicalAddress); + fasterSession.SingleWriter(ref key, + ref value, + ref hlog.GetValue(newPhysicalAddress, newPhysicalAddress + actualSize)); + } + + + var updatedEntry = default(HashBucketEntry); + updatedEntry.Tag = tag; + updatedEntry.Address = newLogicalAddress & Constants.kAddressMask; + updatedEntry.Pending = entry.Pending; + updatedEntry.Tentative = false; + updatedEntry.ReadCache = copyToReadCache; + + var foundEntry = default(HashBucketEntry); + foundEntry.word = Interlocked.CompareExchange( + ref bucket->bucket_entries[slot], + updatedEntry.word, + entry.word); + if (foundEntry.word != entry.word) + { + if (!copyToReadCache) hlog.GetInfo(newPhysicalAddress).Invalid = true; + // Note: only Compact actually retries; + // other operations, i.e., copy to tail during reads, just give up if the first try fails + return OperationStatus.RETRY_NOW; + } + else + return OperationStatus.SUCCESS; + #endregion + } + #endregion #region Split Index diff --git a/cs/src/core/Index/FASTER/FASTERLegacy.cs b/cs/src/core/Index/FASTER/FASTERLegacy.cs index 6bf230bc7..e8d55a9fd 100644 --- a/cs/src/core/Index/FASTER/FASTERLegacy.cs +++ b/cs/src/core/Index/FASTER/FASTERLegacy.cs @@ -211,7 +211,7 @@ public bool CompleteCheckpoint(bool spinWait = false) do { CompletePending(); - if (_fasterKV.systemState.phase == Phase.REST) + if (_fasterKV.systemState.Phase == Phase.REST) { CompletePending(); return true; @@ -228,7 +228,7 @@ private Guid InternalAcquire() { _fasterKV.epoch.Resume(); _threadCtx.InitializeThread(); - Phase phase = _fasterKV.systemState.phase; + Phase phase = _fasterKV.systemState.Phase; if (phase != Phase.REST) { throw new FasterException("Can acquire only in REST phase!"); diff --git a/cs/src/core/Index/FASTER/FASTERThread.cs b/cs/src/core/Index/FASTER/FASTERThread.cs index e4afe021a..60d7f7be6 100644 --- a/cs/src/core/Index/FASTER/FASTERThread.cs +++ b/cs/src/core/Index/FASTER/FASTERThread.cs @@ -23,7 +23,7 @@ internal CommitPoint InternalContinue<Input, Output, Context>(string guid, out F // We have recovered the corresponding session. // Now obtain the session by first locking the rest phase var currentState = SystemState.Copy(ref systemState); - if (currentState.phase == Phase.REST) + if (currentState.Phase == Phase.REST) { var intermediateState = SystemState.MakeIntermediate(currentState); if (MakeTransition(currentState, intermediateState)) @@ -70,7 +70,7 @@ internal void InternalRefresh<Input, Output, Context, FasterSession>(FasterExecu // We check if we are in normal mode var newPhaseInfo = SystemState.Copy(ref systemState); - if (ctx.phase == Phase.REST && newPhaseInfo.phase == Phase.REST && ctx.version == newPhaseInfo.version) + if (ctx.phase == Phase.REST && newPhaseInfo.Phase == Phase.REST && ctx.version == newPhaseInfo.Version) { return; } @@ -177,7 +177,7 @@ internal bool InternalCompletePending<Input, Output, Context, FasterSession>( } } - internal bool InRestPhase() => systemState.phase == Phase.REST; + internal bool InRestPhase() => systemState.Phase == Phase.REST; #region Complete Retry Requests internal void InternalCompleteRetryRequests<Input, Output, Context, FasterSession>( diff --git a/cs/src/core/Index/FASTER/LogAccessor.cs b/cs/src/core/Index/FASTER/LogAccessor.cs index 6cf55e0ac..e93275d33 100644 --- a/cs/src/core/Index/FASTER/LogAccessor.cs +++ b/cs/src/core/Index/FASTER/LogAccessor.cs @@ -302,11 +302,13 @@ public long Compact<Input, Output, Context, Functions, CompactionFunctions>(Func if (untilAddress > fht.Log.SafeReadOnlyAddress) throw new FasterException("Can compact only until Log.SafeReadOnlyAddress"); var originalUntilAddress = untilAddress; + var expectedAddress = untilAddress; var lf = new LogCompactionFunctions<Key, Value, Input, Output, Context, Functions>(functions); using var fhtSession = fht.For(lf).NewSession<LogCompactionFunctions<Key, Value, Input, Output, Context, Functions>>(); VariableLengthStructSettings<Key, Value> variableLengthStructSettings = null; + VariableLengthStructSettings<Key, long> variableLengthStructSettingsKaddr = null; if (allocator is VariableLengthBlittableAllocator<Key, Value> varLen) { variableLengthStructSettings = new VariableLengthStructSettings<Key, Value> @@ -314,6 +316,11 @@ public long Compact<Input, Output, Context, Functions, CompactionFunctions>(Func keyLength = varLen.KeyLength, valueLength = varLen.ValueLength, }; + variableLengthStructSettingsKaddr = new VariableLengthStructSettings<Key, long> + { + keyLength = varLen.KeyLength, + valueLength = null, + }; } using (var tempKv = new FasterKV<Key, Value>(fht.IndexSize, new LogSettings { LogDevice = new NullDevice(), ObjectLogDevice = new NullDevice() }, comparer: fht.Comparer, variableLengthStructSettings: variableLengthStructSettings)) @@ -327,12 +334,23 @@ public long Compact<Input, Output, Context, Functions, CompactionFunctions>(Func ref var value = ref iter1.GetValue(); if (recordInfo.Tombstone || cf.IsDeleted(key, value)) + { tempKvSession.Delete(ref key, default, 0); + } else + { tempKvSession.Upsert(ref key, ref value, default, 0); + // below is to get and preserve information in RecordInfo, if we need. + /*tempKvSession.ContainsKeyInMemory(ref key, out long logicalAddress); + long physicalAddress = tempKv.hlog.GetPhysicalAddress(logicalAddress); + ref var tempRecordInfo = ref tempKv.hlog.GetInfo(physicalAddress); + RecordInfo.WriteInfo(ref tempRecordInfo, + tempRecordInfo.Version, false, false, tempRecordInfo.PreviousAddress);*/ + } } // Ensure address is at record boundary untilAddress = originalUntilAddress = iter1.NextAddress; + expectedAddress = untilAddress; } // Scan until SafeReadOnlyAddress @@ -374,8 +392,10 @@ public long Compact<Input, Output, Context, Functions, CompactionFunctions>(Func // Possibly deleted key (once ContainsKeyInMemory is updated to check Tombstones) continue; } - - fhtSession.Upsert(ref iter3.GetKey(), ref iter3.GetValue(), default, 0); + // Note: we use untilAddress as expectedAddress here. + // As long as there's no record of the same key whose address is greater than untilAddress, + // i.e., the last address that this compact covers, we are safe to copy the old record to the tail. + fhtSession.CopyToTail(ref iter3.GetKey(), ref iter3.GetValue(), ref recordInfo, expectedAddress); } } } diff --git a/cs/src/core/Index/FasterLog/FasterLog.cs b/cs/src/core/Index/FasterLog/FasterLog.cs index 78dbb2a6c..99a3536db 100644 --- a/cs/src/core/Index/FasterLog/FasterLog.cs +++ b/cs/src/core/Index/FasterLog/FasterLog.cs @@ -4,6 +4,7 @@ #pragma warning disable 0162 using System; +using System.Buffers; using System.Collections.Concurrent; using System.Collections.Generic; using System.Diagnostics; @@ -843,6 +844,67 @@ public FasterLogScanIterator Scan(long beginAddress, long endAddress, string nam return GetRecordAndFree(ctx.record); } + /// <summary> + /// Random read record from log as IMemoryOwner<byte>, at given address + /// </summary> + /// <param name="address">Logical address to read from</param> + /// <param name="memoryPool">MemoryPool to rent the destination buffer from</param> + /// <param name="estimatedLength">Estimated length of entry, if known</param> + /// <param name="token">Cancellation token</param> + /// <returns></returns> + public async ValueTask<(IMemoryOwner<byte>, int)> ReadAsync(long address, MemoryPool<byte> memoryPool, int estimatedLength = 0, CancellationToken token = default) + { + token.ThrowIfCancellationRequested(); + epoch.Resume(); + if (address >= CommittedUntilAddress || address < BeginAddress) + { + epoch.Suspend(); + return default; + } + var ctx = new SimpleReadContext + { + logicalAddress = address, + completedRead = new SemaphoreSlim(0) + }; + unsafe + { + allocator.AsyncReadRecordToMemory(address, headerSize + estimatedLength, AsyncGetFromDiskCallback, ref ctx); + } + epoch.Suspend(); + await ctx.completedRead.WaitAsync(token).ConfigureAwait(false); + return GetRecordAsMemoryOwnerAndFree(ctx.record, memoryPool); + } + + /// <summary> + /// Random read record from log, at given address + /// </summary> + /// <param name="address">Logical address to read from</param> + /// <param name="token">Cancellation token</param> + /// <returns></returns> + public async ValueTask<int> ReadRecordLengthAsync(long address, CancellationToken token = default) + { + token.ThrowIfCancellationRequested(); + epoch.Resume(); + if (address >= CommittedUntilAddress || address < BeginAddress) + { + epoch.Suspend(); + return default; + } + var ctx = new SimpleReadContext + { + logicalAddress = address, + completedRead = new SemaphoreSlim(0) + }; + unsafe + { + allocator.AsyncReadRecordToMemory(address, headerSize, AsyncGetHeaderOnlyFromDiskCallback, ref ctx); + } + epoch.Suspend(); + await ctx.completedRead.WaitAsync(token).ConfigureAwait(false); + + return GetRecordLengthAndFree(ctx.record); + } + [MethodImpl(MethodImplOptions.AggressiveInlining)] private int Align(int length) { @@ -1172,6 +1234,29 @@ private unsafe void AsyncGetFromDiskCallback(uint errorCode, uint numBytes, obje } } + private void AsyncGetHeaderOnlyFromDiskCallback(uint errorCode, uint numBytes, object context) + { + var ctx = (SimpleReadContext)context; + + if (errorCode != 0) + { + Trace.TraceError("AsyncGetFromDiskCallback error: {0}", errorCode); + ctx.record.Return(); + ctx.record = null; + ctx.completedRead.Release(); + } + else + { + if (ctx.record.available_bytes < headerSize) + { + Debug.WriteLine("No record header present at address: " + ctx.logicalAddress); + ctx.record.Return(); + ctx.record = null; + } + ctx.completedRead.Release(); + } + } + private (byte[], int) GetRecordAndFree(SectorAlignedMemory record) { if (record == null) @@ -1197,6 +1282,54 @@ private unsafe void AsyncGetFromDiskCallback(uint errorCode, uint numBytes, obje return (result, length); } + private (IMemoryOwner<byte>, int) GetRecordAsMemoryOwnerAndFree(SectorAlignedMemory record, MemoryPool<byte> memoryPool) + { + if (record == null) + return (null, 0); + + IMemoryOwner<byte> result; + int length; + unsafe + { + var ptr = record.GetValidPointer(); + length = GetLength(ptr); + if (!VerifyChecksum(ptr, length)) + { + throw new FasterException("Checksum failed for read"); + } + result = memoryPool.Rent(length); + + fixed (byte* bp = result.Memory.Span) + { + Buffer.MemoryCopy(ptr + headerSize, bp, length, length); + } + } + + record.Return(); + return (result, length); + } + + private int GetRecordLengthAndFree(SectorAlignedMemory record) + { + if (record == null) + return 0; + + int length; + unsafe + { + var ptr = record.GetValidPointer(); + length = GetLength(ptr); + + if (!VerifyChecksum(ptr, length)) + { + throw new FasterException("Checksum failed for read"); + } + } + + record.Return(); + return length; + } + private long CommitInternal(bool spinWait = false) { if (readOnlyMode) diff --git a/cs/src/core/Index/FasterLog/FasterLogIterator.cs b/cs/src/core/Index/FasterLog/FasterLogIterator.cs index 707e916cc..77395d25e 100644 --- a/cs/src/core/Index/FasterLog/FasterLogIterator.cs +++ b/cs/src/core/Index/FasterLog/FasterLogIterator.cs @@ -291,6 +291,20 @@ public void CompleteUntil(long address) Utility.MonotonicUpdate(ref requestedCompletedUntilAddress, address, out _); } + /// <summary> + /// Mark iterator complete until the end of the record at specified + /// address. Info is not persisted until a subsequent commit operation + /// on the log. Note: this is slower than CompleteUntil() because the + /// record's length needs to be looked up first. + /// </summary> + /// <param name="recordStartAddress"></param> + /// <param name="token"></param> + public async ValueTask CompleteUntilRecordAtAsync(long recordStartAddress, CancellationToken token = default) + { + int len = await fasterLog.ReadRecordLengthAsync(recordStartAddress, token: token); + CompleteUntil(recordStartAddress + headerSize + len); + } + internal void UpdateCompletedUntilAddress(long address) { Utility.MonotonicUpdate(ref CompletedUntilAddress, address, out _); diff --git a/cs/src/core/Index/Interfaces/IFasterKV.cs b/cs/src/core/Index/Interfaces/IFasterKV.cs index 940e4afd4..30c61c442 100644 --- a/cs/src/core/Index/Interfaces/IFasterKV.cs +++ b/cs/src/core/Index/Interfaces/IFasterKV.cs @@ -78,31 +78,46 @@ public AdvancedClientSession<Key, Value, Input, Output, Context, IAdvancedFuncti /// Initiate full (index + log) checkpoint of FASTER /// </summary> /// <param name="token">Token describing checkpoint</param> + /// <param name="targetVersion"> + /// intended version number of the next version. Checkpoint will not execute if supplied version is not larger + /// than current version. Actual new version may have version number greater than supplied number. If the supplied + /// number is -1, checkpoint will unconditionally create a new version. + /// </param> /// <returns>Whether we successfully initiated the checkpoint (initiation may fail if we are already taking a checkpoint or performing some other /// operation such as growing the index). Use CompleteCheckpointAsync to await completion.</returns> /// <remarks>Uses the checkpoint type specified in the <see cref="CheckpointSettings"/></remarks> - bool TakeFullCheckpoint(out Guid token); + bool TakeFullCheckpoint(out Guid token, long targetVersion = -1); /// <summary> /// Initiate full (index + log) checkpoint of FASTER /// </summary> /// <param name="token">Token describing checkpoint</param> /// <param name="checkpointType">The checkpoint type to use (ignores the checkpoint type specified in the <see cref="CheckpointSettings"/>)</param> + /// <param name="targetVersion"> + /// intended version number of the next version. Checkpoint will not execute if supplied version is not larger + /// than current version. Actual new version may have version number greater than supplied number. If the supplied + /// number is -1, checkpoint will unconditionally create a new version. + /// </param> /// <returns>Whether we successfully initiated the checkpoint (initiation mayfail if we are already taking a checkpoint or performing some other /// operation such as growing the index). Use CompleteCheckpointAsync to await completion.</returns> - public bool TakeFullCheckpoint(out Guid token, CheckpointType checkpointType); + public bool TakeFullCheckpoint(out Guid token, CheckpointType checkpointType, long targetVersion = -1); /// <summary> /// Take full (index + log) checkpoint of FASTER asynchronously /// </summary> /// <param name="checkpointType">The checkpoint type to use (ignores the checkpoint type specified in the <see cref="CheckpointSettings"/>)</param> /// <param name="cancellationToken">A token to cancel the operation</param> + /// <param name="targetVersion"> + /// intended version number of the next version. Checkpoint will not execute if supplied version is not larger + /// than current version. Actual new version may have version number greater than supplied number. If the supplied + /// number is -1, checkpoint will unconditionally create a new version. + /// </param> /// <returns>A (bool success, Guid token) tuple. /// success: Whether we successfully initiated the checkpoint (initiation may fail if we are already taking a checkpoint or performing some other /// operation such as growing the index). /// token: Token for taken checkpoint. /// Await the task to complete checkpoint, if initiated successfully</returns> - public ValueTask<(bool success, Guid token)> TakeFullCheckpointAsync(CheckpointType checkpointType, CancellationToken cancellationToken = default); + public ValueTask<(bool success, Guid token)> TakeFullCheckpointAsync(CheckpointType checkpointType, CancellationToken cancellationToken = default, long targetVersion = -1); /// <summary> /// Initiate checkpoint of FASTER index only (not log) @@ -126,8 +141,13 @@ public AdvancedClientSession<Key, Value, Input, Output, Context, IAdvancedFuncti /// Initiate checkpoint of FASTER log only (not index) /// </summary> /// <param name="token">Token describing checkpoint</param> + /// <param name="targetVersion"> + /// intended version number of the next version. Checkpoint will not execute if supplied version is not larger + /// than current version. Actual new version may have version number greater than supplied number. If the supplied + /// number is -1, checkpoint will unconditionally create a new version. + /// </param> /// <returns>Whether we could initiate the checkpoint. Use CompleteCheckpointAsync to await completion.</returns> - bool TakeHybridLogCheckpoint(out Guid token); + bool TakeHybridLogCheckpoint(out Guid token, long targetVersion = -1); /// <summary> /// Take asynchronous checkpoint of FASTER log only (not index) @@ -135,9 +155,14 @@ public AdvancedClientSession<Key, Value, Input, Output, Context, IAdvancedFuncti /// <param name="token">Token describing checkpoint</param> /// <param name="checkpointType">The checkpoint type to use (ignores the checkpoint type specified in the <see cref="CheckpointSettings"/>)</param> /// <param name="tryIncremental">For snapshot, try to store as incremental delta over last snapshot</param> + /// <param name="targetVersion"> + /// intended version number of the next version. Checkpoint will not execute if supplied version is not larger + /// than current version. Actual new version may have version number greater than supplied number. If the supplied + /// number is -1, checkpoint will unconditionally create a new version. + /// </param> /// <returns>Whether we successfully initiated the checkpoint (initiation mayfail if we are already taking a checkpoint or performing some other /// operation such as growing the index). Use CompleteCheckpointAsync to await completion.</returns> - public bool TakeHybridLogCheckpoint(out Guid token, CheckpointType checkpointType, bool tryIncremental = false); + public bool TakeHybridLogCheckpoint(out Guid token, CheckpointType checkpointType, bool tryIncremental = false, long targetVersion = -1); /// <summary> /// Initiate checkpoint of FASTER log only (not index) @@ -145,27 +170,35 @@ public AdvancedClientSession<Key, Value, Input, Output, Context, IAdvancedFuncti /// <param name="checkpointType">The checkpoint type to use (ignores the checkpoint type specified in the <see cref="CheckpointSettings"/>)</param> /// <param name="tryIncremental">For snapshot, try to store as incremental delta over last snapshot</param> /// <param name="cancellationToken">A token to cancel the operation</param> + /// <param name="targetVersion"> + /// intended version number of the next version. Checkpoint will not execute if supplied version is not larger + /// than current version. Actual new version may have version number greater than supplied number. If the supplied + /// number is -1, checkpoint will unconditionally create a new version. + /// </param> /// <returns>A (bool success, Guid token) tuple. /// success: Whether we successfully initiated the checkpoint (initiation may fail if we are already taking a checkpoint or performing some other /// operation such as growing the index). /// token: Token for taken checkpoint. /// Await the task to complete checkpoint, if initiated successfully</returns> - public ValueTask<(bool success, Guid token)> TakeHybridLogCheckpointAsync(CheckpointType checkpointType, bool tryIncremental = false, CancellationToken cancellationToken = default); + public ValueTask<(bool success, Guid token)> TakeHybridLogCheckpointAsync(CheckpointType checkpointType, bool tryIncremental = false, CancellationToken cancellationToken = default, long targetVersion = -1); /// <summary> /// Recover from last successful index and log checkpoints /// </summary> /// <param name="numPagesToPreload">Number of pages to preload into memory after recovery</param> /// <param name="undoNextVersion">Whether records with version after checkpoint version need to be undone (and invalidated on log)</param> - void Recover(int numPagesToPreload = -1, bool undoNextVersion = true); + /// <param name="recoverTo"> specific version requested within the checkpoint, if checkpoint supports multiple versions (e.g. incremental snapshot checkpoints), or -1 for latest version</param> + + void Recover(int numPagesToPreload = -1, bool undoNextVersion = true, long recoverTo = -1); /// <summary> /// Asynchronously recover from last successful index and log checkpoint /// </summary> /// <param name="numPagesToPreload">Number of pages to preload into memory after recovery</param> /// <param name="undoNextVersion">Whether records with version after checkpoint version need to be undone (and invalidated on log)</param> + /// <param name="recoverTo"> specific version requested within the checkpoint, if checkpoint supports multiple versions (e.g. incremental snapshot checkpoints), or -1 for latest version</param> /// <param name="cancellationToken">Cancellation token</param> - ValueTask RecoverAsync(int numPagesToPreload = -1, bool undoNextVersion = true, CancellationToken cancellationToken = default); + ValueTask RecoverAsync(int numPagesToPreload = -1, bool undoNextVersion = true, long recoverTo = -1, CancellationToken cancellationToken = default); /// <summary> /// Recover using full checkpoint token diff --git a/cs/src/core/Index/Recovery/Checkpoint.cs b/cs/src/core/Index/Recovery/Checkpoint.cs index 09e1598d4..7409521f4 100644 --- a/cs/src/core/Index/Recovery/Checkpoint.cs +++ b/cs/src/core/Index/Recovery/Checkpoint.cs @@ -6,6 +6,8 @@ //#define WAIT_FOR_INDEX_CHECKPOINT using System; +using System.Linq; +using System.Text; using System.Threading; using System.Threading.Tasks; @@ -59,12 +61,24 @@ internal void AcquireSharedLatchesForAllPendingRequests<Input, Output, Context>( internal void WriteHybridLogMetaInfo() { - checkpointManager.CommitLogCheckpoint(_hybridLogCheckpointToken, _hybridLogCheckpoint.info.ToByteArray()); + var metadata = _hybridLogCheckpoint.info.ToByteArray(); + if (CommitCookie != null && CommitCookie.Length != 0) + { + var convertedCookie = Convert.ToBase64String(CommitCookie); + metadata = metadata.Concat(Encoding.Default.GetBytes(convertedCookie)).ToArray(); + } + checkpointManager.CommitLogCheckpoint(_hybridLogCheckpointToken, metadata); } internal void WriteHybridLogIncrementalMetaInfo(DeltaLog deltaLog) { - checkpointManager.CommitLogIncrementalCheckpoint(_hybridLogCheckpointToken, _hybridLogCheckpoint.info.version, _hybridLogCheckpoint.info.ToByteArray(), deltaLog); + var metadata = _hybridLogCheckpoint.info.ToByteArray(); + if (CommitCookie != null && CommitCookie.Length != 0) + { + var convertedCookie = Convert.ToBase64String(CommitCookie); + metadata = metadata.Concat(Encoding.Default.GetBytes(convertedCookie)).ToArray(); + } + checkpointManager.CommitLogIncrementalCheckpoint(_hybridLogCheckpointToken, _hybridLogCheckpoint.info.version, metadata, deltaLog); } internal void WriteIndexMetaInfo() diff --git a/cs/src/core/Index/Recovery/DeltaLog.cs b/cs/src/core/Index/Recovery/DeltaLog.cs index 385fdb072..60c4c2a2d 100644 --- a/cs/src/core/Index/Recovery/DeltaLog.cs +++ b/cs/src/core/Index/Recovery/DeltaLog.cs @@ -10,6 +10,22 @@ namespace FASTER.core { + /// <summary> + /// The type of a record in the delta (incremental) log + /// </summary> + public enum DeltaLogEntryType : int + { + /// <summary> + /// The entry is a delta record + /// </summary> + DELTA, + + /// <summary> + /// The entry is checkpoint metadata + /// </summary> + CHECKPOINT_METADATA + } + [StructLayout(LayoutKind.Explicit, Size = DeltaLog.HeaderSize)] struct DeltalogHeader { @@ -18,7 +34,7 @@ struct DeltalogHeader [FieldOffset(8)] public int Length; [FieldOffset(12)] - public int Type; + public DeltaLogEntryType Type; } /// <summary> @@ -177,14 +193,14 @@ private long Align(long length) /// <param name="entryLength"></param> /// <param name="type"></param> /// <returns></returns> - public unsafe bool GetNext(out long physicalAddress, out int entryLength, out int type) + public unsafe bool GetNext(out long physicalAddress, out int entryLength, out DeltaLogEntryType type) { while (true) { physicalAddress = 0; entryLength = 0; currentAddress = nextAddress; - type = 0; + type = DeltaLogEntryType.DELTA; var _currentPage = currentAddress >> LogPageSizeBits; var _currentFrame = _currentPage % frameSize; @@ -264,7 +280,7 @@ private unsafe static bool VerifyBlockChecksum(byte* ptr, int length) } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private unsafe static void SetBlockHeader(int length, int type, byte* dest) + private unsafe static void SetBlockHeader(int length, DeltaLogEntryType type, byte* dest) { ref var header = ref GetHeader((long)dest); header.Length = length; @@ -301,7 +317,7 @@ public unsafe void Allocate(out int maxEntryLength, out long physicalAddress) /// </summary> /// <param name="entryLength">Entry length</param> /// <param name="type">Optional record type</param> - public unsafe void Seal(int entryLength, int type = 0) + public unsafe void Seal(int entryLength, DeltaLogEntryType type = DeltaLogEntryType.DELTA) { if (entryLength > 0) { diff --git a/cs/src/core/Index/Recovery/ICheckpointManager.cs b/cs/src/core/Index/Recovery/ICheckpointManager.cs index 685c9a500..a2ccfc506 100644 --- a/cs/src/core/Index/Recovery/ICheckpointManager.cs +++ b/cs/src/core/Index/Recovery/ICheckpointManager.cs @@ -79,8 +79,10 @@ public interface ICheckpointManager : IDisposable /// </summary> /// <param name="logToken">Token</param> /// <param name="deltaLog">Delta log</param> + /// <param name="scanDelta"> whether or not to scan through the delta log to acquire latest entry</param> + /// <param name="recoverTo"> version upper bound to scan for in the delta log. Function will return the largest version metadata no greater than the given version.</param> /// <returns>Metadata, or null if invalid</returns> - byte[] GetLogCheckpointMetadata(Guid logToken, DeltaLog deltaLog); + byte[] GetLogCheckpointMetadata(Guid logToken, DeltaLog deltaLog, bool scanDelta, long recoverTo); /// <summary> /// Get list of index checkpoint tokens, in order of usage preference @@ -94,7 +96,6 @@ public interface ICheckpointManager : IDisposable /// <returns></returns> public IEnumerable<Guid> GetLogCheckpointTokens(); - /// <summary> /// Provide device to store index checkpoint (including overflow buckets) /// </summary> @@ -123,6 +124,11 @@ public interface ICheckpointManager : IDisposable /// <returns></returns> IDevice GetDeltaLogDevice(Guid token); + /// <summary> + /// Cleanup all data (subfolder) related to the given guid by this manager + /// </summary> + public void Purge(Guid token); + /// <summary> /// Cleanup all data (subfolder) related to checkpoints by this manager /// </summary> diff --git a/cs/src/core/Index/Recovery/LocalCheckpointManager.cs b/cs/src/core/Index/Recovery/LocalCheckpointManager.cs index dc0db7142..079826f17 100644 --- a/cs/src/core/Index/Recovery/LocalCheckpointManager.cs +++ b/cs/src/core/Index/Recovery/LocalCheckpointManager.cs @@ -31,6 +31,14 @@ public void PurgeAll() { try { new DirectoryInfo(directoryConfiguration.checkpointDir).Delete(true); } catch { } } + + /// <inheritdoc /> + public void Purge(Guid token) + { + // Try both because we don't know which one + try { new DirectoryInfo(directoryConfiguration.GetHybridLogCheckpointFolder(token)).Delete(true); } catch { } + try { new DirectoryInfo(directoryConfiguration.GetIndexCheckpointFolder(token)).Delete(true); } catch { } + } /// <summary> /// Initialize index checkpoint @@ -114,26 +122,45 @@ public byte[] GetIndexCheckpointMetadata(Guid indexToken) /// </summary> /// <param name="logToken">Token</param> /// <param name="deltaLog">Delta log</param> + /// <param name="scanDelta"> whether or not to scan through the delta log to acquire latest entry</param> + /// <param name="recoverTo"> version upper bound to scan for in the delta log. Function will return the largest version metadata no greater than the given version.</param> /// <returns>Metadata, or null if invalid</returns> - public byte[] GetLogCheckpointMetadata(Guid logToken, DeltaLog deltaLog) + public byte[] GetLogCheckpointMetadata(Guid logToken, DeltaLog deltaLog, bool scanDelta, long recoverTo) { byte[] metadata = null; - if (deltaLog != null) + if (scanDelta && deltaLog != null) { // Get latest valid metadata from delta-log deltaLog.Reset(); - while (deltaLog.GetNext(out long physicalAddress, out int entryLength, out int type)) + while (deltaLog.GetNext(out long physicalAddress, out int entryLength, out DeltaLogEntryType type)) { - if (type != 1) continue; // consider only metadata records - long endAddress = physicalAddress + entryLength; - metadata = new byte[entryLength]; - unsafe + switch (type) { - fixed (byte* m = metadata) - { - Buffer.MemoryCopy((void*)physicalAddress, m, entryLength, entryLength); - } + case DeltaLogEntryType.DELTA: + // consider only metadata records + continue; + case DeltaLogEntryType.CHECKPOINT_METADATA: + metadata = new byte[entryLength]; + unsafe + { + fixed (byte* m = metadata) + { + Buffer.MemoryCopy((void*)physicalAddress, m, entryLength, entryLength); + } + } + HybridLogRecoveryInfo recoveryInfo = new(); + using (StreamReader s = new(new MemoryStream(metadata))) { + recoveryInfo.Initialize(s); + // Finish recovery if only specific versions are requested + if (recoveryInfo.version == recoverTo || recoveryInfo.version < recoverTo && recoveryInfo.nextVersion > recoverTo) goto LoopEnd; + } + continue; + default: + throw new FasterException("Unexpected entry type"); } + LoopEnd: + break; + } if (metadata != null) return metadata; } @@ -255,7 +282,7 @@ public unsafe void CommitLogIncrementalCheckpoint(Guid logToken, int version, by deltaLog.Allocate(out int length, out long physicalAddress); if (length < commitMetadata.Length) { - deltaLog.Seal(0, type: 1); + deltaLog.Seal(0, DeltaLogEntryType.CHECKPOINT_METADATA); deltaLog.Allocate(out length, out physicalAddress); if (length < commitMetadata.Length) { @@ -267,7 +294,7 @@ public unsafe void CommitLogIncrementalCheckpoint(Guid logToken, int version, by { Buffer.MemoryCopy(ptr, (void*)physicalAddress, commitMetadata.Length, commitMetadata.Length); } - deltaLog.Seal(commitMetadata.Length, type: 1); + deltaLog.Seal(commitMetadata.Length, DeltaLogEntryType.CHECKPOINT_METADATA); deltaLog.FlushAsync().Wait(); } diff --git a/cs/src/core/Index/Recovery/Recovery.cs b/cs/src/core/Index/Recovery/Recovery.cs index 7bc258fae..47a52d536 100644 --- a/cs/src/core/Index/Recovery/Recovery.cs +++ b/cs/src/core/Index/Recovery/Recovery.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Concurrent; +using System.Collections.Generic; using System.Diagnostics; using System.Runtime.CompilerServices; using System.Threading; @@ -17,7 +18,6 @@ internal enum FlushStatus { Pending, Done }; internal class RecoveryStatus { - public long startPage; public long endPage; public long snapshotEndPage; public long untilAddress; @@ -27,7 +27,6 @@ internal class RecoveryStatus public IDevice recoveryDevice; public long recoveryDevicePageOffset; public IDevice objectLogRecoveryDevice; - public IDevice deltaRecoveryDevice; // These are circular buffers of 'capacity' size; the indexing wraps due to hlog.GetPageIndexForPage(). public ReadStatus[] readStatus; @@ -37,11 +36,9 @@ internal class RecoveryStatus private readonly SemaphoreSlim flushSemaphore = new(0); public RecoveryStatus(int capacity, - long startPage, long endPage, long untilAddress, CheckpointType checkpointType) { this.capacity = capacity; - this.startPage = startPage; this.endPage = endPage; this.untilAddress = untilAddress; this.checkpointType = checkpointType; @@ -101,35 +98,59 @@ internal void Dispose() { recoveryDevice.Dispose(); objectLogRecoveryDevice.Dispose(); - deltaRecoveryDevice?.Dispose(); } } public partial class FasterKV<Key, Value> : FasterBase, IFasterKV<Key, Value> { - private void InternalRecoverFromLatestCheckpoints(int numPagesToPreload, bool undoNextVersion) + private void FindRecoveryInfo(long requestedVersion, out HybridLogCheckpointInfo recoveredHlcInfo, + out IndexCheckpointInfo recoveredICInfo) { - GetRecoveryInfoFromLatestCheckpoints(out HybridLogCheckpointInfo recoveredHLCInfo, out IndexCheckpointInfo recoveredICInfo); - InternalRecover(recoveredICInfo, recoveredHLCInfo, numPagesToPreload, undoNextVersion); - } - - private ValueTask InternalRecoverFromLatestCheckpointsAsync(int numPagesToPreload, bool undoNextVersion, CancellationToken cancellationToken) - { - GetRecoveryInfoFromLatestCheckpoints(out HybridLogCheckpointInfo recoveredHLCInfo, out IndexCheckpointInfo recoveredICInfo); - return InternalRecoverAsync(recoveredICInfo, recoveredHLCInfo, numPagesToPreload, undoNextVersion, cancellationToken); - } - private void GetRecoveryInfoFromLatestCheckpoints(out HybridLogCheckpointInfo recoveredHLCInfo, out IndexCheckpointInfo recoveredICInfo) - { Debug.WriteLine("********* Primary Recovery Information ********"); - recoveredHLCInfo = default; + HybridLogCheckpointInfo current, closest = default; + Guid closestToken = default; + long closestVersion = long.MaxValue; + byte[] cookie = default; + + // Traverse through all current tokens to find either the largest version or the version that's closest to + // but smaller than the requested version. Need to iterate through all unpruned versions because file system + // is not guaranteed to return tokens in order of freshness. foreach (var hybridLogToken in checkpointManager.GetLogCheckpointTokens()) { try { - recoveredHLCInfo = new HybridLogCheckpointInfo(); - recoveredHLCInfo.Recover(hybridLogToken, checkpointManager, hlog.LogPageSizeBits); + current = new HybridLogCheckpointInfo(); + current.Recover(hybridLogToken, checkpointManager, hlog.LogPageSizeBits, + out var currCookie, false); + var distanceToTarget = (requestedVersion == -1 ? long.MaxValue : requestedVersion) - current.info.version; + // This is larger than intended version, cannot recover to this. + if (distanceToTarget < 0) continue; + // We have found the exact version to recover to --- the above conditional establishes that the + // checkpointed version is <= requested version, and if next version is larger than requestedVersion, + // there cannot be any closer version. + if (current.info.nextVersion > requestedVersion) + { + closest = current; + closestToken = hybridLogToken; + cookie = currCookie; + break; + } + + // Otherwise, write it down and wait to see if there's a closer one; + if (distanceToTarget < closestVersion) + { + closestVersion = distanceToTarget; + closest.Dispose(); + closest = current; + closestToken = hybridLogToken; + cookie = currCookie; + } + else + { + current.Dispose(); + } } catch { @@ -137,13 +158,20 @@ private void GetRecoveryInfoFromLatestCheckpoints(out HybridLogCheckpointInfo re } Debug.WriteLine("HybridLog Checkpoint: {0}", hybridLogToken); - break; } - if (recoveredHLCInfo.IsDefault()) - throw new FasterException("Unable to find valid index token"); + recoveredHlcInfo = closest; + recoveredCommitCookie = cookie; + if (recoveredHlcInfo.IsDefault()) + throw new FasterException("Unable to find valid HybridLog token"); - recoveredHLCInfo.info.DebugPrint(); + if (recoveredHlcInfo.deltaLog != null) + { + recoveredHlcInfo.Dispose(); + // need to actually scan delta log now + recoveredHlcInfo.Recover(closestToken, checkpointManager, hlog.LogPageSizeBits, out _, true); + } + recoveredHlcInfo.info.DebugPrint(); recoveredICInfo = default; foreach (var indexToken in checkpointManager.GetIndexCheckpointTokens()) @@ -159,7 +187,7 @@ private void GetRecoveryInfoFromLatestCheckpoints(out HybridLogCheckpointInfo re continue; } - if (!IsCompatible(recoveredICInfo.info, recoveredHLCInfo.info)) + if (!IsCompatible(recoveredICInfo.info, recoveredHlcInfo.info)) { recoveredICInfo = default; continue; @@ -183,16 +211,20 @@ private bool IsCompatible(in IndexRecoveryInfo indexInfo, in HybridLogRecoveryIn return l1 <= l2; } - private void InternalRecover(Guid indexToken, Guid hybridLogToken, int numPagesToPreload, bool undoNextVersion) + private void InternalRecover(Guid indexToken, Guid hybridLogToken, int numPagesToPreload, bool undoNextVersion, long recoverTo) { GetRecoveryInfo(indexToken, hybridLogToken, out HybridLogCheckpointInfo recoveredHLCInfo, out IndexCheckpointInfo recoveredICInfo); - InternalRecover(recoveredICInfo, recoveredHLCInfo, numPagesToPreload, undoNextVersion); + if (recoverTo != -1 && recoveredHLCInfo.deltaLog == null) + { + throw new FasterException("Recovering to a specific version within a token is only supported for incremental snapshots"); + } + InternalRecover(recoveredICInfo, recoveredHLCInfo, numPagesToPreload, undoNextVersion, recoverTo); } - private ValueTask InternalRecoverAsync(Guid indexToken, Guid hybridLogToken, int numPagesToPreload, bool undoNextVersion, CancellationToken cancellationToken) + private ValueTask InternalRecoverAsync(Guid indexToken, Guid hybridLogToken, int numPagesToPreload, bool undoNextVersion, long recoverTo, CancellationToken cancellationToken) { GetRecoveryInfo(indexToken, hybridLogToken, out HybridLogCheckpointInfo recoveredHLCInfo, out IndexCheckpointInfo recoveredICInfo); - return InternalRecoverAsync(recoveredICInfo, recoveredHLCInfo, numPagesToPreload, undoNextVersion, cancellationToken); + return InternalRecoverAsync(recoveredICInfo, recoveredHLCInfo, numPagesToPreload, undoNextVersion, recoverTo, cancellationToken); } private void GetRecoveryInfo(Guid indexToken, Guid hybridLogToken, out HybridLogCheckpointInfo recoveredHLCInfo, out IndexCheckpointInfo recoveredICInfo) @@ -204,7 +236,7 @@ private void GetRecoveryInfo(Guid indexToken, Guid hybridLogToken, out HybridLog // Recovery appropriate context information recoveredHLCInfo = new HybridLogCheckpointInfo(); - recoveredHLCInfo.Recover(hybridLogToken, checkpointManager, hlog.LogPageSizeBits); + recoveredHLCInfo.Recover(hybridLogToken, checkpointManager, hlog.LogPageSizeBits, out recoveredCommitCookie, true); recoveredHLCInfo.info.DebugPrint(); try { @@ -234,7 +266,7 @@ private void GetRecoveryInfo(Guid indexToken, Guid hybridLogToken, out HybridLog } } - private void InternalRecover(IndexCheckpointInfo recoveredICInfo, HybridLogCheckpointInfo recoveredHLCInfo, int numPagesToPreload, bool undoNextVersion) + private void InternalRecover(IndexCheckpointInfo recoveredICInfo, HybridLogCheckpointInfo recoveredHLCInfo, int numPagesToPreload, bool undoNextVersion, long recoverTo) { if (!RecoverToInitialPage(recoveredICInfo, recoveredHLCInfo, out long recoverFromAddress)) RecoverFuzzyIndex(recoveredICInfo); @@ -257,7 +289,7 @@ private void InternalRecover(IndexCheckpointInfo recoveredICInfo, HybridLogCheck // First recover from index starting point (fromAddress) to snapshot starting point (flushedLogicalAddress) RecoverHybridLog(scanFromAddress, recoverFromAddress, recoveredHLCInfo.info.flushedLogicalAddress, recoveredHLCInfo.info.nextVersion, CheckpointType.Snapshot, undoNextVersion); // Then recover snapshot into mutable region - RecoverHybridLogFromSnapshotFile(recoveredHLCInfo.info.flushedLogicalAddress, recoverFromAddress, recoveredHLCInfo.info.finalLogicalAddress, recoveredHLCInfo.info.startLogicalAddress, recoveredHLCInfo.info.snapshotFinalLogicalAddress, recoveredHLCInfo.info.nextVersion, recoveredHLCInfo.info.guid, undoNextVersion, recoveredHLCInfo.deltaLog); + RecoverHybridLogFromSnapshotFile(recoveredHLCInfo.info.flushedLogicalAddress, recoverFromAddress, recoveredHLCInfo.info.finalLogicalAddress, recoveredHLCInfo.info.startLogicalAddress, recoveredHLCInfo.info.snapshotFinalLogicalAddress, recoveredHLCInfo.info.nextVersion, recoveredHLCInfo.info.guid, undoNextVersion, recoveredHLCInfo.deltaLog, recoverTo); readOnlyAddress = recoveredHLCInfo.info.flushedLogicalAddress; } @@ -272,14 +304,11 @@ private void InternalRecover(IndexCheckpointInfo recoveredICInfo, HybridLogCheck // Recover session information hlog.RecoveryReset(tailAddress, headAddress, recoveredHLCInfo.info.beginAddress, readOnlyAddress); _recoveredSessions = recoveredHLCInfo.info.continueTokens; - - recoveredHLCInfo.deltaLog?.Dispose(); - recoveredHLCInfo.deltaFileDevice?.Dispose(); - checkpointManager.OnRecovery(recoveredICInfo.info.token, recoveredHLCInfo.info.guid); + recoveredHLCInfo.Dispose(); } - private async ValueTask InternalRecoverAsync(IndexCheckpointInfo recoveredICInfo, HybridLogCheckpointInfo recoveredHLCInfo, int numPagesToPreload, bool undoNextVersion, CancellationToken cancellationToken) + private async ValueTask InternalRecoverAsync(IndexCheckpointInfo recoveredICInfo, HybridLogCheckpointInfo recoveredHLCInfo, int numPagesToPreload, bool undoNextVersion, long recoverTo, CancellationToken cancellationToken) { if (!RecoverToInitialPage(recoveredICInfo, recoveredHLCInfo, out long recoverFromAddress)) await RecoverFuzzyIndexAsync(recoveredICInfo, cancellationToken).ConfigureAwait(false); @@ -303,7 +332,7 @@ private async ValueTask InternalRecoverAsync(IndexCheckpointInfo recoveredICInfo await RecoverHybridLogAsync (scanFromAddress, recoverFromAddress, recoveredHLCInfo.info.flushedLogicalAddress, recoveredHLCInfo.info.nextVersion, CheckpointType.Snapshot, undoNextVersion, cancellationToken).ConfigureAwait(false); // Then recover snapshot into mutable region await RecoverHybridLogFromSnapshotFileAsync(recoveredHLCInfo.info.flushedLogicalAddress, recoverFromAddress, recoveredHLCInfo.info.finalLogicalAddress, recoveredHLCInfo.info.startLogicalAddress, - recoveredHLCInfo.info.snapshotFinalLogicalAddress, recoveredHLCInfo.info.nextVersion, recoveredHLCInfo.info.guid, undoNextVersion, recoveredHLCInfo.deltaLog, cancellationToken).ConfigureAwait(false); + recoveredHLCInfo.info.snapshotFinalLogicalAddress, recoveredHLCInfo.info.nextVersion, recoveredHLCInfo.info.guid, undoNextVersion, recoveredHLCInfo.deltaLog, recoverTo, cancellationToken).ConfigureAwait(false); readOnlyAddress = recoveredHLCInfo.info.flushedLogicalAddress; } @@ -319,8 +348,7 @@ await RecoverHybridLogFromSnapshotFileAsync(recoveredHLCInfo.info.flushedLogical hlog.RecoveryReset(tailAddress, headAddress, recoveredHLCInfo.info.beginAddress, readOnlyAddress); _recoveredSessions = recoveredHLCInfo.info.continueTokens; - recoveredHLCInfo.deltaLog?.Dispose(); - recoveredHLCInfo.deltaFileDevice?.Dispose(); + recoveredHLCInfo.Dispose(); } /// <summary> @@ -336,8 +364,8 @@ private bool RecoverToInitialPage(IndexCheckpointInfo recoveredICInfo, HybridLog currentSyncStateMachine = null; // Set new system state after recovery - systemState.phase = Phase.REST; - systemState.version = recoveredHLCInfo.info.version + 1; + systemState.Phase = Phase.REST; + systemState.Version = recoveredHLCInfo.info.version + 1; if (!recoveredICInfo.IsDefault() && recoveryCountdown != null) { @@ -499,7 +527,7 @@ private RecoveryStatus GetPageRangesToRead(long scanFromAddress, long untilAddre capacity = hlog.GetCapacityNumPages(); int totalPagesToRead = (int)(endPage - startPage); numPagesToReadFirst = Math.Min(capacity, totalPagesToRead); - return new RecoveryStatus(capacity, startPage, endPage, untilAddress, checkpointType); + return new RecoveryStatus(capacity, endPage, untilAddress, checkpointType); } private bool ProcessReadPage(long recoverFromAddress, long untilAddress, int nextVersion, bool undoNextVersion, RecoveryStatus recoveryStatus, long page, int pageIndex) @@ -542,7 +570,7 @@ private async ValueTask WaitUntilAllPagesHaveBeenFlushedAsync(long startPage, lo await recoveryStatus.WaitFlushAsync(hlog.GetPageIndexForPage(page), cancellationToken).ConfigureAwait(false); } - private void RecoverHybridLogFromSnapshotFile(long scanFromAddress, long recoverFromAddress, long untilAddress, long snapshotStartAddress, long snapshotEndAddress, int nextVersion, Guid guid, bool undoNextVersion, DeltaLog deltaLog) + private void RecoverHybridLogFromSnapshotFile(long scanFromAddress, long recoverFromAddress, long untilAddress, long snapshotStartAddress, long snapshotEndAddress, int nextVersion, Guid guid, bool undoNextVersion, DeltaLog deltaLog, long recoverTo) { GetSnapshotPageRangesToRead(scanFromAddress, untilAddress, snapshotStartAddress, snapshotEndAddress, guid, out long startPage, out long endPage, out long snapshotEndPage, out int capacity, out var recoveryStatus, out int numPagesToReadFirst); @@ -573,7 +601,7 @@ private void RecoverHybridLogFromSnapshotFile(long scanFromAddress, long recover } // Apply delta - hlog.ApplyDelta(deltaLog, page, end); + hlog.ApplyDelta(deltaLog, page, end, recoverTo); for (long p = page; p < end; p++) { @@ -599,7 +627,7 @@ private void RecoverHybridLogFromSnapshotFile(long scanFromAddress, long recover recoveryStatus.Dispose(); } - private async ValueTask RecoverHybridLogFromSnapshotFileAsync(long scanFromAddress, long recoverFromAddress, long untilAddress, long snapshotStartAddress, long snapshotEndAddress, int nextVersion, Guid guid, bool undoNextVersion, DeltaLog deltaLog, CancellationToken cancellationToken) + private async ValueTask RecoverHybridLogFromSnapshotFileAsync(long scanFromAddress, long recoverFromAddress, long untilAddress, long snapshotStartAddress, long snapshotEndAddress, int nextVersion, Guid guid, bool undoNextVersion, DeltaLog deltaLog, long recoverTo, CancellationToken cancellationToken) { GetSnapshotPageRangesToRead(scanFromAddress, untilAddress, snapshotStartAddress, snapshotEndAddress, guid, out long startPage, out long endPage, out long snapshotEndPage, out int capacity, out var recoveryStatus, out int numPagesToReadFirst); @@ -630,7 +658,7 @@ private async ValueTask RecoverHybridLogFromSnapshotFileAsync(long scanFromAddre } // Apply delta - hlog.ApplyDelta(deltaLog, page, end); + hlog.ApplyDelta(deltaLog, page, end, recoverTo); for (long p = page; p < end; p++) { @@ -673,16 +701,13 @@ private void GetSnapshotPageRangesToRead(long fromAddress, long untilAddress, lo capacity = hlog.GetCapacityNumPages(); var recoveryDevice = checkpointManager.GetSnapshotLogDevice(guid); var objectLogRecoveryDevice = checkpointManager.GetSnapshotObjectLogDevice(guid); - var deltaRecoveryDevice = checkpointManager.GetDeltaLogDevice(guid); recoveryDevice.Initialize(hlog.GetSegmentSize()); objectLogRecoveryDevice.Initialize(-1); - deltaRecoveryDevice.Initialize(-1); - recoveryStatus = new RecoveryStatus(capacity, startPage, endPage, untilAddress, CheckpointType.Snapshot) + recoveryStatus = new RecoveryStatus(capacity, endPage, untilAddress, CheckpointType.Snapshot) { recoveryDevice = recoveryDevice, objectLogRecoveryDevice = objectLogRecoveryDevice, - deltaRecoveryDevice = deltaRecoveryDevice, recoveryDevicePageOffset = snapshotStartPage, snapshotEndPage = snapshotEndPage }; @@ -919,7 +944,7 @@ private bool RestoreHybridLogInitializePages(long beginAddress, long headAddress tailPage = GetPage(fromAddress); headPage = GetPage(headAddress); - recoveryStatus = new RecoveryStatus(GetCapacityNumPages(), headPage, tailPage, untilAddress, 0); + recoveryStatus = new RecoveryStatus(GetCapacityNumPages(), tailPage, untilAddress, 0); for (int i = 0; i < recoveryStatus.capacity; i++) { recoveryStatus.readStatus[i] = ReadStatus.Done; diff --git a/cs/src/core/Index/Synchronization/FasterStateMachine.cs b/cs/src/core/Index/Synchronization/FasterStateMachine.cs index 594c8e2ab..663b15204 100644 --- a/cs/src/core/Index/Synchronization/FasterStateMachine.cs +++ b/cs/src/core/Index/Synchronization/FasterStateMachine.cs @@ -18,9 +18,44 @@ public partial class FasterKV<Key, Value> // The current state machine in the system. The value could be stale and point to the previous state machine // if no state machine is active at this time. private ISynchronizationStateMachine currentSyncStateMachine; + private List<IStateMachineCallback> callbacks = new List<IStateMachineCallback>(); + internal long lastVersion; - internal SystemState SystemState => systemState; + /// <summary> + /// Any additional (user specified) metadata to write out with commit + /// </summary> + public byte[] CommitCookie { get; set; } + + private byte[] recoveredCommitCookie; + /// <summary> + /// User-specified commit cookie persisted with last recovered commit + /// </summary> + public byte[] RecoveredCommitCookie => recoveredCommitCookie; + /// <summary> + /// Get the current state machine state of the system + /// </summary> + public SystemState SystemState => systemState; + + /// <summary> + /// Version number of the last checkpointed state + /// </summary> + public long LastCheckpointedVersion => lastVersion; + + /// <summary> + /// Current version number of the store + /// </summary> + public long CurrentVersion => systemState.Version; + + /// <summary> + /// Registers the given callback to be invoked for every state machine transition. Not safe to call with + /// concurrent FASTER operations. Note that registered callbacks execute as part of the critical + /// section of FASTER's state transitions. Excessive synchronization or expensive computation in the callback + /// may slow or halt state machine execution. For advanced users only. + /// </summary> + /// <param name="callback"> callback to register </param> + public void UnsafeRegisterCallback(IStateMachineCallback callback) => callbacks.Add(callback); + /// <summary> /// Attempt to start the given state machine in the system if no other state machine is active. /// </summary> @@ -41,9 +76,9 @@ private bool StartStateMachine(ISynchronizationStateMachine stateMachine) [MethodImpl(MethodImplOptions.AggressiveInlining)] private bool MakeTransition(SystemState expectedState, SystemState nextState) { - if (Interlocked.CompareExchange(ref systemState.word, nextState.word, expectedState.word) != - expectedState.word) return false; - Debug.WriteLine("Moved to {0}, {1}", nextState.phase, nextState.version); + if (Interlocked.CompareExchange(ref systemState.Word, nextState.Word, expectedState.Word) != + expectedState.Word) return false; + Debug.WriteLine("Moved to {0}, {1}", nextState.Phase, nextState.Version); return true; } @@ -64,13 +99,17 @@ internal void GlobalStateMachineStep(SystemState expectedState) // Execute custom task logic currentSyncStateMachine.GlobalBeforeEnteringState(nextState, this); + // Execute any additional callbacks in critical section + foreach (var callback in callbacks) + callback.BeforeEnteringState(nextState, this); + var success = MakeTransition(intermediate, nextState); // Guaranteed to succeed, because other threads will always block while the system is in intermediate. Debug.Assert(success); currentSyncStateMachine.GlobalAfterEnteringState(nextState, this); // Mark the state machine done as we exit the state machine. - if (nextState.phase == Phase.REST) stateMachineActive = 0; + if (nextState.Phase == Phase.REST) stateMachineActive = 0; } @@ -78,9 +117,9 @@ internal void GlobalStateMachineStep(SystemState expectedState) [MethodImpl(MethodImplOptions.AggressiveInlining)] private SystemState StartOfCurrentCycle(SystemState currentGlobalState) { - return currentGlobalState.phase < Phase.REST - ? SystemState.Make(Phase.REST, currentGlobalState.version - 1) - : SystemState.Make(Phase.REST, currentGlobalState.version); + return currentGlobalState.Phase < Phase.REST + ? SystemState.Make(Phase.REST, currentGlobalState.Version - 1) + : SystemState.Make(Phase.REST, currentGlobalState.Version); } // Given the current thread state and global state, fast forward the thread state to the @@ -88,8 +127,8 @@ private SystemState StartOfCurrentCycle(SystemState currentGlobalState) [MethodImpl(MethodImplOptions.AggressiveInlining)] private SystemState FastForwardToCurrentCycle(SystemState threadState, SystemState targetStartState) { - if (threadState.version < targetStartState.version || - threadState.version == targetStartState.version && threadState.phase < targetStartState.phase) + if (threadState.Version < targetStartState.Version || + threadState.Version == targetStartState.Version && threadState.Phase < targetStartState.Phase) { return targetStartState; } @@ -109,7 +148,7 @@ internal bool SameCycle<Input, Output, Context>(FasterExecutionContext<Input, Ou { var _systemState = SystemState.Copy(ref systemState); SystemState.RemoveIntermediate(ref _systemState); - return StartOfCurrentCycle(threadState).version == StartOfCurrentCycle(_systemState).version; + return StartOfCurrentCycle(threadState).Version == StartOfCurrentCycle(_systemState).Version; } return ctx.threadStateMachine == currentSyncStateMachine; } @@ -150,7 +189,7 @@ private void ThreadStateMachineStep<Input, Output, Context, FasterSession>( #region Get returning thread to start of current cycle, issuing completion callbacks if needed if (ctx != null) { - if (ctx.version < targetStartState.version) + if (ctx.version < targetStartState.Version) { // Issue CPR callback for full session if (ctx.serialNum != -1) @@ -175,7 +214,7 @@ private void ThreadStateMachineStep<Input, Output, Context, FasterSession>( fasterSession?.CheckpointCompletionCallback(ctx.guid, commitPoint); } } - if ((ctx.version == targetStartState.version) && (ctx.phase < Phase.REST) && !(ctx.threadStateMachine is IndexSnapshotStateMachine)) + if ((ctx.version == targetStartState.Version) && (ctx.phase < Phase.REST) && !(ctx.threadStateMachine is IndexSnapshotStateMachine)) { IssueCompletionCallback(ctx, fasterSession); } @@ -184,12 +223,12 @@ private void ThreadStateMachineStep<Input, Output, Context, FasterSession>( // No state machine associated with target, or target is in REST phase: // we can directly fast forward session to target state - if (currentTask == null || targetState.phase == Phase.REST) + if (currentTask == null || targetState.Phase == Phase.REST) { if (ctx != null) { - ctx.phase = targetState.phase; - ctx.version = targetState.version; + ctx.phase = targetState.Phase; + ctx.version = targetState.Version; ctx.threadStateMachine = currentTask; } return; @@ -218,22 +257,22 @@ private void ThreadStateMachineStep<Input, Output, Context, FasterSession>( do { Debug.Assert( - (threadState.version < targetState.version) || - (threadState.version == targetState.version && - (threadState.phase <= targetState.phase || currentTask is IndexSnapshotStateMachine) + (threadState.Version < targetState.Version) || + (threadState.Version == targetState.Version && + (threadState.Phase <= targetState.Phase || currentTask is IndexSnapshotStateMachine) )); currentTask.OnThreadEnteringState(threadState, previousState, this, ctx, fasterSession, valueTasks, token); if (ctx != null) { - ctx.phase = threadState.phase; - ctx.version = threadState.version; + ctx.phase = threadState.Phase; + ctx.version = threadState.Version; } - previousState.word = threadState.word; + previousState.Word = threadState.Word; threadState = currentTask.NextState(threadState); - if (systemState.word != targetState.word) + if (systemState.Word != targetState.Word) { var tmp = SystemState.Copy(ref systemState); if (currentSyncStateMachine == currentTask) @@ -242,7 +281,7 @@ private void ThreadStateMachineStep<Input, Output, Context, FasterSession>( SystemState.RemoveIntermediate(ref targetState); } } - } while (previousState.word != targetState.word); + } while (previousState.Word != targetState.Word); #endregion return; diff --git a/cs/src/core/Index/Synchronization/FullCheckpointStateMachine.cs b/cs/src/core/Index/Synchronization/FullCheckpointStateMachine.cs index 9018655fb..942c7ad2e 100644 --- a/cs/src/core/Index/Synchronization/FullCheckpointStateMachine.cs +++ b/cs/src/core/Index/Synchronization/FullCheckpointStateMachine.cs @@ -16,7 +16,7 @@ public void GlobalBeforeEnteringState<Key, Value>( SystemState next, FasterKV<Key, Value> faster) { - switch (next.phase) + switch (next.Phase) { case Phase.PREP_INDEX_CHECKPOINT: Debug.Assert(faster._indexCheckpoint.IsDefault() && @@ -25,7 +25,7 @@ public void GlobalBeforeEnteringState<Key, Value>( faster._indexCheckpointToken = fullCheckpointToken; faster._hybridLogCheckpointToken = fullCheckpointToken; faster.InitializeIndexCheckpoint(faster._indexCheckpointToken); - faster.InitializeHybridLogCheckpoint(faster._hybridLogCheckpointToken, next.version); + faster.InitializeHybridLogCheckpoint(faster._hybridLogCheckpointToken, next.Version); break; case Phase.WAIT_FLUSH: faster._indexCheckpoint.info.num_buckets = faster.overflowBucketsAllocator.GetMaxValidAddress(); @@ -79,19 +79,19 @@ public FullCheckpointStateMachine(ISynchronizationTask checkpointBackend, long t public override SystemState NextState(SystemState start) { var result = SystemState.Copy(ref start); - switch (start.phase) + switch (start.Phase) { case Phase.REST: - result.phase = Phase.PREP_INDEX_CHECKPOINT; + result.Phase = Phase.PREP_INDEX_CHECKPOINT; break; case Phase.PREP_INDEX_CHECKPOINT: - result.phase = Phase.PREPARE; + result.Phase = Phase.PREPARE; break; case Phase.WAIT_PENDING: - result.phase = Phase.WAIT_INDEX_CHECKPOINT; + result.Phase = Phase.WAIT_INDEX_CHECKPOINT; break; case Phase.WAIT_INDEX_CHECKPOINT: - result.phase = Phase.WAIT_FLUSH; + result.Phase = Phase.WAIT_FLUSH; break; default: result = base.NextState(start); diff --git a/cs/src/core/Index/Synchronization/HybridLogCheckpointTask.cs b/cs/src/core/Index/Synchronization/HybridLogCheckpointTask.cs index 13a043633..8e50fee92 100644 --- a/cs/src/core/Index/Synchronization/HybridLogCheckpointTask.cs +++ b/cs/src/core/Index/Synchronization/HybridLogCheckpointTask.cs @@ -12,32 +12,35 @@ namespace FASTER.core /// </summary> internal abstract class HybridLogCheckpointOrchestrationTask : ISynchronizationTask { + private long lastVersion; /// <inheritdoc /> public virtual void GlobalBeforeEnteringState<Key, Value>(SystemState next, FasterKV<Key, Value> faster) { - switch (next.phase) + switch (next.Phase) { case Phase.PREPARE: + lastVersion = faster.systemState.Version; if (faster._hybridLogCheckpoint.IsDefault()) { faster._hybridLogCheckpointToken = Guid.NewGuid(); - faster.InitializeHybridLogCheckpoint(faster._hybridLogCheckpointToken, next.version); + faster.InitializeHybridLogCheckpoint(faster._hybridLogCheckpointToken, next.Version); } - faster._hybridLogCheckpoint.info.version = next.version; + faster._hybridLogCheckpoint.info.version = next.Version; faster.ObtainCurrentTailAddress(ref faster._hybridLogCheckpoint.info.startLogicalAddress); break; case Phase.WAIT_FLUSH: faster._hybridLogCheckpoint.info.headAddress = faster.hlog.HeadAddress; faster._hybridLogCheckpoint.info.beginAddress = faster.hlog.BeginAddress; - faster._hybridLogCheckpoint.info.nextVersion = next.version; + faster._hybridLogCheckpoint.info.nextVersion = next.Version; break; case Phase.PERSISTENCE_CALLBACK: CollectMetadata(next, faster); faster.WriteHybridLogMetaInfo(); + faster.lastVersion = lastVersion; break; case Phase.REST: - faster._hybridLogCheckpoint.Reset(); + faster._hybridLogCheckpoint.Dispose(); var nextTcs = new TaskCompletionSource<LinkedCheckpointInfo>(TaskCreationOptions.RunContinuationsAsynchronously); faster.checkpointTcs.SetResult(new LinkedCheckpointInfo { NextTask = nextTcs.Task }); faster.checkpointTcs = nextTcs; @@ -61,7 +64,7 @@ protected void CollectMetadata<Key, Value>(SystemState next, FasterKV<Key, Value lock (faster._activeSessions) // write dormant sessions to checkpoint foreach (var kvp in faster._activeSessions) - kvp.Value.AtomicSwitch(next.version - 1); + kvp.Value.AtomicSwitch(next.Version - 1); } /// <inheritdoc /> @@ -80,7 +83,7 @@ public virtual void OnThreadState<Key, Value, Input, Output, Context, FasterSess CancellationToken token = default) where FasterSession : IFasterSession { - if (current.phase != Phase.PERSISTENCE_CALLBACK) return; + if (current.Phase != Phase.PERSISTENCE_CALLBACK) return; if (ctx != null) { @@ -90,10 +93,10 @@ public virtual void OnThreadState<Key, Value, Input, Output, Context, FasterSess ctx.prevCtx.markers[EpochPhaseIdx.CheckpointCompletionCallback] = true; } - faster.epoch.Mark(EpochPhaseIdx.CheckpointCompletionCallback, current.version); + faster.epoch.Mark(EpochPhaseIdx.CheckpointCompletionCallback, current.Version); } - if (faster.epoch.CheckIsComplete(EpochPhaseIdx.CheckpointCompletionCallback, current.version)) + if (faster.epoch.CheckIsComplete(EpochPhaseIdx.CheckpointCompletionCallback, current.Version)) faster.GlobalStateMachineStep(current); } } @@ -111,13 +114,11 @@ public override void GlobalBeforeEnteringState<Key, Value>(SystemState next, { base.GlobalBeforeEnteringState(next, faster); - if (next.phase == Phase.PREPARE) + if (next.Phase == Phase.PREPARE) { - faster._lastSnapshotCheckpoint.deltaFileDevice?.Dispose(); - faster._lastSnapshotCheckpoint.deltaLog?.Dispose(); - faster._lastSnapshotCheckpoint = default; + faster._lastSnapshotCheckpoint.Dispose(); } - if (next.phase != Phase.WAIT_FLUSH) return; + if (next.Phase != Phase.WAIT_FLUSH) return; faster.hlog.ShiftReadOnlyToTail(out var tailAddress, out faster._hybridLogCheckpoint.flushedSemaphore); @@ -136,7 +137,7 @@ public override void OnThreadState<Key, Value, Input, Output, Context, FasterSes { base.OnThreadState(current, prev, faster, ctx, fasterSession, valueTasks, token); - if (current.phase != Phase.WAIT_FLUSH) return; + if (current.Phase != Phase.WAIT_FLUSH) return; if (ctx == null || !ctx.prevCtx.markers[EpochPhaseIdx.WaitFlush]) { @@ -157,9 +158,9 @@ public override void OnThreadState<Key, Value, Input, Output, Context, FasterSes } if (ctx != null) - faster.epoch.Mark(EpochPhaseIdx.WaitFlush, current.version); + faster.epoch.Mark(EpochPhaseIdx.WaitFlush, current.Version); - if (faster.epoch.CheckIsComplete(EpochPhaseIdx.WaitFlush, current.version)) + if (faster.epoch.CheckIsComplete(EpochPhaseIdx.WaitFlush, current.Version)) faster.GlobalStateMachineStep(current); } } @@ -174,12 +175,10 @@ internal sealed class SnapshotCheckpointTask : HybridLogCheckpointOrchestrationT /// <inheritdoc /> public override void GlobalBeforeEnteringState<Key, Value>(SystemState next, FasterKV<Key, Value> faster) { - switch (next.phase) + switch (next.Phase) { case Phase.PREPARE: - faster._lastSnapshotCheckpoint.deltaFileDevice?.Dispose(); - faster._lastSnapshotCheckpoint.deltaLog?.Dispose(); - faster._lastSnapshotCheckpoint = default; + faster._lastSnapshotCheckpoint.Dispose(); base.GlobalBeforeEnteringState(next, faster); faster._hybridLogCheckpoint.info.startLogicalAddress = faster.hlog.FlushedUntilAddress; faster._hybridLogCheckpoint.info.useSnapshotFile = 1; @@ -220,7 +219,7 @@ public override void GlobalBeforeEnteringState<Key, Value>(SystemState next, Fas // update flushed-until address to the latest faster._hybridLogCheckpoint.info.flushedLogicalAddress = faster.hlog.FlushedUntilAddress; base.GlobalBeforeEnteringState(next, faster); - faster._lastSnapshotCheckpoint = faster._hybridLogCheckpoint; + faster._lastSnapshotCheckpoint = faster._hybridLogCheckpoint.Transfer(); break; default: base.GlobalBeforeEnteringState(next, faster); @@ -239,7 +238,7 @@ public override void OnThreadState<Key, Value, Input, Output, Context, FasterSes { base.OnThreadState(current, prev, faster, ctx, fasterSession, valueTasks, token); - if (current.phase != Phase.WAIT_FLUSH) return; + if (current.Phase != Phase.WAIT_FLUSH) return; if (ctx == null || !ctx.prevCtx.markers[EpochPhaseIdx.WaitFlush]) { @@ -261,9 +260,9 @@ public override void OnThreadState<Key, Value, Input, Output, Context, FasterSes } if (ctx != null) - faster.epoch.Mark(EpochPhaseIdx.WaitFlush, current.version); + faster.epoch.Mark(EpochPhaseIdx.WaitFlush, current.Version); - if (faster.epoch.CheckIsComplete(EpochPhaseIdx.WaitFlush, current.version)) + if (faster.epoch.CheckIsComplete(EpochPhaseIdx.WaitFlush, current.Version)) faster.GlobalStateMachineStep(current); } } @@ -278,13 +277,13 @@ internal sealed class IncrementalSnapshotCheckpointTask : HybridLogCheckpointOrc /// <inheritdoc /> public override void GlobalBeforeEnteringState<Key, Value>(SystemState next, FasterKV<Key, Value> faster) { - switch (next.phase) + switch (next.Phase) { case Phase.PREPARE: faster._hybridLogCheckpoint = faster._lastSnapshotCheckpoint; base.GlobalBeforeEnteringState(next, faster); faster._hybridLogCheckpoint.info.startLogicalAddress = faster.hlog.FlushedUntilAddress; - faster._hybridLogCheckpoint.prevVersion = next.version; + faster._hybridLogCheckpoint.prevVersion = next.Version; break; case Phase.WAIT_FLUSH: base.GlobalBeforeEnteringState(next, faster); @@ -311,7 +310,8 @@ public override void GlobalBeforeEnteringState<Key, Value>(SystemState next, Fas CollectMetadata(next, faster); faster.WriteHybridLogIncrementalMetaInfo(faster._hybridLogCheckpoint.deltaLog); faster._hybridLogCheckpoint.info.deltaTailAddress = faster._hybridLogCheckpoint.deltaLog.TailAddress; - faster._lastSnapshotCheckpoint = faster._hybridLogCheckpoint; + faster._lastSnapshotCheckpoint = faster._hybridLogCheckpoint.Transfer(); + faster._hybridLogCheckpoint.Dispose(); break; } } @@ -327,7 +327,7 @@ public override void OnThreadState<Key, Value, Input, Output, Context, FasterSes { base.OnThreadState(current, prev, faster, ctx, fasterSession, valueTasks, token); - if (current.phase != Phase.WAIT_FLUSH) return; + if (current.Phase != Phase.WAIT_FLUSH) return; if (ctx == null || !ctx.prevCtx.markers[EpochPhaseIdx.WaitFlush]) { @@ -349,9 +349,9 @@ public override void OnThreadState<Key, Value, Input, Output, Context, FasterSes } if (ctx != null) - faster.epoch.Mark(EpochPhaseIdx.WaitFlush, current.version); + faster.epoch.Mark(EpochPhaseIdx.WaitFlush, current.Version); - if (faster.epoch.CheckIsComplete(EpochPhaseIdx.WaitFlush, current.version)) + if (faster.epoch.CheckIsComplete(EpochPhaseIdx.WaitFlush, current.Version)) faster.GlobalStateMachineStep(current); } } @@ -382,16 +382,16 @@ protected HybridLogCheckpointStateMachine(long targetVersion, params ISynchroniz public override SystemState NextState(SystemState start) { var result = SystemState.Copy(ref start); - switch (start.phase) + switch (start.Phase) { case Phase.WAIT_PENDING: - result.phase = Phase.WAIT_FLUSH; + result.Phase = Phase.WAIT_FLUSH; break; case Phase.WAIT_FLUSH: - result.phase = Phase.PERSISTENCE_CALLBACK; + result.Phase = Phase.PERSISTENCE_CALLBACK; break; case Phase.PERSISTENCE_CALLBACK: - result.phase = Phase.REST; + result.Phase = Phase.REST; break; default: result = base.NextState(start); diff --git a/cs/src/core/Index/Synchronization/IStateMachineCallback.cs b/cs/src/core/Index/Synchronization/IStateMachineCallback.cs new file mode 100644 index 000000000..70a2b84db --- /dev/null +++ b/cs/src/core/Index/Synchronization/IStateMachineCallback.cs @@ -0,0 +1,17 @@ +namespace FASTER.core +{ + /// <summary> + /// Encapsulates custom logic to be executed as part of FASTER's state machine logic + /// </summary> + public interface IStateMachineCallback + { + /// <summary> + /// Invoked immediately before every state transition. + /// </summary> + /// <param name="next"> next system state </param> + /// <param name="faster"> reference to FASTER K-V </param> + /// <typeparam name="Key">Key Type</typeparam> + /// <typeparam name="Value">Value Type</typeparam> + void BeforeEnteringState<Key, Value>(SystemState next, FasterKV<Key, Value> faster); + } +} \ No newline at end of file diff --git a/cs/src/core/Index/Synchronization/IndexResizeStateMachine.cs b/cs/src/core/Index/Synchronization/IndexResizeStateMachine.cs index a43d0ae47..7b71d58f9 100644 --- a/cs/src/core/Index/Synchronization/IndexResizeStateMachine.cs +++ b/cs/src/core/Index/Synchronization/IndexResizeStateMachine.cs @@ -14,7 +14,7 @@ public void GlobalBeforeEnteringState<Key, Value>( SystemState next, FasterKV<Key, Value> faster) { - switch (next.phase) + switch (next.Phase) { case Phase.PREPARE_GROW: // nothing to do @@ -45,7 +45,7 @@ public void GlobalAfterEnteringState<Key, Value>( SystemState next, FasterKV<Key, Value> faster) { - switch (next.phase) + switch (next.Phase) { case Phase.PREPARE_GROW: faster.epoch.BumpCurrentEpoch(() => faster.GlobalStateMachineStep(next)); @@ -70,7 +70,7 @@ public void OnThreadState<Key, Value, Input, Output, Context, FasterSession>( CancellationToken token = default) where FasterSession : IFasterSession { - switch (current.phase) + switch (current.Phase) { case Phase.PREPARE_GROW: case Phase.IN_PROGRESS_GROW: @@ -96,16 +96,16 @@ public IndexResizeStateMachine() : base(new IndexResizeTask()) {} public override SystemState NextState(SystemState start) { var nextState = SystemState.Copy(ref start); - switch (start.phase) + switch (start.Phase) { case Phase.REST: - nextState.phase = Phase.PREPARE_GROW; + nextState.Phase = Phase.PREPARE_GROW; break; case Phase.PREPARE_GROW: - nextState.phase = Phase.IN_PROGRESS_GROW; + nextState.Phase = Phase.IN_PROGRESS_GROW; break; case Phase.IN_PROGRESS_GROW: - nextState.phase = Phase.REST; + nextState.Phase = Phase.REST; break; default: throw new FasterException("Invalid Enum Argument"); diff --git a/cs/src/core/Index/Synchronization/IndexSnapshotStateMachine.cs b/cs/src/core/Index/Synchronization/IndexSnapshotStateMachine.cs index e4796d5f9..bc9f351ea 100644 --- a/cs/src/core/Index/Synchronization/IndexSnapshotStateMachine.cs +++ b/cs/src/core/Index/Synchronization/IndexSnapshotStateMachine.cs @@ -16,7 +16,7 @@ public void GlobalBeforeEnteringState<Key, Value>( SystemState next, FasterKV<Key, Value> faster) { - switch (next.phase) + switch (next.Phase) { case Phase.PREP_INDEX_CHECKPOINT: if (faster._indexCheckpoint.IsDefault()) @@ -67,7 +67,7 @@ public void OnThreadState<Key, Value, Input, Output, Context, FasterSession>( CancellationToken token = default) where FasterSession : IFasterSession { - switch (current.phase) + switch (current.Phase) { case Phase.PREP_INDEX_CHECKPOINT: faster.GlobalStateMachineStep(current); @@ -109,16 +109,16 @@ public IndexSnapshotStateMachine() : base(new IndexSnapshotTask()) public override SystemState NextState(SystemState start) { var result = SystemState.Copy(ref start); - switch (start.phase) + switch (start.Phase) { case Phase.REST: - result.phase = Phase.PREP_INDEX_CHECKPOINT; + result.Phase = Phase.PREP_INDEX_CHECKPOINT; break; case Phase.PREP_INDEX_CHECKPOINT: - result.phase = Phase.WAIT_INDEX_ONLY_CHECKPOINT; + result.Phase = Phase.WAIT_INDEX_ONLY_CHECKPOINT; break; case Phase.WAIT_INDEX_ONLY_CHECKPOINT: - result.phase = Phase.REST; + result.Phase = Phase.REST; break; default: throw new FasterException("Invalid Enum Argument"); diff --git a/cs/src/core/Index/Synchronization/StateTransitions.cs b/cs/src/core/Index/Synchronization/StateTransitions.cs index 552331bb7..ace3af510 100644 --- a/cs/src/core/Index/Synchronization/StateTransitions.cs +++ b/cs/src/core/Index/Synchronization/StateTransitions.cs @@ -21,68 +21,160 @@ internal struct ResizeInfo public long word; } - internal enum Phase : int { - IN_PROGRESS, + /// <summary> + /// The current phase of a state-machine operation such as a checkpoint + /// </summary> + public enum Phase : int { + /// <summary>In-progress phase, entering (v+1) version</summary> + IN_PROGRESS, + + /// <summary>Wait-pending phase, waiting for pending (v) operations to complete</summary> WAIT_PENDING, + + /// <summary>Wait for an index checkpoint to finish</summary> WAIT_INDEX_CHECKPOINT, + + /// <summary>Wait for data flush to complete</summary> WAIT_FLUSH, + + /// <summary>After flush has completed, write metadata to persistent storage and issue user callbacks</summary> PERSISTENCE_CALLBACK, + + /// <summary>The default phase; no state-machine operation is operating</summary> REST, + + /// <summary>Prepare for an index checkpoint</summary> PREP_INDEX_CHECKPOINT, + + /// <summary>Wait for an index-only checkpoint to complete</summary> WAIT_INDEX_ONLY_CHECKPOINT, + + /// <summary>Prepare for a checkpoint, still in (v) version</summary> PREPARE, - PREPARE_GROW, - IN_PROGRESS_GROW, + + /// <summary>Prepare to resize the index</summary> + PREPARE_GROW, + + /// <summary>Index resizing is in progress</summary> + IN_PROGRESS_GROW, + + /// <summary>Internal intermediate state of state machine</summary> INTERMEDIATE = 16, }; + /// <summary> + /// The current state of a state-machine operation such as a checkpoint. + /// </summary> [StructLayout(LayoutKind.Explicit, Size = 8)] - internal struct SystemState + public struct SystemState { + /// <summary> + /// The current <see cref="Phase"/> of the operation + /// </summary> [FieldOffset(0)] - public Phase phase; + public Phase Phase; + /// <summary> + /// The version of the database when this operation is complete + /// </summary> [FieldOffset(4)] - public int version; + public int Version; + /// <summary> + /// The word containing information in bitfields + /// </summary> [FieldOffset(0)] - public long word; + internal long Word; + /// <summary> + /// Copy the <paramref name="other"/> <see cref="SystemState"/> into this <see cref="SystemState"/> + /// </summary> [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static SystemState Copy(ref SystemState other) + internal static SystemState Copy(ref SystemState other) { var info = default(SystemState); - info.word = other.word; + info.Word = other.Word; return info; } + /// <summary> + /// Create a <see cref="SystemState"/> with the specified values + /// </summary> [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static SystemState Make(Phase status, int version) + internal static SystemState Make(Phase status, int version) { var info = default(SystemState); - info.phase = status; - info.version = version; + info.Phase = status; + info.Version = version; return info; } + /// <summary> + /// Create a copy of the passed <see cref="SystemState"/> that is marked with the <see cref="Phase.INTERMEDIATE"/> phase + /// </summary> [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static SystemState MakeIntermediate(SystemState state) - => Make(state.phase | Phase.INTERMEDIATE, state.version); + internal static SystemState MakeIntermediate(SystemState state) + => Make(state.Phase | Phase.INTERMEDIATE, state.Version); + /// <summary> + /// Create a copy of the passed <see cref="SystemState"/> that is not marked with the <see cref="Phase.INTERMEDIATE"/> phase + /// </summary> [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void RemoveIntermediate(ref SystemState state) + internal static void RemoveIntermediate(ref SystemState state) { - state.phase &= ~Phase.INTERMEDIATE; + state.Phase &= ~Phase.INTERMEDIATE; } - public static bool Equal(SystemState s1, SystemState s2) + /// <summary> + /// Compare two <see cref="SystemState"/>s for equality + /// </summary> + internal static bool Equal(SystemState s1, SystemState s2) { - return s1.word == s2.word; + return s1.Word == s2.Word; } + /// <inheritdoc/> public override string ToString() { - return $"[{phase},{version}]"; + return $"[{Phase},{Version}]"; + } + + /// <summary> + /// Compare the current <see cref="SystemState"/> to <paramref name="obj"/> for equality if obj is also a <see cref="SystemState"/> + /// </summary> + public override bool Equals(object obj) + { + return obj is SystemState other && Equals(other); + } + + /// <inheritdoc/> + public override int GetHashCode() + { + return Word.GetHashCode(); + } + + /// <summary> + /// Compare the current <see cref="SystemState"/> to <paramref name="other"/> for equality + /// </summary> + private bool Equals(SystemState other) + { + return Word == other.Word; + } + + /// <summary> + /// Equals + /// </summary> + public static bool operator ==(SystemState left, SystemState right) + { + return left.Equals(right); + } + + /// <summary> + /// Not Equals + /// </summary> + public static bool operator !=(SystemState left, SystemState right) + { + return !(left == right); } } } diff --git a/cs/src/core/Index/Synchronization/VersionChangeStateMachine.cs b/cs/src/core/Index/Synchronization/VersionChangeStateMachine.cs index f143321b5..9706d76a6 100644 --- a/cs/src/core/Index/Synchronization/VersionChangeStateMachine.cs +++ b/cs/src/core/Index/Synchronization/VersionChangeStateMachine.cs @@ -37,7 +37,7 @@ public void OnThreadState<Key, Value, Input, Output, Context, FasterSession>( CancellationToken token = default) where FasterSession : IFasterSession { - switch (current.phase) + switch (current.Phase) { case Phase.PREPARE: if (ctx != null) @@ -49,17 +49,17 @@ public void OnThreadState<Key, Value, Input, Output, Context, FasterSession>( ctx.markers[EpochPhaseIdx.Prepare] = true; } - faster.epoch.Mark(EpochPhaseIdx.Prepare, current.version); + faster.epoch.Mark(EpochPhaseIdx.Prepare, current.Version); } - if (faster.epoch.CheckIsComplete(EpochPhaseIdx.Prepare, current.version)) + if (faster.epoch.CheckIsComplete(EpochPhaseIdx.Prepare, current.Version)) faster.GlobalStateMachineStep(current); break; case Phase.IN_PROGRESS: if (ctx != null) { // Need to be very careful here as threadCtx is changing - var _ctx = prev.phase == Phase.IN_PROGRESS ? ctx.prevCtx : ctx; + var _ctx = prev.Phase == Phase.IN_PROGRESS ? ctx.prevCtx : ctx; var tokens = faster._hybridLogCheckpoint.info.checkpointTokens; if (!faster.SameCycle(ctx, current) || tokens == null) return; @@ -73,11 +73,11 @@ public void OnThreadState<Key, Value, Input, Output, Context, FasterSession>( ctx.prevCtx.markers[EpochPhaseIdx.InProgress] = true; } - faster.epoch.Mark(EpochPhaseIdx.InProgress, current.version); + faster.epoch.Mark(EpochPhaseIdx.InProgress, current.Version); } // Has to be prevCtx, not ctx - if (faster.epoch.CheckIsComplete(EpochPhaseIdx.InProgress, current.version)) + if (faster.epoch.CheckIsComplete(EpochPhaseIdx.InProgress, current.Version)) faster.GlobalStateMachineStep(current); break; case Phase.WAIT_PENDING: @@ -91,10 +91,10 @@ public void OnThreadState<Key, Value, Input, Output, Context, FasterSession>( break; } - faster.epoch.Mark(EpochPhaseIdx.WaitPending, current.version); + faster.epoch.Mark(EpochPhaseIdx.WaitPending, current.Version); } - if (faster.epoch.CheckIsComplete(EpochPhaseIdx.WaitPending, current.version)) + if (faster.epoch.CheckIsComplete(EpochPhaseIdx.WaitPending, current.Version)) faster.GlobalStateMachineStep(current); break; case Phase.REST: @@ -114,7 +114,7 @@ public void GlobalBeforeEnteringState<Key, Value>( SystemState next, FasterKV<Key, Value> faster) { - if (next.phase == Phase.REST) + if (next.Phase == Phase.REST) // Before leaving the checkpoint, make sure all previous versions are read-only. faster.hlog.ShiftReadOnlyToTail(out _, out _); } @@ -144,7 +144,7 @@ public void OnThreadState<Key, Value, Input, Output, Context, FasterSession>( /// </summary> internal class VersionChangeStateMachine : SynchronizationStateMachineBase { - private readonly long targetVersion; + private long targetVersion; /// <summary> /// Construct a new VersionChangeStateMachine with the given tasks. Does not load any tasks by default. @@ -166,23 +166,30 @@ protected VersionChangeStateMachine(long targetVersion = -1, params ISynchroniza public override SystemState NextState(SystemState start) { var nextState = SystemState.Copy(ref start); - switch (start.phase) + switch (start.Phase) { case Phase.REST: - nextState.phase = Phase.PREPARE; + nextState.Phase = Phase.PREPARE; break; case Phase.PREPARE: - nextState.phase = Phase.IN_PROGRESS; + nextState.Phase = Phase.IN_PROGRESS; + // 13 bits of 1s --- FASTER records only store 13 bits of version number, and we need to ensure that + // the next version is distinguishable from the last in those 13 bits. + var bitMask = (1L << 13) - 1; + // If they are not distinguishable, simply increment target version to resolve this + if (((targetVersion - start.Version) & bitMask) == 0) + targetVersion++; + // TODO: Move to long for system state as well. - SetToVersion(targetVersion == -1 ? start.version + 1 : targetVersion); - nextState.version = (int) ToVersion(); + SetToVersion(targetVersion == -1 ? start.Version + 1 : targetVersion); + nextState.Version = (int) ToVersion(); break; case Phase.IN_PROGRESS: // This phase has no effect if using relaxed CPR model - nextState.phase = Phase.WAIT_PENDING; + nextState.Phase = Phase.WAIT_PENDING; break; case Phase.WAIT_PENDING: - nextState.phase = Phase.REST; + nextState.Phase = Phase.REST; break; default: throw new FasterException("Invalid Enum Argument"); diff --git a/cs/src/core/VarLen/SpanByte.cs b/cs/src/core/VarLen/SpanByte.cs index 30f709379..d75150753 100644 --- a/cs/src/core/VarLen/SpanByte.cs +++ b/cs/src/core/VarLen/SpanByte.cs @@ -3,6 +3,7 @@ using System; using System.Buffers; +using System.Diagnostics; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; @@ -10,14 +11,19 @@ namespace FASTER.core { /// <summary> /// Represents a pinned variable length byte array that is viewable as a fixed (pinned) Span<byte> - /// Format: [4-byte (int) length of payload][payload bytes...] + /// Format: [4-byte (int) length of payload][[optional 8-byte metadata] payload bytes...] + /// First 4 bits of length are used as a mask for various properties, so max length is 256MB /// </summary> [StructLayout(LayoutKind.Explicit)] public unsafe struct SpanByte { // Byte #30 and #31 are used for read-only and lock respectively - const int kTypeBitMask = 1 << 29; - const int kHeaderMask = 7 << 29; + // Byte #29 is used to denote unserialized (1) or serialized (0) data + const int kUnserializedBitMask = 1 << 29; + // Byte #28 is used to denote extra metadata present (1) or absent (0) in payload + const int kExtraMetadataBitMask = 1 << 28; + // Mask for header + const int kHeaderMask = 0xf << 28; /// <summary> /// Length of the payload @@ -34,7 +40,18 @@ public unsafe struct SpanByte internal IntPtr Pointer => payload; /// <summary> - /// Length of payload + /// Get payload pointer + /// </summary> + public byte* ToPointer() + { + if (Serialized) + return MetadataSize + (byte*)Unsafe.AsPointer(ref payload); + else + return MetadataSize + (byte*)payload; + } + + /// <summary> + /// Length of payload, including metadata if any /// </summary> public int Length { @@ -42,16 +59,25 @@ public int Length set { length = (length & kHeaderMask) | value; } } + /// <summary> + /// Length of payload, not including metadata if any + /// </summary> + public int LengthWithoutMetadata => (length & ~kHeaderMask) - MetadataSize; + /// <summary> /// Format of structure /// </summary> - public bool Serialized => (length & kTypeBitMask) == 0; + public bool Serialized => (length & kUnserializedBitMask) == 0; /// <summary> - /// Total serialized size in bytes, including header + /// Total serialized size in bytes, including header and metadata if any /// </summary> - public int TotalSize => Length + sizeof(int); + public int TotalSize => sizeof(int) + Length; + /// <summary> + /// Size of metadata header, if any (returns 0 or 8) + /// </summary> + public int MetadataSize => (length & kExtraMetadataBitMask) >> (28 - 3); /// <summary> /// Constructor @@ -60,20 +86,64 @@ public int Length /// <param name="payload"></param> public SpanByte(int length, IntPtr payload) { - this.length = length | kTypeBitMask; + Debug.Assert(length <= ~kHeaderMask); + this.length = length | kUnserializedBitMask; this.payload = payload; } + /// <summary> + /// Extra metadata header + /// </summary> + public long ExtraMetadata + { + get + { + if (Serialized) + return MetadataSize > 0 ? *(long*)Unsafe.AsPointer(ref payload) : 0; + else + return MetadataSize > 0 ? *(long*)payload : 0; + } + set + { + if (value > 0) + { + length |= kExtraMetadataBitMask; + Debug.Assert(Length > MetadataSize); + if (Serialized) + *(long*)Unsafe.AsPointer(ref payload) = value; + else + *(long*)payload = value; + } + } + } + + /// <summary> + /// Mark SpanByte as having 8-byte metadata in header of payload + /// </summary> + public void MarkExtraMetadata() + { + Debug.Assert(Length >= 8); + length |= kExtraMetadataBitMask; + } + + /// <summary> + /// Unmark SpanByte as having 8-byte metadata in header of payload + /// </summary> + public void UnmarkExtraMetadata() + { + length &= ~kExtraMetadataBitMask; + } + /// <summary> /// Check or set struct as invalid /// </summary> public bool Invalid { - get { return ((length & kTypeBitMask) != 0) && payload == IntPtr.Zero; } + get { return ((length & kUnserializedBitMask) != 0) && payload == IntPtr.Zero; } set { if (value) { - length |= kTypeBitMask; + length |= kUnserializedBitMask; payload = IntPtr.Zero; } else @@ -84,27 +154,51 @@ public bool Invalid } /// <summary> - /// Get Span<byte> for this SpanByte's payload + /// Get Span<byte> for this SpanByte's payload (excluding metadata if any) /// </summary> /// <returns></returns> public Span<byte> AsSpan() { if (Serialized) - return new Span<byte>(Unsafe.AsPointer(ref payload), Length); + return new Span<byte>(MetadataSize + (byte*)Unsafe.AsPointer(ref payload), Length - MetadataSize); else - return new Span<byte>((void*)payload, Length); + return new Span<byte>(MetadataSize + (byte*)payload, Length - MetadataSize); } /// <summary> - /// Get ReadOnlySpan<byte> for this SpanByte's payload + /// Get ReadOnlySpan<byte> for this SpanByte's payload (excluding metadata if any) /// </summary> /// <returns></returns> public ReadOnlySpan<byte> AsReadOnlySpan() { if (Serialized) - return new ReadOnlySpan<byte>(Unsafe.AsPointer(ref payload), Length); + return new ReadOnlySpan<byte>(MetadataSize + (byte*)Unsafe.AsPointer(ref payload), Length - MetadataSize); + else + return new ReadOnlySpan<byte>(MetadataSize + (byte*)payload, Length - MetadataSize); + } + + /// <summary> + /// Get Span<byte> for this SpanByte's payload (including metadata if any) + /// </summary> + /// <returns></returns> + public Span<byte> AsSpanWithMetadata() + { + if (Serialized) + return new Span<byte>((byte*)Unsafe.AsPointer(ref payload), Length); else - return new ReadOnlySpan<byte>((void*)payload, Length); + return new Span<byte>((byte*)payload, Length); + } + + /// <summary> + /// Get ReadOnlySpan<byte> for this SpanByte's payload (including metadata if any) + /// </summary> + /// <returns></returns> + public ReadOnlySpan<byte> AsReadOnlySpanWithMetadata() + { + if (Serialized) + return new ReadOnlySpan<byte>((byte*)Unsafe.AsPointer(ref payload), Length); + else + return new ReadOnlySpan<byte>((byte*)payload, Length); } /// <summary> @@ -115,7 +209,7 @@ public ReadOnlySpan<byte> AsReadOnlySpan() public SpanByte Deserialize() { if (!Serialized) return this; - return new SpanByte(Length, (IntPtr)Unsafe.AsPointer(ref payload)); + return new SpanByte(Length - MetadataSize, (IntPtr)(MetadataSize + (byte*)Unsafe.AsPointer(ref payload))); } /// <summary> @@ -125,6 +219,8 @@ public SpanByte Deserialize() /// <returns></returns> public static ref SpanByte Reinterpret(Span<byte> span) { + Debug.Assert(span.Length - sizeof(int) <= ~kHeaderMask); + fixed (byte* ptr = span) { *(int*)ptr = span.Length - sizeof(int); @@ -254,7 +350,8 @@ public bool TryCopyTo(ref SpanByte dst) /// <param name="dst"></param> public void CopyTo(ref SpanByte dst) { - AsReadOnlySpan().CopyTo(dst.AsSpan()); + dst.ExtraMetadata = ExtraMetadata; + AsReadOnlySpan().CopyTo(dst.AsSpan()); } /// <summary> @@ -315,7 +412,9 @@ public void CopyWithHeaderTo(ref SpanByteAndMemory dst, MemoryPool<byte> memoryP var span = dst.SpanByte.AsSpan(); fixed (byte* ptr = span) *(int*)ptr = Length; - AsReadOnlySpan().CopyTo(span.Slice(sizeof(int))); + dst.SpanByte.ExtraMetadata = ExtraMetadata; + + AsReadOnlySpan().CopyTo(span.Slice(sizeof(int) + MetadataSize)); return; } dst.ConvertToHeap(); @@ -325,7 +424,8 @@ public void CopyWithHeaderTo(ref SpanByteAndMemory dst, MemoryPool<byte> memoryP dst.Length = TotalSize; fixed (byte* ptr = dst.Memory.Memory.Span) *(int*)ptr = Length; - AsReadOnlySpan().CopyTo(dst.Memory.Memory.Span.Slice(sizeof(int))); + dst.SpanByte.ExtraMetadata = ExtraMetadata; + AsReadOnlySpan().CopyTo(dst.Memory.Memory.Span.Slice(sizeof(int) + MetadataSize)); } /// <summary> @@ -334,13 +434,14 @@ public void CopyWithHeaderTo(ref SpanByteAndMemory dst, MemoryPool<byte> memoryP /// <param name="destination"></param> public void CopyTo(byte* destination) { - *(int*)destination = Length; if (Serialized) { + *(int*)destination = length; Buffer.MemoryCopy(Unsafe.AsPointer(ref payload), destination + sizeof(int), Length, Length); } else { + *(int*)destination = length & ~kUnserializedBitMask; Buffer.MemoryCopy((void*)payload, destination + sizeof(int), Length, Length); } } diff --git a/cs/src/core/VarLen/SpanByteAdvancedFunctions.cs b/cs/src/core/VarLen/SpanByteAdvancedFunctions.cs new file mode 100644 index 000000000..ed7275eaa --- /dev/null +++ b/cs/src/core/VarLen/SpanByteAdvancedFunctions.cs @@ -0,0 +1,156 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +using System.Buffers; + +namespace FASTER.core +{ + /// <summary> + /// Callback functions for SpanByte key, value + /// </summary> + public class SpanByteAdvancedFunctions<Key, Output, Context> : AdvancedFunctionsBase<Key, SpanByte, SpanByte, Output, Context> + { + /// <summary> + /// Constructor + /// </summary> + /// <param name="locking"></param> + public SpanByteAdvancedFunctions(bool locking = false) : base(locking) { } + + /// <inheritdoc /> + public override void SingleWriter(ref Key key, ref SpanByte src, ref SpanByte dst) + { + src.CopyTo(ref dst); + } + + /// <inheritdoc /> + public override bool ConcurrentWriter(ref Key key, ref SpanByte src, ref SpanByte dst, ref RecordInfo recordInfo, long address) + { + if (locking) dst.SpinLock(); + + // We can write the source (src) data to the existing destination (dst) in-place, + // only if there is sufficient space + if (dst.Length < src.Length || dst.IsMarkedReadOnly()) + { + dst.MarkReadOnly(); + if (locking) dst.Unlock(); + return false; + } + + // Option 1: write the source data, leaving the destination size unchanged. You will need + // to mange the actual space used by the value if you stop here. + src.CopyTo(ref dst); + + // We can adjust the length header on the serialized log, if we wish. + // This method will also zero out the extra space to retain log scan correctness. + dst.ShrinkSerializedLength(src.Length); + + if (locking) dst.Unlock(); + return true; + } + + /// <inheritdoc/> + public override void InitialUpdater(ref Key key, ref SpanByte input, ref SpanByte value, ref Output output) + { + input.CopyTo(ref value); + } + + /// <inheritdoc/> + public override void CopyUpdater(ref Key key, ref SpanByte input, ref SpanByte oldValue, ref SpanByte newValue, ref Output output) + { + oldValue.CopyTo(ref newValue); + } + + /// <inheritdoc/> + public override bool InPlaceUpdater(ref Key key, ref SpanByte input, ref SpanByte value, ref Output output, ref RecordInfo recordInfo, long address) + { + // The default implementation of IPU simply writes input to destination, if there is space + return ConcurrentWriter(ref key, ref input, ref value, ref recordInfo, address); + } + } + + /// <summary> + /// Callback functions using SpanByteAndMemory output, for SpanByte key, value, input + /// </summary> + public class SpanByteAdvancedFunctions<Context> : SpanByteAdvancedFunctions<SpanByte, SpanByteAndMemory, Context> + { + readonly MemoryPool<byte> memoryPool; + + /// <summary> + /// Constructor + /// </summary> + /// <param name="memoryPool"></param> + /// <param name="locking"></param> + public SpanByteAdvancedFunctions(MemoryPool<byte> memoryPool = default, bool locking = false) : base(locking) + { + this.memoryPool = memoryPool ?? MemoryPool<byte>.Shared; + } + + /// <inheritdoc /> + public unsafe override void SingleReader(ref SpanByte key, ref SpanByte input, ref SpanByte value, ref SpanByteAndMemory dst, long address) + { + value.CopyTo(ref dst, memoryPool); + } + + /// <inheritdoc /> + public unsafe override void ConcurrentReader(ref SpanByte key, ref SpanByte input, ref SpanByte value, ref SpanByteAndMemory dst, ref RecordInfo recordInfo, long address) + { + value.CopyTo(ref dst, memoryPool); + } + + /// <inheritdoc /> + public override bool SupportsLocking => locking; + + /// <inheritdoc /> + public override void Lock(ref RecordInfo recordInfo, ref SpanByte key, ref SpanByte value, LockType lockType, ref long lockContext) + { + value.SpinLock(); + } + + /// <inheritdoc /> + public override bool Unlock(ref RecordInfo recordInfo, ref SpanByte key, ref SpanByte value, LockType lockType, long lockContext) + { + value.Unlock(); + return true; + } + } + + /// <summary> + /// Callback functions for SpanByte with byte[] output, for SpanByte key, value, input + /// </summary> + public class SpanByteAdvancedFunctions_ByteArrayOutput<Context> : SpanByteAdvancedFunctions<SpanByte, byte[], Context> + { + /// <summary> + /// Constructor + /// </summary> + /// <param name="locking"></param> + public SpanByteAdvancedFunctions_ByteArrayOutput(bool locking = false) : base(locking) { } + + /// <inheritdoc /> + public override void SingleReader(ref SpanByte key, ref SpanByte input, ref SpanByte value, ref byte[] dst, long address) + { + dst = value.ToByteArray(); + } + + /// <inheritdoc /> + public override void ConcurrentReader(ref SpanByte key, ref SpanByte input, ref SpanByte value, ref byte[] dst, ref RecordInfo recordInfo, long address) + { + dst = value.ToByteArray(); + } + + /// <inheritdoc /> + public override bool SupportsLocking => locking; + + /// <inheritdoc /> + public override void Lock(ref RecordInfo recordInfo, ref SpanByte key, ref SpanByte value, LockType lockType, ref long lockContext) + { + value.SpinLock(); + } + + /// <inheritdoc /> + public override bool Unlock(ref RecordInfo recordInfo, ref SpanByte key, ref SpanByte value, LockType lockType, long lockContext) + { + value.Unlock(); + return true; + } + } +} diff --git a/cs/src/core/VarLen/SpanByteComparer.cs b/cs/src/core/VarLen/SpanByteComparer.cs index 27467ca1a..c441681ba 100644 --- a/cs/src/core/VarLen/SpanByteComparer.cs +++ b/cs/src/core/VarLen/SpanByteComparer.cs @@ -18,7 +18,7 @@ public unsafe long GetHashCode64(ref SpanByte spanByte) if (spanByte.Serialized) { byte* ptr = (byte*)Unsafe.AsPointer(ref spanByte); - return Utility.HashBytes(ptr + sizeof(int), *(int*)ptr); + return Utility.HashBytes(ptr + sizeof(int), spanByte.Length); } else { @@ -27,11 +27,12 @@ public unsafe long GetHashCode64(ref SpanByte spanByte) } } - /// <inheritdoc /> - [MethodImpl(MethodImplOptions.AggressiveInlining)] + /// <inheritdoc /> + [MethodImpl(MethodImplOptions.AggressiveInlining)] public unsafe bool Equals(ref SpanByte k1, ref SpanByte k2) { - return k1.AsReadOnlySpan().SequenceEqual(k2.AsReadOnlySpan()); + return k1.AsReadOnlySpanWithMetadata().SequenceEqual(k2.AsReadOnlySpanWithMetadata()) + && (k1.MetadataSize == k2.MetadataSize); } } } diff --git a/cs/src/devices/AzureStorageDevice/AzureStorageDevice.cs b/cs/src/devices/AzureStorageDevice/AzureStorageDevice.cs index 8a8be98a3..9954c3338 100644 --- a/cs/src/devices/AzureStorageDevice/AzureStorageDevice.cs +++ b/cs/src/devices/AzureStorageDevice/AzureStorageDevice.cs @@ -206,12 +206,14 @@ public override void RemoveSegmentAsync(int segment, AsyncCallback callback, IAs if (!this.BlobManager.CancellationToken.IsCancellationRequested) { var t = pageBlob.DeleteAsync(cancellationToken: this.BlobManager.CancellationToken); - t.GetAwaiter().GetResult(); // REVIEW: this method cannot avoid GetAwaiter - if (t.IsFaulted) + t.GetAwaiter().OnCompleted(() => // REVIEW: this method cannot avoid GetAwaiter { - this.BlobManager?.HandleBlobError(nameof(RemoveSegmentAsync), "could not remove page blob for segment", pageBlob?.Name, t.Exception, false); - } - callback(result); + if (t.IsFaulted) + { + this.BlobManager?.HandleBlobError(nameof(RemoveSegmentAsync), "could not remove page blob for segment", pageBlob?.Name, t.Exception, false); + } + callback(result); + }); } } } @@ -363,10 +365,10 @@ public override unsafe void ReadAsync(int segmentId, ulong sourceAddress, IntPtr throw exception; } - this.ReadFromBlobUnsafeAsync(blobEntry.PageBlob, (long)sourceAddress, (long)destinationAddress, readLength) - .ContinueWith((Task t) => - { - if (t.IsFaulted) + var t = this.ReadFromBlobUnsafeAsync(blobEntry.PageBlob, (long)sourceAddress, (long)destinationAddress, readLength); + t.GetAwaiter().OnCompleted(() => // REVIEW: this method cannot avoid GetAwaiter + { + if (t.IsFaulted) { Debug.WriteLine("AzureStorageDevice.ReadAsync Returned (Failure)"); callback(uint.MaxValue, readLength, context); @@ -400,7 +402,7 @@ public override void WriteAsync(IntPtr sourceAddress, int segmentId, ulong desti // If no blob exists for the segment, we must first create the segment asynchronouly. (Create call takes ~70 ms by measurement) // After creation is done, we can call write. - var _ = entry.CreateAsync(size, pageBlob); + entry.CreateAsync(size, pageBlob).GetAwaiter().GetResult(); // REVIEW: this method cannot avoid GetAwaiter } // Otherwise, some other thread beat us to it. Okay to use their blobs. blobEntry = blobs[segmentId]; @@ -422,20 +424,22 @@ private void TryWriteAsync(BlobEntry blobEntry, IntPtr sourceAddress, ulong dest private unsafe void WriteToBlobAsync(CloudPageBlob blob, IntPtr sourceAddress, ulong destinationAddress, uint numBytesToWrite, DeviceIOCompletionCallback callback, object context) { - this.WriteToBlobAsync(blob, sourceAddress, (long)destinationAddress, numBytesToWrite) - .ContinueWith((Task t) => + Debug.WriteLine($"AzureStorageDevice.WriteToBlobAsync Called target={blob.Name}"); + + var t = this.WriteToBlobAsync(blob, sourceAddress, (long)destinationAddress, numBytesToWrite); + t.GetAwaiter().OnCompleted(() => // REVIEW: this method cannot avoid GetAwaiter + { + if (t.IsFaulted) { - if (t.IsFaulted) - { - Debug.WriteLine("AzureStorageDevice.WriteAsync Returned (Failure)"); - callback(uint.MaxValue, numBytesToWrite, context); - } - else - { - Debug.WriteLine("AzureStorageDevice.WriteAsync Returned"); - callback(0, numBytesToWrite, context); - } - }); + Debug.WriteLine("AzureStorageDevice.WriteAsync Returned (Failure)"); + callback(uint.MaxValue, numBytesToWrite, context); + } + else + { + Debug.WriteLine("AzureStorageDevice.WriteAsync Returned"); + callback(0, numBytesToWrite, context); + } + }); } private async Task WriteToBlobAsync(CloudPageBlob blob, IntPtr sourceAddress, long destinationAddress, uint numBytesToWrite) diff --git a/cs/test/AsyncLargeObjectTests.cs b/cs/test/AsyncLargeObjectTests.cs index 84cd84fca..92bca0b3c 100644 --- a/cs/test/AsyncLargeObjectTests.cs +++ b/cs/test/AsyncLargeObjectTests.cs @@ -4,12 +4,10 @@ using System; using System.Threading.Tasks; using FASTER.core; -using System.IO; using NUnit.Framework; namespace FASTER.test.async { - [TestFixture] internal class LargeObjectTests { @@ -17,17 +15,13 @@ internal class LargeObjectTests private FasterKV<MyKey, MyLargeValue> fht2; private IDevice log, objlog; private string test_path; - private readonly MyLargeFunctions functions = new MyLargeFunctions(); + private readonly MyLargeFunctions functions = new(); [SetUp] public void Setup() { - if (test_path == null) - { - test_path = TestContext.CurrentContext.TestDirectory + "/" + Path.GetRandomFileName(); - if (!Directory.Exists(test_path)) - Directory.CreateDirectory(test_path); - } + test_path = TestUtils.MethodTestDir; + TestUtils.RecreateDirectory(test_path); } [TearDown] @@ -36,19 +30,17 @@ public void TearDown() TestUtils.DeleteDirectory(test_path); } - [TestCase(CheckpointType.FoldOver)] - [TestCase(CheckpointType.Snapshot)] + [Test] [Category("FasterKV")] - public async Task LargeObjectTest(CheckpointType checkpointType) + public async Task LargeObjectTest([Values]CheckpointType checkpointType) { MyInput input = default; - MyLargeOutput output = new MyLargeOutput(); + MyLargeOutput output = new (); log = Devices.CreateLogDevice(test_path + "/LargeObjectTest.log"); objlog = Devices.CreateLogDevice(test_path + "/LargeObjectTest.obj.log"); - fht1 = new FasterKV<MyKey, MyLargeValue> - (128, + fht1 = new (128, new LogSettings { LogDevice = log, ObjectLogDevice = objlog, MutableFraction = 0.1, PageSizeBits = 21, MemorySizeBits = 26 }, new CheckpointSettings { CheckpointDir = test_path, CheckPointType = checkpointType }, new SerializerSettings<MyKey, MyLargeValue> { keySerializer = () => new MyKeySerializer(), valueSerializer = () => new MyLargeValueSerializer() } @@ -78,8 +70,7 @@ public async Task LargeObjectTest(CheckpointType checkpointType) log = Devices.CreateLogDevice(test_path + "/LargeObjectTest.log"); objlog = Devices.CreateLogDevice(test_path + "/LargeObjectTest.obj.log"); - fht2 = new FasterKV<MyKey, MyLargeValue> - (128, + fht2 = new(128, new LogSettings { LogDevice = log, ObjectLogDevice = objlog, MutableFraction = 0.1, PageSizeBits = 21, MemorySizeBits = 26 }, new CheckpointSettings { CheckpointDir = test_path, CheckPointType = checkpointType }, new SerializerSettings<MyKey, MyLargeValue> { keySerializer = () => new MyKeySerializer(), valueSerializer = () => new MyLargeValueSerializer() } @@ -99,7 +90,7 @@ public async Task LargeObjectTest(CheckpointType checkpointType) { for (int i = 0; i < output.value.value.Length; i++) { - Assert.IsTrue(output.value.value[i] == (byte)(output.value.value.Length + i)); + Assert.AreEqual((byte)(output.value.value.Length + i), output.value.value[i]); } } } diff --git a/cs/test/AsyncTests.cs b/cs/test/AsyncTests.cs index 8244bfc71..05b60013f 100644 --- a/cs/test/AsyncTests.cs +++ b/cs/test/AsyncTests.cs @@ -11,9 +11,8 @@ namespace FASTER.test.async { - [TestFixture] - public class RecoveryTests + public class AsyncRecoveryTests { private FasterKV<AdId, NumClicks> fht1; private FasterKV<AdId, NumClicks> fht2; @@ -23,23 +22,27 @@ public class RecoveryTests [TestCase(CheckpointType.FoldOver)] [TestCase(CheckpointType.Snapshot)] - [Category("FasterKV")] + [Category("FasterKV"), Category("CheckpointRestore")] + [Category("Smoke")] + public async Task AsyncRecoveryTest1(CheckpointType checkpointType) { - log = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/SimpleRecoveryTest2.log", deleteOnClose: true); + TestUtils.DeleteDirectory(TestUtils.MethodTestDir, wait:true); + log = Devices.CreateLogDevice(TestUtils.MethodTestDir + "/AsyncRecoveryTest1.log", deleteOnClose: true); - Directory.CreateDirectory(TestContext.CurrentContext.TestDirectory + "/checkpoints4"); + string testPath = TestUtils.MethodTestDir + "/checkpoints4"; + Directory.CreateDirectory(testPath); fht1 = new FasterKV<AdId, NumClicks> (128, logSettings: new LogSettings { LogDevice = log, MutableFraction = 0.1, PageSizeBits = 10, MemorySizeBits = 13 }, - checkpointSettings: new CheckpointSettings { CheckpointDir = TestContext.CurrentContext.TestDirectory + "/checkpoints4", CheckPointType = checkpointType } + checkpointSettings: new CheckpointSettings { CheckpointDir = testPath, CheckPointType = checkpointType } ); fht2 = new FasterKV<AdId, NumClicks> (128, logSettings: new LogSettings { LogDevice = log, MutableFraction = 0.1, PageSizeBits = 10, MemorySizeBits = 13 }, - checkpointSettings: new CheckpointSettings { CheckpointDir = TestContext.CurrentContext.TestDirectory + "/checkpoints4", CheckPointType = checkpointType } + checkpointSettings: new CheckpointSettings { CheckpointDir = testPath, CheckPointType = checkpointType } ); int numOps = 5000; @@ -88,7 +91,7 @@ public async Task AsyncRecoveryTest1(CheckpointType checkpointType) var guid = s1.ID; using (var s3 = fht2.For(functions).ResumeSession<AdSimpleFunctions>(guid, out CommitPoint lsn)) { - Assert.IsTrue(lsn.UntilSerialNo == numOps - 1); + Assert.AreEqual(numOps - 1, lsn.UntilSerialNo); for (int key = 0; key < numOps; key++) { @@ -98,14 +101,14 @@ public async Task AsyncRecoveryTest1(CheckpointType checkpointType) s3.CompletePending(true,true); else { - Assert.IsTrue(output.value.numClicks == key); + Assert.AreEqual(key, output.value.numClicks); } } } fht2.Dispose(); log.Dispose(); - new DirectoryInfo(TestContext.CurrentContext.TestDirectory + "/checkpoints4").Delete(true); + TestUtils.DeleteDirectory(TestUtils.MethodTestDir); } } @@ -113,8 +116,8 @@ public class AdSimpleFunctions : FunctionsBase<AdId, NumClicks, AdInput, Output, { public override void ReadCompletionCallback(ref AdId key, ref AdInput input, ref Output output, Empty ctx, Status status) { - Assert.IsTrue(status == Status.OK); - Assert.IsTrue(output.value.numClicks == key.adId); + Assert.AreEqual(Status.OK, status); + Assert.AreEqual(key.adId, output.value.numClicks); } // Read functions diff --git a/cs/test/BasicDiskFASTERTests.cs b/cs/test/BasicDiskFASTERTests.cs index 37705a98b..0e1ff6505 100644 --- a/cs/test/BasicDiskFASTERTests.cs +++ b/cs/test/BasicDiskFASTERTests.cs @@ -2,16 +2,9 @@ // Licensed under the MIT license. using System; -using System.Text; -using System.Threading; -using System.Threading.Tasks; -using System.Collections.Generic; -using System.Linq; using FASTER.core; -using System.IO; using NUnit.Framework; using FASTER.devices; -using System.Diagnostics; namespace FASTER.test { @@ -19,48 +12,49 @@ namespace FASTER.test internal class BasicStorageFASTERTests { private FasterKV<KeyStruct, ValueStruct> fht; - public const string EMULATED_STORAGE_STRING = "UseDevelopmentStorage=true;"; - public const string TEST_CONTAINER = "test"; [Test] [Category("FasterKV")] public void LocalStorageWriteRead() { - TestDeviceWriteRead(Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/BasicDiskFASTERTests.log", deleteOnClose: true)); + TestDeviceWriteRead(Devices.CreateLogDevice(TestUtils.MethodTestDir + "/BasicDiskFASTERTests.log", deleteOnClose: true)); } [Test] [Category("FasterKV")] + [Category("Smoke")] public void PageBlobWriteRead() { - if ("yes".Equals(Environment.GetEnvironmentVariable("RunAzureTests"))) - TestDeviceWriteRead(new AzureStorageDevice(EMULATED_STORAGE_STRING, TEST_CONTAINER, "PageBlobWriteRead", "BasicDiskFASTERTests")); + TestUtils.IgnoreIfNotRunningAzureTests(); + TestDeviceWriteRead(new AzureStorageDevice(TestUtils.AzureEmulatedStorageString, TestUtils.AzureTestContainer, TestUtils.AzureTestDirectory, "BasicDiskFASTERTests")); } [Test] [Category("FasterKV")] + [Category("Smoke")] public void PageBlobWriteReadWithLease() { - if ("yes".Equals(Environment.GetEnvironmentVariable("RunAzureTests"))) - TestDeviceWriteRead(new AzureStorageDevice(EMULATED_STORAGE_STRING, TEST_CONTAINER, "PageBlobWriteRead", "BasicDiskFASTERTests",null,true,true)); + TestUtils.IgnoreIfNotRunningAzureTests(); + TestDeviceWriteRead(new AzureStorageDevice(TestUtils.AzureEmulatedStorageString, TestUtils.AzureTestContainer, TestUtils.AzureTestDirectory, "BasicDiskFASTERTests",null,true,true)); } - [Test] [Category("FasterKV")] + [Category("Smoke")] public void TieredWriteRead() { + TestUtils.DeleteDirectory(TestUtils.MethodTestDir); IDevice tested; - IDevice localDevice = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/BasicDiskFASTERTests.log", deleteOnClose: true, capacity: 1 << 30); - if ("yes".Equals(Environment.GetEnvironmentVariable("RunAzureTests"))) + IDevice localDevice = Devices.CreateLogDevice(TestUtils.MethodTestDir + "/BasicDiskFASTERTests.log", deleteOnClose: true, capacity: 1 << 30); + if (TestUtils.IsRunningAzureTests) { - IDevice cloudDevice = new AzureStorageDevice(EMULATED_STORAGE_STRING, TEST_CONTAINER, "TieredWriteRead", "BasicDiskFASTERTests"); + IDevice cloudDevice = new AzureStorageDevice(TestUtils.AzureEmulatedStorageString, TestUtils.AzureTestContainer, TestUtils.AzureTestDirectory, "BasicDiskFASTERTests"); tested = new TieredStorageDevice(1, localDevice, cloudDevice); } else { // If no Azure is enabled, just use another disk - IDevice localDevice2 = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/BasicDiskFASTERTests2.log", deleteOnClose: true, capacity: 1 << 30); + IDevice localDevice2 = Devices.CreateLogDevice(TestUtils.MethodTestDir + "/BasicDiskFASTERTests2.log", deleteOnClose: true, capacity: 1 << 30); tested = new TieredStorageDevice(1, localDevice, localDevice2); } @@ -69,10 +63,11 @@ public void TieredWriteRead() [Test] [Category("FasterKV")] + [Category("Smoke")] public void ShardedWriteRead() { - IDevice localDevice1 = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/BasicDiskFASTERTests1.log", deleteOnClose: true, capacity: 1 << 30); - IDevice localDevice2 = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/BasicDiskFASTERTests2.log", deleteOnClose: true, capacity: 1 << 30); + IDevice localDevice1 = Devices.CreateLogDevice(TestUtils.MethodTestDir + "/BasicDiskFASTERTests1.log", deleteOnClose: true, capacity: 1 << 30); + IDevice localDevice2 = Devices.CreateLogDevice(TestUtils.MethodTestDir + "/BasicDiskFASTERTests2.log", deleteOnClose: true, capacity: 1 << 30); var device = new ShardedStorageDevice(new UniformPartitionScheme(512, localDevice1, localDevice2)); TestDeviceWriteRead(device); } @@ -119,13 +114,13 @@ void TestDeviceWriteRead(IDevice log) { if (i < 100) { - Assert.IsTrue(output.value.vfield1 == value.vfield1 + 1); - Assert.IsTrue(output.value.vfield2 == value.vfield2 + 1); + Assert.AreEqual(value.vfield1 + 1, output.value.vfield1); + Assert.AreEqual(value.vfield2 + 1, output.value.vfield2); } else { - Assert.IsTrue(output.value.vfield1 == value.vfield1); - Assert.IsTrue(output.value.vfield2 == value.vfield2); + Assert.AreEqual(value.vfield1, output.value.vfield1); + Assert.AreEqual(value.vfield2, output.value.vfield2); } } } @@ -134,6 +129,7 @@ void TestDeviceWriteRead(IDevice log) fht.Dispose(); fht = null; log.Dispose(); + TestUtils.DeleteDirectory(TestUtils.MethodTestDir); } } } diff --git a/cs/test/BasicFASTERTests.cs b/cs/test/BasicFASTERTests.cs index b9cd3f12c..3c1eadb8e 100644 --- a/cs/test/BasicFASTERTests.cs +++ b/cs/test/BasicFASTERTests.cs @@ -8,7 +8,6 @@ namespace FASTER.test { - //** NOTE - more detailed / in depth Read tests in ReadAddressTests.cs //** These tests ensure the basics are fully covered @@ -18,29 +17,63 @@ internal class BasicFASTERTests private FasterKV<KeyStruct, ValueStruct> fht; private ClientSession<KeyStruct, ValueStruct, InputStruct, OutputStruct, Empty, Functions> session; private IDevice log; + private string path; + TestUtils.DeviceType deviceType; [SetUp] public void Setup() { - log = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/BasicFasterTests.log", deleteOnClose: true); - fht = new FasterKV<KeyStruct, ValueStruct> - (128, new LogSettings { LogDevice = log, MemorySizeBits = 29 }); - session = fht.For(new Functions()).NewSession<Functions>(); + path = TestUtils.MethodTestDir + "/"; + + // Clean up log files from previous test runs in case they weren't cleaned up + TestUtils.DeleteDirectory(path, wait: true); } [TearDown] public void TearDown() { - session.Dispose(); - fht.Dispose(); + session?.Dispose(); + session = null; + fht?.Dispose(); fht = null; - log.Dispose(); + log?.Dispose(); + log = null; + TestUtils.DeleteDirectory(path); + } + + private void AssertCompleted(Status expected, Status actual) + { + if (actual == Status.PENDING) + (actual, _) = CompletePendingResult(); + Assert.AreEqual(expected, actual); + } + + private (Status status, OutputStruct output) CompletePendingResult() + { + session.CompletePendingWithOutputs(out var completedOutputs); + return TestUtils.GetSinglePendingResult(completedOutputs); + } + + private static (Status status, OutputStruct output) CompletePendingResult(CompletedOutputIterator<KeyStruct, ValueStruct, InputStruct, OutputStruct, Empty> completedOutputs) + { + Assert.IsTrue(completedOutputs.Next()); + var result = (completedOutputs.Current.Status, completedOutputs.Current.Output); + Assert.IsFalse(completedOutputs.Next()); + completedOutputs.Dispose(); + return result; } [Test] [Category("FasterKV")] - public void NativeInMemWriteRead() + [Category("Smoke")] + public void NativeInMemWriteRead([Values] TestUtils.DeviceType deviceType) { + string filename = path + "NativeInMemWriteRead" + deviceType.ToString() + ".log"; + log = TestUtils.CreateTestDevice(deviceType, filename); + fht = new FasterKV<KeyStruct, ValueStruct> + (128, new LogSettings { LogDevice = log, PageSizeBits = 10, MemorySizeBits = 12, SegmentSizeBits = 22 }); + session = fht.For(new Functions()).NewSession<Functions>(); + InputStruct input = default; OutputStruct output = default; @@ -50,23 +83,22 @@ public void NativeInMemWriteRead() session.Upsert(ref key1, ref value, Empty.Default, 0); var status = session.Read(ref key1, ref input, ref output, Empty.Default, 0); - if (status == Status.PENDING) - { - session.CompletePending(true); - } - else - { - Assert.IsTrue(status == Status.OK); - } - - Assert.IsTrue(output.value.vfield1 == value.vfield1); - Assert.IsTrue(output.value.vfield2 == value.vfield2); + AssertCompleted(Status.OK, status); + Assert.AreEqual(value.vfield1, output.value.vfield1); + Assert.AreEqual(value.vfield2, output.value.vfield2); } [Test] [Category("FasterKV")] - public void NativeInMemWriteReadDelete() + [Category("Smoke")] + public void NativeInMemWriteReadDelete([Values] TestUtils.DeviceType deviceType) { + string filename = path + "NativeInMemWriteReadDelete" + deviceType.ToString() + ".log"; + log = TestUtils.CreateTestDevice(deviceType, filename); + fht = new FasterKV<KeyStruct, ValueStruct> + (128, new LogSettings { LogDevice = log, PageSizeBits = 10, MemorySizeBits = 12, SegmentSizeBits = 22 }); + session = fht.For(new Functions()).NewSession<Functions>(); + InputStruct input = default; OutputStruct output = default; @@ -75,28 +107,12 @@ public void NativeInMemWriteReadDelete() session.Upsert(ref key1, ref value, Empty.Default, 0); var status = session.Read(ref key1, ref input, ref output, Empty.Default, 0); - - if (status == Status.PENDING) - { - session.CompletePending(true); - } - else - { - Assert.IsTrue(status == Status.OK); - } + AssertCompleted(Status.OK, status); session.Delete(ref key1, Empty.Default, 0); status = session.Read(ref key1, ref input, ref output, Empty.Default, 0); - - if (status == Status.PENDING) - { - session.CompletePending(true); - } - else - { - Assert.IsTrue(status == Status.NOTFOUND); - } + AssertCompleted(Status.NOTFOUND, status); var key2 = new KeyStruct { kfield1 = 14, kfield2 = 15 }; var value2 = new ValueStruct { vfield1 = 24, vfield2 = 25 }; @@ -104,26 +120,29 @@ public void NativeInMemWriteReadDelete() session.Upsert(ref key2, ref value2, Empty.Default, 0); status = session.Read(ref key2, ref input, ref output, Empty.Default, 0); - if (status == Status.PENDING) - { - session.CompletePending(true); - } - else - { - Assert.IsTrue(status == Status.OK); - } - - Assert.IsTrue(output.value.vfield1 == value2.vfield1); - Assert.IsTrue(output.value.vfield2 == value2.vfield2); + AssertCompleted(Status.OK, status); + Assert.AreEqual(value2.vfield1, output.value.vfield1); + Assert.AreEqual(value2.vfield2, output.value.vfield2); } [Test] [Category("FasterKV")] + [Category("Smoke")] public void NativeInMemWriteReadDelete2() { + // Just set this one since Write Read Delete already does all four devices + deviceType = TestUtils.DeviceType.MLSD; + const int count = 10; + string filename = path + "NativeInMemWriteReadDelete2" + deviceType.ToString() + ".log"; + log = TestUtils.CreateTestDevice(deviceType, filename); + fht = new FasterKV<KeyStruct, ValueStruct> + // (128, new LogSettings { LogDevice = log, MemorySizeBits = 22, SegmentSizeBits = 22, PageSizeBits = 10 }); + (128, new LogSettings { LogDevice = log, MemorySizeBits = 29 }); + session = fht.For(new Functions()).NewSession<Functions>(); + InputStruct input = default; OutputStruct output = default; @@ -147,15 +166,7 @@ public void NativeInMemWriteReadDelete2() var value = new ValueStruct { vfield1 = i, vfield2 = 24 }; var status = session.Read(ref key1, ref input, ref output, Empty.Default, 0); - - if (status == Status.PENDING) - { - session.CompletePending(true); - } - else - { - Assert.IsTrue(status == Status.NOTFOUND); - } + AssertCompleted(Status.NOTFOUND, status); session.Upsert(ref key1, ref value, Empty.Default, 0); } @@ -163,28 +174,32 @@ public void NativeInMemWriteReadDelete2() for (int i = 0; i < 10 * count; i++) { var key1 = new KeyStruct { kfield1 = i, kfield2 = 14 }; - var status = session.Read(ref key1, ref input, ref output, Empty.Default, 0); - - if (status == Status.PENDING) - { - session.CompletePending(true); - } - else - { - Assert.IsTrue(status == Status.OK); - } + AssertCompleted(Status.OK, status); } } [Test] [Category("FasterKV")] + [Category("Smoke")] public unsafe void NativeInMemWriteRead2() { + // Just use this one instead of all four devices since InMemWriteRead covers all four devices + deviceType = TestUtils.DeviceType.MLSD; + + int count = 200; + + string filename = path + "NativeInMemWriteRead2" + deviceType.ToString() + ".log"; + log = TestUtils.CreateTestDevice(deviceType, filename); + fht = new FasterKV<KeyStruct, ValueStruct> + // (128, new LogSettings { LogDevice = log, MemorySizeBits = 22, SegmentSizeBits = 22, PageSizeBits = 10 }); + (128, new LogSettings { LogDevice = log, MemorySizeBits = 29 }); + session = fht.For(new Functions()).NewSession<Functions>(); + InputStruct input = default; Random r = new Random(10); - for (int c = 0; c < 1000; c++) + for (int c = 0; c < count; c++) { var i = r.Next(10000); var key1 = new KeyStruct { kfield1 = i, kfield2 = i + 1 }; @@ -194,7 +209,7 @@ public unsafe void NativeInMemWriteRead2() r = new Random(10); - for (int c = 0; c < 1000; c++) + for (int c = 0; c < count; c++) { var i = r.Next(10000); OutputStruct output = default; @@ -206,31 +221,40 @@ public unsafe void NativeInMemWriteRead2() session.CompletePending(true); } - Assert.IsTrue(output.value.vfield1 == value.vfield1); - Assert.IsTrue(output.value.vfield2 == value.vfield2); + Assert.AreEqual(value.vfield1, output.value.vfield1); + Assert.AreEqual(value.vfield2, output.value.vfield2); } // Clean up and retry - should not find now fht.Log.ShiftBeginAddress(fht.Log.TailAddress); r = new Random(10); - for (int c = 0; c < 1000; c++) + for (int c = 0; c < count; c++) { var i = r.Next(10000); OutputStruct output = default; var key1 = new KeyStruct { kfield1 = i, kfield2 = i + 1 }; - Assert.IsTrue(session.Read(ref key1, ref input, ref output, Empty.Default, 0) == Status.NOTFOUND); + Assert.AreEqual(Status.NOTFOUND, session.Read(ref key1, ref input, ref output, Empty.Default, 0)); } } [Test] [Category("FasterKV")] - public unsafe void TestShiftHeadAddress() + [Category("Smoke")] + public unsafe void TestShiftHeadAddress([Values] TestUtils.DeviceType deviceType) { InputStruct input = default; + int count = 200; + + string filename = path + "TestShiftHeadAddress" + deviceType.ToString() + ".log"; + log = TestUtils.CreateTestDevice(deviceType, filename); + fht = new FasterKV<KeyStruct, ValueStruct> + (128, new LogSettings { LogDevice = log, MemorySizeBits = 22, SegmentSizeBits = 22, PageSizeBits = 10 }); + session = fht.For(new Functions()).NewSession<Functions>(); + Random r = new Random(10); - for (int c = 0; c < 1000; c++) + for (int c = 0; c < count; c++) { var i = r.Next(10000); var key1 = new KeyStruct { kfield1 = i, kfield2 = i + 1 }; @@ -240,7 +264,7 @@ public unsafe void TestShiftHeadAddress() r = new Random(10); - for (int c = 0; c < 1000; c++) + for (int c = 0; c < count; c++) { var i = r.Next(10000); OutputStruct output = default; @@ -252,31 +276,39 @@ public unsafe void TestShiftHeadAddress() session.CompletePending(true); } - Assert.IsTrue(output.value.vfield1 == value.vfield1); - Assert.IsTrue(output.value.vfield2 == value.vfield2); + Assert.AreEqual(value.vfield1, output.value.vfield1); + Assert.AreEqual(value.vfield2, output.value.vfield2); } // Shift head and retry - should not find in main memory now fht.Log.FlushAndEvict(true); r = new Random(10); - for (int c = 0; c < 1000; c++) + for (int c = 0; c < count; c++) { var i = r.Next(10000); OutputStruct output = default; var key1 = new KeyStruct { kfield1 = i, kfield2 = i + 1 }; - Assert.IsTrue(session.Read(ref key1, ref input, ref output, Empty.Default, 0) == Status.PENDING); + Status foundStatus = session.Read(ref key1, ref input, ref output, Empty.Default, 0); + Assert.AreEqual(Status.PENDING, foundStatus); session.CompletePending(true); } } [Test] [Category("FasterKV")] - public unsafe void NativeInMemRMWRefKeys() + [Category("Smoke")] + public unsafe void NativeInMemRMWRefKeys([Values] TestUtils.DeviceType deviceType) { InputStruct input = default; OutputStruct output = default; + string filename = path + "NativeInMemRMWRefKeys" + deviceType.ToString() + ".log"; + log = TestUtils.CreateTestDevice(deviceType, filename); + fht = new FasterKV<KeyStruct, ValueStruct> + (128, new LogSettings { LogDevice = log, MemorySizeBits = 22, SegmentSizeBits = 22, PageSizeBits = 10 }); + session = fht.For(new Functions()).NewSession<Functions>(); + var nums = Enumerable.Range(0, 1000).ToArray(); var rnd = new Random(11); for (int i = 0; i < nums.Length; ++i) @@ -322,40 +354,29 @@ public unsafe void NativeInMemRMWRefKeys() status = session.Read(ref key, ref input, ref output, Empty.Default, 0); - if (status == Status.PENDING) - { - session.CompletePending(true); - } - else - { - Assert.IsTrue(status == Status.OK); - } - Assert.IsTrue(output.value.vfield1 == 2 * value.vfield1, "found " + output.value.vfield1 + ", expected " + 2 * value.vfield1); - Assert.IsTrue(output.value.vfield2 == 2 * value.vfield2); + AssertCompleted(Status.OK, status); + Assert.AreEqual(2 * value.vfield1, output.value.vfield1); + Assert.AreEqual(2 * value.vfield2, output.value.vfield2); } key = new KeyStruct { kfield1 = nums.Length, kfield2 = nums.Length + 1 }; status = session.Read(ref key, ref input, ref output, Empty.Default, 0); - - if (status == Status.PENDING) - { - session.CompletePending(true); - } - else - { - Assert.IsTrue(status == Status.NOTFOUND); - } + AssertCompleted(Status.NOTFOUND, status); } - // Tests the overload where no reference params used: key,input,userContext,serialNo - [Test] [Category("FasterKV")] - public unsafe void NativeInMemRMWNoRefKeys() + public unsafe void NativeInMemRMWNoRefKeys([Values] TestUtils.DeviceType deviceType) { InputStruct input = default; + string filename = path + "NativeInMemRMWNoRefKeys" + deviceType.ToString() + ".log"; + log = TestUtils.CreateTestDevice(deviceType, filename); + fht = new FasterKV<KeyStruct, ValueStruct> + (128, new LogSettings { LogDevice = log, MemorySizeBits = 22, SegmentSizeBits = 22, PageSizeBits = 10 }); + session = fht.For(new Functions()).NewSession<Functions>(); + var nums = Enumerable.Range(0, 1000).ToArray(); var rnd = new Random(11); for (int i = 0; i < nums.Length; ++i) @@ -394,129 +415,114 @@ public unsafe void NativeInMemRMWNoRefKeys() status = session.Read(ref key, ref input, ref output, Empty.Default, 0); - if (status == Status.PENDING) - { - session.CompletePending(true); - } - else - { - Assert.IsTrue(status == Status.OK); - } - Assert.IsTrue(output.value.vfield1 == 2 * value.vfield1, "found " + output.value.vfield1 + ", expected " + 2 * value.vfield1); - Assert.IsTrue(output.value.vfield2 == 2 * value.vfield2); + AssertCompleted(Status.OK, status); + Assert.AreEqual(2 * value.vfield1, output.value.vfield1); + Assert.AreEqual(2 * value.vfield2, output.value.vfield2); } key = new KeyStruct { kfield1 = nums.Length, kfield2 = nums.Length + 1 }; status = session.Read(ref key, ref input, ref output, Empty.Default, 0); - - if (status == Status.PENDING) - { - session.CompletePending(true); - } - else - { - Assert.IsTrue(status == Status.NOTFOUND); - } + AssertCompleted(Status.NOTFOUND, status); } - - // Tests the overload of .Read(key, input, out output, context, serialNo) [Test] [Category("FasterKV")] - public void ReadNoRefKeyInputOutput() + [Category("Smoke")] + public void ReadNoRefKeyInputOutput([Values] TestUtils.DeviceType deviceType) { InputStruct input = default; - OutputStruct output = default; + + string filename = path + "ReadNoRefKeyInputOutput" + deviceType.ToString() + ".log"; + log = TestUtils.CreateTestDevice(deviceType, filename); + fht = new FasterKV<KeyStruct, ValueStruct> + (128, new LogSettings { LogDevice = log, MemorySizeBits = 22, SegmentSizeBits = 22, PageSizeBits = 10 }); + session = fht.For(new Functions()).NewSession<Functions>(); var key1 = new KeyStruct { kfield1 = 13, kfield2 = 14 }; var value = new ValueStruct { vfield1 = 23, vfield2 = 24 }; session.Upsert(ref key1, ref value, Empty.Default, 0); - var status = session.Read(key1, input, out output, Empty.Default, 111); - - if (status == Status.PENDING) - { - session.CompletePending(true); - } - else - { - Assert.IsTrue(status == Status.OK); - } + var status = session.Read(key1, input, out OutputStruct output, Empty.Default, 111); + AssertCompleted(Status.OK, status); // Verify the read data - Assert.IsTrue(output.value.vfield1 == value.vfield1); - Assert.IsTrue(output.value.vfield2 == value.vfield2); - Assert.IsTrue(13 == key1.kfield1); - Assert.IsTrue(14 == key1.kfield2); + Assert.AreEqual(value.vfield1, output.value.vfield1); + Assert.AreEqual(value.vfield2, output.value.vfield2); + Assert.AreEqual(key1.kfield1, 13); + Assert.AreEqual(key1.kfield2, 14); } - // Test the overload call of .Read (key, out output, userContext, serialNo) [Test] [Category("FasterKV")] - public void ReadNoRefKey() + public void ReadNoRefKey([Values] TestUtils.DeviceType deviceType) { - OutputStruct output = default; + string filename = path + "ReadNoRefKey" + deviceType.ToString() + ".log"; + log = TestUtils.CreateTestDevice(deviceType, filename); + fht = new FasterKV<KeyStruct, ValueStruct> + (128, new LogSettings { LogDevice = log, MemorySizeBits = 22, SegmentSizeBits = 22, PageSizeBits = 10 }); + session = fht.For(new Functions()).NewSession<Functions>(); var key1 = new KeyStruct { kfield1 = 13, kfield2 = 14 }; var value = new ValueStruct { vfield1 = 23, vfield2 = 24 }; session.Upsert(ref key1, ref value, Empty.Default, 0); - var status = session.Read(key1, out output, Empty.Default, 1); - - if (status == Status.PENDING) - { - session.CompletePending(true); - } - else - { - Assert.IsTrue(status == Status.OK); - } + var status = session.Read(key1, out OutputStruct output, Empty.Default, 1); + AssertCompleted(Status.OK, status); // Verify the read data - Assert.IsTrue(output.value.vfield1 == value.vfield1); - Assert.IsTrue(output.value.vfield2 == value.vfield2); - Assert.IsTrue(13 == key1.kfield1); - Assert.IsTrue(14 == key1.kfield2); + Assert.AreEqual(value.vfield1, output.value.vfield1); + Assert.AreEqual(value.vfield2, output.value.vfield2); + Assert.AreEqual(key1.kfield1, 13); + Assert.AreEqual(key1.kfield2, 14); } // Test the overload call of .Read (ref key, ref output, userContext, serialNo) [Test] [Category("FasterKV")] - public void ReadWithoutInput() + [Category("Smoke")] + public void ReadWithoutInput([Values] TestUtils.DeviceType deviceType) { + string filename = path + "ReadWithoutInput" + deviceType.ToString() + ".log"; + log = TestUtils.CreateTestDevice(deviceType, filename); + fht = new FasterKV<KeyStruct, ValueStruct> + (128, new LogSettings { LogDevice = log, MemorySizeBits = 22, SegmentSizeBits = 22, PageSizeBits = 10 }); + session = fht.For(new Functions()).NewSession<Functions>(); + OutputStruct output = default; var key1 = new KeyStruct { kfield1 = 13, kfield2 = 14 }; var value = new ValueStruct { vfield1 = 23, vfield2 = 24 }; session.Upsert(ref key1, ref value, Empty.Default, 0); - var status = session.Read(ref key1, ref output, Empty.Default,99); - - if (status == Status.PENDING) - { - session.CompletePending(true); - } - else - { - Assert.IsTrue(status == Status.OK); - } + var status = session.Read(ref key1, ref output, Empty.Default, 99); + AssertCompleted(Status.OK, status); // Verify the read data - Assert.IsTrue(output.value.vfield1 == value.vfield1); - Assert.IsTrue(output.value.vfield2 == value.vfield2); - Assert.IsTrue(13 == key1.kfield1); - Assert.IsTrue(14 == key1.kfield2); + Assert.AreEqual(value.vfield1, output.value.vfield1); + Assert.AreEqual(value.vfield2, output.value.vfield2); + Assert.AreEqual(key1.kfield1, 13); + Assert.AreEqual(key1.kfield2, 14); } // Test the overload call of .Read (ref key, ref input, ref output, ref recordInfo, userContext: context) [Test] [Category("FasterKV")] + [Category("Smoke")] public void ReadWithoutSerialID() { + // Just checking without Serial ID so one device type is enough + deviceType = TestUtils.DeviceType.MLSD; + + string filename = path + "ReadWithoutSerialID" + deviceType.ToString() + ".log"; + log = TestUtils.CreateTestDevice(deviceType, filename); + fht = new FasterKV<KeyStruct, ValueStruct> + (128, new LogSettings { LogDevice = log, MemorySizeBits = 29 }); + session = fht.For(new Functions()).NewSession<Functions>(); + InputStruct input = default; OutputStruct output = default; @@ -525,55 +531,55 @@ public void ReadWithoutSerialID() session.Upsert(ref key1, ref value, Empty.Default, 0); var status = session.Read(ref key1, ref input, ref output, Empty.Default); + AssertCompleted(Status.OK, status); - if (status == Status.PENDING) - { - session.CompletePending(true); - } - else - { - Assert.IsTrue(status == Status.OK); - } - - Assert.IsTrue(output.value.vfield1 == value.vfield1); - Assert.IsTrue(output.value.vfield2 == value.vfield2); - Assert.IsTrue(13 == key1.kfield1); - Assert.IsTrue(14 == key1.kfield2); + Assert.AreEqual(value.vfield1, output.value.vfield1); + Assert.AreEqual(value.vfield2, output.value.vfield2); + Assert.AreEqual(key1.kfield1, 13); + Assert.AreEqual(key1.kfield2, 14); } - // Test the overload call of .Read (key) [Test] [Category("FasterKV")] - public void ReadBareMinParams() + [Category("Smoke")] + public void ReadBareMinParams([Values] TestUtils.DeviceType deviceType) { + string filename = path + "ReadBareMinParams" + deviceType.ToString() + ".log"; + log = TestUtils.CreateTestDevice(deviceType, filename); + fht = new FasterKV<KeyStruct, ValueStruct> + (128, new LogSettings { LogDevice = log, MemorySizeBits = 22, SegmentSizeBits = 22, PageSizeBits = 10 }); + session = fht.For(new Functions()).NewSession<Functions>(); + var key1 = new KeyStruct { kfield1 = 13, kfield2 = 14 }; var value = new ValueStruct { vfield1 = 23, vfield2 = 24 }; session.Upsert(ref key1, ref value, Empty.Default, 0); - var status = session.Read(key1); - - if (status.Item1 == Status.PENDING) - { - session.CompletePending(true); - } - else - { - Assert.IsTrue(status.Item1 == Status.OK); - } + var (status, output) = session.Read(key1); + AssertCompleted(Status.OK, status); - Assert.IsTrue(status.Item2.value.vfield1 == value.vfield1); - Assert.IsTrue(status.Item2.value.vfield2 == value.vfield2); - Assert.IsTrue(13 == key1.kfield1); - Assert.IsTrue(14 == key1.kfield2); + Assert.AreEqual(value.vfield1, output.value.vfield1); + Assert.AreEqual(value.vfield2, output.value.vfield2); + Assert.AreEqual(key1.kfield1, 13); + Assert.AreEqual(key1.kfield2, 14); } // Test the ReadAtAddress where ReadFlags = ReadFlags.none [Test] [Category("FasterKV")] + [Category("Smoke")] public void ReadAtAddressReadFlagsNone() { + // Just functional test of ReadFlag so one device is enough + deviceType = TestUtils.DeviceType.MLSD; + + string filename = path + "ReadAtAddressReadFlagsNone" + deviceType.ToString() + ".log"; + log = TestUtils.CreateTestDevice(deviceType, filename); + fht = new FasterKV<KeyStruct, ValueStruct> + (128, new LogSettings { LogDevice = log, MemorySizeBits = 29 }); + session = fht.For(new Functions()).NewSession<Functions>(); + InputStruct input = default; OutputStruct output = default; @@ -582,132 +588,194 @@ public void ReadAtAddressReadFlagsNone() var readAtAddress = fht.Log.BeginAddress; session.Upsert(ref key1, ref value, Empty.Default, 0); - var status = session.ReadAtAddress(readAtAddress, ref input, ref output, ReadFlags.None,Empty.Default,0); + var status = session.ReadAtAddress(readAtAddress, ref input, ref output, ReadFlags.None, Empty.Default, 0); + AssertCompleted(Status.OK, status); + + Assert.AreEqual(value.vfield1, output.value.vfield1); + Assert.AreEqual(value.vfield2, output.value.vfield2); + Assert.AreEqual(key1.kfield1, 13); + Assert.AreEqual(key1.kfield2, 14); + } + + // Test the ReadAtAddress where ReadFlags = ReadFlags.SkipReadCache + + class SkipReadCacheFunctions : AdvancedFunctions // Must use AdvancedFunctions for the address parameters to the callbacks + { + internal long expectedReadAddress; - if (status == Status.PENDING) + public override void SingleReader(ref KeyStruct key, ref InputStruct input, ref ValueStruct value, ref OutputStruct dst, long address) + => Assign(ref value, ref dst, address); + + public override void ConcurrentReader(ref KeyStruct key, ref InputStruct input, ref ValueStruct value, ref OutputStruct dst, ref RecordInfo recordInfo, long address) + => Assign(ref value, ref dst, address); + + void Assign(ref ValueStruct value, ref OutputStruct dst, long address) { - session.CompletePending(true); + dst.value = value; + Assert.AreEqual(expectedReadAddress, address); + expectedReadAddress = -1; // show that the test executed } - else + public override void ReadCompletionCallback(ref KeyStruct key, ref InputStruct input, ref OutputStruct output, Empty ctx, Status status, RecordInfo recordInfo) { - Assert.IsTrue(status == Status.OK); + // Do no data verifications here; they're done in the test } - - Assert.IsTrue(output.value.vfield1 == value.vfield1); - Assert.IsTrue(output.value.vfield2 == value.vfield2); - Assert.IsTrue(13 == key1.kfield1); - Assert.IsTrue(14 == key1.kfield2); } - // Test the ReadAtAddress where ReadFlags = ReadFlags.SkipReadCache - [Test] [Category("FasterKV")] + [Category("Smoke")] public void ReadAtAddressReadFlagsSkipReadCache() { + // Another ReadFlag functional test so one device is enough + deviceType = TestUtils.DeviceType.MLSD; + + string filename = path + "ReadAtAddressReadFlagsSkipReadCache" + deviceType.ToString() + ".log"; + log = TestUtils.CreateTestDevice(deviceType, filename); + fht = new FasterKV<KeyStruct, ValueStruct> + (128, new LogSettings { LogDevice = log, MemorySizeBits = 29, ReadCacheSettings = new ReadCacheSettings() }); + + SkipReadCacheFunctions functions = new(); + using var skipReadCacheSession = fht.For(functions).NewSession<SkipReadCacheFunctions>(); + InputStruct input = default; OutputStruct output = default; var key1 = new KeyStruct { kfield1 = 13, kfield2 = 14 }; var value = new ValueStruct { vfield1 = 23, vfield2 = 24 }; var readAtAddress = fht.Log.BeginAddress; + Status status; + skipReadCacheSession.Upsert(ref key1, ref value, Empty.Default, 0); - - session.Upsert(ref key1, ref value, Empty.Default, 0); - //**** TODO: When Bug Fixed ... use the invalidAddress line - // Bug #136259 - // Ah—slight bug here.I took a quick look to verify that the logicalAddress passed to SingleReader was kInvalidAddress(0), - //and while I got that right for the SingleWriter call, I missed it on the SingleReader. - //This is because we streamlined it to no longer expose RecordAccessor.IsReadCacheAddress, and I missed it here. - // test that the record retrieved on Read variants that do find the record in the readCache - //pass Constants.kInvalidAddress as the ‘address’ parameter to SingleReader. - // Reading the same record using Read(…, ref RecordInfo…) - //and ReadFlags.SkipReadCache should return a valid record there. - //For now, write the same test, and instead of testing for address == kInvalidAddress, - //test for (address & Constants.kReadCacheBitMask) != 0. - - - //var status = session.ReadAtAddress(invalidAddress, ref input, ref output, ReadFlags.SkipReadCache); - var status = session.ReadAtAddress(readAtAddress, ref input, ref output, ReadFlags.SkipReadCache); - - if (status == Status.PENDING) + void VerifyOutput() { - session.CompletePending(true); + Assert.AreEqual(-1, functions.expectedReadAddress); // make sure the test executed + Assert.AreEqual(value.vfield1, output.value.vfield1); + Assert.AreEqual(value.vfield2, output.value.vfield2); + Assert.AreEqual(13, key1.kfield1); + Assert.AreEqual(14, key1.kfield2); } - else + + void VerifyResult() { - Assert.IsTrue(status == Status.OK); + if (status == Status.PENDING) + { + skipReadCacheSession.CompletePendingWithOutputs(out var completedOutputs, wait: true); + (status, output) = TestUtils.GetSinglePendingResult(completedOutputs); + } + Assert.AreEqual(Status.OK, status); + VerifyOutput(); } - Assert.IsTrue(output.value.vfield1 == value.vfield1); - Assert.IsTrue(output.value.vfield2 == value.vfield2); - Assert.IsTrue(13 == key1.kfield1); - Assert.IsTrue(14 == key1.kfield2); + // This will just be an ordinary read, as the record is in memory. + functions.expectedReadAddress = readAtAddress; + status = skipReadCacheSession.Read(ref key1, ref input, ref output); + Assert.AreEqual(Status.OK, status); + VerifyOutput(); + + // ReadCache is used when the record is read from disk. + fht.Log.FlushAndEvict(wait:true); + + // SkipReadCache is primarily for indexing, so a read during index scan does not result in a readcache update. + // Reading at a normal logical address will not use the readcache, because the "readcache" bit is not set in that logical address. + // And we cannot get a readcache address, since reads satisfied from the readcache pass kInvalidAddress to functions. + // Therefore, we test here simply that we do not put it in the readcache when we tell it not to. + + // Do not put it into the read cache. + functions.expectedReadAddress = readAtAddress; + RecordInfo recordInfo = new() { PreviousAddress = readAtAddress }; + status = skipReadCacheSession.Read(ref key1, ref input, ref output, ref recordInfo, ReadFlags.SkipReadCache); + VerifyResult(); + + Assert.AreEqual(fht.ReadCache.BeginAddress, fht.ReadCache.TailAddress); + + // Put it into the read cache. + functions.expectedReadAddress = readAtAddress; + recordInfo.PreviousAddress = readAtAddress; // Read*() sets this to the record's PreviousAddress (so caller can follow the chain), so reinitialize it. + status = skipReadCacheSession.Read(ref key1, ref input, ref output, ref recordInfo); + VerifyResult(); + + Assert.Less(fht.ReadCache.BeginAddress, fht.ReadCache.TailAddress); + + // Now this will read from the read cache. + functions.expectedReadAddress = Constants.kInvalidAddress; + status = skipReadCacheSession.Read(ref key1, ref input, ref output); + Assert.AreEqual(Status.OK, status); + VerifyOutput(); } // Simple Upsert test where ref key and ref value but nothing else set [Test] [Category("FasterKV")] - public void UpsertDefaultsTest() + [Category("Smoke")] + public void UpsertDefaultsTest([Values] TestUtils.DeviceType deviceType) { + string filename = path + "UpsertDefaultsTest" + deviceType.ToString() + ".log"; + log = TestUtils.CreateTestDevice(deviceType, filename); + fht = new FasterKV<KeyStruct, ValueStruct> + (128, new LogSettings { LogDevice = log, MemorySizeBits = 22, SegmentSizeBits = 22, PageSizeBits = 10 }); + session = fht.For(new Functions()).NewSession<Functions>(); + InputStruct input = default; OutputStruct output = default; var key1 = new KeyStruct { kfield1 = 13, kfield2 = 14 }; var value = new ValueStruct { vfield1 = 23, vfield2 = 24 }; - Assert.IsTrue(fht.EntryCount == 0); + Assert.AreEqual(0, fht.EntryCount); session.Upsert(ref key1, ref value); var status = session.Read(ref key1, ref input, ref output, Empty.Default, 0); + AssertCompleted(Status.OK, status); - if (status == Status.PENDING) - { - session.CompletePending(true); - } - else - { - Assert.IsTrue(status == Status.OK); - } - - Assert.IsTrue(fht.EntryCount == 1); - Assert.IsTrue(output.value.vfield1 == value.vfield1); - Assert.IsTrue(output.value.vfield2 == value.vfield2); + Assert.AreEqual(1, fht.EntryCount); + Assert.AreEqual(value.vfield1, output.value.vfield1); + Assert.AreEqual(value.vfield2, output.value.vfield2); } // Simple Upsert test of overload where not using Ref for key and value and setting all parameters [Test] [Category("FasterKV")] + [Category("Smoke")] public void UpsertNoRefNoDefaultsTest() { + // Just checking more parameter values so one device is enough + deviceType = TestUtils.DeviceType.MLSD; + + string filename = path + "UpsertNoRefNoDefaultsTest" + deviceType.ToString() + ".log"; + log = TestUtils.CreateTestDevice(deviceType, filename); + fht = new FasterKV<KeyStruct, ValueStruct> + (128, new LogSettings { LogDevice = log, MemorySizeBits = 29 }); + session = fht.For(new Functions()).NewSession<Functions>(); + InputStruct input = default; OutputStruct output = default; var key1 = new KeyStruct { kfield1 = 13, kfield2 = 14 }; var value = new ValueStruct { vfield1 = 23, vfield2 = 24 }; - session.Upsert(key1, value, Empty.Default,0); - var status = session.Read(ref key1, ref input, ref output, Empty.Default,0); - - if (status == Status.PENDING) - { - session.CompletePending(true); - } - else - { - Assert.IsTrue(status == Status.OK); - } + session.Upsert(key1, value, Empty.Default, 0); + var status = session.Read(ref key1, ref input, ref output, Empty.Default, 0); + AssertCompleted(Status.OK, status); - Assert.IsTrue(output.value.vfield1 == value.vfield1); - Assert.IsTrue(output.value.vfield2 == value.vfield2); + Assert.AreEqual(value.vfield1, output.value.vfield1); + Assert.AreEqual(value.vfield2, output.value.vfield2); } // Upsert Test using Serial Numbers ... based on the VersionedRead Sample [Test] [Category("FasterKV")] + [Category("Smoke")] public void UpsertSerialNumberTest() { + // Simple Upsert of Serial Number test so one device is enough + deviceType = TestUtils.DeviceType.MLSD; + + string filename = path + "UpsertSerialNumberTest" + deviceType.ToString() + ".log"; + log = TestUtils.CreateTestDevice(deviceType, filename); + fht = new FasterKV<KeyStruct, ValueStruct> + (128, new LogSettings { LogDevice = log, MemorySizeBits = 29 }); + session = fht.For(new Functions()).NewSession<Functions>(); int numKeys = 100; int keyMod = 10; @@ -730,17 +798,9 @@ public void UpsertSerialNumberTest() { var status = session.Read(ref key, ref input, ref output, serialNo: maxLap + 1); - if (status == Status.PENDING) - { - session.CompletePending(true); - } - else - { - Assert.IsTrue(status == Status.OK); - } - - Assert.IsTrue(output.value.vfield1 == value.vfield1); - Assert.IsTrue(output.value.vfield2 == value.vfield2); + AssertCompleted(Status.OK, status); + Assert.AreEqual(value.vfield1, output.value.vfield1); + Assert.AreEqual(value.vfield2, output.value.vfield2); } } @@ -751,18 +811,18 @@ public void UpsertSerialNumberTest() [Category("FasterKV")] public static void KVBasicsSampleEndToEndInDocs() { - string testDir = $"{TestContext.CurrentContext.TestDirectory}"; - using var log = Devices.CreateLogDevice($"{testDir}/hlog.log", deleteOnClose: true); + string testDir = TestUtils.MethodTestDir; + using var log = Devices.CreateLogDevice($"{testDir}/hlog.log", deleteOnClose: false); using var store = new FasterKV<long, long>(1L << 20, new LogSettings { LogDevice = log }); using var s = store.NewSession(new SimpleFunctions<long, long>()); long key = 1, value = 1, input = 10, output = 0; s.Upsert(ref key, ref value); s.Read(ref key, ref output); - Assert.IsTrue(output == value); + Assert.AreEqual(value, output); s.RMW(ref key, ref input); s.RMW(ref key, ref input); s.Read(ref key, ref output); - Assert.IsTrue(output == 10); + Assert.AreEqual(10, output); } } -} +} \ No newline at end of file diff --git a/cs/test/BlittableIterationTests.cs b/cs/test/BlittableIterationTests.cs index 7f86fa5bb..e8f5f0cb8 100644 --- a/cs/test/BlittableIterationTests.cs +++ b/cs/test/BlittableIterationTests.cs @@ -1,48 +1,50 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -using System; -using System.Text; -using System.Threading; -using System.Threading.Tasks; -using System.Collections.Generic; -using System.Linq; using FASTER.core; -using System.IO; using NUnit.Framework; namespace FASTER.test { - [TestFixture] internal class BlittableIterationTests { private FasterKV<KeyStruct, ValueStruct> fht; private IDevice log; + private string path; [SetUp] public void Setup() { - log = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/BlittableIterationTests.log", deleteOnClose: true); - fht = new FasterKV<KeyStruct, ValueStruct> - (1L << 20, new LogSettings { LogDevice = log, MemorySizeBits = 15, PageSizeBits = 9 }); + path = TestUtils.MethodTestDir + "/"; + + // Clean up log files from previous test runs in case they weren't cleaned up + TestUtils.DeleteDirectory(path, wait:true); } [TearDown] public void TearDown() { - fht.Dispose(); + fht?.Dispose(); fht = null; - log.Dispose(); + log?.Dispose(); + log = null; + TestUtils.DeleteDirectory(path); } [Test] [Category("FasterKV")] - public void BlittableIterationTest1() + [Category("Smoke")] + public void BlittableIterationTest1([Values] TestUtils.DeviceType deviceType) { + string filename = path + "BlittableIterationTest1" + deviceType.ToString() + ".log"; + log = TestUtils.CreateTestDevice(deviceType, filename); + fht = new FasterKV<KeyStruct, ValueStruct> + (1L << 20, new LogSettings { LogDevice = log, MemorySizeBits = 15, PageSizeBits = 9, SegmentSizeBits = 22 }); + using var session = fht.For(new FunctionsCompaction()).NewSession<FunctionsCompaction>(); - const int totalRecords = 2000; + const int totalRecords = 500; var start = fht.Log.TailAddress; for (int i = 0; i < totalRecords; i++) @@ -57,12 +59,11 @@ public void BlittableIterationTest1() while (iter.GetNext(out var recordInfo)) { count++; - Assert.IsTrue(iter.GetValue().vfield1 == iter.GetKey().kfield1); + Assert.AreEqual(iter.GetKey().kfield1, iter.GetValue().vfield1); } iter.Dispose(); - Assert.IsTrue(count == totalRecords); - + Assert.AreEqual(totalRecords, count); for (int i = 0; i < totalRecords; i++) { @@ -76,11 +77,11 @@ public void BlittableIterationTest1() while (iter.GetNext(out var recordInfo)) { count++; - Assert.IsTrue(iter.GetValue().vfield1 == iter.GetKey().kfield1 * 2); + Assert.AreEqual(iter.GetKey().kfield1 * 2, iter.GetValue().vfield1); } iter.Dispose(); - Assert.IsTrue(count == totalRecords); + Assert.AreEqual(totalRecords, count); for (int i = totalRecords/2; i < totalRecords; i++) { @@ -97,7 +98,7 @@ public void BlittableIterationTest1() } iter.Dispose(); - Assert.IsTrue(count == totalRecords); + Assert.AreEqual(totalRecords, count); for (int i = 0; i < totalRecords; i+=2) { @@ -114,7 +115,7 @@ public void BlittableIterationTest1() } iter.Dispose(); - Assert.IsTrue(count == totalRecords); + Assert.AreEqual(totalRecords, count); for (int i = 0; i < totalRecords; i += 2) { @@ -131,7 +132,7 @@ public void BlittableIterationTest1() } iter.Dispose(); - Assert.IsTrue(count == totalRecords / 2); + Assert.AreEqual(totalRecords / 2, count); for (int i = 0; i < totalRecords; i++) { @@ -145,12 +146,11 @@ public void BlittableIterationTest1() while (iter.GetNext(out var recordInfo)) { count++; - Assert.IsTrue(iter.GetValue().vfield1 == iter.GetKey().kfield1 * 3); + Assert.AreEqual(iter.GetKey().kfield1 * 3, iter.GetValue().vfield1); } iter.Dispose(); - Assert.IsTrue(count == totalRecords); - + Assert.AreEqual(totalRecords, count); } } } diff --git a/cs/test/BlittableLogCompactionTests.cs b/cs/test/BlittableLogCompactionTests.cs index 5a18143ac..99f3f7300 100644 --- a/cs/test/BlittableLogCompactionTests.cs +++ b/cs/test/BlittableLogCompactionTests.cs @@ -1,19 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -using System; -using System.Text; -using System.Threading; -using System.Threading.Tasks; -using System.Collections.Generic; -using System.Linq; using FASTER.core; -using System.IO; using NUnit.Framework; namespace FASTER.test { - [TestFixture] internal class BlittableLogCompactionTests { @@ -23,7 +15,8 @@ internal class BlittableLogCompactionTests [SetUp] public void Setup() { - log = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/BlittableLogCompactionTests.log", deleteOnClose: true); + TestUtils.DeleteDirectory(TestUtils.MethodTestDir, wait:true); + log = Devices.CreateLogDevice(TestUtils.MethodTestDir + "/BlittableLogCompactionTests.log", deleteOnClose: true); fht = new FasterKV<KeyStruct, ValueStruct> (1L << 20, new LogSettings { LogDevice = log, MemorySizeBits = 15, PageSizeBits = 9 }); } @@ -31,14 +24,18 @@ public void Setup() [TearDown] public void TearDown() { - fht.Dispose(); + fht?.Dispose(); fht = null; - log.Dispose(); + log?.Dispose(); + log = null; + TestUtils.DeleteDirectory(TestUtils.MethodTestDir); } [Test] [Category("FasterKV")] [Category("Compaction")] + [Category("Smoke")] + public void BlittableLogCompactionTest1() { using var session = fht.For(new FunctionsCompaction()).NewSession<FunctionsCompaction>(); @@ -60,7 +57,7 @@ public void BlittableLogCompactionTest1() } compactUntil = session.Compact(compactUntil, true); - Assert.IsTrue(fht.Log.BeginAddress == compactUntil); + Assert.AreEqual(compactUntil, fht.Log.BeginAddress); // Read 2000 keys - all should be present for (int i = 0; i < totalRecords; i++) @@ -74,9 +71,9 @@ public void BlittableLogCompactionTest1() session.CompletePending(true); else { - Assert.IsTrue(status == Status.OK); - Assert.IsTrue(output.value.vfield1 == value.vfield1); - Assert.IsTrue(output.value.vfield2 == value.vfield2); + Assert.AreEqual(Status.OK, status); + Assert.AreEqual(value.vfield1, output.value.vfield1); + Assert.AreEqual(value.vfield2, output.value.vfield2); } } } @@ -117,8 +114,8 @@ public void BlittableLogCompactionTest2() var tail = fht.Log.TailAddress; compactUntil = session.Compact(compactUntil, true); - Assert.IsTrue(fht.Log.BeginAddress == compactUntil); - Assert.IsTrue(fht.Log.TailAddress == tail); + Assert.AreEqual(compactUntil, fht.Log.BeginAddress); + Assert.AreEqual(tail, fht.Log.TailAddress); // Read 2000 keys - all should be present for (int i = 0; i < totalRecords; i++) @@ -132,9 +129,9 @@ public void BlittableLogCompactionTest2() session.CompletePending(true); else { - Assert.IsTrue(status == Status.OK); - Assert.IsTrue(output.value.vfield1 == value.vfield1); - Assert.IsTrue(output.value.vfield2 == value.vfield2); + Assert.AreEqual(Status.OK, status); + Assert.AreEqual(value.vfield1, output.value.vfield1); + Assert.AreEqual(value.vfield2, output.value.vfield2); } } } @@ -172,7 +169,7 @@ public void BlittableLogCompactionTest3() var tail = fht.Log.TailAddress; compactUntil = session.Compact(compactUntil, true); - Assert.IsTrue(fht.Log.BeginAddress == compactUntil); + Assert.AreEqual(compactUntil, fht.Log.BeginAddress); // Read 2000 keys - all should be present for (int i = 0; i < totalRecords; i++) @@ -190,13 +187,13 @@ public void BlittableLogCompactionTest3() { if (ctx == 0) { - Assert.IsTrue(status == Status.OK); - Assert.IsTrue(output.value.vfield1 == value.vfield1); - Assert.IsTrue(output.value.vfield2 == value.vfield2); + Assert.AreEqual(Status.OK, status); + Assert.AreEqual(value.vfield1, output.value.vfield1); + Assert.AreEqual(value.vfield2, output.value.vfield2); } else { - Assert.IsTrue(status == Status.NOTFOUND); + Assert.AreEqual(Status.NOTFOUND, status); } } } @@ -205,6 +202,7 @@ public void BlittableLogCompactionTest3() [Test] [Category("FasterKV")] [Category("Compaction")] + [Category("Smoke")] public void BlittableLogCompactionCustomFunctionsTest1() { @@ -230,7 +228,7 @@ public void BlittableLogCompactionCustomFunctionsTest1() // Only leave records with even vfield1 compactUntil = session.Compact(compactUntil, true, default(EvenCompactionFunctions)); - Assert.IsTrue(fht.Log.BeginAddress == compactUntil); + Assert.AreEqual(compactUntil, fht.Log.BeginAddress); // Read 2000 keys - all should be present for (var i = 0; i < totalRecords; i++) @@ -250,13 +248,13 @@ public void BlittableLogCompactionCustomFunctionsTest1() { if (ctx == 0) { - Assert.IsTrue(status == Status.OK); - Assert.IsTrue(output.value.vfield1 == value.vfield1); - Assert.IsTrue(output.value.vfield2 == value.vfield2); + Assert.AreEqual(Status.OK, status); + Assert.AreEqual(value.vfield1, output.value.vfield1); + Assert.AreEqual(value.vfield2, output.value.vfield2); } else { - Assert.IsTrue(status == Status.NOTFOUND); + Assert.AreEqual(Status.NOTFOUND, status); } } } @@ -296,9 +294,9 @@ public void BlittableLogCompactionCustomFunctionsTest2() } else { - Assert.IsTrue(status == Status.OK); - Assert.IsTrue(output.value.vfield1 == value.vfield1); - Assert.IsTrue(output.value.vfield2 == value.vfield2); + Assert.AreEqual(Status.OK, status); + Assert.AreEqual(value.vfield1, output.value.vfield1); + Assert.AreEqual(value.vfield2, output.value.vfield2); } } diff --git a/cs/test/BlittableLogScanTests.cs b/cs/test/BlittableLogScanTests.cs index f25b9ae84..fba9303e4 100644 --- a/cs/test/BlittableLogScanTests.cs +++ b/cs/test/BlittableLogScanTests.cs @@ -2,18 +2,11 @@ // Licensed under the MIT license. using System; -using System.Text; -using System.Threading; -using System.Threading.Tasks; -using System.Collections.Generic; -using System.Linq; using FASTER.core; -using System.IO; using NUnit.Framework; namespace FASTER.test { - [TestFixture] internal class BlittableFASTERScanTests { @@ -24,7 +17,8 @@ internal class BlittableFASTERScanTests [SetUp] public void Setup() { - log = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/BlittableFASTERScanTests.log", deleteOnClose: true); + TestUtils.DeleteDirectory(TestUtils.MethodTestDir, wait:true); + log = Devices.CreateLogDevice(TestUtils.MethodTestDir + "/BlittableFASTERScanTests.log", deleteOnClose: true); fht = new FasterKV<KeyStruct, ValueStruct> (1L << 20, new LogSettings { LogDevice = log, MemorySizeBits = 15, PageSizeBits = 9 }); } @@ -32,13 +26,17 @@ public void Setup() [TearDown] public void TearDown() { - fht.Dispose(); + fht?.Dispose(); fht = null; - log.Dispose(); + log?.Dispose(); + log = null; + TestUtils.DeleteDirectory(TestUtils.MethodTestDir); } [Test] [Category("FasterKV")] + [Category("Smoke")] + public void BlittableDiskWriteScan() { using var session = fht.For(new Functions()).NewSession<Functions>(); @@ -59,26 +57,26 @@ public void BlittableDiskWriteScan() int val = 0; while (iter.GetNext(out _, out KeyStruct key, out ValueStruct value)) { - Assert.IsTrue(key.kfield1 == val); - Assert.IsTrue(key.kfield2 == val + 1); - Assert.IsTrue(value.vfield1 == val); - Assert.IsTrue(value.vfield2 == val + 1); + Assert.AreEqual(val, key.kfield1); + Assert.AreEqual(val + 1, key.kfield2); + Assert.AreEqual(val, value.vfield1); + Assert.AreEqual(val + 1, value.vfield2); val++; } - Assert.IsTrue(totalRecords == val); + Assert.AreEqual(val, totalRecords); iter = fht.Log.Scan(start, fht.Log.TailAddress, ScanBufferingMode.DoublePageBuffering); val = 0; while (iter.GetNext(out RecordInfo recordInfo, out KeyStruct key, out ValueStruct value)) { - Assert.IsTrue(key.kfield1 == val); - Assert.IsTrue(key.kfield2 == val + 1); - Assert.IsTrue(value.vfield1 == val); - Assert.IsTrue(value.vfield2 == val + 1); + Assert.AreEqual(val, key.kfield1); + Assert.AreEqual(val + 1, key.kfield2); + Assert.AreEqual(val, value.vfield1); + Assert.AreEqual(val + 1, value.vfield2); val++; } - Assert.IsTrue(totalRecords == val); + Assert.AreEqual(val, totalRecords); s.Dispose(); } @@ -89,7 +87,7 @@ class LogObserver : IObserver<IFasterScanIterator<KeyStruct, ValueStruct>> public void OnCompleted() { - Assert.IsTrue(val == totalRecords); + Assert.AreEqual(totalRecords, val); } public void OnError(Exception error) @@ -100,10 +98,10 @@ public void OnNext(IFasterScanIterator<KeyStruct, ValueStruct> iter) { while (iter.GetNext(out _, out KeyStruct key, out ValueStruct value)) { - Assert.IsTrue(key.kfield1 == val); - Assert.IsTrue(key.kfield2 == val + 1); - Assert.IsTrue(value.vfield1 == val); - Assert.IsTrue(value.vfield2 == val + 1); + Assert.AreEqual(val, key.kfield1); + Assert.AreEqual(val + 1, key.kfield2); + Assert.AreEqual(val, value.vfield1); + Assert.AreEqual(val + 1, value.vfield2); val++; } } diff --git a/cs/test/CheckpointManagerTests.cs b/cs/test/CheckpointManagerTests.cs new file mode 100644 index 000000000..c448d8fcd --- /dev/null +++ b/cs/test/CheckpointManagerTests.cs @@ -0,0 +1,141 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Security.Cryptography; +using System.Threading.Tasks; +using FASTER.core; +using FASTER.devices; +using FASTER.test.recovery; +using NUnit.Framework; + +namespace FASTER.test +{ + public class CheckpointManagerTests + { + private Random random = new Random(); + + [Test] + [Category("CheckpointRestore")] + [Category("Smoke")] + public async Task CheckpointManagerPurgeCheck([Values] DeviceMode deviceMode) + { + ICheckpointManager checkpointManager; + if (deviceMode == DeviceMode.Local) + { + checkpointManager = new DeviceLogCommitCheckpointManager( + new LocalStorageNamedDeviceFactory(), + new DefaultCheckpointNamingScheme(TestUtils.MethodTestDir + + "/checkpoints/")); // PurgeAll deletes this directory + } + else + { + TestUtils.IgnoreIfNotRunningAzureTests(); + checkpointManager = new DeviceLogCommitCheckpointManager( + new AzureStorageNamedDeviceFactory(TestUtils.AzureEmulatedStorageString), + new DefaultCheckpointNamingScheme( + $"{TestUtils.AzureTestContainer}/{TestUtils.AzureTestDirectory}")); + } + + var path = TestUtils.MethodTestDir + "/"; + using var log = Devices.CreateLogDevice(path + "hlog.log", deleteOnClose: true); + TestUtils.RecreateDirectory(path); + + using var fht = new FasterKV<long, long> + (1 << 10, + logSettings: new LogSettings + { + LogDevice = log, MutableFraction = 1, PageSizeBits = 10, MemorySizeBits = 20, + ReadCacheSettings = null + }, + checkpointSettings: new CheckpointSettings { CheckpointManager = checkpointManager } + ); + using var s = fht.NewSession(new SimpleFunctions<long, long>()); + + var logCheckpoints = new Dictionary<Guid, int>(); + var indexCheckpoints = new Dictionary<Guid, int>(); + var fullCheckpoints = new Dictionary<Guid, int>(); + + for (var i = 0; i < 10; i++) + { + // Do some dummy update + s.Upsert(0, random.Next()); + + var checkpointType = random.Next(5); + Guid result = default; + switch (checkpointType) + { + case 0: + fht.TakeHybridLogCheckpoint(out result, CheckpointType.FoldOver); + logCheckpoints.Add(result, 0); + break; + case 1: + fht.TakeHybridLogCheckpoint(out result, CheckpointType.Snapshot); + logCheckpoints.Add(result, 0); + break; + case 2: + fht.TakeIndexCheckpoint(out result); + indexCheckpoints.Add(result, 0); + break; + case 3: + fht.TakeFullCheckpoint(out result, CheckpointType.FoldOver); + fullCheckpoints.Add(result, 0); + break; + case 4: + fht.TakeFullCheckpoint(out result, CheckpointType.Snapshot); + fullCheckpoints.Add(result, 0); + break; + default: + Assert.True(false); + break; + } + + await fht.CompleteCheckpointAsync(); + } + + Assert.AreEqual(checkpointManager.GetLogCheckpointTokens().ToDictionary(guid => guid, _ => 0), + logCheckpoints.Union(fullCheckpoints).ToDictionary(e => e.Key, e => e.Value)); + Assert.AreEqual(checkpointManager.GetIndexCheckpointTokens().ToDictionary(guid => guid, _ => 0), + indexCheckpoints.Union(fullCheckpoints).ToDictionary(e => e.Key, e => e.Value)); + + if (logCheckpoints.Count != 0) + { + var guid = logCheckpoints.First().Key; + checkpointManager.Purge(guid); + logCheckpoints.Remove(guid); + Assert.AreEqual(checkpointManager.GetLogCheckpointTokens().ToDictionary(guid => guid, _ => 0), + logCheckpoints.Union(fullCheckpoints).ToDictionary(e => e.Key, e => e.Value)); + Assert.AreEqual(checkpointManager.GetIndexCheckpointTokens().ToDictionary(guid => guid, _ => 0), + indexCheckpoints.Union(fullCheckpoints).ToDictionary(e => e.Key, e => e.Value)); + } + + if (indexCheckpoints.Count != 0) + { + var guid = indexCheckpoints.First().Key; + checkpointManager.Purge(guid); + indexCheckpoints.Remove(guid); + Assert.AreEqual(checkpointManager.GetLogCheckpointTokens().ToDictionary(guid => guid, _ => 0), + logCheckpoints.Union(fullCheckpoints).ToDictionary(e => e.Key, e => e.Value)); + Assert.AreEqual(checkpointManager.GetIndexCheckpointTokens().ToDictionary(guid => guid, _ => 0), + indexCheckpoints.Union(fullCheckpoints).ToDictionary(e => e.Key, e => e.Value)); + } + + + if (fullCheckpoints.Count != 0) + { + var guid = fullCheckpoints.First().Key; + checkpointManager.Purge(guid); + fullCheckpoints.Remove(guid); + Assert.AreEqual(checkpointManager.GetLogCheckpointTokens().ToDictionary(guid => guid, _ => 0), + logCheckpoints.Union(fullCheckpoints).ToDictionary(e => e.Key, e => e.Value)); + Assert.AreEqual(checkpointManager.GetIndexCheckpointTokens().ToDictionary(guid => guid, _ => 0), + indexCheckpoints.Union(fullCheckpoints).ToDictionary(e => e.Key, e => e.Value)); + } + + checkpointManager.PurgeAll(); + Assert.IsEmpty(checkpointManager.GetLogCheckpointTokens()); + Assert.IsEmpty(checkpointManager.GetIndexCheckpointTokens()); + + checkpointManager.Dispose(); + } + } +} \ No newline at end of file diff --git a/cs/test/CompletePendingTests.cs b/cs/test/CompletePendingTests.cs index 5d3b6ffcd..f18d3c5e4 100644 --- a/cs/test/CompletePendingTests.cs +++ b/cs/test/CompletePendingTests.cs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -using System; using System.Collections.Generic; using System.Threading.Tasks; using FASTER.core; @@ -14,11 +13,18 @@ class CompletePendingTests { private FasterKV<KeyStruct, ValueStruct> fht; private IDevice log; + private string path; + [SetUp] public void Setup() { - log = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/CompletePendingTests.log", preallocateFile: true, deleteOnClose: true); + path = TestUtils.MethodTestDir + "/"; + + // Clean up log files from previous test runs in case they weren't cleaned up + TestUtils.DeleteDirectory(path, wait:true); + + log = Devices.CreateLogDevice(path + "/CompletePendingTests.log", preallocateFile: true, deleteOnClose: true); fht = new FasterKV<KeyStruct, ValueStruct>(128, new LogSettings { LogDevice = log, MemorySizeBits = 29 }); } @@ -29,6 +35,7 @@ public void TearDown() fht = null; log?.Dispose(); log = null; + TestUtils.DeleteDirectory(path, wait: true); } const int numRecords = 1000; diff --git a/cs/test/ComponentRecoveryTests.cs b/cs/test/ComponentRecoveryTests.cs index 93a5d55ef..d4e04ce12 100644 --- a/cs/test/ComponentRecoveryTests.cs +++ b/cs/test/ComponentRecoveryTests.cs @@ -8,15 +8,16 @@ namespace FASTER.test.recovery { - [TestFixture] internal class ComponentRecoveryTests { private static unsafe void Setup_MallocFixedPageSizeRecoveryTest(out int seed, out IDevice device, out int numBucketsToAdd, out long[] logicalAddresses, out ulong numBytesWritten) { + TestUtils.DeleteDirectory(TestUtils.MethodTestDir, wait:true); + seed = 123; var rand1 = new Random(seed); - device = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/MallocFixedPageSizeRecoveryTest.dat", deleteOnClose: true); + device = Devices.CreateLogDevice(TestUtils.MethodTestDir + "/MallocFixedPageSizeRecoveryTest.dat", deleteOnClose: true); var allocator = new MallocFixedPageSize<HashBucket>(); //do something @@ -41,9 +42,9 @@ private static unsafe void Setup_MallocFixedPageSizeRecoveryTest(out int seed, o allocator.Dispose(); } - private static unsafe void Finish_MallocFixedPageSizeRecoveryTest(int seed, int numBucketsToAdd, long[] logicalAddresses, ulong numBytesWritten, MallocFixedPageSize<HashBucket> recoveredAllocator, ulong numBytesRead) + private static unsafe void Finish_MallocFixedPageSizeRecoveryTest(int seed, IDevice device, int numBucketsToAdd, long[] logicalAddresses, ulong numBytesWritten, MallocFixedPageSize<HashBucket> recoveredAllocator, ulong numBytesRead) { - Assert.IsTrue(numBytesWritten == numBytesRead); + Assert.AreEqual(numBytesRead, numBytesWritten); var rand2 = new Random(seed); for (int i = 0; i < numBucketsToAdd; i++) @@ -52,15 +53,18 @@ private static unsafe void Finish_MallocFixedPageSizeRecoveryTest(int seed, int var bucket = (HashBucket*)recoveredAllocator.GetPhysicalAddress(logicalAddress); for (int j = 0; j < Constants.kOverflowBucketIndex; j++) { - Assert.IsTrue(bucket->bucket_entries[j] == rand2.Next()); + Assert.AreEqual(rand2.Next(), bucket->bucket_entries[j]); } } recoveredAllocator.Dispose(); + device.Dispose(); + TestUtils.DeleteDirectory(TestUtils.MethodTestDir); } [Test] [Category("CheckpointRestore")] + [Category("Smoke")] public void MallocFixedPageSizeRecoveryTest() { Setup_MallocFixedPageSizeRecoveryTest(out int seed, out IDevice device, out int numBucketsToAdd, out long[] logicalAddresses, out ulong numBytesWritten); @@ -71,11 +75,12 @@ public void MallocFixedPageSizeRecoveryTest() //wait until complete recoveredAllocator.IsRecoveryCompleted(true); - Finish_MallocFixedPageSizeRecoveryTest(seed, numBucketsToAdd, logicalAddresses, numBytesWritten, recoveredAllocator, numBytesRead); + Finish_MallocFixedPageSizeRecoveryTest(seed, device, numBucketsToAdd, logicalAddresses, numBytesWritten, recoveredAllocator, numBytesRead); } [Test] [Category("CheckpointRestore")] + [Category("Smoke")] public async Task MallocFixedPageSizeRecoveryAsyncTest() { Setup_MallocFixedPageSizeRecoveryTest(out int seed, out IDevice device, out int numBucketsToAdd, out long[] logicalAddresses, out ulong numBytesWritten); @@ -83,16 +88,18 @@ public async Task MallocFixedPageSizeRecoveryAsyncTest() var recoveredAllocator = new MallocFixedPageSize<HashBucket>(); ulong numBytesRead = await recoveredAllocator.RecoverAsync(device, 0, numBucketsToAdd, numBytesWritten, cancellationToken: default); - Finish_MallocFixedPageSizeRecoveryTest(seed, numBucketsToAdd, logicalAddresses, numBytesWritten, recoveredAllocator, numBytesRead); + Finish_MallocFixedPageSizeRecoveryTest(seed, device, numBucketsToAdd, logicalAddresses, numBytesWritten, recoveredAllocator, numBytesRead); } private static unsafe void Setup_FuzzyIndexRecoveryTest(out int seed, out int size, out long numAdds, out IDevice ht_device, out IDevice ofb_device, out FasterBase hash_table1, out ulong ht_num_bytes_written, out ulong ofb_num_bytes_written, out int num_ofb_buckets) { + TestUtils.DeleteDirectory(TestUtils.MethodTestDir, wait: true); + seed = 123; size = 1 << 16; numAdds = 1 << 18; - ht_device = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/TestFuzzyIndexRecoveryht.dat", deleteOnClose: true); - ofb_device = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/TestFuzzyIndexRecoveryofb.dat", deleteOnClose: true); + ht_device = Devices.CreateLogDevice(TestUtils.MethodTestDir + "/TestFuzzyIndexRecoveryht.dat", deleteOnClose: true); + ofb_device = Devices.CreateLogDevice(TestUtils.MethodTestDir + "/TestFuzzyIndexRecoveryofb.dat", deleteOnClose: true); hash_table1 = new FasterBase(); hash_table1.Initialize(size, 512); @@ -122,7 +129,7 @@ private static unsafe void Setup_FuzzyIndexRecoveryTest(out int seed, out int si hash_table1.IsIndexFuzzyCheckpointCompletedAsync().AsTask().Wait(); } - private static unsafe void Finish_FuzzyIndexRecoveryTest(int seed, long numAdds, FasterBase hash_table1, FasterBase hash_table2) + private static unsafe void Finish_FuzzyIndexRecoveryTest(int seed, long numAdds, IDevice ht_device, IDevice ofb_device, FasterBase hash_table1, FasterBase hash_table2) { var keyGenerator2 = new Random(seed); @@ -142,20 +149,25 @@ private static unsafe void Finish_FuzzyIndexRecoveryTest(int seed, long numAdds, var exists1 = hash_table1.FindTag(hash, tag, ref bucket1, ref slot1, ref entry1); var exists2 = hash_table2.FindTag(hash, tag, ref bucket2, ref slot2, ref entry2); - Assert.IsTrue(exists1 == exists2); + Assert.AreEqual(exists2, exists1); if (exists1) { - Assert.IsTrue(entry1.word == entry2.word); + Assert.AreEqual(entry2.word, entry1.word); } } hash_table1.Free(); hash_table2.Free(); + + ht_device.Dispose(); + ofb_device.Dispose(); + TestUtils.DeleteDirectory(TestUtils.MethodTestDir); } [Test] [Category("CheckpointRestore")] + [Category("Smoke")] public unsafe void FuzzyIndexRecoveryTest() { Setup_FuzzyIndexRecoveryTest(out int seed, out int size, out long numAdds, out IDevice ht_device, out IDevice ofb_device, out FasterBase hash_table1, @@ -169,11 +181,12 @@ public unsafe void FuzzyIndexRecoveryTest() //wait until complete hash_table2.IsFuzzyIndexRecoveryComplete(true); - Finish_FuzzyIndexRecoveryTest(seed, numAdds, hash_table1, hash_table2); + Finish_FuzzyIndexRecoveryTest(seed, numAdds, ht_device, ofb_device, hash_table1, hash_table2); } [Test] [Category("CheckpointRestore")] + [Category("Smoke")] public async Task FuzzyIndexRecoveryAsyncTest() { Setup_FuzzyIndexRecoveryTest(out int seed, out int size, out long numAdds, out IDevice ht_device, out IDevice ofb_device, out FasterBase hash_table1, @@ -184,7 +197,7 @@ public async Task FuzzyIndexRecoveryAsyncTest() await hash_table2.RecoverFuzzyIndexAsync(0, ht_device, ht_num_bytes_written, ofb_device, num_ofb_buckets, ofb_num_bytes_written, cancellationToken: default); - Finish_FuzzyIndexRecoveryTest(seed, numAdds, hash_table1, hash_table2); + Finish_FuzzyIndexRecoveryTest(seed, numAdds, ht_device, ofb_device, hash_table1, hash_table2); } } } diff --git a/cs/test/DeltaLogTests.cs b/cs/test/DeltaLogTests.cs index ede9b6b95..5dac77fbe 100644 --- a/cs/test/DeltaLogTests.cs +++ b/cs/test/DeltaLogTests.cs @@ -2,7 +2,6 @@ // Licensed under the MIT license. using System; -using System.IO; using FASTER.core; using NUnit.Framework; @@ -11,66 +10,80 @@ namespace FASTER.test [TestFixture] internal class DeltaLogStandAloneTests { + private FasterLog log; + private IDevice device; + private string path; + + [SetUp] + public void Setup() + { + path = TestUtils.MethodTestDir + "/"; + + // Clean up log files from previous test runs in case they weren't cleaned up + TestUtils.DeleteDirectory(path, wait: true); + } + + [TearDown] + public void TearDown() + { + log?.Dispose(); + log = null; + device?.Dispose(); + device = null; + + // Clean up log files + TestUtils.DeleteDirectory(path, wait: true); + } [Test] - public void DeltaLogTest1() + [Category("FasterLog")] + [Category("Smoke")] + public void DeltaLogTest1([Values] TestUtils.DeviceType deviceType) { - int TotalCount = 1000; - string path = TestContext.CurrentContext.TestDirectory + "/" + TestContext.CurrentContext.Test.Name + "/"; - DirectoryInfo di = Directory.CreateDirectory(path); - using (IDevice device = Devices.CreateLogDevice(path + TestContext.CurrentContext.Test.Name + "/delta.log", deleteOnClose: false)) - { - device.Initialize(-1); - using DeltaLog deltaLog = new DeltaLog(device, 12, 0); - Random r = new Random(20); - int i; + const int TotalCount = 200; + string filename = $"{path}delta_{deviceType}.log"; + TestUtils.RecreateDirectory(path); + + device = TestUtils.CreateTestDevice(deviceType, filename); + device.Initialize(-1); + using DeltaLog deltaLog = new DeltaLog(device, 12, 0); + Random r = new (20); + int i; - var bufferPool = new SectorAlignedBufferPool(1, (int)device.SectorSize); - deltaLog.InitializeForWrites(bufferPool); - for (i = 0; i < TotalCount; i++) + SectorAlignedBufferPool bufferPool = new(1, (int)device.SectorSize); + deltaLog.InitializeForWrites(bufferPool); + for (i = 0; i < TotalCount; i++) + { + int _len = 1 + r.Next(254); + long address; + while (true) { - int len = 1 + r.Next(254); - long address; - while (true) - { - deltaLog.Allocate(out int maxLen, out address); - if (len <= maxLen) break; - deltaLog.Seal(0); - } - for (int j = 0; j < len; j++) - { - unsafe { *(byte*)(address + j) = (byte)len; } - } - deltaLog.Seal(len, i); + deltaLog.Allocate(out int maxLen, out address); + if (_len <= maxLen) break; + deltaLog.Seal(0); } - deltaLog.FlushAsync().Wait(); - - deltaLog.InitializeForReads(); - i = 0; - r = new Random(20); - while (deltaLog.GetNext(out long address, out int len, out int type)) + for (int j = 0; j < _len; j++) { - int _len = 1 + r.Next(254); - Assert.IsTrue(type == i); - Assert.IsTrue(_len == len); - for (int j = 0; j < len; j++) - { - unsafe { Assert.IsTrue(*(byte*)(address + j) == (byte)_len); }; - } - i++; + unsafe { *(byte*)(address + j) = (byte)_len; } } - Assert.IsTrue(i == TotalCount); - bufferPool.Free(); + deltaLog.Seal(_len, i % 2 == 0 ? DeltaLogEntryType.DELTA : DeltaLogEntryType.CHECKPOINT_METADATA); } - while (true) + deltaLog.FlushAsync().Wait(); + + deltaLog.InitializeForReads(); + r = new (20); + for (i = 0; deltaLog.GetNext(out long address, out int len, out var type); i++) { - try + int _len = 1 + r.Next(254); + Assert.AreEqual( i % 2 == 0 ? DeltaLogEntryType.DELTA : DeltaLogEntryType.CHECKPOINT_METADATA, type); + Assert.AreEqual(len, _len); + for (int j = 0; j < len; j++) { - di.Delete(recursive: true); - break; + unsafe { Assert.AreEqual((byte)_len, *(byte*)(address + j)); }; } - catch { } } + Assert.AreEqual(TotalCount, i, $"i={i} and TotalCount={TotalCount}"); + bufferPool.Free(); } } } diff --git a/cs/test/DeviceFasterLogTests.cs b/cs/test/DeviceFasterLogTests.cs index b7364d832..e1fc8f55b 100644 --- a/cs/test/DeviceFasterLogTests.cs +++ b/cs/test/DeviceFasterLogTests.cs @@ -13,66 +13,93 @@ namespace FASTER.test { - [TestFixture] internal class DeviceFasterLogTests { const int entryLength = 100; - const int numEntries = 100000; + const int numEntries = 1000; private FasterLog log; - public const string EMULATED_STORAGE_STRING = "UseDevelopmentStorage=true;"; - public const string TEST_CONTAINER = "test"; + static readonly byte[] entry = new byte[100]; [Test] [Category("FasterLog")] - public async ValueTask PageBlobFasterLogTest1([Values] LogChecksumType logChecksum, [Values]FasterLogTests.IteratorType iteratorType) + public async ValueTask PageBlobFasterLogTest1([Values] LogChecksumType logChecksum, [Values]FasterLogTestBase.IteratorType iteratorType) { - if ("yes".Equals(Environment.GetEnvironmentVariable("RunAzureTests"))) - { - var device = new AzureStorageDevice(EMULATED_STORAGE_STRING, $"{TEST_CONTAINER}", "PageBlobFasterLogTest1", "fasterlog.log", deleteOnClose: true); - var checkpointManager = new DeviceLogCommitCheckpointManager( - new AzureStorageNamedDeviceFactory(EMULATED_STORAGE_STRING), - new DefaultCheckpointNamingScheme($"{TEST_CONTAINER}/PageBlobFasterLogTest1")); - await FasterLogTest1(logChecksum, device, checkpointManager, iteratorType); - device.Dispose(); - checkpointManager.PurgeAll(); - checkpointManager.Dispose(); - } + TestUtils.IgnoreIfNotRunningAzureTests(); + var device = new AzureStorageDevice(TestUtils.AzureEmulatedStorageString, $"{TestUtils.AzureTestContainer}", TestUtils.AzureTestDirectory, "fasterlog.log", deleteOnClose: true); + var checkpointManager = new DeviceLogCommitCheckpointManager( + new AzureStorageNamedDeviceFactory(TestUtils.AzureEmulatedStorageString), + new DefaultCheckpointNamingScheme($"{TestUtils.AzureTestContainer}/{TestUtils.AzureTestDirectory}")); + await FasterLogTest1(logChecksum, device, checkpointManager, iteratorType); + device.Dispose(); + checkpointManager.PurgeAll(); + checkpointManager.Dispose(); + } + + [Test] + [Category("FasterLog")] + public async ValueTask PageBlobFasterLogTestWithLease([Values] LogChecksumType logChecksum, [Values] FasterLogTestBase.IteratorType iteratorType) + { + // Set up the blob manager so can set lease to it + TestUtils.IgnoreIfNotRunningAzureTests(); + CloudStorageAccount storageAccount = CloudStorageAccount.DevelopmentStorageAccount; + var cloudBlobClient = storageAccount.CreateCloudBlobClient(); + CloudBlobContainer blobContainer = cloudBlobClient.GetContainerReference("test-container"); + blobContainer.CreateIfNotExists(); + var mycloudBlobDir = blobContainer.GetDirectoryReference(@"BlobManager/MyLeaseTest1"); + + var blobMgr = new DefaultBlobManager(true, mycloudBlobDir); + var device = new AzureStorageDevice(TestUtils.AzureEmulatedStorageString, $"{TestUtils.AzureTestContainer}", TestUtils.AzureTestDirectory, "fasterlogLease.log", deleteOnClose: true, underLease: true, blobManager: blobMgr); + + var checkpointManager = new DeviceLogCommitCheckpointManager( + new AzureStorageNamedDeviceFactory(TestUtils.AzureEmulatedStorageString), + new DefaultCheckpointNamingScheme($"{TestUtils.AzureTestContainer}/{TestUtils.AzureTestDirectory}")); + await FasterLogTest1(logChecksum, device, checkpointManager, iteratorType); + device.Dispose(); + checkpointManager.PurgeAll(); + checkpointManager.Dispose(); + blobContainer.Delete(); } + [Test] [Category("FasterLog")] - public async ValueTask PageBlobFasterLogTestWithLease([Values] LogChecksumType logChecksum, [Values] FasterLogTests.IteratorType iteratorType) + public void BasicHighLatencyDeviceTest() { + TestUtils.DeleteDirectory(TestUtils.MethodTestDir, wait: true); + + // Create devices \ log for test for in memory device + using LocalMemoryDevice device = new LocalMemoryDevice(1L << 28, 1L << 25, 2, latencyMs: 20); + using FasterLog LocalMemorylog = new FasterLog(new FasterLogSettings { LogDevice = device, PageSizeBits = 80, MemorySizeBits = 20, GetMemory = null, SegmentSizeBits = 80, MutableFraction = 0.2, LogCommitManager = null }); - // Need this environment variable set AND Azure Storage Emulator running - if ("yes".Equals(Environment.GetEnvironmentVariable("RunAzureTests"))) + int entryLength = 10; + + // Set Default entry data + for (int i = 0; i < entryLength; i++) { - // Set up the blob manager so can set lease to it - CloudStorageAccount storageAccount = CloudStorageAccount.DevelopmentStorageAccount; - var cloudBlobClient = storageAccount.CreateCloudBlobClient(); - CloudBlobContainer blobContainer = cloudBlobClient.GetContainerReference("test-container"); - blobContainer.CreateIfNotExists(); - var mycloudBlobDir = blobContainer.GetDirectoryReference(@"BlobManager/MyLeaseTest1"); - - var blobMgr = new DefaultBlobManager(true, mycloudBlobDir); - var device = new AzureStorageDevice(EMULATED_STORAGE_STRING, $"{TEST_CONTAINER}", "PageBlobFasterLogTestWithLease", "fasterlogLease.log", deleteOnClose: true, underLease: true, blobManager: blobMgr); - - var checkpointManager = new DeviceLogCommitCheckpointManager( - new AzureStorageNamedDeviceFactory(EMULATED_STORAGE_STRING), - new DefaultCheckpointNamingScheme($"{TEST_CONTAINER}/PageBlobFasterLogTestWithLease")); - await FasterLogTest1(logChecksum, device, checkpointManager, iteratorType); - device.Dispose(); - checkpointManager.PurgeAll(); - checkpointManager.Dispose(); - blobContainer.Delete(); + entry[i] = (byte)i; + LocalMemorylog.Enqueue(entry); + } + + // Commit to the log + LocalMemorylog.Commit(true); + + // Read the log just to verify was actually committed + int currentEntry = 0; + using (var iter = LocalMemorylog.Scan(0, 100_000_000)) + { + while (iter.GetNext(out byte[] result, out _, out _)) + { + Assert.IsTrue(result[currentEntry] == currentEntry, "Fail - Result[" + currentEntry.ToString() + "]: is not same as " + currentEntry.ToString()); + currentEntry++; + } } } - private async ValueTask FasterLogTest1(LogChecksumType logChecksum, IDevice device, ILogCommitManager logCommitManager, FasterLogTests.IteratorType iteratorType) + private async ValueTask FasterLogTest1(LogChecksumType logChecksum, IDevice device, ILogCommitManager logCommitManager, FasterLogTestBase.IteratorType iteratorType) { var logSettings = new FasterLogSettings { PageSizeBits = 20, SegmentSizeBits = 20, LogDevice = device, LogChecksum = logChecksum, LogCommitManager = logCommitManager }; - log = FasterLogTests.IsAsync(iteratorType) ? await FasterLog.CreateAsync(logSettings) : new FasterLog(logSettings); + log = FasterLogTestBase.IsAsync(iteratorType) ? await FasterLog.CreateAsync(logSettings) : new FasterLog(logSettings); byte[] entry = new byte[entryLength]; for (int i = 0; i < entryLength; i++) @@ -86,11 +113,11 @@ private async ValueTask FasterLogTest1(LogChecksumType logChecksum, IDevice devi using (var iter = log.Scan(0, long.MaxValue)) { - var counter = new FasterLogTests.Counter(log); + var counter = new FasterLogTestBase.Counter(log); switch (iteratorType) { - case FasterLogTests.IteratorType.AsyncByteVector: + case FasterLogTestBase.IteratorType.AsyncByteVector: await foreach ((byte[] result, _, _, long nextAddress) in iter.GetAsyncEnumerable()) { Assert.IsTrue(result.SequenceEqual(entry)); @@ -103,7 +130,7 @@ private async ValueTask FasterLogTest1(LogChecksumType logChecksum, IDevice devi break; } break; - case FasterLogTests.IteratorType.AsyncMemoryOwner: + case FasterLogTestBase.IteratorType.AsyncMemoryOwner: await foreach ((IMemoryOwner<byte> result, int _, long _, long nextAddress) in iter.GetAsyncEnumerable(MemoryPool<byte>.Shared)) { Assert.IsTrue(result.Memory.Span.ToArray().Take(entry.Length).SequenceEqual(entry)); @@ -117,7 +144,7 @@ private async ValueTask FasterLogTest1(LogChecksumType logChecksum, IDevice devi break; } break; - case FasterLogTests.IteratorType.Sync: + case FasterLogTestBase.IteratorType.Sync: while (iter.GetNext(out byte[] result, out _, out _)) { Assert.IsTrue(result.SequenceEqual(entry)); diff --git a/cs/test/EnqueueAndWaitForCommit.cs b/cs/test/EnqueueAndWaitForCommit.cs index 55aa8a3ab..a0d94da60 100644 --- a/cs/test/EnqueueAndWaitForCommit.cs +++ b/cs/test/EnqueueAndWaitForCommit.cs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. using System; -using System.IO; using System.Threading.Tasks; using FASTER.core; using NUnit.Framework; @@ -16,9 +15,9 @@ internal class EnqWaitCommitTest public FasterLog log; public IDevice device; - static readonly byte[] entry = new byte[entryLength]; - static readonly ReadOnlySpanBatch spanBatch = new ReadOnlySpanBatch(numEntries); - private string commitPath; + static byte[] entry; + static ReadOnlySpanBatch spanBatch; + private string path; public enum EnqueueIteratorType { @@ -38,26 +37,29 @@ private struct ReadOnlySpanBatch : IReadOnlySpanBatch [SetUp] public void Setup() { - commitPath = TestContext.CurrentContext.TestDirectory + "/" + TestContext.CurrentContext.Test.Name + "/"; + entry = new byte[entryLength]; + spanBatch = new(numEntries); + + path = TestUtils.MethodTestDir + "/"; // Clean up log files from previous test runs in case they weren't cleaned up - try { new DirectoryInfo(commitPath).Delete(true); } - catch { } + TestUtils.DeleteDirectory(path, wait:true); // Create devices \ log for test - device = Devices.CreateLogDevice(commitPath + "EnqueueAndWaitForCommit.log", deleteOnClose: true); + device = Devices.CreateLogDevice(path + "EnqueueAndWaitForCommit.log", deleteOnClose: true); log = new FasterLog(new FasterLogSettings { LogDevice = device }); } [TearDown] public void TearDown() { - log.Dispose(); - device.Dispose(); + log?.Dispose(); + log = null; + device?.Dispose(); + device = null; // Clean up log files - try { new DirectoryInfo(commitPath).Delete(true); } - catch { } + TestUtils.DeleteDirectory(path); } [Test] @@ -85,8 +87,8 @@ public async ValueTask EnqueueWaitCommitBasicTest([Values] EnqueueIteratorType i int currentEntry = 0; while (iter.GetNext(out byte[] result, out _, out _)) { - Assert.IsTrue(currentEntry < entryLength); - Assert.IsTrue(result[currentEntry] == (byte)currentEntry, "Fail - Result[" + currentEntry.ToString() + "]:" + result[0].ToString() + " not match expected:" + currentEntry); + Assert.Less(currentEntry, entryLength); + Assert.AreEqual((byte)currentEntry, result[currentEntry]); currentEntry++; } diff --git a/cs/test/EnqueueTests.cs b/cs/test/EnqueueTests.cs index 24f660766..fe770d158 100644 --- a/cs/test/EnqueueTests.cs +++ b/cs/test/EnqueueTests.cs @@ -1,9 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. using System; -using System.Buffers; -using System.Collections.Generic; -using System.IO; using System.Linq; using System.Threading; using System.Threading.Tasks; @@ -13,15 +10,13 @@ namespace FASTER.test { - [TestFixture] internal class EnqueueTests { private FasterLog log; private IDevice device; - static readonly byte[] entry = new byte[100]; - static readonly ReadOnlySpanBatch spanBatch = new ReadOnlySpanBatch(10000); - private string commitPath; + static byte[] entry; + private string path; public enum EnqueueIteratorType { @@ -41,44 +36,44 @@ private struct ReadOnlySpanBatch : IReadOnlySpanBatch [SetUp] public void Setup() { - - commitPath = TestContext.CurrentContext.TestDirectory + "/" + TestContext.CurrentContext.Test.Name + "/"; + entry = new byte[100]; + path = TestUtils.MethodTestDir + "/"; // Clean up log files from previous test runs in case they weren't cleaned up - if (Directory.Exists(commitPath)) - Directory.Delete(commitPath, true); - - // Create devices \ log for test - device = Devices.CreateLogDevice(commitPath + "Enqueue.log", deleteOnClose: true); - log = new FasterLog(new FasterLogSettings { LogDevice = device }); + TestUtils.DeleteDirectory(path, wait:true); } [TearDown] public void TearDown() { - log.Dispose(); - device.Dispose(); + log?.Dispose(); + log = null; + device?.Dispose(); + device = null; // Clean up log files - if (Directory.Exists(commitPath)) - Directory.Delete(commitPath, true); + TestUtils.DeleteDirectory(path); } - [Test] [Category("FasterLog")] [Category("Smoke")] - public void EnqueueBasicTest([Values] EnqueueIteratorType iteratorType) + public void EnqueueBasicTest([Values] EnqueueIteratorType iteratorType, [Values] TestUtils.DeviceType deviceType) { + int entryLength = 20; - int numEntries = 1000; + int numEntries = 500; int entryFlag = 9999; + string filename = path + "Enqueue"+deviceType.ToString()+".log"; + device = TestUtils.CreateTestDevice(deviceType, filename); + log = new FasterLog(new FasterLogSettings { LogDevice = device, SegmentSizeBits = 22, LogCommitDir = path }); // Needs to match what is set in TestUtils.CreateTestDevice + // Reduce SpanBatch to make sure entry fits on page if (iteratorType == EnqueueIteratorType.SpanBatch) { - entryLength = 10; - numEntries = 500; + entryLength = 5; + numEntries = 200; } // Set Default entry data @@ -124,9 +119,6 @@ public void EnqueueBasicTest([Values] EnqueueIteratorType iteratorType) // Commit to the log log.Commit(true); - // flag to make sure data has been checked - bool datacheckrun = false; - // Read the log - Look for the flag so know each entry is unique int currentEntry = 0; using (var iter = log.Scan(0, 100_000_000)) @@ -135,17 +127,14 @@ public void EnqueueBasicTest([Values] EnqueueIteratorType iteratorType) { if (currentEntry < entryLength) { - // set check flag to show got in here - datacheckrun = true; - // Span Batch only added first entry several times so have separate verification if (iteratorType == EnqueueIteratorType.SpanBatch) { - Assert.IsTrue(result[0] == (byte)entryFlag, "Fail - Result[0]:" + result[0].ToString() + " entryFlag:" + entryFlag); + Assert.AreEqual((byte)entryFlag, result[0]); } else { - Assert.IsTrue(result[currentEntry] == (byte)entryFlag, "Fail - Result[" + currentEntry.ToString() + "]:" + result[0].ToString() + " entryFlag:" + entryFlag); + Assert.AreEqual((byte)entryFlag, result[currentEntry]); } currentEntry++; @@ -153,19 +142,28 @@ public void EnqueueBasicTest([Values] EnqueueIteratorType iteratorType) } } - // if data verification was skipped, then pop a fail - if (datacheckrun == false) - Assert.Fail("Failure -- data loop after log.Scan never entered so wasn't verified. "); + // Make sure expected length (entryLength) is same as current - also makes sure that data verification was not skipped + Assert.AreEqual(entryLength, currentEntry); + } [Test] [Category("FasterLog")] - public async Task EnqueueAsyncBasicTest() + [Category("Smoke")] + public async Task EnqueueAsyncBasicTest([Values] TestUtils.DeviceType deviceType) { - bool datacheckrun = false; + const int expectedEntryCount = 10; + string filename = path + "EnqueueAsyncBasic" + deviceType.ToString() + ".log"; + device = TestUtils.CreateTestDevice(deviceType, filename); + log = new FasterLog(new FasterLogSettings { LogDevice = device,SegmentSizeBits = 22, LogCommitDir = path }); + +#if WINDOWS + if (deviceType == TestUtils.DeviceType.EmulatedAzure) + return; +#endif CancellationToken cancellationToken = default; ReadOnlyMemory<byte> readOnlyMemoryEntry = entry; ReadOnlySpanBatch spanBatch = new ReadOnlySpanBatch(5); @@ -189,9 +187,6 @@ public async Task EnqueueAsyncBasicTest() while (iter.GetNext(out byte[] result, out _, out _)) { - // set check flag to show got in here - datacheckrun = true; - // Verify based on which input read switch (currentEntry) { @@ -217,9 +212,9 @@ public async Task EnqueueAsyncBasicTest() } - // if data verification was skipped, then pop a fail - if (datacheckrun == false) - Assert.Fail("Failure -- data loop after log.Scan never entered so wasn't verified. "); + // Make sure expected length is same as current - also makes sure that data verification was not skipped + Assert.AreEqual(expectedEntryCount, currentEntry); + } } diff --git a/cs/test/FASTER.test.csproj b/cs/test/FASTER.test.csproj index 3507d6ef7..f5d51ae7e 100644 --- a/cs/test/FASTER.test.csproj +++ b/cs/test/FASTER.test.csproj @@ -21,7 +21,7 @@ <DelaySign>false</DelaySign> <DocumentationFile>bin\$(Platform)\$(Configuration)\$(TargetFramework)\$(AssemblyName).xml</DocumentationFile> </PropertyGroup> - + <PropertyGroup Condition="'$(Configuration)' == 'Debug'"> <DefineConstants>TRACE;DEBUG</DefineConstants> <DebugType>full</DebugType> @@ -38,8 +38,23 @@ <PropertyGroup> <NoWarn>1701;1702;1591</NoWarn> </PropertyGroup> - - <ItemGroup> + + <PropertyGroup> + <IsWindows Condition="'$([System.Runtime.InteropServices.RuntimeInformation]::IsOSPlatform($([System.Runtime.InteropServices.OSPlatform]::Windows)))' == 'true'">true</IsWindows> + <IsOSX Condition="'$([System.Runtime.InteropServices.RuntimeInformation]::IsOSPlatform($([System.Runtime.InteropServices.OSPlatform]::OSX)))' == 'true'">true</IsOSX> + <IsLinux Condition="'$([System.Runtime.InteropServices.RuntimeInformation]::IsOSPlatform($([System.Runtime.InteropServices.OSPlatform]::Linux)))' == 'true'">true</IsLinux> + </PropertyGroup> + <PropertyGroup Condition="'$(IsWindows)'=='true'"> + <DefineConstants>WINDOWS</DefineConstants> + </PropertyGroup> + <PropertyGroup Condition="'$(IsOSX)'=='true'"> + <DefineConstants>OSX</DefineConstants> + </PropertyGroup> + <PropertyGroup Condition="'$(IsLinux)'=='true'"> + <DefineConstants>LINUX</DefineConstants> + </PropertyGroup> + + <ItemGroup> <PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.7.1" /> <PackageReference Include="NUnit" Version="3.12.0" /> <PackageReference Include="NUnit3TestAdapter" Version="3.17.0"> diff --git a/cs/test/FasterLogAndDeviceConfigTests.cs b/cs/test/FasterLogAndDeviceConfigTests.cs index 934973216..7f20b848e 100644 --- a/cs/test/FasterLogAndDeviceConfigTests.cs +++ b/cs/test/FasterLogAndDeviceConfigTests.cs @@ -1,48 +1,32 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -using System; -using System.Buffers; -using System.Collections.Generic; using System.IO; -using System.Linq; -using System.Threading; -using System.Threading.Tasks; using FASTER.core; using NUnit.Framework; - namespace FASTER.test { - //* NOTE: //* A lot of various usage of Log config and Device config are in FasterLog.cs so the test here //* is for areas / parameters not covered by the tests in other areas of the test system //* For completeness, setting other parameters too where possible //* However, the verification is pretty light. Just makes sure log file created and things be added and read from it - - - - [TestFixture] internal class LogAndDeviceConfigTests { private FasterLog log; private IDevice device; - private string path = Path.GetTempPath() + "DeviceConfigTests/"; + private string path; static readonly byte[] entry = new byte[100]; - [SetUp] public void Setup() { + path = TestUtils.MethodTestDir + "/"; + // Clean up log files from previous test runs in case they weren't cleaned up - try - { - if (Directory.Exists(path)) - Directory.Delete(path, true); - } - catch {} + TestUtils.DeleteDirectory(path, wait:true); // Create devices \ log for test device = Devices.CreateLogDevice(path + "DeviceConfig", deleteOnClose: true, recoverDevice: true, preallocateFile: true, capacity: 1 << 30); @@ -52,24 +36,20 @@ public void Setup() [TearDown] public void TearDown() { - log.Dispose(); - device.Dispose(); + log?.Dispose(); + log = null; + device?.Dispose(); + device = null; // Clean up log files - try - { - if (Directory.Exists(path)) - Directory.Delete(path, true); - } - catch { } + TestUtils.DeleteDirectory(path); } - [Test] [Category("FasterLog")] + [Category("Smoke")] public void DeviceAndLogConfig() { - int entryLength = 10; // Set Default entry data @@ -92,15 +72,10 @@ public void DeviceAndLogConfig() { while (iter.GetNext(out byte[] result, out _, out _)) { - Assert.IsTrue(result[currentEntry] == currentEntry, "Fail - Result[" + currentEntry.ToString() + "]: is not same as "+currentEntry.ToString() ); - + Assert.AreEqual(currentEntry, result[currentEntry]); currentEntry++; } } } - } - } - - diff --git a/cs/test/FasterLogRecoverReadOnlyTests.cs b/cs/test/FasterLogRecoverReadOnlyTests.cs index 44f2a683d..734e6bc2c 100644 --- a/cs/test/FasterLogRecoverReadOnlyTests.cs +++ b/cs/test/FasterLogRecoverReadOnlyTests.cs @@ -27,10 +27,12 @@ public class FasterLogRecoverReadOnlyTests [SetUp] public void Setup() { - path = Path.GetTempPath() + "RecoverReadOnlyTest/"; + path = TestUtils.MethodTestDir + "/"; deviceName = path + "testlog"; - if (Directory.Exists(path)) - TestUtils.DeleteDirectory(path); + + // Clean up log files from previous test runs in case they weren't cleaned up + TestUtils.DeleteDirectory(path, wait:true); + cts = new CancellationTokenSource(); done = new SemaphoreSlim(0); } @@ -38,8 +40,11 @@ public void Setup() [TearDown] public void TearDown() { + cts?.Dispose(); + cts = default; + done?.Dispose(); + done = default; TestUtils.DeleteDirectory(path); - cts.Dispose(); } [Test] diff --git a/cs/test/FasterLogResumeTests.cs b/cs/test/FasterLogResumeTests.cs index 6750e1da1..8ba7a7d21 100644 --- a/cs/test/FasterLogResumeTests.cs +++ b/cs/test/FasterLogResumeTests.cs @@ -9,33 +9,30 @@ using NUnit.Framework; using System.Threading; - namespace FASTER.test { [TestFixture] internal class FasterLogResumeTests { private IDevice device; - private string commitPath; + private string path; [SetUp] public void Setup() { - commitPath = TestContext.CurrentContext.TestDirectory + "/" + TestContext.CurrentContext.Test.Name + "/"; + path = TestUtils.MethodTestDir + "/"; - if (Directory.Exists(commitPath)) - TestUtils.DeleteDirectory(commitPath); + TestUtils.DeleteDirectory(path, wait:true); - device = Devices.CreateLogDevice(commitPath + "fasterlog.log", deleteOnClose: true); + device = Devices.CreateLogDevice(path + "fasterlog.log", deleteOnClose: true); } [TearDown] public void TearDown() { - device.Dispose(); - - if (Directory.Exists(commitPath)) - TestUtils.DeleteDirectory(commitPath); + device?.Dispose(); + device = null; + TestUtils.DeleteDirectory(path); } [Test] @@ -71,6 +68,39 @@ public async Task FasterLogResumePersistedReaderSpec([Values] LogChecksumType lo } } + [Test] + [Category("FasterLog")] + public async Task FasterLogResumeViaCompleteUntilRecordAtSpec([Values] LogChecksumType logChecksum) + { + CancellationToken cancellationToken = default; + + var input1 = new byte[] { 0, 1, 2, 3 }; + var input2 = new byte[] { 4, 5, 6, 7, 8, 9, 10 }; + var input3 = new byte[] { 11, 12 }; + string readerName = "abc"; + + using (var l = new FasterLog(new FasterLogSettings { LogDevice = device, PageSizeBits = 16, MemorySizeBits = 16, LogChecksum = logChecksum })) + { + await l.EnqueueAsync(input1, cancellationToken); + await l.EnqueueAsync(input2); + await l.EnqueueAsync(input3); + await l.CommitAsync(); + + using var originalIterator = l.Scan(0, long.MaxValue, readerName); + Assert.IsTrue(originalIterator.GetNext(out _, out _, out long recordAddress, out _)); + await originalIterator.CompleteUntilRecordAtAsync(recordAddress); + Assert.IsTrue(originalIterator.GetNext(out _, out _, out _, out _)); // move the reader ahead + await l.CommitAsync(); + } + + using (var l = new FasterLog(new FasterLogSettings { LogDevice = device, PageSizeBits = 16, MemorySizeBits = 16, LogChecksum = logChecksum })) + { + using var recoveredIterator = l.Scan(0, long.MaxValue, readerName); + Assert.IsTrue(recoveredIterator.GetNext(out byte[] outBuf, out _, out _, out _)); + Assert.True(input2.SequenceEqual(outBuf)); // we should have read in input2, not input1 or input3 + } + } + [Test] [Category("FasterLog")] public async Task FasterLogResumePersistedReader2([Values] LogChecksumType logChecksum, [Values] bool overwriteLogCommits, [Values] bool removeOutdated) @@ -80,7 +110,7 @@ public async Task FasterLogResumePersistedReader2([Values] LogChecksumType logCh var input3 = new byte[] { 11, 12 }; string readerName = "abc"; - using (var logCommitManager = new DeviceLogCommitCheckpointManager(new LocalStorageNamedDeviceFactory(), new DefaultCheckpointNamingScheme(commitPath), overwriteLogCommits, removeOutdated)) + using (var logCommitManager = new DeviceLogCommitCheckpointManager(new LocalStorageNamedDeviceFactory(), new DefaultCheckpointNamingScheme(path), overwriteLogCommits, removeOutdated)) { long originalCompleted; @@ -124,7 +154,7 @@ public async Task FasterLogResumePersistedReader3([Values] LogChecksumType logCh var input3 = new byte[] { 11, 12 }; string readerName = "abcd"; - using (var logCommitManager = new DeviceLogCommitCheckpointManager(new LocalStorageNamedDeviceFactory(), new DefaultCheckpointNamingScheme(commitPath), overwriteLogCommits, removeOutdated)) + using (var logCommitManager = new DeviceLogCommitCheckpointManager(new LocalStorageNamedDeviceFactory(), new DefaultCheckpointNamingScheme(path), overwriteLogCommits, removeOutdated)) { long originalCompleted; diff --git a/cs/test/FasterLogScanTests.cs b/cs/test/FasterLogScanTests.cs index c8bded883..1a2d82dab 100644 --- a/cs/test/FasterLogScanTests.cs +++ b/cs/test/FasterLogScanTests.cs @@ -1,54 +1,57 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -using System; -using System.Buffers; -using System.Collections.Generic; -using System.IO; -using System.Linq; -using System.Threading; -using System.Threading.Tasks; using FASTER.core; using NUnit.Framework; -using System.Text; namespace FASTER.test { - [TestFixture] internal class FasterLogScanTests { - private FasterLog log; private IDevice device; private FasterLog logUncommitted; private IDevice deviceUnCommitted; - private string path = Path.GetTempPath() + "ScanTests/"; - static readonly byte[] entry = new byte[100]; - static int entryLength = 100; - static int numEntries = 1000; - static int entryFlag = 9999; + private string path; + static byte[] entry; + const int entryLength = 100; + const int numEntries = 1000; + static readonly int entryFlag = 9999; // Create and populate the log file so can do various scans [SetUp] public void Setup() { + entry = new byte[100]; + path = TestUtils.MethodTestDir + "/"; + // Clean up log files from previous test runs in case they weren't cleaned up - try { new DirectoryInfo(path).Delete(true); } - catch {} + TestUtils.DeleteDirectory(path, wait:true); + } - // Set up the Devices \ logs - device = Devices.CreateLogDevice(path + "LogScan", deleteOnClose: true); - log = new FasterLog(new FasterLogSettings { LogDevice = device }); - deviceUnCommitted = Devices.CreateLogDevice(path + "LogScanUncommitted", deleteOnClose: true); - logUncommitted = new FasterLog(new FasterLogSettings { LogDevice = deviceUnCommitted }); + [TearDown] + public void TearDown() + { + log?.Dispose(); + log = null; + device?.Dispose(); + device = null; + deviceUnCommitted?.Dispose(); + deviceUnCommitted = null; + logUncommitted?.Dispose(); + logUncommitted = null; + // Clean up log files + TestUtils.DeleteDirectory(path); + } + + public void PopulateLog(FasterLog log) + { //****** Populate log for Basic data for tests // Set Default entry data for (int i = 0; i < entryLength; i++) - { entry[i] = (byte)i; - } // Enqueue but set each Entry in a way that can differentiate between entries for (int i = 0; i < numEntries; i++) @@ -67,14 +70,14 @@ public void Setup() // Commit to the log log.Commit(true); + } - + public void PopulateUncommittedLog(FasterLog logUncommitted) + { //****** Populate uncommitted log / device for ScanUncommittedTest // Set Default entry data for (int j = 0; j < entryLength; j++) - { entry[j] = (byte)j; - } // Enqueue but set each Entry in a way that can differentiate between entries for (int j = 0; j < numEntries; j++) @@ -93,33 +96,22 @@ public void Setup() // refresh uncommitted so can see it when scan - do NOT commit though logUncommitted.RefreshUncommitted(true); - - } - - [TearDown] - public void TearDown() - { - log.Dispose(); - device.Dispose(); - deviceUnCommitted.Dispose(); - logUncommitted.Dispose(); - - // Clean up log files - try { new DirectoryInfo(path).Delete(true); } - catch { } } [Test] [Category("FasterLog")] - public void ScanBasicDefaultTest() + [Category("Smoke")] + public void ScanBasicDefaultTest([Values] TestUtils.DeviceType deviceType) { + // Create log and device here (not in setup) because using DeviceType Enum which can't be used in Setup + string filename = path + "LogScanDefault" + deviceType.ToString() + ".log"; + device = TestUtils.CreateTestDevice(deviceType, filename); + log = new FasterLog(new FasterLogSettings { LogDevice = device, SegmentSizeBits = 22, LogCommitDir = path }); + PopulateLog(log); // Basic default scan from start to end // Indirectly used in other tests, but good to have the basic test here for completeness - // flag to make sure data has been checked - bool datacheckrun = false; - // Read the log - Look for the flag so know each entry is unique int currentEntry = 0; using (var iter = log.Scan(0, 100_000_000)) @@ -128,31 +120,29 @@ public void ScanBasicDefaultTest() { if (currentEntry < entryLength) { - // set check flag to show got in here - datacheckrun = true; - // Span Batch only added first entry several times so have separate verification - Assert.IsTrue(result[currentEntry] == (byte)entryFlag, "Fail - Result["+ currentEntry.ToString() + "]:" + result[0].ToString() + " entryFlag:" + entryFlag); - + Assert.AreEqual((byte)entryFlag, result[currentEntry]); currentEntry++; } } } - // if data verification was skipped, then pop a fail - if (datacheckrun == false) - Assert.Fail("Failure -- data loop after log.Scan never entered so wasn't verified. "); + // Make sure expected length is same as current - also makes sure that data verification was not skipped + Assert.AreEqual(entryLength, currentEntry); + } [Test] [Category("FasterLog")] - public void ScanNoDefaultTest() + public void ScanNoDefaultTest([Values] TestUtils.DeviceType deviceType) { - // Test where all params are set just to make sure handles it ok - // flag to make sure data has been checked - bool datacheckrun = false; + // Create log and device here (not in setup) because using DeviceType Enum which can't be used in Setup + string filename = path + "LogScanNoDefault" + deviceType.ToString() + ".log"; + device = TestUtils.CreateTestDevice(deviceType, filename); + log = new FasterLog(new FasterLogSettings { LogDevice = device, SegmentSizeBits = 22, LogCommitDir = path }); + PopulateLog(log); // Read the log - Look for the flag so know each entry is unique int currentEntry = 0; @@ -162,32 +152,29 @@ public void ScanNoDefaultTest() { if (currentEntry < entryLength) { - // set check flag to show got in here - datacheckrun = true; - // Span Batch only added first entry several times so have separate verification - Assert.IsTrue(result[currentEntry] == (byte)entryFlag, "Fail - Result[" + currentEntry.ToString() + "]:" + result[0].ToString() + " entryFlag:" + entryFlag); - + Assert.AreEqual((byte)entryFlag, result[currentEntry]); currentEntry++; } } } - // if data verification was skipped, then pop a fail - if (datacheckrun == false) - Assert.Fail("Failure -- data loop after log.Scan never entered so wasn't verified. "); + // Make sure expected length is same as current - also makes sure that data verification was not skipped + Assert.AreEqual(entryLength, currentEntry); } - [Test] [Category("FasterLog")] - public void ScanByNameTest() + [Category("Smoke")] + public void ScanByNameTest([Values] TestUtils.DeviceType deviceType) { - //You can persist iterators(or more precisely, their CompletedUntilAddress) as part of a commit by simply naming them during their creation. - // flag to make sure data has been checked - bool datacheckrun = false; + // Create log and device here (not in setup) because using DeviceType Enum which can't be used in Setup + string filename = path + "LogScanByName" + deviceType.ToString() + ".log"; + device = TestUtils.CreateTestDevice(deviceType, filename); + log = new FasterLog(new FasterLogSettings { LogDevice = device, SegmentSizeBits = 22, LogCommitDir = path }); + PopulateLog(log); // Read the log - Look for the flag so know each entry is unique int currentEntry = 0; @@ -197,31 +184,29 @@ public void ScanByNameTest() { if (currentEntry < entryLength) { - // set check flag to show got in here - datacheckrun = true; - // Span Batch only added first entry several times so have separate verification - Assert.IsTrue(result[currentEntry] == (byte)entryFlag, "Fail - Result[" + currentEntry.ToString() + "]:" + result[0].ToString() + " entryFlag:" + entryFlag); - + Assert.AreEqual((byte)entryFlag, result[currentEntry]); currentEntry++; } } } - // if data verification was skipped, then pop a fail - if (datacheckrun == false) - Assert.Fail("Failure -- data loop after log.Scan never entered so wasn't verified. "); + // Make sure expected length is same as current - also makes sure that data verification was not skipped + Assert.AreEqual(entryLength, currentEntry); } - [Test] [Category("FasterLog")] - public void ScanWithoutRecoverTest() + [Category("Smoke")] + public void ScanWithoutRecoverTest([Values] TestUtils.DeviceType deviceType) { // You may also force an iterator to start at the specified begin address, i.e., without recovering: recover parameter = false - // flag to make sure data has been checked - bool datacheckrun = false; + // Create log and device here (not in setup) because using DeviceType Enum which can't be used in Setup + string filename = path + "LogScanWithoutRecover" + deviceType.ToString() + ".log"; + device = TestUtils.CreateTestDevice(deviceType, filename); + log = new FasterLog(new FasterLogSettings { LogDevice = device, SegmentSizeBits = 22, LogCommitDir = path }); + PopulateLog(log); // Read the log int currentEntry = 9; // since starting at specified address of 1000, need to set current entry as 9 so verification starts at proper spot @@ -231,30 +216,29 @@ public void ScanWithoutRecoverTest() { if (currentEntry < entryLength) { - // set check flag to show got in here - datacheckrun = true; - // Span Batch only added first entry several times so have separate verification - Assert.IsTrue(result[currentEntry] == (byte)entryFlag, "Fail - Result[" + currentEntry.ToString() + "]:" + result[0].ToString() + " entryFlag:" + entryFlag); - + Assert.AreEqual((byte)entryFlag, result[currentEntry]); currentEntry++; } } } - // if data verification was skipped, then pop a fail - if (datacheckrun == false) - Assert.Fail("Failure -- data loop after log.Scan never entered so wasn't verified. "); + // Make sure expected length is same as current - also makes sure that data verification was not skipped + Assert.AreEqual(entryLength, currentEntry); } [Test] [Category("FasterLog")] - public void ScanBufferingModeDoublePageTest() + [Category("Smoke")] + public void ScanBufferingModeDoublePageTest([Values] TestUtils.DeviceType deviceType) { // Same as default, but do it just to make sure have test in case default changes - // flag to make sure data has been checked - bool datacheckrun = false; + // Create log and device here (not in setup) because using DeviceType Enum which can't be used in Setup + string filename = path + "LogScanDoublePage" + deviceType.ToString() + ".log"; + device = TestUtils.CreateTestDevice(deviceType, filename); + log = new FasterLog(new FasterLogSettings { LogDevice = device, SegmentSizeBits = 22, LogCommitDir = path }); + PopulateLog(log); // Read the log - Look for the flag so know each entry is unique int currentEntry = 0; @@ -264,29 +248,27 @@ public void ScanBufferingModeDoublePageTest() { if (currentEntry < entryLength) { - // set check flag to show got in here - datacheckrun = true; - // Span Batch only added first entry several times so have separate verification - Assert.IsTrue(result[currentEntry] == (byte)entryFlag, "Fail - Result[" + currentEntry.ToString() + "]:" + result[0].ToString() + " entryFlag:" + entryFlag); - + Assert.AreEqual((byte)entryFlag, result[currentEntry]); currentEntry++; } } } - // if data verification was skipped, then pop a fail - if (datacheckrun == false) - Assert.Fail("Failure -- data loop after log.Scan never entered so wasn't verified. "); + // Make sure expected length is same as current - also makes sure that data verification was not skipped + Assert.AreEqual(entryLength, currentEntry); } [Test] [Category("FasterLog")] - public void ScanBufferingModeSinglePageTest() + [Category("Smoke")] + public void ScanBufferingModeSinglePageTest([Values] TestUtils.DeviceType deviceType) { - - // flag to make sure data has been checked - bool datacheckrun = false; + // Create log and device here (not in setup) because using DeviceType Enum which can't be used in Setup + string filename = path + "LogScanSinglePage" + deviceType.ToString() + ".log"; + device = TestUtils.CreateTestDevice(deviceType, filename); + log = new FasterLog(new FasterLogSettings { LogDevice = device, SegmentSizeBits = 22, LogCommitDir = path }); + PopulateLog(log); // Read the log - Look for the flag so know each entry is unique int currentEntry = 0; @@ -296,58 +278,46 @@ public void ScanBufferingModeSinglePageTest() { if (currentEntry < entryLength) { - // set check flag to show got in here - datacheckrun = true; - // Span Batch only added first entry several times so have separate verification - Assert.IsTrue(result[currentEntry] == (byte)entryFlag, "Fail - Result[" + currentEntry.ToString() + "]:" + result[0].ToString() + " entryFlag:" + entryFlag); - + Assert.AreEqual((byte)entryFlag, result[currentEntry]); currentEntry++; } } } - // if data verification was skipped, then pop a fail - if (datacheckrun == false) - Assert.Fail("Failure -- data loop after log.Scan never entered so wasn't verified. "); + // Make sure expected length is same as current - also makes sure that data verification was not skipped + Assert.AreEqual(entryLength, currentEntry); } - [Test] [Category("FasterLog")] - public void ScanUncommittedTest() + [Category("Smoke")] + public void ScanUncommittedTest([Values] TestUtils.DeviceType deviceType) { - - // flag to make sure data has been checked - bool datacheckrun = false; + // Create log and device here (not in setup) because using DeviceType Enum which can't be used in Setup + string filename = path + "LogScan" + deviceType.ToString() + ".log"; + device = TestUtils.CreateTestDevice(deviceType, filename); + log = new FasterLog(new FasterLogSettings { LogDevice = device, SegmentSizeBits = 22, LogCommitDir = path }); + PopulateUncommittedLog(log); // Setting scanUnCommitted to true is actual test here. // Read the log - Look for the flag so know each entry is unique and still reads uncommitted int currentEntry = 0; - using (var iter = logUncommitted.Scan(0, 100_000_000, scanUncommitted: true)) + using (var iter = log.Scan(0, 100_000_000, scanUncommitted: true)) { while (iter.GetNext(out byte[] result, out _, out _)) { if (currentEntry < entryLength) { - // set check flag to show got in here - datacheckrun = true; - // Span Batch only added first entry several times so have separate verification - Assert.IsTrue(result[currentEntry] == (byte)entryFlag, "Fail - Result[" + currentEntry.ToString() + "]:" + result[0].ToString() + " entryFlag:" + entryFlag); - + Assert.AreEqual((byte)entryFlag, result[currentEntry]); currentEntry++; } } } - // if data verification was skipped, then pop a fail - if (datacheckrun == false) - Assert.Fail("Failure -- data loop after log.Scan never entered so wasn't verified. "); + // Make sure expected length is same as current - also makes sure that data verification was not skipped + Assert.AreEqual(entryLength, currentEntry); } - - } } - - diff --git a/cs/test/FasterLogTests.cs b/cs/test/FasterLogTests.cs index a92fc61f8..b3918bc86 100644 --- a/cs/test/FasterLogTests.cs +++ b/cs/test/FasterLogTests.cs @@ -16,15 +16,21 @@ namespace FASTER.test [TestFixture] internal class FasterLogStandAloneTests { - [Test] - public void TestDisposeReleasesFileLocksWithInprogressCommit() + [Category("FasterLog")] + [Category("Smoke")] + + public void TestDisposeReleasesFileLocksWithInprogressCommit([Values] TestUtils.DeviceType deviceType) { - string commitPath = TestContext.CurrentContext.TestDirectory + "/" + TestContext.CurrentContext.Test.Name + "/"; - DirectoryInfo di = Directory.CreateDirectory(commitPath); - IDevice device = Devices.CreateLogDevice(commitPath + "testDisposeReleasesFileLocksWithInprogressCommit.log", preallocateFile: true, deleteOnClose: false); - FasterLog fasterLog = new FasterLog(new FasterLogSettings { LogDevice = device, LogChecksum = LogChecksumType.PerEntry }); + string path = TestUtils.MethodTestDir + "/"; + string filename = path + "TestDisposeRelease" + deviceType.ToString() + ".log"; + + DirectoryInfo di = Directory.CreateDirectory(path); + IDevice device = TestUtils.CreateTestDevice(deviceType, filename); + FasterLog fasterLog = new FasterLog(new FasterLogSettings { LogDevice = device, SegmentSizeBits = 22, LogCommitDir = path, LogChecksum = LogChecksumType.PerEntry }); + Assert.IsTrue(fasterLog.TryEnqueue(new byte[100], out long beginAddress)); + fasterLog.Commit(spinWait: false); fasterLog.Dispose(); device.Dispose(); @@ -40,21 +46,21 @@ public void TestDisposeReleasesFileLocksWithInprogressCommit() } } - [TestFixture] - internal class FasterLogTests + // This test base class allows splitting up the tests into separate fixtures that can be run in parallel + internal class FasterLogTestBase { - const int entryLength = 100; - const int numEntries = 100000;//1000000; - const int numSpanEntries = 500; // really slows down if go too many - private FasterLog log; - private IDevice device; - private string commitPath; - private DeviceLogCommitCheckpointManager manager; - - static readonly byte[] entry = new byte[100]; - static readonly ReadOnlySpanBatch spanBatch = new ReadOnlySpanBatch(10000); - - private struct ReadOnlySpanBatch : IReadOnlySpanBatch + protected const int entryLength = 100; + protected const int numEntries = 100000;//1000000; + protected const int numSpanEntries = 500; // really slows down if go too many + protected FasterLog log; + protected IDevice device; + protected string path; + protected DeviceLogCommitCheckpointManager manager; + + protected static readonly byte[] entry = new byte[100]; + protected static readonly ReadOnlySpanBatch spanBatch = new ReadOnlySpanBatch(10000); + + protected struct ReadOnlySpanBatch : IReadOnlySpanBatch { private readonly int batchSize; public ReadOnlySpanBatch(int batchSize) => this.batchSize = batchSize; @@ -62,41 +68,26 @@ private struct ReadOnlySpanBatch : IReadOnlySpanBatch public int TotalEntries() => batchSize; } - [SetUp] - public void Setup() + protected void BaseSetup() { - commitPath = TestContext.CurrentContext.TestDirectory + "/" + TestContext.CurrentContext.Test.Name + "/"; + path = TestUtils.MethodTestDir + "/"; // Clean up log files from previous test runs in case they weren't cleaned up - try - { - if (Directory.Exists(commitPath)) - Directory.Delete(commitPath, true); - } - catch { } - - device = Devices.CreateLogDevice(commitPath + "fasterlog.log", deleteOnClose: true); - manager = new DeviceLogCommitCheckpointManager(new LocalStorageNamedDeviceFactory(deleteOnClose: true), new DefaultCheckpointNamingScheme(commitPath)); + TestUtils.DeleteDirectory(path, wait: true); - } + manager = new DeviceLogCommitCheckpointManager(new LocalStorageNamedDeviceFactory(deleteOnClose: true), new DefaultCheckpointNamingScheme(path)); + } - [TearDown] - public void TearDown() + protected void BaseTearDown() { - if (log is not null) - log.Dispose(); - manager.Dispose(); - device.Dispose(); - - // Saw timing issues on release build where fasterlog.log was not quite freed up before deleting which caused long delays - Thread.Sleep(1000); - try - { - if (Directory.Exists(commitPath)) - Directory.Delete(commitPath, true); - } - catch { } - + log?.Dispose(); + log = null; + manager?.Dispose(); + manager = null; + device?.Dispose(); + device = null; + + TestUtils.DeleteDirectory(path); } internal class Counter @@ -129,7 +120,7 @@ public enum IteratorType internal static bool IsAsync(IteratorType iterType) => iterType == IteratorType.AsyncByteVector || iterType == IteratorType.AsyncMemoryOwner; - private async ValueTask AssertGetNext(IAsyncEnumerator<(byte[] entry, int entryLength, long currentAddress, long nextAddress)> asyncByteVectorIter, + protected async ValueTask AssertGetNext(IAsyncEnumerator<(byte[] entry, int entryLength, long currentAddress, long nextAddress)> asyncByteVectorIter, IAsyncEnumerator<(IMemoryOwner<byte> entry, int entryLength, long currentAddress, long nextAddress)> asyncMemoryOwnerIter, FasterLogScanIterator iter, byte[] expectedData = default, bool verifyAtEnd = false) { @@ -141,7 +132,7 @@ private async ValueTask AssertGetNext(IAsyncEnumerator<(byte[] entry, int entryL // MoveNextAsync() would hang here waiting for more entries if (verifyAtEnd) - Assert.IsTrue(asyncByteVectorIter.Current.nextAddress == log.TailAddress); + Assert.AreEqual(log.TailAddress, asyncByteVectorIter.Current.nextAddress); return; } @@ -154,7 +145,7 @@ private async ValueTask AssertGetNext(IAsyncEnumerator<(byte[] entry, int entryL // MoveNextAsync() would hang here waiting for more entries if (verifyAtEnd) - Assert.IsTrue(asyncMemoryOwnerIter.Current.nextAddress == log.TailAddress); + Assert.AreEqual(log.TailAddress, asyncMemoryOwnerIter.Current.nextAddress); return; } @@ -165,10 +156,34 @@ private async ValueTask AssertGetNext(IAsyncEnumerator<(byte[] entry, int entryL Assert.IsFalse(iter.GetNext(out _, out _, out _)); } + protected static async Task LogWriterAsync(FasterLog log, byte[] entry) + { + CancellationTokenSource cts = new CancellationTokenSource(); + CancellationToken token = cts.Token; + + // Enter in some entries then wait on this separate thread + await log.EnqueueAsync(entry); + await log.EnqueueAsync(entry); + var commitTask = await log.CommitAsync(null, token); + await log.EnqueueAsync(entry); + await log.CommitAsync(commitTask, token); + } + } + + [TestFixture] + internal class FasterLogGeneralTests : FasterLogTestBase + { + [SetUp] + public void Setup() => base.BaseSetup(); + + [TearDown] + public void TearDown() => base.BaseTearDown(); + [Test] [Category("FasterLog")] - public async ValueTask FasterLogTest1([Values]LogChecksumType logChecksum, [Values]IteratorType iteratorType) + public async ValueTask FasterLogTest1([Values] LogChecksumType logChecksum, [Values] IteratorType iteratorType) { + device = Devices.CreateLogDevice(path + "fasterlog.log", deleteOnClose: true); var logSettings = new FasterLogSettings { LogDevice = device, LogChecksum = logChecksum, LogCommitManager = manager }; log = IsAsync(iteratorType) ? await FasterLog.CreateAsync(logSettings) : new FasterLog(logSettings); @@ -214,16 +229,28 @@ public async ValueTask FasterLogTest1([Values]LogChecksumType logChecksum, [Valu Assert.Fail("Unknown IteratorType"); break; } - Assert.IsTrue(counter.count == numEntries); + Assert.AreEqual(numEntries, counter.count); } + } + + + [TestFixture] + internal class FasterLogEnqueueTests : FasterLogTestBase + { + [SetUp] + public void Setup() => base.BaseSetup(); + + [TearDown] + public void TearDown() => base.BaseTearDown(); [Test] [Category("FasterLog")] - public async ValueTask TryEnqueue1([Values]LogChecksumType logChecksum, [Values]IteratorType iteratorType) + public async ValueTask TryEnqueue1([Values] LogChecksumType logChecksum, [Values] IteratorType iteratorType) { CancellationTokenSource cts = new CancellationTokenSource(); CancellationToken token = cts.Token; + device = Devices.CreateLogDevice(path + "fasterlog.log", deleteOnClose: true); var logSettings = new FasterLogSettings { LogDevice = device, LogChecksum = logChecksum, LogCommitManager = manager }; log = IsAsync(iteratorType) ? await FasterLog.CreateAsync(logSettings) : new FasterLog(logSettings); @@ -254,19 +281,24 @@ public async ValueTask TryEnqueue1([Values]LogChecksumType logChecksum, [Values] Assert.IsFalse(waitingReader.IsCompleted); await log.CommitAsync(token); - while (!waitingReader.IsCompleted); + while (!waitingReader.IsCompleted) ; Assert.IsTrue(waitingReader.IsCompleted); - await AssertGetNext(asyncByteVectorIter, asyncMemoryOwnerIter, iter, data1, verifyAtEnd :true); + await AssertGetNext(asyncByteVectorIter, asyncMemoryOwnerIter, iter, data1, verifyAtEnd: true); } } } [Test] [Category("FasterLog")] - public async ValueTask TryEnqueue2([Values]LogChecksumType logChecksum, [Values]IteratorType iteratorType) + [Category("Smoke")] + + public async ValueTask TryEnqueue2([Values] LogChecksumType logChecksum, [Values] IteratorType iteratorType, [Values] TestUtils.DeviceType deviceType) { - var logSettings = new FasterLogSettings { LogDevice = device, PageSizeBits = 14, LogChecksum = logChecksum, LogCommitManager = manager }; + string filename = path + "TryEnqueue2" + deviceType.ToString() + ".log"; + device = TestUtils.CreateTestDevice(deviceType, filename); + + var logSettings = new FasterLogSettings { LogDevice = device, PageSizeBits = 14, LogChecksum = logChecksum, LogCommitManager = manager, SegmentSizeBits = 22 }; log = IsAsync(iteratorType) ? await FasterLog.CreateAsync(logSettings) : new FasterLog(logSettings); const int dataLength = 10000; @@ -319,12 +351,27 @@ public async ValueTask TryEnqueue2([Values]LogChecksumType logChecksum, [Values] break; } } + } + + [TestFixture] + internal class FasterLogTruncateTests : FasterLogTestBase + { + [SetUp] + public void Setup() => base.BaseSetup(); + + [TearDown] + public void TearDown() => base.BaseTearDown(); [Test] [Category("FasterLog")] - public async ValueTask TruncateUntilBasic([Values]LogChecksumType logChecksum, [Values]IteratorType iteratorType) + [Category("Smoke")] + + public async ValueTask TruncateUntilBasic([Values] LogChecksumType logChecksum, [Values] IteratorType iteratorType, [Values] TestUtils.DeviceType deviceType) { - var logSettings = new FasterLogSettings { LogDevice = device, PageSizeBits = 14, LogChecksum = logChecksum, LogCommitManager = manager }; + string filename = path + "TruncateUntilBasic" + deviceType.ToString() + ".log"; + device = TestUtils.CreateTestDevice(deviceType, filename); + + var logSettings = new FasterLogSettings { LogDevice = device, PageSizeBits = 14, LogChecksum = logChecksum, LogCommitManager = manager, SegmentSizeBits = 22 }; log = IsAsync(iteratorType) ? await FasterLog.CreateAsync(logSettings) : new FasterLog(logSettings); byte[] data1 = new byte[100]; @@ -335,11 +382,11 @@ public async ValueTask TruncateUntilBasic([Values]LogChecksumType logChecksum, [ log.Enqueue(data1); } - Assert.IsTrue(log.CommittedUntilAddress == log.BeginAddress); + Assert.AreEqual(log.BeginAddress, log.CommittedUntilAddress); await log.CommitAsync(); - Assert.IsTrue(log.CommittedUntilAddress == log.TailAddress); - Assert.IsTrue(log.CommittedBeginAddress == log.BeginAddress); + Assert.AreEqual(log.TailAddress, log.CommittedUntilAddress); + Assert.AreEqual(log.BeginAddress, log.CommittedBeginAddress); using var iter = log.Scan(0, long.MaxValue); var asyncByteVectorIter = iteratorType == IteratorType.AsyncByteVector ? iter.GetAsyncEnumerable().GetAsyncEnumerator() : default; @@ -349,25 +396,29 @@ public async ValueTask TruncateUntilBasic([Values]LogChecksumType logChecksum, [ log.TruncateUntil(iter.NextAddress); - Assert.IsTrue(log.CommittedUntilAddress == log.TailAddress); - Assert.IsTrue(log.CommittedBeginAddress < log.BeginAddress); - Assert.IsTrue(iter.NextAddress == log.BeginAddress); + Assert.AreEqual(log.TailAddress, log.CommittedUntilAddress); + Assert.Less(log.CommittedBeginAddress, log.BeginAddress); + Assert.AreEqual(log.BeginAddress, iter.NextAddress); await log.CommitAsync(); - Assert.IsTrue(log.CommittedUntilAddress == log.TailAddress); - Assert.IsTrue(log.CommittedBeginAddress == log.BeginAddress); + Assert.AreEqual(log.TailAddress, log.CommittedUntilAddress); + Assert.AreEqual(log.BeginAddress, log.CommittedBeginAddress); } [Test] [Category("FasterLog")] - public async ValueTask EnqueueAndWaitForCommitAsyncBasicTest([Values]LogChecksumType logChecksum) + [Category("Smoke")] + + public async ValueTask EnqueueAndWaitForCommitAsyncBasicTest([Values] LogChecksumType logChecksum, [Values] TestUtils.DeviceType deviceType) { CancellationToken cancellationToken = default; ReadOnlySpanBatch spanBatch = new ReadOnlySpanBatch(numSpanEntries); - log = new FasterLog(new FasterLogSettings { LogDevice = device, PageSizeBits = 16, MemorySizeBits = 16, LogChecksum = logChecksum, LogCommitManager = manager }); + string filename = path + "EnqueueAndWaitForCommitAsyncBasicTest" + deviceType.ToString() + ".log"; + device = TestUtils.CreateTestDevice(deviceType, filename); + log = new FasterLog(new FasterLogSettings { LogDevice = device, PageSizeBits = 16, MemorySizeBits = 16, LogChecksum = logChecksum, LogCommitManager = manager, SegmentSizeBits = 22 }); int headerSize = logChecksum == LogChecksumType.None ? 4 : 12; bool _disposed = false; @@ -405,8 +456,9 @@ public async ValueTask EnqueueAndWaitForCommitAsyncBasicTest([Values]LogChecksum [Test] [Category("FasterLog")] - public async ValueTask TruncateUntil2([Values] LogChecksumType logChecksum, [Values]IteratorType iteratorType) + public async ValueTask TruncateUntil2([Values] LogChecksumType logChecksum, [Values] IteratorType iteratorType) { + device = Devices.CreateLogDevice(path + "fasterlog.log", deleteOnClose: true); var logSettings = new FasterLogSettings { LogDevice = device, MemorySizeBits = 20, PageSizeBits = 14, LogChecksum = logChecksum, LogCommitManager = manager }; log = IsAsync(iteratorType) ? await FasterLog.CreateAsync(logSettings) : new FasterLog(logSettings); @@ -418,9 +470,9 @@ public async ValueTask TruncateUntil2([Values] LogChecksumType logChecksum, [Val log.Enqueue(data1); } log.RefreshUncommitted(); - Assert.IsTrue(log.SafeTailAddress == log.TailAddress); + Assert.AreEqual(log.TailAddress, log.SafeTailAddress); - Assert.IsTrue(log.CommittedUntilAddress < log.SafeTailAddress); + Assert.Less(log.CommittedUntilAddress, log.SafeTailAddress); using var iter = log.Scan(0, long.MaxValue, scanUncommitted: true); var asyncByteVectorIter = iteratorType == IteratorType.AsyncByteVector ? iter.GetAsyncEnumerable().GetAsyncEnumerator() : default; @@ -431,7 +483,7 @@ public async ValueTask TruncateUntil2([Values] LogChecksumType logChecksum, [Val case IteratorType.Sync: while (iter.GetNext(out _, out _, out _)) log.TruncateUntil(iter.NextAddress); - Assert.IsTrue(iter.NextAddress == log.SafeTailAddress); + Assert.AreEqual(log.SafeTailAddress, iter.NextAddress); break; case IteratorType.AsyncByteVector: { @@ -474,6 +526,7 @@ public async ValueTask TruncateUntil2([Values] LogChecksumType logChecksum, [Val [Category("FasterLog")] public async ValueTask TruncateUntilPageStart([Values] LogChecksumType logChecksum, [Values] IteratorType iteratorType) { + device = Devices.CreateLogDevice(path + "fasterlog.log", deleteOnClose: true); log = new FasterLog(new FasterLogSettings { LogDevice = device, MemorySizeBits = 20, PageSizeBits = 14, LogChecksum = logChecksum, LogCommitManager = manager }); byte[] data1 = new byte[1000]; for (int i = 0; i < 100; i++) data1[i] = (byte)i; @@ -483,9 +536,9 @@ public async ValueTask TruncateUntilPageStart([Values] LogChecksumType logChecks log.Enqueue(data1); } log.RefreshUncommitted(); - Assert.IsTrue(log.SafeTailAddress == log.TailAddress); + Assert.AreEqual(log.TailAddress, log.SafeTailAddress); - Assert.IsTrue(log.CommittedUntilAddress < log.SafeTailAddress); + Assert.Less(log.CommittedUntilAddress, log.SafeTailAddress); using (var iter = log.Scan(0, long.MaxValue, scanUncommitted: true)) { @@ -497,7 +550,7 @@ public async ValueTask TruncateUntilPageStart([Values] LogChecksumType logChecks case IteratorType.Sync: while (iter.GetNext(out _, out _, out _)) log.TruncateUntilPageStart(iter.NextAddress); - Assert.IsTrue(iter.NextAddress == log.SafeTailAddress); + Assert.AreEqual(log.SafeTailAddress, iter.NextAddress); break; case IteratorType.AsyncByteVector: { @@ -539,9 +592,12 @@ public async ValueTask TruncateUntilPageStart([Values] LogChecksumType logChecks [Test] [Category("FasterLog")] - public void CommitNoSpinWait() + [Category("Smoke")] + public void CommitNoSpinWait([Values] TestUtils.DeviceType deviceType) { - log = new FasterLog(new FasterLogSettings { LogDevice = device, LogCommitManager = manager }); + string filename = path + "CommitNoSpinWait" + deviceType.ToString() + ".log"; + device = TestUtils.CreateTestDevice(deviceType, filename); + log = new FasterLog(new FasterLogSettings { LogDevice = device, LogCommitManager = manager, SegmentSizeBits = 22 }); int commitFalseEntries = 100; @@ -554,15 +610,18 @@ public void CommitNoSpinWait() log.Enqueue(entry); } - // Main point of the test ... If true, spin-wait until commit completes. Otherwise, issue commit and return immediately. - // There won't be that much difference from True to False here as the True case is so quick. However, it is a good basic check - // to make sure it isn't crashing and that it does actually commit it - // Seen timing issues on CI machine when doing false to true ... so just take a second to let it settle - log.Commit(false); - Thread.Sleep(4000); + //******* + // Main point of the test ... If commit(true) (like other tests do) it waits until commit completes before moving on. + // If set to false, it will fire and forget the commit and return immediately (which is the way this test is set up). + // There won't be that much difference from True to False here as the True case is so quick but there can be issues if start checking right after commit without giving time to commit. + // Also, it is a good basic check to make sure it isn't crashing and that it does actually commit it + // Can take two approaches + // 1) Just give it a few seconds to commit before checking as literally takes a second or so to commit. If commit isn't finished after this slight delay, then we have issues and it will show as a fail when reads and it needs investigated + // 2) Check right away but if it fails, check again and repeat until it is done committing. No need to add this extra complexity into the test and so this is appropriate use of a sleep + //******* - // flag to make sure data has been checked - bool datacheckrun = false; + log.Commit(false); + Thread.Sleep(5000); // Read the log - Look for the flag so know each entry is unique int currentEntry = 0; @@ -572,39 +631,35 @@ public void CommitNoSpinWait() { if (currentEntry < entryLength) { - // set check flag to show got in here - datacheckrun = true; - - Assert.IsTrue(result[currentEntry] == (byte)currentEntry, "Fail - Result[" + currentEntry.ToString() + "]:" + result[0].ToString() + " currentEntry:" + currentEntry); - + Assert.AreEqual((byte)currentEntry, result[currentEntry]); currentEntry++; } } } - // if data verification was skipped, then pop a fail - if (datacheckrun == false) - Assert.Fail("Failure -- data loop after log.Scan never entered so wasn't verified. "); - log.Dispose(); - } + // Make sure expected length is same as current - also makes sure that data verification was not skipped + Assert.AreEqual(entryLength, currentEntry); + + } [Test] [Category("FasterLog")] - public async ValueTask CommitAsyncPrevTask() + [Category("Smoke")] + public async ValueTask CommitAsyncPrevTask([Values] TestUtils.DeviceType deviceType) { - CancellationTokenSource cts = new CancellationTokenSource(); CancellationToken token = cts.Token; - Task currentTask; - var logSettings = new FasterLogSettings { LogDevice = device, LogCommitManager = manager }; + string filename = $"{path}/CommitAsyncPrevTask_{deviceType}.log"; + device = TestUtils.CreateTestDevice(deviceType, filename); + var logSettings = new FasterLogSettings { LogDevice = device, LogCommitManager = manager, SegmentSizeBits = 22 }; log = await FasterLog.CreateAsync(logSettings); - // make it small since launching each on separate threads - int entryLength = 10; + const int entryLength = 10; + int expectedEntries = 3; // Not entry length because this is number of enqueues called // Set Default entry data for (int i = 0; i < entryLength; i++) @@ -613,20 +668,18 @@ public async ValueTask CommitAsyncPrevTask() } // Enqueue and AsyncCommit in a separate thread (wait there until commit is done though). - currentTask = Task.Run(() => LogWriterAsync(log, entry), token); - - // Give all a second or so to queue up and to help with timing issues - shouldn't need but timing issues - Thread.Sleep(2000); + Task currentTask = Task.Run(() => LogWriterAsync(log, entry), token); // Commit to the log currentTask.Wait(4000, token); // double check to make sure finished - seen cases where timing kept running even after commit done + bool wasCanceled = false; if (currentTask.Status != TaskStatus.RanToCompletion) + { + wasCanceled = true; cts.Cancel(); - - // flag to make sure data has been checked - bool datacheckrun = false; + } // Read the log to make sure all entries are put in int currentEntry = 0; @@ -636,19 +689,14 @@ public async ValueTask CommitAsyncPrevTask() { if (currentEntry < entryLength) { - // set check flag to show got in here - datacheckrun = true; - - Assert.IsTrue(result[currentEntry] == (byte)currentEntry, "Fail - Result[" + currentEntry.ToString() + "]:" + result[0].ToString() + " not match expected:" + currentEntry); - + Assert.AreEqual((byte)currentEntry, result[currentEntry]); currentEntry++; } } } - // if data verification was skipped, then pop a fail - if (datacheckrun == false) - Assert.Fail("Failure -- data loop after log.Scan never entered so wasn't verified. "); + // Make sure expected entries is same as current - also makes sure that data verification was not skipped + Assert.AreEqual(expectedEntries, currentEntry); // NOTE: seeing issues where task is not running to completion on Release builds // This is a final check to make sure task finished. If didn't then assert @@ -656,35 +704,23 @@ public async ValueTask CommitAsyncPrevTask() // case of task not stopping if (currentTask.Status != TaskStatus.RanToCompletion) { - Assert.Fail("Final Status check Failure -- Task should be 'RanToCompletion' but current Status is:" + currentTask.Status); + Assert.Fail($"Final Status check Failure -- Task should be 'RanToCompletion' but current Status is: {currentTask.Status}; wasCanceled = {wasCanceled}"); } } - static async Task LogWriterAsync(FasterLog log, byte[] entry) - { - - CancellationTokenSource cts = new CancellationTokenSource(); - CancellationToken token = cts.Token; - - - // Enter in some entries then wait on this separate thread - await log.EnqueueAsync(entry); - await log.EnqueueAsync(entry); - var commitTask = await log.CommitAsync(null,token); - await log.EnqueueAsync(entry); - await log.CommitAsync(commitTask,token); - } - [Test] [Category("FasterLog")] - public async ValueTask RefreshUncommittedAsyncTest([Values] IteratorType iteratorType) + [Category("Smoke")] + public async ValueTask RefreshUncommittedAsyncTest([Values] IteratorType iteratorType, [Values] TestUtils.DeviceType deviceType) { - CancellationTokenSource cts = new CancellationTokenSource(); CancellationToken token = cts.Token; - log = new FasterLog(new FasterLogSettings { LogDevice = device, MemorySizeBits = 20, PageSizeBits = 14, LogCommitManager = manager }); + string filename = path + "RefreshUncommittedAsyncTest" + deviceType.ToString() + ".log"; + device = TestUtils.CreateTestDevice(deviceType, filename); + + log = new FasterLog(new FasterLogSettings { LogDevice = device, MemorySizeBits = 20, PageSizeBits = 14, LogCommitManager = manager, SegmentSizeBits = 22 }); byte[] data1 = new byte[1000]; for (int i = 0; i < 100; i++) data1[i] = (byte)i; @@ -696,8 +732,8 @@ public async ValueTask RefreshUncommittedAsyncTest([Values] IteratorType iterato // Actual tess is here await log.RefreshUncommittedAsync(); - Assert.IsTrue(log.SafeTailAddress == log.TailAddress); - Assert.IsTrue(log.CommittedUntilAddress < log.SafeTailAddress); + Assert.AreEqual(log.TailAddress, log.SafeTailAddress); + Assert.Less(log.CommittedUntilAddress, log.SafeTailAddress); using (var iter = log.Scan(0, long.MaxValue, scanUncommitted: true)) { @@ -709,7 +745,7 @@ public async ValueTask RefreshUncommittedAsyncTest([Values] IteratorType iterato case IteratorType.Sync: while (iter.GetNext(out _, out _, out _)) log.TruncateUntilPageStart(iter.NextAddress); - Assert.IsTrue(iter.NextAddress == log.SafeTailAddress); + Assert.AreEqual(log.SafeTailAddress, iter.NextAddress); break; case IteratorType.AsyncByteVector: { @@ -740,7 +776,7 @@ public async ValueTask RefreshUncommittedAsyncTest([Values] IteratorType iterato if (!IsAsync(iteratorType)) Assert.IsFalse(iter.GetNext(out _, out _, out _)); - // Actual tess is here + // Actual test is here await log.RefreshUncommittedAsync(token); await AssertGetNext(asyncByteVectorIter, asyncMemoryOwnerIter, iter, data1, verifyAtEnd: true); @@ -748,9 +784,5 @@ public async ValueTask RefreshUncommittedAsyncTest([Values] IteratorType iterato log.Dispose(); } - - - - } } diff --git a/cs/test/FunctionPerSessionTests.cs b/cs/test/FunctionPerSessionTests.cs index d89695caf..dedca489d 100644 --- a/cs/test/FunctionPerSessionTests.cs +++ b/cs/test/FunctionPerSessionTests.cs @@ -1,7 +1,6 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + using System.Threading; using System.Threading.Tasks; using FASTER.core; @@ -107,7 +106,8 @@ public class FunctionPerSessionTests [SetUp] public void Setup() { - _log = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/FunctionPerSessionTests1.log", deleteOnClose: true); + TestUtils.DeleteDirectory(TestUtils.MethodTestDir, wait:true); + _log = Devices.CreateLogDevice(TestUtils.MethodTestDir + "/FunctionPerSessionTests1.log", deleteOnClose: true); _faster = new FasterKV<int, RefCountedValue>(128, new LogSettings() { @@ -122,9 +122,11 @@ public void Setup() [TearDown] public void TearDown() { - _faster.Dispose(); + _faster?.Dispose(); _faster = null; - _log.Dispose(); + _log?.Dispose(); + _log = null; + TestUtils.DeleteDirectory(TestUtils.MethodTestDir); } [Test] diff --git a/cs/test/GenericByteArrayTests.cs b/cs/test/GenericByteArrayTests.cs index 454b22539..07d8c5f28 100644 --- a/cs/test/GenericByteArrayTests.cs +++ b/cs/test/GenericByteArrayTests.cs @@ -19,11 +19,11 @@ internal class GenericByteArrayTests [SetUp] public void Setup() { - log = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/GenericStringTests.log", deleteOnClose: true); - objlog = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/GenericStringTests.obj.log", deleteOnClose: true); + TestUtils.DeleteDirectory(TestUtils.MethodTestDir, wait:true); + log = Devices.CreateLogDevice(TestUtils.MethodTestDir + "/GenericStringTests.log", deleteOnClose: true); + objlog = Devices.CreateLogDevice(TestUtils.MethodTestDir + "/GenericStringTests.obj.log", deleteOnClose: true); - fht - = new FasterKV<byte[], byte[]>( + fht = new FasterKV<byte[], byte[]>( 1L << 20, // size of hash table in #cache lines; 64 bytes per cache line new LogSettings { LogDevice = log, ObjectLogDevice = objlog, MutableFraction = 0.1, MemorySizeBits = 14, PageSizeBits = 9 }, // log device comparer: new ByteArrayEC() @@ -35,13 +35,17 @@ public void Setup() [TearDown] public void TearDown() { - session.Dispose(); - fht.Dispose(); + session?.Dispose(); + session = null; + fht?.Dispose(); fht = null; - log.Dispose(); - objlog.Dispose(); - } + log?.Dispose(); + log = null; + objlog?.Dispose(); + objlog = null; + TestUtils.DeleteDirectory(TestUtils.MethodTestDir); + } private byte[] GetByteArray(int i) { @@ -50,6 +54,7 @@ private byte[] GetByteArray(int i) [Test] [Category("FasterKV")] + [Category("Smoke")] public void ByteArrayBasicTest() { const int totalRecords = 2000; diff --git a/cs/test/GenericDiskDeleteTests.cs b/cs/test/GenericDiskDeleteTests.cs index d2e58bc98..c6191de65 100644 --- a/cs/test/GenericDiskDeleteTests.cs +++ b/cs/test/GenericDiskDeleteTests.cs @@ -6,7 +6,6 @@ namespace FASTER.test { - [TestFixture] internal class GenericDiskDeleteTests { @@ -17,8 +16,9 @@ internal class GenericDiskDeleteTests [SetUp] public void Setup() { - log = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/GenericDiskDeleteTests.log", deleteOnClose: true); - objlog = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/GenericDiskDeleteTests.obj.log", deleteOnClose: true); + TestUtils.DeleteDirectory(TestUtils.MethodTestDir, wait:true); + log = Devices.CreateLogDevice(TestUtils.MethodTestDir + "/GenericDiskDeleteTests.log", deleteOnClose: true); + objlog = Devices.CreateLogDevice(TestUtils.MethodTestDir + "/GenericDiskDeleteTests.obj.log", deleteOnClose: true); fht = new FasterKV<MyKey, MyValue> (128, @@ -32,16 +32,22 @@ public void Setup() [TearDown] public void TearDown() { - session.Dispose(); - fht.Dispose(); + session?.Dispose(); + session = null; + fht?.Dispose(); fht = null; - log.Dispose(); - objlog.Dispose(); - } + log?.Dispose(); + log = null; + objlog?.Dispose(); + objlog = null; + TestUtils.DeleteDirectory(TestUtils.MethodTestDir); + } [Test] [Category("FasterKV")] + [Category("Smoke")] + public void DiskDeleteBasicTest1() { const int totalRecords = 2000; @@ -66,7 +72,7 @@ public void DiskDeleteBasicTest1() } else { - Assert.IsTrue(output.value.value == value.value); + Assert.AreEqual(value.value, output.value.value); } } @@ -90,7 +96,7 @@ public void DiskDeleteBasicTest1() } else { - Assert.IsTrue(status == Status.NOTFOUND); + Assert.AreEqual(Status.NOTFOUND, status); } } @@ -102,7 +108,7 @@ public void DiskDeleteBasicTest1() if (recordInfo.Tombstone) val++; } - Assert.IsTrue(totalRecords == val); + Assert.AreEqual(val, totalRecords); } @@ -127,25 +133,25 @@ public void DiskDeleteBasicTest2() var input = new MyInput { value = 1000 }; var output = new MyOutput(); var status = session.Read(ref key100, ref input, ref output, 1, 0); - Assert.IsTrue(status == Status.NOTFOUND); + Assert.AreEqual(Status.NOTFOUND, status); status = session.Upsert(ref key100, ref value100, 0, 0); - Assert.IsTrue(status == Status.OK); + Assert.AreEqual(Status.OK, status); status = session.Read(ref key100, ref input, ref output, 0, 0); - Assert.IsTrue(status == Status.OK); - Assert.IsTrue(output.value.value == value100.value); + Assert.AreEqual(Status.OK, status); + Assert.AreEqual(value100.value, output.value.value); session.Delete(ref key100, 0, 0); session.Delete(ref key200, 0, 0); // This RMW should create new initial value, since item is deleted status = session.RMW(ref key200, ref input, 1, 0); - Assert.IsTrue(status == Status.NOTFOUND); + Assert.AreEqual(Status.NOTFOUND, status); status = session.Read(ref key200, ref input, ref output, 0, 0); - Assert.IsTrue(status == Status.OK); - Assert.IsTrue(output.value.value == input.value); + Assert.AreEqual(Status.OK, status); + Assert.AreEqual(input.value, output.value.value); // Delete key 200 again session.Delete(ref key200, 0, 0); @@ -158,17 +164,17 @@ public void DiskDeleteBasicTest2() session.Upsert(ref _key, ref _value, 0, 0); } status = session.Read(ref key100, ref input, ref output, 1, 0); - Assert.IsTrue(status == Status.PENDING); + Assert.AreEqual(Status.PENDING, status); session.CompletePending(true); // This RMW should create new initial value, since item is deleted status = session.RMW(ref key200, ref input, 1, 0); - Assert.IsTrue(status == Status.PENDING); + Assert.AreEqual(Status.PENDING, status); session.CompletePending(true); status = session.Read(ref key200, ref input, ref output, 0, 0); - Assert.IsTrue(status == Status.OK); - Assert.IsTrue(output.value.value == input.value); + Assert.AreEqual(Status.OK, status); + Assert.AreEqual(input.value, output.value.value); } } } diff --git a/cs/test/GenericIterationTests.cs b/cs/test/GenericIterationTests.cs index 5dabbb8db..4e29f9eec 100644 --- a/cs/test/GenericIterationTests.cs +++ b/cs/test/GenericIterationTests.cs @@ -1,19 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -using System; -using System.Text; -using System.Threading; -using System.Threading.Tasks; -using System.Collections.Generic; -using System.Linq; using FASTER.core; -using System.IO; using NUnit.Framework; namespace FASTER.test { - [TestFixture] internal class GenericIterationTests { @@ -24,8 +16,9 @@ internal class GenericIterationTests [SetUp] public void Setup() { - log = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/GenericIterationTests.log", deleteOnClose: true); - objlog = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/GenericIterationTests.obj.log", deleteOnClose: true); + TestUtils.DeleteDirectory(TestUtils.MethodTestDir, wait: true); + log = Devices.CreateLogDevice(TestUtils.MethodTestDir + "/GenericIterationTests.log", deleteOnClose: true); + objlog = Devices.CreateLogDevice(TestUtils.MethodTestDir + "/GenericIterationTests.obj.log", deleteOnClose: true); fht = new FasterKV<MyKey, MyValue> (128, @@ -39,15 +32,22 @@ public void Setup() [TearDown] public void TearDown() { - session.Dispose(); - fht.Dispose(); + session?.Dispose(); + session = null; + fht?.Dispose(); fht = null; - log.Dispose(); - objlog.Dispose(); + log?.Dispose(); + log = null; + objlog?.Dispose(); + objlog = null; + + TestUtils.DeleteDirectory(TestUtils.MethodTestDir); } [Test] [Category("FasterKV")] + [Category("Smoke")] + public void IterationBasicTest() { using var session = fht.For(new MyFunctionsDelete()).NewSession<MyFunctionsDelete>(); @@ -67,12 +67,11 @@ public void IterationBasicTest() while (iter.GetNext(out var recordInfo)) { count++; - Assert.IsTrue(iter.GetValue().value == iter.GetKey().key); + Assert.AreEqual(iter.GetKey().key, iter.GetValue().value); } iter.Dispose(); - Assert.IsTrue(count == totalRecords); - + Assert.AreEqual(totalRecords, count); for (int i = 0; i < totalRecords; i++) { @@ -86,11 +85,11 @@ public void IterationBasicTest() while (iter.GetNext(out var recordInfo)) { count++; - Assert.IsTrue(iter.GetValue().value == iter.GetKey().key * 2); + Assert.AreEqual(iter.GetKey().key * 2, iter.GetValue().value); } iter.Dispose(); - Assert.IsTrue(count == totalRecords); + Assert.AreEqual(totalRecords, count); for (int i = totalRecords / 2; i < totalRecords; i++) { @@ -107,7 +106,7 @@ public void IterationBasicTest() } iter.Dispose(); - Assert.IsTrue(count == totalRecords); + Assert.AreEqual(totalRecords, count); for (int i = 0; i < totalRecords; i += 2) { @@ -124,7 +123,7 @@ public void IterationBasicTest() } iter.Dispose(); - Assert.IsTrue(count == totalRecords); + Assert.AreEqual(totalRecords, count); for (int i = 0; i < totalRecords; i += 2) { @@ -141,7 +140,7 @@ public void IterationBasicTest() } iter.Dispose(); - Assert.IsTrue(count == totalRecords / 2); + Assert.AreEqual(totalRecords / 2, count); for (int i = 0; i < totalRecords; i++) { @@ -155,12 +154,11 @@ public void IterationBasicTest() while (iter.GetNext(out var recordInfo)) { count++; - Assert.IsTrue(iter.GetValue().value == iter.GetKey().key * 3); + Assert.AreEqual(iter.GetKey().key * 3, iter.GetValue().value); } iter.Dispose(); - Assert.IsTrue(count == totalRecords); - + Assert.AreEqual(totalRecords, count); } } } diff --git a/cs/test/GenericLogCompactionTests.cs b/cs/test/GenericLogCompactionTests.cs index 91b6ae4f0..d50083f0d 100644 --- a/cs/test/GenericLogCompactionTests.cs +++ b/cs/test/GenericLogCompactionTests.cs @@ -6,53 +6,83 @@ namespace FASTER.test { - - [TestFixture] internal class GenericLogCompactionTests { private FasterKV<MyKey, MyValue> fht; private ClientSession<MyKey, MyValue, MyInput, MyOutput, int, MyFunctionsDelete> session; private IDevice log, objlog; + private string path; [SetUp] public void Setup() { - log = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/GenericLogCompactionTests.log", deleteOnClose: true); - objlog = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/GenericLogCompactionTests.obj.log", deleteOnClose: true); - - fht = new FasterKV<MyKey, MyValue> - (128, - logSettings: new LogSettings { LogDevice = log, ObjectLogDevice = objlog, MutableFraction = 0.1, MemorySizeBits = 14, PageSizeBits = 9 }, - checkpointSettings: new CheckpointSettings { CheckPointType = CheckpointType.FoldOver }, - serializerSettings: new SerializerSettings<MyKey, MyValue> { keySerializer = () => new MyKeySerializer(), valueSerializer = () => new MyValueSerializer() } - ); + path = TestUtils.MethodTestDir + "/"; + + // Clean up log files from previous test runs in case they weren't cleaned up + TestUtils.DeleteDirectory(path, wait:true); + + if (TestContext.CurrentContext.Test.Arguments.Length == 0) + { + // Default log creation + log = Devices.CreateLogDevice(path + "/GenericLogCompactionTests.log", deleteOnClose: true); + objlog = Devices.CreateLogDevice(path + "/GenericLogCompactionTests.obj.log", deleteOnClose: true); + + fht = new FasterKV<MyKey, MyValue> + (128, + logSettings: new LogSettings { LogDevice = log, ObjectLogDevice = objlog, MutableFraction = 0.1, MemorySizeBits = 14, PageSizeBits = 9 }, + checkpointSettings: new CheckpointSettings { CheckPointType = CheckpointType.FoldOver }, + serializerSettings: new SerializerSettings<MyKey, MyValue> { keySerializer = () => new MyKeySerializer(), valueSerializer = () => new MyValueSerializer() } + ); + } + else + { + // For this class, deviceType is the only parameter. Using this to illustrate the approach; NUnit doesn't provide metadata for arguments, + // so for multi-parameter tests it is probably better to stay with the "separate SetUp method" approach. + var deviceType = (TestUtils.DeviceType)TestContext.CurrentContext.Test.Arguments[0]; + + log = TestUtils.CreateTestDevice(deviceType, $"{path}LogCompactBasicTest_{deviceType}.log"); + objlog = TestUtils.CreateTestDevice(deviceType, $"{path}LogCompactBasicTest_{deviceType}.obj.log"); + + fht = new FasterKV<MyKey, MyValue> + (128, + logSettings: new LogSettings { LogDevice = log, ObjectLogDevice = objlog, MutableFraction = 0.1, MemorySizeBits = 14, PageSizeBits = 9, SegmentSizeBits = 22 }, + checkpointSettings: new CheckpointSettings { CheckPointType = CheckpointType.FoldOver }, + serializerSettings: new SerializerSettings<MyKey, MyValue> { keySerializer = () => new MyKeySerializer(), valueSerializer = () => new MyValueSerializer() } + ); + } session = fht.For(new MyFunctionsDelete()).NewSession<MyFunctionsDelete>(); } [TearDown] public void TearDown() { - session.Dispose(); - fht.Dispose(); + session?.Dispose(); + session = null; + fht?.Dispose(); fht = null; - log.Dispose(); - objlog.Dispose(); + log?.Dispose(); + log = null; + objlog?.Dispose(); + objlog = null; + + TestUtils.DeleteDirectory(path); } // Basic test that where shift begin address to untilAddress after compact [Test] [Category("FasterKV")] - public void LogCompactBasicTest() + [Category("Smoke")] + public void LogCompactBasicTest([Values] TestUtils.DeviceType deviceType) { MyInput input = new MyInput(); - const int totalRecords = 2000; + const int totalRecords = 500; long compactUntil = 0; for (int i = 0; i < totalRecords; i++) { - if (i == 1000) + if (i == 250) compactUntil = fht.Log.TailAddress; var key1 = new MyKey { key = i }; @@ -61,9 +91,9 @@ public void LogCompactBasicTest() } compactUntil = session.Compact(compactUntil, true); - Assert.IsTrue(fht.Log.BeginAddress == compactUntil); + Assert.AreEqual(compactUntil, fht.Log.BeginAddress); - // Read 2000 keys - all should be present + // Read all keys - all should be present for (int i = 0; i < totalRecords; i++) { MyOutput output = new MyOutput(); @@ -73,18 +103,24 @@ public void LogCompactBasicTest() var status = session.Read(ref key1, ref input, ref output, 0, 0); if (status == Status.PENDING) - session.CompletePending(true); - else { - Assert.IsTrue(status == Status.OK); - Assert.IsTrue(output.value.value == value.value); + session.CompletePendingWithOutputs(out var completedOutputs, wait: true); + Assert.IsTrue(completedOutputs.Next()); + Assert.AreEqual(Status.OK, completedOutputs.Current.Status); + output = completedOutputs.Current.Output; + Assert.IsFalse(completedOutputs.Next()); + completedOutputs.Dispose(); } + Assert.AreEqual(Status.OK, status); + Assert.AreEqual(value.value, output.value.value); } } // Basic test where DO NOT shift begin address to untilAddress after compact [Test] [Category("FasterKV")] + [Category("Compaction")] + [Category("Smoke")] public void LogCompactNotShiftBeginAddrTest() { MyInput input = new MyInput(); @@ -119,15 +155,15 @@ public void LogCompactNotShiftBeginAddrTest() session.CompletePending(true); else { - Assert.IsTrue(status == Status.OK); - Assert.IsTrue(output.value.value == value.value); + Assert.AreEqual(Status.OK, status); + Assert.AreEqual(value.value, output.value.value); } } } - [Test] [Category("FasterKV")] + [Category("Compaction")] public void LogCompactTestNewEntries() { MyInput input = new MyInput(); @@ -157,8 +193,8 @@ public void LogCompactTestNewEntries() var tail = fht.Log.TailAddress; compactUntil = session.Compact(compactUntil, true); - Assert.IsTrue(fht.Log.BeginAddress == compactUntil); - Assert.IsTrue(fht.Log.TailAddress == tail); + Assert.AreEqual(compactUntil, fht.Log.BeginAddress); + Assert.AreEqual(tail, fht.Log.TailAddress); // Read 2000 keys - all should be present for (int i = 0; i < totalRecords; i++) @@ -172,14 +208,16 @@ public void LogCompactTestNewEntries() session.CompletePending(true); else { - Assert.IsTrue(status == Status.OK); - Assert.IsTrue(output.value.value == value.value); + Assert.AreEqual(Status.OK, status); + Assert.AreEqual(value.value, output.value.value); } } } [Test] [Category("FasterKV")] + [Category("Compaction")] + [Category("Smoke")] public void LogCompactAfterDeleteTest() { MyInput input = new MyInput(); @@ -205,9 +243,9 @@ public void LogCompactAfterDeleteTest() } compactUntil = session.Compact(compactUntil, true); - Assert.IsTrue(fht.Log.BeginAddress == compactUntil); + Assert.AreEqual(compactUntil, fht.Log.BeginAddress); - // Read 2000 keys - all should be present + // Read keys - all should be present for (int i = 0; i < totalRecords; i++) { MyOutput output = new MyOutput(); @@ -223,12 +261,12 @@ public void LogCompactAfterDeleteTest() { if (ctx == 0) { - Assert.IsTrue(status == Status.OK); - Assert.IsTrue(output.value.value == value.value); + Assert.AreEqual(Status.OK, status); + Assert.AreEqual(value.value, output.value.value); } else { - Assert.IsTrue(status == Status.NOTFOUND); + Assert.AreEqual(Status.NOTFOUND, status); } } } @@ -236,6 +274,8 @@ public void LogCompactAfterDeleteTest() [Test] [Category("FasterKV")] + [Category("Compaction")] + public void LogCompactBasicCustomFctnTest() { MyInput input = new MyInput(); @@ -254,7 +294,7 @@ public void LogCompactBasicCustomFctnTest() } compactUntil = session.Compact(compactUntil, true, default(EvenCompactionFunctions)); - Assert.IsTrue(fht.Log.BeginAddress == compactUntil); + Assert.AreEqual(compactUntil, fht.Log.BeginAddress); // Read 2000 keys - all should be present for (var i = 0; i < totalRecords; i++) @@ -274,12 +314,12 @@ public void LogCompactBasicCustomFctnTest() { if (ctx == 0) { - Assert.IsTrue(status == Status.OK); - Assert.IsTrue(output.value.value == value.value); + Assert.AreEqual(Status.OK, status); + Assert.AreEqual(value.value, output.value.value); } else { - Assert.IsTrue(status == Status.NOTFOUND); + Assert.AreEqual(Status.NOTFOUND, status); } } } @@ -288,6 +328,8 @@ public void LogCompactBasicCustomFctnTest() // Same as basic test of Custom Functions BUT this will NOT shift begin address to untilAddress after compact [Test] [Category("FasterKV")] + [Category("Compaction")] + public void LogCompactCustomFctnNotShiftBeginTest() { MyInput input = new MyInput(); @@ -306,11 +348,11 @@ public void LogCompactCustomFctnNotShiftBeginTest() } compactUntil = session.Compact(compactUntil, false, default(EvenCompactionFunctions)); - Assert.IsFalse(fht.Log.BeginAddress == compactUntil); + Assert.AreNotEqual(compactUntil, fht.Log.BeginAddress); // Verified that begin address not changed so now compact and change Begin to untilAddress compactUntil = session.Compact(compactUntil, true, default(EvenCompactionFunctions)); - Assert.IsTrue(fht.Log.BeginAddress == compactUntil); + Assert.AreEqual(compactUntil, fht.Log.BeginAddress); // Read 2000 keys - all should be present for (var i = 0; i < totalRecords; i++) @@ -330,12 +372,12 @@ public void LogCompactCustomFctnNotShiftBeginTest() { if (ctx == 0) { - Assert.IsTrue(status == Status.OK); - Assert.IsTrue(output.value.value == value.value); + Assert.AreEqual(Status.OK, status); + Assert.AreEqual(value.value, output.value.value); } else { - Assert.IsTrue(status == Status.NOTFOUND); + Assert.AreEqual(Status.NOTFOUND, status); } } } @@ -343,6 +385,8 @@ public void LogCompactCustomFctnNotShiftBeginTest() [Test] [Category("FasterKV")] + [Category("Compaction")] + public void LogCompactCopyInPlaceCustomFctnTest() { // Update: irrelevant as session compaction no longer uses Copy/CopyInPlace @@ -376,8 +420,8 @@ public void LogCompactCopyInPlaceCustomFctnTest() } else { - Assert.IsTrue(status == Status.OK); - Assert.IsTrue(output.value.value == value.value); + Assert.AreEqual(Status.OK, status); + Assert.AreEqual(value.value, output.value.value); } } diff --git a/cs/test/GenericLogScanTests.cs b/cs/test/GenericLogScanTests.cs index 36ecbfc14..e0c28bfd5 100644 --- a/cs/test/GenericLogScanTests.cs +++ b/cs/test/GenericLogScanTests.cs @@ -2,56 +2,56 @@ // Licensed under the MIT license. using System; -using System.Text; -using System.Threading; -using System.Threading.Tasks; -using System.Collections.Generic; -using System.Linq; using FASTER.core; -using System.IO; using NUnit.Framework; namespace FASTER.test { - [TestFixture] internal class GenericFASTERScanTests { private FasterKV<MyKey, MyValue> fht; private IDevice log, objlog; - const int totalRecords = 2000; + const int totalRecords = 250; + private string path; [SetUp] public void Setup() { - log = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/GenericFASTERScanTests.log", deleteOnClose: true); - objlog = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/GenericFASTERScanTests.obj.log", deleteOnClose: true); + path = TestUtils.MethodTestDir + "/"; - fht = new FasterKV<MyKey, MyValue> - (128, - logSettings: new LogSettings { LogDevice = log, ObjectLogDevice = objlog, MutableFraction = 0.1, MemorySizeBits = 15, PageSizeBits = 9 }, - checkpointSettings: new CheckpointSettings { CheckPointType = CheckpointType.FoldOver }, - serializerSettings: new SerializerSettings<MyKey, MyValue> { keySerializer = () => new MyKeySerializer(), valueSerializer = () => new MyValueSerializer() } - ); + // Clean up log files from previous test runs in case they weren't cleaned up + TestUtils.DeleteDirectory(path, wait: true); } [TearDown] public void TearDown() { - fht.Dispose(); + fht?.Dispose(); fht = null; - log.Dispose(); - objlog.Dispose(); - } + log?.Dispose(); + log = null; + objlog?.Dispose(); + objlog = null; + TestUtils.DeleteDirectory(path); + } [Test] [Category("FasterKV")] - public void DiskWriteScanBasicTest() + [Category("Smoke")] + public void DiskWriteScanBasicTest([Values] TestUtils.DeviceType deviceType) { - using var session = fht.For(new MyFunctions()).NewSession<MyFunctions>(); + log = TestUtils.CreateTestDevice(deviceType, $"{path}DiskWriteScanBasicTest_{deviceType}.log"); + objlog = TestUtils.CreateTestDevice(deviceType, $"{path}DiskWriteScanBasicTest_{deviceType}.obj.log"); + fht = new (128, + logSettings: new LogSettings { LogDevice = log, ObjectLogDevice = objlog, MutableFraction = 0.1, MemorySizeBits = 15, PageSizeBits = 9, SegmentSizeBits = 22 }, + checkpointSettings: new CheckpointSettings { CheckPointType = CheckpointType.FoldOver }, + serializerSettings: new SerializerSettings<MyKey, MyValue> { keySerializer = () => new MyKeySerializer(), valueSerializer = () => new MyValueSerializer() } + ); - var s = fht.Log.Subscribe(new LogObserver()); + using var session = fht.For(new MyFunctions()).NewSession<MyFunctions>(); + using var s = fht.Log.Subscribe(new LogObserver()); var start = fht.Log.TailAddress; for (int i = 0; i < totalRecords; i++) @@ -59,35 +59,32 @@ public void DiskWriteScanBasicTest() var _key = new MyKey { key = i }; var _value = new MyValue { value = i }; session.Upsert(ref _key, ref _value, Empty.Default, 0); - if (i % 100 == 0) fht.Log.FlushAndEvict(true); + if (i % 100 == 0) + fht.Log.FlushAndEvict(true); } fht.Log.FlushAndEvict(true); + using (var iter = fht.Log.Scan(start, fht.Log.TailAddress, ScanBufferingMode.SinglePageBuffering)) { - - int val = 0; - while (iter.GetNext(out RecordInfo recordInfo, out MyKey key, out MyValue value)) + int val; + for (val = 0; iter.GetNext(out RecordInfo recordInfo, out MyKey key, out MyValue value); ++val) { - Assert.IsTrue(key.key == val); - Assert.IsTrue(value.value == val); - val++; + Assert.AreEqual(val, key.key, $"log scan 1: key"); + Assert.AreEqual(val, value.value, $"log scan 1: value"); } - Assert.IsTrue(totalRecords == val); + Assert.AreEqual(val, totalRecords, $"log scan 1: totalRecords"); } using (var iter = fht.Log.Scan(start, fht.Log.TailAddress, ScanBufferingMode.DoublePageBuffering)) { - int val = 0; - while (iter.GetNext(out RecordInfo recordInfo, out MyKey key, out MyValue value)) + int val; + for (val = 0; iter.GetNext(out RecordInfo recordInfo, out MyKey key, out MyValue value); ++val) { - Assert.IsTrue(key.key == val); - Assert.IsTrue(value.value == val); - val++; + Assert.AreEqual(val, key.key, $"log scan 2: key"); + Assert.AreEqual(val, value.value, $"log scan 2: value"); } - Assert.IsTrue(totalRecords == val); + Assert.AreEqual(val, totalRecords, $"log scan 2: totalRecords"); } - - s.Dispose(); } class LogObserver : IObserver<IFasterScanIterator<MyKey, MyValue>> @@ -96,7 +93,7 @@ class LogObserver : IObserver<IFasterScanIterator<MyKey, MyValue>> public void OnCompleted() { - Assert.IsTrue(val == totalRecords); + Assert.AreEqual(val == totalRecords, $"LogObserver.OnCompleted: totalRecords"); } public void OnError(Exception error) @@ -107,8 +104,8 @@ public void OnNext(IFasterScanIterator<MyKey, MyValue> iter) { while (iter.GetNext(out _, out MyKey key, out MyValue value)) { - Assert.IsTrue(key.key == val); - Assert.IsTrue(value.value == val); + Assert.AreEqual(val, key.key, $"LogObserver.OnNext: key"); + Assert.AreEqual(val, value.value, $"LogObserver.OnNext: value"); val++; } } diff --git a/cs/test/GenericStringTests.cs b/cs/test/GenericStringTests.cs index 48499927b..0a906cfcd 100644 --- a/cs/test/GenericStringTests.cs +++ b/cs/test/GenericStringTests.cs @@ -3,49 +3,60 @@ using FASTER.core; using NUnit.Framework; -using System; namespace FASTER.test { - [TestFixture] internal class GenericStringTests { private FasterKV<string, string> fht; private ClientSession<string, string, string, string, Empty, MyFuncs> session; private IDevice log, objlog; + private string path; [SetUp] public void Setup() { - log = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/GenericStringTests.log", deleteOnClose: true); - objlog = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/GenericStringTests.obj.log", deleteOnClose: true); - - fht - = new FasterKV<string, string>( - 1L << 20, // size of hash table in #cache lines; 64 bytes per cache line - new LogSettings { LogDevice = log, ObjectLogDevice = objlog, MutableFraction = 0.1, MemorySizeBits = 14, PageSizeBits = 9 } // log device - ); + path = TestUtils.MethodTestDir + "/"; - session = fht.For(new MyFuncs()).NewSession<MyFuncs>(); + // Clean up log files from previous test runs in case they weren't cleaned up + TestUtils.DeleteDirectory(path, wait: true); } [TearDown] public void TearDown() { - session.Dispose(); - fht.Dispose(); + session?.Dispose(); + session = null; + fht?.Dispose(); fht = null; - log.Dispose(); - objlog.Dispose(); - } + log?.Dispose(); + log = null; + objlog?.Dispose(); + objlog = null; + TestUtils.DeleteDirectory(path); + } [Test] [Category("FasterKV")] - public void StringBasicTest() + [Category("Smoke")] + public void StringBasicTest([Values] TestUtils.DeviceType deviceType) { - const int totalRecords = 2000; + string logfilename = path + "GenericStringTests" + deviceType.ToString() + ".log"; + string objlogfilename = path + "GenericStringTests" + deviceType.ToString() + ".obj.log"; + + log = TestUtils.CreateTestDevice(deviceType, logfilename); + objlog = TestUtils.CreateTestDevice(deviceType, objlogfilename); + + fht = new FasterKV<string, string>( + 1L << 20, // size of hash table in #cache lines; 64 bytes per cache line + new LogSettings { LogDevice = log, ObjectLogDevice = objlog, MutableFraction = 0.1, MemorySizeBits = 14, PageSizeBits = 9, SegmentSizeBits = 22 } // log device + ); + + session = fht.For(new MyFuncs()).NewSession<MyFuncs>(); + + const int totalRecords = 200; for (int i = 0; i < totalRecords; i++) { var _key = $"{i}"; @@ -53,7 +64,7 @@ public void StringBasicTest() session.Upsert(ref _key, ref _value, Empty.Default, 0); } session.CompletePending(true); - Assert.IsTrue(fht.EntryCount == totalRecords); + Assert.AreEqual(totalRecords, fht.EntryCount); for (int i = 0; i < totalRecords; i++) { @@ -68,7 +79,7 @@ public void StringBasicTest() } else { - Assert.IsTrue(output == value); + Assert.AreEqual(value, output); } } } @@ -77,7 +88,7 @@ class MyFuncs : SimpleFunctions<string, string> { public override void ReadCompletionCallback(ref string key, ref string input, ref string output, Empty ctx, Status status) { - Assert.IsTrue(output == key); + Assert.AreEqual(key, output); } } } diff --git a/cs/test/LargeObjectTests.cs b/cs/test/LargeObjectTests.cs index 8afa69337..12ea3382a 100644 --- a/cs/test/LargeObjectTests.cs +++ b/cs/test/LargeObjectTests.cs @@ -2,18 +2,12 @@ // Licensed under the MIT license. using System; -using System.Text; -using System.Threading; -using System.Threading.Tasks; -using System.Collections.Generic; -using System.Linq; using FASTER.core; using System.IO; using NUnit.Framework; namespace FASTER.test.largeobjects { - [TestFixture] internal class LargeObjectTests { @@ -25,12 +19,8 @@ internal class LargeObjectTests [SetUp] public void Setup() { - if (test_path == null) - { - test_path = TestContext.CurrentContext.TestDirectory + "/" + Path.GetRandomFileName(); - if (!Directory.Exists(test_path)) - Directory.CreateDirectory(test_path); - } + test_path = TestUtils.MethodTestDir; + TestUtils.RecreateDirectory(test_path); } [TearDown] @@ -100,7 +90,7 @@ public void LargeObjectTest(CheckpointType checkpointType) { for (int i = 0; i < output.value.value.Length; i++) { - Assert.IsTrue(output.value.value[i] == (byte)(output.value.value.Length+i)); + Assert.AreEqual((byte)(output.value.value.Length+i), output.value.value[i]); } } } diff --git a/cs/test/LockTests.cs b/cs/test/LockTests.cs index 1c97b74eb..35d131579 100644 --- a/cs/test/LockTests.cs +++ b/cs/test/LockTests.cs @@ -46,7 +46,8 @@ public override bool Unlock(ref RecordInfo recordInfo, ref int key, ref int valu [SetUp] public void Setup() { - log = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/GenericStringTests.log", deleteOnClose: true); + TestUtils.DeleteDirectory(TestUtils.MethodTestDir, wait: true); + log = Devices.CreateLogDevice(TestUtils.MethodTestDir + "/GenericStringTests.log", deleteOnClose: true); fkv = new FasterKV<int, int>(1L << 20, new LogSettings { LogDevice = log, ObjectLogDevice = null }); session = fkv.For(new Functions()).NewSession<Functions>(); } @@ -54,12 +55,14 @@ public void Setup() [TearDown] public void TearDown() { - session.Dispose(); + session?.Dispose(); session = null; - fkv.Dispose(); + fkv?.Dispose(); fkv = null; - log.Dispose(); + log?.Dispose(); log = null; + + TestUtils.DeleteDirectory(TestUtils.MethodTestDir); } [Test] diff --git a/cs/test/LogReadAsyncTests.cs b/cs/test/LogReadAsyncTests.cs index a99f986df..22270ddc1 100644 --- a/cs/test/LogReadAsyncTests.cs +++ b/cs/test/LogReadAsyncTests.cs @@ -1,25 +1,19 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -using System; + using System.Buffers; -using System.Collections.Generic; -using System.IO; -using System.Linq; using System.Threading; -using System.Threading.Tasks; using FASTER.core; using NUnit.Framework; - namespace FASTER.test { - [TestFixture] internal class LogReadAsyncTests { private FasterLog log; private IDevice device; - private string path = Path.GetTempPath() + "LogReadAsync/"; + private string path; public enum ParameterDefaultsIteratorType { @@ -31,33 +25,37 @@ public enum ParameterDefaultsIteratorType [SetUp] public void Setup() { + path = TestUtils.MethodTestDir + "/"; + // Clean up log files from previous test runs in case they weren't cleaned up - try { new DirectoryInfo(path).Delete(true); } - catch { } + TestUtils.DeleteDirectory(path, wait:true); - // Create devices \ log for test - device = Devices.CreateLogDevice(path + "LogReadAsync", deleteOnClose: true); - log = new FasterLog(new FasterLogSettings { LogDevice = device }); } [TearDown] public void TearDown() { - log.Dispose(); - device.Dispose(); + log?.Dispose(); + log = null; + device?.Dispose(); + device = null; // Clean up log files - try { new DirectoryInfo(path).Delete(true); } - catch { } + TestUtils.DeleteDirectory(TestUtils.MethodTestDir); } [Test] [Category("FasterLog")] - public void LogReadAsyncBasicTest([Values] ParameterDefaultsIteratorType iteratorType) + [Category("Smoke")] + public void LogReadAsyncBasicTest([Values] ParameterDefaultsIteratorType iteratorType, [Values] TestUtils.DeviceType deviceType) { - int entryLength = 100; - int numEntries = 1000000; + int entryLength = 20; + int numEntries = 500; int entryFlag = 9999; + string filename = path + "LogReadAsync" + deviceType.ToString() + ".log"; + device = TestUtils.CreateTestDevice(deviceType, filename); + log = new FasterLog(new FasterLogSettings { LogDevice = device,SegmentSizeBits = 22, LogCommitDir = path }); + byte[] entry = new byte[entryLength]; // Set Default entry data @@ -83,6 +81,7 @@ public void LogReadAsyncBasicTest([Values] ParameterDefaultsIteratorType iterato // Commit to the log log.Commit(true); + // Read one entry based on different parameters for AsyncReadOnly and verify switch (iteratorType) { @@ -93,9 +92,9 @@ public void LogReadAsyncBasicTest([Values] ParameterDefaultsIteratorType iterato var foundEntry = record.Result.Item1[1]; // 1 var foundTotal = record.Result.Item2; - Assert.IsTrue(foundFlagged == (byte)entryFlag, "Fail reading data - Found Flagged Entry:" + foundFlagged.ToString() + " Expected Flagged entry:" + entryFlag); - Assert.IsTrue(foundEntry == 1, "Fail reading data - Found Normal Entry:" + foundEntry.ToString() + " Expected Value: 1"); - Assert.IsTrue(foundTotal == 100, "Fail reading data - Found Total:" + foundTotal.ToString() + " Expected Total: 100"); + Assert.AreEqual((byte)entryFlag, foundFlagged, $"Fail reading Flagged Entry"); + Assert.AreEqual(1, foundEntry, $"Fail reading Normal Entry"); + Assert.AreEqual(entryLength, foundTotal, $"Fail reading Total"); break; case ParameterDefaultsIteratorType.LengthParam: @@ -105,9 +104,9 @@ record = log.ReadAsync(log.BeginAddress, 208); foundEntry = record.Result.Item1[1]; // 1 foundTotal = record.Result.Item2; - Assert.IsTrue(foundFlagged == (byte)entryFlag, "Fail reading data - Found Flagged Entry:" + foundFlagged.ToString() + " Expected Flagged entry:" + entryFlag); - Assert.IsTrue(foundEntry == 1, "Fail reading data - Found Normal Entry:" + foundEntry.ToString() + " Expected Value: 1"); - Assert.IsTrue(foundTotal == 100, "Fail reading data - Found Total:" + foundTotal.ToString() + " Expected Total: 100"); + Assert.AreEqual((byte)entryFlag, foundFlagged, $"Fail reading Flagged Entry"); + Assert.AreEqual(1, foundEntry, $"Fail reading Normal Entry"); + Assert.AreEqual(entryLength, foundTotal, $"Fail readingTotal"); break; case ParameterDefaultsIteratorType.TokenParam: @@ -119,15 +118,27 @@ record = log.ReadAsync(log.BeginAddress, 104, cts); foundEntry = record.Result.Item1[1]; // 1 foundTotal = record.Result.Item2; - Assert.IsTrue(foundFlagged == (byte)entryFlag, "Fail reading data - Found Flagged Entry:" + foundFlagged.ToString() + " Expected Flagged entry:" + entryFlag); - Assert.IsTrue(foundEntry == 1, "Fail reading data - Found Normal Entry:" + foundEntry.ToString() + " Expected Value: 1"); - Assert.IsTrue(foundTotal == 100, "Fail reading data - Found Total:" + foundTotal.ToString() + " Expected Total: 100"); + Assert.AreEqual((byte)entryFlag, foundFlagged, $"Fail readingFlagged Entry"); + Assert.AreEqual(1, foundEntry, $"Fail reading Normal Entry"); + Assert.AreEqual(entryLength, foundTotal, $"Fail reading Total"); + + // Read one entry as IMemoryOwner and verify + var recordMemoryOwner = log.ReadAsync(log.BeginAddress, MemoryPool<byte>.Shared, 104, cts); + var foundFlaggedMem = recordMemoryOwner.Result.Item1.Memory.Span[0]; // 15 + var foundEntryMem = recordMemoryOwner.Result.Item1.Memory.Span[1]; // 1 + var foundTotalMem = recordMemoryOwner.Result.Item2; + + Assert.IsTrue(foundFlagged == foundFlaggedMem, $"MemoryPool-based ReadAsync result does not match that of the byte array one. value: {foundFlaggedMem} expected: {foundFlagged}"); + Assert.IsTrue(foundEntry == foundEntryMem, $"MemoryPool-based ReadAsync result does not match that of the byte array one. value: {foundEntryMem} expected: {foundEntry}"); + Assert.IsTrue(foundTotal == foundTotalMem, $"MemoryPool-based ReadAsync result does not match that of the byte array one. value: {foundTotalMem} expected: {foundTotal}"); break; default: Assert.Fail("Unknown case ParameterDefaultsIteratorType.DefaultParams:"); break; } + + } } diff --git a/cs/test/LowMemAsyncTests.cs b/cs/test/LowMemAsyncTests.cs index 02c67bcc8..0045f726a 100644 --- a/cs/test/LowMemAsyncTests.cs +++ b/cs/test/LowMemAsyncTests.cs @@ -8,7 +8,6 @@ namespace FASTER.test.async { - [TestFixture] public class LowMemAsyncTests { @@ -20,9 +19,9 @@ public class LowMemAsyncTests [SetUp] public void Setup() { - path = TestContext.CurrentContext.TestDirectory + $"/{TestContext.CurrentContext.Test.ClassName}/"; + path = TestUtils.MethodTestDir; + TestUtils.DeleteDirectory(path, wait: true); log = new LocalMemoryDevice(1L << 30, 1L << 25, 1, latencyMs: 20); - // log = Devices.CreateLogDevice(path + "Async.log", deleteOnClose: true); Directory.CreateDirectory(path); fht1 = new FasterKV<long, long> (1L << 10, @@ -34,9 +33,11 @@ public void Setup() [TearDown] public void TearDown() { - fht1.Dispose(); - log.Dispose(); - new DirectoryInfo(path).Delete(true); + fht1?.Dispose(); + fht1 = null; + log?.Dispose(); + log = null; + TestUtils.DeleteDirectory(path); } private static async Task Populate(ClientSession<long, long, long, long, Empty, IFunctions<long, long, long, long, Empty>> s1) @@ -66,8 +67,9 @@ private static async Task Populate(ClientSession<long, long, long, long, Empty, } [Test] - [Category("FasterKV"), Category("Stress")] - public async Task ConcurrentUpsertReadAsyncTest() + [Category("FasterKV")] + [Category("Stress")] + public async Task LowMemConcurrentUpsertReadAsyncTest() { await Task.Yield(); using var s1 = fht1.NewSession(new SimpleFunctions<long, long>((a, b) => a + b)); @@ -82,13 +84,15 @@ public async Task ConcurrentUpsertReadAsyncTest() for (long key = 0; key < numOps; key++) { var (status, output) = (await readtasks[key].ConfigureAwait(false)).Complete(); - Assert.IsTrue(status == Status.OK && output == key); + Assert.AreEqual(Status.OK, status); + Assert.AreEqual(key, output); } } [Test] - [Category("FasterKV"), Category("Stress")] - public async Task ConcurrentUpsertRMWReadAsyncTest() + [Category("FasterKV")] + [Category("Stress")] + public async Task LowMemConcurrentUpsertRMWReadAsyncTest() { await Task.Yield(); using var s1 = fht1.NewSession(new SimpleFunctions<long, long>((a, b) => a + b)); @@ -122,7 +126,8 @@ public async Task ConcurrentUpsertRMWReadAsyncTest() for (long key = 0; key < numOps; key++) { var (status, output) = (await readtasks[key].ConfigureAwait(false)).Complete(); - Assert.IsTrue(status == Status.OK && output == key + key); + Assert.AreEqual(Status.OK, status); + Assert.AreEqual(key + key, output); } } } diff --git a/cs/test/ManagedLocalStorageTests.cs b/cs/test/ManagedLocalStorageTests.cs index 8f0b1867a..3e16be654 100644 --- a/cs/test/ManagedLocalStorageTests.cs +++ b/cs/test/ManagedLocalStorageTests.cs @@ -1,19 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -using System; -using System.Buffers; -using System.Collections.Generic; + using System.IO; -using System.Linq; using System.Threading; -using System.Threading.Tasks; using FASTER.core; using NUnit.Framework; - namespace FASTER.test { - [TestFixture] internal class ManageLocalStorageTests { @@ -22,44 +16,45 @@ internal class ManageLocalStorageTests private FasterLog logFullParams; private IDevice deviceFullParams; static readonly byte[] entry = new byte[100]; - private string commitPath; - + private string path; [SetUp] public void Setup() { - commitPath = TestContext.CurrentContext.TestDirectory + "/" + TestContext.CurrentContext.Test.Name + "/"; + path = TestUtils.MethodTestDir + "/"; // Clean up log files from previous test runs in case they weren't cleaned up // We loop to ensure clean-up as deleteOnClose does not always work for MLSD - while (Directory.Exists(commitPath)) - Directory.Delete(commitPath, true); + TestUtils.DeleteDirectory(path, wait:true); // Create devices \ log for test - device = new ManagedLocalStorageDevice(commitPath + "ManagedLocalStore.log", deleteOnClose: true); + device = new ManagedLocalStorageDevice(path + "ManagedLocalStore.log", deleteOnClose: true); log = new FasterLog(new FasterLogSettings { LogDevice = device, PageSizeBits = 12, MemorySizeBits = 14 }); - deviceFullParams = new ManagedLocalStorageDevice(commitPath + "ManagedLocalStoreFullParams.log", deleteOnClose: false, recoverDevice: true, preallocateFile: true, capacity: 1 << 30); + deviceFullParams = new ManagedLocalStorageDevice(path + "ManagedLocalStoreFullParams.log", deleteOnClose: false, recoverDevice: true, preallocateFile: true, capacity: 1 << 30); logFullParams = new FasterLog(new FasterLogSettings { LogDevice = device, PageSizeBits = 12, MemorySizeBits = 14 }); - } [TearDown] public void TearDown() { - log.Dispose(); - device.Dispose(); - logFullParams.Dispose(); - deviceFullParams.Dispose(); - - // Clean up log files - if (Directory.Exists(commitPath)) - Directory.Delete(commitPath, true); + log?.Dispose(); + log = null; + device?.Dispose(); + device = null; + logFullParams?.Dispose(); + logFullParams = null; + deviceFullParams?.Dispose(); + deviceFullParams = null; + + // Clean up log + TestUtils.DeleteDirectory(path); } [Test] [Category("FasterLog")] + [Category("Smoke")] public void ManagedLocalStoreBasicTest() { int entryLength = 20; @@ -80,7 +75,6 @@ public void ManagedLocalStoreBasicTest() { while (!disposeCommitThread) { - Thread.Sleep(10); log.Commit(true); } }); @@ -118,9 +112,8 @@ public void ManagedLocalStoreBasicTest() // Final commit to the log log.Commit(true); - - // flag to make sure data has been checked - bool datacheckrun = false; + + int currentEntry = 0; Thread[] th2 = new Thread[numIterThreads]; for (int t = 0; t < numIterThreads; t++) @@ -129,21 +122,17 @@ public void ManagedLocalStoreBasicTest() new Thread(() => { // Read the log - Look for the flag so know each entry is unique - int currentEntry = 0; using (var iter = log.Scan(0, long.MaxValue)) { while (iter.GetNext(out byte[] result, out _, out _)) { - // set check flag to show got in here - datacheckrun = true; - if (numEnqueueThreads == 1) - Assert.IsTrue(result[0] == (byte)currentEntry, "Fail - Result[" + currentEntry.ToString() + "]:" + result[0].ToString()); + Assert.AreEqual((byte)currentEntry, result[0]); currentEntry++; } } - Assert.IsTrue(currentEntry == numEntries * numEnqueueThreads); + Assert.AreEqual(numEntries * numEnqueueThreads, currentEntry); }); } @@ -152,9 +141,8 @@ public void ManagedLocalStoreBasicTest() for (int t = 0; t < numIterThreads; t++) th2[t].Join(); - // if data verification was skipped, then pop a fail - if (datacheckrun == false) - Assert.Fail("Failure -- data loop after log.Scan never entered so wasn't verified. "); + // Make sure number of entries is same as current - also makes sure that data verification was not skipped + Assert.AreEqual(numEntries, currentEntry); } [Test] @@ -175,24 +163,18 @@ public void ManagedLocalStoreFullParamsTest() logFullParams.Commit(true); // Verify - Assert.IsTrue(File.Exists(commitPath + "/log-commits/commit.0.0")); - Assert.IsTrue(File.Exists(commitPath + "/ManagedLocalStore.log.0")); + Assert.IsTrue(File.Exists(path + "/log-commits/commit.0.0")); + Assert.IsTrue(File.Exists(path + "/ManagedLocalStore.log.0")); // Read the log just to verify can actually read it int currentEntry = 0; - using (var iter = logFullParams.Scan(0, 100_000_000)) + using var iter = logFullParams.Scan(0, 100_000_000); + while (iter.GetNext(out byte[] result, out _, out _)) { - while (iter.GetNext(out byte[] result, out _, out _)) - { - Assert.IsTrue(result[currentEntry] == currentEntry, "Fail - Result[" + currentEntry.ToString() + "]: is not same as " + currentEntry.ToString()); - - currentEntry++; - } + Assert.AreEqual(currentEntry, result[currentEntry]); + currentEntry++; } } - - - } } diff --git a/cs/test/MemoryLogCompactionTests.cs b/cs/test/MemoryLogCompactionTests.cs index 50b69edc9..cadac217e 100644 --- a/cs/test/MemoryLogCompactionTests.cs +++ b/cs/test/MemoryLogCompactionTests.cs @@ -9,39 +9,50 @@ namespace FASTER.test { - [TestFixture] internal class MemoryLogCompactionTests { private FasterKV<ReadOnlyMemory<int>, Memory<int>> fht; private IDevice log; + private string path; [SetUp] public void Setup() { - log = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/MemoryLogCompactionTests1.log", deleteOnClose: true); - fht = new FasterKV<ReadOnlyMemory<int>, Memory<int>> - (1L << 20, new LogSettings { LogDevice = log, MemorySizeBits = 15, PageSizeBits = 12 }); + + path = TestUtils.MethodTestDir + "/"; + + // Clean up log files from previous test runs in case they weren't cleaned up + TestUtils.DeleteDirectory(path, wait: true); } [TearDown] public void TearDown() { - fht.Dispose(); + fht?.Dispose(); fht = null; - log.Dispose(); + log?.Dispose(); + log = null; + TestUtils.DeleteDirectory(path); } [Test] [Category("FasterKV")] - public void MemoryLogCompactionTest1() + [Category("Compaction")] + public void MemoryLogCompactionTest1([Values] TestUtils.DeviceType deviceType) { + + string filename = path + "MemoryLogCompactionTests1" + deviceType.ToString() + ".log"; + log = TestUtils.CreateTestDevice(deviceType, filename); + fht = new FasterKV<ReadOnlyMemory<int>, Memory<int>> + (1L << 20, new LogSettings { LogDevice = log, MemorySizeBits = 12, PageSizeBits = 10, SegmentSizeBits = 22 }); + using var session = fht.For(new MemoryCompaction()).NewSession<MemoryCompaction>(); var key = new Memory<int>(new int[20]); var value = new Memory<int>(new int[20]); - const int totalRecords = 2000; + const int totalRecords = 200; var start = fht.Log.TailAddress; for (int i = 0; i < totalRecords; i++) @@ -49,38 +60,38 @@ public void MemoryLogCompactionTest1() key.Span.Fill(i); value.Span.Fill(i); session.Upsert(key, value); - if (i < 50) + if (i < 5) session.Delete(key); // in-place delete } - for (int i = 50; i < 100; i++) + for (int i = 5; i < 10; i++) { key.Span.Fill(i); value.Span.Fill(i); session.Delete(key); // tombstone inserted } - // Compact 20% of log: + // Compact log var compactUntil = fht.Log.BeginAddress + (fht.Log.TailAddress - fht.Log.BeginAddress) / 5; compactUntil = session.Compact(compactUntil, true); - Assert.IsTrue(fht.Log.BeginAddress == compactUntil); + Assert.AreEqual(compactUntil, fht.Log.BeginAddress); - // Read 2000 keys - all but first 100 (deleted) should be present + // Read total keys - all but first 5 (deleted) should be present for (int i = 0; i < totalRecords; i++) { key.Span.Fill(i); - var (status, output) = session.Read(key, userContext: i < 100 ? 1 : 0); + var (status, output) = session.Read(key, userContext: i < 10 ? 1 : 0); if (status == Status.PENDING) session.CompletePending(true); else { - if (i < 100) - Assert.IsTrue(status == Status.NOTFOUND); + if (i < 10) + Assert.AreEqual(Status.NOTFOUND, status); else { - Assert.IsTrue(status == Status.OK); + Assert.AreEqual(Status.OK, status); Assert.IsTrue(output.Item1.Memory.Span.Slice(0, output.Item2).SequenceEqual(key.Span)); output.Item1.Dispose(); } @@ -94,10 +105,10 @@ public void MemoryLogCompactionTest1() while (iter.GetNext(out RecordInfo recordInfo)) { var k = iter.GetKey(); - Assert.IsTrue(k.Span[0] >= 100); + Assert.GreaterOrEqual(k.Span[0], 10); count++; } - Assert.IsTrue(count == 1900); + Assert.AreEqual(190, count); } // Test iteration of all log records @@ -107,11 +118,11 @@ public void MemoryLogCompactionTest1() while (iter.GetNext(out RecordInfo recordInfo)) { var k = iter.GetKey(); - Assert.IsTrue(k.Span[0] >= 50); + Assert.GreaterOrEqual(k.Span[0], 5); count++; } - // Includes 1900 live records + 50 deleted records - Assert.IsTrue(count == 1950); + // Includes 190 live records + 5 deleted records + Assert.AreEqual(195, count); } } } @@ -120,7 +131,7 @@ public class MemoryCompaction : MemoryFunctions<ReadOnlyMemory<int>, int, int> { public override void RMWCompletionCallback(ref ReadOnlyMemory<int> key, ref Memory<int> input, ref (IMemoryOwner<int>, int) output, int ctx, Status status) { - Assert.IsTrue(status == Status.OK); + Assert.AreEqual(Status.OK, status); } public override void ReadCompletionCallback(ref ReadOnlyMemory<int> key, ref Memory<int> input, ref (IMemoryOwner<int>, int) output, int ctx, Status status) @@ -129,12 +140,12 @@ public override void ReadCompletionCallback(ref ReadOnlyMemory<int> key, ref Mem { if (ctx == 0) { - Assert.IsTrue(status == Status.OK); + Assert.AreEqual(Status.OK, status); Assert.IsTrue(output.Item1.Memory.Span.Slice(0, output.Item2).SequenceEqual(key.Span)); } else { - Assert.IsTrue(status == Status.NOTFOUND); + Assert.AreEqual(Status.NOTFOUND, status); } } finally diff --git a/cs/test/MiscFASTERTests.cs b/cs/test/MiscFASTERTests.cs index 86797a9b8..b9eb27c38 100644 --- a/cs/test/MiscFASTERTests.cs +++ b/cs/test/MiscFASTERTests.cs @@ -1,19 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -using System; -using System.Text; -using System.Threading; -using System.Threading.Tasks; -using System.Collections.Generic; -using System.Linq; using FASTER.core; -using System.IO; using NUnit.Framework; namespace FASTER.test { - [TestFixture] internal class MiscFASTERTests { @@ -23,8 +15,9 @@ internal class MiscFASTERTests [SetUp] public void Setup() { - log = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/MiscFASTERTests.log", deleteOnClose: true); - objlog = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/MiscFASTERTests.obj.log", deleteOnClose: true); + TestUtils.DeleteDirectory(TestUtils.MethodTestDir, wait: true); + log = Devices.CreateLogDevice(TestUtils.MethodTestDir + "/MiscFASTERTests.log", deleteOnClose: true); + objlog = Devices.CreateLogDevice(TestUtils.MethodTestDir + "/MiscFASTERTests.obj.log", deleteOnClose: true); fht = new FasterKV<int, MyValue> (128, @@ -37,15 +30,18 @@ public void Setup() [TearDown] public void TearDown() { - fht.Dispose(); + fht?.Dispose(); fht = null; - log.Dispose(); - objlog.Dispose(); + log?.Dispose(); + log = null; + objlog?.Dispose(); + objlog = null; + TestUtils.DeleteDirectory(TestUtils.MethodTestDir); } - [Test] [Category("FasterKV")] + [Category("Smoke")] public void MixedTest1() { using var session = fht.For(new MixedFunctions()).NewSession<MixedFunctions>(); @@ -61,10 +57,10 @@ public void MixedTest1() session.RMW(ref key2, ref input2, Empty.Default, 0); session.Read(ref key, ref input1, ref output, Empty.Default, 0); - Assert.IsTrue(output.value.value == input1.value); + Assert.AreEqual(input1.value, output.value.value); session.Read(ref key2, ref input2, ref output, Empty.Default, 0); - Assert.IsTrue(output.value.value == input2.value); + Assert.AreEqual(input2.value, output.value.value); } [Test] @@ -90,10 +86,10 @@ public void MixedTest2() } else { - Assert.IsTrue(status == Status.OK); + Assert.AreEqual(Status.OK, status); } - Assert.IsTrue(g1.value.value == 23); + Assert.AreEqual(23, g1.value.value); key2 = 99999; status = session.Read(ref key2, ref input, ref g1, Empty.Default, 0); @@ -104,7 +100,7 @@ public void MixedTest2() } else { - Assert.IsTrue(status == Status.NOTFOUND); + Assert.AreEqual(Status.NOTFOUND, status); } } @@ -118,7 +114,7 @@ public void ShouldCreateNewRecordIfConcurrentWriterReturnsFalse() var log = default(IDevice); try { - log = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/hlog1.log", deleteOnClose: true); + log = Devices.CreateLogDevice(TestUtils.MethodTestDir + "/hlog1.log", deleteOnClose: true); using var fht = new FasterKV<KeyStruct, ValueStruct> (128, new LogSettings { LogDevice = log, MemorySizeBits = 29 }); using var session = fht.NewSession(copyOnWrite); diff --git a/cs/test/NativeReadCacheTests.cs b/cs/test/NativeReadCacheTests.cs index 884d8b73a..58473aa3a 100644 --- a/cs/test/NativeReadCacheTests.cs +++ b/cs/test/NativeReadCacheTests.cs @@ -1,19 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -using System; -using System.Text; -using System.Threading; -using System.Threading.Tasks; -using System.Collections.Generic; -using System.Linq; using FASTER.core; -using System.IO; using NUnit.Framework; namespace FASTER.test { - [TestFixture] internal class NativeReadCacheTests { @@ -23,8 +15,9 @@ internal class NativeReadCacheTests [SetUp] public void Setup() { + TestUtils.DeleteDirectory(TestUtils.MethodTestDir, wait: true); var readCacheSettings = new ReadCacheSettings { MemorySizeBits = 15, PageSizeBits = 10 }; - log = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/NativeReadCacheTests.log", deleteOnClose: true); + log = Devices.CreateLogDevice(TestUtils.MethodTestDir + "/NativeReadCacheTests.log", deleteOnClose: true); fht = new FasterKV<KeyStruct, ValueStruct> (1L<<20, new LogSettings { LogDevice = log, MemorySizeBits = 15, PageSizeBits = 10, ReadCacheSettings = readCacheSettings }); } @@ -32,13 +25,16 @@ public void Setup() [TearDown] public void TearDown() { - fht.Dispose(); + fht?.Dispose(); fht = null; - log.Dispose(); + log?.Dispose(); + log = null; + TestUtils.DeleteDirectory(TestUtils.MethodTestDir); } [Test] [Category("FasterKV")] + [Category("Smoke")] public void NativeDiskWriteReadCache() { using var session = fht.NewSession(new Functions()); @@ -64,7 +60,7 @@ public void NativeDiskWriteReadCache() var value = new ValueStruct { vfield1 = i, vfield2 = i + 1 }; var status = session.Read(ref key1, ref input, ref output, Empty.Default, 0); - Assert.IsTrue(status == Status.PENDING); + Assert.AreEqual(Status.PENDING, status); session.CompletePending(true); } @@ -76,9 +72,9 @@ public void NativeDiskWriteReadCache() var value = new ValueStruct { vfield1 = i, vfield2 = i + 1 }; var status = session.Read(ref key1, ref input, ref output, Empty.Default, 0); - Assert.IsTrue(status == Status.OK); - Assert.IsTrue(output.value.vfield1 == value.vfield1); - Assert.IsTrue(output.value.vfield2 == value.vfield2); + Assert.AreEqual(Status.OK, status); + Assert.AreEqual(value.vfield1, output.value.vfield1); + Assert.AreEqual(value.vfield2, output.value.vfield2); } // Evict the read cache entirely @@ -92,7 +88,7 @@ public void NativeDiskWriteReadCache() var value = new ValueStruct { vfield1 = i, vfield2 = i + 1 }; var status = session.Read(ref key1, ref input, ref output, Empty.Default, 0); - Assert.IsTrue(status == Status.PENDING); + Assert.AreEqual(Status.PENDING, status); session.CompletePending(true); } @@ -104,9 +100,9 @@ public void NativeDiskWriteReadCache() var value = new ValueStruct { vfield1 = i, vfield2 = i + 1 }; var status = session.Read(ref key1, ref input, ref output, Empty.Default, 0); - Assert.IsTrue(status == Status.OK); - Assert.IsTrue(output.value.vfield1 == value.vfield1); - Assert.IsTrue(output.value.vfield2 == value.vfield2); + Assert.AreEqual(Status.OK, status); + Assert.AreEqual(value.vfield1, output.value.vfield1); + Assert.AreEqual(value.vfield2, output.value.vfield2); } // Upsert to overwrite the read cache @@ -130,8 +126,8 @@ public void NativeDiskWriteReadCache() } else { - Assert.IsTrue(output.value.vfield1 == i + 1); - Assert.IsTrue(output.value.vfield2 == i + 2); + Assert.AreEqual(i + 1, output.value.vfield1); + Assert.AreEqual(i + 2, output.value.vfield2); } } @@ -143,9 +139,9 @@ public void NativeDiskWriteReadCache() var value = new ValueStruct { vfield1 = i + 1, vfield2 = i + 2 }; var status = session.Read(ref key1, ref input, ref output, Empty.Default, 0); - Assert.IsTrue(status == Status.OK); - Assert.IsTrue(output.value.vfield1 == value.vfield1); - Assert.IsTrue(output.value.vfield2 == value.vfield2); + Assert.AreEqual(Status.OK, status); + Assert.AreEqual(value.vfield1, output.value.vfield1); + Assert.AreEqual(value.vfield2, output.value.vfield2); } } @@ -176,7 +172,7 @@ public void NativeDiskWriteReadCache2() var value = new ValueStruct { vfield1 = i, vfield2 = i + 1 }; var status = session.Read(ref key1, ref input, ref output, Empty.Default, 0); - Assert.IsTrue(status == Status.PENDING); + Assert.AreEqual(Status.PENDING, status); session.CompletePending(true); } @@ -188,9 +184,9 @@ public void NativeDiskWriteReadCache2() var value = new ValueStruct { vfield1 = i, vfield2 = i + 1 }; var status = session.Read(ref key1, ref input, ref output, Empty.Default, 0); - Assert.IsTrue(status == Status.OK); - Assert.IsTrue(output.value.vfield1 == value.vfield1); - Assert.IsTrue(output.value.vfield2 == value.vfield2); + Assert.AreEqual(Status.OK, status); + Assert.AreEqual(value.vfield1, output.value.vfield1); + Assert.AreEqual(value.vfield2, output.value.vfield2); } // Evict the read cache entirely @@ -204,7 +200,7 @@ public void NativeDiskWriteReadCache2() var value = new ValueStruct { vfield1 = i, vfield2 = i + 1 }; var status = session.Read(ref key1, ref input, ref output, Empty.Default, 0); - Assert.IsTrue(status == Status.PENDING); + Assert.AreEqual(Status.PENDING, status); session.CompletePending(true); } @@ -216,9 +212,9 @@ public void NativeDiskWriteReadCache2() var value = new ValueStruct { vfield1 = i, vfield2 = i + 1 }; var status = session.Read(ref key1, ref input, ref output, Empty.Default, 0); - Assert.IsTrue(status == Status.OK); - Assert.IsTrue(output.value.vfield1 == value.vfield1); - Assert.IsTrue(output.value.vfield2 == value.vfield2); + Assert.AreEqual(Status.OK, status); + Assert.AreEqual(value.vfield1, output.value.vfield1); + Assert.AreEqual(value.vfield2, output.value.vfield2); } } } diff --git a/cs/test/NeedCopyUpdateTests.cs b/cs/test/NeedCopyUpdateTests.cs index 8ca61eff8..cd030c0c6 100644 --- a/cs/test/NeedCopyUpdateTests.cs +++ b/cs/test/NeedCopyUpdateTests.cs @@ -1,19 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -using System; -using System.Text; -using System.Threading; -using System.Threading.Tasks; -using System.Collections.Generic; -using System.Linq; using FASTER.core; -using System.IO; using NUnit.Framework; namespace FASTER.test { - [TestFixture] internal class NeedCopyUpdateTests { @@ -23,8 +15,9 @@ internal class NeedCopyUpdateTests [SetUp] public void Setup() { - log = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/NeedCopyUpdateTests.log", deleteOnClose: true); - objlog = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/NeedCopyUpdateTests.obj.log", deleteOnClose: true); + TestUtils.DeleteDirectory(TestUtils.MethodTestDir, wait: true); + log = Devices.CreateLogDevice(TestUtils.MethodTestDir + "/NeedCopyUpdateTests.log", deleteOnClose: true); + objlog = Devices.CreateLogDevice(TestUtils.MethodTestDir + "/NeedCopyUpdateTests.obj.log", deleteOnClose: true); fht = new FasterKV<int, RMWValue> (128, @@ -37,15 +30,18 @@ public void Setup() [TearDown] public void TearDown() { - fht.Dispose(); + fht?.Dispose(); fht = null; - log.Dispose(); - objlog.Dispose(); + log?.Dispose(); + log = null; + objlog?.Dispose(); + objlog = null; + TestUtils.DeleteDirectory(TestUtils.MethodTestDir); } - [Test] [Category("FasterKV")] + [Category("Smoke")] public void TryAddTest() { using var session = fht.For(new TryAddTestFunctions()).NewSession<TryAddTestFunctions>(); @@ -56,33 +52,33 @@ public void TryAddTest() var value2 = new RMWValue { value = 2 }; status = session.RMW(ref key, ref value1); // InitialUpdater + NOTFOUND - Assert.IsTrue(status == Status.NOTFOUND); + Assert.AreEqual(Status.NOTFOUND, status); Assert.IsTrue(value1.flag); // InitialUpdater is called status = session.RMW(ref key, ref value2); // InPlaceUpdater + OK - Assert.IsTrue(status == Status.OK); + Assert.AreEqual(Status.OK, status); fht.Log.Flush(true); status = session.RMW(ref key, ref value2); // NeedCopyUpdate + OK - Assert.IsTrue(status == Status.OK); + Assert.AreEqual(Status.OK, status); fht.Log.FlushAndEvict(true); status = session.RMW(ref key, ref value2, Status.OK, 0); // PENDING + NeedCopyUpdate + OK - Assert.IsTrue(status == Status.PENDING); + Assert.AreEqual(Status.PENDING, status); session.CompletePending(true); // Test stored value. Should be value1 var output = new RMWValue(); status = session.Read(ref key, ref value1, ref output, Status.OK, 0); - Assert.IsTrue(status == Status.PENDING); + Assert.AreEqual(Status.PENDING, status); session.CompletePending(true); status = session.Delete(ref key); - Assert.IsTrue(status == Status.OK); + Assert.AreEqual(Status.OK, status); session.CompletePending(true); fht.Log.FlushAndEvict(true); status = session.RMW(ref key, ref value2, Status.NOTFOUND, 0); // PENDING + InitialUpdater + NOTFOUND - Assert.IsTrue(status == Status.PENDING); + Assert.AreEqual(Status.PENDING, status); session.CompletePending(true); } } @@ -124,7 +120,7 @@ public override void CopyUpdater(ref int key, ref RMWValue input, ref RMWValue o public override void RMWCompletionCallback(ref int key, ref RMWValue input, ref RMWValue output, Status ctx, Status status) { - Assert.IsTrue(status == ctx); + Assert.AreEqual(ctx, status); if (status == Status.NOTFOUND) Assert.IsTrue(input.flag); // InitialUpdater is called. @@ -132,7 +128,7 @@ public override void RMWCompletionCallback(ref int key, ref RMWValue input, ref public override void ReadCompletionCallback(ref int key, ref RMWValue input, ref RMWValue output, Status ctx, Status status) { - Assert.IsTrue(input.value == output.value); + Assert.AreEqual(output.value, input.value); } } } \ No newline at end of file diff --git a/cs/test/ObjectFASTERTests.cs b/cs/test/ObjectFASTERTests.cs index d1d49ae25..1d57d5461 100644 --- a/cs/test/ObjectFASTERTests.cs +++ b/cs/test/ObjectFASTERTests.cs @@ -1,19 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -using System; -using System.Text; -using System.Threading; using System.Threading.Tasks; -using System.Collections.Generic; -using System.Linq; using FASTER.core; -using System.IO; using NUnit.Framework; namespace FASTER.test { - [TestFixture] internal class ObjectFASTERTests { @@ -23,8 +16,9 @@ internal class ObjectFASTERTests [SetUp] public void Setup() { - log = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/ObjectFASTERTests.log", deleteOnClose: true); - objlog = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/ObjectFASTERTests.obj.log", deleteOnClose: true); + TestUtils.DeleteDirectory(TestUtils.MethodTestDir, wait: true); + log = Devices.CreateLogDevice(TestUtils.MethodTestDir + "/ObjectFASTERTests.log", deleteOnClose: true); + objlog = Devices.CreateLogDevice(TestUtils.MethodTestDir + "/ObjectFASTERTests.obj.log", deleteOnClose: true); fht = new FasterKV<MyKey, MyValue> (128, @@ -37,13 +31,18 @@ public void Setup() [TearDown] public void TearDown() { - fht.Dispose(); + fht?.Dispose(); fht = null; - log.Dispose(); + log?.Dispose(); + log = null; + objlog?.Dispose(); + objlog = null; + TestUtils.DeleteDirectory(TestUtils.MethodTestDir); } [Test] [Category("FasterKV")] + [Category("Smoke")] public void ObjectInMemWriteRead() { using var session = fht.NewSession(new MyFunctions()); @@ -56,7 +55,7 @@ public void ObjectInMemWriteRead() session.Upsert(ref key1, ref value, Empty.Default, 0); session.Read(ref key1, ref input, ref output, Empty.Default, 0); - Assert.IsTrue(output.value.value == value.value); + Assert.AreEqual(value.value, output.value.value); } [Test] @@ -77,16 +76,17 @@ public void ObjectInMemWriteRead2() session.Read(ref key1, ref input1, ref output, Empty.Default, 0); - Assert.IsTrue(output.value.value == input1.value); + Assert.AreEqual(input1.value, output.value.value); session.Read(ref key2, ref input2, ref output, Empty.Default, 0); - Assert.IsTrue(output.value.value == input2.value); + Assert.AreEqual(input2.value, output.value.value); } [Test] [Category("FasterKV")] + [Category("Smoke")] public void ObjectDiskWriteRead() { using var session = fht.NewSession(new MyFunctions()); @@ -110,10 +110,10 @@ public void ObjectDiskWriteRead() } else { - Assert.IsTrue(status == Status.OK); + Assert.AreEqual(Status.OK, status); } - Assert.IsTrue(g1.value.value == 23); + Assert.AreEqual(23, g1.value.value); key2 = new MyKey { key = 99999 }; status = session.Read(ref key2, ref input, ref g1, Empty.Default, 0); @@ -124,7 +124,7 @@ public void ObjectDiskWriteRead() } else { - Assert.IsTrue(status == Status.NOTFOUND); + Assert.AreEqual(Status.NOTFOUND, status); } // Update first 100 using RMW from storage @@ -151,13 +151,13 @@ public void ObjectDiskWriteRead() { if (i < 100) { - Assert.IsTrue(output.value.value == value.value + 1); - Assert.IsTrue(output.value.value == value.value + 1); + Assert.AreEqual(value.value + 1, output.value.value); + Assert.AreEqual(value.value + 1, output.value.value); } else { - Assert.IsTrue(output.value.value == value.value); - Assert.IsTrue(output.value.value == value.value); + Assert.AreEqual(value.value, output.value.value); + Assert.AreEqual(value.value, output.value.value); } } } @@ -184,21 +184,21 @@ public async Task ReadAsyncObjectDiskWriteRead() var input = new MyInput(); var readResult = await session.ReadAsync(ref key1, ref input, Empty.Default); var result = readResult.Complete(); - Assert.IsTrue(result.Item1 == Status.OK); - Assert.IsTrue(result.Item2.value.value == 1989); + Assert.AreEqual(Status.OK, result.status); + Assert.AreEqual(1989, result.output.value.value); var key2 = new MyKey { key = 23 }; readResult = await session.ReadAsync(ref key2, ref input, Empty.Default); result = readResult.Complete(); - Assert.IsTrue(result.Item1 == Status.OK); - Assert.IsTrue(result.Item2.value.value == 23); + Assert.AreEqual(Status.OK, result.status); + Assert.AreEqual(23, result.output.value.value); var key3 = new MyKey { key = 9999 }; readResult = await session.ReadAsync(ref key3, ref input, Empty.Default); result = readResult.Complete(); - Assert.IsTrue(result.Item1 == Status.NOTFOUND); + Assert.AreEqual(Status.NOTFOUND, result.status); // Update last 100 using RMW in memory for (int i = 1900; i < 2000; i++) @@ -228,11 +228,11 @@ public async Task ReadAsyncObjectDiskWriteRead() readResult = await session.ReadAsync(ref key, ref input, Empty.Default); result = readResult.Complete(); - Assert.IsTrue(result.Item1 == Status.OK); + Assert.AreEqual(Status.OK, result.status); if (i < 100 || i >= 1900) - Assert.IsTrue(result.Item2.value.value == value.value + 1); + Assert.AreEqual(value.value + 1, result.output.value.value); else - Assert.IsTrue(result.Item2.value.value == value.value); + Assert.AreEqual(value.value, result.output.value.value); } } } diff --git a/cs/test/ObjectReadCacheTests.cs b/cs/test/ObjectReadCacheTests.cs index 75fff70ee..780aea558 100644 --- a/cs/test/ObjectReadCacheTests.cs +++ b/cs/test/ObjectReadCacheTests.cs @@ -6,7 +6,6 @@ namespace FASTER.test { - [TestFixture] internal class ObjectReadCacheTests { @@ -16,9 +15,10 @@ internal class ObjectReadCacheTests [SetUp] public void Setup() { + TestUtils.DeleteDirectory(TestUtils.MethodTestDir, wait: true); var readCacheSettings = new ReadCacheSettings { MemorySizeBits = 15, PageSizeBits = 10 }; - log = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/ObjectReadCacheTests.log", deleteOnClose: true); - objlog = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/ObjectReadCacheTests.obj.log", deleteOnClose: true); + log = Devices.CreateLogDevice(TestUtils.MethodTestDir + "/ObjectReadCacheTests.log", deleteOnClose: true); + objlog = Devices.CreateLogDevice(TestUtils.MethodTestDir + "/ObjectReadCacheTests.obj.log", deleteOnClose: true); fht = new FasterKV<MyKey, MyValue> (128, @@ -31,14 +31,18 @@ public void Setup() [TearDown] public void TearDown() { - fht.Dispose(); + fht?.Dispose(); fht = null; - log.Dispose(); - objlog.Dispose(); + log?.Dispose(); + log = null; + objlog?.Dispose(); + objlog = null; + TestUtils.DeleteDirectory(TestUtils.MethodTestDir); } [Test] [Category("FasterKV")] + [Category("Smoke")] public void ObjectDiskWriteReadCache() { using var session = fht.NewSession(new MyFunctions()); @@ -64,7 +68,7 @@ public void ObjectDiskWriteReadCache() var value = new MyValue { value = i }; var status = session.Read(ref key1, ref input, ref output, Empty.Default, 0); - Assert.IsTrue(status == Status.PENDING); + Assert.AreEqual(Status.PENDING, status); session.CompletePending(true); } @@ -76,8 +80,8 @@ public void ObjectDiskWriteReadCache() var value = new MyValue { value = i }; var status = session.Read(ref key1, ref input, ref output, Empty.Default, 0); - Assert.IsTrue(status == Status.OK); - Assert.IsTrue(output.value.value == value.value); + Assert.AreEqual(Status.OK, status); + Assert.AreEqual(value.value, output.value.value); } // Evict the read cache entirely @@ -91,7 +95,7 @@ public void ObjectDiskWriteReadCache() var value = new MyValue { value = i }; var status = session.Read(ref key1, ref input, ref output, Empty.Default, 0); - Assert.IsTrue(status == Status.PENDING); + Assert.AreEqual(Status.PENDING, status); session.CompletePending(true); } @@ -103,8 +107,8 @@ public void ObjectDiskWriteReadCache() var value = new MyValue { value = i }; var status = session.Read(ref key1, ref input, ref output, Empty.Default, 0); - Assert.IsTrue(status == Status.OK); - Assert.IsTrue(output.value.value == value.value); + Assert.AreEqual(Status.OK, status); + Assert.AreEqual(value.value, output.value.value); } @@ -134,8 +138,8 @@ public void ObjectDiskWriteReadCache() var value = new MyValue { value = i + 1 }; var status = session.Read(ref key1, ref input, ref output, Empty.Default, 0); - Assert.IsTrue(status == Status.OK); - Assert.IsTrue(output.value.value == value.value); + Assert.AreEqual(Status.OK, status); + Assert.AreEqual(value.value, output.value.value); } } @@ -166,7 +170,7 @@ public void ObjectDiskWriteReadCache2() var value = new MyValue { value = i }; var status = session.Read(ref key1, ref input, ref output, Empty.Default, 0); - Assert.IsTrue(status == Status.PENDING); + Assert.AreEqual(Status.PENDING, status); session.CompletePending(true); } @@ -178,8 +182,8 @@ public void ObjectDiskWriteReadCache2() var value = new MyValue { value = i }; var status = session.Read(ref key1, ref input, ref output, Empty.Default, 0); - Assert.IsTrue(status == Status.OK); - Assert.IsTrue(output.value.value == value.value); + Assert.AreEqual(Status.OK, status); + Assert.AreEqual(value.value, output.value.value); } // Evict the read cache entirely @@ -193,7 +197,7 @@ public void ObjectDiskWriteReadCache2() var value = new MyValue { value = i }; var status = session.Read(ref key1, ref input, ref output, Empty.Default, 0); - Assert.IsTrue(status == Status.PENDING); + Assert.AreEqual(Status.PENDING, status); session.CompletePending(true); } @@ -205,8 +209,8 @@ public void ObjectDiskWriteReadCache2() var value = new MyValue { value = i }; var status = session.Read(ref key1, ref input, ref output, Empty.Default, 0); - Assert.IsTrue(status == Status.OK); - Assert.IsTrue(output.value.value == value.value); + Assert.AreEqual(Status.OK, status); + Assert.AreEqual(value.value, output.value.value); } } } diff --git a/cs/test/ObjectRecoveryTest.cs b/cs/test/ObjectRecoveryTest.cs index c7e37b133..28c3384dc 100644 --- a/cs/test/ObjectRecoveryTest.cs +++ b/cs/test/ObjectRecoveryTest.cs @@ -29,14 +29,13 @@ internal class ObjectRecoveryTests private IDevice log, objlog; [SetUp] - public void Setup() + public void Setup() => Setup(deleteDir: true); + + public void Setup(bool deleteDir) { - if (test_path == null) - { - test_path = TestContext.CurrentContext.TestDirectory + "/" + Path.GetRandomFileName(); - if (!Directory.Exists(test_path)) - Directory.CreateDirectory(test_path); - } + test_path = TestUtils.MethodTestDir; + if (deleteDir) + TestUtils.RecreateDirectory(test_path); log = Devices.CreateLogDevice(test_path + "/ObjectRecoveryTests.log", false); objlog = Devices.CreateLogDevice(test_path + "/ObjectRecoveryTests.obj.log", false); @@ -51,25 +50,33 @@ public void Setup() } [TearDown] - public void TearDown() + public void TearDown() => TearDown(deleteDir: true); + + public void TearDown(bool deleteDir) { - fht.Dispose(); + fht?.Dispose(); fht = null; - log.Dispose(); - objlog.Dispose(); - TestUtils.DeleteDirectory(test_path); + log?.Dispose(); + log = null; + objlog?.Dispose(); + objlog = null; + + if (deleteDir) + TestUtils.DeleteDirectory(test_path); + } + + private void PrepareToRecover() + { + TearDown(deleteDir: false); + Setup(deleteDir: false); } [Test] - [Category("FasterKV")] + [Category("FasterKV"), Category("CheckpointRestore")] public async ValueTask ObjectRecoveryTest1([Values]bool isAsync) { Populate(); - fht.Dispose(); - fht = null; - log.Dispose(); - objlog.Dispose(); - Setup(); + PrepareToRecover(); if (isAsync) await fht.RecoverAsync(token, token); @@ -195,12 +202,7 @@ public unsafe void Verify(Guid cprVersion, Guid indexVersion) // Assert if expected is same as found for (long i = 0; i < numUniqueKeys; i++) { - Assert.IsTrue( - expected[i] == outputArray[i].value.numClicks, - "Debug error for AdId {0}: Expected ({1}), Found({2})", - inputArray[i].Item1.adId, - expected[i], - outputArray[i].value.numClicks); + Assert.AreEqual(expected[i], outputArray[i].value.numClicks, $"AdId {inputArray[i].Item1.adId}"); } } } diff --git a/cs/test/ObjectRecoveryTest2.cs b/cs/test/ObjectRecoveryTest2.cs index 3192d3508..12b8aa288 100644 --- a/cs/test/ObjectRecoveryTest2.cs +++ b/cs/test/ObjectRecoveryTest2.cs @@ -18,9 +18,8 @@ public class ObjectRecoveryTests2 [SetUp] public void Setup() { - FasterFolderPath = TestContext.CurrentContext.TestDirectory + "/" + Path.GetRandomFileName(); - if (!Directory.Exists(FasterFolderPath)) - Directory.CreateDirectory(FasterFolderPath); + FasterFolderPath = TestUtils.MethodTestDir; + TestUtils.RecreateDirectory(FasterFolderPath); } [TearDown] @@ -31,6 +30,9 @@ public void TearDown() [Test] [Category("FasterKV")] + [Category("CheckpointRestore")] + [Category("Smoke")] + public async ValueTask ObjectRecoveryTest2( [Values]CheckpointType checkpointType, [Range(100, 700, 300)] int iterations, @@ -129,8 +131,8 @@ private void Read(ClientSession<MyKey, MyValue, MyInput, MyOutput, MyContext, My context.FinalizeRead(ref status, ref g1); } - Assert.IsTrue(status == Status.OK); - Assert.IsTrue(g1.value.value == i.ToString()); + Assert.AreEqual(Status.OK, status); + Assert.AreEqual(i.ToString(), g1.value.value); } if (delete) @@ -147,7 +149,7 @@ private void Read(ClientSession<MyKey, MyValue, MyInput, MyOutput, MyContext, My context.FinalizeRead(ref status, ref output); } - Assert.IsTrue(status == Status.NOTFOUND); + Assert.AreEqual(Status.NOTFOUND, status); } } } diff --git a/cs/test/ObjectRecoveryTest3.cs b/cs/test/ObjectRecoveryTest3.cs index 256feb68f..b672421cb 100644 --- a/cs/test/ObjectRecoveryTest3.cs +++ b/cs/test/ObjectRecoveryTest3.cs @@ -11,7 +11,6 @@ namespace FASTER.test.recovery.objects { - [TestFixture] public class ObjectRecoveryTests3 { @@ -21,9 +20,8 @@ public class ObjectRecoveryTests3 [SetUp] public void Setup() { - FasterFolderPath = TestContext.CurrentContext.TestDirectory + "/" + Path.GetRandomFileName(); - if (!Directory.Exists(FasterFolderPath)) - Directory.CreateDirectory(FasterFolderPath); + FasterFolderPath = TestUtils.MethodTestDir; + TestUtils.RecreateDirectory(FasterFolderPath); } [TearDown] @@ -33,7 +31,7 @@ public void TearDown() } [Test] - [Category("FasterKV")] + [Category("FasterKV"), Category("CheckpointRestore")] public async ValueTask ObjectRecoveryTest3( [Values]CheckpointType checkpointType, [Values(1000)] int iterations, @@ -138,8 +136,8 @@ private void Read(ClientSession<MyKey, MyValue, MyInput, MyOutput, MyContext, My context.FinalizeRead(ref status, ref g1); } - Assert.IsTrue(status == Status.OK); - Assert.IsTrue(g1.value.value == i.ToString()); + Assert.AreEqual(Status.OK, status); + Assert.AreEqual(i.ToString(), g1.value.value); } if (delete) @@ -156,7 +154,7 @@ private void Read(ClientSession<MyKey, MyValue, MyInput, MyOutput, MyContext, My context.FinalizeRead(ref status, ref output); } - Assert.IsTrue(status == Status.NOTFOUND); + Assert.AreEqual(Status.NOTFOUND, status); } } } diff --git a/cs/test/ObjectRecoveryTestTypes.cs b/cs/test/ObjectRecoveryTestTypes.cs index 9954d08a6..5be2db8c8 100644 --- a/cs/test/ObjectRecoveryTestTypes.cs +++ b/cs/test/ObjectRecoveryTestTypes.cs @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#pragma warning disable 1591 - using System.Threading; using FASTER.core; @@ -36,7 +34,6 @@ public override void Serialize(ref AdId obj) } } - public class Input { public AdId adId; diff --git a/cs/test/ObjectTestTypes.cs b/cs/test/ObjectTestTypes.cs index 223a98ba1..fd883443c 100644 --- a/cs/test/ObjectTestTypes.cs +++ b/cs/test/ObjectTestTypes.cs @@ -1,18 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#pragma warning disable 1591 - -using System; -using System.Text; -using System.Threading; -using System.Threading.Tasks; -using System.Collections.Generic; -using System.Linq; using FASTER.core; -using System.Runtime.CompilerServices; -using System.IO; -using System.Diagnostics; using NUnit.Framework; namespace FASTER.test @@ -21,58 +10,50 @@ public class MyKey : IFasterEqualityComparer<MyKey> { public int key; - public long GetHashCode64(ref MyKey key) - { - return Utility.GetHashCode(key.key); - } + public long GetHashCode64(ref MyKey key) => Utility.GetHashCode(key.key); - public bool Equals(ref MyKey k1, ref MyKey k2) - { - return k1.key == k2.key; - } + public bool Equals(ref MyKey k1, ref MyKey k2) => k1.key == k2.key; + + public override string ToString() => this.key.ToString(); } public class MyKeySerializer : BinaryObjectSerializer<MyKey> { - public override void Deserialize(out MyKey obj) - { - obj = new MyKey(); - obj.key = reader.ReadInt32(); - } + public override void Deserialize(out MyKey obj) => obj = new MyKey { key = reader.ReadInt32() }; - public override void Serialize(ref MyKey obj) - { - writer.Write(obj.key); - } + public override void Serialize(ref MyKey obj) => writer.Write(obj.key); } - public class MyValue + public class MyValue : IFasterEqualityComparer<MyValue> { public int value; + + public long GetHashCode64(ref MyValue k) => Utility.GetHashCode(k.value); + + public bool Equals(ref MyValue k1, ref MyValue k2) => k1.value == k2.value; + + public override string ToString() => this.value.ToString(); } public class MyValueSerializer : BinaryObjectSerializer<MyValue> { - public override void Deserialize(out MyValue obj) - { - obj = new MyValue(); - obj.value = reader.ReadInt32(); - } + public override void Deserialize(out MyValue obj) => obj = new MyValue { value = reader.ReadInt32() }; - public override void Serialize(ref MyValue obj) - { - writer.Write(obj.value); - } + public override void Serialize(ref MyValue obj) => writer.Write(obj.value); } public class MyInput { public int value; + + public override string ToString() => this.value.ToString(); } public class MyOutput { public MyValue value; + + public override string ToString() => this.value.ToString(); } public class MyFunctions : FunctionsBase<MyKey, MyValue, MyInput, MyOutput, Empty> @@ -111,13 +92,13 @@ public override bool ConcurrentWriter(ref MyKey key, ref MyValue src, ref MyValu public override void ReadCompletionCallback(ref MyKey key, ref MyInput input, ref MyOutput output, Empty ctx, Status status) { - Assert.IsTrue(status == Status.OK); - Assert.IsTrue(key.key == output.value.value); + Assert.AreEqual(Status.OK, status); + Assert.AreEqual(output.value.value, key.key); } public override void RMWCompletionCallback(ref MyKey key, ref MyInput input, ref MyOutput output, Empty ctx, Status status) { - Assert.IsTrue(status == Status.OK); + Assert.AreEqual(Status.OK, status); } public override void SingleReader(ref MyKey key, ref MyInput input, ref MyValue value, ref MyOutput dst) @@ -133,6 +114,64 @@ public override void SingleWriter(ref MyKey key, ref MyValue src, ref MyValue ds } } + public class MyFunctions2 : FunctionsBase<MyValue, MyValue, MyInput, MyOutput, Empty> + { + public override void InitialUpdater(ref MyValue key, ref MyInput input, ref MyValue value, ref MyOutput output) + { + value = new MyValue { value = input.value }; + } + + public override bool InPlaceUpdater(ref MyValue key, ref MyInput input, ref MyValue value, ref MyOutput output) + { + value.value += input.value; + return true; + } + + public override bool NeedCopyUpdate(ref MyValue key, ref MyInput input, ref MyValue oldValue, ref MyOutput output) => true; + + public override void CopyUpdater(ref MyValue key, ref MyInput input, ref MyValue oldValue, ref MyValue newValue, ref MyOutput output) + { + newValue = new MyValue { value = oldValue.value + input.value }; + } + + public override void ConcurrentReader(ref MyValue key, ref MyInput input, ref MyValue value, ref MyOutput dst) + { + if (dst == default) + dst = new MyOutput(); + + dst.value = value; + } + + public override bool ConcurrentWriter(ref MyValue key, ref MyValue src, ref MyValue dst) + { + dst.value = src.value; + return true; + } + + public override void ReadCompletionCallback(ref MyValue key, ref MyInput input, ref MyOutput output, Empty ctx, Status status) + { + Assert.AreEqual(Status.OK, status); + Assert.AreEqual(key.value, output.value.value); + } + + public override void RMWCompletionCallback(ref MyValue key, ref MyInput input, ref MyOutput output, Empty ctx, Status status) + { + Assert.AreEqual(Status.OK, status); + } + + public override void SingleReader(ref MyValue key, ref MyInput input, ref MyValue value, ref MyOutput dst) + { + if (dst == default) + dst = new MyOutput(); + dst.value = value; + } + + public override void SingleWriter(ref MyValue key, ref MyValue src, ref MyValue dst) + { + dst = src; + } + } + public class MyFunctionsDelete : FunctionsBase<MyKey, MyValue, MyInput, MyOutput, int> { public override void InitialUpdater(ref MyKey key, ref MyInput input, ref MyValue value, ref MyOutput output) @@ -170,21 +209,21 @@ public override void ReadCompletionCallback(ref MyKey key, ref MyInput input, re { if (ctx == 0) { - Assert.IsTrue(status == Status.OK); - Assert.IsTrue(key.key == output.value.value); + Assert.AreEqual(Status.OK, status); + Assert.AreEqual(key.key, output.value.value); } else if (ctx == 1) { - Assert.IsTrue(status == Status.NOTFOUND); + Assert.AreEqual(Status.NOTFOUND, status); } } public override void RMWCompletionCallback(ref MyKey key, ref MyInput input, ref MyOutput output, int ctx, Status status) { if (ctx == 0) - Assert.IsTrue(status == Status.OK); + Assert.AreEqual(Status.OK, status); else if (ctx == 1) - Assert.IsTrue(status == Status.NOTFOUND); + Assert.AreEqual(Status.NOTFOUND, status); } public override void SingleReader(ref MyKey key, ref MyInput input, ref MyValue value, ref MyOutput dst) @@ -287,10 +326,10 @@ public class MyLargeFunctions : FunctionsBase<MyKey, MyLargeValue, MyInput, MyLa { public override void ReadCompletionCallback(ref MyKey key, ref MyInput input, ref MyLargeOutput output, Empty ctx, Status status) { - Assert.IsTrue(status == Status.OK); + Assert.AreEqual(Status.OK, status); for (int i = 0; i < output.value.value.Length; i++) { - Assert.IsTrue(output.value.value[i] == (byte)(output.value.value.Length + i)); + Assert.AreEqual((byte)(output.value.value.Length + i), output.value.value[i]); } } diff --git a/cs/test/Properties/AssemblyInfo.cs b/cs/test/Properties/AssemblyInfo.cs index 5fe1d74d0..d586faa81 100644 --- a/cs/test/Properties/AssemblyInfo.cs +++ b/cs/test/Properties/AssemblyInfo.cs @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. +using NUnit.Framework; using System.Reflection; -using System.Runtime.CompilerServices; using System.Runtime.InteropServices; // General Information about an assembly is controlled through the following @@ -37,3 +37,10 @@ // [assembly: AssemblyVersion("1.0.*")] //[assembly: AssemblyVersion("1.0.0.0")] //[assembly: AssemblyFileVersion("1.0.0.0")] + +// Make all fixtures in the test assembly run in parallel +#if false // disable parallelism until all the problems are resolved +//#if NETCOREAPP || NET // net461 runs x86 by default so OOMs on current memory usage by tests when running multiple tests simultaneously +[assembly: Parallelizable(ParallelScope.Fixtures)] + //[assembly: LevelOfParallelism(4)] // For reduced parallelization of net461 if we reduce memory usage in tests +#endif diff --git a/cs/test/ReadAddressTests.cs b/cs/test/ReadAddressTests.cs index 30079549d..8cc222b56 100644 --- a/cs/test/ReadAddressTests.cs +++ b/cs/test/ReadAddressTests.cs @@ -12,7 +12,7 @@ namespace FASTER.test.readaddress { -#if false // TODO temporarily deactivated due to removal of addresses from single-writer callbacks (also add UpsertAsync where we do RMWAsync/Upsert) +#if false // TODO temporarily deactivated due to removal of addresses from single-writer callbacks (also add UpsertAsync where we do RMWAsync/Upsert); update to new test format [TestFixture] public class ReadAddressTests { @@ -133,7 +133,8 @@ private class TestStore : IDisposable internal TestStore(bool useReadCache, CopyReadsToTail copyReadsToTail, bool flush) { - this.testDir = $"{TestContext.CurrentContext.TestDirectory}/{TestContext.CurrentContext.Test.Name}"; + this.testDir = TestUtils.MethodTestDir; + TestUtils.DeleteDirectory(this.testDir, wait:true); this.logDevice = Devices.CreateLogDevice($"{testDir}/hlog.log"); this.flush = flush; @@ -237,12 +238,11 @@ internal void ProcessNoKeyRecord(Status status, ref Value actualOutput, int keyO public void Dispose() { - if (!(this.fkv is null)) - this.fkv.Dispose(); - if (!(this.logDevice is null)) - this.logDevice.Dispose(); - if (!string.IsNullOrEmpty(this.testDir)) - new DirectoryInfo(this.testDir).Delete(true); + this.fkv?.Dispose(); + this.fkv = null; + this.logDevice?.Dispose(); + this.logDevice = null; + TestUtils.DeleteDirectory(this.testDir); } } @@ -576,4 +576,97 @@ public async Task ReadNoKeyAsyncTests(bool useReadCache, CopyReadsToTail copyRea } } #endif + [TestFixture] + public class ReadMinAddressTests + { + const int numOps = 500; + + private IDevice log; + private FasterKV<long, long> fht; + private ClientSession<long, long, long, long, Empty, IFunctions<long, long, long, long, Empty>> session; + + [SetUp] + public void Setup() + { + TestUtils.DeleteDirectory(TestUtils.MethodTestDir, wait: true); + + log = Devices.CreateLogDevice(TestUtils.MethodTestDir + "/SimpleRecoveryTest1.log", deleteOnClose: true); + + fht = new FasterKV<long, long>(128, + logSettings: new LogSettings { LogDevice = log, MutableFraction = 0.1, MemorySizeBits = 29 } + ); + + session = fht.NewSession(new SimpleFunctions<long, long>()); + } + + [TearDown] + public void TearDown() + { + session?.Dispose(); + session = null; + fht?.Dispose(); + fht = null; + log?.Dispose(); + log = null; + + TestUtils.DeleteDirectory(TestUtils.MethodTestDir); + } + + [Test] + [Category("FasterKV"), Category("Read")] + public async ValueTask ReadMinAddressTest([Values] bool isAsync) + { + long minAddress = core.Constants.kInvalidAddress; + var pivotKey = numOps / 2; + long makeValue(long key) => key + numOps * 10; + for (int ii = 0; ii < numOps; ii++) + { + if (ii == pivotKey) + minAddress = fht.Log.TailAddress; + session.Upsert(ii, makeValue(ii)); + } + + // Verify the test set up correctly + Assert.AreNotEqual(core.Constants.kInvalidAddress, minAddress); + + long input = 0; + + async ValueTask ReadMin(long key, Status expectedStatus) + { + Status status; + long output = 0; + if (isAsync) + (status, output) = (await session.ReadAsync(ref key, ref input, minAddress, ReadFlags.MinAddress)).Complete(); + else + { + RecordInfo recordInfo = new() { PreviousAddress = minAddress }; + status = session.Read(ref key, ref input, ref output, ref recordInfo, ReadFlags.MinAddress); + if (status == Status.PENDING) + { + Assert.IsTrue(session.CompletePendingWithOutputs(out var completedOutputs, wait: true)); + (status, output) = TestUtils.GetSinglePendingResult(completedOutputs); + } + } + Assert.AreEqual(expectedStatus, status); + if (status != Status.NOTFOUND) + Assert.AreEqual(output, makeValue(key)); + } + + async ValueTask RunTests() + { + // First read at the pivot, to verify that and make sure the rest of the test works + await ReadMin(pivotKey, Status.OK); + + // Read a Key that is below the min address + await ReadMin(pivotKey - 1, Status.NOTFOUND); + + // Read a Key that is above the min address + await ReadMin(pivotKey + 1, Status.OK); + } + + await RunTests(); + fht.Log.FlushAndEvict(wait: true); + await RunTests(); + } + } } diff --git a/cs/test/RecoverContinueTests.cs b/cs/test/RecoverContinueTests.cs index ed3af26d9..d2c2f6832 100644 --- a/cs/test/RecoverContinueTests.cs +++ b/cs/test/RecoverContinueTests.cs @@ -1,4 +1,5 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. + +// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. using System.Threading; @@ -17,31 +18,33 @@ internal class RecoverContinueTests private FasterKV<AdId, NumClicks> fht3; private IDevice log; private int numOps; - private string checkpointDir = TestContext.CurrentContext.TestDirectory + "/checkpoints3"; + private string checkpointDir; [SetUp] public void Setup() { - log = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/RecoverContinueTests.log", deleteOnClose: true); + TestUtils.DeleteDirectory(TestUtils.MethodTestDir, wait: true); + log = Devices.CreateLogDevice(TestUtils.MethodTestDir + "/RecoverContinueTests.log", deleteOnClose: true); + checkpointDir = TestUtils.MethodTestDir + "/checkpoints3"; Directory.CreateDirectory(checkpointDir); fht1 = new FasterKV <AdId, NumClicks> - (128, + (16, logSettings: new LogSettings { LogDevice = log, MutableFraction = 0.1, MemorySizeBits = 29 }, checkpointSettings: new CheckpointSettings { CheckpointDir = checkpointDir, CheckPointType = CheckpointType.Snapshot } ); fht2 = new FasterKV <AdId, NumClicks> - (128, + (16, logSettings: new LogSettings { LogDevice = log, MutableFraction = 0.1, MemorySizeBits = 29 }, checkpointSettings: new CheckpointSettings { CheckpointDir = checkpointDir, CheckPointType = CheckpointType.Snapshot } ); fht3 = new FasterKV <AdId, NumClicks> - (128, + (16, logSettings: new LogSettings { LogDevice = log, MutableFraction = 0.1, MemorySizeBits = 29 }, checkpointSettings: new CheckpointSettings { CheckpointDir = checkpointDir, CheckPointType = CheckpointType.Snapshot } ); @@ -52,18 +55,22 @@ public void Setup() [TearDown] public void TearDown() { - fht1.Dispose(); - fht2.Dispose(); - fht3.Dispose(); + fht1?.Dispose(); + fht2?.Dispose(); + fht3?.Dispose(); fht1 = null; fht2 = null; fht3 = null; - log.Dispose(); + log?.Dispose(); + log = null; + TestUtils.DeleteDirectory(TestUtils.MethodTestDir); TestUtils.DeleteDirectory(checkpointDir); } [Test] [Category("FasterKV")] + [Category("CheckpointRestore")] + [Category("Smoke")] public async ValueTask RecoverContinueTest([Values]bool isAsync) { long sno = 0; @@ -71,7 +78,7 @@ public async ValueTask RecoverContinueTest([Values]bool isAsync) var firstsession = fht1.For(new AdSimpleFunctions()).NewSession<AdSimpleFunctions>("first"); IncrementAllValues(ref firstsession, ref sno); fht1.TakeFullCheckpoint(out _); - fht1.CompleteCheckpointAsync().GetAwaiter().GetResult(); + fht1.CompleteCheckpointAsync().AsTask().GetAwaiter().GetResult(); firstsession.Dispose(); // Check if values after checkpoint are correct @@ -91,10 +98,10 @@ public async ValueTask RecoverContinueTest([Values]bool isAsync) // Continue and increment values var continuesession = fht2.For(new AdSimpleFunctions()).ResumeSession<AdSimpleFunctions>("first", out CommitPoint cp); long newSno = cp.UntilSerialNo; - Assert.IsTrue(newSno == sno - 1); + Assert.AreEqual(sno - 1, newSno); IncrementAllValues(ref continuesession, ref sno); fht2.TakeFullCheckpoint(out _); - fht2.CompleteCheckpointAsync().GetAwaiter().GetResult(); + fht2.CompleteCheckpointAsync().AsTask().GetAwaiter().GetResult(); continuesession.Dispose(); // Check if values after continue checkpoint are correct @@ -110,7 +117,7 @@ public async ValueTask RecoverContinueTest([Values]bool isAsync) var nextsession = fht3.For(new AdSimpleFunctions()).ResumeSession<AdSimpleFunctions>("first", out cp); long newSno2 = cp.UntilSerialNo; - Assert.IsTrue(newSno2 == sno - 1); + Assert.AreEqual(sno - 1, newSno2); CheckAllValues(ref nextsession, 2); nextsession.Dispose(); } @@ -130,7 +137,7 @@ private void CheckAllValues( fht.CompletePending(true); else { - Assert.IsTrue(outputArg.value.numClicks == value); + Assert.AreEqual(value, outputArg.value.numClicks); } } @@ -158,8 +165,8 @@ public class AdSimpleFunctions : FunctionsBase<AdId, NumClicks, AdInput, Output, { public override void ReadCompletionCallback(ref AdId key, ref AdInput input, ref Output output, Empty ctx, Status status) { - Assert.IsTrue(status == Status.OK); - Assert.IsTrue(output.value.numClicks == key.adId); + Assert.AreEqual(Status.OK, status); + Assert.AreEqual(key.adId, output.value.numClicks); } // Read functions diff --git a/cs/test/RecoverReadOnlyTest.cs b/cs/test/RecoverReadOnlyTest.cs index 43f57bb35..ed81c90b4 100644 --- a/cs/test/RecoverReadOnlyTest.cs +++ b/cs/test/RecoverReadOnlyTest.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. + using System; using System.IO; using System.Threading; @@ -8,12 +9,10 @@ using NUnit.Framework; using System.Text; - //** Note - this test is based on FasterLogPubSub sample found in the samples directory. namespace FASTER.test { - [TestFixture] internal class BasicRecoverReadOnly { @@ -22,17 +21,17 @@ internal class BasicRecoverReadOnly private FasterLog logReadOnly; private IDevice deviceReadOnly; - private static string path = Path.GetTempPath() + "BasicRecoverAsyncReadOnly/"; + private static string path; const int commitPeriodMs = 2000; const int restorePeriodMs = 1000; [SetUp] public void Setup() { + path = TestUtils.MethodTestDir + "/"; // Clean up log files from previous test runs in case they weren't cleaned up - try { new DirectoryInfo(path).Delete(true); } - catch {} + TestUtils.DeleteDirectory(path, wait:true); // Create devices \ log for test device = Devices.CreateLogDevice(path + "Recover", deleteOnClose: true); @@ -44,19 +43,23 @@ public void Setup() [TearDown] public void TearDown() { - log.Dispose(); - device.Dispose(); - logReadOnly.Dispose(); - deviceReadOnly.Dispose(); + log?.Dispose(); + log = null; + device?.Dispose(); + device = null; + logReadOnly?.Dispose(); + logReadOnly = null; + deviceReadOnly?.Dispose(); + deviceReadOnly = null; // Clean up log files - try { new DirectoryInfo(path).Delete(true); } - catch { } + TestUtils.DeleteDirectory(path); } [Test] [Category("FasterLog")] + [Category("Smoke")] public void RecoverReadOnlyAsyncBasicTest() { using var cts = new CancellationTokenSource(); @@ -67,7 +70,9 @@ public void RecoverReadOnlyAsyncBasicTest() // Run consumer on SEPARATE read-only FasterLog instance var consumer = SeparateConsumerAsync(cts.Token); - // Give it some time to run a bit - similar to waiting for things to run before hitting cancel + //** Give it some time to run a bit + //** Acceptable use of using sleep for this spot + //** Similar to waiting for things to run before manually hitting cancel from a command prompt Thread.Sleep(3000); cts.Cancel(); @@ -81,7 +86,6 @@ public void RecoverReadOnlyAsyncBasicTest() //**** Helper Functions - based off of FasterLogPubSub sample *** static async Task CommitterAsync(FasterLog log, CancellationToken cancellationToken) { - while (!cancellationToken.IsCancellationRequested) { await Task.Delay(TimeSpan.FromMilliseconds(commitPeriodMs), cancellationToken); @@ -107,7 +111,6 @@ static async Task ProducerAsync(FasterLog log, CancellationToken cancellationTok // to the primary FasterLog's commits. public async Task SeparateConsumerAsync(CancellationToken cancellationToken) { - var _ = BeginRecoverReadOnlyLoop(logReadOnly, cancellationToken); // This enumerator waits asynchronously when we have reached the committed tail of the duplicate FasterLog. When RecoverReadOnly @@ -119,7 +122,6 @@ public async Task SeparateConsumerAsync(CancellationToken cancellationToken) } } - static async Task BeginRecoverReadOnlyLoop(FasterLog log, CancellationToken cancellationToken) { while (!cancellationToken.IsCancellationRequested) @@ -129,7 +131,6 @@ static async Task BeginRecoverReadOnlyLoop(FasterLog log, CancellationToken canc await log.RecoverReadOnlyAsync(cancellationToken); } } - } } diff --git a/cs/test/RecoveryChecks.cs b/cs/test/RecoveryChecks.cs index 4c76ad0d4..5ce755078 100644 --- a/cs/test/RecoveryChecks.cs +++ b/cs/test/RecoveryChecks.cs @@ -3,7 +3,6 @@ using System.Threading.Tasks; using FASTER.core; -using System.IO; using NUnit.Framework; using FASTER.test.recovery.sumstore; using System; @@ -17,18 +16,14 @@ public enum DeviceMode Cloud } - [TestFixture] - public class RecoveryChecks + public class RecoveryCheckBase { - IDevice log; - const int numOps = 5000; - AdId[] inputArray; - string path; - public const string EMULATED_STORAGE_STRING = "UseDevelopmentStorage=true;"; - public const string TEST_CONTAINER = "recoverychecks"; + protected IDevice log; + protected const int numOps = 5000; + protected AdId[] inputArray; + protected string path; - [SetUp] - public void Setup() + protected void BaseSetup() { inputArray = new AdId[numOps]; for (int i = 0; i < numOps; i++) @@ -36,23 +31,24 @@ public void Setup() inputArray[i].adId = i; } - path = TestContext.CurrentContext.TestDirectory + "/RecoveryChecks/"; + path = TestUtils.MethodTestDir + "/"; log = Devices.CreateLogDevice(path + "hlog.log", deleteOnClose: true); - Directory.CreateDirectory(path); + TestUtils.RecreateDirectory(path); } - [TearDown] - public void TearDown() + protected void BaseTearDown() { - log.Dispose(); - new DirectoryInfo(path).Delete(true); + log?.Dispose(); + log = null; + TestUtils.DeleteDirectory(path); } public class MyFunctions : SimpleFunctions<long, long> { public override void ReadCompletionCallback(ref long key, ref long input, ref long output, Empty ctx, Status status) { - Assert.IsTrue(status == Status.OK && output == key); + Assert.AreEqual(Status.OK, status, $"status = {status}"); + Assert.AreEqual(key, output, $"output = {output}"); } } @@ -60,15 +56,34 @@ public class MyFunctions2 : SimpleFunctions<long, long> { public override void ReadCompletionCallback(ref long key, ref long input, ref long output, Empty ctx, Status status) { + Verify(status, key, output); + } + + internal static void Verify(Status status, long key, long output) + { + Assert.AreEqual(Status.OK, status); if (key < 950) - Assert.IsTrue(status == Status.OK && output == key); + Assert.AreEqual(key, output); else - Assert.IsTrue(status == Status.OK && output == key + 1); + Assert.AreEqual(key + 1, output); } } + } + + [TestFixture] + public class RecoveryCheck1Tests : RecoveryCheckBase + { + [SetUp] + public void Setup() => BaseSetup(); + + [TearDown] + public void TearDown() => BaseTearDown(); [Test] [Category("FasterKV")] + [Category("CheckpointRestore")] + [Category("Smoke")] + public async ValueTask RecoveryCheck1([Values] CheckpointType checkpointType, [Values] bool isAsync, [Values] bool useReadCache, [Values(128, 1<<10)]int size) { using var fht1 = new FasterKV<long, long> @@ -91,7 +106,10 @@ public async ValueTask RecoveryCheck1([Values] CheckpointType checkpointType, [V long output = default; var status = s1.Read(ref key, ref output); if (status != Status.PENDING) - Assert.IsTrue(status == Status.OK && output == key); + { + Assert.AreEqual(Status.OK, status, $"status = {status}"); + Assert.AreEqual(key, output, $"output = {output}"); + } } s1.CompletePending(true); } @@ -111,7 +129,7 @@ public async ValueTask RecoveryCheck1([Values] CheckpointType checkpointType, [V } else { - var (status, token) = task.GetAwaiter().GetResult(); + var (status, token) = task.AsTask().GetAwaiter().GetResult(); fht2.Recover(default, token); } @@ -125,13 +143,27 @@ public async ValueTask RecoveryCheck1([Values] CheckpointType checkpointType, [V long output = default; var status = s2.Read(ref key, ref output); if (status != Status.PENDING) - Assert.IsTrue(status == Status.OK && output == key); + { + Assert.AreEqual(Status.OK, status, $"status = {status}"); + Assert.AreEqual(key, output, $"output = {output}"); + } } s2.CompletePending(true); } + } + + [TestFixture] + public class RecoveryCheck2Tests : RecoveryCheckBase + { + [SetUp] + public void Setup() => BaseSetup(); + + [TearDown] + public void TearDown() => BaseTearDown(); + [Test] - [Category("FasterKV")] + [Category("FasterKV"), Category("CheckpointRestore")] public async ValueTask RecoveryCheck2([Values] CheckpointType checkpointType, [Values] bool isAsync, [Values] bool useReadCache, [Values(128, 1 << 10)] int size) { using var fht1 = new FasterKV<long, long> @@ -163,7 +195,10 @@ public async ValueTask RecoveryCheck2([Values] CheckpointType checkpointType, [V long output = default; var status = s1.Read(ref key, ref output); if (status != Status.PENDING) - Assert.IsTrue(status == Status.OK && output == key); + { + Assert.AreEqual(Status.OK, status, $"status = {status}"); + Assert.AreEqual(key, output, $"output = {output}"); + } } s1.CompletePending(true); } @@ -177,7 +212,7 @@ public async ValueTask RecoveryCheck2([Values] CheckpointType checkpointType, [V } else { - var(status, token) = task.GetAwaiter().GetResult(); + var(status, token) = task.AsTask().GetAwaiter().GetResult(); fht2.Recover(default, token); } @@ -191,14 +226,27 @@ public async ValueTask RecoveryCheck2([Values] CheckpointType checkpointType, [V long output = default; var status = s2.Read(ref key, ref output); if (status != Status.PENDING) - Assert.IsTrue(status == Status.OK && output == key); + { + Assert.AreEqual(Status.OK, status, $"status = {status}"); + Assert.AreEqual(key, output, $"output = {output}"); + } } s2.CompletePending(true); } } + } + + [TestFixture] + public class RecoveryCheck3Tests : RecoveryCheckBase + { + [SetUp] + public void Setup() => BaseSetup(); + + [TearDown] + public void TearDown() => BaseTearDown(); [Test] - [Category("FasterKV")] + [Category("FasterKV"), Category("CheckpointRestore")] public async ValueTask RecoveryCheck3([Values] CheckpointType checkpointType, [Values] bool isAsync, [Values] bool useReadCache, [Values(128, 1 << 10)] int size) { using var fht1 = new FasterKV<long, long> @@ -230,7 +278,10 @@ public async ValueTask RecoveryCheck3([Values] CheckpointType checkpointType, [V long output = default; var status = s1.Read(ref key, ref output); if (status != Status.PENDING) - Assert.IsTrue(status == Status.OK && output == key); + { + Assert.AreEqual(Status.OK, status, $"status = {status}"); + Assert.AreEqual(key, output, $"output = {output}"); + } } s1.CompletePending(true); } @@ -244,7 +295,7 @@ public async ValueTask RecoveryCheck3([Values] CheckpointType checkpointType, [V } else { - var (status, token) = task.GetAwaiter().GetResult(); + var (status, token) = task.AsTask().GetAwaiter().GetResult(); fht2.Recover(default, token); } @@ -258,14 +309,28 @@ public async ValueTask RecoveryCheck3([Values] CheckpointType checkpointType, [V long output = default; var status = s2.Read(ref key, ref output); if (status != Status.PENDING) - Assert.IsTrue(status == Status.OK && output == key); + { + Assert.AreEqual(Status.OK, status, $"status = {status}"); + Assert.AreEqual(key, output, $"output = {output}"); + } } s2.CompletePending(true); } } + } + + [TestFixture] + public class RecoveryCheck4Tests : RecoveryCheckBase + { + [SetUp] + public void Setup() => BaseSetup(); + + [TearDown] + public void TearDown() => BaseTearDown(); + [Test] - [Category("FasterKV")] + [Category("FasterKV"), Category("CheckpointRestore")] public async ValueTask RecoveryCheck4([Values] CheckpointType checkpointType, [Values] bool isAsync, [Values] bool useReadCache, [Values(128, 1 << 10)] int size) { using var fht1 = new FasterKV<long, long> @@ -297,13 +362,16 @@ public async ValueTask RecoveryCheck4([Values] CheckpointType checkpointType, [V long output = default; var status = s1.Read(ref key, ref output); if (status != Status.PENDING) - Assert.IsTrue(status == Status.OK && output == key); + { + Assert.AreEqual(Status.OK, status, $"status = {status}"); + Assert.AreEqual(key, output, $"output = {output}"); + } } s1.CompletePending(true); } if (i == 0) - fht1.TakeIndexCheckpointAsync().GetAwaiter().GetResult(); + fht1.TakeIndexCheckpointAsync().AsTask().GetAwaiter().GetResult(); var task = fht1.TakeHybridLogCheckpointAsync(checkpointType); @@ -314,7 +382,7 @@ public async ValueTask RecoveryCheck4([Values] CheckpointType checkpointType, [V } else { - var (status, token) = task.GetAwaiter().GetResult(); + var (status, token) = task.AsTask().GetAwaiter().GetResult(); fht2.Recover(default, token); } @@ -328,14 +396,29 @@ public async ValueTask RecoveryCheck4([Values] CheckpointType checkpointType, [V long output = default; var status = s2.Read(ref key, ref output); if (status != Status.PENDING) - Assert.IsTrue(status == Status.OK && output == key); + { + Assert.AreEqual(Status.OK, status, $"status = {status}"); + Assert.AreEqual(key, output, $"output = {output}"); + } } s2.CompletePending(true); } } + } + + [TestFixture] + public class RecoveryCheck5Tests : RecoveryCheckBase + { + [SetUp] + public void Setup() => BaseSetup(); + + [TearDown] + public void TearDown() => BaseTearDown(); + [Test] [Category("FasterKV")] + [Category("CheckpointRestore")] public async ValueTask RecoveryCheck5([Values] CheckpointType checkpointType, [Values] bool isAsync, [Values] bool useReadCache, [Values(128, 1 << 10)] int size) { using var fht1 = new FasterKV<long, long> @@ -358,7 +441,10 @@ public async ValueTask RecoveryCheck5([Values] CheckpointType checkpointType, [V long output = default; var status = s1.Read(ref key, ref output); if (status != Status.PENDING) - Assert.IsTrue(status == Status.OK && output == key); + { + Assert.AreEqual(Status.OK, status, $"status = {status}"); + Assert.AreEqual(key, output, $"output = {output}"); + } } s1.CompletePending(true); } @@ -370,7 +456,10 @@ public async ValueTask RecoveryCheck5([Values] CheckpointType checkpointType, [V long output = default; var status = s1.Read(ref key, ref output); if (status != Status.PENDING) - Assert.IsTrue(status == Status.OK && output == key); + { + Assert.AreEqual(Status.OK, status, $"status = {status}"); + Assert.AreEqual(key, output, $"output = {output}"); + } } s1.CompletePending(true); @@ -389,7 +478,7 @@ public async ValueTask RecoveryCheck5([Values] CheckpointType checkpointType, [V } else { - var (status, token) = task.GetAwaiter().GetResult(); + var (status, token) = task.AsTask().GetAwaiter().GetResult(); fht2.Recover(default, token); } @@ -403,13 +492,28 @@ public async ValueTask RecoveryCheck5([Values] CheckpointType checkpointType, [V long output = default; var status = s2.Read(ref key, ref output); if (status != Status.PENDING) - Assert.IsTrue(status == Status.OK && output == key); + { + Assert.AreEqual(Status.OK, status, $"status = {status}"); + Assert.AreEqual(key, output, $"output = {output}"); + } } s2.CompletePending(true); } + } + + [TestFixture] + public class RecoveryCheckSnapshotTests : RecoveryCheckBase + { + [SetUp] + public void Setup() => BaseSetup(); + [TearDown] + public void TearDown() => BaseTearDown(); [Test] + [Category("FasterKV")] + [Category("CheckpointRestore")] + [Category("Smoke")] public async ValueTask IncrSnapshotRecoveryCheck([Values] DeviceMode deviceMode) { ICheckpointManager checkpointManager; @@ -417,18 +521,14 @@ public async ValueTask IncrSnapshotRecoveryCheck([Values] DeviceMode deviceMode) { checkpointManager = new DeviceLogCommitCheckpointManager( new LocalStorageNamedDeviceFactory(), - new DefaultCheckpointNamingScheme(TestContext.CurrentContext.TestDirectory + $"/RecoveryChecks/IncrSnapshotRecoveryCheck")); + new DefaultCheckpointNamingScheme(TestUtils.MethodTestDir + "/checkpoints/")); // PurgeAll deletes this directory } else { - if ("yes".Equals(Environment.GetEnvironmentVariable("RunAzureTests"))) - { - checkpointManager = new DeviceLogCommitCheckpointManager( - new AzureStorageNamedDeviceFactory(EMULATED_STORAGE_STRING), - new DefaultCheckpointNamingScheme($"{TEST_CONTAINER}/IncrSnapshotRecoveryCheck")); - } - else - return; + TestUtils.IgnoreIfNotRunningAzureTests(); + checkpointManager = new DeviceLogCommitCheckpointManager( + new AzureStorageNamedDeviceFactory(TestUtils.AzureEmulatedStorageString), + new DefaultCheckpointNamingScheme($"{TestUtils.AzureTestContainer}/{TestUtils.AzureTestDirectory}")); } await IncrSnapshotRecoveryCheck(checkpointManager); @@ -436,7 +536,7 @@ public async ValueTask IncrSnapshotRecoveryCheck([Values] DeviceMode deviceMode) checkpointManager.Dispose(); } - public async ValueTask IncrSnapshotRecoveryCheck(ICheckpointManager checkpointManager) + private async ValueTask IncrSnapshotRecoveryCheck(ICheckpointManager checkpointManager) { using var fht1 = new FasterKV<long, long> (1 << 10, @@ -451,31 +551,33 @@ public async ValueTask IncrSnapshotRecoveryCheck(ICheckpointManager checkpointMa } var task = fht1.TakeHybridLogCheckpointAsync(CheckpointType.Snapshot); - var result = await task; + var (success, token) = await task; for (long key = 950; key < 1000; key++) { s1.Upsert(key, key+1); } + var version1 = fht1.CurrentVersion; var _result1 = fht1.TakeHybridLogCheckpoint(out var _token1, CheckpointType.Snapshot, true); await fht1.CompleteCheckpointAsync(); Assert.IsTrue(_result1); - Assert.IsTrue(_token1 == result.token); + Assert.AreEqual(token, _token1); for (long key = 1000; key < 2000; key++) { s1.Upsert(key, key + 1); } + var version2 = fht1.CurrentVersion; var _result2 = fht1.TakeHybridLogCheckpoint(out var _token2, CheckpointType.Snapshot, true); await fht1.CompleteCheckpointAsync(); Assert.IsTrue(_result2); - Assert.IsTrue(_token2 == result.token); - + Assert.AreEqual(token, _token2); + // Test that we can recover to latest version using var fht2 = new FasterKV<long, long> (1 << 10, logSettings: new LogSettings { LogDevice = log, MutableFraction = 1, PageSizeBits = 10, MemorySizeBits = 14, ReadCacheSettings = null }, @@ -484,7 +586,7 @@ public async ValueTask IncrSnapshotRecoveryCheck(ICheckpointManager checkpointMa await fht2.RecoverAsync(default, _token2); - Assert.IsTrue(fht1.Log.TailAddress == fht2.Log.TailAddress, $"fht1 tail = {fht1.Log.TailAddress}; fht2 tail = {fht2.Log.TailAddress}"); + Assert.AreEqual(fht2.Log.TailAddress, fht1.Log.TailAddress); using var s2 = fht2.NewSession(new MyFunctions2()); for (long key = 0; key < 2000; key++) @@ -493,13 +595,32 @@ public async ValueTask IncrSnapshotRecoveryCheck(ICheckpointManager checkpointMa var status = s2.Read(ref key, ref output); if (status != Status.PENDING) { - if (key < 950) - Assert.IsTrue(status == Status.OK && output == key); - else - Assert.IsTrue(status == Status.OK && output == key + 1); + MyFunctions2.Verify(status, key, output); } } s2.CompletePending(true); + + // Test that we can recover to earlier version + using var fht3 = new FasterKV<long, long> + (1 << 10, + logSettings: new LogSettings { LogDevice = log, MutableFraction = 1, PageSizeBits = 10, MemorySizeBits = 14, ReadCacheSettings = null }, + checkpointSettings: new CheckpointSettings { CheckpointManager = checkpointManager } + ); + + await fht3.RecoverAsync(recoverTo: version1); + + Assert.IsTrue(fht3.EntryCount == 1000); + using var s3 = fht3.NewSession(new MyFunctions2()); + for (long key = 0; key < 1000; key++) + { + long output = default; + var status = s3.Read(ref key, ref output); + if (status != Status.PENDING) + { + MyFunctions2.Verify(status, key, output); + } + } + s3.CompletePending(true); } } } diff --git a/cs/test/RecoveryTestTypes.cs b/cs/test/RecoveryTestTypes.cs index 2bbf26796..a0f0bf74c 100644 --- a/cs/test/RecoveryTestTypes.cs +++ b/cs/test/RecoveryTestTypes.cs @@ -1,16 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -using System; -using System.Text; using System.Threading; -using System.Threading.Tasks; -using System.Collections.Generic; -using System.Linq; using FASTER.core; -using System.Runtime.CompilerServices; -using System.IO; -using System.Diagnostics; namespace FASTER.test.recovery.sumstore { @@ -18,30 +10,31 @@ public struct AdId : IFasterEqualityComparer<AdId> { public long adId; - public long GetHashCode64(ref AdId key) - { - return Utility.GetHashCode(key.adId); - } - public bool Equals(ref AdId k1, ref AdId k2) - { - return k1.adId == k2.adId; - } + public long GetHashCode64(ref AdId key) => Utility.GetHashCode(key.adId); + + public bool Equals(ref AdId k1, ref AdId k2) => k1.adId == k2.adId; } public struct AdInput { public AdId adId; public NumClicks numClicks; + + public override string ToString() => $"id = {adId.adId}, clicks = {numClicks.numClicks}"; } public struct NumClicks { public long numClicks; + + public override string ToString() => numClicks.ToString(); } public struct Output { public NumClicks value; + + public override string ToString() => value.ToString(); } public class Functions : FunctionsBase<AdId, NumClicks, AdInput, Output, Empty> diff --git a/cs/test/RecoveryTests.cs b/cs/test/RecoveryTests.cs index 9255a100e..fec01d5f3 100644 --- a/cs/test/RecoveryTests.cs +++ b/cs/test/RecoveryTests.cs @@ -11,127 +11,128 @@ namespace FASTER.test.recovery.sumstore { [TestFixture] - internal class RecoveryTests + internal class DeviceTypeRecoveryTests { - const long numUniqueKeys = (1 << 14); - const long keySpace = (1L << 14); - const long numOps = (1L << 19); - const long completePendingInterval = (1L << 10); - const long checkpointInterval = (1L << 16); + internal const long numUniqueKeys = (1 << 12); + internal const long keySpace = (1L << 14); + internal const long numOps = (1L << 17); + internal const long completePendingInterval = (1L << 10); + internal const long checkpointInterval = (1L << 14); private FasterKV<AdId, NumClicks> fht; - private string test_path; - private List<Guid> logTokens, indexTokens; + private string path; + private readonly List<Guid> logTokens = new(); + private readonly List<Guid> indexTokens = new(); private IDevice log; [SetUp] public void Setup() { - if (test_path == null) - { - test_path = TestContext.CurrentContext.TestDirectory + "/" + Path.GetRandomFileName(); - if (!Directory.Exists(test_path)) - Directory.CreateDirectory(test_path); - } + path = TestUtils.MethodTestDir + "/"; - log = Devices.CreateLogDevice(test_path + "/FullRecoveryTests.log"); + // Only clean these in the initial Setup, as tests use the other Setup() overload to recover + logTokens.Clear(); + indexTokens.Clear(); + TestUtils.DeleteDirectory(path, true); + } - fht = new FasterKV<AdId, NumClicks> - (keySpace, - new LogSettings {LogDevice = log}, - new CheckpointSettings {CheckpointDir = test_path, CheckPointType = CheckpointType.Snapshot} + private void Setup(TestUtils.DeviceType deviceType) + { + log = TestUtils.CreateTestDevice(deviceType, path + "Test.log"); + fht = new FasterKV<AdId, NumClicks>(keySpace, + //new LogSettings { LogDevice = log, MemorySizeBits = 14, PageSizeBits = 9 }, // locks ups at session.RMW line in Populate() for Local Memory + new LogSettings { LogDevice = log, SegmentSizeBits = 25 }, + new CheckpointSettings { CheckpointDir = path, CheckPointType = CheckpointType.Snapshot } ); } [TearDown] - public void TearDown() + public void TearDown() => TearDown(deleteDir: true); + + private void TearDown(bool deleteDir) { - fht.Dispose(); + fht?.Dispose(); fht = null; - log.Dispose(); - TestUtils.DeleteDirectory(test_path); + log?.Dispose(); + log = null; + + // Do NOT clean up here unless specified, as tests use this TearDown() to prepare for recovery + if (deleteDir) + TestUtils.DeleteDirectory(path); + } + + private void PrepareToRecover(TestUtils.DeviceType deviceType) + { + TearDown(deleteDir: false); + Setup(deviceType); } [Test] [Category("FasterKV")] - public async ValueTask RecoveryTestSeparateCheckpoint([Values]bool isAsync) + [Category("CheckpointRestore")] + public async ValueTask RecoveryTestSeparateCheckpoint([Values]bool isAsync, [Values] TestUtils.DeviceType deviceType) { + Setup(deviceType); Populate(SeparateCheckpointAction); for (var i = 0; i < logTokens.Count; i++) { if (i >= indexTokens.Count) break; - fht.Dispose(); - fht = null; - log.Dispose(); - Setup(); - await RecoverAndTestAsync(logTokens[i], indexTokens[i], isAsync); + PrepareToRecover(deviceType); + await RecoverAndTestAsync(i, isAsync); } } [Test] [Category("FasterKV")] - public async ValueTask RecoveryTestFullCheckpoint([Values] bool isAsync) + [Category("CheckpointRestore")] + [Category("Smoke")] + public async ValueTask RecoveryTestFullCheckpoint([Values] bool isAsync, [Values] TestUtils.DeviceType deviceType) { + Setup(deviceType); Populate(FullCheckpointAction); - foreach (var token in logTokens) + for (var i = 0; i < logTokens.Count; i++) { - fht.Dispose(); - fht = null; - log.Dispose(); - Setup(); - await RecoverAndTestAsync(token, token, isAsync); + PrepareToRecover(deviceType); + await RecoverAndTestAsync(i, isAsync); } } - public void FullCheckpointAction(int opNum) + private void FullCheckpointAction(int opNum) { if ((opNum + 1) % checkpointInterval == 0) { Guid token; - while (!fht.TakeFullCheckpoint(out token)) - { - } - + while (!fht.TakeFullCheckpoint(out token)) { } logTokens.Add(token); indexTokens.Add(token); - - fht.CompleteCheckpointAsync().GetAwaiter().GetResult(); + fht.CompleteCheckpointAsync().AsTask().GetAwaiter().GetResult(); } } - public void SeparateCheckpointAction(int opNum) + private void SeparateCheckpointAction(int opNum) { - if ((opNum + 1) % checkpointInterval != 0) return; + if ((opNum + 1) % checkpointInterval != 0) + return; var checkpointNum = (opNum + 1) / checkpointInterval; + Guid token; if (checkpointNum % 2 == 1) { - Guid token; - while (!fht.TakeHybridLogCheckpoint(out token)) - { - } - + while (!fht.TakeHybridLogCheckpoint(out token)) { } logTokens.Add(token); - fht.CompleteCheckpointAsync().GetAwaiter().GetResult(); } else { - Guid token; - while (!fht.TakeIndexCheckpoint(out token)) - { - } - + while (!fht.TakeIndexCheckpoint(out token)) { } indexTokens.Add(token); - fht.CompleteCheckpointAsync().GetAwaiter().GetResult(); } + fht.CompleteCheckpointAsync().AsTask().GetAwaiter().GetResult(); } - public void Populate(Action<int> checkpointAction) + private void Populate(Action<int> checkpointAction) { - logTokens = new List<Guid>(); - indexTokens = new List<Guid>(); // Prepare the dataset var inputArray = new AdInput[numOps]; for (int i = 0; i < numOps; i++) @@ -151,25 +152,23 @@ public void Populate(Action<int> checkpointAction) checkpointAction(i); if (i % completePendingInterval == 0) - { session.CompletePending(false); - } } // Make sure operations are completed session.CompletePending(true); - - // Deregister thread from FASTER - session.Dispose(); } - public async ValueTask RecoverAndTestAsync(Guid cprVersion, Guid indexVersion, bool isAsync) + private async ValueTask RecoverAndTestAsync(int tokenIndex, bool isAsync) { + var logToken = logTokens[tokenIndex]; + var indexToken = indexTokens[tokenIndex]; + // Recover if (isAsync) - await fht.RecoverAsync(indexVersion, cprVersion); + await fht.RecoverAsync(indexToken, logToken); else - fht.Recover(indexVersion, cprVersion); + fht.Recover(indexToken, logToken); // Create array for reading var inputArray = new AdInput[numUniqueKeys]; @@ -189,7 +188,7 @@ public async ValueTask RecoverAndTestAsync(Guid cprVersion, Guid indexVersion, b for (var i = 0; i < numUniqueKeys; i++) { var status = session.Read(ref inputArray[i].adId, ref input, ref output, Empty.Default, i); - Assert.IsTrue(status == Status.OK); + Assert.AreEqual(Status.OK, status, $"At tokenIndex {tokenIndex}, keyIndex {i}, AdId {inputArray[i].adId.adId}"); inputArray[i].numClicks = output.value; } @@ -201,11 +200,11 @@ public async ValueTask RecoverAndTestAsync(Guid cprVersion, Guid indexVersion, b // Test outputs var checkpointInfo = default(HybridLogRecoveryInfo); - checkpointInfo.Recover(cprVersion, + checkpointInfo.Recover(logToken, new DeviceLogCommitCheckpointManager( new LocalStorageNamedDeviceFactory(), new DefaultCheckpointNamingScheme( - new DirectoryInfo(test_path).FullName))); + new DirectoryInfo(path).FullName))); // Compute expected array long[] expected = new long[numUniqueKeys]; @@ -234,11 +233,299 @@ public async ValueTask RecoverAndTestAsync(Guid cprVersion, Guid indexVersion, b // Assert if expected is same as found for (long i = 0; i < numUniqueKeys; i++) { - Assert.IsTrue( - expected[i] == inputArray[i].numClicks.numClicks, - "Debug error for AdId {0}: Expected ({1}), Found({2})", inputArray[i].adId.adId, expected[i], - inputArray[i].numClicks.numClicks); + Assert.AreEqual(expected[i], inputArray[i].numClicks.numClicks, $"At keyIndex {i}, AdId {inputArray[i].adId.adId}"); } } } -} \ No newline at end of file + + [TestFixture] + internal class AllocatorTypeRecoveryTests + { + // VarLenMax is the variable-length portion; 2 is for the fixed fields + const int VarLenMax = 10; + const int StackAllocMax = VarLenMax + 2; + const int RandSeed = 101; + const long expectedValueBase = DeviceTypeRecoveryTests.numUniqueKeys * (DeviceTypeRecoveryTests.numOps / DeviceTypeRecoveryTests.numUniqueKeys - 1); + private static long ExpectedValue(int key) => expectedValueBase + key; + + private IDisposable fhtDisp; + private string path; + private Guid logToken; + private Guid indexToken; + private IDevice log; + private IDevice objlog; + private bool smallSector; + + // 'object' to avoid generic args + private object varLenStructObj; + private object serializerSettingsObj; + + [SetUp] + public void Setup() + { + smallSector = false; + varLenStructObj = null; + serializerSettingsObj = null; + + path = TestUtils.MethodTestDir + "/"; + + // Only clean these in the initial Setup, as tests use the other Setup() overload to recover + logToken = Guid.Empty; + indexToken = Guid.Empty; + TestUtils.DeleteDirectory(path, true); + } + + private FasterKV<TData, TData> Setup<TData>() + { + log = new LocalMemoryDevice(1L << 26, 1L << 22, 2, sector_size: smallSector ? 64 : (uint)512, fileName: $"{path}{typeof(TData).Name}.log"); + objlog = serializerSettingsObj is null + ? null + : new LocalMemoryDevice(1L << 26, 1L << 22, 2, fileName: $"{path}{typeof(TData).Name}.obj.log"); + + var varLenStruct = this.varLenStructObj as IVariableLengthStruct<TData>; + Assert.AreEqual(this.varLenStructObj is null, varLenStruct is null, "varLenStructSettings"); + VariableLengthStructSettings<TData, TData> varLenStructSettings = varLenStruct is null + ? null + : new VariableLengthStructSettings<TData, TData> { keyLength = varLenStruct, valueLength = varLenStruct }; + + var result = new FasterKV<TData, TData>(DeviceTypeRecoveryTests.keySpace, + new LogSettings { LogDevice = log, ObjectLogDevice = objlog, SegmentSizeBits = 25 }, + new CheckpointSettings { CheckpointDir = path, CheckPointType = CheckpointType.Snapshot }, + this.serializerSettingsObj as SerializerSettings<TData, TData>, + variableLengthStructSettings: varLenStructSettings + ); + + fhtDisp = result; + return result; + } + + [TearDown] + public void TearDown() => TearDown(deleteDir: true); + + private void TearDown(bool deleteDir) + { + fhtDisp?.Dispose(); + fhtDisp = null; + log?.Dispose(); + log = null; + objlog?.Dispose(); + objlog = null; + + // Do NOT clean up here unless specified, as tests use this TearDown() to prepare for recovery + if (deleteDir) + TestUtils.DeleteDirectory(path); + } + + private FasterKV<TData, TData> PrepareToRecover<TData>() + { + TearDown(deleteDir: false); + return Setup<TData>(); + } + + [Test] + [Category("FasterKV")] + [Category("CheckpointRestore")] + public async ValueTask RecoveryTestByAllocatorType([Values] TestUtils.AllocatorType allocatorType, [Values] bool isAsync) + { + await TestDriver(allocatorType, isAsync); + } + + [Test] + [Category("FasterKV")] + [Category("CheckpointRestore")] + public async ValueTask RecoveryTestFailOnSectorSize([Values] TestUtils.AllocatorType allocatorType, [Values] bool isAsync) + { + this.smallSector = true; + await TestDriver(allocatorType, isAsync); + } + + private async ValueTask TestDriver(TestUtils.AllocatorType allocatorType, [Values] bool isAsync) + { + ValueTask task; + switch (allocatorType) + { + case TestUtils.AllocatorType.FixedBlittable: + task = RunTest<long>(Populate, Read, Recover, isAsync); + break; + case TestUtils.AllocatorType.VarLenBlittable: + this.varLenStructObj = new VLValue(); + task = RunTest<VLValue>(Populate, Read, Recover, isAsync); + break; + case TestUtils.AllocatorType.Generic: + this.serializerSettingsObj = new MyValueSerializer(); + task = RunTest<MyValue>(Populate, Read, Recover, isAsync); + break; + default: + throw new ApplicationException("Unknown allocator type"); + }; + await task; + } + + private async ValueTask RunTest<TData>(Action<FasterKV<TData, TData>> populateAction, Action<FasterKV<TData, TData>> readAction, Func<FasterKV<TData, TData>, bool, ValueTask> recoverFunc, bool isAsync) + { + var fht = Setup<TData>(); + populateAction(fht); + readAction(fht); + if (smallSector) + { + Assert.ThrowsAsync<FasterException>(async () => await Checkpoint(fht, isAsync)); + Assert.Pass("Verified expected exception; the test cannot continue, so exiting early with success"); + } + else + await Checkpoint(fht, isAsync); + + Assert.AreNotEqual(Guid.Empty, this.logToken); + Assert.AreNotEqual(Guid.Empty, this.indexToken); + readAction(fht); + + fht = PrepareToRecover<TData>(); + await recoverFunc(fht, isAsync); + readAction(fht); + } + + private void Populate(FasterKV<long, long> fht) + { + using var session = fht.NewSession(new SimpleFunctions<long, long>()); + + for (int i = 0; i < DeviceTypeRecoveryTests.numOps; i++) + session.Upsert(i % DeviceTypeRecoveryTests.numUniqueKeys, i); + session.CompletePending(true); + } + + static int GetVarLen(Random r) => r.Next(VarLenMax) + 2; + + private unsafe void Populate(FasterKV<VLValue, VLValue> fht) + { + using var session = fht.NewSession(new VLFunctions2()); + Random rng = new(RandSeed); + + // Single alloc outside the loop, to the max length we'll need. + int* keyval = stackalloc int[StackAllocMax]; + int* val = stackalloc int[StackAllocMax]; + + for (int i = 0; i < DeviceTypeRecoveryTests.numOps; i++) + { + // We must be consistent on length across iterations of each key value + var key0 = i % (int)DeviceTypeRecoveryTests.numUniqueKeys; + if (key0 == 0) + rng = new (RandSeed); + + ref VLValue key1 = ref *(VLValue*)keyval; + key1.length = 2; + key1.field1 = key0; + + var len = GetVarLen(rng); + ref VLValue value = ref *(VLValue*)val; + value.length = len; + for (int j = 1; j < len; j++) + *(val + j) = i; + + session.Upsert(ref key1, ref value, Empty.Default, 0); + } + session.CompletePending(true); + } + + private unsafe void Populate(FasterKV<MyValue, MyValue> fht) + { + using var session = fht.NewSession(new MyFunctions2()); + + for (int i = 0; i < DeviceTypeRecoveryTests.numOps; i++) + { + var key = new MyValue { value = i % (int)DeviceTypeRecoveryTests.numUniqueKeys }; + var value = new MyValue { value = i }; + session.Upsert(key, value); + } + session.CompletePending(true); + } + + private async ValueTask Checkpoint<TData>(FasterKV<TData, TData> fht, bool isAsync) + { + if (isAsync) + { + var (success, token) = await fht.TakeFullCheckpointAsync(CheckpointType.Snapshot); + Assert.IsTrue(success); + this.logToken = token; + } + else + { + while (!fht.TakeFullCheckpoint(out this.logToken)) { } + fht.CompleteCheckpointAsync().AsTask().GetAwaiter().GetResult(); + } + this.indexToken = this.logToken; + } + + private async ValueTask RecoverAndReadTest(FasterKV<long, long> fht, bool isAsync) + { + await Recover(fht, isAsync); + Read(fht); + } + + private static void Read(FasterKV<long, long> fht) + { + using var session = fht.NewSession(new SimpleFunctions<long, long>()); + + for (var i = 0; i < DeviceTypeRecoveryTests.numUniqueKeys; i++) + { + var status = session.Read(i % DeviceTypeRecoveryTests.numUniqueKeys, default, out long output); + Assert.AreEqual(Status.OK, status, $"keyIndex {i}"); + Assert.AreEqual(ExpectedValue(i), output); + } + } + + private async ValueTask RecoverAndReadTest(FasterKV<VLValue, VLValue> fht, bool isAsync) + { + await Recover(fht, isAsync); + Read(fht); + } + + private static void Read(FasterKV<VLValue, VLValue> fht) + { + using var session = fht.NewSession(new VLFunctions2()); + + Random rng = new (RandSeed); + Input input = default; + + for (var i = 0; i < DeviceTypeRecoveryTests.numUniqueKeys; i++) + { + var key1 = new VLValue { length = 2, field1 = i }; + var len = GetVarLen(rng); + + int[] output = null; + var status = session.Read(ref key1, ref input, ref output, Empty.Default, 0); + + Assert.AreEqual(Status.OK, status); + Assert.AreEqual(len, output[0], "Length"); + Assert.AreEqual(ExpectedValue(i), output[1], "field1"); + for (int j = 2; j < len; j++) + Assert.AreEqual(ExpectedValue(i), output[j], "extra data at position {j}"); + } + } + + private async ValueTask RecoverAndReadTest(FasterKV<MyValue, MyValue> fht, bool isAsync) + { + await Recover(fht, isAsync); + Read(fht); + } + + private static void Read(FasterKV<MyValue, MyValue> fht) + { + using var session = fht.NewSession(new MyFunctions2()); + + for (var i = 0; i < DeviceTypeRecoveryTests.numUniqueKeys; i++) + { + var key = new MyValue { value = i }; + var status = session.Read(key, default, out MyOutput output); + Assert.AreEqual(Status.OK, status, $"keyIndex {i}"); + Assert.AreEqual(ExpectedValue(i), output.value.value); + } + } + + private async ValueTask Recover<TData>(FasterKV<TData, TData> fht, bool isAsync = false) + { + if (isAsync) + await fht.RecoverAsync(this.indexToken, this.logToken); + else + fht.Recover(this.indexToken, this.logToken); + } + } +} diff --git a/cs/test/ReproReadCacheTest.cs b/cs/test/ReproReadCacheTest.cs index 89367e6be..164281882 100644 --- a/cs/test/ReproReadCacheTest.cs +++ b/cs/test/ReproReadCacheTest.cs @@ -33,9 +33,11 @@ public override void ReadCompletionCallback(ref SpanByte key, ref long input, re } [Test] + [Category("FasterKV")] public unsafe void RandomReadCacheTest1() { - var log = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/BasicFasterTests.log", deleteOnClose: true); + TestUtils.DeleteDirectory(TestUtils.MethodTestDir, wait: true); + var log = Devices.CreateLogDevice(TestUtils.MethodTestDir + "/BasicFasterTests.log", deleteOnClose: true); var fht = new FasterKV<SpanByte, long>( size: 1L << 20, new LogSettings @@ -104,6 +106,10 @@ void Read(int i) { Read(r.Next(num)); } + + fht.Dispose(); + log.Dispose(); + TestUtils.DeleteDirectory(TestUtils.MethodTestDir); } } } diff --git a/cs/test/SessionFASTERTests.cs b/cs/test/SessionFASTERTests.cs index 4db2efe4b..74c55d7c2 100644 --- a/cs/test/SessionFASTERTests.cs +++ b/cs/test/SessionFASTERTests.cs @@ -1,19 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -using System; -using System.Text; -using System.Threading; using System.Threading.Tasks; -using System.Collections.Generic; -using System.Linq; using FASTER.core; -using System.IO; using NUnit.Framework; namespace FASTER.test.async { - [TestFixture] internal class SessionFASTERTests { @@ -23,7 +16,8 @@ internal class SessionFASTERTests [SetUp] public void Setup() { - log = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/hlog1.log", deleteOnClose: true); + TestUtils.DeleteDirectory(TestUtils.MethodTestDir, wait: true); + log = Devices.CreateLogDevice(TestUtils.MethodTestDir + "/hlog1.log", deleteOnClose: true); fht = new FasterKV<KeyStruct, ValueStruct> (128, new LogSettings { LogDevice = log, MemorySizeBits = 29 }); } @@ -31,15 +25,16 @@ public void Setup() [TearDown] public void TearDown() { - fht.Dispose(); + fht?.Dispose(); fht = null; - log.Dispose(); + log?.Dispose(); + log = null; + TestUtils.DeleteDirectory(TestUtils.MethodTestDir); } - - [Test] [Category("FasterKV")] + [Category("Smoke")] public void SessionTest1() { using var session = fht.NewSession(new Functions()); @@ -58,14 +53,13 @@ public void SessionTest1() } else { - Assert.IsTrue(status == Status.OK); + Assert.AreEqual(Status.OK, status); } - Assert.IsTrue(output.value.vfield1 == value.vfield1); - Assert.IsTrue(output.value.vfield2 == value.vfield2); + Assert.AreEqual(value.vfield1, output.value.vfield1); + Assert.AreEqual(value.vfield2, output.value.vfield2); } - [Test] [Category("FasterKV")] public void SessionTest2() @@ -91,11 +85,11 @@ public void SessionTest2() } else { - Assert.IsTrue(status == Status.OK); + Assert.AreEqual(Status.OK, status); } - Assert.IsTrue(output.value.vfield1 == value1.vfield1); - Assert.IsTrue(output.value.vfield2 == value1.vfield2); + Assert.AreEqual(value1.vfield1, output.value.vfield1); + Assert.AreEqual(value1.vfield2, output.value.vfield2); status = session2.Read(ref key2, ref input, ref output, Empty.Default, 0); @@ -105,11 +99,11 @@ public void SessionTest2() } else { - Assert.IsTrue(status == Status.OK); + Assert.AreEqual(Status.OK, status); } - Assert.IsTrue(output.value.vfield1 == value2.vfield1); - Assert.IsTrue(output.value.vfield2 == value2.vfield2); + Assert.AreEqual(value2.vfield1, output.value.vfield1); + Assert.AreEqual(value2.vfield2, output.value.vfield2); } [Test] @@ -134,11 +128,11 @@ public void SessionTest3() } else { - Assert.IsTrue(status == Status.OK); + Assert.AreEqual(Status.OK, status); } - Assert.IsTrue(output.value.vfield1 == value.vfield1); - Assert.IsTrue(output.value.vfield2 == value.vfield2); + Assert.AreEqual(value.vfield1, output.value.vfield1); + Assert.AreEqual(value.vfield2, output.value.vfield2); }).Wait(); } @@ -165,11 +159,11 @@ public void SessionTest4() } else { - Assert.IsTrue(status == Status.OK); + Assert.AreEqual(Status.OK, status); } - Assert.IsTrue(output.value.vfield1 == value1.vfield1); - Assert.IsTrue(output.value.vfield2 == value1.vfield2); + Assert.AreEqual(value1.vfield1, output.value.vfield1); + Assert.AreEqual(value1.vfield2, output.value.vfield2); }); var t2 = Task.CompletedTask.ContinueWith((t) => @@ -190,11 +184,11 @@ public void SessionTest4() } else { - Assert.IsTrue(status == Status.OK); + Assert.AreEqual(Status.OK, status); } - Assert.IsTrue(output.value.vfield1 == value2.vfield1); - Assert.IsTrue(output.value.vfield2 == value2.vfield2); + Assert.AreEqual(value2.vfield1, output.value.vfield1); + Assert.AreEqual(value2.vfield2, output.value.vfield2); }); t1.Wait(); @@ -222,11 +216,11 @@ public void SessionTest5() } else { - Assert.IsTrue(status == Status.OK); + Assert.AreEqual(Status.OK, status); } - Assert.IsTrue(output.value.vfield1 == value1.vfield1); - Assert.IsTrue(output.value.vfield2 == value1.vfield2); + Assert.AreEqual(value1.vfield1, output.value.vfield1); + Assert.AreEqual(value1.vfield2, output.value.vfield2); session.Dispose(); @@ -245,7 +239,7 @@ public void SessionTest5() } else { - Assert.IsTrue(status == Status.OK); + Assert.AreEqual(Status.OK, status); } status = session.Read(ref key2, ref input, ref output, Empty.Default, 0); @@ -256,11 +250,11 @@ public void SessionTest5() } else { - Assert.IsTrue(status == Status.OK); + Assert.AreEqual(Status.OK, status); } - Assert.IsTrue(output.value.vfield1 == value2.vfield1); - Assert.IsTrue(output.value.vfield2 == value2.vfield2); + Assert.AreEqual(value2.vfield1, output.value.vfield1); + Assert.AreEqual(value2.vfield2, output.value.vfield2); session.Dispose(); } diff --git a/cs/test/SharedDirectoryTests.cs b/cs/test/SharedDirectoryTests.cs index 8716d9202..1bdd1d049 100644 --- a/cs/test/SharedDirectoryTests.cs +++ b/cs/test/SharedDirectoryTests.cs @@ -17,9 +17,9 @@ namespace FASTER.test.recovery.sumstore [TestFixture] internal class SharedDirectoryTests { - const long numUniqueKeys = (1 << 14); - const long keySpace = (1L << 14); - const long numOps = (1L << 19); + const long numUniqueKeys = (1 << 5); + const long keySpace = (1L << 5); + const long numOps = (1L << 10); const long completePendingInterval = (1L << 10); private string rootPath; private string sharedLogDirectory; @@ -29,8 +29,8 @@ internal class SharedDirectoryTests [SetUp] public void Setup() { - this.rootPath = $"{TestContext.CurrentContext.TestDirectory}/{Path.GetRandomFileName()}"; - Directory.CreateDirectory(this.rootPath); + this.rootPath = TestUtils.MethodTestDir; + TestUtils.RecreateDirectory(this.rootPath); this.sharedLogDirectory = $"{this.rootPath}/SharedLogs"; Directory.CreateDirectory(this.sharedLogDirectory); @@ -43,17 +43,13 @@ public void TearDown() { this.original.TearDown(); this.clone.TearDown(); - try - { - TestUtils.DeleteDirectory(this.rootPath); - } - catch - { - } + TestUtils.DeleteDirectory(rootPath); } [Test] [Category("FasterKV")] + [Category("CheckpointRestore")] + [Category("Smoke")] public async ValueTask SharedLogDirectory([Values]bool isAsync) { this.original.Initialize($"{this.rootPath}/OriginalCheckpoint", this.sharedLogDirectory); @@ -87,7 +83,7 @@ public async ValueTask SharedLogDirectory([Values]bool isAsync) // Dispose original, files should not be deleted on Windows this.original.TearDown(); -#if NETCOREAPP +#if NETCOREAPP || NET if (RuntimeInformation.IsOSPlatform(System.Runtime.InteropServices.OSPlatform.Windows)) #endif { @@ -111,7 +107,7 @@ private struct FasterTestInstance public void Initialize(string checkpointDirectory, string logDirectory, bool populateLogHandles = false) { -#if NETCOREAPP +#if NETCOREAPP || NET if (!RuntimeInformation.IsOSPlatform(System.Runtime.InteropServices.OSPlatform.Windows)) populateLogHandles = false; #endif @@ -140,7 +136,7 @@ public void Initialize(string checkpointDirectory, string logDirectory, bool pop } } -#if NETCOREAPP +#if NETCOREAPP || NET if (!RuntimeInformation.IsOSPlatform(System.Runtime.InteropServices.OSPlatform.Windows)) { this.LogDevice = new ManagedLocalStorageDevice(deviceFileName, deleteOnClose: true); @@ -218,7 +214,7 @@ private void Test(FasterTestInstance fasterInstance, Guid checkpointToken) for (var i = 0; i < numUniqueKeys; i++) { var status = session.Read(ref inputArray[i].adId, ref input, ref output, Empty.Default, i); - Assert.IsTrue(status == Status.OK); + Assert.AreEqual(Status.OK, status); inputArray[i].numClicks = output.value; } @@ -253,9 +249,7 @@ private void Test(FasterTestInstance fasterInstance, Guid checkpointToken) // Assert that expected is same as found for (long i = 0; i < numUniqueKeys; i++) { - Assert.IsTrue( - expected[i] == inputArray[i].numClicks.numClicks, - "Debug error for AdId {0}: Expected ({1}), Found({2})", inputArray[i].adId.adId, expected[i], inputArray[i].numClicks.numClicks); + Assert.AreEqual(expected[i], inputArray[i].numClicks.numClicks, $"AdId {inputArray[i].adId.adId}"); } } diff --git a/cs/test/SimpleAsyncTests.cs b/cs/test/SimpleAsyncTests.cs index 16b4135c9..d08d1d4d5 100644 --- a/cs/test/SimpleAsyncTests.cs +++ b/cs/test/SimpleAsyncTests.cs @@ -2,7 +2,6 @@ // Licensed under the MIT license. using FASTER.core; -using System.IO; using NUnit.Framework; using FASTER.test.recovery.sumstore; using System.Threading.Tasks; @@ -28,9 +27,9 @@ public void Setup() inputArray[i].adId = i; } - path = TestContext.CurrentContext.TestDirectory + "/SimpleAsyncTests/"; + path = TestUtils.MethodTestDir + "/"; + TestUtils.RecreateDirectory(path); log = Devices.CreateLogDevice(path + "Async.log", deleteOnClose: true); - Directory.CreateDirectory(path); fht1 = new FasterKV<long, long> (1L << 10, logSettings: new LogSettings { LogDevice = log, MutableFraction = 1, PageSizeBits = 10, MemorySizeBits = 15 }, @@ -41,14 +40,17 @@ public void Setup() [TearDown] public void TearDown() { - fht1.Dispose(); - log.Dispose(); - new DirectoryInfo(path).Delete(true); + fht1?.Dispose(); + fht1 = null; + log?.Dispose(); + log = null; + TestUtils.DeleteDirectory(path); } // Test that does .ReadAsync with minimum parameters (ref key) [Test] [Category("FasterKV")] + [Category("Smoke")] public async Task ReadAsyncMinParamTest() { using var s1 = fht1.NewSession(new SimpleFunctions<long, long>()); @@ -62,7 +64,8 @@ public async Task ReadAsyncMinParamTest() for (long key = 0; key < numOps; key++) { var (status, output) = (await s1.ReadAsync(ref key)).Complete(); - Assert.IsTrue(status == Status.OK && output == key); + Assert.AreEqual(Status.OK, status); + Assert.AreEqual(key, output); } } @@ -83,13 +86,15 @@ public async Task ReadAsyncMinParamTestNoDefaultTest() for (long key = 0; key < numOps; key++) { var (status, output) = (await s1.ReadAsync(ref key, Empty.Default, 99, cancellationToken)).Complete(); - Assert.IsTrue(status == Status.OK && output == key); + Assert.AreEqual(Status.OK, status); + Assert.AreEqual(key, output); } } // Test that does .ReadAsync no ref key (key) [Test] [Category("FasterKV")] + [Category("Smoke")] public async Task ReadAsyncNoRefKeyTest() { using var s1 = fht1.NewSession(new SimpleFunctions<long, long>()); @@ -102,7 +107,8 @@ public async Task ReadAsyncNoRefKeyTest() for (long key = 0; key < numOps; key++) { var (status, output) = (await s1.ReadAsync(key,Empty.Default, 99)).Complete(); - Assert.IsTrue(status == Status.OK && output == key); + Assert.AreEqual(Status.OK, status); + Assert.AreEqual(key, output); } } @@ -180,6 +186,7 @@ public async Task ReadAsyncNoRefKeyNoRefInputTest() // Test that does .UpsertAsync, .ReadAsync, .DeleteAsync, .ReadAsync with minimum parameters passed by reference (ref key) [Test] [Category("FasterKV")] + [Category("Smoke")] public async Task UpsertReadDeleteReadAsyncMinParamByRefTest() { using var s1 = fht1.NewSession(new SimpleFunctions<long, long>()); @@ -190,12 +197,13 @@ public async Task UpsertReadDeleteReadAsyncMinParamByRefTest() r = await r.CompleteAsync(); // test async version of Upsert completion } - Assert.IsTrue(numOps > 100); + Assert.Greater(numOps, 100); for (long key = 0; key < numOps; key++) { var (status, output) = (await s1.ReadAsync(ref key)).Complete(); - Assert.IsTrue(status == Status.OK && output == key); + Assert.AreEqual(Status.OK, status); + Assert.AreEqual(key, output); } { // Scope for variables @@ -205,13 +213,14 @@ public async Task UpsertReadDeleteReadAsyncMinParamByRefTest() r = await r.CompleteAsync(); // test async version of Delete completion var (status, _) = (await s1.ReadAsync(ref deleteKey)).Complete(); - Assert.IsTrue(status == Status.NOTFOUND); + Assert.AreEqual(Status.NOTFOUND, status); } } // Test that does .UpsertAsync, .ReadAsync, .DeleteAsync, .ReadAsync with minimum parameters passed by value (key) [Test] [Category("FasterKV")] + [Category("Smoke")] public async Task UpsertReadDeleteReadAsyncMinParamByValueTest() { using var s1 = fht1.NewSession(new SimpleFunctions<long, long>()); @@ -221,12 +230,13 @@ public async Task UpsertReadDeleteReadAsyncMinParamByValueTest() Assert.AreNotEqual(Status.PENDING, status); } - Assert.IsTrue(numOps > 100); + Assert.Greater(numOps, 100); for (long key = 0; key < numOps; key++) { var (status, output) = (await s1.ReadAsync(key)).Complete(); - Assert.IsTrue(status == Status.OK && output == key); + Assert.AreEqual(Status.OK, status); + Assert.AreEqual(key, output); } { // Scope for variables @@ -235,7 +245,7 @@ public async Task UpsertReadDeleteReadAsyncMinParamByValueTest() Assert.AreNotEqual(Status.PENDING, status); (status, _) = (await s1.ReadAsync(deleteKey)).Complete(); - Assert.IsTrue(status == Status.NOTFOUND); + Assert.AreEqual(Status.NOTFOUND, status); } } @@ -364,6 +374,7 @@ public async Task ReadyToCompletePendingAsyncTest() // Test that does both UpsertAsync and RMWAsync to populate the FasterKV and update it, possibly after flushing it from memory. [Test] [Category("FasterKV")] + [Category("Smoke")] public async Task UpsertAsyncAndRMWAsyncTest([Values] bool useRMW, [Values] bool doFlush, [Values] bool completeAsync) { using var s1 = fht1.NewSession(new SimpleFunctions<long, long>()); diff --git a/cs/test/SimpleRecoveryTest.cs b/cs/test/SimpleRecoveryTest.cs index 7f7289107..b0c39997a 100644 --- a/cs/test/SimpleRecoveryTest.cs +++ b/cs/test/SimpleRecoveryTest.cs @@ -2,91 +2,111 @@ // Licensed under the MIT license. using System; +using System.Linq; using System.Threading; using System.Threading.Tasks; using FASTER.core; -using System.IO; using NUnit.Framework; using FASTER.devices; namespace FASTER.test.recovery.sumstore.simple { - [TestFixture] public class RecoveryTests { + const int numOps = 5000; + AdId[] inputArray; + + private byte[] commitCookie; + string checkpointDir; + ICheckpointManager checkpointManager; + private FasterKV<AdId, NumClicks> fht1; private FasterKV<AdId, NumClicks> fht2; private IDevice log; - public const string EMULATED_STORAGE_STRING = "UseDevelopmentStorage=true;"; - public const string TEST_CONTAINER = "checkpoints4444"; + + + [SetUp] + public void Setup() + { + TestUtils.DeleteDirectory(TestUtils.MethodTestDir, wait: true); + checkpointManager = default; + checkpointDir = default; + inputArray = new AdId[numOps]; + for (int i = 0; i < numOps; i++) + inputArray[i].adId = i; + } + + [TearDown] + public void TearDown() + { + fht1?.Dispose(); + fht1 = null; + fht2?.Dispose(); + fht2 = null; + log?.Dispose(); + log = null; + + checkpointManager?.Dispose(); + TestUtils.DeleteDirectory(TestUtils.MethodTestDir); + } [Test] - [Category("FasterKV")] - public async ValueTask PageBlobSimpleRecoveryTest([Values]CheckpointType checkpointType, [Values]bool isAsync) + [Category("FasterKV"), Category("CheckpointRestore")] + public async ValueTask PageBlobSimpleRecoveryTest([Values]CheckpointType checkpointType, [Values]bool isAsync, [Values]bool testCommitCookie) { - if ("yes".Equals(Environment.GetEnvironmentVariable("RunAzureTests"))) - { - ICheckpointManager checkpointManager = new DeviceLogCommitCheckpointManager( - new AzureStorageNamedDeviceFactory(EMULATED_STORAGE_STRING), - new DefaultCheckpointNamingScheme($"{TEST_CONTAINER}/PageBlobSimpleRecoveryTest")); - await SimpleRecoveryTest1_Worker(checkpointType, checkpointManager, isAsync); - checkpointManager.PurgeAll(); - checkpointManager.Dispose(); - } + TestUtils.IgnoreIfNotRunningAzureTests(); + checkpointManager = new DeviceLogCommitCheckpointManager( + new AzureStorageNamedDeviceFactory(TestUtils.AzureEmulatedStorageString), + new DefaultCheckpointNamingScheme($"{TestUtils.AzureTestContainer}/{TestUtils.AzureTestDirectory}")); + await SimpleRecoveryTest1_Worker(checkpointType, isAsync, testCommitCookie); + checkpointManager.PurgeAll(); } [Test] [Category("FasterKV")] - public async ValueTask LocalDeviceSimpleRecoveryTest([Values] CheckpointType checkpointType, [Values] bool isAsync) + [Category("CheckpointRestore")] + [Category("Smoke")] + + public async ValueTask LocalDeviceSimpleRecoveryTest([Values] CheckpointType checkpointType, [Values] bool isAsync, [Values]bool testCommitCookie) { - ICheckpointManager checkpointManager = new DeviceLogCommitCheckpointManager( + checkpointManager = new DeviceLogCommitCheckpointManager( new LocalStorageNamedDeviceFactory(), - new DefaultCheckpointNamingScheme($"{TEST_CONTAINER}/PageBlobSimpleRecoveryTest")); - await SimpleRecoveryTest1_Worker(checkpointType, checkpointManager, isAsync); + new DefaultCheckpointNamingScheme($"{TestUtils.MethodTestDir}/{TestUtils.AzureTestDirectory}")); + await SimpleRecoveryTest1_Worker(checkpointType, isAsync, testCommitCookie); checkpointManager.PurgeAll(); - checkpointManager.Dispose(); } - [Test] - [Category("FasterKV")] - public async ValueTask SimpleRecoveryTest1([Values]CheckpointType checkpointType, [Values]bool isAsync) + [Category("FasterKV"), Category("CheckpointRestore")] + public async ValueTask SimpleRecoveryTest1([Values]CheckpointType checkpointType, [Values]bool isAsync, [Values]bool testCommitCookie) { - await SimpleRecoveryTest1_Worker(checkpointType, null, isAsync); + await SimpleRecoveryTest1_Worker(checkpointType, isAsync, testCommitCookie); } - private async ValueTask SimpleRecoveryTest1_Worker(CheckpointType checkpointType, ICheckpointManager checkpointManager, bool isAsync) + private async ValueTask SimpleRecoveryTest1_Worker(CheckpointType checkpointType, bool isAsync, bool testCommitCookie) { - string checkpointDir = TestContext.CurrentContext.TestDirectory + $"/{TEST_CONTAINER}"; + if (testCommitCookie) + { + // Generate a new unique byte sequence for test + commitCookie = Guid.NewGuid().ToByteArray(); + } - if (checkpointManager != null) - checkpointDir = null; + if (checkpointManager is null) + checkpointDir = TestUtils.MethodTestDir + $"/checkpoints"; - log = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/SimpleRecoveryTest1.log", deleteOnClose: true); + log = Devices.CreateLogDevice(TestUtils.MethodTestDir + "/SimpleRecoveryTest1.log", deleteOnClose: true); - fht1 = new FasterKV - <AdId, NumClicks> - (128, + fht1 = new FasterKV<AdId, NumClicks>(128, logSettings: new LogSettings { LogDevice = log, MutableFraction = 0.1, MemorySizeBits = 29 }, checkpointSettings: new CheckpointSettings { CheckpointDir = checkpointDir, CheckpointManager = checkpointManager, CheckPointType = checkpointType } ); - fht2 = new FasterKV - <AdId, NumClicks> - (128, + fht2 = new FasterKV<AdId, NumClicks>(128, logSettings: new LogSettings { LogDevice = log, MutableFraction = 0.1, MemorySizeBits = 29 }, checkpointSettings: new CheckpointSettings { CheckpointDir = checkpointDir, CheckpointManager = checkpointManager, CheckPointType = checkpointType } ); - - int numOps = 5000; - var inputArray = new AdId[numOps]; - for (int i = 0; i < numOps; i++) - { - inputArray[i].adId = i; - } - NumClicks value; AdInput inputArg = default; Output output = default; @@ -97,6 +117,9 @@ private async ValueTask SimpleRecoveryTest1_Worker(CheckpointType checkpointType value.numClicks = key; session1.Upsert(ref inputArray[key], ref value, Empty.Default, 0); } + + if (testCommitCookie) + fht1.CommitCookie = commitCookie; fht1.TakeFullCheckpoint(out Guid token); fht1.CompleteCheckpointAsync().GetAwaiter().GetResult(); session1.Dispose(); @@ -106,6 +129,11 @@ private async ValueTask SimpleRecoveryTest1_Worker(CheckpointType checkpointType else fht2.Recover(token); + if (testCommitCookie) + Assert.IsTrue(fht2.RecoveredCommitCookie.SequenceEqual(commitCookie)); + else + Assert.Null(fht2.RecoveredCommitCookie); + var session2 = fht2.NewSession(new AdSimpleFunctions()); for (int key = 0; key < numOps; key++) { @@ -114,52 +142,29 @@ private async ValueTask SimpleRecoveryTest1_Worker(CheckpointType checkpointType if (status == Status.PENDING) session2.CompletePending(true); else - { Assert.IsTrue(output.value.numClicks == key); - } } session2.Dispose(); - - log.Dispose(); - fht1.Dispose(); - fht2.Dispose(); - - if (checkpointManager == null) - new DirectoryInfo(checkpointDir).Delete(true); } [Test] - [Category("FasterKV")] + [Category("FasterKV"), Category("CheckpointRestore")] public async ValueTask SimpleRecoveryTest2([Values]CheckpointType checkpointType, [Values]bool isAsync) { - var checkpointManager = new DeviceLogCommitCheckpointManager(new LocalStorageNamedDeviceFactory(), new DefaultCheckpointNamingScheme(TestContext.CurrentContext.TestDirectory + "/checkpoints4"), false); + checkpointManager = new DeviceLogCommitCheckpointManager(new LocalStorageNamedDeviceFactory(), new DefaultCheckpointNamingScheme(TestUtils.MethodTestDir + "/checkpoints4"), false); + log = Devices.CreateLogDevice(TestUtils.MethodTestDir + "/SimpleRecoveryTest2.log", deleteOnClose: true); - log = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/SimpleRecoveryTest2.log", deleteOnClose: true); - - // Directory.CreateDirectory(TestContext.CurrentContext.TestDirectory + "/checkpoints4"); - - fht1 = new FasterKV - <AdId, NumClicks> - (128, + fht1 = new FasterKV<AdId, NumClicks>(128, logSettings: new LogSettings { LogDevice = log, MutableFraction = 0.1, MemorySizeBits = 29 }, checkpointSettings: new CheckpointSettings { CheckpointManager = checkpointManager, CheckPointType = checkpointType } ); - fht2 = new FasterKV - <AdId, NumClicks> - (128, + fht2 = new FasterKV<AdId, NumClicks>(128, logSettings: new LogSettings { LogDevice = log, MutableFraction = 0.1, MemorySizeBits = 29 }, checkpointSettings: new CheckpointSettings { CheckpointManager = checkpointManager, CheckPointType = checkpointType } ); - int numOps = 5000; - var inputArray = new AdId[numOps]; - for (int i = 0; i < numOps; i++) - { - inputArray[i].adId = i; - } - NumClicks value; AdInput inputArg = default; Output output = default; @@ -192,45 +197,25 @@ public async ValueTask SimpleRecoveryTest2([Values]CheckpointType checkpointType } } session2.Dispose(); - - log.Dispose(); - fht1.Dispose(); - fht2.Dispose(); - checkpointManager.Dispose(); - - new DirectoryInfo(TestContext.CurrentContext.TestDirectory + "/checkpoints4").Delete(true); } [Test] - [Category("FasterKV")] + [Category("FasterKV"), Category("CheckpointRestore")] public async ValueTask ShouldRecoverBeginAddress([Values]bool isAsync) { - log = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/SimpleRecoveryTest2.log", deleteOnClose: true); - - Directory.CreateDirectory(TestContext.CurrentContext.TestDirectory + "/checkpoints6"); + log = Devices.CreateLogDevice(TestUtils.MethodTestDir + "/SimpleRecoveryTest2.log", deleteOnClose: true); + checkpointDir = TestUtils.MethodTestDir + "/checkpoints6"; - fht1 = new FasterKV - <AdId, NumClicks> - (128, + fht1 = new FasterKV<AdId, NumClicks>(128, logSettings: new LogSettings { LogDevice = log, MutableFraction = 0.1, MemorySizeBits = 29 }, - checkpointSettings: new CheckpointSettings { CheckpointDir = TestContext.CurrentContext.TestDirectory + "/checkpoints6", CheckPointType = CheckpointType.FoldOver } + checkpointSettings: new CheckpointSettings { CheckpointDir = checkpointDir, CheckPointType = CheckpointType.FoldOver } ); - fht2 = new FasterKV - <AdId, NumClicks> - (128, + fht2 = new FasterKV<AdId, NumClicks>(128, logSettings: new LogSettings { LogDevice = log, MutableFraction = 0.1, MemorySizeBits = 29 }, - checkpointSettings: new CheckpointSettings { CheckpointDir = TestContext.CurrentContext.TestDirectory + "/checkpoints6", CheckPointType = CheckpointType.FoldOver } + checkpointSettings: new CheckpointSettings { CheckpointDir = checkpointDir, CheckPointType = CheckpointType.FoldOver } ); - - int numOps = 5000; - var inputArray = new AdId[numOps]; - for (int i = 0; i < numOps; i++) - { - inputArray[i].adId = i; - } - NumClicks value; var session1 = fht1.NewSession(new AdSimpleFunctions()); @@ -256,11 +241,6 @@ public async ValueTask ShouldRecoverBeginAddress([Values]bool isAsync) fht2.Recover(token); Assert.AreEqual(address, fht2.Log.BeginAddress); - - log.Dispose(); - fht1.Dispose(); - fht2.Dispose(); - new DirectoryInfo(TestContext.CurrentContext.TestDirectory + "/checkpoints6").Delete(true); } } diff --git a/cs/test/SimpleTests.cs b/cs/test/SimpleTests.cs index 6faafc7b8..91fb7418d 100644 --- a/cs/test/SimpleTests.cs +++ b/cs/test/SimpleTests.cs @@ -17,38 +17,38 @@ public unsafe void AddressInfoTest() AddressInfo info; AddressInfo.WriteInfo(&info, 44, 55); - Assert.IsTrue(info.Address == 44); - Assert.IsTrue(info.Size == 512); + Assert.AreEqual(44, info.Address); + Assert.AreEqual(512, info.Size); AddressInfo.WriteInfo(&info, 44, 512); - Assert.IsTrue(info.Address == 44); - Assert.IsTrue(info.Size == 512); + Assert.AreEqual(44, info.Address); + Assert.AreEqual(512, info.Size); AddressInfo.WriteInfo(&info, 44, 513); - Assert.IsTrue(info.Address == 44); - Assert.IsTrue(info.Size == 1024); + Assert.AreEqual(44, info.Address); + Assert.AreEqual(1024, info.Size); if (sizeof(IntPtr) > 4) { AddressInfo.WriteInfo(&info, 44, 1L << 20); - Assert.IsTrue(info.Address == 44); - Assert.IsTrue(info.Size == 1L << 20); + Assert.AreEqual(44, info.Address); + Assert.AreEqual(1L << 20, info.Size); AddressInfo.WriteInfo(&info, 44, 511 * (1L << 20)); - Assert.IsTrue(info.Address == 44); - Assert.IsTrue(info.Size == 511 * (1L << 20)); + Assert.AreEqual(44, info.Address); + Assert.AreEqual(511 * (1L << 20), info.Size); AddressInfo.WriteInfo(&info, 44, 512 * (1L << 20)); - Assert.IsTrue(info.Address == 44); - Assert.IsTrue(info.Size == 512 * (1L << 20)); + Assert.AreEqual(44, info.Address); + Assert.AreEqual(512 * (1L << 20), info.Size); AddressInfo.WriteInfo(&info, 44, 555555555L); - Assert.IsTrue(info.Address == 44); - Assert.IsTrue(info.Size == (1 + (555555555L / 512)) * 512); + Assert.AreEqual(44, info.Address); + Assert.AreEqual((1 + (555555555L / 512)) * 512, info.Size); AddressInfo.WriteInfo(&info, 44, 2 * 555555555L); - Assert.IsTrue(info.Address == 44); - Assert.IsTrue(info.Size == (1 + (2 * 555555555L / 1048576)) * 1048576); + Assert.AreEqual(44, info.Address); + Assert.AreEqual((1 + (2 * 555555555L / 1048576)) * 1048576, info.Size); } } } diff --git a/cs/test/SpanByteTests.cs b/cs/test/SpanByteTests.cs index 59216de6a..219cc4ac7 100644 --- a/cs/test/SpanByteTests.cs +++ b/cs/test/SpanByteTests.cs @@ -9,99 +9,170 @@ namespace FASTER.test { - [TestFixture] internal class SpanByteTests { [Test] [Category("FasterKV")] + [Category("Smoke")] public unsafe void SpanByteTest1() { Span<byte> output = stackalloc byte[20]; SpanByte input = default; + TestUtils.DeleteDirectory(TestUtils.MethodTestDir, wait:true); - using var log = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/hlog1.log", deleteOnClose: true); - using var fht = new FasterKV<SpanByte, SpanByte> - (128, new LogSettings { LogDevice = log, MemorySizeBits = 17, PageSizeBits = 12 }); - using var s = fht.NewSession(new SpanByteFunctions<Empty>()); - - var key1 = MemoryMarshal.Cast<char, byte>("key1".AsSpan()); - var value1 = MemoryMarshal.Cast<char, byte>("value1".AsSpan()); - var output1 = SpanByteAndMemory.FromFixedSpan(output); - - s.Upsert(key1, value1); + try + { + using var log = Devices.CreateLogDevice(TestUtils.MethodTestDir + "/hlog1.log", deleteOnClose: true); + using var fht = new FasterKV<SpanByte, SpanByte> + (128, new LogSettings { LogDevice = log, MemorySizeBits = 17, PageSizeBits = 12 }); + using var s = fht.NewSession(new SpanByteFunctions<Empty>()); - s.Read(key1, ref input, ref output1); + var key1 = MemoryMarshal.Cast<char, byte>("key1".AsSpan()); + var value1 = MemoryMarshal.Cast<char, byte>("value1".AsSpan()); + var output1 = SpanByteAndMemory.FromFixedSpan(output); - Assert.IsTrue(output1.IsSpanByte); - Assert.IsTrue(output1.SpanByte.AsReadOnlySpan().SequenceEqual(value1)); + s.Upsert(key1, value1); + + s.Read(key1, ref input, ref output1); - var key2 = MemoryMarshal.Cast<char, byte>("key2".AsSpan()); - var value2 = MemoryMarshal.Cast<char, byte>("value2value2value2".AsSpan()); - var output2 = SpanByteAndMemory.FromFixedSpan(output); + Assert.IsTrue(output1.IsSpanByte); + Assert.IsTrue(output1.SpanByte.AsReadOnlySpan().SequenceEqual(value1)); - s.Upsert(key2, value2); + var key2 = MemoryMarshal.Cast<char, byte>("key2".AsSpan()); + var value2 = MemoryMarshal.Cast<char, byte>("value2value2value2".AsSpan()); + var output2 = SpanByteAndMemory.FromFixedSpan(output); - s.Read(key2, ref input, ref output2); + s.Upsert(key2, value2); + s.Read(key2, ref input, ref output2); - Assert.IsTrue(!output2.IsSpanByte); - Assert.IsTrue(output2.Memory.Memory.Span.Slice(0, output2.Length).SequenceEqual(value2)); + Assert.IsTrue(!output2.IsSpanByte); + Assert.IsTrue(output2.Memory.Memory.Span.Slice(0, output2.Length).SequenceEqual(value2)); + } + finally + { + TestUtils.DeleteDirectory(TestUtils.MethodTestDir); + } } [Test] [Category("FasterKV")] + [Category("Smoke")] public unsafe void MultiReadSpanByteKeyTest() { - using var log = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/MultiReadSpanByteKeyTest.log", deleteOnClose: true); - using var fht = new FasterKV<SpanByte, long>( - size: 1L << 20, - new LogSettings { LogDevice = log, MemorySizeBits = 15, PageSizeBits = 12 }); - using var session = fht.For(new MultiReadSpanByteKeyTestFunctions()).NewSession<MultiReadSpanByteKeyTestFunctions>(); + TestUtils.DeleteDirectory(TestUtils.MethodTestDir, wait: true); - for (int i = 0; i < 3000; i++) + try { - var key = MemoryMarshal.Cast<char, byte>($"{i}".AsSpan()); - fixed (byte* _ = key) - session.Upsert(SpanByte.FromFixedSpan(key), i); - } + using var log = Devices.CreateLogDevice(TestUtils.MethodTestDir + "/MultiReadSpanByteKeyTest.log", deleteOnClose: true); + using var fht = new FasterKV<SpanByte, long>( + size: 1L << 10, + new LogSettings { LogDevice = log, MemorySizeBits = 15, PageSizeBits = 12 }); + using var session = fht.For(new MultiReadSpanByteKeyTestFunctions()).NewSession<MultiReadSpanByteKeyTestFunctions>(); - // Evict all records to disk - fht.Log.FlushAndEvict(true); + for (int i = 0; i < 200; i++) + { + var key = MemoryMarshal.Cast<char, byte>($"{i}".AsSpan()); + fixed (byte* _ = key) + session.Upsert(SpanByte.FromFixedSpan(key), i); + } - for (long key = 0; key < 50; key++) - { - // read each key multiple times - for (int i = 0; i < 10; i++) - Assert.AreEqual(key, ReadKey($"{key}")); - } + // Evict all records to disk + fht.Log.FlushAndEvict(true); - long ReadKey(string keyString) - { - Status status; + for (long key = 0; key < 50; key++) + { + // read each key multiple times + for (int i = 0; i < 10; i++) + Assert.AreEqual(key, ReadKey($"{key}")); + } - var key = MemoryMarshal.Cast<char, byte>(keyString.AsSpan()); - fixed (byte* _ = key) - status = session.Read(key: SpanByte.FromFixedSpan(key), out var unused); + long ReadKey(string keyString) + { + Status status; - // All keys need to be fetched from disk - Assert.AreEqual(Status.PENDING, status); + var key = MemoryMarshal.Cast<char, byte>(keyString.AsSpan()); + fixed (byte* _ = key) + status = session.Read(key: SpanByte.FromFixedSpan(key), out var unused); + + // All keys need to be fetched from disk + Assert.AreEqual(Status.PENDING, status); - session.CompletePendingWithOutputs(out var completedOutputs, wait: true); + session.CompletePendingWithOutputs(out var completedOutputs, wait: true); - var count = 0; - var value = 0L; - using (completedOutputs) - { - while (completedOutputs.Next()) + var count = 0; + var value = 0L; + using (completedOutputs) { - count++; - Assert.AreEqual(Status.OK, completedOutputs.Current.Status); - value = completedOutputs.Current.Output; + while (completedOutputs.Next()) + { + count++; + Assert.AreEqual(Status.OK, completedOutputs.Current.Status); + value = completedOutputs.Current.Output; + } } + Assert.AreEqual(1, count); + return value; } - Assert.AreEqual(1, count); - return value; } + finally + { + TestUtils.DeleteDirectory(TestUtils.MethodTestDir); + } + } + + [Test] + [Category("FasterKV")] + [Category("Smoke")] + public unsafe void SpanByteUnitTest1() + { + Span<byte> payload = stackalloc byte[20]; + Span<byte> serialized = stackalloc byte[24]; + + SpanByte sb = SpanByte.FromFixedSpan(payload); + Assert.IsFalse(sb.Serialized); + Assert.AreEqual(20, sb.Length); + Assert.AreEqual(24, sb.TotalSize); + Assert.AreEqual(20, sb.AsSpan().Length); + Assert.AreEqual(20, sb.AsReadOnlySpan().Length); + + fixed (byte* ptr = serialized) + sb.CopyTo(ptr); + ref SpanByte ssb = ref SpanByte.ReinterpretWithoutLength(serialized); + Assert.IsTrue(ssb.Serialized); + Assert.AreEqual(0, ssb.MetadataSize); + Assert.AreEqual(20, ssb.Length); + Assert.AreEqual(24, ssb.TotalSize); + Assert.AreEqual(20, ssb.AsSpan().Length); + Assert.AreEqual(20, ssb.AsReadOnlySpan().Length); + + ssb.MarkExtraMetadata(); + Assert.IsTrue(ssb.Serialized); + Assert.AreEqual(8, ssb.MetadataSize); + Assert.AreEqual(20, ssb.Length); + Assert.AreEqual(24, ssb.TotalSize); + Assert.AreEqual(20 - 8, ssb.AsSpan().Length); + Assert.AreEqual(20 - 8, ssb.AsReadOnlySpan().Length); + ssb.ExtraMetadata = 31337; + Assert.AreEqual(31337, ssb.ExtraMetadata); + + sb.MarkExtraMetadata(); + Assert.AreEqual(20, sb.Length); + Assert.AreEqual(24, sb.TotalSize); + Assert.AreEqual(20 - 8, sb.AsSpan().Length); + Assert.AreEqual(20 - 8, sb.AsReadOnlySpan().Length); + sb.ExtraMetadata = 31337; + Assert.AreEqual(31337, sb.ExtraMetadata); + + fixed (byte* ptr = serialized) + sb.CopyTo(ptr); + Assert.IsTrue(ssb.Serialized); + Assert.AreEqual(8, ssb.MetadataSize); + Assert.AreEqual(20, ssb.Length); + Assert.AreEqual(24, ssb.TotalSize); + Assert.AreEqual(20 - 8, ssb.AsSpan().Length); + Assert.AreEqual(20 - 8, ssb.AsReadOnlySpan().Length); + Assert.AreEqual(31337, ssb.ExtraMetadata); } class MultiReadSpanByteKeyTestFunctions : FunctionsBase<SpanByte, long, long, long, Empty> diff --git a/cs/test/StateMachineTests.cs b/cs/test/StateMachineTests.cs index 6358eafaf..63315cec1 100644 --- a/cs/test/StateMachineTests.cs +++ b/cs/test/StateMachineTests.cs @@ -1,21 +1,16 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -using System; -using System.Text; -using System.Threading; -using System.Threading.Tasks; using System.Collections.Generic; -using System.Linq; +using System.Threading; using FASTER.core; using System.IO; using NUnit.Framework; using FASTER.test.recovery.sumstore; -using System.Diagnostics; +using NUnit.Framework.Interfaces; namespace FASTER.test.statemachine { - [TestFixture] public class StateMachineTests { @@ -33,26 +28,30 @@ public void Setup() inputArray[i].adId = i; } - log = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/StateMachineTest1.log", deleteOnClose: true); - Directory.CreateDirectory(TestContext.CurrentContext.TestDirectory + "/statemachinetest"); + log = Devices.CreateLogDevice(TestUtils.MethodTestDir + "/StateMachineTest1.log", deleteOnClose: true); + string checkpointDir = TestUtils.MethodTestDir + "/statemachinetest"; + Directory.CreateDirectory(checkpointDir); fht1 = new FasterKV<AdId, NumClicks> (128, logSettings: new LogSettings { LogDevice = log, MutableFraction = 0.1, PageSizeBits = 10, MemorySizeBits = 13 }, - checkpointSettings: new CheckpointSettings { CheckpointDir = TestContext.CurrentContext.TestDirectory + "/statemachinetest", CheckPointType = CheckpointType.FoldOver } + checkpointSettings: new CheckpointSettings { CheckpointDir = checkpointDir, CheckPointType = CheckpointType.FoldOver } ); } [TearDown] public void TearDown() { - fht1.Dispose(); - log.Dispose(); - new DirectoryInfo(TestContext.CurrentContext.TestDirectory + "/statemachinetest").Delete(true); + fht1?.Dispose(); + fht1 = null; + log?.Dispose(); + log = null; + TestUtils.DeleteDirectory(TestUtils.MethodTestDir); } - [TestCase] [Category("FasterKV")] + [Category("CheckpointRestore")] + [Category("Smoke")] public void StateMachineTest1() { Prepare(out var f, out var s1, out var s2); @@ -94,7 +93,7 @@ public void StateMachineTest1() s1.Refresh(); // Completion callback should have been called once - Assert.IsTrue(f.checkpointCallbackExpectation == 0); + Assert.AreEqual(0, f.checkpointCallbackExpectation); // We should be in REST, 2 Assert.IsTrue(SystemState.Equal(SystemState.Make(Phase.REST, 2), fht1.SystemState)); @@ -108,7 +107,7 @@ public void StateMachineTest1() [TestCase] - [Category("FasterKV")] + [Category("FasterKV"), Category("CheckpointRestore")] public void StateMachineTest2() { Prepare(out var f, out var s1, out var s2); @@ -142,7 +141,7 @@ public void StateMachineTest2() s1.Refresh(); // Completion callback should have been called once - Assert.IsTrue(f.checkpointCallbackExpectation == 0); + Assert.AreEqual(0, f.checkpointCallbackExpectation); // We should be in REST, 2 Assert.IsTrue(SystemState.Equal(SystemState.Make(Phase.REST, 2), fht1.SystemState)); @@ -153,7 +152,7 @@ public void StateMachineTest2() } [TestCase] - [Category("FasterKV")] + [Category("FasterKV"), Category("CheckpointRestore")] public void StateMachineTest3() { Prepare(out var f, out var s1, out var s2); @@ -183,7 +182,7 @@ public void StateMachineTest3() s1.UnsafeResumeThread(); // Completion callback should have been called once - Assert.IsTrue(f.checkpointCallbackExpectation == 0); + Assert.AreEqual(0, f.checkpointCallbackExpectation); s2.Dispose(); s1.Dispose(); @@ -192,7 +191,7 @@ public void StateMachineTest3() } [TestCase] - [Category("FasterKV")] + [Category("FasterKV"), Category("CheckpointRestore")] public void StateMachineTest4() { Prepare(out var f, out var s1, out var s2); @@ -231,7 +230,7 @@ public void StateMachineTest4() s1.UnsafeResumeThread(); // Completion callback should have been called once - Assert.IsTrue(f.checkpointCallbackExpectation == 0); + Assert.AreEqual(0, f.checkpointCallbackExpectation); s2.Dispose(); s1.Dispose(); @@ -240,7 +239,7 @@ public void StateMachineTest4() } [TestCase] - [Category("FasterKV")] + [Category("FasterKV"), Category("CheckpointRestore")] public void StateMachineTest5() { Prepare(out var f, out var s1, out var s2); @@ -271,7 +270,7 @@ public void StateMachineTest5() s1.Refresh(); // Completion callback should have been called once - Assert.IsTrue(f.checkpointCallbackExpectation == 0); + Assert.AreEqual(0, f.checkpointCallbackExpectation); // We should be in PERSISTENCE_CALLBACK, 2 Assert.IsTrue(SystemState.Equal(SystemState.Make(Phase.PERSISTENCE_CALLBACK, 2), fht1.SystemState)); @@ -295,7 +294,7 @@ public void StateMachineTest5() s1.UnsafeResumeThread(); // Completion callback should have been called once - Assert.IsTrue(f.checkpointCallbackExpectation == 0); + Assert.AreEqual(0, f.checkpointCallbackExpectation); s2.Dispose(); s1.Dispose(); @@ -305,7 +304,7 @@ public void StateMachineTest5() [TestCase] - [Category("FasterKV")] + [Category("FasterKV"), Category("CheckpointRestore")] public void StateMachineTest6() { Prepare(out var f, out var s1, out var s2); @@ -329,7 +328,7 @@ public void StateMachineTest6() s2.Dispose(); fht1.TakeHybridLogCheckpoint(out _); - fht1.CompleteCheckpointAsync().GetAwaiter().GetResult(); + fht1.CompleteCheckpointAsync().AsTask().GetAwaiter().GetResult(); // We should be in REST, 3 Assert.IsTrue(SystemState.Equal(SystemState.Make(Phase.REST, 3), fht1.SystemState)); @@ -339,15 +338,127 @@ public void StateMachineTest6() s1.UnsafeResumeThread(); + // Completion callback should have been called once + Assert.AreEqual(0, f.checkpointCallbackExpectation); + + s1.Dispose(); + + RecoverAndTest(log); + } + + [TestCase] + [Category("FasterKV")] + [Category("CheckpointRestore")] + [Category("Smoke")] + public void StateMachineCallbackTest1() + { + var callback = new TestCallback(); + fht1.UnsafeRegisterCallback(callback); + Prepare(out var f, out var s1, out var s2); + + // We should be in PREPARE, 1 + Assert.IsTrue(SystemState.Equal(SystemState.Make(Phase.PREPARE, 1), fht1.SystemState)); + callback.CheckInvoked(fht1.SystemState); + + // Refresh session s2 + s2.Refresh(); + s1.Refresh(); + + // We should now be in IN_PROGRESS, 2 + Assert.IsTrue(SystemState.Equal(SystemState.Make(Phase.IN_PROGRESS, 2), fht1.SystemState)); + callback.CheckInvoked(fht1.SystemState); + + s2.Refresh(); + + // We should be in WAIT_PENDING, 2 + Assert.IsTrue(SystemState.Equal(SystemState.Make(Phase.WAIT_PENDING, 2), fht1.SystemState)); + callback.CheckInvoked(fht1.SystemState); + + s1.Refresh(); + + // We should be in WAIT_FLUSH, 2 + Assert.IsTrue(SystemState.Equal(SystemState.Make(Phase.WAIT_FLUSH, 2), fht1.SystemState)); + callback.CheckInvoked(fht1.SystemState); + + s2.Refresh(); + + // We should be in PERSISTENCE_CALLBACK, 2 + Assert.IsTrue(SystemState.Equal(SystemState.Make(Phase.PERSISTENCE_CALLBACK, 2), fht1.SystemState)); + callback.CheckInvoked(fht1.SystemState); + + // Expect checkpoint completion callback + f.checkpointCallbackExpectation = 1; + + s1.Refresh(); + // Completion callback should have been called once Assert.IsTrue(f.checkpointCallbackExpectation == 0); + // We should be in REST, 2 + Assert.IsTrue(SystemState.Equal(SystemState.Make(Phase.REST, 2), fht1.SystemState)); + callback.CheckInvoked(fht1.SystemState); + + // Dispose session s2; does not move state machine forward + s2.Dispose(); s1.Dispose(); RecoverAndTest(log); } + + + [TestCase] + [Category("FasterKV")] + [Category("CheckpointRestore")] + public void VersionChangeRollOverTest() + { + var toVersion = 1 + (1 << 14); + Prepare(out var f, out var s1, out var s2, toVersion); + + // We should be in PREPARE, 1 + Assert.IsTrue(SystemState.Equal(SystemState.Make(Phase.PREPARE, 1), fht1.SystemState)); + + // Refresh session s2 + s2.Refresh(); + s1.Refresh(); + + // We should now be in IN_PROGRESS, toVersion + 1 (because of rollover of 13 bit short version) + Assert.IsTrue(SystemState.Equal(SystemState.Make(Phase.IN_PROGRESS, toVersion + 1), fht1.SystemState)); + + s2.Refresh(); + + // We should be in WAIT_PENDING, 2 + Assert.IsTrue(SystemState.Equal(SystemState.Make(Phase.WAIT_PENDING, toVersion + 1), fht1.SystemState)); - void Prepare(out SimpleFunctions f, out ClientSession<AdId, NumClicks, NumClicks, NumClicks, Empty, SimpleFunctions> s1, out ThreadSession<AdId, NumClicks, NumClicks, NumClicks, Empty, SimpleFunctions> s2) + s1.Refresh(); + + // We should be in WAIT_FLUSH, 2 + Assert.IsTrue(SystemState.Equal(SystemState.Make(Phase.WAIT_FLUSH, toVersion + 1), fht1.SystemState)); + + s2.Refresh(); + + // We should be in PERSISTENCE_CALLBACK, 2 + Assert.IsTrue(SystemState.Equal(SystemState.Make(Phase.PERSISTENCE_CALLBACK, toVersion + 1), fht1.SystemState)); + + // Expect checkpoint completion callback + f.checkpointCallbackExpectation = 1; + + s1.Refresh(); + + Assert.IsTrue(SystemState.Equal(SystemState.Make(Phase.REST, toVersion + 1), fht1.SystemState)); + + + // Dispose session s2; does not move state machine forward + s2.Dispose(); + s1.Dispose(); + + RecoverAndTest(log); + } + + + void Prepare(out SimpleFunctions f, + out ClientSession<AdId, NumClicks, NumClicks, NumClicks, Empty, SimpleFunctions> s1, + out ThreadSession<AdId, NumClicks, NumClicks, NumClicks, Empty, SimpleFunctions> s2, + long toVersion = -1) { f = new SimpleFunctions(); @@ -356,7 +467,7 @@ void Prepare(out SimpleFunctions f, out ClientSession<AdId, NumClicks, NumClicks // Take index checkpoint for recovery purposes fht1.TakeIndexCheckpoint(out _); - fht1.CompleteCheckpointAsync().GetAwaiter().GetResult(); + fht1.CompleteCheckpointAsync().AsTask().GetAwaiter().GetResult(); // Index checkpoint does not update version, so // we should still be in REST, 1 @@ -381,7 +492,7 @@ void Prepare(out SimpleFunctions f, out ClientSession<AdId, NumClicks, NumClicks // We should be in REST, 1 Assert.IsTrue(SystemState.Equal(SystemState.Make(Phase.REST, 1), fht1.SystemState)); - fht1.TakeHybridLogCheckpoint(out _); + fht1.TakeHybridLogCheckpoint(out _, toVersion); // We should be in PREPARE, 1 Assert.IsTrue(SystemState.Equal(SystemState.Make(Phase.PREPARE, 1), fht1.SystemState)); @@ -397,14 +508,14 @@ void RecoverAndTest(IDevice log) <AdId, NumClicks> (128, logSettings: new LogSettings { LogDevice = log, MutableFraction = 0.1, PageSizeBits = 10, MemorySizeBits = 13 }, - checkpointSettings: new CheckpointSettings { CheckpointDir = TestContext.CurrentContext.TestDirectory + "/statemachinetest", CheckPointType = CheckpointType.FoldOver } + checkpointSettings: new CheckpointSettings { CheckpointDir = TestUtils.MethodTestDir + "/statemachinetest", CheckPointType = CheckpointType.FoldOver } ); fht2.Recover(); // sync, does not require session using (var s3 = fht2.ResumeSession(f, "foo", out CommitPoint lsn)) { - Assert.IsTrue(lsn.UntilSerialNo == numOps - 1); + Assert.AreEqual(numOps - 1, lsn.UntilSerialNo); // Expect checkpoint completion callback f.checkpointCallbackExpectation = 1; @@ -412,7 +523,7 @@ void RecoverAndTest(IDevice log) s3.Refresh(); // Completion callback should have been called once - Assert.IsTrue(f.checkpointCallbackExpectation == 0); + Assert.AreEqual(0, f.checkpointCallbackExpectation); for (var key = 0; key < numOps; key++) { @@ -422,7 +533,7 @@ void RecoverAndTest(IDevice log) s3.CompletePending(true); else { - Assert.IsTrue(output.numClicks == key); + Assert.AreEqual(key, output.numClicks); } } } @@ -440,7 +551,7 @@ public override void CheckpointCompletionCallback(string sessionId, CommitPoint switch (checkpointCallbackExpectation) { case 0: - Assert.IsTrue(false, "Unexpected checkpoint callback"); + Assert.Fail("Unexpected checkpoint callback"); break; default: Interlocked.Decrement(ref checkpointCallbackExpectation); @@ -450,8 +561,25 @@ public override void CheckpointCompletionCallback(string sessionId, CommitPoint public override void ReadCompletionCallback(ref AdId key, ref NumClicks input, ref NumClicks output, Empty ctx, Status status) { - Assert.IsTrue(status == Status.OK); - Assert.IsTrue(output.numClicks == key.adId); + Assert.AreEqual(Status.OK, status); + Assert.AreEqual(key.adId, output.numClicks); + } + } + + public class TestCallback : IStateMachineCallback + { + private readonly HashSet<SystemState> invokedStates = new(); + + + public void BeforeEnteringState<Key1, Value>(SystemState next, FasterKV<Key1, Value> faster) + { + Assert.IsFalse(invokedStates.Contains(next)); + invokedStates.Add(next); + } + + public void CheckInvoked(SystemState state) + { + Assert.IsTrue(invokedStates.Contains(state)); } } } diff --git a/cs/test/TestUtils.cs b/cs/test/TestUtils.cs index b0800d0c2..1e4200c84 100644 --- a/cs/test/TestUtils.cs +++ b/cs/test/TestUtils.cs @@ -1,34 +1,171 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. +using NUnit.Framework; using System; +using System.Diagnostics; using System.IO; +using FASTER.core; +using FASTER.devices; +using System.Threading; +using System.Runtime.InteropServices; namespace FASTER.test { internal static class TestUtils { - internal static void DeleteDirectory(string path) + /// <summary> + /// Delete a directory recursively + /// </summary> + /// <param name="path">The folder to delete</param> + /// <param name="wait">If true, loop on exceptions that are retryable, and verify the directory no longer exists. Generally true on SetUp, false on TearDown</param> + internal static void DeleteDirectory(string path, bool wait = false) { + if (!Directory.Exists(path)) + return; + foreach (string directory in Directory.GetDirectories(path)) - { - DeleteDirectory(directory); - } + DeleteDirectory(directory, wait); - // Exceptions may happen due to a handle briefly remaining held after Dispose(). - try - { - Directory.Delete(path, true); - } - catch (Exception ex) when (ex is IOException || - ex is UnauthorizedAccessException) + bool retry = true; + while (retry) { + // Exceptions may happen due to a handle briefly remaining held after Dispose(). + retry = false; try { Directory.Delete(path, true); } - catch { } + catch (Exception ex) when (ex is IOException || + ex is UnauthorizedAccessException) + { + if (!wait) + { + try { Directory.Delete(path, true); } + catch { } + return; + } + retry = true; + } + } + + if (!wait) + return; + + while (Directory.Exists(path)) + Thread.Yield(); + } + + /// <summary> + /// Create a clean new directory, removing a previous one if needed. + /// </summary> + /// <param name="path"></param> + internal static void RecreateDirectory(string path) + { + if (Directory.Exists(path)) + DeleteDirectory(path); + + // Don't catch; if this fails, so should the test + Directory.CreateDirectory(path); + } + + internal static bool IsRunningAzureTests => "yes".Equals(Environment.GetEnvironmentVariable("RunAzureTests")); + + internal static void IgnoreIfNotRunningAzureTests() + { + // Need this environment variable set AND Azure Storage Emulator running + if (!IsRunningAzureTests) + Assert.Ignore("Environment variable RunAzureTests is not defined"); + } + + // Used to test the various devices by using the same test with VALUES parameter + // Cannot use LocalStorageDevice from non-Windows OS platform + public enum DeviceType + { +#if WINDOWS + LSD, + EmulatedAzure, +#endif + MLSD, + LocalMemory + } + + internal static IDevice CreateTestDevice(DeviceType testDeviceType, string filename, int latencyMs = 20) // latencyMs works only for DeviceType = LocalMemory + { + IDevice device = null; + bool preallocateFile = false; + long capacity = -1; // Capacity unspecified + bool recoverDevice = false; + bool useIoCompletionPort = false; + bool disableFileBuffering = true; + + bool deleteOnClose = false; + + switch (testDeviceType) + { +#if WINDOWS + case DeviceType.LSD: +#if NETSTANDARD || NET + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) // avoids CA1416 // Validate platform compatibility +#endif + device = new LocalStorageDevice(filename, preallocateFile, deleteOnClose, disableFileBuffering, capacity, recoverDevice, useIoCompletionPort); + break; + case DeviceType.EmulatedAzure: + IgnoreIfNotRunningAzureTests(); + device = new AzureStorageDevice(AzureEmulatedStorageString, AzureTestContainer, AzureTestDirectory, Path.GetFileName(filename), deleteOnClose: false); + break; +#endif + case DeviceType.MLSD: + device = new ManagedLocalStorageDevice(filename, preallocateFile, deleteOnClose, capacity, recoverDevice); + break; + // Emulated higher latency storage device - takes a disk latency arg (latencyMs) and emulates an IDevice using main memory, serving data at specified latency + case DeviceType.LocalMemory: + device = new LocalMemoryDevice(1L << 26, 1L << 22, 2, sector_size: 512, latencyMs: latencyMs, fileName: filename); // 64 MB (1L << 26) is enough for our test cases + break; + } + + return device; + } + + private static string ConvertedClassName(bool forAzure = false) + { + // Make this all under one root folder named {prefix}, which is the base namespace name. All UT namespaces using this must start with this prefix. + const string prefix = "FASTER.test"; + Assert.IsTrue(TestContext.CurrentContext.Test.ClassName.StartsWith($"{prefix}."), $"Expected {prefix} prefix was not found"); + var suffix = TestContext.CurrentContext.Test.ClassName.Substring(prefix.Length + 1); + return forAzure ? suffix : $"{prefix}/{suffix}"; + } + + internal static string MethodTestDir => Path.Combine(TestContext.CurrentContext.TestDirectory, $"{ConvertedClassName()}_{TestContext.CurrentContext.Test.MethodName}"); + + internal static string AzureTestContainer + { + get + { + var container = ConvertedClassName(forAzure: true).Replace('.', '-').ToLower(); + Microsoft.Azure.Storage.NameValidator.ValidateContainerName(container); + return container; } } + + internal static string AzureTestDirectory => TestContext.CurrentContext.Test.MethodName; + + internal const string AzureEmulatedStorageString = "UseDevelopmentStorage=true;"; + + internal enum AllocatorType + { + FixedBlittable, + VarLenBlittable, + Generic + } + + internal static (Status status, TOutput output) GetSinglePendingResult<TKey, TValue, TInput, TOutput, TContext>(CompletedOutputIterator<TKey, TValue, TInput, TOutput, TContext> completedOutputs) + { + Assert.IsTrue(completedOutputs.Next()); + var result = (completedOutputs.Current.Status, completedOutputs.Current.Output); + Assert.IsFalse(completedOutputs.Next()); + completedOutputs.Dispose(); + return result; + } } } diff --git a/cs/test/TryEnqueueBasicTests.cs b/cs/test/TryEnqueueBasicTests.cs index 65e7d42d4..ec3b15c18 100644 --- a/cs/test/TryEnqueueBasicTests.cs +++ b/cs/test/TryEnqueueBasicTests.cs @@ -1,19 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. + using System; -using System.Buffers; -using System.Collections.Generic; -using System.IO; -using System.Linq; -using System.Threading; -using System.Threading.Tasks; using FASTER.core; using NUnit.Framework; - namespace FASTER.test { - //** Fundamental basic test for TryEnqueue that covers all the parameters in TryEnqueue //** Other tests in FasterLog.cs provide more coverage for TryEnqueue @@ -22,7 +15,7 @@ internal class TryEnqueueTests { private FasterLog log; private IDevice device; - private string path = Path.GetTempPath() + "TryEnqueueTests/"; + private string path; static readonly byte[] entry = new byte[100]; public enum TryEnqueueIteratorType @@ -43,36 +36,45 @@ private struct ReadOnlySpanBatch : IReadOnlySpanBatch [SetUp] public void Setup() { - // Clean up log files from previous test runs in case they weren't cleaned up - try { new DirectoryInfo(path).Delete(true); } - catch {} + path = TestUtils.MethodTestDir + "/"; - // Create devices \ log for test - device = Devices.CreateLogDevice(path + "TryEnqueue", deleteOnClose: true); - log = new FasterLog(new FasterLogSettings { LogDevice = device }); + // Clean up log files from previous test runs in case they weren't cleaned up + TestUtils.DeleteDirectory(path, wait: true); } [TearDown] public void TearDown() { - log.Dispose(); - device.Dispose(); + log?.Dispose(); + log = null; + device?.Dispose(); + device = null; // Clean up log files - try { new DirectoryInfo(path).Delete(true); } - catch { } + TestUtils.DeleteDirectory(path); } [Test] [Category("FasterLog")] [Category("Smoke")] - public void TryEnqueueBasicTest([Values] TryEnqueueIteratorType iteratorType) + public void TryEnqueueBasicTest([Values] TryEnqueueIteratorType iteratorType, [Values] TestUtils.DeviceType deviceType) { int entryLength = 50; int numEntries = 10000; int entryFlag = 9999; + // Create devices \ log for test + string filename = path + "TryEnqueue" + deviceType.ToString() + ".log"; + device = TestUtils.CreateTestDevice(deviceType, filename); + log = new FasterLog(new FasterLogSettings { LogDevice = device, SegmentSizeBits = 22, LogCommitDir = path }); + +#if WINDOWS + // Issue with Non Async Commit and Emulated Azure so don't run it - at least put after device creation to see if crashes doing that simple thing + if (deviceType == TestUtils.DeviceType.EmulatedAzure) + return; +#endif + // Reduce SpanBatch to make sure entry fits on page if (iteratorType == TryEnqueueIteratorType.SpanBatch) { @@ -88,7 +90,6 @@ public void TryEnqueueBasicTest([Values] TryEnqueueIteratorType iteratorType) ReadOnlySpanBatch spanBatch = new ReadOnlySpanBatch(numEntries); - // TryEnqueue but set each Entry in a way that can differentiate between entries for (int i = 0; i < numEntries; i++) { @@ -125,7 +126,7 @@ public void TryEnqueueBasicTest([Values] TryEnqueueIteratorType iteratorType) } // Verify each Enqueue worked - Assert.IsTrue(appendResult == true, "Fail - TryEnqueue failed with a 'false' result for entry:" + i.ToString()); + Assert.IsTrue(appendResult, "Fail - TryEnqueue failed with a 'false' result for entry: " + i.ToString()); // logical address has new entry every x bytes which is one entry less than the TailAddress if (iteratorType == TryEnqueueIteratorType.SpanBatch) @@ -133,15 +134,12 @@ public void TryEnqueueBasicTest([Values] TryEnqueueIteratorType iteratorType) else ExpectedOutAddress = log.TailAddress - 104; - Assert.IsTrue(logicalAddress == ExpectedOutAddress, "Fail - returned LogicalAddr: " + logicalAddress.ToString() + " is not equal to Expected LogicalAddr: " + ExpectedOutAddress.ToString()); + Assert.AreEqual(ExpectedOutAddress, logicalAddress); } // Commit to the log log.Commit(true); - // flag to make sure data has been checked - bool datacheckrun = false; - // Read the log - Look for the flag so know each entry is unique int currentEntry = 0; using (var iter = log.Scan(0, 100_000_000)) @@ -150,23 +148,19 @@ public void TryEnqueueBasicTest([Values] TryEnqueueIteratorType iteratorType) { if (currentEntry < entryLength) { - // set check flag to show got in here - datacheckrun = true; - // Span Batch only added first entry several times so have separate verification if (iteratorType == TryEnqueueIteratorType.SpanBatch) - Assert.IsTrue(result[0] == (byte)entryFlag, "Fail - Result[0]:"+result[0].ToString()+" entryFlag:"+entryFlag); + Assert.AreEqual((byte)entryFlag, result[0]); else - Assert.IsTrue(result[currentEntry] == (byte)entryFlag, "Fail - Result["+ currentEntry.ToString() + "]:" + result[0].ToString() + " entryFlag:" + entryFlag); + Assert.AreEqual((byte)entryFlag, result[currentEntry]); currentEntry++; } } } - // if data verification was skipped, then pop a fail - if (datacheckrun == false) - Assert.Fail("Failure -- data loop after log.Scan never entered so wasn't verified. "); + // Make sure expected length is same as current - also makes sure that data verification was not skipped + Assert.AreEqual(entryLength, currentEntry); } } diff --git a/cs/test/VLTestTypes.cs b/cs/test/VLTestTypes.cs index c27e266f9..cbe32c7d6 100644 --- a/cs/test/VLTestTypes.cs +++ b/cs/test/VLTestTypes.cs @@ -13,30 +13,21 @@ public struct Key : IFasterEqualityComparer<Key>, IVariableLengthStruct<Key> { public long key; - public long GetHashCode64(ref Key key) - { - return Utility.GetHashCode(key.key); - } - public bool Equals(ref Key k1, ref Key k2) - { - return k1.key == k2.key; - } + public long GetHashCode64(ref Key key) => Utility.GetHashCode(key.key); + + public bool Equals(ref Key k1, ref Key k2) => k1.key == k2.key; - public int GetLength(ref Key t) - { - return sizeof(long); - } + public int GetLength(ref Key t) => sizeof(long); - public int GetInitialLength() - { - return sizeof(long); - } + public int GetInitialLength() => sizeof(long); public unsafe void Serialize(ref Key source, void* destination) => Buffer.MemoryCopy(Unsafe.AsPointer(ref source), destination, GetLength(ref source), GetLength(ref source)); public unsafe ref Key AsRef(void* source) => ref Unsafe.AsRef<Key>(source); public unsafe void Initialize(void* source, void* dest) { } + + public override string ToString() => this.key.ToString(); } [StructLayout(LayoutKind.Explicit)] @@ -48,20 +39,15 @@ public unsafe struct VLValue : IFasterEqualityComparer<VLValue>, IVariableLength [FieldOffset(4)] public int field1; - public int GetInitialLength() - { - return 2 * sizeof(int); - } + public int GetInitialLength() => 2 * sizeof(int); - public int GetLength(ref VLValue t) - { - return sizeof(int) * t.length; - } + public int GetLength(ref VLValue t) => sizeof(int) * t.length; public unsafe void Serialize(ref VLValue source, void* destination) => Buffer.MemoryCopy(Unsafe.AsPointer(ref source), destination, GetLength(ref source), GetLength(ref source)); public unsafe ref VLValue AsRef(void* source) => ref Unsafe.AsRef<VLValue>(source); + public unsafe void Initialize(void* source, void* dest) { } public void ToIntArray(ref int[] dst) @@ -78,14 +64,10 @@ public void ToIntArray(ref int[] dst) public void CopyTo(ref VLValue dst) { var fulllength = GetLength(ref this); - Buffer.MemoryCopy(Unsafe.AsPointer(ref this), - Unsafe.AsPointer(ref dst), fulllength, fulllength); + Buffer.MemoryCopy(Unsafe.AsPointer(ref this), Unsafe.AsPointer(ref dst), fulllength, fulllength); } - public long GetHashCode64(ref VLValue k) - { - return Utility.GetHashCode(k.length) ^ Utility.GetHashCode(k.field1); - } + public long GetHashCode64(ref VLValue k) => Utility.GetHashCode(k.length) ^ Utility.GetHashCode(k.field1); public bool Equals(ref VLValue k1, ref VLValue k2) { @@ -98,26 +80,30 @@ public bool Equals(ref VLValue k1, ref VLValue k2) return false; return true; } + + public override string ToString() => $"len = {this.length}, field1 = {this.field1}"; } public struct Input { public long input; + + public override string ToString() => this.input.ToString(); } public class VLFunctions : FunctionsBase<Key, VLValue, Input, int[], Empty> { public override void RMWCompletionCallback(ref Key key, ref Input input, ref int[] output, Empty ctx, Status status) { - Assert.IsTrue(status == Status.OK); + Assert.AreEqual(Status.OK, status); } public override void ReadCompletionCallback(ref Key key, ref Input input, ref int[] output, Empty ctx, Status status) { - Assert.IsTrue(status == Status.OK); + Assert.AreEqual(Status.OK, status); for (int i = 0; i < output.Length; i++) { - Assert.IsTrue(output[i] == output.Length); + Assert.AreEqual(output.Length, output[i]); } } @@ -152,15 +138,15 @@ public class VLFunctions2 : FunctionsBase<VLValue, VLValue, Input, int[], Empty> { public override void RMWCompletionCallback(ref VLValue key, ref Input input, ref int[] output, Empty ctx, Status status) { - Assert.IsTrue(status == Status.OK); + Assert.AreEqual(Status.OK, status); } public override void ReadCompletionCallback(ref VLValue key, ref Input input, ref int[] output, Empty ctx, Status status) { - Assert.IsTrue(status == Status.OK); + Assert.AreEqual(Status.OK, status); for (int i = 0; i < output.Length; i++) { - Assert.IsTrue(output[i] == output.Length); + Assert.AreEqual(output.Length, output[i]); } } diff --git a/cs/test/VariableLengthIteratorTests.cs b/cs/test/VariableLengthIteratorTests.cs index 825d13d48..a3ae9e651 100644 --- a/cs/test/VariableLengthIteratorTests.cs +++ b/cs/test/VariableLengthIteratorTests.cs @@ -1,9 +1,9 @@ -using System; +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +using System; using System.Collections.Generic; -using System.Linq; using System.Runtime.CompilerServices; -using System.Text; -using System.Threading.Tasks; using FASTER.core; using NUnit.Framework; @@ -16,8 +16,10 @@ public class IteratorTests [Category("FasterKV")] public void ShouldSkipEmptySpaceAtEndOfPage() { + TestUtils.DeleteDirectory(TestUtils.MethodTestDir, wait: true); + var vlLength = new VLValue(); - var log = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/hlog-vl-iter.log", deleteOnClose: true); + var log = Devices.CreateLogDevice(TestUtils.MethodTestDir + "/hlog-vl-iter.log", deleteOnClose: true); var fht = new FasterKV<Key, VLValue> (128, new LogSettings { LogDevice = log, MemorySizeBits = 17, PageSizeBits = 10 }, // 1KB page @@ -70,6 +72,7 @@ public void ShouldSkipEmptySpaceAtEndOfPage() fht.Dispose(); log.Dispose(); } + TestUtils.DeleteDirectory(TestUtils.MethodTestDir); void Set(long keyValue, int length, int tag) { diff --git a/cs/test/VariableLengthStructFASTERTests.cs b/cs/test/VariableLengthStructFASTERTests.cs index 9f28b83d5..78348b109 100644 --- a/cs/test/VariableLengthStructFASTERTests.cs +++ b/cs/test/VariableLengthStructFASTERTests.cs @@ -7,32 +7,31 @@ namespace FASTER.test { - [TestFixture] internal class VariableLengthStructFASTERTests { // VarLenMax is the variable-length portion; 2 is for the fixed fields const int VarLenMax = 10; const int StackAllocMax = VarLenMax + 2; - int GetVarLen(Random r) => r.Next(VarLenMax) + 2; + static int GetVarLen(Random r) => r.Next(VarLenMax) + 2; [Test] [Category("FasterKV")] + [Category("Smoke")] public unsafe void VariableLengthTest1() { - FasterKV<Key, VLValue> fht; - IDevice log; - log = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/hlog1.log", deleteOnClose: true); - fht = new FasterKV<Key, VLValue> + TestUtils.DeleteDirectory(TestUtils.MethodTestDir, wait: true); + + var log = Devices.CreateLogDevice(TestUtils.MethodTestDir + "/hlog1.log", deleteOnClose: true); + var fht = new FasterKV<Key, VLValue> (128, new LogSettings { LogDevice = log, MemorySizeBits = 17, PageSizeBits = 12 }, null, null, null, new VariableLengthStructSettings<Key, VLValue> { valueLength = new VLValue() } ); - var s = fht.NewSession(new VLFunctions()); Input input = default; - Random r = new Random(100); + Random r = new(100); // Single alloc outside the loop, to the max length we'll need. int* val = stackalloc int[StackAllocMax]; @@ -64,38 +63,36 @@ public unsafe void VariableLengthTest1() } else { - Assert.IsTrue(status == Status.OK); - Assert.IsTrue(output.Length == len); + Assert.AreEqual(Status.OK, status); + Assert.AreEqual(len, output.Length); for (int j = 0; j < len; j++) { - Assert.IsTrue(output[j] == len); + Assert.AreEqual(len, output[j]); } } } s.Dispose(); fht.Dispose(); - fht = null; log.Dispose(); + TestUtils.DeleteDirectory(TestUtils.MethodTestDir); } [Test] [Category("FasterKV")] public unsafe void VariableLengthTest2() { - FasterKV<VLValue, VLValue> fht; - IDevice log; - log = Devices.CreateLogDevice(TestContext.CurrentContext.TestDirectory + "/hlog1.log", deleteOnClose: true); - fht = new FasterKV<VLValue, VLValue> + TestUtils.DeleteDirectory(TestUtils.MethodTestDir, wait: true); + + var log = Devices.CreateLogDevice(TestUtils.MethodTestDir + "/hlog1.log", deleteOnClose: true); + var fht = new FasterKV<VLValue, VLValue> (128, new LogSettings { LogDevice = log, MemorySizeBits = 17, PageSizeBits = 12 }, null, null, null, new VariableLengthStructSettings<VLValue, VLValue> { keyLength = new VLValue(), valueLength = new VLValue() } ); - - var s = fht.NewSession(new VLFunctions2()); Input input = default; - Random r = new Random(100); + Random r = new(100); // Single alloc outside the loop, to the max length we'll need. int* keyval = stackalloc int[StackAllocMax]; @@ -137,20 +134,19 @@ public unsafe void VariableLengthTest2() } else { - Assert.IsTrue(status == Status.OK); - Assert.IsTrue(output.Length == len); + Assert.AreEqual(Status.OK, status); + Assert.AreEqual(len, output.Length); for (int j = 0; j < len; j++) { - Assert.IsTrue(output[j] == len); + Assert.AreEqual(len, output[j]); } } } s.Dispose(); fht.Dispose(); - fht = null; log.Dispose(); + TestUtils.DeleteDirectory(TestUtils.MethodTestDir); } - } } diff --git a/cs/test/WaitForCommit.cs b/cs/test/WaitForCommit.cs index ead71f196..1d507823d 100644 --- a/cs/test/WaitForCommit.cs +++ b/cs/test/WaitForCommit.cs @@ -1,33 +1,29 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -using System; -using System.Buffers; -using System.Collections.Generic; -using System.IO; -using System.Linq; + using System.Threading; -using System.Threading.Tasks; using FASTER.core; using NUnit.Framework; - namespace FASTER.test { - [TestFixture] internal class WaitForCommitTests { - public FasterLog log; + static FasterLog log; public IDevice device; - private string path = Path.GetTempPath() + "WaitForCommitTests/"; + private string path; static readonly byte[] entry = new byte[10]; + static readonly AutoResetEvent ev = new(false); + static readonly AutoResetEvent done = new(false); [SetUp] public void Setup() { + path = TestUtils.MethodTestDir + "/"; + // Clean up log files from previous test runs in case they weren't cleaned up - try { new DirectoryInfo(path).Delete(true); } - catch { } + TestUtils.DeleteDirectory(path, wait:true); // Create devices \ log for test device = Devices.CreateLogDevice(path + "WaitForCommit", deleteOnClose: true); @@ -37,26 +33,28 @@ public void Setup() [TearDown] public void TearDown() { - log.Dispose(); - device.Dispose(); + log?.Dispose(); + log = null; + device?.Dispose(); + device = null; // Clean up log files - try { new DirectoryInfo(path).Delete(true); } - catch { } + TestUtils.DeleteDirectory(path); } [TestCase("Sync")] // use string here instead of Bool so shows up in Test Explorer with more descriptive name [TestCase("Async")] [Test] [Category("FasterLog")] + [Category("Smoke")] public void WaitForCommitBasicTest(string SyncTest) { - - CancellationTokenSource cts = new CancellationTokenSource(); + CancellationTokenSource cts = new(); CancellationToken token = cts.Token; // make it small since launching each on separate threads int entryLength = 10; + int expectedEntries = 3; // Not entry length because this is number of enqueues called // Set Default entry data for (int i = 0; i < entryLength; i++) @@ -64,31 +62,18 @@ public void WaitForCommitBasicTest(string SyncTest) entry[i] = (byte)i; } - Task currentTask; - - // Enqueue and Commit in a separate thread (wait there until commit is done though). + // Enqueue / WaitForCommit on a task (that will be waited) until the Commit on the separate thread is done if (SyncTest == "Sync") { - currentTask = Task.Run(() => LogWriter(log, entry), token); + new Thread(new ThreadStart(LogWriter)).Start(); } else { - currentTask = Task.Run(() => LogWriterAsync(log, entry), token); + new Thread(new ThreadStart(LogWriterAsync)).Start(); } - // Give all a second or so to queue up and to help with timing issues - shouldn't need but timing issues - Thread.Sleep(2000); - - // Commit to the log + ev.WaitOne(); log.Commit(true); - currentTask.Wait(4000, token); - - // double check to make sure finished - seen cases where timing kept running even after commit done - if (currentTask.Status != TaskStatus.RanToCompletion) - cts.Cancel(); - - // flag to make sure data has been checked - bool datacheckrun = false; // Read the log to make sure all entries are put in int currentEntry = 0; @@ -98,48 +83,40 @@ public void WaitForCommitBasicTest(string SyncTest) { if (currentEntry < entryLength) { - // set check flag to show got in here - datacheckrun = true; - - Assert.IsTrue(result[currentEntry] == (byte)currentEntry, "Fail - Result[" + currentEntry.ToString() + "]:" + result[0].ToString() + " not match expected:" + currentEntry); - + Assert.AreEqual((byte)currentEntry, result[currentEntry]); currentEntry++; } } } - // if data verification was skipped, then pop a fail - if (datacheckrun == false) - Assert.Fail("Failure -- data loop after log.Scan never entered so wasn't verified. "); + // Make sure expected entries is same as current - also makes sure that data verification was not skipped + Assert.AreEqual(expectedEntries, currentEntry,$"expectedEntries:{expectedEntries} does not equal currentEntry:{currentEntry}"); - // NOTE: seeing issues where task is not running to completion on Release builds - // This is a final check to make sure task finished. If didn't then assert - // One note - if made it this far, know that data was Enqueue and read properly, so just - // case of task not stopping - if (currentTask.Status != TaskStatus.RanToCompletion) - { - Assert.Fail("Final Status check Failure -- Task should be 'RanToCompletion' but current Status is:" + currentTask.Status); - } + done.WaitOne(); } - static void LogWriter(FasterLog log, byte[] entry) + static void LogWriter() { // Enter in some entries then wait on this separate thread log.Enqueue(entry); log.Enqueue(entry); log.Enqueue(entry); + ev.Set(); log.WaitForCommit(log.TailAddress); + done.Set(); } - static async Task LogWriterAsync(FasterLog log, byte[] entry) + static void LogWriterAsync() { + // Using "await" here will kick out of the calling thread once the first await is finished // Enter in some entries then wait on this separate thread - await log.EnqueueAsync(entry); - await log.EnqueueAsync(entry); - await log.EnqueueAsync(entry); - await log.WaitForCommitAsync(log.TailAddress); + log.EnqueueAsync(entry).AsTask().GetAwaiter().GetResult(); + log.EnqueueAsync(entry).AsTask().GetAwaiter().GetResult(); + log.EnqueueAsync(entry).AsTask().GetAwaiter().GetResult(); + ev.Set(); + log.WaitForCommitAsync(log.TailAddress).AsTask().GetAwaiter().GetResult(); + done.Set(); } - } } diff --git a/docs/_data/navigation.yml b/docs/_data/navigation.yml index 3b7d27a6e..7151b3696 100644 --- a/docs/_data/navigation.yml +++ b/docs/_data/navigation.yml @@ -44,6 +44,8 @@ docs: children: - title: "Basics" url: /docs/remote-basics/ + - title: "Pub-sub" + url: /docs/remote-pubsub/ - title: Developer Guide children: diff --git a/docs/_docs/20-fasterkv-basics.md b/docs/_docs/20-fasterkv-basics.md index 7602a5dc2..7dc42e9d7 100644 --- a/docs/_docs/20-fasterkv-basics.md +++ b/docs/_docs/20-fasterkv-basics.md @@ -273,7 +273,7 @@ public static void Test() { using var log = Devices.CreateLogDevice("C:\\Temp\\hlog.log"); using var store = new FasterKV<long, long>(1L << 20, new LogSettings { LogDevice = log }); - using var s = store.NewSession(new SimpleFunctions<long, long>()); + using var s = store.NewSession(new SimpleFunctions<long, long>((a, b) => a + b)); long key = 1, value = 1, input = 10, output = 0; s.Upsert(ref key, ref value); s.Read(ref key, ref output); diff --git a/docs/_docs/23-fasterkv-tuning.md b/docs/_docs/23-fasterkv-tuning.md index d063dc53c..ebbd308af 100644 --- a/docs/_docs/23-fasterkv-tuning.md +++ b/docs/_docs/23-fasterkv-tuning.md @@ -35,21 +35,22 @@ the device via `IDisposable`. * `PageSizeBits`: This field (P) is used to indicate the size of each page. It is provided in terms of the number of bits. You get the actual size of the page by a simple computation of two raised to the power of the number specified (2<sup>P</sup>). For example, when `PageSizeBits` P is set to 12, it represents pages of size 4KB -(since 2<sup>P</sup> = 2<sup>10</sup> = 4096 = 4KB). Generally you should not need to adjust page size from its -default value of 25 (= 32MB). You cannot use a page size smaller than the device sector size (512 bytes by -default), which translates to a `PageSizeBits` value of at least 9. +(since 2<sup>P</sup> = 2<sup>12</sup> = 4096 = 4KB). Generally you should not need to adjust `PageSizeBits` from its +default value of 25 (2<sup>25</sup> = 32MB page size). You cannot use a page size smaller than the device sector +size (512 bytes by default), which translates to a `PageSizeBits` value of at least 9. -* `MemorySizeBits`: This field (M) indicates the total size of memory used by the log. As before, for a setting +* `MemorySizeBits`: This field (M) is used to indicate the total size of memory used by the log. As before, for a setting of M, 2<sup>M</sup> is the number of bytes used totally by the log. Since each page is of size 2<sup>P</sup>, the -number of pages in memory is simply 2<sup>M-P</sup>. FASTER requires at least 2 pages in memory, so M should be -set to at least P+1. +number of pages in memory is simply 2<sup>M-P</sup>. FASTER requires at least 1 page in memory, so `MemorySizeBits` should be +set to at least P. * `MutableFraction`: This field (F) indicates the fraction of the memory that will be treated as _mutable_, i.e., updates are performed in-place instead of creating a new version on the tail of the log. A value of 0.9 (the default) indicates that 90% of memory will be mutable. -* `SegmentSize`: On disk, the data is written to files in coarse-grained chunks called _segments_. We can size -each chunk independently of pages, as one segment typically consists of many pages. For instance, if we want +* `SegmentSizeBits`: On disk, the data is written to files in coarse-grained chunks called _segments_. We can size +each chunk independently of pages, as one segment typically consists of many pages. `SegmentSizeBits` is used to +indicate the size of each segment. As before, it is specified in bits. For instance, if we want each file on disk to be 1GB (the default), we can set `SegmentSizeBits` S to 30, since 2<sup>30</sup> = 1GB. * `CopyReadsToTail`: This enum setting indicates whether reads should be copied to the tail of the log. This diff --git a/docs/_docs/43-fasterlog-tuning.md b/docs/_docs/43-fasterlog-tuning.md index 8bcec7567..d19e5857d 100644 --- a/docs/_docs/43-fasterlog-tuning.md +++ b/docs/_docs/43-fasterlog-tuning.md @@ -28,16 +28,18 @@ You can use our extension method to easily create an instance of a local device: * `PageSizeBits`: This field (P) is used to indicate the size of each page. It is provided in terms of the number of bits. You get the actual size of the page by a simple computation of two raised to the power of the number specified (2<sup>P</sup>). For example, when `PageSizeBits` P is set to 12, it represents pages of size 4KB -(since 2<sup>P</sup> = 2<sup>10</sup> = 4096 = 4KB). Generally you should not need to adjust page size from its -default value of 25 (= 32MB). +(since 2<sup>P</sup> = 2<sup>12</sup> = 4096 = 4KB). Generally you should not need to adjust `PageSizeBits` from its +default value of 25 (2<sup>25</sup> = 32MB page size). You cannot use a page size smaller than the device sector +size (512 bytes by default), which translates to a `PageSizeBits` value of at least 9. -* `MemorySizeBits`: This field (M) indicates the total size of memory used by the log. As before, for a setting +* `MemorySizeBits`: This field (M) is used to indicate the total size of memory used by the log. As before, for a setting of M, 2<sup>M</sup> is the number of bytes used totally by the log. Since each page is of size 2<sup>P</sup>, the -number of pages in memory is simply 2<sup>M-P</sup>. FASTER requires at least 2 pages in memory, so M should be -set to at least P+1. +number of pages in memory is simply 2<sup>M-P</sup>. FASTER requires at least 1 page in memory, so `MemorySizeBits` should be +set to at least P. -* `SegmentSize`: On disk, the data is written to files in coarse-grained chunks called _segments_. We can size -each chunk independently of pages, as one segment typically consists of many pages. For instance, if we want +* `SegmentSizeBits`: On disk, the data is written to files in coarse-grained chunks called _segments_. We can size +each chunk independently of pages, as one segment typically consists of many pages. `SegmentSizeBits` is used to +indicate the size of each segment. As before, it is specified in bits. For instance, if we want each file on disk to be 1GB (the default), we can set `SegmentSizeBits` S to 30, since 2<sup>30</sup> = 1GB. * `MutableFraction`: This field (F) indicates the fraction of the memory that is marked as _mutable_, i.e., diff --git a/docs/_docs/51-remote-pubsub.md b/docs/_docs/51-remote-pubsub.md new file mode 100644 index 000000000..3dc1f3f96 --- /dev/null +++ b/docs/_docs/51-remote-pubsub.md @@ -0,0 +1,61 @@ +--- +title: "Remote FASTER - Publish/Subscribe" +permalink: /docs/remote-pubsub/ +excerpt: "Remote FASTER - Pub/Sub" +last_modified_at: 2021-10-21 +toc: false +classes: wide +--- + +FASTER now supports Publish/Subscribe with remote clients. Clients can now subscribe to keys or prefixes of keys and get all updates to the values of +the keys made by any other client. Following are the options for Publish/Subscribe: + +1. **Publish/Subscribe with KV**: +One or multiple clients can subscribe to the updates of a key or pattern stored in FASTER. Whenever there is a call for `Upsert()` or `RMW()` for the +subscribed key or pattern, all the subscribers for the key/pattern are notified of the change. + +2. **Publish/Subscribe without KV**: +One or multiple clients can subscribe to a key or pattern that is not stored in the FASTER key-value store. Whenever there is a call to `Publish()` +a key or pattern, all the subscribers for the key/pattern are notified of the change. + +The basic approach in order to use Publish/Subscribe (with and without KV) is as follows: + +## Creating Subscribe(KV)Broker + +You can create a Subscribe(KV)Broker with either fixed-size (blittable struct) Key and Value types or variable-sized (varlen) Key and Value types +similar to byte arrays. The Subscribe(KV)Broker must be created along with the `FASTERServer`, and passed to the provider. +`FasterLog` is used by the broker for storing the keys and values until they are forwarded to the subscribers. + +The method of creating a SubscribeBroker is as follows: + +```cs +var kvBroker = new SubscribeKVBroker<SpanByte, SpanByte, SpanByte, IKeyInputSerializer<SpanByte, SpanByte>>(new SpanByteKeySerializer(), null, true); +var broker = new SubscribeBroker<SpanByte, SpanByte, IKeySerializer<SpanByte>>(new SpanByteKeySerializer(), null, true); +``` + +The first argument is a `IKeySerializer` used for serializing/deserializing keys for pattern-based subscriptions. Second argument is the location +of the log directory used for `FasterLog`. The last argument is a boolean, whether the `FasterLog` should start fresh, or should recover from the +previous state. + +## Subscribing to a Key / Pattern from clients: + +A `FASTERClient` can subscribe to a key or glob-pattern with the following command: + +```cs +clientSession.Subscribe(key); // Used for subscribing to a key that is not stored in FasterKV +clientSession.PSubscribe(pattern); // Used for subscribing to a glob-style pattern that is not stored in FasterKV +clientSession.SubscribeKV(key); // Used for subscribing to a key that is stored in FasterKV +clientSession.PSubscribe(pattern); // Used for subscribing to a glob-style pattern that is stored in FasterKV +``` + +The clientSession can be used to subscribe to multiple keys or patterns in the same session. Once a key or pattern is subscribed, +the clientSession cannot accept other commands (such as `Upsert()`, `RMW()`, etc) until all the keys or patterns are unsubscribed. + +## Publishing a key from a client: + +a `FASTERClient` can publish a key and value, for pushing the updated value for the key to all its subscribers either synchronously or asynchronously. +```cs +clientSession.Publish(key, value); // Used for publishing a key and value that is not stored in FasterKV, asynchronously +clientSession.PublishNow(key, value); // Used for publishing a key and value and is not stored in FasterKV, synchronously +``` +For the case of (P)SubscribeKV, the key and value is automatically pushed to the subscribers on `Upsert()` or `RMW()`. diff --git a/docs/_docs/80-build-and-test.md b/docs/_docs/80-build-and-test.md index 6f30ae4a4..aa68d4441 100644 --- a/docs/_docs/80-build-and-test.md +++ b/docs/_docs/80-build-and-test.md @@ -9,4 +9,4 @@ classes: wide For C#, clone the Git repo, open cs/FASTER.sln in Visual Studio, and build. -For C++, click [here](/docs/fasterkv-cpp/). +For C++, click [here](/FASTER/docs/fasterkv-cpp/). diff --git a/docs/_pages/home.md b/docs/_pages/home.md index 3f4dd197c..0188146f6 100644 --- a/docs/_pages/home.md +++ b/docs/_pages/home.md @@ -10,7 +10,7 @@ header: url: "/docs/quick-start-guide/" excerpt: > A fast concurrent persistent key-value store and log, in C# and C++.<br /> - <small><a href="https://github.com/microsoft/FASTER/releases/tag/v1.9.5">Latest release v1.9.5</a></small> + <small><a href="https://github.com/microsoft/FASTER/releases/tag/v1.9.6">Latest release v1.9.6</a></small> features: - image_path: /assets/images/faster-feature-1.png alt: "feature1" @@ -99,7 +99,8 @@ public static void Main() ); // Create a session per sequence of interactions with FASTER - using var s = store.NewSession(new SimpleFunctions<long, long>()); + // We use default callback functions with a custom merger: RMW merges input by adding it to value + using var s = store.NewSession(new SimpleFunctions<long, long>((a, b) => a + b)); long key = 1, value = 1, input = 10, output = 0; // Upsert and Read