diff --git a/.github/workflows/kvrocks.yaml b/.github/workflows/kvrocks.yaml index 45d1916aeed..1d8c24eae0d 100644 --- a/.github/workflows/kvrocks.yaml +++ b/.github/workflows/kvrocks.yaml @@ -77,7 +77,7 @@ jobs: run: | sudo apt update sudo apt install -y clang-format-14 clang-tidy-14 - - uses: apache/skywalking-eyes/header@v0.5.0 + - uses: apache/skywalking-eyes/header@v0.6.0 with: config: .github/config/licenserc.yml - name: Check with clang-format @@ -110,21 +110,25 @@ jobs: fail-fast: false matrix: include: - - name: Darwin Clang - os: macos-11 + # FIXME: update macos-11 to macos-12/13 + # - name: Darwin Clang + # os: macos-11 + # compiler: auto + - name: Darwin Clang arm64 + os: macos-14 compiler: auto - - name: Darwin Clang without Jemalloc - os: macos-11 - compiler: auto - without_jemalloc: -DDISABLE_JEMALLOC=ON - - name: Darwin Clang with OpenSSL - os: macos-11 - compiler: auto - with_openssl: -DENABLE_OPENSSL=ON - - name: Darwin Clang without luaJIT - os: macos-11 - compiler: auto - without_luajit: -DENABLE_LUAJIT=OFF + # - name: Darwin Clang without Jemalloc + # os: macos-11 + # compiler: auto + # without_jemalloc: -DDISABLE_JEMALLOC=ON + # - name: Darwin Clang with OpenSSL + # os: macos-11 + # compiler: auto + # with_openssl: -DENABLE_OPENSSL=ON + # - name: Darwin Clang without luaJIT + # os: macos-11 + # compiler: auto + # without_luajit: -DENABLE_LUAJIT=OFF - name: Ubuntu GCC os: ubuntu-20.04 compiler: gcc @@ -163,6 +167,11 @@ jobs: without_jemalloc: -DDISABLE_JEMALLOC=ON compiler: clang ignore_when_tsan: -tags="ignore_when_tsan" + - name: Ubuntu Clang UBSAN + os: ubuntu-20.04 + with_sanitizer: -DENABLE_UBSAN=ON + without_jemalloc: -DDISABLE_JEMALLOC=ON + compiler: clang - name: Ubuntu GCC Ninja os: ubuntu-20.04 with_ninja: --ninja @@ -172,7 +181,7 @@ jobs: compiler: gcc with_openssl: -DENABLE_OPENSSL=ON - name: Ubuntu Clang with OpenSSL - os: ubuntu-20.04 + os: ubuntu-22.04 compiler: clang with_openssl: -DENABLE_OPENSSL=ON - name: Ubuntu GCC without luaJIT @@ -183,14 +192,14 @@ jobs: os: ubuntu-20.04 without_luajit: -DENABLE_LUAJIT=OFF compiler: clang - - name: Ubuntu GCC with new encoding + - name: Ubuntu GCC with old encoding os: ubuntu-20.04 compiler: gcc - new_encoding: -DENABLE_NEW_ENCODING=TRUE - - name: Ubuntu Clang with new encoding - os: ubuntu-20.04 + new_encoding: -DENABLE_NEW_ENCODING=FALSE + - name: Ubuntu Clang with old encoding + os: ubuntu-22.04 compiler: clang - new_encoding: -DENABLE_NEW_ENCODING=TRUE + new_encoding: -DENABLE_NEW_ENCODING=FALSE - name: Ubuntu GCC with speedb enabled os: ubuntu-20.04 compiler: gcc @@ -219,14 +228,22 @@ jobs: with: path: | ~/local/bin/redis-cli - key: ${{ runner.os }}-redis-cli + key: ${{ runner.os }}-${{ runner.arch }}-redis-cli + - name: Cache redis server + id: cache-redis-server + uses: actions/cache@v4 + with: + path: | + ~/local/bin/redis-server + key: ${{ runner.os }}-${{ runner.arch }}-redis-server - name: Install redis - if: steps.cache-redis.outputs.cache-hit != 'true' + if: ${{ steps.cache-redis.outputs.cache-hit != 'true' || steps.cache-redis-server.outputs.cache-hit != 'true' }} run: | - curl -O https://download.redis.io/releases/redis-6.2.7.tar.gz - tar -xzvf redis-6.2.7.tar.gz + curl -O https://download.redis.io/releases/redis-6.2.14.tar.gz + tar -xzvf redis-6.2.14.tar.gz mkdir -p $HOME/local/bin - pushd redis-6.2.7 && BUILD_TLS=yes make -j$NPROC redis-cli && mv src/redis-cli $HOME/local/bin/ && popd + pushd redis-6.2.14 && BUILD_TLS=yes make -j$NPROC redis-cli && mv src/redis-cli $HOME/local/bin/ && popd + pushd redis-6.2.14 && BUILD_TLS=yes make -j$NPROC redis-server && mv src/redis-server $HOME/local/bin/ && popd - uses: actions/checkout@v4 with: @@ -257,6 +274,8 @@ jobs: - name: Build Kvrocks (SonarCloud) if: ${{ matrix.sonarcloud }} run: | + build-wrapper-linux-x86-64 --out-dir ${{ env.SONARCLOUD_OUTPUT_DIR }} ./x.py build -j$NPROC --compiler ${{ matrix.compiler }} --skip-build + cp -r build _build build-wrapper-linux-x86-64 --out-dir ${{ env.SONARCLOUD_OUTPUT_DIR }} ./x.py build -j$NPROC --unittest --compiler ${{ matrix.compiler }} ${{ matrix.sonarcloud }} - name: Setup Coredump @@ -281,7 +300,7 @@ jobs: GOCASE_RUN_ARGS="" if [[ -n "${{ matrix.with_openssl }}" ]] && [[ "${{ matrix.os }}" == ubuntu* ]]; then git clone https://github.com/jsha/minica - cd minica && go build && cd .. + cd minica && git checkout 96a5c93723cf3d34b50b3e723a9f05cd3765bc67 && go build && cd .. ./minica/minica --domains localhost cp localhost/cert.pem tests/gocase/tls/cert/server.crt cp localhost/key.pem tests/gocase/tls/cert/server.key @@ -290,6 +309,29 @@ jobs: fi ./x.py test go build $GOCASE_RUN_ARGS ${{ matrix.ignore_when_tsan}} + - name: Install redis-py + run: pip3 install redis==4.3.6 + + - name: Run kvrocks2redis Test + # Currently, when enabling Tsan/Asan or running in macOS 11/14, the value mismatch in destination redis server. + # See https://github.com/apache/kvrocks/issues/2195. + if: ${{ !contains(matrix.name, 'Tsan') && !contains(matrix.name, 'Asan') && !startsWith(matrix.os, 'macos') }} + run: | + ulimit -c unlimited + export LSAN_OPTIONS="suppressions=$(realpath ./tests/lsan-suppressions)" + export TSAN_OPTIONS="suppressions=$(realpath ./tests/tsan-suppressions)" + $HOME/local/bin/redis-server --daemonize yes + mkdir -p kvrocks2redis-ci-data + ./build/kvrocks --dir `pwd`/kvrocks2redis-ci-data --pidfile `pwd`/kvrocks.pid --daemonize yes + sleep 10s + echo -en "data-dir `pwd`/kvrocks2redis-ci-data\ndaemonize yes\noutput-dir ./\nnamespace.__namespace 127.0.0.1 6379\n" >> ./kvrocks2redis-ci.conf + cat ./kvrocks2redis-ci.conf + ./build/kvrocks2redis -c ./kvrocks2redis-ci.conf + sleep 10s + python3 utils/kvrocks2redis/tests/populate-kvrocks.py --password="" --flushdb=true + sleep 10s + python3 utils/kvrocks2redis/tests/check_consistency.py --src_password="" + - name: Find reports and crashes if: always() run: | @@ -335,13 +377,15 @@ jobs: uses: actions/upload-artifact@v4 with: name: sonarcloud-data - path: ${{ env.SONARCLOUD_OUTPUT_DIR }} + path: | + ${{ env.SONARCLOUD_OUTPUT_DIR }} + _build check-docker: name: Check Docker image needs: [precondition, check-and-lint, check-typos] if: ${{ needs.precondition.outputs.docs_only != 'true' }} - runs-on: ubuntu-20.04 + runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v4 - name: Get core numbers @@ -394,7 +438,7 @@ jobs: if: ${{ startsWith(matrix.image, 'centos') }} run: | yum install -y centos-release-scl-rh - yum install -y devtoolset-11 python3 autoconf automake wget git gcc gcc-c++ + yum install -y devtoolset-11 python3 python3-pip autoconf automake wget git gcc gcc-c++ echo "NPROC=$(nproc)" >> $GITHUB_ENV mv /usr/bin/gcc /usr/bin/gcc-4.8.5 ln -s /opt/rh/devtoolset-11/root/bin/gcc /usr/bin/gcc @@ -405,13 +449,13 @@ jobs: if: ${{ startsWith(matrix.image, 'archlinux') }} run: | pacman -Syu --noconfirm - pacman -Sy --noconfirm autoconf automake python3 git wget which cmake make gcc + pacman -Sy --noconfirm autoconf automake python3 python-redis git wget which cmake make gcc echo "NPROC=$(nproc)" >> $GITHUB_ENV - name: Setup openSUSE if: ${{ startsWith(matrix.image, 'opensuse') }} run: | - zypper install -y gcc11 gcc11-c++ make wget git autoconf automake python3 curl tar gzip cmake go + zypper install -y gcc11 gcc11-c++ make wget git autoconf automake python3 python3-pip curl tar gzip cmake go update-alternatives --install /usr/bin/cc cc /usr/bin/gcc-11 100 update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++-11 100 update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-11 100 @@ -426,19 +470,29 @@ jobs: ~/local/bin/redis-cli key: ${{ matrix.image }}-redis-cli + - name: Cache redis server + id: cache-redis-server + uses: actions/cache@v3 + with: + path: | + ~/local/bin/redis-server + key: ${{ matrix.image }}-redis-server + - name: Install redis - if: steps.cache-redis.outputs.cache-hit != 'true' + if: ${{ steps.cache-redis.outputs.cache-hit != 'true' || steps.cache-redis-server.outputs.cache-hit != 'true' }} run: | - curl -O https://download.redis.io/releases/redis-6.2.7.tar.gz - tar -xzvf redis-6.2.7.tar.gz + curl -O https://download.redis.io/releases/redis-6.2.14.tar.gz + tar -xzvf redis-6.2.14.tar.gz mkdir -p $HOME/local/bin - pushd redis-6.2.7 && USE_JEMALLOC=no make -j$NPROC redis-cli && mv src/redis-cli $HOME/local/bin/ && popd + pushd redis-6.2.14 && USE_JEMALLOC=no make -j$NPROC redis-cli && mv src/redis-cli $HOME/local/bin/ && popd + pushd redis-6.2.14 && USE_JEMALLOC=no make -j$NPROC redis-server && mv src/redis-server $HOME/local/bin/ && popd - name: Install cmake if: ${{ startsWith(matrix.image, 'centos') }} run: | - wget https://github.com/Kitware/CMake/releases/download/v3.26.4/cmake-3.26.4-linux-x86_64.sh - bash cmake-3.26.4-linux-x86_64.sh --skip-license --prefix=/usr + VERSION=3.26.4 + wget https://github.com/Kitware/CMake/releases/download/v$VERSION/cmake-$VERSION-linux-x86_64.sh + bash cmake-$VERSION-linux-x86_64.sh --skip-license --prefix=/usr - uses: actions/checkout@v3 #v4 use Node 20 and not working at CentOS 7 - uses: actions/setup-go@v4 #v5 use Node 20 too @@ -461,6 +515,24 @@ jobs: GOCASE_RUN_ARGS="" ./x.py test go build $GOCASE_RUN_ARGS + - name: Install redis-py + if: ${{ !startsWith(matrix.image, 'archlinux') }} # already installed + run: pip3 install redis==4.3.6 + + - name: Run kvrocks2redis Test + run: | + $HOME/local/bin/redis-server --daemonize yes + mkdir -p kvrocks2redis-ci-data + ./build/kvrocks --dir `pwd`/kvrocks2redis-ci-data --pidfile `pwd`/kvrocks.pid --daemonize yes + sleep 10s + echo -en "data-dir `pwd`/kvrocks2redis-ci-data\ndaemonize yes\noutput-dir ./\nnamespace.__namespace 127.0.0.1 6379\n" >> ./kvrocks2redis-ci.conf + cat ./kvrocks2redis-ci.conf + ./build/kvrocks2redis -c ./kvrocks2redis-ci.conf + sleep 10s + python3 utils/kvrocks2redis/tests/populate-kvrocks.py --password="" --flushdb=true + sleep 10s + python3 utils/kvrocks2redis/tests/check_consistency.py --src_password="" + required: if: always() name: Required diff --git a/.github/workflows/pr-lint.yaml b/.github/workflows/pr-lint.yaml new file mode 100644 index 00000000000..a67683cc19e --- /dev/null +++ b/.github/workflows/pr-lint.yaml @@ -0,0 +1,91 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +name: "Lint PR" + +on: + pull_request_target: + types: + - opened + - edited + - synchronize + +permissions: + pull-requests: read + +jobs: + main: + name: Validate PR title + runs-on: ubuntu-latest + steps: + - uses: amannn/action-semantic-pull-request@v5 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + # Configure which types are allowed (newline-delimited). + # Default: https://github.com/commitizen/conventional-commit-types + types: | + fix + feat + build + chore + ci + docs + perf + refactor + revert + style + test + # Configure which scopes are allowed (newline-delimited). + # These are regex patterns auto-wrapped in `^ $`. + scopes: | + \S+ + # Configure that a scope must always be provided. + requireScope: false + # Configure which scopes are disallowed in PR titles (newline-delimited). + # For instance by setting the value below, `chore(release): ...` (lowercase) + # and `ci(e2e,release): ...` (unknown scope) will be rejected. + # These are regex patterns auto-wrapped in `^ $`. + # disallowScopes: | + # release + # [A-Z]+ + # Configure additional validation for the subject based on a regex. + # This example ensures the subject doesn't start with an uppercase character. + # subjectPattern: ^(?![A-Z]).+$ + # If `subjectPattern` is configured, you can use this property to override + # the default error message that is shown when the pattern doesn't match. + # The variables `subject` and `title` can be used within the message. + # subjectPatternError: | + # The subject "{subject}" found in the pull request title "{title}" + # didn't match the configured pattern. Please ensure that the subject + # doesn't start with an uppercase character. + # The GitHub base URL will be automatically set to the correct value from the GitHub context variable. + # If you want to override this, you can do so here (not recommended). + # githubBaseUrl: https://github.myorg.com/api/v3 + # If the PR contains one of these newline-delimited labels, the + # validation is skipped. If you want to rerun the validation when + # labels change, you might want to use the `labeled` and `unlabeled` + # event triggers in your workflow. + ignoreLabels: | + disable-pr-lint + # If you're using a format for the PR title that differs from the traditional Conventional + # Commits spec, you can use these options to customize the parsing of the type, scope and + # subject. The `headerPattern` should contain a regex where the capturing groups in parentheses + # correspond to the parts listed in `headerPatternCorrespondence`. + # See: https://github.com/conventional-changelog/conventional-changelog/tree/master/packages/conventional-commits-parser#headerpattern + # headerPattern: '^(\w*)(?:\(([\w$.\-*/ ]*)\))?: (.*)$' + # headerPatternCorrespondence: type, scope, subject diff --git a/.github/workflows/sonar.yaml b/.github/workflows/sonar.yaml index a106823cd6f..e47564eebec 100644 --- a/.github/workflows/sonar.yaml +++ b/.github/workflows/sonar.yaml @@ -57,15 +57,14 @@ jobs: fs.writeFileSync(`${process.env.GITHUB_WORKSPACE}/sonarcloud-data.zip`, Buffer.from(download.data)); - name: 'Unzip code coverage' run: | - unzip sonarcloud-data.zip -d sonarcloud-data - ls -a sonarcloud-data + unzip sonarcloud-data.zip + mv _build build + mkdir -p build/CMakeFiles/CMakeTmp + ls -a sonarcloud-data build - uses: actions/setup-python@v5 with: python-version: 3.x - - name: Configure Kvrocks - run: | - ./x.py build -j$(nproc) --compiler gcc --skip-build - name: Run sonar-scanner env: diff --git a/CMakeLists.txt b/CMakeLists.txt index 3e384e03665..1719aa0c38c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -23,16 +23,17 @@ project(kvrocks option(DISABLE_JEMALLOC "disable use of the jemalloc library" OFF) option(ENABLE_ASAN "enable address sanitizer" OFF) option(ENABLE_TSAN "enable thread sanitizer" OFF) +option(ENABLE_UBSAN "enable undefined behavior sanitizer" OFF) option(ASAN_WITH_LSAN "enable leak sanitizer while address sanitizer is enabled" ON) option(ENABLE_STATIC_LIBSTDCXX "link kvrocks with static library of libstd++ instead of shared library" ON) option(ENABLE_LUAJIT "enable use of luaJIT instead of lua" ON) option(ENABLE_OPENSSL "enable openssl to support tls connection" OFF) option(ENABLE_IPO "enable interprocedural optimization" ON) -option(ENABLE_UNWIND "enable libunwind in glog" ON) option(ENABLE_SPEEDB "enable speedb instead of rocksdb" OFF) +set(SYMBOLIZE_BACKEND "" CACHE STRING "symbolization backend library for cpptrace (libbacktrace, libdwarf, or empty)") set(PORTABLE 0 CACHE STRING "build a portable binary (disable arch-specific optimizations)") # TODO: set ENABLE_NEW_ENCODING to ON when we are ready -option(ENABLE_NEW_ENCODING "enable new encoding (#1033) for storing 64bit size and expire time in milliseconds" OFF) +option(ENABLE_NEW_ENCODING "enable new encoding (#1033) for storing 64bit size and expire time in milliseconds" ON) if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0") cmake_policy(SET CMP0135 NEW) @@ -91,6 +92,33 @@ if(ENABLE_ASAN) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address") set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -fsanitize=address") endif() + +# Copied from https://github.com/apache/arrow/blob/main/cpp/cmake_modules/san-config.cmake +# +# Flag to enable clang undefined behavior sanitizer +# We explicitly don't enable all of the sanitizer flags: +# - disable 'vptr' because of RTTI issues across shared libraries (?) +# - disable 'alignment' because unaligned access is really OK on Nehalem and we do it +# all over the place. +# - disable 'function' because it appears to give a false positive +# (https://github.com/google/sanitizers/issues/911) +# - disable 'float-divide-by-zero' on clang, which considers it UB +# (https://bugs.llvm.org/show_bug.cgi?id=17000#c1) +# Note: GCC does not support the 'function' flag. +if(ENABLE_UBSAN) + if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=undefined -fno-sanitize=alignment,vptr,function,float-divide-by-zero") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fsanitize=undefined -fno-sanitize=alignment,vptr,function,float-divide-by-zero") + elseif(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL "5.1") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=undefined -fno-sanitize=alignment,vptr") + set(CMAKE_C_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=undefined -fno-sanitize=alignment,vptr") + else() + message(FATAL_ERROR "Cannot use UBSAN without clang or gcc >= 5.1") + endif() + set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -fsanitize=undefined") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-sanitize-recover=all") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fno-sanitize-recover=all") +endif() if(ENABLE_TSAN) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fsanitize=thread") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=thread") @@ -139,6 +167,9 @@ include(cmake/jsoncons.cmake) include(cmake/xxhash.cmake) include(cmake/span.cmake) include(cmake/trie.cmake) +include(cmake/pegtl.cmake) +include(cmake/rangev3.cmake) +include(cmake/cpptrace.cmake) if (ENABLE_LUAJIT) include(cmake/luajit.cmake) @@ -171,6 +202,9 @@ list(APPEND EXTERNAL_LIBS ${Backtrace_LIBRARY}) list(APPEND EXTERNAL_LIBS xxhash) list(APPEND EXTERNAL_LIBS span-lite) list(APPEND EXTERNAL_LIBS tsl_hat_trie) +list(APPEND EXTERNAL_LIBS pegtl) +list(APPEND EXTERNAL_LIBS range-v3) +list(APPEND EXTERNAL_LIBS cpptrace::cpptrace) # Add git sha to version.h find_package(Git REQUIRED) @@ -210,7 +244,16 @@ add_library(kvrocks_objs OBJECT ${KVROCKS_SRCS}) target_include_directories(kvrocks_objs PUBLIC src src/common src/vendor ${PROJECT_BINARY_DIR} ${Backtrace_INCLUDE_DIR}) target_compile_features(kvrocks_objs PUBLIC cxx_std_17) target_compile_options(kvrocks_objs PUBLIC -Wall -Wpedantic -Wsign-compare -Wreturn-type -fno-omit-frame-pointer) -target_compile_options(kvrocks_objs PUBLIC -Werror=unused-result -Werror=unused-variable) +target_compile_options(kvrocks_objs PUBLIC -Werror=unused-result) + +# disable unused-variable check on GCC < 8 due to the structure bindings +# https://gcc.gnu.org/bugzilla/show_bug.cgi?format=multiple&id=81767 +if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 8) + target_compile_options(kvrocks_objs PUBLIC -Wno-error=unused-variable) +else() + target_compile_options(kvrocks_objs PUBLIC -Werror=unused-variable) +endif() + if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") target_compile_options(kvrocks_objs PUBLIC -Wno-pedantic) elseif((CMAKE_CXX_COMPILER_ID STREQUAL "Clang") OR (CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang")) diff --git a/Dockerfile b/Dockerfile index 4ea544dee5e..0ade7dca7b6 100644 --- a/Dockerfile +++ b/Dockerfile @@ -15,28 +15,29 @@ # specific language governing permissions and limitations # under the License. -FROM alpine:3.16 as build +FROM debian:bookworm-slim AS build ARG MORE_BUILD_ARGS -RUN apk update && apk upgrade && apk add git gcc g++ make cmake ninja autoconf automake libtool python3 linux-headers curl openssl-dev libexecinfo-dev redis +RUN DEBIAN_FRONTEND=noninteractive && apt-get update && apt-get upgrade -y && apt-get -y --no-install-recommends install git build-essential autoconf cmake libtool python3 libssl-dev && apt-get autoremove && apt-get clean + WORKDIR /kvrocks COPY . . RUN ./x.py build -DENABLE_OPENSSL=ON -DPORTABLE=1 -DCMAKE_BUILD_TYPE=Release -j $(nproc) $MORE_BUILD_ARGS -FROM alpine:3.16 +FROM debian:bookworm-slim + +RUN DEBIAN_FRONTEND=noninteractive && apt-get update && apt-get upgrade -y && apt-get -y install openssl ca-certificates redis-tools && apt-get clean -RUN apk update && apk upgrade && apk add libexecinfo RUN mkdir /var/run/kvrocks VOLUME /var/lib/kvrocks COPY --from=build /kvrocks/build/kvrocks /bin/ -COPY --from=build /usr/bin/redis-cli /bin/ HEALTHCHECK --interval=10s --timeout=1s --start-period=30s --retries=3 \ - CMD ./bin/redis-cli -p 6666 PING | grep -E '(PONG|NOAUTH)' || exit 1 + CMD redis-cli -p 6666 PING | grep -E '(PONG|NOAUTH)' || exit 1 COPY ./LICENSE ./NOTICE ./licenses /kvrocks/ COPY ./kvrocks.conf /var/lib/kvrocks/ diff --git a/NOTICE b/NOTICE index a25a60d1e70..3b2a470bf3c 100644 --- a/NOTICE +++ b/NOTICE @@ -6,6 +6,10 @@ The Apache Software Foundation (http://www.apache.org/). ================================================================ +Thanks to designers Lingyu Tian and Shili Fan for contributing the logo of Kvrocks. + +================================================================ + This product includes a number of Dependencies with separate copyright notices and license terms. Your use of these submodules is subject to the terms and conditions of the following licenses. @@ -66,6 +70,8 @@ The text of each license is also included in licenses/LICENSE-[project].txt * LuaJIT(https://github.com/KvrocksLabs/LuaJIT) * lua(https://github.com/KvrocksLabs/lua, alternative to LuaJIT) * hat-trie(https://github.com/Tessil/hat-trie) +* pegtl(https://github.com/taocpp/PEGTL, NOTE: changed to Boost Software License Version 1.0 in main branch) +* cpptrace(https://github.com/jeremy-rifkin/cpptrace) ================================================================ Boost Software License Version 1.0 @@ -75,6 +81,7 @@ The text of each license is also included in licenses/LICENSE-[project].txt * jsoncons(https://github.com/danielaparker/jsoncons) * span-lite(https://github.com/martinmoene/span-lite) +* range-v3(https://github.com/ericniebler/range-v3) ================================================================ zlib/libpng licenses diff --git a/README.md b/README.md index ffe4614862d..e33cdabb29a 100644 --- a/README.md +++ b/README.md @@ -38,8 +38,6 @@ Kvrocks has the following key features: * High Availability: Support Redis sentinel to failover when master or slave was failed. * Cluster: Centralized management but accessible via any Redis cluster client. -Thanks to designers [Lingyu Tian](https://github.com/tianlingyu1997) and Shili Fan for contributing the logo of Kvrocks. - ## Who uses Kvrocks You can find Kvrocks users at [the Users page](https://kvrocks.apache.org/users/). @@ -113,6 +111,8 @@ $ docker run -it -p 6666:6666 apache/kvrocks --bind 0.0.0.0 $ docker run -it -p 6666:6666 apache/kvrocks:nightly ``` +Please visit [Apache Kvrocks on DockerHub](https://hub.docker.com/r/apache/kvrocks) for additional details about images. + ### Connect Kvrocks service ```sh @@ -183,41 +183,6 @@ Documents are hosted at the [official website](https://kvrocks.apache.org/docs/g Kvrocks community welcomes all forms of contribution and you can find out how to get involved on the [Community](https://kvrocks.apache.org/community/) and [How to Contribute](https://kvrocks.apache.org/community/contributing) pages. -## Performance - -### Hardware - -* CPU: 48 cores Intel(R) Xeon(R) CPU E5-2650 v4 @ 2.20GHz -* Memory: 32 GiB -* NET: Intel Corporation I350 Gigabit Network Connection -* DISK: 2TB NVMe Intel SSD DC P4600 - -> Benchmark Client: multi-thread redis-benchmark(unstable branch) - -### 1. Commands QPS - -> kvrocks: workers = 16, benchmark: 8 threads/ 512 conns / 128 payload - -latency: 99.9% < 10ms - -![image](assets/chart-commands.png) - -### 2. QPS on different payloads - -> kvrocks: workers = 16, benchmark: 8 threads/ 512 conns - -latency: 99.9% < 10ms - -![image](assets/chart-values.png) - -### 3. QPS on different workers - -> kvrocks: workers = 16, benchmark: 8 threads/ 512 conns / 128 payload - -latency: 99.9% < 10ms - -![image](assets/chart-threads.png) - ## License Apache Kvrocks is licensed under the Apache License Version 2.0. See the [LICENSE](LICENSE) file for details. diff --git a/assets/KQIR.png b/assets/KQIR.png new file mode 100644 index 00000000000..aeef0c4ec70 Binary files /dev/null and b/assets/KQIR.png differ diff --git a/assets/chart-commands.png b/assets/chart-commands.png deleted file mode 100644 index d9a250a633f..00000000000 Binary files a/assets/chart-commands.png and /dev/null differ diff --git a/assets/chart-threads.png b/assets/chart-threads.png deleted file mode 100644 index 591263664ed..00000000000 Binary files a/assets/chart-threads.png and /dev/null differ diff --git a/assets/chart-values.png b/assets/chart-values.png deleted file mode 100644 index f7e0d708f20..00000000000 Binary files a/assets/chart-values.png and /dev/null differ diff --git a/cmake/cpptrace.cmake b/cmake/cpptrace.cmake new file mode 100644 index 00000000000..370198bef81 --- /dev/null +++ b/cmake/cpptrace.cmake @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +include_guard() + +include(cmake/utils.cmake) + +FetchContent_DeclareGitHubWithMirror(cpptrace + jeremy-rifkin/cpptrace v0.6.2 + MD5=b13786adcc1785cb900746ea96c50bee +) + +if (SYMBOLIZE_BACKEND STREQUAL "libbacktrace") + set(CPPTRACE_BACKEND_OPTION "CPPTRACE_GET_SYMBOLS_WITH_LIBBACKTRACE=ON") +elseif (SYMBOLIZE_BACKEND STREQUAL "libdwarf") + set(CPPTRACE_BACKEND_OPTION "CPPTRACE_GET_SYMBOLS_WITH_LIBDWARF=ON") +else () + set(CPPTRACE_BACKEND_OPTION "CPPTRACE_GET_SYMBOLS_WITH_ADDR2LINE=ON") +endif () + +FetchContent_MakeAvailableWithArgs(cpptrace + ${CPPTRACE_BACKEND_OPTION} +) diff --git a/cmake/glog.cmake b/cmake/glog.cmake index 4a29a589755..5e9c12687e3 100644 --- a/cmake/glog.cmake +++ b/cmake/glog.cmake @@ -28,5 +28,4 @@ FetchContent_MakeAvailableWithArgs(glog WITH_GFLAGS=OFF WITH_GTEST=OFF BUILD_SHARED_LIBS=OFF - WITH_UNWIND=${ENABLE_UNWIND} ) diff --git a/cmake/jsoncons.cmake b/cmake/jsoncons.cmake index a81bd8388a0..fdad4153454 100644 --- a/cmake/jsoncons.cmake +++ b/cmake/jsoncons.cmake @@ -20,8 +20,8 @@ include_guard() include(cmake/utils.cmake) FetchContent_DeclareGitHubWithMirror(jsoncons - danielaparker/jsoncons v0.173.4 - MD5=947254529a8629d001322a78454a23d2 + danielaparker/jsoncons v0.176.0 + MD5=5d0343fe48cbc640bdb42d89a5b87182 ) FetchContent_MakeAvailableWithArgs(jsoncons diff --git a/cmake/pegtl.cmake b/cmake/pegtl.cmake new file mode 100644 index 00000000000..d1638f365c1 --- /dev/null +++ b/cmake/pegtl.cmake @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +include_guard() + +include(cmake/utils.cmake) + +FetchContent_DeclareGitHubTarWithMirror(pegtl + taocpp/PEGTL 3.2.7 + MD5=31b14660c883bc0489ddcdfbd29199c9 +) + +FetchContent_MakeAvailableWithArgs(pegtl) diff --git a/cmake/rangev3.cmake b/cmake/rangev3.cmake new file mode 100644 index 00000000000..c30a1468e60 --- /dev/null +++ b/cmake/rangev3.cmake @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +include_guard() + +include(cmake/utils.cmake) + +FetchContent_DeclareGitHubWithMirror(rangev3 + ericniebler/range-v3 0.12.0 + MD5=e220e3f545fdf46241b4f139822d73a1 +) + +if (CMAKE_BUILD_TYPE STREQUAL "Debug") + set(WITH_DEBUG_INFO ON) +elseif(CMAKE_BUILD_TYPE STREQUAL "Release") + set(WITH_DEBUG_INFO OFF) +elseif(CMAKE_BUILD_TYPE STREQUAL "RelWithDebInfo") + set(WITH_DEBUG_INFO ON) +elseif (CMAKE_BUILD_TYPE STREQUAL "MinSizeRel") + set(WITH_DEBUG_INFO OFF) +endif() + +if (PORTABLE STREQUAL 0) + set(ARG_RANGES_NATIVE ON) +else() + set(ARG_RANGES_NATIVE OFF) +endif() + +FetchContent_MakeAvailableWithArgs(rangev3 + RANGES_CXX_STD=17 + RANGES_BUILD_CALENDAR_EXAMPLE=OFF + RANGES_DEBUG_INFO=${WITH_DEBUG_INFO} + RANGES_NATIVE=${ARG_RANGES_NATIVE} +) diff --git a/cmake/rocksdb.cmake b/cmake/rocksdb.cmake index 6a150fc5ea1..102bda3b548 100644 --- a/cmake/rocksdb.cmake +++ b/cmake/rocksdb.cmake @@ -26,8 +26,8 @@ endif() include(cmake/utils.cmake) FetchContent_DeclareGitHubWithMirror(rocksdb - facebook/rocksdb v8.11.3 - MD5=6b1faf94b3880913725ddb09ab60d7a6 + facebook/rocksdb v8.11.4 + MD5=8190ea7769705aabc928311762d5aafe ) FetchContent_GetProperties(jemalloc) diff --git a/cmake/snappy.cmake b/cmake/snappy.cmake index b7cc862167d..d62a506df7a 100644 --- a/cmake/snappy.cmake +++ b/cmake/snappy.cmake @@ -20,8 +20,8 @@ include_guard() include(cmake/utils.cmake) FetchContent_DeclareGitHubWithMirror(snappy - google/snappy f725f6766bfc62418c6491b504c8e5865ec99412 - MD5=17a982c9b0c667b3744e1fecba0046f7 + google/snappy 1.2.1 + MD5=b120895e012e097b86bf49e0ef9ca67c ) FetchContent_MakeAvailableWithArgs(snappy diff --git a/cmake/tbb.cmake b/cmake/tbb.cmake index 790b2e7e7c9..a408ac8b9dd 100644 --- a/cmake/tbb.cmake +++ b/cmake/tbb.cmake @@ -20,8 +20,8 @@ include_guard() include(cmake/utils.cmake) FetchContent_DeclareGitHubWithMirror(tbb - oneapi-src/oneTBB v2021.11.0 - MD5=eea2bdc5ae0a51389da27480617ccff9 + oneapi-src/oneTBB v2021.12.0 + MD5=0919a8eda74333e1aafa8d602bb9cc90 ) FetchContent_MakeAvailableWithArgs(tbb diff --git a/cmake/utils.cmake b/cmake/utils.cmake index c1a8594b89e..cce10cbf982 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -34,6 +34,7 @@ function(FetchContent_MakeAvailableWithArgs dep) parse_var(${arg} key value) set(${key}_OLD ${${key}}) set(${key} ${value} CACHE INTERNAL "") + message("In ${dep}: ${key} set to ${value}") endforeach() add_subdirectory(${${dep}_SOURCE_DIR} ${${dep}_BINARY_DIR} EXCLUDE_FROM_ALL) @@ -58,3 +59,10 @@ function(FetchContent_DeclareGitHubWithMirror dep repo tag hash) ${hash} ) endfunction() + +function(FetchContent_DeclareGitHubTarWithMirror dep repo tag hash) + FetchContent_DeclareWithMirror(${dep} + https://github.com/${repo}/archive/${tag}.tar.gz + ${hash} + ) +endfunction() diff --git a/cmake/zstd.cmake b/cmake/zstd.cmake index 4c0032e40f9..6a322220c36 100644 --- a/cmake/zstd.cmake +++ b/cmake/zstd.cmake @@ -20,8 +20,8 @@ include_guard() include(cmake/utils.cmake) FetchContent_DeclareGitHubWithMirror(zstd - facebook/zstd v1.5.5 - MD5=f336cde1961ee7e5d3a7f8c0c0f96987 + facebook/zstd v1.5.6 + MD5=cfb58a03ae01a39d5fff731ecaaa2657 ) FetchContent_GetProperties(zstd) diff --git a/dev/hooks/pre-push b/dev/hooks/pre-push new file mode 100755 index 00000000000..f7a0f5239a5 --- /dev/null +++ b/dev/hooks/pre-push @@ -0,0 +1,41 @@ +#!/usr/bin/env bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Check 'format' and 'golangci-lint' before 'git push', +# Copy this script to .git/hooks to activate, +# and remove it from .git/hooks to deactivate. + +set -Euo pipefail + +unset GIT_DIR +ROOT_DIR="$(git rev-parse --show-toplevel)" +cd "$ROOT_DIR" + +run_check() { + local check_name=$1 + echo "Running pre-push script $ROOT_DIR/x.py $check_name" + ./x.py check "$check_name" + + if [ $? -ne 0 ]; then + echo 'You may use `git push --no-verify` to skip this check.' + exit 1 + fi +} + +run_check format +run_check golangci-lint diff --git a/kvrocks.conf b/kvrocks.conf index e6b17dae2ef..66715028493 100644 --- a/kvrocks.conf +++ b/kvrocks.conf @@ -492,15 +492,23 @@ profiling-sample-record-threshold-ms 100 ################################## CRON ################################### # Compact Scheduler, auto compact at schedule time -# time expression format is the same as crontab(currently only support * and int) -# e.g. compact-cron 0 3 * * * 0 4 * * * +# Time expression format is the same as crontab (supported cron syntax: *, n, */n, `1,3-6,9,11`) +# e.g. compact-cron 0 3,4 * * * # would compact the db at 3am and 4am everyday # compact-cron 0 3 * * * # The hour range that compaction checker would be active # e.g. compaction-checker-range 0-7 means compaction checker would be worker between # 0-7am every day. -compaction-checker-range 0-7 +# WARNING: this config option is deprecated and will be removed, +# please use compaction-checker-cron instead +# compaction-checker-range 0-7 + +# The time pattern that compaction checker would be active +# Time expression format is the same as crontab (supported cron syntax: *, n, */n, `1,3-6,9,11`) +# e.g. compaction-checker-cron * 0-7 * * * means compaction checker would be worker between +# 0-7am every day. +compaction-checker-cron * 0-7 * * * # When the compaction checker is triggered, the db will periodically pick the SST file # with the highest "deleted percentage" (i.e. the percentage of deleted keys in the SST @@ -515,10 +523,17 @@ compaction-checker-range 0-7 # force-compact-file-min-deleted-percentage 10 # Bgsave scheduler, auto bgsave at scheduled time -# time expression format is the same as crontab(currently only support * and int) -# e.g. bgsave-cron 0 3 * * * 0 4 * * * +# Time expression format is the same as crontab (supported cron syntax: *, n, */n, `1,3-6,9,11`) +# e.g. bgsave-cron 0 3,4 * * * # would bgsave the db at 3am and 4am every day +# Kvrocks doesn't store the key number directly. It needs to scan the DB and +# then retrieve the key number by using the dbsize scan command. +# The Dbsize scan scheduler auto-recalculates the estimated keys at scheduled time. +# Time expression format is the same as crontab (supported cron syntax: *, n, */n, `1,3-6,9,11`) +# e.g. dbsize-scan-cron 0 * * * * +# would recalculate the keyspace infos of the db every hour. + # Command renaming. # # It is possible to change the name of dangerous commands in a shared @@ -727,6 +742,29 @@ rocksdb.cache_index_and_filter_blocks yes # default snappy rocksdb.compression snappy +# Specify the compression level to use. It trades compression speed +# and ratio, might be useful when tuning for disk space. +# See details: https://github.com/facebook/rocksdb/wiki/Space-Tuning +# For zstd: valid range is from 1 (fastest) to 19 (best ratio), +# For zlib: valid range is from 1 (fastest) to 9 (best ratio), +# For lz4: adjusting the level influences the 'acceleration'. +# RocksDB sets a negative level to indicate acceleration directly, +# with more negative values indicating higher speed and less compression. +# Note: This setting is ignored for compression algorithms like Snappy that +# do not support variable compression levels. +# +# RocksDB Default: +# - zstd: 3 +# - zlib: Z_DEFAULT_COMPRESSION (currently -1) +# - kLZ4: -1 (i.e., `acceleration=1`; see `CompressionOptions::level` doc) +# For all others, RocksDB does not specify a compression level. +# If the compression type doesn't support the setting, it will be a no-op. +# +# Default: 32767 (RocksDB's generic default compression level. Internally +# it'll be translated to the default compression level specific to the +# compression library as mentioned above) +rocksdb.compression_level 32767 + # If non-zero, we perform bigger reads when doing compaction. If you're # running RocksDB on spinning disks, you should set this to at least 2MB. # That way RocksDB's compaction is doing sequential instead of random reads. @@ -851,8 +889,8 @@ rocksdb.max_bytes_for_level_multiplier 10 # In iterators, it will prefetch data asynchronously in the background for each file being iterated on. # In MultiGet, it will read the necessary data blocks from those files in parallel as much as possible. -# Default no -rocksdb.read_options.async_io no +# Default yes +rocksdb.read_options.async_io yes # If yes, the write will be flushed from the operating system # buffer cache before the write is considered complete. diff --git a/licenses/LICENSE-cpptrace.txt b/licenses/LICENSE-cpptrace.txt new file mode 100644 index 00000000000..299bf1fac8c --- /dev/null +++ b/licenses/LICENSE-cpptrace.txt @@ -0,0 +1,18 @@ +The MIT License (MIT) + +Copyright (c) 2023-2024 Jeremy Rifkin + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and +associated documentation files (the "Software"), to deal in the Software without restriction, +including without limitation the rights to use, copy, modify, merge, publish, distribute, +sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial +portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT +NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES +OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/licenses/LICENSE-pegtl.txt b/licenses/LICENSE-pegtl.txt new file mode 100644 index 00000000000..97fa831443f --- /dev/null +++ b/licenses/LICENSE-pegtl.txt @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2007-2022 Dr. Colin Hirsch and Daniel Frey + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/licenses/LICENSE-range-v3.txt b/licenses/LICENSE-range-v3.txt new file mode 100644 index 00000000000..698193e974e --- /dev/null +++ b/licenses/LICENSE-range-v3.txt @@ -0,0 +1,151 @@ +======================================================== +Boost Software License - Version 1.0 - August 17th, 2003 +======================================================== + +Permission is hereby granted, free of charge, to any person or organization +obtaining a copy of the software and accompanying documentation covered by +this license (the "Software") to use, reproduce, display, distribute, +execute, and transmit the Software, and to prepare derivative works of the +Software, and to permit third-parties to whom the Software is furnished to +do so, all subject to the following: + +The copyright notices in the Software and this entire statement, including +the above license grant, this restriction and the following disclaimer, +must be included in all copies of the Software, in whole or in part, and +all derivative works of the Software, unless such copies or derivative +works are solely in the form of machine-executable object code generated by +a source language processor. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT +SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE +FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, +ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. + +============================================================================== +libc++ License +============================================================================== + +The libc++ library is dual licensed under both the University of Illinois +"BSD-Like" license and the MIT license. As a user of this code you may choose +to use it under either license. As a contributor, you agree to allow your code +to be used under both. + +Full text of the relevant licenses is included below. + +============================================================================== + +University of Illinois/NCSA +Open Source License + +Copyright (c) 2009-2014 by the contributors listed in CREDITS.TXT +http://llvm.org/svn/llvm-project/libcxx/trunk/CREDITS.TXT + +All rights reserved. + +Developed by: + + LLVM Team + + University of Illinois at Urbana-Champaign + + http://llvm.org + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal with +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: + + * Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimers. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimers in the + documentation and/or other materials provided with the distribution. + + * Neither the names of the LLVM Team, University of Illinois at + Urbana-Champaign, nor the names of its contributors may be used to + endorse or promote products derived from this Software without specific + prior written permission. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS WITH THE +SOFTWARE. + +============================================================================== + +Copyright (c) 2009-2014 by the contributors listed in CREDITS.TXT + http://llvm.org/svn/llvm-project/libcxx/trunk/CREDITS.TXT + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + +============================================================================== +Stepanov and McJones, "Elements of Programming" license +============================================================================== + +// Copyright (c) 2009 Alexander Stepanov and Paul McJones +// +// Permission to use, copy, modify, distribute and sell this software +// and its documentation for any purpose is hereby granted without +// fee, provided that the above copyright notice appear in all copies +// and that both that copyright notice and this permission notice +// appear in supporting documentation. The authors make no +// representations about the suitability of this software for any +// purpose. It is provided "as is" without express or implied +// warranty. +// +// Algorithms from +// Elements of Programming +// by Alexander Stepanov and Paul McJones +// Addison-Wesley Professional, 2009 + +============================================================================== +SGI C++ Standard Template Library license +============================================================================== + +// Copyright (c) 1994 +// Hewlett-Packard Company +// +// Permission to use, copy, modify, distribute and sell this software +// and its documentation for any purpose is hereby granted without fee, +// provided that the above copyright notice appear in all copies and +// that both that copyright notice and this permission notice appear +// in supporting documentation. Hewlett-Packard Company makes no +// representations about the suitability of this software for any +// purpose. It is provided "as is" without express or implied warranty. +// +// Copyright (c) 1996 +// Silicon Graphics Computer Systems, Inc. +// +// Permission to use, copy, modify, distribute and sell this software +// and its documentation for any purpose is hereby granted without fee, +// provided that the above copyright notice appear in all copies and +// that both that copyright notice and this permission notice appear +// in supporting documentation. Silicon Graphics makes no +// representations about the suitability of this software for any +// purpose. It is provided "as is" without express or implied warranty. +// diff --git a/src/VERSION.txt b/src/VERSION.txt index 76cff7f1f82..f3ac133c547 100644 --- a/src/VERSION.txt +++ b/src/VERSION.txt @@ -1 +1 @@ -unstable \ No newline at end of file +2.9.0 \ No newline at end of file diff --git a/src/cli/main.cc b/src/cli/main.cc index 6e9b0b14a55..32c957a4ed5 100644 --- a/src/cli/main.cc +++ b/src/cli/main.cc @@ -154,12 +154,12 @@ int main(int argc, char *argv[]) { } bool is_supervised = IsSupervisedMode(config.supervised_mode); if (config.daemonize && !is_supervised) Daemonize(); - s = CreatePidFile(config.GetPidFile()); + s = CreatePidFile(config.pidfile); if (!s.IsOK()) { LOG(ERROR) << "Failed to create pidfile: " << s.Msg(); return 1; } - auto pidfile_exit = MakeScopeExit([&config] { RemovePidFile(config.GetPidFile()); }); + auto pidfile_exit = MakeScopeExit([&config] { RemovePidFile(config.pidfile); }); #ifdef ENABLE_OPENSSL // initialize OpenSSL diff --git a/src/cli/signal_util.h b/src/cli/signal_util.h index cb5abc67371..d6f02e614aa 100644 --- a/src/cli/signal_util.h +++ b/src/cli/signal_util.h @@ -24,39 +24,19 @@ #include #include +#include #include #include #include "version_util.h" -namespace google { -bool Symbolize(void *pc, char *out, size_t out_size); -} // namespace google - extern "C" inline void SegvHandler(int sig, siginfo_t *info, void *secret) { - void *trace[100]; - - LOG(ERROR) << "======= Ooops! kvrocks " << PrintVersion << " got signal: " << strsignal(sig) << " (" << sig - << ") ======="; - int trace_size = backtrace(trace, sizeof(trace) / sizeof(void *)); - char **messages = backtrace_symbols(trace, trace_size); - - size_t max_msg_len = 0; - for (int i = 1; i < trace_size; ++i) { - auto msg_len = strlen(messages[i]); - if (msg_len > max_msg_len) { - max_msg_len = msg_len; - } - } - - for (int i = 1; i < trace_size; ++i) { - char func_info[1024] = {}; - if (google::Symbolize(trace[i], func_info, sizeof(func_info) - 1)) { - LOG(ERROR) << std::left << std::setw(static_cast(max_msg_len)) << messages[i] << " " << func_info; - } else { - LOG(ERROR) << messages[i]; - } - } + LOG(ERROR) << "Ooops! Apache Kvrocks " << PrintVersion << " got signal: " << strsignal(sig) << " (" << sig << ")"; + auto trace = cpptrace::generate_trace(); + trace.print(LOG(ERROR)); + LOG(ERROR) + << "It would be greatly appreciated if you could submit this crash to https://github.com/apache/kvrocks/issues " + "along with the stacktrace above, logs and any relevant information."; struct sigaction act; /* Make sure we exit with the right signal at the end. So for instance diff --git a/src/cluster/cluster.cc b/src/cluster/cluster.cc index ff98da876b1..8d373e0dfc1 100644 --- a/src/cluster/cluster.cc +++ b/src/cluster/cluster.cc @@ -22,9 +22,11 @@ #include +#include #include #include #include +#include #include "cluster/cluster_defs.h" #include "commands/commander.h" @@ -37,11 +39,11 @@ #include "time_util.h" ClusterNode::ClusterNode(std::string id, std::string host, int port, int role, std::string master_id, - std::bitset slots) + const std::bitset &slots) : id(std::move(id)), host(std::move(host)), port(port), role(role), master_id(std::move(master_id)), slots(slots) {} Cluster::Cluster(Server *srv, std::vector binds, int port) - : srv_(srv), binds_(std::move(binds)), port_(port), size_(0), version_(-1), myself_(nullptr) { + : srv_(srv), binds_(std::move(binds)), port_(port) { for (auto &slots_node : slots_nodes_) { slots_node = nullptr; } @@ -53,10 +55,10 @@ Cluster::Cluster(Server *srv, std::vector binds, int port) // cluster data, so these commands should be executed exclusively, and ReadWriteLock // also can guarantee accessing data is safe. bool Cluster::SubCommandIsExecExclusive(const std::string &subcommand) { - for (auto v : {"setnodes", "setnodeid", "setslot", "import"}) { - if (util::EqualICase(v, subcommand)) return true; - } - return false; + std::array subcommands = {"setnodes", "setnodeid", "setslot", "import", "reset"}; + + return std::any_of(std::begin(subcommands), std::end(subcommands), + [&subcommand](const std::string &val) { return util::EqualICase(val, subcommand); }); } Status Cluster::SetNodeId(const std::string &node_id) { @@ -170,26 +172,26 @@ Status Cluster::SetClusterNodes(const std::string &nodes_str, int64_t version, b size_ = 0; // Update slots to nodes - for (const auto &n : slots_nodes) { - slots_nodes_[n.first] = nodes_[n.second]; + for (const auto &[slot, node_id] : slots_nodes) { + slots_nodes_[slot] = nodes_[node_id]; } // Update replicas info and size - for (auto &n : nodes_) { - if (n.second->role == kClusterSlave) { - if (nodes_.find(n.second->master_id) != nodes_.end()) { - nodes_[n.second->master_id]->replicas.push_back(n.first); + for (const auto &[node_id, node] : nodes_) { + if (node->role == kClusterSlave) { + if (nodes_.find(node->master_id) != nodes_.end()) { + nodes_[node->master_id]->replicas.push_back(node_id); } } - if (n.second->role == kClusterMaster && n.second->slots.count() > 0) { + if (node->role == kClusterMaster && node->slots.count() > 0) { size_++; } } if (myid_.empty() || force) { - for (auto &n : nodes_) { - if (n.second->port == port_ && util::MatchListeningIP(binds_, n.second->host)) { - myid_ = n.first; + for (const auto &[node_id, node] : nodes_) { + if (node->port == port_ && util::MatchListeningIP(binds_, node->host)) { + myid_ = node_id; break; } } @@ -210,9 +212,9 @@ Status Cluster::SetClusterNodes(const std::string &nodes_str, int64_t version, b // Clear data of migrated slots if (!migrated_slots_.empty()) { - for (auto &it : migrated_slots_) { - if (slots_nodes_[it.first] != myself_) { - auto s = srv_->slot_migrator->ClearKeysOfSlot(kDefaultNamespace, it.first); + for (const auto &[slot, _] : migrated_slots_) { + if (slots_nodes_[slot] != myself_) { + auto s = srv_->slot_migrator->ClearKeysOfSlot(kDefaultNamespace, slot); if (!s.ok()) { LOG(ERROR) << "failed to clear data of migrated slots: " << s.ToString(); } @@ -318,41 +320,39 @@ Status Cluster::ImportSlot(redis::Connection *conn, int slot, int state) { if (!IsValidSlot(slot)) { return {Status::NotOK, errSlotOutOfRange}; } + auto source_node = srv_->cluster->slots_nodes_[slot]; + if (source_node && source_node->id == myid_) { + return {Status::NotOK, "Can't import slot which belongs to me"}; + } + Status s; switch (state) { case kImportStart: - if (!srv_->slot_import->Start(conn->GetFD(), slot)) { - return {Status::NotOK, fmt::format("Can't start importing slot {}", slot)}; - } + s = srv_->slot_import->Start(slot); + if (!s.IsOK()) return s; // Set link importing conn->SetImporting(); myself_->importing_slot = slot; // Set link error callback - conn->close_cb = [object_ptr = srv_->slot_import.get(), capture_fd = conn->GetFD()](int fd) { - object_ptr->StopForLinkError(capture_fd); - }; - // Stop forbidding writing slot to accept write commands + conn->close_cb = [object_ptr = srv_->slot_import.get(), slot](int fd) { + auto s = object_ptr->StopForLinkError(); + if (!s.IsOK()) { + LOG(ERROR) << fmt::format("[import] Failed to stop importing slot {}: {}", slot, s.Msg()); + } + }; // Stop forbidding writing slot to accept write commands if (slot == srv_->slot_migrator->GetForbiddenSlot()) srv_->slot_migrator->ReleaseForbiddenSlot(); - LOG(INFO) << "[import] Start importing slot " << slot; + LOG(INFO) << fmt::format("[import] Start importing slot {}", slot); break; case kImportSuccess: - if (!srv_->slot_import->Success(slot)) { - LOG(ERROR) << "[import] Failed to set slot importing success, maybe slot is wrong" - << ", received slot: " << slot << ", current slot: " << srv_->slot_import->GetSlot(); - return {Status::NotOK, fmt::format("Failed to set slot {} importing success", slot)}; - } - - LOG(INFO) << "[import] Succeed to import slot " << slot; + s = srv_->slot_import->Success(slot); + if (!s.IsOK()) return s; + LOG(INFO) << fmt::format("[import] Mark the importing slot {} as succeed", slot); break; case kImportFailed: - if (!srv_->slot_import->Fail(slot)) { - LOG(ERROR) << "[import] Failed to set slot importing error, maybe slot is wrong" - << ", received slot: " << slot << ", current slot: " << srv_->slot_import->GetSlot(); - return {Status::NotOK, fmt::format("Failed to set slot {} importing error", slot)}; - } - - LOG(INFO) << "[import] Failed to import slot " << slot; + s = srv_->slot_import->Fail(slot); + if (!s.IsOK()) return s; + LOG(INFO) << fmt::format("[import] Mark the importing slot {} as failed", slot); break; default: return {Status::NotOK, errInvalidImportState}; @@ -363,7 +363,7 @@ Status Cluster::ImportSlot(redis::Connection *conn, int slot, int state) { Status Cluster::GetClusterInfo(std::string *cluster_infos) { if (version_ < 0) { - return {Status::ClusterDown, errClusterNoInitialized}; + return {Status::RedisClusterDown, errClusterNoInitialized}; } cluster_infos->clear(); @@ -421,7 +421,7 @@ Status Cluster::GetClusterInfo(std::string *cluster_infos) { // ... continued until done Status Cluster::GetSlotsInfo(std::vector *slots_infos) { if (version_ < 0) { - return {Status::ClusterDown, errClusterNoInitialized}; + return {Status::RedisClusterDown, errClusterNoInitialized}; } slots_infos->clear(); @@ -464,53 +464,113 @@ SlotInfo Cluster::genSlotNodeInfo(int start, int end, const std::shared_ptr Cluster::GetReplicas(const std::string &node_id) { + if (version_ < 0) { + return {Status::RedisClusterDown, errClusterNoInitialized}; + } + + auto item = nodes_.find(node_id); + if (item == nodes_.end()) { + return {Status::InvalidArgument, errInvalidNodeID}; + } + + auto node = item->second; + if (node->role != kClusterMaster) { + return {Status::InvalidArgument, errNoMasterNode}; + } + + auto now = util::GetTimeStampMS(); + std::string replicas_desc; + for (const auto &replica_id : node->replicas) { + auto n = nodes_.find(replica_id); + if (n == nodes_.end()) { + continue; + } + + auto replica = n->second; + + std::string node_str; + // ID, host, port + node_str.append( + fmt::format("{} {}:{}@{} ", replica_id, replica->host, replica->port, replica->port + kClusterPortIncr)); + + // Flags + node_str.append(fmt::format("slave {} ", node_id)); + + // Ping sent, pong received, config epoch, link status + node_str.append(fmt::format("{} {} {} connected", now - 1, now, version_)); + + replicas_desc.append(node_str + "\n"); + } + + return replicas_desc; +} + +std::string Cluster::getNodeIDBySlot(int slot) const { + if (slot < 0 || slot >= kClusterSlots || !slots_nodes_[slot]) return ""; + return slots_nodes_[slot]->id; +} + std::string Cluster::genNodesDescription() { auto slots_infos = getClusterNodeSlots(); auto now = util::GetTimeStampMS(); std::string nodes_desc; - for (const auto &item : nodes_) { - const std::shared_ptr n = item.second; - + for (const auto &[_, node] : nodes_) { std::string node_str; // ID, host, port - node_str.append(n->id + " "); - node_str.append(fmt::format("{}:{}@{} ", n->host, n->port, n->port + kClusterPortIncr)); + node_str.append(node->id + " "); + node_str.append(fmt::format("{}:{}@{} ", node->host, node->port, node->port + kClusterPortIncr)); // Flags - if (n->id == myid_) node_str.append("myself,"); - if (n->role == kClusterMaster) { + if (node->id == myid_) node_str.append("myself,"); + if (node->role == kClusterMaster) { node_str.append("master - "); } else { - node_str.append("slave " + n->master_id + " "); + node_str.append("slave " + node->master_id + " "); } // Ping sent, pong received, config epoch, link status node_str.append(fmt::format("{} {} {} connected", now - 1, now, version_)); - if (n->role == kClusterMaster) { - auto iter = slots_infos.find(n->id); - if (iter != slots_infos.end() && iter->second.size() > 0) { + if (node->role == kClusterMaster) { + auto iter = slots_infos.find(node->id); + if (iter != slots_infos.end() && !iter->second.empty()) { node_str.append(" " + iter->second); } } + // Just for MYSELF node to show the importing/migrating slot + if (node->id == myid_) { + if (srv_->slot_migrator) { + auto migrating_slot = srv_->slot_migrator->GetMigratingSlot(); + if (migrating_slot != -1) { + node_str.append(fmt::format(" [{}->-{}]", migrating_slot, srv_->slot_migrator->GetDstNode())); + } + } + if (srv_->slot_import) { + auto importing_slot = srv_->slot_import->GetSlot(); + if (importing_slot != -1) { + node_str.append(fmt::format(" [{}-<-{}]", importing_slot, getNodeIDBySlot(importing_slot))); + } + } + } nodes_desc.append(node_str + "\n"); } return nodes_desc; } -std::map Cluster::getClusterNodeSlots() const { +std::map> Cluster::getClusterNodeSlots() const { int start = -1; // node id => slots info string - std::map slots_infos; + std::map> slots_infos; std::shared_ptr n = nullptr; for (int i = 0; i <= kClusterSlots; i++) { @@ -540,30 +600,29 @@ std::map Cluster::getClusterNodeSlots() const { return slots_infos; } -std::string Cluster::genNodesInfo() { +std::string Cluster::genNodesInfo() const { auto slots_infos = getClusterNodeSlots(); std::string nodes_info; - for (const auto &item : nodes_) { - const std::shared_ptr &n = item.second; + for (const auto &[_, node] : nodes_) { std::string node_str; node_str.append("node "); // ID - node_str.append(n->id + " "); + node_str.append(node->id + " "); // Host + Port - node_str.append(fmt::format("{} {} ", n->host, n->port)); + node_str.append(fmt::format("{} {} ", node->host, node->port)); // Role - if (n->role == kClusterMaster) { + if (node->role == kClusterMaster) { node_str.append("master - "); } else { - node_str.append("slave " + n->master_id + " "); + node_str.append("slave " + node->master_id + " "); } // Slots - if (n->role == kClusterMaster) { - auto iter = slots_infos.find(n->id); - if (iter != slots_infos.end() && iter->second.size() > 0) { + if (node->role == kClusterMaster) { + auto iter = slots_infos.find(node->id); + if (iter != slots_infos.end() && !iter->second.empty()) { node_str.append(" " + iter->second); } } @@ -634,7 +693,7 @@ Status Cluster::LoadClusterNodes(const std::string &file_path) { Status Cluster::parseClusterNodes(const std::string &nodes_str, ClusterNodes *nodes, std::unordered_map *slots_nodes) { std::vector nodes_info = util::Split(nodes_str, "\n"); - if (nodes_info.size() == 0) { + if (nodes_info.empty()) { return {Status::ClusterInvalidInfo, errInvalidClusterNodeInfo}; } @@ -743,16 +802,17 @@ Status Cluster::parseClusterNodes(const std::string &nodes_str, ClusterNodes *no return Status::OK(); } -bool Cluster::IsWriteForbiddenSlot(int slot) { return srv_->slot_migrator->GetForbiddenSlot() == slot; } +bool Cluster::IsWriteForbiddenSlot(int slot) const { return srv_->slot_migrator->GetForbiddenSlot() == slot; } Status Cluster::CanExecByMySelf(const redis::CommandAttributes *attributes, const std::vector &cmd_tokens, redis::Connection *conn) { std::vector keys_indexes; - auto s = redis::CommandTable::GetKeysFromCommand(attributes, cmd_tokens, &keys_indexes); + // No keys - if (!s.IsOK()) return Status::OK(); + if (auto s = redis::CommandTable::GetKeysFromCommand(attributes, cmd_tokens, &keys_indexes); !s.IsOK()) + return Status::OK(); - if (keys_indexes.size() == 0) return Status::OK(); + if (keys_indexes.empty()) return Status::OK(); int slot = -1; for (auto i : keys_indexes) { @@ -761,13 +821,13 @@ Status Cluster::CanExecByMySelf(const redis::CommandAttributes *attributes, cons int cur_slot = GetSlotIdFromKey(cmd_tokens[i]); if (slot == -1) slot = cur_slot; if (slot != cur_slot) { - return {Status::RedisExecErr, "CROSSSLOT Attempted to access keys that don't hash to the same slot"}; + return {Status::RedisCrossSlot, "Attempted to access keys that don't hash to the same slot"}; } } if (slot == -1) return Status::OK(); if (slots_nodes_[slot] == nullptr) { - return {Status::ClusterDown, "CLUSTERDOWN Hash slot not served"}; + return {Status::RedisClusterDown, "Hash slot not served"}; } if (myself_ && myself_ == slots_nodes_[slot]) { @@ -775,23 +835,24 @@ Status Cluster::CanExecByMySelf(const redis::CommandAttributes *attributes, cons // Server can't change the topology directly, so we record the migrated slots // to move the requests of the migrated slots to the destination node. if (migrated_slots_.count(slot) > 0) { // I'm not serving the migrated slot - return {Status::RedisExecErr, fmt::format("MOVED {} {}", slot, migrated_slots_[slot])}; + return {Status::RedisMoved, fmt::format("{} {}", slot, migrated_slots_[slot])}; } // To keep data consistency, slot will be forbidden write while sending the last incremental data. // During this phase, the requests of the migrating slot has to be rejected. if ((attributes->flags & redis::kCmdWrite) && IsWriteForbiddenSlot(slot)) { - return {Status::RedisExecErr, "TRYAGAIN Can't write to slot being migrated which is in write forbidden phase"}; + return {Status::RedisTryAgain, "Can't write to slot being migrated which is in write forbidden phase"}; } return Status::OK(); // I'm serving this slot } - if (myself_ && myself_->importing_slot == slot && conn->IsImporting()) { + if (myself_ && myself_->importing_slot == slot && + (conn->IsImporting() || conn->IsFlagEnabled(redis::Connection::kAsking))) { // While data migrating, the topology of the destination node has not been changed. // The destination node has to serve the requests from the migrating slot, // although the slot is not belong to itself. Therefore, we record the importing slot // and mark the importing connection to accept the importing data. - return Status::OK(); // I'm serving the importing connection + return Status::OK(); // I'm serving the importing connection or asking connection } if (myself_ && imported_slots_.count(slot)) { @@ -802,10 +863,44 @@ Status Cluster::CanExecByMySelf(const redis::CommandAttributes *attributes, cons } if (myself_ && myself_->role == kClusterSlave && !(attributes->flags & redis::kCmdWrite) && - nodes_.find(myself_->master_id) != nodes_.end() && nodes_[myself_->master_id] == slots_nodes_[slot]) { + nodes_.find(myself_->master_id) != nodes_.end() && nodes_[myself_->master_id] == slots_nodes_[slot] && + conn->IsFlagEnabled(redis::Connection::kReadOnly)) { return Status::OK(); // My master is serving this slot } - return {Status::RedisExecErr, - fmt::format("MOVED {} {}:{}", slot, slots_nodes_[slot]->host, slots_nodes_[slot]->port)}; + return {Status::RedisMoved, fmt::format("{} {}:{}", slot, slots_nodes_[slot]->host, slots_nodes_[slot]->port)}; +} + +// Only HARD mode is meaningful to the Kvrocks cluster, +// so it will force clearing all information after resetting. +Status Cluster::Reset() { + if (srv_->slot_migrator && srv_->slot_migrator->GetMigratingSlot() != -1) { + return {Status::NotOK, "Can't reset cluster while migrating slot"}; + } + if (srv_->slot_import && srv_->slot_import->GetSlot() != -1) { + return {Status::NotOK, "Can't reset cluster while importing slot"}; + } + if (!srv_->storage->IsEmptyDB()) { + return {Status::NotOK, "Can't reset cluster while database is not empty"}; + } + if (srv_->IsSlave()) { + auto s = srv_->RemoveMaster(); + if (!s.IsOK()) return s; + } + + version_ = -1; + size_ = 0; + myid_.clear(); + myself_.reset(); + + nodes_.clear(); + for (auto &n : slots_nodes_) { + n = nullptr; + } + migrated_slots_.clear(); + imported_slots_.clear(); + + // unlink the cluster nodes file if exists + unlink(srv_->GetConfig()->NodesFilePath().data()); + return Status::OK(); } diff --git a/src/cluster/cluster.h b/src/cluster/cluster.h index 8b8132ab0db..335d5ef16b1 100644 --- a/src/cluster/cluster.h +++ b/src/cluster/cluster.h @@ -39,7 +39,7 @@ class ClusterNode { public: explicit ClusterNode(std::string id, std::string host, int port, int role, std::string master_id, - std::bitset slots); + const std::bitset &slots); std::string id; std::string host; int port; @@ -71,6 +71,7 @@ class Cluster { explicit Cluster(Server *srv, std::vector binds, int port); Status SetClusterNodes(const std::string &nodes_str, int64_t version, bool force); Status GetClusterNodes(std::string *nodes_str); + StatusOr GetReplicas(const std::string &node_id); Status SetNodeId(const std::string &node_id); Status SetSlotRanges(const std::vector &slot_ranges, const std::string &node_id, int64_t version); Status SetSlotMigrated(int slot, const std::string &ip_port); @@ -80,7 +81,7 @@ class Cluster { int64_t GetVersion() const { return version_; } static bool IsValidSlot(int slot) { return slot >= 0 && slot < kClusterSlots; } bool IsNotMaster(); - bool IsWriteForbiddenSlot(int slot); + bool IsWriteForbiddenSlot(int slot) const; Status CanExecByMySelf(const redis::CommandAttributes *attributes, const std::vector &cmd_tokens, redis::Connection *conn); Status SetMasterSlaveRepl(); @@ -89,21 +90,23 @@ class Cluster { std::string GetMyId() const { return myid_; } Status DumpClusterNodes(const std::string &file); Status LoadClusterNodes(const std::string &file_path); + Status Reset(); static bool SubCommandIsExecExclusive(const std::string &subcommand); private: + std::string getNodeIDBySlot(int slot) const; std::string genNodesDescription(); - std::string genNodesInfo(); - std::map getClusterNodeSlots() const; + std::string genNodesInfo() const; + std::map> getClusterNodeSlots() const; SlotInfo genSlotNodeInfo(int start, int end, const std::shared_ptr &n); static Status parseClusterNodes(const std::string &nodes_str, ClusterNodes *nodes, std::unordered_map *slots_nodes); Server *srv_; std::vector binds_; int port_; - int size_; - int64_t version_; + int size_ = 0; + int64_t version_ = -1; std::string myid_; std::shared_ptr myself_; ClusterNodes nodes_; diff --git a/src/cluster/cluster_defs.h b/src/cluster/cluster_defs.h index 6638db8807e..0178ac75884 100644 --- a/src/cluster/cluster_defs.h +++ b/src/cluster/cluster_defs.h @@ -36,7 +36,7 @@ inline constexpr const char *errSlotOutOfRange = "Slot is out of range"; inline constexpr const char *errInvalidClusterVersion = "Invalid cluster version"; inline constexpr const char *errSlotOverlapped = "Slot distribution is overlapped"; inline constexpr const char *errNoMasterNode = "The node isn't a master"; -inline constexpr const char *errClusterNoInitialized = "CLUSTERDOWN The cluster is not initialized"; +inline constexpr const char *errClusterNoInitialized = "The cluster is not initialized"; inline constexpr const char *errInvalidClusterNodeInfo = "Invalid cluster nodes info"; inline constexpr const char *errInvalidImportState = "Invalid import state"; diff --git a/src/cluster/redis_slot.cc b/src/cluster/redis_slot.cc index 5934fd2d601..991b5d863e7 100644 --- a/src/cluster/redis_slot.cc +++ b/src/cluster/redis_slot.cc @@ -20,8 +20,6 @@ #include "redis_slot.h" -#include - #include #include #include diff --git a/src/cluster/replication.cc b/src/cluster/replication.cc index 4df05a47683..3de51a94047 100644 --- a/src/cluster/replication.cc +++ b/src/cluster/replication.cc @@ -33,6 +33,7 @@ #include #include +#include "commands/error_constants.h" #include "event_util.h" #include "fmt/format.h" #include "io_util.h" @@ -201,7 +202,7 @@ void ReplicationThread::CallbacksStateMachine::ReadWriteCB(bufferevent *bev) { assert(handler_idx_ <= handlers_.size()); DLOG(INFO) << "[replication] Execute handler[" << getHandlerName(handler_idx_) << "]"; auto st = getHandlerFunc(handler_idx_)(repl_, bev); - repl_->last_io_time_.store(util::GetTimeStamp(), std::memory_order_relaxed); + repl_->last_io_time_secs_.store(util::GetTimeStamp(), std::memory_order_relaxed); switch (st) { case CBState::NEXT: ++handler_idx_; @@ -402,13 +403,13 @@ ReplicationThread::CBState ReplicationThread::authWriteCB(bufferevent *bev) { return CBState::NEXT; } -inline bool ResponseLineIsOK(const char *line) { return strncmp(line, "+OK", 3) == 0; } +inline bool ResponseLineIsOK(std::string_view line) { return line == RESP_PREFIX_SIMPLE_STRING "OK"; } ReplicationThread::CBState ReplicationThread::authReadCB(bufferevent *bev) { // NOLINT auto input = bufferevent_get_input(bev); UniqueEvbufReadln line(input, EVBUFFER_EOL_CRLF_STRICT); if (!line) return CBState::AGAIN; - if (!ResponseLineIsOK(line.get())) { + if (!ResponseLineIsOK(line.View())) { // Auth failed LOG(ERROR) << "[replication] Auth failed: " << line.get(); return CBState::RESTART; @@ -430,7 +431,7 @@ ReplicationThread::CBState ReplicationThread::checkDBNameReadCB(bufferevent *bev if (!line) return CBState::AGAIN; if (line[0] == '-') { - if (isRestoringError(line.get())) { + if (isRestoringError(line.View())) { LOG(WARNING) << "The master was restoring the db, retry later"; } else { LOG(ERROR) << "Failed to get the db name, " << line.get(); @@ -468,18 +469,18 @@ ReplicationThread::CBState ReplicationThread::replConfReadCB(bufferevent *bev) { if (!line) return CBState::AGAIN; // on unknown option: first try without announce ip, if it fails again - do nothing (to prevent infinite loop) - if (isUnknownOption(line.get()) && !next_try_without_announce_ip_address_) { + if (isUnknownOption(line.View()) && !next_try_without_announce_ip_address_) { next_try_without_announce_ip_address_ = true; LOG(WARNING) << "The old version master, can't handle ip-address, " << "try without it again"; // Retry previous state, i.e. send replconf again return CBState::PREV; } - if (line[0] == '-' && isRestoringError(line.get())) { + if (line[0] == '-' && isRestoringError(line.View())) { LOG(WARNING) << "The master was restoring the db, retry later"; return CBState::RESTART; } - if (!ResponseLineIsOK(line.get())) { + if (!ResponseLineIsOK(line.View())) { LOG(WARNING) << "[replication] Failed to replconf: " << line.get() + 1; // backward compatible with old version that doesn't support replconf cmd return CBState::NEXT; @@ -530,12 +531,12 @@ ReplicationThread::CBState ReplicationThread::tryPSyncReadCB(bufferevent *bev) { UniqueEvbufReadln line(input, EVBUFFER_EOL_CRLF_STRICT); if (!line) return CBState::AGAIN; - if (line[0] == '-' && isRestoringError(line.get())) { + if (line[0] == '-' && isRestoringError(line.View())) { LOG(WARNING) << "The master was restoring the db, retry later"; return CBState::RESTART; } - if (line[0] == '-' && isWrongPsyncNum(line.get())) { + if (line[0] == '-' && isWrongPsyncNum(line.View())) { next_try_old_psync_ = true; LOG(WARNING) << "The old version master, can't handle new PSYNC, " << "try old PSYNC again"; @@ -543,7 +544,7 @@ ReplicationThread::CBState ReplicationThread::tryPSyncReadCB(bufferevent *bev) { return CBState::PREV; } - if (!ResponseLineIsOK(line.get())) { + if (!ResponseLineIsOK(line.View())) { // PSYNC isn't OK, we should use FullSync // Switch to fullsync state machine fullsync_steps_.Start(); @@ -844,7 +845,7 @@ Status ReplicationThread::sendAuth(int sock_fd, ssl_st *ssl) { } UniqueEvbufReadln line(evbuf.get(), EVBUFFER_EOL_CRLF_STRICT); if (!line) continue; - if (!ResponseLineIsOK(line.get())) { + if (!ResponseLineIsOK(line.View())) { return {Status::NotOK, "auth got invalid response"}; } break; @@ -998,30 +999,36 @@ Status ReplicationThread::parseWriteBatch(const std::string &batch_string) { return Status::OK(); } -bool ReplicationThread::isRestoringError(const char *err) { - return std::string(err) == "-ERR restoring the db from backup"; +bool ReplicationThread::isRestoringError(std::string_view err) { + // err doesn't contain the CRLF, so cannot use redis::Error here. + return err == RESP_PREFIX_ERROR + redis::StatusToRedisErrorMsg({Status::RedisLoading, redis::errRestoringBackup}); } -bool ReplicationThread::isWrongPsyncNum(const char *err) { - return std::string(err) == "-ERR wrong number of arguments"; +bool ReplicationThread::isWrongPsyncNum(std::string_view err) { + // err doesn't contain the CRLF, so cannot use redis::Error here. + return err == RESP_PREFIX_ERROR + redis::StatusToRedisErrorMsg({Status::NotOK, redis::errWrongNumArguments}); } -bool ReplicationThread::isUnknownOption(const char *err) { return std::string(err) == "-ERR unknown option"; } +bool ReplicationThread::isUnknownOption(std::string_view err) { + // err doesn't contain the CRLF, so cannot use redis::Error here. + return err == RESP_PREFIX_ERROR + redis::StatusToRedisErrorMsg({Status::NotOK, redis::errUnknownOption}); +} rocksdb::Status WriteBatchHandler::PutCF(uint32_t column_family_id, const rocksdb::Slice &key, const rocksdb::Slice &value) { type_ = kBatchTypeNone; - if (column_family_id == kColumnFamilyIDPubSub) { + if (column_family_id == static_cast(ColumnFamilyID::PubSub)) { type_ = kBatchTypePublish; kv_ = std::make_pair(key.ToString(), value.ToString()); return rocksdb::Status::OK(); - } else if (column_family_id == kColumnFamilyIDPropagate) { + } else if (column_family_id == static_cast(ColumnFamilyID::Propagate)) { type_ = kBatchTypePropagate; kv_ = std::make_pair(key.ToString(), value.ToString()); return rocksdb::Status::OK(); - } else if (column_family_id == kColumnFamilyIDStream) { + } else if (column_family_id == static_cast(ColumnFamilyID::Stream)) { type_ = kBatchTypeStream; kv_ = std::make_pair(key.ToString(), value.ToString()); + return rocksdb::Status::OK(); } return rocksdb::Status::OK(); } diff --git a/src/cluster/replication.h b/src/cluster/replication.h index b7f49717cc1..8da25713920 100644 --- a/src/cluster/replication.h +++ b/src/cluster/replication.h @@ -98,7 +98,7 @@ class ReplicationThread : private EventCallbackBase { Status Start(std::function &&pre_fullsync_cb, std::function &&post_fullsync_cb); void Stop(); ReplState State() { return repl_state_.load(std::memory_order_relaxed); } - time_t LastIOTime() { return last_io_time_.load(std::memory_order_relaxed); } + int64_t LastIOTimeSecs() const { return last_io_time_secs_.load(std::memory_order_relaxed); } void TimerCB(int, int16_t); @@ -155,7 +155,7 @@ class ReplicationThread : private EventCallbackBase { Server *srv_ = nullptr; engine::Storage *storage_ = nullptr; std::atomic repl_state_; - std::atomic last_io_time_ = 0; + std::atomic last_io_time_secs_ = 0; bool next_try_old_psync_ = false; bool next_try_without_announce_ip_address_ = false; @@ -204,9 +204,9 @@ class ReplicationThread : private EventCallbackBase { Status fetchFiles(int sock_fd, const std::string &dir, const std::vector &files, const std::vector &crcs, const FetchFileCallback &fn, ssl_st *ssl); Status parallelFetchFile(const std::string &dir, const std::vector> &files); - static bool isRestoringError(const char *err); - static bool isWrongPsyncNum(const char *err); - static bool isUnknownOption(const char *err); + static bool isRestoringError(std::string_view err); + static bool isWrongPsyncNum(std::string_view err); + static bool isUnknownOption(std::string_view err); Status parseWriteBatch(const std::string &batch_string); }; diff --git a/src/cluster/slot_import.cc b/src/cluster/slot_import.cc index 6696361171d..4306e336757 100644 --- a/src/cluster/slot_import.cc +++ b/src/cluster/slot_import.cc @@ -21,82 +21,69 @@ #include "slot_import.h" SlotImport::SlotImport(Server *srv) - : Database(srv->storage, kDefaultNamespace), - srv_(srv), - import_slot_(-1), - import_status_(kImportNone), - import_fd_(-1) { + : Database(srv->storage, kDefaultNamespace), srv_(srv), import_slot_(-1), import_status_(kImportNone) { std::lock_guard guard(mutex_); // Let metadata_cf_handle_ be nullptr, then get them in real time while use them. // See comments in SlotMigrator::SlotMigrator for detailed reason. metadata_cf_handle_ = nullptr; } -bool SlotImport::Start(int fd, int slot) { +Status SlotImport::Start(int slot) { std::lock_guard guard(mutex_); if (import_status_ == kImportStart) { - LOG(ERROR) << "[import] Only one slot importing is allowed" - << ", current slot is " << import_slot_ << ", cannot import slot " << slot; - return false; + // return ok if the same slot is importing + if (import_slot_ == slot) { + return Status::OK(); + } + return {Status::NotOK, fmt::format("only one importing slot is allowed, current slot is: {}", import_slot_)}; } // Clean slot data first auto s = ClearKeysOfSlot(namespace_, slot); if (!s.ok()) { - LOG(INFO) << "[import] Failed to clear keys of slot " << slot << "current status is importing 'START'" - << ", Err: " << s.ToString(); - return false; + return {Status::NotOK, fmt::format("clear keys of slot error: {}", s.ToString())}; } import_status_ = kImportStart; import_slot_ = slot; - import_fd_ = fd; - - return true; + return Status::OK(); } -bool SlotImport::Success(int slot) { +Status SlotImport::Success(int slot) { std::lock_guard guard(mutex_); if (import_slot_ != slot) { - LOG(ERROR) << "[import] Wrong slot, importing slot: " << import_slot_ << ", but got slot: " << slot; - return false; + return {Status::NotOK, fmt::format("mismatch slot, importing slot: {}, but got: {}", import_slot_, slot)}; } Status s = srv_->cluster->SetSlotImported(import_slot_); if (!s.IsOK()) { - LOG(ERROR) << "[import] Failed to set slot, Err: " << s.Msg(); - return false; + return {Status::NotOK, fmt::format("unable to set imported status: {}", slot)}; } import_status_ = kImportSuccess; - import_fd_ = -1; - - return true; + return Status::OK(); } -bool SlotImport::Fail(int slot) { +Status SlotImport::Fail(int slot) { std::lock_guard guard(mutex_); if (import_slot_ != slot) { - LOG(ERROR) << "[import] Wrong slot, importing slot: " << import_slot_ << ", but got slot: " << slot; - return false; + return {Status::NotOK, fmt::format("mismatch slot, importing slot: {}, but got: {}", import_slot_, slot)}; } // Clean imported slot data auto s = ClearKeysOfSlot(namespace_, slot); if (!s.ok()) { - LOG(INFO) << "[import] Failed to clear keys of slot " << slot << ", current importing status is importing 'FAIL'" - << ", Err: " << s.ToString(); + return {Status::NotOK, fmt::format("clear keys of slot error: {}", s.ToString())}; } import_status_ = kImportFailed; - import_fd_ = -1; - - return true; + return Status::OK(); } -void SlotImport::StopForLinkError(int fd) { +Status SlotImport::StopForLinkError() { std::lock_guard guard(mutex_); - if (import_status_ != kImportStart) return; + // We don't need to do anything if the importer is not started yet. + if (import_status_ != kImportStart) return Status::OK(); // Maybe server has failovered // Situation: @@ -111,18 +98,20 @@ void SlotImport::StopForLinkError(int fd) { // Clean imported slot data auto s = ClearKeysOfSlot(namespace_, import_slot_); if (!s.ok()) { - LOG(WARNING) << "[import] Failed to clear keys of slot " << import_slot_ << " Current status is link error" - << ", Err: " << s.ToString(); + return {Status::NotOK, fmt::format("clear keys of slot error: {}", s.ToString())}; } } - LOG(INFO) << "[import] Stop importing for link error, slot: " << import_slot_; import_status_ = kImportFailed; - import_fd_ = -1; + return Status::OK(); } int SlotImport::GetSlot() { std::lock_guard guard(mutex_); + // import_slot_ only be set when import_status_ is kImportStart + if (import_status_ != kImportStart) { + return -1; + } return import_slot_; } diff --git a/src/cluster/slot_import.h b/src/cluster/slot_import.h index 385aca0f6d1..9f4bb72d3de 100644 --- a/src/cluster/slot_import.h +++ b/src/cluster/slot_import.h @@ -42,10 +42,10 @@ class SlotImport : public redis::Database { explicit SlotImport(Server *srv); ~SlotImport() = default; - bool Start(int fd, int slot); - bool Success(int slot); - bool Fail(int slot); - void StopForLinkError(int fd); + Status Start(int slot); + Status Success(int slot); + Status Fail(int slot); + Status StopForLinkError(); int GetSlot(); int GetStatus(); void GetImportInfo(std::string *info); @@ -55,5 +55,4 @@ class SlotImport : public redis::Database { std::mutex mutex_; int import_slot_; int import_status_; - int import_fd_; }; diff --git a/src/cluster/slot_migrate.cc b/src/cluster/slot_migrate.cc index 884c8ace7aa..a074d536b74 100644 --- a/src/cluster/slot_migrate.cc +++ b/src/cluster/slot_migrate.cc @@ -334,15 +334,17 @@ Status SlotMigrator::sendSnapshotByCmd() { LOG(INFO) << "[migrate] Start migrating snapshot of slot " << slot; - rocksdb::ReadOptions read_options = storage_->DefaultScanOptions(); - read_options.snapshot = slot_snapshot_; - rocksdb::ColumnFamilyHandle *cf_handle = storage_->GetCFHandle(engine::kMetadataColumnFamilyName); - auto iter = util::UniqueIterator(storage_->GetDB()->NewIterator(read_options, cf_handle)); - // Construct key prefix to iterate the keys belong to the target slot std::string prefix = ComposeSlotKeyPrefix(namespace_, slot); LOG(INFO) << "[migrate] Iterate keys of slot, key's prefix: " << prefix; + rocksdb::ReadOptions read_options = storage_->DefaultScanOptions(); + read_options.snapshot = slot_snapshot_; + Slice prefix_slice(prefix); + read_options.iterate_lower_bound = &prefix_slice; + rocksdb::ColumnFamilyHandle *cf_handle = storage_->GetCFHandle(ColumnFamilyID::Metadata); + auto iter = util::UniqueIterator(storage_->GetDB()->NewIterator(read_options, cf_handle)); + // Seek to the beginning of keys start with 'prefix' and iterate all these keys for (iter->Seek(prefix); iter->Valid(); iter->Next()) { // The migrating task has to be stopped, if server role is changed from master to slave @@ -738,14 +740,16 @@ Status SlotMigrator::migrateComplexKey(const rocksdb::Slice &key, const Metadata cmd = type_to_cmd[metadata.Type()]; std::vector user_cmd = {cmd, key.ToString()}; + // Construct key prefix to iterate values of the complex type user key + std::string slot_key = AppendNamespacePrefix(key); + std::string prefix_subkey = InternalKey(slot_key, "", metadata.version, true).Encode(); rocksdb::ReadOptions read_options = storage_->DefaultScanOptions(); read_options.snapshot = slot_snapshot_; + Slice prefix_slice(prefix_subkey); + read_options.iterate_lower_bound = &prefix_slice; // Should use th raw db iterator to avoid reading uncommitted writes in transaction mode auto iter = util::UniqueIterator(storage_->GetDB()->NewIterator(read_options)); - // Construct key prefix to iterate values of the complex type user key - std::string slot_key = AppendNamespacePrefix(key); - std::string prefix_subkey = InternalKey(slot_key, "", metadata.version, true).Encode(); int item_count = 0; for (iter->Seek(prefix_subkey); iter->Valid(); iter->Next()) { @@ -840,13 +844,15 @@ Status SlotMigrator::migrateComplexKey(const rocksdb::Slice &key, const Metadata Status SlotMigrator::migrateStream(const Slice &key, const StreamMetadata &metadata, std::string *restore_cmds) { rocksdb::ReadOptions read_options = storage_->DefaultScanOptions(); read_options.snapshot = slot_snapshot_; - // Should use th raw db iterator to avoid reading uncommitted writes in transaction mode - auto iter = util::UniqueIterator( - storage_->GetDB()->NewIterator(read_options, storage_->GetCFHandle(engine::kStreamColumnFamilyName))); - std::string ns_key = AppendNamespacePrefix(key); // Construct key prefix to iterate values of the stream std::string prefix_key = InternalKey(ns_key, "", metadata.version, true).Encode(); + rocksdb::Slice prefix_key_slice(prefix_key); + read_options.iterate_lower_bound = &prefix_key_slice; + + // Should use th raw db iterator to avoid reading uncommitted writes in transaction mode + auto iter = + util::UniqueIterator(storage_->GetDB()->NewIterator(read_options, storage_->GetCFHandle(ColumnFamilyID::Stream))); std::vector user_cmd = {type_to_cmd[metadata.Type()], key.ToString()}; @@ -1197,10 +1203,12 @@ Status SlotMigrator::sendSnapshotByRawKV() { uint64_t start_ts = util::GetTimeStampMS(); LOG(INFO) << "[migrate] Migrating snapshot of slot " << migrating_slot_ << " by raw key value"; + auto prefix = ComposeSlotKeyPrefix(namespace_, migrating_slot_); rocksdb::ReadOptions read_options = storage_->DefaultScanOptions(); read_options.snapshot = slot_snapshot_; + rocksdb::Slice prefix_slice(prefix); + read_options.iterate_lower_bound = &prefix_slice; engine::DBIterator iter(storage_, read_options); - auto prefix = ComposeSlotKeyPrefix(namespace_, migrating_slot_); BatchSender batch_sender(*dst_fd_, migrate_batch_size_bytes_, migrate_batch_bytes_per_sec_); @@ -1216,7 +1224,7 @@ Status SlotMigrator::sendSnapshotByRawKV() { } batch_sender.SetPrefixLogData(log_data); - GET_OR_RET(batch_sender.Put(storage_->GetCFHandle(engine::kMetadataColumnFamilyName), iter.Key(), iter.Value())); + GET_OR_RET(batch_sender.Put(storage_->GetCFHandle(ColumnFamilyID::Metadata), iter.Key(), iter.Value())); auto subkey_iter = iter.GetSubKeyIterator(); if (!subkey_iter) { @@ -1232,7 +1240,7 @@ Status SlotMigrator::sendSnapshotByRawKV() { score_key.append(subkey_iter->UserKey().ToString()); auto score_key_bytes = InternalKey(iter.Key(), score_key, internal_key.GetVersion(), storage_->IsSlotIdEncoded()).Encode(); - GET_OR_RET(batch_sender.Put(storage_->GetCFHandle(kColumnFamilyIDZSetScore), score_key_bytes, Slice())); + GET_OR_RET(batch_sender.Put(storage_->GetCFHandle(ColumnFamilyID::SecondarySubkey), score_key_bytes, Slice())); } if (batch_sender.IsFull()) { @@ -1326,11 +1334,19 @@ Status SlotMigrator::migrateIncrementalDataByRawKV(uint64_t end_seq, BatchSender break; } case engine::WALItem::Type::kTypePut: { + if (item.column_family_id > kMaxColumnFamilyID) { + LOG(INFO) << fmt::format("[migrate] Invalid put column family id: {}", item.column_family_id); + continue; + } GET_OR_RET(batch_sender->Put(storage_->GetCFHandle(static_cast(item.column_family_id)), item.key, item.value)); break; } case engine::WALItem::Type::kTypeDelete: { + if (item.column_family_id > kMaxColumnFamilyID) { + LOG(INFO) << fmt::format("[migrate] Invalid delete column family id: {}", item.column_family_id); + continue; + } GET_OR_RET( batch_sender->Delete(storage_->GetCFHandle(static_cast(item.column_family_id)), item.key)); break; diff --git a/src/cluster/slot_migrate.h b/src/cluster/slot_migrate.h index e1faf404682..e22ba47d621 100644 --- a/src/cluster/slot_migrate.h +++ b/src/cluster/slot_migrate.h @@ -103,6 +103,7 @@ class SlotMigrator : public redis::Database { SlotMigrationStage GetCurrentSlotMigrationStage() const { return current_stage_; } int16_t GetForbiddenSlot() const { return forbidden_slot_; } int16_t GetMigratingSlot() const { return migrating_slot_; } + std::string GetDstNode() const { return dst_node_; } void GetMigrationInfo(std::string *info) const; void CancelSyncCtx(); diff --git a/src/cluster/sync_migrate_context.cc b/src/cluster/sync_migrate_context.cc index 3ba2806ca45..f1de89b6e26 100644 --- a/src/cluster/sync_migrate_context.cc +++ b/src/cluster/sync_migrate_context.cc @@ -68,7 +68,7 @@ void SyncMigrateContext::OnWrite(bufferevent *bev) { if (migrate_result_) { conn_->Reply(redis::SimpleString("OK")); } else { - conn_->Reply(redis::Error("ERR " + migrate_result_.Msg())); + conn_->Reply(redis::Error(migrate_result_)); } timer_.reset(); diff --git a/src/commands/cmd_bit.cc b/src/commands/cmd_bit.cc index 088e0add853..541bc38f691 100644 --- a/src/commands/cmd_bit.cc +++ b/src/commands/cmd_bit.cc @@ -22,6 +22,7 @@ #include "commands/command_parser.h" #include "error_constants.h" #include "server/server.h" +#include "status.h" #include "types/redis_bitmap.h" namespace redis { @@ -171,6 +172,10 @@ class CommandBitPos : public Commander { stop_ = *parse_stop; } + if (args.size() >= 6 && util::EqualICase(args[5], "BIT")) { + is_bit_index_ = true; + } + auto parse_arg = ParseInt(args[2], 10); if (!parse_arg) { return {Status::RedisParseErr, errValueNotInteger}; @@ -189,7 +194,7 @@ class CommandBitPos : public Commander { Status Execute(Server *srv, Connection *conn, std::string *output) override { int64_t pos = 0; redis::Bitmap bitmap_db(srv->storage, conn->GetNamespace()); - auto s = bitmap_db.BitPos(args_[1], bit_, start_, stop_, stop_given_, &pos); + auto s = bitmap_db.BitPos(args_[1], bit_, start_, stop_, stop_given_, &pos, is_bit_index_); if (!s.ok()) return {Status::RedisExecErr, s.ToString()}; *output = redis::Integer(pos); @@ -201,6 +206,7 @@ class CommandBitPos : public Commander { int64_t stop_ = -1; bool bit_ = false; bool stop_given_ = false; + bool is_bit_index_ = false; }; class CommandBitOp : public Commander { diff --git a/src/commands/cmd_bloom_filter.cc b/src/commands/cmd_bloom_filter.cc index f33979e3d50..bb2df94e2f5 100644 --- a/src/commands/cmd_bloom_filter.cc +++ b/src/commands/cmd_bloom_filter.cc @@ -33,7 +33,7 @@ constexpr const char *errInvalidErrorRate = "error rate should be between 0 and constexpr const char *errInvalidCapacity = "capacity should be larger than 0"; constexpr const char *errInvalidExpansion = "expansion should be greater or equal to 1"; constexpr const char *errNonscalingButExpand = "nonscaling filters cannot expand"; -constexpr const char *errFilterFull = "ERR nonscaling filter is full"; +constexpr const char *errFilterFull = "nonscaling filter is full"; } // namespace namespace redis { @@ -119,7 +119,7 @@ class CommandBFAdd : public Commander { *output = redis::Integer(0); break; case BloomFilterAddResult::kFull: - *output = redis::Error(errFilterFull); + *output = redis::Error({Status::NotOK, errFilterFull}); break; } return Status::OK(); @@ -152,7 +152,7 @@ class CommandBFMAdd : public Commander { *output += redis::Integer(0); break; case BloomFilterAddResult::kFull: - *output += redis::Error(errFilterFull); + *output += redis::Error({Status::NotOK, errFilterFull}); break; } } @@ -248,7 +248,7 @@ class CommandBFInsert : public Commander { *output += redis::Integer(0); break; case BloomFilterAddResult::kFull: - *output += redis::Error(errFilterFull); + *output += redis::Error({Status::NotOK, errFilterFull}); break; } } diff --git a/src/commands/cmd_cluster.cc b/src/commands/cmd_cluster.cc index 96ecfbde28b..0f9b17603fc 100644 --- a/src/commands/cmd_cluster.cc +++ b/src/commands/cmd_cluster.cc @@ -23,6 +23,7 @@ #include "cluster/sync_migrate_context.h" #include "commander.h" #include "error_constants.h" +#include "status.h" namespace redis { @@ -34,6 +35,14 @@ class CommandCluster : public Commander { if (args.size() == 2 && (subcommand_ == "nodes" || subcommand_ == "slots" || subcommand_ == "info")) return Status::OK(); + // CLUSTER RESET [HARD|SOFT] + if (subcommand_ == "reset" && (args_.size() == 2 || args_.size() == 3)) { + if (args_.size() == 3 && !util::EqualICase(args_[2], "hard") && !util::EqualICase(args_[2], "soft")) { + return {Status::RedisParseErr, errInvalidSyntax}; + } + return Status::OK(); + } + if (subcommand_ == "keyslot" && args_.size() == 3) return Status::OK(); if (subcommand_ == "import") { @@ -47,7 +56,9 @@ class CommandCluster : public Commander { return Status::OK(); } - return {Status::RedisParseErr, "CLUSTER command, CLUSTER INFO|NODES|SLOTS|KEYSLOT"}; + if (subcommand_ == "replicas" && args_.size() == 3) return Status::OK(); + + return {Status::RedisParseErr, "CLUSTER command, CLUSTER INFO|NODES|SLOTS|KEYSLOT|RESET|REPLICAS"}; } Status Execute(Server *srv, Connection *conn, std::string *output) override { @@ -79,7 +90,7 @@ class CommandCluster : public Commander { } } } else { - return {Status::RedisExecErr, s.Msg()}; + return s; } } else if (subcommand_ == "nodes") { std::string nodes_desc; @@ -87,7 +98,7 @@ class CommandCluster : public Commander { if (s.IsOK()) { *output = conn->VerbatimString("txt", nodes_desc); } else { - return {Status::RedisExecErr, s.Msg()}; + return s; } } else if (subcommand_ == "info") { std::string cluster_info; @@ -95,14 +106,29 @@ class CommandCluster : public Commander { if (s.IsOK()) { *output = conn->VerbatimString("txt", cluster_info); } else { - return {Status::RedisExecErr, s.Msg()}; + return s; } } else if (subcommand_ == "import") { Status s = srv->cluster->ImportSlot(conn, static_cast(slot_), state_); if (s.IsOK()) { *output = redis::SimpleString("OK"); } else { - return {Status::RedisExecErr, s.Msg()}; + return s; + } + } else if (subcommand_ == "reset") { + Status s = srv->cluster->Reset(); + if (s.IsOK()) { + *output = redis::SimpleString("OK"); + } else { + return s; + } + } else if (subcommand_ == "replicas") { + auto node_id = args_[2]; + StatusOr s = srv->cluster->GetReplicas(node_id); + if (s.IsOK()) { + *output = conn->VerbatimString("txt", s.GetValue()); + } else { + return s; } } else { return {Status::RedisExecErr, "Invalid cluster command options"}; @@ -121,7 +147,7 @@ class CommandClusterX : public Commander { Status Parse(const std::vector &args) override { subcommand_ = util::ToLower(args[1]); - if (args.size() == 2 && (subcommand_ == "version")) return Status::OK(); + if (args.size() == 2 && (subcommand_ == "version" || subcommand_ == "myid")) return Status::OK(); if (subcommand_ == "setnodeid" && args_.size() == 3 && args_[2].size() == kClusterNodeIdLen) return Status::OK(); @@ -207,7 +233,7 @@ class CommandClusterX : public Commander { return Status::OK(); } - return {Status::RedisParseErr, "CLUSTERX command, CLUSTERX VERSION|SETNODEID|SETNODES|SETSLOT|MIGRATE"}; + return {Status::RedisParseErr, "CLUSTERX command, CLUSTERX VERSION|MYID|SETNODEID|SETNODES|SETSLOT|MIGRATE"}; } Status Execute(Server *srv, Connection *conn, std::string *output) override { @@ -226,7 +252,7 @@ class CommandClusterX : public Commander { need_persist_nodes_info = true; *output = redis::SimpleString("OK"); } else { - return {Status::RedisExecErr, s.Msg()}; + return s; } } else if (subcommand_ == "setnodeid") { Status s = srv->cluster->SetNodeId(args_[2]); @@ -234,7 +260,7 @@ class CommandClusterX : public Commander { need_persist_nodes_info = true; *output = redis::SimpleString("OK"); } else { - return {Status::RedisExecErr, s.Msg()}; + return s; } } else if (subcommand_ == "setslot") { Status s = srv->cluster->SetSlotRanges(slot_ranges_, args_[4], set_version_); @@ -242,11 +268,13 @@ class CommandClusterX : public Commander { need_persist_nodes_info = true; *output = redis::SimpleString("OK"); } else { - return {Status::RedisExecErr, s.Msg()}; + return s; } } else if (subcommand_ == "version") { int64_t v = srv->cluster->GetVersion(); *output = redis::BulkString(std::to_string(v)); + } else if (subcommand_ == "myid") { + *output = redis::BulkString(srv->cluster->GetMyId()); } else if (subcommand_ == "migrate") { if (sync_migrate_) { sync_migrate_ctx_ = std::make_unique(srv, conn, sync_migrate_timeout_); @@ -259,7 +287,7 @@ class CommandClusterX : public Commander { } *output = redis::SimpleString("OK"); } else { - return {Status::RedisExecErr, s.Msg()}; + return s; } } else { return {Status::RedisExecErr, "Invalid cluster command options"}; @@ -284,16 +312,45 @@ class CommandClusterX : public Commander { std::unique_ptr sync_migrate_ctx_ = nullptr; }; -static uint64_t GenerateClusterFlag(const std::vector &args) { +static uint64_t GenerateClusterFlag(uint64_t flags, const std::vector &args) { if (args.size() >= 2 && Cluster::SubCommandIsExecExclusive(args[1])) { - return kCmdExclusive; + return flags | kCmdExclusive; } - return 0; + return flags; } +class CommandReadOnly : public Commander { + public: + Status Execute(Server *srv, Connection *conn, std::string *output) override { + *output = redis::SimpleString("OK"); + conn->EnableFlag(redis::Connection::kReadOnly); + return Status::OK(); + } +}; + +class CommandReadWrite : public Commander { + public: + Status Execute(Server *srv, Connection *conn, std::string *output) override { + *output = redis::SimpleString("OK"); + conn->DisableFlag(redis::Connection::kReadOnly); + return Status::OK(); + } +}; + +class CommandAsking : public Commander { + public: + Status Execute(Server *srv, Connection *conn, std::string *output) override { + conn->EnableFlag(redis::Connection::kAsking); + *output = redis::SimpleString("OK"); + return Status::OK(); + } +}; + REDIS_REGISTER_COMMANDS(MakeCmdAttr("cluster", -2, "cluster no-script", 0, 0, 0, GenerateClusterFlag), - MakeCmdAttr("clusterx", -2, "cluster no-script", 0, 0, 0, - GenerateClusterFlag), ) + MakeCmdAttr("clusterx", -2, "cluster no-script", 0, 0, 0, GenerateClusterFlag), + MakeCmdAttr("readonly", 1, "cluster no-multi", 0, 0, 0), + MakeCmdAttr("readwrite", 1, "cluster no-multi", 0, 0, 0), + MakeCmdAttr("asking", 1, "cluster", 0, 0, 0), ) } // namespace redis diff --git a/src/commands/cmd_function.cc b/src/commands/cmd_function.cc index 2d7ce193e49..1afab33d61e 100644 --- a/src/commands/cmd_function.cc +++ b/src/commands/cmd_function.cc @@ -99,7 +99,16 @@ struct CommandFCall : Commander { CommandKeyRange GetScriptEvalKeyRange(const std::vector &args); -REDIS_REGISTER_COMMANDS(MakeCmdAttr("function", -2, "exclusive no-script", 0, 0, 0), +uint64_t GenerateFunctionFlags(uint64_t flags, const std::vector &args) { + if (util::EqualICase(args[1], "load") || util::EqualICase(args[1], "delete")) { + return flags | kCmdWrite; + } + + return flags; +} + +REDIS_REGISTER_COMMANDS(MakeCmdAttr("function", -2, "exclusive no-script", 0, 0, 0, + GenerateFunctionFlags), MakeCmdAttr>("fcall", -3, "exclusive write no-script", GetScriptEvalKeyRange), MakeCmdAttr>("fcall_ro", -3, "read-only ro-script no-script", GetScriptEvalKeyRange)); diff --git a/src/commands/cmd_hash.cc b/src/commands/cmd_hash.cc index 6db97f89025..677f131eb98 100644 --- a/src/commands/cmd_hash.cc +++ b/src/commands/cmd_hash.cc @@ -326,16 +326,11 @@ class CommandHRangeByLex : public Commander { return parser.InvalidSyntax(); } } - Status s; if (spec_.reversed) { - s = ParseRangeLexSpec(args[3], args[2], &spec_); + return ParseRangeLexSpec(args[3], args[2], &spec_); } else { - s = ParseRangeLexSpec(args[2], args[3], &spec_); + return ParseRangeLexSpec(args[2], args[3], &spec_); } - if (!s.IsOK()) { - return {Status::RedisParseErr, s.Msg()}; - } - return Status::OK(); } Status Execute(Server *srv, Connection *conn, std::string *output) override { diff --git a/src/commands/cmd_json.cc b/src/commands/cmd_json.cc index 54a28271eae..7708d2f8785 100644 --- a/src/commands/cmd_json.cc +++ b/src/commands/cmd_json.cc @@ -45,6 +45,14 @@ std::string OptionalsToString(const Connection *conn, Optionals &opts) { return str; } +std::string SizeToString(const std::vector &elems) { + std::string result = MultiLen(elems.size()); + for (const auto &elem : elems) { + result += redis::Integer(elem); + } + return result; +} + class CommandJsonSet : public Commander { public: Status Execute(Server *srv, Connection *conn, std::string *output) override { @@ -383,8 +391,12 @@ class CommandJsonObjLen : public Commander { Optionals results; auto s = json.ObjLen(args_[1], path, &results); if (s.IsNotFound()) { - *output = conn->NilString(); - return Status::OK(); + if (args_.size() == 2) { + *output = conn->NilString(); + return Status::OK(); + } else { + return {Status::RedisExecErr, "Path '" + args_[2] + "' does not exist or not an object"}; + } } if (!s.ok()) return {Status::RedisExecErr, s.ToString()}; @@ -411,8 +423,7 @@ class CommandJsonArrTrim : public Commander { auto s = json.ArrTrim(args_[1], path_, start_, stop_, &results); if (s.IsNotFound()) { - *output = conn->NilString(); - return Status::OK(); + return {Status::RedisExecErr, "could not perform this operation on a key that doesn't exist"}; } if (!s.ok()) return {Status::RedisExecErr, s.ToString()}; @@ -559,6 +570,15 @@ class CommandJsonStrLen : public Commander { Optionals results; auto s = json.StrLen(args_[1], path, &results); + if (s.IsNotFound()) { + if (args_.size() == 2) { + *output = conn->NilString(); + return Status::OK(); + } else { + return {Status::RedisExecErr, "could not perform this operation on a key that doesn't exist"}; + } + } + if (!s.ok()) return {Status::RedisExecErr, s.ToString()}; *output = OptionalsToString(conn, results); @@ -594,6 +614,65 @@ class CommandJsonMGet : public Commander { } }; +class CommandJsonMSet : public Commander { + public: + Status Execute(Server *svr, Connection *conn, std::string *output) override { + if ((args_.size() - 1) % 3 != 0) { + return {Status::RedisExecErr, errWrongNumOfArguments}; + } + + redis::Json json(svr->storage, conn->GetNamespace()); + + std::vector user_keys; + std::vector paths; + std::vector values; + for (size_t i = 0; i < (args_.size() - 1) / 3; i++) { + user_keys.emplace_back(args_[i * 3 + 1]); + paths.emplace_back(args_[i * 3 + 2]); + values.emplace_back(args_[i * 3 + 3]); + } + + if (auto s = json.MSet(user_keys, paths, values); !s.ok()) return {Status::RedisExecErr, s.ToString()}; + + *output = redis::SimpleString("OK"); + return Status::OK(); + } +}; + +class CommandJsonDebug : public Commander { + public: + Status Execute(Server *svr, Connection *conn, std::string *output) override { + redis::Json json(svr->storage, conn->GetNamespace()); + + std::string path = "$"; + + if (!util::EqualICase(args_[1], "memory")) { + return {Status::RedisExecErr, "ERR wrong number of arguments for 'json.debug' command"}; + } + + if (args_.size() == 4) { + path = args_[3]; + } else if (args_.size() > 4) { + return {Status::RedisExecErr, "The number of arguments is more than expected"}; + } + + std::vector results; + auto s = json.DebugMemory(args_[2], path, &results); + + if (s.IsNotFound()) { + if (args_.size() == 3) { + *output = redis::Integer(0); + } else { + *output = SizeToString(results); + } + return Status::OK(); + } + if (!s.ok()) return {Status::RedisExecErr, s.ToString()}; + + *output = SizeToString(results); + return Status::OK(); + } +}; REDIS_REGISTER_COMMANDS(MakeCmdAttr("json.set", 4, "write", 1, 1, 1), MakeCmdAttr("json.get", -2, "read-only", 1, 1, 1), MakeCmdAttr("json.info", 2, "read-only", 1, 1, 1), @@ -616,6 +695,8 @@ REDIS_REGISTER_COMMANDS(MakeCmdAttr("json.set", 4, "write", 1, 1 MakeCmdAttr("json.objlen", -2, "read-only", 1, 1, 1), MakeCmdAttr("json.strappend", -3, "write", 1, 1, 1), MakeCmdAttr("json.strlen", -2, "read-only", 1, 1, 1), - MakeCmdAttr("json.mget", -3, "read-only", 1, 1, 1), ); + MakeCmdAttr("json.mget", -3, "read-only", 1, -2, 1), + MakeCmdAttr("json.mset", -4, "write", 1, -3, 3), + MakeCmdAttr("json.debug", -3, "read-only", 2, 2, 1)); } // namespace redis diff --git a/src/commands/cmd_key.cc b/src/commands/cmd_key.cc index 2eacdd1ef74..d7219841a80 100644 --- a/src/commands/cmd_key.cc +++ b/src/commands/cmd_key.cc @@ -64,6 +64,40 @@ class CommandMove : public Commander { } }; +class CommandMoveX : public Commander { + public: + Status Execute(Server *srv, Connection *conn, std::string *output) override { + std::string &key = args_[1], &token = args_[2]; + + redis::Database redis(srv->storage, conn->GetNamespace()); + + std::string ns; + AuthResult auth_result = srv->AuthenticateUser(token, &ns); + switch (auth_result) { + case AuthResult::NO_REQUIRE_PASS: + return {Status::NotOK, "Forbidden to move key when requirepass is empty"}; + case AuthResult::INVALID_PASSWORD: + return {Status::NotOK, "Invalid password"}; + case AuthResult::IS_USER: + case AuthResult::IS_ADMIN: + break; + } + + Database::CopyResult res = Database::CopyResult::DONE; + std::string ns_key = redis.AppendNamespacePrefix(key); + std::string new_ns_key = ComposeNamespaceKey(ns, key, srv->storage->IsSlotIdEncoded()); + auto s = redis.Copy(ns_key, new_ns_key, true, true, &res); + if (!s.ok()) return {Status::RedisExecErr, s.ToString()}; + + if (res == Database::CopyResult::DONE) { + *output = redis::Integer(1); + } else { + *output = redis::Integer(0); + } + return Status::OK(); + } +}; + class CommandObject : public Commander { public: Status Execute(Server *srv, Connection *conn, std::string *output) override { @@ -187,7 +221,7 @@ class CommandExpireAt : public Commander { timestamp_ = *parse_result; - return Commander::Parse(args); + return Status::OK(); } Status Execute(Server *srv, Connection *conn, std::string *output) override { @@ -311,11 +345,12 @@ class CommandRename : public Commander { public: Status Execute(Server *srv, Connection *conn, std::string *output) override { redis::Database redis(srv->storage, conn->GetNamespace()); - bool ret = true; - - auto s = redis.Rename(args_[1], args_[2], false, &ret); + Database::CopyResult res = Database::CopyResult::DONE; + std::string ns_key = redis.AppendNamespacePrefix(args_[1]); + std::string new_ns_key = redis.AppendNamespacePrefix(args_[2]); + auto s = redis.Copy(ns_key, new_ns_key, false, true, &res); if (!s.ok()) return {Status::RedisExecErr, s.ToString()}; - + if (res == Database::CopyResult::KEY_NOT_EXIST) return {Status::RedisExecErr, "no such key"}; *output = redis::SimpleString("OK"); return Status::OK(); } @@ -325,22 +360,185 @@ class CommandRenameNX : public Commander { public: Status Execute(Server *srv, Connection *conn, std::string *output) override { redis::Database redis(srv->storage, conn->GetNamespace()); - bool ret = true; - auto s = redis.Rename(args_[1], args_[2], true, &ret); + Database::CopyResult res = Database::CopyResult::DONE; + std::string ns_key = redis.AppendNamespacePrefix(args_[1]); + std::string new_ns_key = redis.AppendNamespacePrefix(args_[2]); + auto s = redis.Copy(ns_key, new_ns_key, true, true, &res); if (!s.ok()) return {Status::RedisExecErr, s.ToString()}; - if (ret) { - *output = redis::Integer(1); - } else { - *output = redis::Integer(0); + switch (res) { + case Database::CopyResult::KEY_NOT_EXIST: + return {Status::RedisExecErr, "no such key"}; + case Database::CopyResult::DONE: + *output = redis::Integer(1); + break; + case Database::CopyResult::KEY_ALREADY_EXIST: + *output = redis::Integer(0); + break; + } + return Status::OK(); + } +}; + +class CommandCopy : public Commander { + public: + Status Parse(const std::vector &args) override { + CommandParser parser(args, 3); + while (parser.Good()) { + if (parser.EatEqICase("db")) { + auto db_num = GET_OR_RET(parser.TakeInt()); + // There's only one database in Kvrocks, so the DB must be 0 here. + if (db_num != 0) { + return {Status::RedisParseErr, errInvalidSyntax}; + } + } else if (parser.EatEqICase("replace")) { + replace_ = true; + } else { + return parser.InvalidSyntax(); + } + } + + return Status::OK(); + } + + Status Execute(Server *srv, Connection *conn, std::string *output) override { + redis::Database redis(srv->storage, conn->GetNamespace()); + Database::CopyResult res = Database::CopyResult::DONE; + std::string ns_key = redis.AppendNamespacePrefix(args_[1]); + std::string new_ns_key = redis.AppendNamespacePrefix(args_[2]); + auto s = redis.Copy(ns_key, new_ns_key, !replace_, false, &res); + if (!s.ok()) return {Status::RedisExecErr, s.ToString()}; + switch (res) { + case Database::CopyResult::KEY_NOT_EXIST: + return {Status::RedisExecErr, "no such key"}; + case Database::CopyResult::DONE: + *output = redis::Integer(1); + break; + case Database::CopyResult::KEY_ALREADY_EXIST: + *output = redis::Integer(0); + break; + } + return Status::OK(); + } + + private: + bool replace_ = false; +}; + +template +class CommandSort : public Commander { + public: + Status Parse(const std::vector &args) override { + CommandParser parser(args, 2); + while (parser.Good()) { + if (parser.EatEqICase("BY")) { + if (!sort_argument_.sortby.empty()) return {Status::InvalidArgument, "don't use multiple BY parameters"}; + sort_argument_.sortby = GET_OR_RET(parser.TakeStr()); + + if (sort_argument_.sortby.find('*') == std::string::npos) { + sort_argument_.dontsort = true; + } else { + /* TODO: + * If BY is specified with a real pattern, we can't accept it in cluster mode, + * unless we can make sure the keys formed by the pattern are in the same slot + * as the key to sort. + * If BY is specified with a real pattern, we can't accept + * it if no full ACL key access is applied for this command. */ + } + } else if (parser.EatEqICase("LIMIT")) { + sort_argument_.offset = GET_OR_RET(parser.template TakeInt()); + sort_argument_.count = GET_OR_RET(parser.template TakeInt()); + } else if (parser.EatEqICase("GET")) { + /* TODO: + * If GET is specified with a real pattern, we can't accept it in cluster mode, + * unless we can make sure the keys formed by the pattern are in the same slot + * as the key to sort. */ + sort_argument_.getpatterns.push_back(GET_OR_RET(parser.TakeStr())); + } else if (parser.EatEqICase("ASC")) { + sort_argument_.desc = false; + } else if (parser.EatEqICase("DESC")) { + sort_argument_.desc = true; + } else if (parser.EatEqICase("ALPHA")) { + sort_argument_.alpha = true; + } else if (parser.EatEqICase("STORE")) { + if constexpr (ReadOnly) { + return {Status::RedisParseErr, "SORT_RO is read-only and does not support the STORE parameter"}; + } + sort_argument_.storekey = GET_OR_RET(parser.TakeStr()); + } else { + return parser.InvalidSyntax(); + } + } + + return Status::OK(); + } + + Status Execute(Server *srv, Connection *conn, std::string *output) override { + redis::Database redis(srv->storage, conn->GetNamespace()); + RedisType type = kRedisNone; + if (auto s = redis.Type(args_[1], &type); !s.ok()) { + return {Status::RedisExecErr, s.ToString()}; + } + + if (type != RedisType::kRedisList && type != RedisType::kRedisSet && type != RedisType::kRedisZSet) { + return {Status::RedisWrongType, "Operation against a key holding the wrong kind of value"}; + } + + /* When sorting a set with no sort specified, we must sort the output + * so the result is consistent across scripting and replication. + * + * The other types (list, sorted set) will retain their native order + * even if no sort order is requested, so they remain stable across + * scripting and replication. + * + * TODO: support CLIENT_SCRIPT flag, (!storekey_.empty() || c->flags & CLIENT_SCRIPT)) */ + if (sort_argument_.dontsort && type == RedisType::kRedisSet && (!sort_argument_.storekey.empty())) { + /* Force ALPHA sorting */ + sort_argument_.dontsort = false; + sort_argument_.alpha = true; + sort_argument_.sortby = ""; } + + std::vector> sorted_elems; + Database::SortResult res = Database::SortResult::DONE; + + if (auto s = redis.Sort(type, args_[1], sort_argument_, &sorted_elems, &res); !s.ok()) { + return {Status::RedisExecErr, s.ToString()}; + } + + switch (res) { + case Database::SortResult::UNKNOWN_TYPE: + return {Status::RedisErrorNoPrefix, "Unknown Type"}; + case Database::SortResult::DOUBLE_CONVERT_ERROR: + return {Status::RedisErrorNoPrefix, "One or more scores can't be converted into double"}; + case Database::SortResult::LIMIT_EXCEEDED: + return {Status::RedisErrorNoPrefix, + "The number of elements to be sorted exceeds SORT_LENGTH_LIMIT = " + std::to_string(SORT_LENGTH_LIMIT)}; + case Database::SortResult::DONE: + if (sort_argument_.storekey.empty()) { + std::vector output_vec; + output_vec.reserve(sorted_elems.size()); + for (const auto &elem : sorted_elems) { + output_vec.emplace_back(elem.has_value() ? redis::BulkString(elem.value()) : conn->NilString()); + } + *output = redis::Array(output_vec); + } else { + *output = Integer(sorted_elems.size()); + } + break; + } + return Status::OK(); } + + private: + SortArgument sort_argument_; }; REDIS_REGISTER_COMMANDS(MakeCmdAttr("ttl", 2, "read-only", 1, 1, 1), MakeCmdAttr("pttl", 2, "read-only", 1, 1, 1), MakeCmdAttr("type", 2, "read-only", 1, 1, 1), MakeCmdAttr("move", 3, "write", 1, 1, 1), + MakeCmdAttr("movex", 3, "write", 1, 1, 1), MakeCmdAttr("object", 3, "read-only", 2, 2, 1), MakeCmdAttr("exists", -2, "read-only", 1, -1, 1), MakeCmdAttr("persist", 2, "write", 1, 1, 1), @@ -353,6 +551,9 @@ REDIS_REGISTER_COMMANDS(MakeCmdAttr("ttl", 2, "read-only", 1, 1, 1), MakeCmdAttr("del", -2, "write no-dbsize-check", 1, -1, 1), MakeCmdAttr("unlink", -2, "write no-dbsize-check", 1, -1, 1), MakeCmdAttr("rename", 3, "write", 1, 2, 1), - MakeCmdAttr("renamenx", 3, "write", 1, 2, 1), ) + MakeCmdAttr("renamenx", 3, "write", 1, 2, 1), + MakeCmdAttr("copy", -3, "write", 1, 2, 1), + MakeCmdAttr>("sort", -2, "write", 1, 1, 1), + MakeCmdAttr>("sort_ro", -2, "read-only", 1, 1, 1)) } // namespace redis diff --git a/src/commands/cmd_list.cc b/src/commands/cmd_list.cc index f354d64cc4b..e9e17266625 100644 --- a/src/commands/cmd_list.cc +++ b/src/commands/cmd_list.cc @@ -304,7 +304,7 @@ class CommandBPop : public BlockingCommander { conn_->Reply(conn_->MultiBulkString({*last_key_ptr, std::move(elem)})); } } else if (!s.IsNotFound()) { - conn_->Reply(redis::Error("ERR " + s.ToString())); + conn_->Reply(redis::Error({Status::NotOK, s.ToString()})); } return s; @@ -414,7 +414,7 @@ class CommandBLMPop : public BlockingCommander { conn_->Reply(redis::Array({redis::BulkString(chosen_key), std::move(elems_bulk)})); } } else if (!s.IsNotFound()) { - conn_->Reply(redis::Error("ERR " + s.ToString())); + conn_->Reply(redis::Error({Status::NotOK, s.ToString()})); } return s; @@ -757,7 +757,7 @@ class CommandBLMove : public BlockingCommander { std::string elem; auto s = list_db.LMove(args_[1], args_[2], src_left_, dst_left_, &elem); if (!s.ok() && !s.IsNotFound()) { - conn_->Reply(redis::Error("ERR " + s.ToString())); + conn_->Reply(redis::Error({Status::NotOK, s.ToString()})); return true; } diff --git a/src/commands/cmd_pubsub.cc b/src/commands/cmd_pubsub.cc index 8f0ddfbd9c0..6923fb89fd0 100644 --- a/src/commands/cmd_pubsub.cc +++ b/src/commands/cmd_pubsub.cc @@ -75,7 +75,7 @@ class CommandMPublish : public Commander { void SubscribeCommandReply(const Connection *conn, std::string *output, const std::string &name, const std::string &sub_name, int num) { - output->append(redis::MultiLen(3)); + output->append(conn->HeaderOfPush(3)); output->append(redis::BulkString(name)); output->append(sub_name.empty() ? conn->NilString() : BulkString(sub_name)); output->append(redis::Integer(num)); diff --git a/src/commands/cmd_replication.cc b/src/commands/cmd_replication.cc index 0a86a9cc619..d3f3c0f25c5 100644 --- a/src/commands/cmd_replication.cc +++ b/src/commands/cmd_replication.cc @@ -22,6 +22,7 @@ #include "error_constants.h" #include "io_util.h" #include "scope_exit.h" +#include "server/redis_reply.h" #include "server/server.h" #include "thread_util.h" #include "time_util.h" @@ -101,7 +102,7 @@ class CommandPSync : public Commander { srv->stats.IncrPSyncOKCount(); s = srv->AddSlave(conn, next_repl_seq_); if (!s.IsOK()) { - std::string err = "-ERR " + s.Msg() + "\r\n"; + std::string err = redis::Error(s); s = util::SockSend(conn->GetFD(), err, conn->GetBufferEvent()); if (!s.IsOK()) { LOG(WARNING) << "failed to send error message to the replica: " << s.Msg(); @@ -229,7 +230,7 @@ class CommandFetchMeta : public Commander { std::string files; auto s = engine::Storage::ReplDataManager::GetFullReplDataInfo(srv->storage, &files); if (!s.IsOK()) { - s = util::SockSend(repl_fd, "-ERR can't create db checkpoint", bev); + s = util::SockSend(repl_fd, redis::Error({Status::RedisErrorNoPrefix, "can't create db checkpoint"}), bev); if (!s.IsOK()) { LOG(WARNING) << "[replication] Failed to send error response: " << s.Msg(); } @@ -242,8 +243,8 @@ class CommandFetchMeta : public Commander { } else { LOG(WARNING) << "[replication] Fail to send full data file info " << ip << ", error: " << strerror(errno); } - auto now = static_cast(util::GetTimeStamp()); - srv->storage->SetCheckpointAccessTime(now); + auto now_secs = static_cast(util::GetTimeStamp()); + srv->storage->SetCheckpointAccessTimeSecs(now_secs); })); if (auto s = util::ThreadDetach(t); !s) { @@ -283,7 +284,7 @@ class CommandFetchFile : public Commander { if (srv->IsStopped()) break; uint64_t file_size = 0, max_replication_bytes = 0; - if (srv->GetConfig()->max_replication_mb > 0) { + if (srv->GetConfig()->max_replication_mb > 0 && srv->GetFetchFileThreadNum() != 0) { max_replication_bytes = (srv->GetConfig()->max_replication_mb * MiB) / srv->GetFetchFileThreadNum(); } auto start = std::chrono::high_resolution_clock::now(); @@ -303,16 +304,18 @@ class CommandFetchFile : public Commander { // Sleep if the speed of sending file is more than replication speed limit auto end = std::chrono::high_resolution_clock::now(); uint64_t duration = std::chrono::duration_cast(end - start).count(); - auto shortest = static_cast(static_cast(file_size) / - static_cast(max_replication_bytes) * (1000 * 1000)); - if (max_replication_bytes > 0 && duration < shortest) { - LOG(INFO) << "[replication] Need to sleep " << (shortest - duration) / 1000 - << " ms since of sending files too quickly"; - usleep(shortest - duration); + if (max_replication_bytes > 0) { + auto shortest = static_cast(static_cast(file_size) / + static_cast(max_replication_bytes) * (1000 * 1000)); + if (duration < shortest) { + LOG(INFO) << "[replication] Need to sleep " << (shortest - duration) / 1000 + << " ms since of sending files too quickly"; + usleep(shortest - duration); + } } } - auto now = static_cast(util::GetTimeStamp()); - srv->storage->SetCheckpointAccessTime(now); + auto now_secs = util::GetTimeStamp(); + srv->storage->SetCheckpointAccessTimeSecs(now_secs); srv->DecrFetchFileThread(); })); diff --git a/src/commands/cmd_script.cc b/src/commands/cmd_script.cc index c7576a6aa28..2547c289fcf 100644 --- a/src/commands/cmd_script.cc +++ b/src/commands/cmd_script.cc @@ -31,8 +31,7 @@ class CommandEvalImpl : public Commander { public: Status Execute(Server *srv, Connection *conn, std::string *output) override { if (evalsha && args_[1].size() != 40) { - *output = redis::Error(errNoMatchingScript); - return Status::OK(); + return {Status::RedisNoScript, errNoMatchingScript}; } int64_t numkeys = GET_OR_RET(ParseInt(args_[2], 10)); @@ -116,6 +115,14 @@ CommandKeyRange GetScriptEvalKeyRange(const std::vector &args) { return {3, 2 + numkeys, 1}; } +uint64_t GenerateScriptFlags(uint64_t flags, const std::vector &args) { + if (util::EqualICase(args[1], "load") || util::EqualICase(args[1], "flush")) { + return flags | kCmdWrite; + } + + return flags; +} + REDIS_REGISTER_COMMANDS(MakeCmdAttr("eval", -3, "exclusive write no-script", GetScriptEvalKeyRange), MakeCmdAttr("evalsha", -3, "exclusive write no-script", GetScriptEvalKeyRange), MakeCmdAttr("eval_ro", -3, "read-only no-script ro-script", diff --git a/src/commands/cmd_search.cc b/src/commands/cmd_search.cc new file mode 100644 index 00000000000..ea42f26977c --- /dev/null +++ b/src/commands/cmd_search.cc @@ -0,0 +1,371 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#include +#include +#include + +#include "commander.h" +#include "commands/command_parser.h" +#include "search/index_info.h" +#include "search/ir.h" +#include "search/ir_dot_dumper.h" +#include "search/plan_executor.h" +#include "search/redis_query_transformer.h" +#include "search/search_encoding.h" +#include "search/sql_transformer.h" +#include "server/redis_reply.h" +#include "server/server.h" +#include "string_util.h" +#include "tao/pegtl/string_input.hpp" + +namespace redis { + +class CommandFTCreate : public Commander { + Status Parse(const std::vector &args) override { + CommandParser parser(args, 1); + + auto index_name = GET_OR_RET(parser.TakeStr()); + if (index_name.empty()) { + return {Status::RedisParseErr, "index name cannot be empty"}; + } + + index_info_ = std::make_unique(index_name, redis::IndexMetadata{}, ""); + auto data_type = IndexOnDataType(0); + + while (parser.Good()) { + if (parser.EatEqICase("ON")) { + if (parser.EatEqICase("HASH")) { + data_type = IndexOnDataType::HASH; + } else if (parser.EatEqICase("JSON")) { + data_type = IndexOnDataType::JSON; + } else { + return {Status::RedisParseErr, "expect HASH or JSON after ON"}; + } + } else if (parser.EatEqICase("PREFIX")) { + size_t count = GET_OR_RET(parser.TakeInt()); + + for (size_t i = 0; i < count; ++i) { + index_info_->prefixes.prefixes.push_back(GET_OR_RET(parser.TakeStr())); + } + } else { + break; + } + } + + if (int(data_type) == 0) { + return {Status::RedisParseErr, "expect ON HASH | JSON"}; + } else { + index_info_->metadata.on_data_type = data_type; + } + + if (parser.EatEqICase("SCHEMA")) { + while (parser.Good()) { + auto field_name = GET_OR_RET(parser.TakeStr()); + if (field_name.empty()) { + return {Status::RedisParseErr, "field name cannot be empty"}; + } + + std::unique_ptr field_meta; + if (parser.EatEqICase("TAG")) { + field_meta = std::make_unique(); + } else if (parser.EatEqICase("NUMERIC")) { + field_meta = std::make_unique(); + } else { + return {Status::RedisParseErr, "expect field type TAG or NUMERIC"}; + } + + while (parser.Good()) { + if (parser.EatEqICase("NOINDEX")) { + field_meta->noindex = true; + } else if (auto tag = dynamic_cast(field_meta.get())) { + if (parser.EatEqICase("CASESENSITIVE")) { + tag->case_sensitive = true; + } else if (parser.EatEqICase("SEPARATOR")) { + auto sep = GET_OR_RET(parser.TakeStr()); + + if (sep.size() != 1) { + return {Status::NotOK, "only one character separator is supported"}; + } + + tag->separator = sep[0]; + } else { + break; + } + } else { + break; + } + } + + kqir::FieldInfo field_info(field_name, std::move(field_meta)); + + index_info_->Add(std::move(field_info)); + } + } else { + return {Status::RedisParseErr, "expect SCHEMA section for this index"}; + } + + if (parser.Good()) { + return {Status::RedisParseErr, "more token than expected in command arguments"}; + } + + return Status::OK(); + } + + Status Execute(Server *srv, Connection *conn, std::string *output) override { + index_info_->ns = conn->GetNamespace(); + + GET_OR_RET(srv->index_mgr.Create(std::move(index_info_))); + + output->append(redis::SimpleString("OK")); + return Status::OK(); + }; + + private: + std::unique_ptr index_info_; +}; + +static void DumpQueryResult(const std::vector &rows, std::string *output) { + output->append(MultiLen(rows.size() * 2 + 1)); + output->append(Integer(rows.size())); + for (const auto &[key, fields, _] : rows) { + output->append(redis::BulkString(key)); + output->append(MultiLen(fields.size() * 2)); + for (const auto &[info, field] : fields) { + output->append(redis::BulkString(info->name)); + output->append(redis::BulkString(field.ToString(info->metadata.get()))); + } + } +} + +class CommandFTExplainSQL : public Commander { + Status Parse(const std::vector &args) override { + if (args.size() == 3) { + if (util::EqualICase(args[2], "simple")) { + format_ = SIMPLE; + } else if (util::EqualICase(args[2], "dot")) { + format_ = DOT_GRAPH; + } else { + return {Status::NotOK, "output format should be SIMPLE or DOT"}; + } + } + + if (args.size() > 3) { + return {Status::NotOK, "more arguments than expected"}; + } + + return Status::OK(); + } + + Status Execute(Server *srv, Connection *conn, std::string *output) override { + const auto &sql = args_[1]; + + auto ir = GET_OR_RET(kqir::sql::ParseToIR(kqir::peg::string_input(sql, "ft.explainsql"))); + + auto plan = GET_OR_RET(srv->index_mgr.GeneratePlan(std::move(ir), conn->GetNamespace())); + + if (format_ == SIMPLE) { + output->append(BulkString(plan->Dump())); + } else if (format_ == DOT_GRAPH) { + std::ostringstream ss; + kqir::DotDumper dumper(ss); + + dumper.Dump(plan.get()); + output->append(BulkString(ss.str())); + } + + return Status::OK(); + }; + + enum OutputFormat { SIMPLE, DOT_GRAPH } format_ = SIMPLE; +}; + +class CommandFTSearchSQL : public Commander { + Status Execute(Server *srv, Connection *conn, std::string *output) override { + const auto &sql = args_[1]; + + auto ir = GET_OR_RET(kqir::sql::ParseToIR(kqir::peg::string_input(sql, "ft.searchsql"))); + + auto results = GET_OR_RET(srv->index_mgr.Search(std::move(ir), conn->GetNamespace())); + + DumpQueryResult(results, output); + + return Status::OK(); + }; +}; + +static StatusOr> ParseRediSearchQuery(const std::vector &args) { + CommandParser parser(args, 1); + + auto index_name = GET_OR_RET(parser.TakeStr()); + auto query_str = GET_OR_RET(parser.TakeStr()); + + auto index_ref = std::make_unique(index_name); + auto query = kqir::Node::MustAs( + GET_OR_RET(kqir::redis_query::ParseToIR(kqir::peg::string_input(query_str, "ft.search")))); + + auto select = std::make_unique(std::vector>{}); + std::unique_ptr sort_by; + std::unique_ptr limit; + while (parser.Good()) { + if (parser.EatEqICase("RETURNS")) { + auto count = GET_OR_RET(parser.TakeInt()); + + for (size_t i = 0; i < count; ++i) { + auto field = GET_OR_RET(parser.TakeStr()); + select->fields.push_back(std::make_unique(field)); + } + } else if (parser.EatEqICase("SORTBY")) { + auto field = GET_OR_RET(parser.TakeStr()); + auto order = kqir::SortByClause::ASC; + if (parser.EatEqICase("ASC")) { + // NOOP + } else if (parser.EatEqICase("DESC")) { + order = kqir::SortByClause::DESC; + } + + sort_by = std::make_unique(order, std::make_unique(field)); + } else if (parser.EatEqICase("LIMIT")) { + auto offset = GET_OR_RET(parser.TakeInt()); + auto count = GET_OR_RET(parser.TakeInt()); + + limit = std::make_unique(offset, count); + } else { + return parser.InvalidSyntax(); + } + } + + return std::make_unique(std::move(index_ref), std::move(query), std::move(limit), + std::move(sort_by), std::move(select)); +} + +class CommandFTExplain : public Commander { + Status Parse(const std::vector &args) override { + ir_ = GET_OR_RET(ParseRediSearchQuery(args)); + return Status::OK(); + } + + Status Execute(Server *srv, Connection *conn, std::string *output) override { + CHECK(ir_); + auto plan = GET_OR_RET(srv->index_mgr.GeneratePlan(std::move(ir_), conn->GetNamespace())); + + output->append(redis::BulkString(plan->Dump())); + + return Status::OK(); + }; + + private: + std::unique_ptr ir_; +}; + +class CommandFTSearch : public Commander { + Status Parse(const std::vector &args) override { + ir_ = GET_OR_RET(ParseRediSearchQuery(args)); + return Status::OK(); + } + + Status Execute(Server *srv, Connection *conn, std::string *output) override { + CHECK(ir_); + auto results = GET_OR_RET(srv->index_mgr.Search(std::move(ir_), conn->GetNamespace())); + + DumpQueryResult(results, output); + + return Status::OK(); + }; + + private: + std::unique_ptr ir_; +}; + +class CommandFTInfo : public Commander { + Status Execute(Server *srv, Connection *conn, std::string *output) override { + const auto &index_map = srv->index_mgr.index_map; + const auto &index_name = args_[1]; + + auto iter = index_map.Find(index_name, conn->GetNamespace()); + if (iter == index_map.end()) { + return {Status::RedisExecErr, "index not found"}; + } + + const auto &info = iter->second; + output->append(MultiLen(8)); + + output->append(redis::SimpleString("index_name")); + output->append(redis::BulkString(info->name)); + + output->append(redis::SimpleString("on_data_type")); + output->append(redis::BulkString(RedisTypeNames[(size_t)info->metadata.on_data_type])); + + output->append(redis::SimpleString("prefixes")); + output->append(redis::ArrayOfBulkStrings(info->prefixes.prefixes)); + + output->append(redis::SimpleString("fields")); + output->append(MultiLen(info->fields.size())); + for (const auto &[_, field] : info->fields) { + output->append(MultiLen(2)); + output->append(redis::BulkString(field.name)); + auto type = field.metadata->Type(); + output->append(redis::BulkString(std::string(type.begin(), type.end()))); + } + + return Status::OK(); + }; +}; + +class CommandFTList : public Commander { + Status Execute(Server *srv, Connection *conn, std::string *output) override { + const auto &index_map = srv->index_mgr.index_map; + + std::vector results; + for (const auto &[_, index] : index_map) { + if (index->ns == conn->GetNamespace()) { + results.push_back(index->name); + } + } + + output->append(ArrayOfBulkStrings(results)); + + return Status::OK(); + }; +}; + +class CommandFTDrop : public Commander { + Status Execute(Server *srv, Connection *conn, std::string *output) override { + const auto &index_name = args_[1]; + + GET_OR_RET(srv->index_mgr.Drop(index_name, conn->GetNamespace())); + + output->append(SimpleString("OK")); + + return Status::OK(); + }; +}; + +// REDIS_REGISTER_COMMANDS(MakeCmdAttr("ft.create", -2, "write exclusive no-multi no-script", 0, 0, 0), +// MakeCmdAttr("ft.searchsql", 2, "read-only", 0, 0, 0), +// MakeCmdAttr("ft.search", -3, "read-only", 0, 0, 0), +// MakeCmdAttr("ft.explainsql", -2, "read-only", 0, 0, 0), +// MakeCmdAttr("ft.explain", -3, "read-only", 0, 0, 0), +// MakeCmdAttr("ft.info", 2, "read-only", 0, 0, 0), +// MakeCmdAttr("ft._list", 1, "read-only", 0, 0, 0), +// MakeCmdAttr("ft.dropindex", 2, "write exclusive no-multi no-script", 0, 0, +// 0)); + +} // namespace redis diff --git a/src/commands/cmd_server.cc b/src/commands/cmd_server.cc index ec284cd89f5..ecf3cfc59f3 100644 --- a/src/commands/cmd_server.cc +++ b/src/commands/cmd_server.cc @@ -29,6 +29,7 @@ #include "config/config.h" #include "error_constants.h" #include "server/redis_connection.h" +#include "server/redis_reply.h" #include "server/server.h" #include "stats/disk_stats.h" #include "storage/rdb.h" @@ -37,48 +38,26 @@ namespace redis { -enum class AuthResult { - OK, - INVALID_PASSWORD, - NO_REQUIRE_PASS, -}; - -AuthResult AuthenticateUser(Server *srv, Connection *conn, const std::string &user_password) { - auto ns = srv->GetNamespace()->GetByToken(user_password); - if (ns.IsOK() && user_password != ns.GetValue()) { - conn->SetNamespace(ns.GetValue()); - conn->BecomeUser(); - return AuthResult::OK; - } - - const auto &requirepass = srv->GetConfig()->requirepass; - if (!requirepass.empty() && user_password != requirepass) { - return AuthResult::INVALID_PASSWORD; - } - - conn->SetNamespace(kDefaultNamespace); - conn->BecomeAdmin(); - if (requirepass.empty()) { - return AuthResult::NO_REQUIRE_PASS; - } - - return AuthResult::OK; -} - class CommandAuth : public Commander { public: Status Execute(Server *srv, Connection *conn, std::string *output) override { auto &user_password = args_[1]; - AuthResult result = AuthenticateUser(srv, conn, user_password); + std::string ns; + AuthResult result = srv->AuthenticateUser(user_password, &ns); switch (result) { - case AuthResult::OK: - *output = redis::SimpleString("OK"); - break; - case AuthResult::INVALID_PASSWORD: - return {Status::RedisExecErr, "invalid password"}; case AuthResult::NO_REQUIRE_PASS: return {Status::RedisExecErr, "Client sent AUTH, but no password is set"}; + case AuthResult::INVALID_PASSWORD: + return {Status::RedisExecErr, "Invalid password"}; + case AuthResult::IS_USER: + conn->BecomeUser(); + break; + case AuthResult::IS_ADMIN: + conn->BecomeAdmin(); + break; } + conn->SetNamespace(ns); + *output = redis::SimpleString("OK"); return Status::OK(); } }; @@ -116,17 +95,17 @@ class CommandNamespace : public Commander { } } else if (args_.size() == 4 && sub_command == "set") { Status s = srv->GetNamespace()->Set(args_[2], args_[3]); - *output = s.IsOK() ? redis::SimpleString("OK") : redis::Error("ERR " + s.Msg()); + *output = s.IsOK() ? redis::SimpleString("OK") : redis::Error(s); LOG(WARNING) << "Updated namespace: " << args_[2] << " with token: " << args_[3] << ", addr: " << conn->GetAddr() << ", result: " << s.Msg(); } else if (args_.size() == 4 && sub_command == "add") { Status s = srv->GetNamespace()->Add(args_[2], args_[3]); - *output = s.IsOK() ? redis::SimpleString("OK") : redis::Error("ERR " + s.Msg()); + *output = s.IsOK() ? redis::SimpleString("OK") : redis::Error(s); LOG(WARNING) << "New namespace: " << args_[2] << " with token: " << args_[3] << ", addr: " << conn->GetAddr() << ", result: " << s.Msg(); } else if (args_.size() == 3 && sub_command == "del") { Status s = srv->GetNamespace()->Del(args_[2]); - *output = s.IsOK() ? redis::SimpleString("OK") : redis::Error("ERR " + s.Msg()); + *output = s.IsOK() ? redis::SimpleString("OK") : redis::Error(s); LOG(WARNING) << "Deleted namespace: " << args_[2] << ", addr: " << conn->GetAddr() << ", result: " << s.Msg(); } else { return {Status::RedisExecErr, "NAMESPACE subcommand must be one of GET, SET, DEL, ADD"}; @@ -276,7 +255,7 @@ class CommandConfig : public Commander { if (args_.size() == 2 && sub_command == "rewrite") { Status s = config->Rewrite(srv->GetNamespace()->List()); - if (!s.IsOK()) return {Status::RedisExecErr, s.Msg()}; + if (!s.IsOK()) return s; *output = redis::SimpleString("OK"); LOG(INFO) << "# CONFIG REWRITE executed with success"; @@ -362,12 +341,12 @@ class CommandDBSize : public Commander { KeyNumStats stats; srv->GetLatestKeyNumStats(ns, &stats); *output = redis::Integer(stats.n_key); - } else if (args_.size() == 2 && args_[1] == "scan") { + } else if (args_.size() == 2 && util::EqualICase(args_[1], "scan")) { Status s = srv->AsyncScanDBSize(ns); if (s.IsOK()) { *output = redis::SimpleString("OK"); } else { - return {Status::RedisExecErr, s.Msg()}; + return s; } } else { return {Status::RedisExecErr, "DBSIZE subcommand only supports scan"}; @@ -682,9 +661,9 @@ class CommandDebug : public Commander { } else if (protocol_type_ == "verbatim") { // verbatim string *output = conn->VerbatimString("txt", "verbatim string"); } else { - *output = redis::Error( - "Wrong protocol type name. Please use one of the following: " - "string|integer|double|array|set|bignum|true|false|null|attrib|verbatim"); + return {Status::RedisErrorNoPrefix, + "Wrong protocol type name. Please use one of the following: " + "string|integer|double|array|set|bignum|true|false|null|attrib|verbatim"}; } } else if (subcommand_ == "dbsize-limit") { srv->storage->SetDBSizeLimit(dbsize_limit_); @@ -793,7 +772,7 @@ class CommandHello final : public Commander { // kvrocks only supports REPL2 by now, but for supporting some // `hello 3`, it will not report error when using 3. if (protocol < 2 || protocol > 3) { - return {Status::NotOK, "-NOPROTO unsupported protocol version"}; + return {Status::RedisNoProto, "unsupported protocol version"}; } } @@ -804,20 +783,26 @@ class CommandHello final : public Commander { if (util::ToLower(opt) == "auth" && more_args != 0) { if (more_args == 2 || more_args == 4) { if (args_[next_arg + 1] != "default") { - return {Status::NotOK, "invalid password"}; + return {Status::NotOK, "Invalid password"}; } next_arg++; } const auto &user_password = args_[next_arg + 1]; - auto auth_result = AuthenticateUser(srv, conn, user_password); + std::string ns; + AuthResult auth_result = srv->AuthenticateUser(user_password, &ns); switch (auth_result) { - case AuthResult::INVALID_PASSWORD: - return {Status::NotOK, "invalid password"}; case AuthResult::NO_REQUIRE_PASS: return {Status::NotOK, "Client sent AUTH, but no password is set"}; - case AuthResult::OK: + case AuthResult::INVALID_PASSWORD: + return {Status::NotOK, "Invalid password"}; + case AuthResult::IS_USER: + conn->BecomeUser(); + break; + case AuthResult::IS_ADMIN: + conn->BecomeAdmin(); break; } + conn->SetNamespace(ns); next_arg += 1; } else if (util::ToLower(opt) == "setname" && more_args != 0) { const std::string &name = args_[next_arg + 1]; @@ -866,28 +851,6 @@ class CommandScan : public CommandScanBase { public: CommandScan() : CommandScanBase() {} - Status Parse(const std::vector &args) override { - if (args.size() % 2 != 0) { - return {Status::RedisParseErr, errWrongNumOfArguments}; - } - - ParseCursor(args[1]); - if (args.size() >= 4) { - Status s = ParseMatchAndCountParam(util::ToLower(args[2]), args_[3]); - if (!s.IsOK()) { - return s; - } - } - - if (args.size() >= 6) { - Status s = ParseMatchAndCountParam(util::ToLower(args[4]), args_[5]); - if (!s.IsOK()) { - return s; - } - } - return Commander::Parse(args); - } - static std::string GenerateOutput(Server *srv, const Connection *conn, const std::vector &keys, const std::string &end_cursor) { std::vector list; @@ -909,7 +872,7 @@ class CommandScan : public CommandScanBase { std::vector keys; std::string end_key; - auto s = redis_db.Scan(key_name, limit_, prefix_, &keys, &end_key); + auto s = redis_db.Scan(key_name, limit_, prefix_, &keys, &end_key, type_); if (!s.ok()) { return {Status::RedisExecErr, s.ToString()}; } @@ -941,18 +904,8 @@ class CommandCompact : public Commander { auto ns = conn->GetNamespace(); if (ns != kDefaultNamespace) { - std::string prefix = ComposeNamespaceKey(ns, "", false); - - redis::Database redis_db(srv->storage, conn->GetNamespace()); - auto s = redis_db.FindKeyRangeWithPrefix(prefix, std::string(), &begin_key, &end_key); - if (!s.ok()) { - if (s.IsNotFound()) { - *output = redis::SimpleString("OK"); - return Status::OK(); - } - - return {Status::RedisExecErr, s.ToString()}; - } + begin_key = ComposeNamespaceKey(ns, "", false); + end_key = util::StringNext(begin_key); } Status s = srv->AsyncCompactDB(begin_key, end_key); @@ -1063,9 +1016,7 @@ class CommandSlaveOf : public Commander { } auto s = IsTryingToReplicateItself(srv, host_, port_); - if (!s.IsOK()) { - return {Status::RedisExecErr, s.Msg()}; - } + if (!s.IsOK()) return s; s = srv->AddMaster(host_, port_, false); if (s.IsOK()) { *output = redis::SimpleString("OK"); @@ -1097,12 +1048,12 @@ class CommandStats : public Commander { } }; -static uint64_t GenerateConfigFlag(const std::vector &args) { +static uint64_t GenerateConfigFlag(uint64_t flags, const std::vector &args) { if (args.size() >= 2 && util::EqualICase(args[1], "set")) { - return kCmdExclusive; + return flags | kCmdExclusive; } - return 0; + return flags; } class CommandLastSave : public Commander { @@ -1178,7 +1129,7 @@ class CommandRestore : public Commander { auto stream_ptr = std::make_unique(args_[3]); RDB rdb(srv->storage, conn->GetNamespace(), std::move(stream_ptr)); auto s = rdb.Restore(args_[1], args_[3], ttl_ms_); - if (!s.IsOK()) return {Status::RedisExecErr, s.Msg()}; + if (!s.IsOK()) return s; *output = redis::SimpleString("OK"); return Status::OK(); } @@ -1238,72 +1189,6 @@ class CommandRdb : public Commander { uint32_t db_index_ = 0; }; -class CommandAnalyze : public Commander { - public: - Status Parse(const std::vector &args) override { - if (args.size() <= 1) return {Status::RedisExecErr, errInvalidSyntax}; - for (unsigned int i = 1; i < args.size(); ++i) { - command_args_.push_back(args[i]); - } - return Status::OK(); - } - Status Execute(Server *srv, Connection *conn, std::string *output) override { - auto commands = redis::CommandTable::Get(); - auto cmd_iter = commands->find(util::ToLower(command_args_[0])); - if (cmd_iter == commands->end()) { - // unsupported redis command - return {Status::RedisExecErr, errInvalidSyntax}; - } - auto redis_cmd = cmd_iter->second; - auto cmd = redis_cmd->factory(); - cmd->SetAttributes(redis_cmd); - cmd->SetArgs(command_args_); - - int arity = cmd->GetAttributes()->arity; - if ((arity > 0 && static_cast(command_args_.size()) != arity) || - (arity < 0 && static_cast(command_args_.size()) < -arity)) { - *output = redis::Error("ERR wrong number of arguments"); - return {Status::RedisExecErr, errWrongNumOfArguments}; - } - - auto s = cmd->Parse(command_args_); - if (!s.IsOK()) { - return s; - } - - auto prev_perf_level = rocksdb::GetPerfLevel(); - rocksdb::SetPerfLevel(rocksdb::PerfLevel::kEnableTimeExceptForMutex); - rocksdb::get_perf_context()->Reset(); - rocksdb::get_iostats_context()->Reset(); - - std::string command_output; - s = cmd->Execute(srv, conn, &command_output); - if (!s.IsOK()) { - return s; - } - - if (command_output[0] == '-') { - *output = command_output; - return s; - } - - std::string perf_context = rocksdb::get_perf_context()->ToString(true); - std::string iostats_context = rocksdb::get_iostats_context()->ToString(true); - rocksdb::get_perf_context()->Reset(); - rocksdb::get_iostats_context()->Reset(); - rocksdb::SetPerfLevel(prev_perf_level); - - *output = redis::MultiLen(3); // command output + perf context + iostats context - *output += command_output; - *output += redis::BulkString(perf_context); - *output += redis::BulkString(iostats_context); - return Status::OK(); - } - - private: - std::vector command_args_; -}; - class CommandReset : public Commander { public: Status Execute(Server *srv, Connection *conn, std::string *output) override { @@ -1358,9 +1243,8 @@ class CommandApplyBatch : public Commander { auto options = svr->storage->DefaultWriteOptions(); options.low_pri = low_pri_; auto s = svr->storage->ApplyWriteBatch(options, std::move(raw_batch_)); - if (!s.IsOK()) { - return {Status::RedisExecErr, s.Msg()}; - } + if (!s.IsOK()) return s; + *output = redis::Integer(size); return Status::OK(); } @@ -1370,13 +1254,55 @@ class CommandApplyBatch : public Commander { bool low_pri_ = false; }; +class CommandDump : public Commander { + public: + Status Parse(const std::vector &args) override { + if (args.size() != 2) { + return {Status::RedisExecErr, errWrongNumOfArguments}; + } + return Status::OK(); + } + + Status Execute(Server *srv, Connection *conn, std::string *output) override { + rocksdb::Status db_status; + std::string &key = args_[1]; + redis::Database redis(srv->storage, conn->GetNamespace()); + int count = 0; + db_status = redis.Exists({key}, &count); + if (!db_status.ok()) { + if (db_status.IsNotFound()) { + *output = conn->NilString(); + return Status::OK(); + } + return {Status::RedisExecErr, db_status.ToString()}; + } + if (count == 0) { + *output = conn->NilString(); + return Status::OK(); + } + + RedisType type = kRedisNone; + db_status = redis.Type(key, &type); + if (!db_status.ok()) return {Status::RedisExecErr, db_status.ToString()}; + + std::string result; + auto stream_ptr = std::make_unique(result); + RDB rdb(srv->storage, conn->GetNamespace(), std::move(stream_ptr)); + auto s = rdb.Dump(key, type); + if (!s.IsOK()) return s; + CHECK(dynamic_cast(rdb.GetStream().get()) != nullptr); + *output = redis::BulkString(static_cast(rdb.GetStream().get())->GetInput()); + return Status::OK(); + } +}; + REDIS_REGISTER_COMMANDS(MakeCmdAttr("auth", 2, "read-only ok-loading", 0, 0, 0), MakeCmdAttr("ping", -1, "read-only", 0, 0, 0), MakeCmdAttr("select", 2, "read-only", 0, 0, 0), MakeCmdAttr("info", -1, "read-only ok-loading", 0, 0, 0), MakeCmdAttr("role", 1, "read-only ok-loading", 0, 0, 0), MakeCmdAttr("config", -2, "read-only", 0, 0, 0, GenerateConfigFlag), - MakeCmdAttr("namespace", -3, "read-only exclusive", 0, 0, 0), + MakeCmdAttr("namespace", -3, "read-only", 0, 0, 0), MakeCmdAttr("keys", 2, "read-only", 0, 0, 0), MakeCmdAttr("flushdb", 1, "write no-dbsize-check", 0, 0, 0), MakeCmdAttr("flushall", 1, "write no-dbsize-check", 0, 0, 0), @@ -1405,7 +1331,7 @@ REDIS_REGISTER_COMMANDS(MakeCmdAttr("auth", 2, "read-only ok-loadin MakeCmdAttr("slaveof", 3, "read-only exclusive no-script", 0, 0, 0), MakeCmdAttr("stats", 1, "read-only", 0, 0, 0), MakeCmdAttr("rdb", -3, "write exclusive", 0, 0, 0), - MakeCmdAttr("analyze", -1, "", 0, 0, 0), MakeCmdAttr("reset", 1, "ok-loading multi no-script pub-sub", 0, 0, 0), - MakeCmdAttr("applybatch", -2, "write no-multi", 0, 0, 0), ) + MakeCmdAttr("applybatch", -2, "write no-multi", 0, 0, 0), + MakeCmdAttr("dump", 2, "read-only", 0, 0, 0), ) } // namespace redis diff --git a/src/commands/cmd_stream.cc b/src/commands/cmd_stream.cc index a90caece532..ee608be2158 100644 --- a/src/commands/cmd_stream.cc +++ b/src/commands/cmd_stream.cc @@ -18,6 +18,8 @@ * */ +#include +#include #include #include @@ -31,6 +33,39 @@ namespace redis { +class CommandXAck : public Commander { + public: + Status Parse(const std::vector &args) override { + stream_name_ = args[1]; + group_name_ = args[2]; + StreamEntryID tmp_id; + for (size_t i = 3; i < args.size(); ++i) { + auto s = ParseStreamEntryID(args[i], &tmp_id); + if (!s.IsOK()) return s; + entry_ids_.emplace_back(tmp_id); + } + + return Status::OK(); + } + + Status Execute(Server *srv, Connection *conn, std::string *output) override { + redis::Stream stream_db(srv->storage, conn->GetNamespace()); + uint64_t acknowledged = 0; + auto s = stream_db.DeletePelEntries(stream_name_, group_name_, entry_ids_, &acknowledged); + if (!s.ok()) { + return {Status::RedisExecErr, s.ToString()}; + } + *output = redis::Integer(acknowledged); + + return Status::OK(); + } + + private: + std::string stream_name_; + std::string group_name_; + std::vector entry_ids_; +}; + class CommandXAdd : public Commander { public: Status Parse(const std::vector &args) override { @@ -95,9 +130,7 @@ class CommandXAdd : public Commander { } auto s = ParseStreamEntryID(args[min_id_idx], &min_id_); - if (!s.IsOK()) { - return {Status::RedisParseErr, s.Msg()}; - } + if (!s.IsOK()) return s; with_min_id_ = true; i += eq_sign_found ? 3 : 2; @@ -110,9 +143,7 @@ class CommandXAdd : public Commander { if (!entry_id_found) { auto result = ParseNextStreamEntryIDStrategy(val); - if (!result.IsOK()) { - return {Status::RedisParseErr, result.Msg()}; - } + if (!result.IsOK()) return result; next_id_strategy_ = std::move(*result); @@ -207,6 +238,219 @@ class CommandXDel : public Commander { std::vector ids_; }; +class CommandXClaim : public Commander { + public: + Status Parse(const std::vector &args) override { + if (args.size() < 6) { + return {Status::RedisParseErr, errWrongNumOfArguments}; + } + + CommandParser parser(args, 1); + stream_name_ = GET_OR_RET(parser.TakeStr()); + group_name_ = GET_OR_RET(parser.TakeStr()); + consumer_name_ = GET_OR_RET(parser.TakeStr()); + auto parse_result = parser.TakeInt(); + if (!parse_result.IsOK()) { + return {Status::RedisParseErr, errValueNotInteger}; + } + min_idle_time_ms_ = parse_result.GetValue(); + if (min_idle_time_ms_ < 0) { + min_idle_time_ms_ = 0; + } + + while (parser.Good() && !isOption(parser.RawPeek())) { + auto raw_id = GET_OR_RET(parser.TakeStr()); + redis::StreamEntryID id; + auto s = ParseStreamEntryID(raw_id, &id); + if (!s.IsOK()) { + return s; + } + entry_ids_.emplace_back(id); + } + + while (parser.Good()) { + if (parser.EatEqICase("idle")) { + auto parse_result = parser.TakeInt(); + if (!parse_result.IsOK()) { + return {Status::RedisParseErr, errValueNotInteger}; + } + if (parse_result.GetValue() < 0) { + return {Status::RedisParseErr, "IDLE for XCLAIM must be non-negative"}; + } + stream_claim_options_.idle_time_ms = parse_result.GetValue(); + } else if (parser.EatEqICase("time")) { + auto parse_result = parser.TakeInt(); + if (!parse_result.IsOK()) { + return {Status::RedisParseErr, errValueNotInteger}; + } + if (parse_result.GetValue() < 0) { + return {Status::RedisParseErr, "TIME for XCLAIM must be non-negative"}; + } + stream_claim_options_.with_time = true; + stream_claim_options_.last_delivery_time_ms = parse_result.GetValue(); + } else if (parser.EatEqICase("retrycount")) { + auto parse_result = parser.TakeInt(); + if (!parse_result.IsOK()) { + return {Status::RedisParseErr, errValueNotInteger}; + } + if (parse_result.GetValue() < 0) { + return {Status::RedisParseErr, "RETRYCOUNT for XCLAIM must be non-negative"}; + } + stream_claim_options_.with_retry_count = true; + stream_claim_options_.last_delivery_count = parse_result.GetValue(); + } else if (parser.EatEqICase("force")) { + stream_claim_options_.force = true; + } else if (parser.EatEqICase("justid")) { + stream_claim_options_.just_id = true; + } else if (parser.EatEqICase("lastid")) { + auto last_id = GET_OR_RET(parser.TakeStr()); + auto s = ParseStreamEntryID(last_id, &stream_claim_options_.last_delivered_id); + if (!s.IsOK()) { + return s; + } + } else { + return parser.InvalidSyntax(); + } + } + return Status::OK(); + } + + Status Execute(Server *srv, Connection *conn, std::string *output) override { + redis::Stream stream_db(srv->storage, conn->GetNamespace()); + StreamClaimResult result; + auto s = stream_db.ClaimPelEntries(stream_name_, group_name_, consumer_name_, min_idle_time_ms_, entry_ids_, + stream_claim_options_, &result); + if (!s.ok()) { + return {Status::RedisExecErr, s.ToString()}; + } + + if (s.IsNotFound()) { + return {Status::RedisExecErr, errNoSuchKey}; + } + + if (!stream_claim_options_.just_id) { + output->append(redis::MultiLen(result.entries.size())); + + for (const auto &e : result.entries) { + output->append(redis::MultiLen(2)); + output->append(redis::BulkString(e.key)); + output->append(conn->MultiBulkString(e.values)); + } + } else { + output->append(redis::MultiLen(result.ids.size())); + for (const auto &id : result.ids) { + output->append(redis::BulkString(id)); + } + } + + return Status::OK(); + } + + private: + std::string stream_name_; + std::string group_name_; + std::string consumer_name_; + uint64_t min_idle_time_ms_; + std::vector entry_ids_; + StreamClaimOptions stream_claim_options_; + + bool static isOption(const std::string &arg) { + static const std::unordered_set options = {"idle", "time", "retrycount", "force", "justid", "lastid"}; + return options.find(util::ToLower(arg)) != options.end(); + } +}; + +class CommandAutoClaim : public Commander { + public: + Status Parse(const std::vector &args) override { + CommandParser parser(args, 1); + key_name_ = GET_OR_RET(parser.TakeStr()); + group_name_ = GET_OR_RET(parser.TakeStr()); + consumer_name_ = GET_OR_RET(parser.TakeStr()); + if (auto parse_status = parser.TakeInt(); !parse_status.IsOK()) { + return {Status::RedisParseErr, "Invalid min-idle-time argument for XAUTOCLAIM"}; + } else { + options_.min_idle_time_ms = parse_status.GetValue(); + } + + auto start_str = GET_OR_RET(parser.TakeStr()); + if (!start_str.empty() && start_str.front() == '(') { + options_.exclude_start = true; + start_str = start_str.substr(1); + } + if (!options_.exclude_start && start_str == "-") { + options_.start_id = StreamEntryID::Minimum(); + } else { + auto parse_status = ParseRangeStart(start_str, &options_.start_id); + if (!parse_status.IsOK()) { + return parse_status; + } + } + + if (parser.EatEqICase("count")) { + uint64_t count = GET_OR_RET(parser.TakeInt()); + constexpr uint64_t min_count = 1; + uint64_t max_count = std::numeric_limits::max() / + (std::max(static_cast(sizeof(StreamEntryID)), options_.attempts_factors)); + if (count < min_count || count > max_count) { + return {Status::RedisParseErr, "COUNT must be > 0"}; + } + options_.count = count; + } + + if (parser.Good() && parser.EatEqICase("justid")) { + options_.just_id = true; + } + + return Status::OK(); + } + + Status Execute(Server *srv, Connection *conn, std::string *output) override { + redis::Stream stream_db(srv->storage, conn->GetNamespace()); + StreamAutoClaimResult result; + auto s = stream_db.AutoClaim(key_name_, group_name_, consumer_name_, options_, &result); + if (!s.ok()) { + if (s.IsNotFound()) { + return {Status::RedisExecErr, + "NOGROUP No such key '" + key_name_ + "' or consumer group '" + group_name_ + "'"}; + } + return {Status::RedisExecErr, s.ToString()}; + } + return sendResults(conn, result, output); + } + + private: + Status sendResults(Connection *conn, const StreamAutoClaimResult &result, std::string *output) const { + output->append(redis::MultiLen(3)); + output->append(redis::BulkString(result.next_claim_id)); + output->append(redis::MultiLen(result.entries.size())); + for (const auto &item : result.entries) { + if (options_.just_id) { + output->append(redis::BulkString(item.key)); + } else { + output->append(redis::MultiLen(2)); + output->append(redis::BulkString(item.key)); + output->append(redis::MultiLen(item.values.size())); + for (const auto &value_item : item.values) { + output->append(redis::BulkString(value_item)); + } + } + } + + output->append(redis::MultiLen(result.deleted_ids.size())); + for (const auto &item : result.deleted_ids) { + output->append(redis::BulkString(item)); + } + + return Status::OK(); + } + + std::string key_name_; + std::string group_name_; + std::string consumer_name_; + StreamAutoClaimOptions options_; +}; + class CommandXGroup : public Commander { public: Status Parse(const std::vector &args) override { @@ -411,10 +655,10 @@ class CommandXInfo : public Commander { count_ = *parse_result; } - } else if (val == "groups" && args.size() == 3) { - subcommand_ = "groups"; - } else if (val == "consumers" && args.size() == 4) { - subcommand_ = "consumers"; + // } else if (val == "groups" && args.size() == 3) { + // subcommand_ = "groups"; + // } else if (val == "consumers" && args.size() == 4) { + // subcommand_ = "consumers"; } else { return {Status::RedisParseErr, errUnknownSubcommandOrWrongArguments}; } @@ -553,7 +797,7 @@ class CommandXInfo : public Commander { } output->append(redis::MultiLen(result_vector.size())); - auto now = util::GetTimeStampMS(); + auto now_ms = util::GetTimeStampMS(); for (auto const &it : result_vector) { output->append(conn->HeaderOfMap(4)); output->append(redis::BulkString("name")); @@ -561,9 +805,9 @@ class CommandXInfo : public Commander { output->append(redis::BulkString("pending")); output->append(redis::Integer(it.second.pending_number)); output->append(redis::BulkString("idle")); - output->append(redis::Integer(now - it.second.last_idle)); + output->append(redis::Integer(now_ms - it.second.last_attempted_interaction_ms)); output->append(redis::BulkString("inactive")); - output->append(redis::Integer(now - it.second.last_active)); + output->append(redis::Integer(now_ms - it.second.last_successful_interaction_ms)); } return Status::OK(); @@ -973,7 +1217,7 @@ class CommandXRead : public Commander, std::vector result; auto s = stream_db.Range(streams_[i], options, &result); if (!s.ok() && !s.IsNotFound()) { - conn_->Reply(redis::Error("ERR " + s.ToString())); + conn_->Reply(redis::Error({Status::NotOK, s.ToString()})); return; } @@ -1266,7 +1510,7 @@ class CommandXReadGroup : public Commander, auto s = stream_db.RangeWithPending(streams_[i], options, &result, group_name_, consumer_name_, noack_, latest_marks_[i]); if (!s.ok() && !s.IsNotFound()) { - conn_->Reply(redis::Error("ERR " + s.ToString())); + conn_->Reply(redis::Error({Status::NotOK, s.ToString()})); return; } @@ -1442,9 +1686,7 @@ class CommandXSetId : public Commander { stream_name_ = args[1]; auto s = redis::ParseStreamEntryID(args[2], &last_id_); - if (!s.IsOK()) { - return {Status::RedisParseErr, s.Msg()}; - } + if (!s.IsOK()) return s; if (args.size() == 3) { return Status::OK(); @@ -1462,9 +1704,7 @@ class CommandXSetId : public Commander { } else if (util::EqualICase(args[i], "maxdeletedid") && i + 1 < args.size()) { StreamEntryID id; s = redis::ParseStreamEntryID(args[i + 1], &id); - if (!s.IsOK()) { - return {Status::RedisParseErr, s.Msg()}; - } + if (!s.IsOK()) return s; max_deleted_id_ = std::make_optional(id.ms, id.seq); i += 2; @@ -1496,16 +1736,20 @@ class CommandXSetId : public Commander { std::optional entries_added_; }; -REDIS_REGISTER_COMMANDS(MakeCmdAttr("xadd", -5, "write", 1, 1, 1), - MakeCmdAttr("xdel", -3, "write no-dbsize-check", 1, 1, 1), - MakeCmdAttr("xgroup", -4, "write", 2, 2, 1), - MakeCmdAttr("xlen", -2, "read-only", 1, 1, 1), - MakeCmdAttr("xinfo", -2, "read-only", 0, 0, 0), - MakeCmdAttr("xrange", -4, "read-only", 1, 1, 1), - MakeCmdAttr("xrevrange", -2, "read-only", 1, 1, 1), - MakeCmdAttr("xread", -4, "read-only", 0, 0, 0), - MakeCmdAttr("xreadgroup", -7, "write", 0, 0, 0), - MakeCmdAttr("xtrim", -4, "write no-dbsize-check", 1, 1, 1), - MakeCmdAttr("xsetid", -3, "write", 1, 1, 1)) +REDIS_REGISTER_COMMANDS( + // MakeCmdAttr("xack", -4, "write no-dbsize-check", 1, 1, 1), + MakeCmdAttr("xadd", -5, "write", 1, 1, 1), + MakeCmdAttr("xdel", -3, "write no-dbsize-check", 1, 1, 1), + // MakeCmdAttr("xclaim", -6, "write", 1, 1, 1), + // MakeCmdAttr("xautoclaim", -6, "write", 1, 1, 1), + // MakeCmdAttr("xgroup", -4, "write", 2, 2, 1), + MakeCmdAttr("xlen", -2, "read-only", 1, 1, 1), + MakeCmdAttr("xinfo", -2, "read-only", 0, 0, 0), + MakeCmdAttr("xrange", -4, "read-only", 1, 1, 1), + MakeCmdAttr("xrevrange", -2, "read-only", 1, 1, 1), + MakeCmdAttr("xread", -4, "read-only", 0, 0, 0), + // MakeCmdAttr("xreadgroup", -7, "write", 0, 0, 0), + MakeCmdAttr("xtrim", -4, "write no-dbsize-check", 1, 1, 1), + MakeCmdAttr("xsetid", -3, "write", 1, 1, 1)) } // namespace redis diff --git a/src/commands/cmd_string.cc b/src/commands/cmd_string.cc index ad5a6bf5b1c..2cfa2d86e8f 100644 --- a/src/commands/cmd_string.cc +++ b/src/commands/cmd_string.cc @@ -66,10 +66,10 @@ class CommandGetEx : public Commander { CommandParser parser(args, 2); std::string_view ttl_flag; while (parser.Good()) { - if (auto v = GET_OR_RET(ParseTTL(parser, ttl_flag))) { - ttl_ = *v; + if (auto v = GET_OR_RET(ParseExpireFlags(parser, ttl_flag))) { + expire_ = *v; } else if (parser.EatEqICaseFlag("PERSIST", ttl_flag)) { - persist_ = true; + expire_ = 0; } else { return parser.InvalidSyntax(); } @@ -80,7 +80,7 @@ class CommandGetEx : public Commander { Status Execute(Server *srv, Connection *conn, std::string *output) override { std::string value; redis::String string_db(srv->storage, conn->GetNamespace()); - auto s = string_db.GetEx(args_[1], &value, ttl_, persist_); + auto s = string_db.GetEx(args_[1], &value, expire_); // The IsInvalidArgument error means the key type maybe a bitmap // which we need to fall back to the bitmap's GetString according @@ -90,12 +90,8 @@ class CommandGetEx : public Commander { uint32_t max_btos_size = static_cast(config->max_bitmap_to_string_mb) * MiB; redis::Bitmap bitmap_db(srv->storage, conn->GetNamespace()); s = bitmap_db.GetString(args_[1], max_btos_size, &value); - if (s.ok()) { - if (ttl_ > 0) { - s = bitmap_db.Expire(args_[1], ttl_ + util::GetTimeStampMS()); - } else if (persist_) { - s = bitmap_db.Expire(args_[1], 0); - } + if (s.ok() && expire_) { + s = bitmap_db.Expire(args_[1], expire_.value()); } } if (!s.ok() && !s.IsNotFound()) { @@ -107,8 +103,7 @@ class CommandGetEx : public Commander { } private: - uint64_t ttl_ = 0; - bool persist_ = false; + std::optional expire_; }; class CommandStrlen : public Commander { @@ -282,8 +277,8 @@ class CommandSet : public Commander { CommandParser parser(args, 3); std::string_view ttl_flag, set_flag; while (parser.Good()) { - if (auto v = GET_OR_RET(ParseTTL(parser, ttl_flag))) { - ttl_ = *v; + if (auto v = GET_OR_RET(ParseExpireFlags(parser, ttl_flag))) { + expire_ = *v; } else if (parser.EatEqICaseFlag("KEEPTTL", ttl_flag)) { keep_ttl_ = true; } else if (parser.EatEqICaseFlag("NX", set_flag)) { @@ -304,17 +299,7 @@ class CommandSet : public Commander { std::optional ret; redis::String string_db(srv->storage, conn->GetNamespace()); - if (ttl_ < 0) { - auto s = string_db.Del(args_[1]); - if (!s.ok()) { - return {Status::RedisExecErr, s.ToString()}; - } - *output = redis::SimpleString("OK"); - return Status::OK(); - } - - rocksdb::Status s; - s = string_db.Set(args_[1], args_[2], {ttl_, set_flag_, get_, keep_ttl_}, ret); + rocksdb::Status s = string_db.Set(args_[1], args_[2], {expire_, set_flag_, get_, keep_ttl_}, ret); if (!s.ok()) { return {Status::RedisExecErr, s.ToString()}; @@ -337,7 +322,7 @@ class CommandSet : public Commander { } private: - uint64_t ttl_ = 0; + uint64_t expire_ = 0; bool get_ = false; bool keep_ttl_ = false; StringSetType set_flag_ = StringSetType::NONE; @@ -353,20 +338,20 @@ class CommandSetEX : public Commander { if (*parse_result <= 0) return {Status::RedisParseErr, errInvalidExpireTime}; - ttl_ = *parse_result; + expire_ = *parse_result * 1000 + util::GetTimeStampMS(); return Commander::Parse(args); } Status Execute(Server *srv, Connection *conn, std::string *output) override { redis::String string_db(srv->storage, conn->GetNamespace()); - auto s = string_db.SetEX(args_[1], args_[3], ttl_ * 1000); + auto s = string_db.SetEX(args_[1], args_[3], expire_); *output = redis::SimpleString("OK"); return Status::OK(); } private: - uint64_t ttl_ = 0; + uint64_t expire_ = 0; }; class CommandPSetEX : public Commander { @@ -379,20 +364,20 @@ class CommandPSetEX : public Commander { if (*ttl_ms <= 0) return {Status::RedisParseErr, errInvalidExpireTime}; - ttl_ = *ttl_ms; + expire_ = *ttl_ms + util::GetTimeStampMS(); return Commander::Parse(args); } Status Execute(Server *srv, Connection *conn, std::string *output) override { redis::String string_db(srv->storage, conn->GetNamespace()); - auto s = string_db.SetEX(args_[1], args_[3], ttl_); + auto s = string_db.SetEX(args_[1], args_[3], expire_); *output = redis::SimpleString("OK"); return Status::OK(); } private: - int64_t ttl_ = 0; + uint64_t expire_ = 0; }; class CommandMSet : public Commander { @@ -412,7 +397,7 @@ class CommandMSet : public Commander { kvs.emplace_back(StringPair{args_[i], args_[i + 1]}); } - auto s = string_db.MSet(kvs); + auto s = string_db.MSet(kvs, 0); if (!s.ok()) { return {Status::RedisExecErr, s.ToString()}; } @@ -581,8 +566,8 @@ class CommandCAS : public Commander { CommandParser parser(args, 4); std::string_view flag; while (parser.Good()) { - if (auto v = GET_OR_RET(ParseTTL(parser, flag))) { - ttl_ = *v; + if (auto v = GET_OR_RET(ParseExpireFlags(parser, flag))) { + expire_ = *v; } else { return parser.InvalidSyntax(); } @@ -593,7 +578,7 @@ class CommandCAS : public Commander { Status Execute(Server *srv, Connection *conn, std::string *output) override { redis::String string_db(srv->storage, conn->GetNamespace()); int ret = 0; - auto s = string_db.CAS(args_[1], args_[2], args_[3], ttl_, &ret); + auto s = string_db.CAS(args_[1], args_[2], args_[3], expire_, &ret); if (!s.ok()) { return {Status::RedisExecErr, s.ToString()}; } @@ -603,7 +588,7 @@ class CommandCAS : public Commander { } private: - uint64_t ttl_ = 0; + uint64_t expire_ = 0; }; class CommandCAD : public Commander { diff --git a/src/commands/cmd_txn.cc b/src/commands/cmd_txn.cc index fa1a47aadae..130533fbc9b 100644 --- a/src/commands/cmd_txn.cc +++ b/src/commands/cmd_txn.cc @@ -68,8 +68,7 @@ class CommandExec : public Commander { auto reset_multiexec = MakeScopeExit([conn] { conn->ResetMultiExec(); }); if (conn->IsMultiError()) { - *output = redis::Error("EXECABORT Transaction discarded"); - return Status::OK(); + return {Status::RedisExecAbort, "Transaction discarded"}; } if (srv->IsWatchedKeysModified(conn)) { diff --git a/src/commands/cmd_zset.cc b/src/commands/cmd_zset.cc index c32c976a431..715747ab8a4 100644 --- a/src/commands/cmd_zset.cc +++ b/src/commands/cmd_zset.cc @@ -136,9 +136,7 @@ class CommandZCount : public Commander { public: Status Parse(const std::vector &args) override { Status s = ParseRangeScoreSpec(args[2], args[3], &spec_); - if (!s.IsOK()) { - return {Status::RedisParseErr, s.Msg()}; - } + if (!s.IsOK()) return s; return Commander::Parse(args); } @@ -204,9 +202,7 @@ class CommandZLexCount : public Commander { public: Status Parse(const std::vector &args) override { Status s = ParseRangeLexSpec(args[2], args[3], &spec_); - if (!s.IsOK()) { - return {Status::RedisParseErr, s.Msg()}; - } + if (!s.IsOK()) return s; return Commander::Parse(args); } @@ -369,7 +365,7 @@ class CommandBZPop : public BlockingCommander { redis::ZSet zset_db(srv_->storage, conn_->GetNamespace()); auto s = PopFromMultipleZsets(&zset_db, keys_, min_, 1, &user_key, &member_scores); if (!s.ok()) { - conn_->Reply(redis::Error("ERR " + s.ToString())); + conn_->Reply(redis::Error({Status::NotOK, s.ToString()})); return true; } @@ -548,7 +544,7 @@ class CommandBZMPop : public BlockingCommander { redis::ZSet zset_db(srv_->storage, conn_->GetNamespace()); auto s = PopFromMultipleZsets(&zset_db, keys_, flag_ == ZSET_MIN, count_, &user_key, &member_scores); if (!s.ok()) { - conn_->Reply(redis::Error("ERR " + s.ToString())); + conn_->Reply(redis::Error({Status::NotOK, s.ToString()})); return true; } @@ -985,9 +981,7 @@ class CommandZRemRangeByScore : public Commander { public: Status Parse(const std::vector &args) override { Status s = ParseRangeScoreSpec(args[2], args[3], &spec_); - if (!s.IsOK()) { - return {Status::RedisParseErr, s.Msg()}; - } + if (!s.IsOK()) return s; return Commander::Parse(args); } @@ -1014,9 +1008,7 @@ class CommandZRemRangeByLex : public Commander { public: Status Parse(const std::vector &args) override { Status s = ParseRangeLexSpec(args[2], args[3], &spec_); - if (!s.IsOK()) { - return {Status::RedisParseErr, s.Msg()}; - } + if (!s.IsOK()) return s; return Commander::Parse(args); } diff --git a/src/commands/commander.h b/src/commands/commander.h index d759bd2047e..1441581d36a 100644 --- a/src/commands/commander.h +++ b/src/commands/commander.h @@ -106,13 +106,20 @@ struct CommandKeyRange { // step length of key position // e.g. key step 2 means "key other key other ..." sequence int key_step; + + template + void ForEachKey(F &&f, const std::vector &args) const { + for (size_t i = first_key; last_key > 0 ? i <= size_t(last_key) : i <= args.size() + last_key; i += key_step) { + std::forward(f)(args[i]); + } + } }; using CommandKeyRangeGen = std::function &)>; using CommandKeyRangeVecGen = std::function(const std::vector &)>; -using AdditionalFlagGen = std::function &)>; +using AdditionalFlagGen = std::function &)>; struct CommandAttributes { // command name @@ -146,13 +153,34 @@ struct CommandAttributes { auto GenerateFlags(const std::vector &args) const { uint64_t res = flags; - if (flag_gen) res |= flag_gen(args); + if (flag_gen) res = flag_gen(res, args); return res; } bool CheckArity(int cmd_size) const { return !((arity > 0 && cmd_size != arity) || (arity < 0 && cmd_size < -arity)); } + + template + void ForEachKeyRange(F &&f, const std::vector &args) const { + if (key_range.first_key > 0) { + std::forward(f)(args, key_range); + } else if (key_range.first_key == -1) { + redis::CommandKeyRange range = key_range_gen(args); + + if (range.first_key > 0) { + std::forward(f)(args, range); + } + } else if (key_range.first_key == -2) { + std::vector vec_range = key_range_vec_gen(args); + + for (const auto &range : vec_range) { + if (range.first_key > 0) { + std::forward(f)(args, range); + } + } + } + } }; using CommandMap = std::map; diff --git a/src/commands/error_constants.h b/src/commands/error_constants.h index 43c7440da09..ea2c38b72c0 100644 --- a/src/commands/error_constants.h +++ b/src/commands/error_constants.h @@ -40,8 +40,10 @@ inline constexpr const char *errLimitOptionNotAllowed = inline constexpr const char *errZSetLTGTNX = "GT, LT, and/or NX options at the same time are not compatible"; inline constexpr const char *errScoreIsNotValidFloat = "score is not a valid float"; inline constexpr const char *errValueIsNotFloat = "value is not a valid float"; -inline constexpr const char *errNoMatchingScript = "NOSCRIPT No matching script. Please use EVAL"; +inline constexpr const char *errNoMatchingScript = "No matching script. Please use EVAL"; inline constexpr const char *errUnknownOption = "unknown option"; inline constexpr const char *errUnknownSubcommandOrWrongArguments = "Unknown subcommand or wrong number of arguments"; +inline constexpr const char *errWrongNumArguments = "wrong number of arguments"; +inline constexpr const char *errRestoringBackup = "kvrocks is restoring the db from backup"; } // namespace redis diff --git a/src/commands/scan_base.h b/src/commands/scan_base.h index 2e11c989bd4..3a6438c2f09 100644 --- a/src/commands/scan_base.h +++ b/src/commands/scan_base.h @@ -21,6 +21,7 @@ #pragma once #include "commander.h" +#include "commands/command_parser.h" #include "error_constants.h" #include "parse_util.h" #include "server/server.h" @@ -31,31 +32,46 @@ inline constexpr const char *kCursorPrefix = "_"; class CommandScanBase : public Commander { public: - Status ParseMatchAndCountParam(const std::string &type, std::string value) { - if (type == "match") { - prefix_ = std::move(value); - if (!prefix_.empty() && prefix_[prefix_.size() - 1] == '*') { - prefix_ = prefix_.substr(0, prefix_.size() - 1); - return Status::OK(); - } + Status Parse(const std::vector &args) override { + CommandParser parser(args, 1); - return {Status::RedisParseErr, "only keys prefix match was supported"}; - } else if (type == "count") { - auto parse_result = ParseInt(value, 10); - if (!parse_result) { - return {Status::RedisParseErr, "count param should be type int"}; - } + PutCursor(GET_OR_RET(parser.TakeStr())); - limit_ = *parse_result; - if (limit_ <= 0) { - return {Status::RedisParseErr, errInvalidSyntax}; + return ParseAdditionalFlags(parser); + } + + template + Status ParseAdditionalFlags(Parser &parser) { + while (parser.Good()) { + if (parser.EatEqICase("match")) { + prefix_ = GET_OR_RET(parser.TakeStr()); + if (!prefix_.empty() && prefix_.back() == '*') { + prefix_ = prefix_.substr(0, prefix_.size() - 1); + } else { + return {Status::RedisParseErr, "currently only key prefix matching is supported"}; + } + } else if (parser.EatEqICase("count")) { + limit_ = GET_OR_RET(parser.TakeInt()); + if (limit_ <= 0) { + return {Status::RedisParseErr, "limit should be a positive integer"}; + } + } else if (IsScan && parser.EatEqICase("type")) { + std::string type_str = GET_OR_RET(parser.TakeStr()); + if (auto iter = std::find(RedisTypeNames.begin(), RedisTypeNames.end(), type_str); + iter != RedisTypeNames.end()) { + type_ = static_cast(iter - RedisTypeNames.begin()); + } else { + return {Status::RedisExecErr, "Invalid type"}; + } + } else { + return parser.InvalidSyntax(); } } return Status::OK(); } - void ParseCursor(const std::string ¶m) { + void PutCursor(const std::string ¶m) { cursor_ = param; if (cursor_ == "0") { cursor_ = std::string(); @@ -83,6 +99,7 @@ class CommandScanBase : public Commander { std::string cursor_; std::string prefix_; int limit_ = 20; + RedisType type_ = kRedisNone; }; class CommandSubkeyScanBase : public CommandScanBase { @@ -90,26 +107,13 @@ class CommandSubkeyScanBase : public CommandScanBase { CommandSubkeyScanBase() : CommandScanBase() {} Status Parse(const std::vector &args) override { - if (args.size() % 2 == 0) { - return {Status::RedisParseErr, errWrongNumOfArguments}; - } + CommandParser parser(args, 1); - key_ = args[1]; - ParseCursor(args[2]); - if (args.size() >= 5) { - Status s = ParseMatchAndCountParam(util::ToLower(args[3]), args_[4]); - if (!s.IsOK()) { - return s; - } - } + key_ = GET_OR_RET(parser.TakeStr()); - if (args.size() >= 7) { - Status s = ParseMatchAndCountParam(util::ToLower(args[5]), args_[6]); - if (!s.IsOK()) { - return s; - } - } - return Commander::Parse(args); + PutCursor(GET_OR_RET(parser.TakeStr())); + + return ParseAdditionalFlags(parser); } std::string GetNextCursor(Server *srv, std::vector &fields, CursorType cursor_type) const { diff --git a/src/commands/ttl_util.h b/src/commands/ttl_util.h index 4885604b937..667cebf3554 100644 --- a/src/commands/ttl_util.h +++ b/src/commands/ttl_util.h @@ -31,15 +31,15 @@ template constexpr auto TTL_RANGE = NumericRange{1, std::numeric_limits::max()}; template -StatusOr> ParseTTL(CommandParser &parser, std::string_view &curr_flag) { +StatusOr> ParseExpireFlags(CommandParser &parser, std::string_view &curr_flag) { if (parser.EatEqICaseFlag("EX", curr_flag)) { - return GET_OR_RET(parser.template TakeInt(TTL_RANGE)) * 1000; + return GET_OR_RET(parser.template TakeInt(TTL_RANGE)) * 1000 + util::GetTimeStampMS(); } else if (parser.EatEqICaseFlag("EXAT", curr_flag)) { - return GET_OR_RET(parser.template TakeInt(TTL_RANGE)) * 1000 - util::GetTimeStampMS(); + return GET_OR_RET(parser.template TakeInt(TTL_RANGE)) * 1000; } else if (parser.EatEqICaseFlag("PX", curr_flag)) { - return GET_OR_RET(parser.template TakeInt(TTL_RANGE)); + return GET_OR_RET(parser.template TakeInt(TTL_RANGE)) + util::GetTimeStampMS(); } else if (parser.EatEqICaseFlag("PXAT", curr_flag)) { - return GET_OR_RET(parser.template TakeInt(TTL_RANGE)) - util::GetTimeStampMS(); + return GET_OR_RET(parser.template TakeInt(TTL_RANGE)); } else { return std::nullopt; } diff --git a/src/common/cron.cc b/src/common/cron.cc index f4c223bf507..041ffe160be 100644 --- a/src/common/cron.cc +++ b/src/common/cron.cc @@ -23,28 +23,54 @@ #include #include +#include "fmt/core.h" #include "parse_util.h" +#include "string_util.h" -std::string Scheduler::ToString() const { - auto param2string = [](int n) -> std::string { return n == -1 ? "*" : std::to_string(n); }; - return param2string(minute) + " " + param2string(hour) + " " + param2string(mday) + " " + param2string(month) + " " + - param2string(wday); +std::string CronScheduler::ToString() const { + return fmt::format("{} {} {} {} {}", minute.ToString(), hour.ToString(), mday.ToString(), month.ToString(), + wday.ToString()); } +bool CronScheduler::IsMatch(const tm *tm) const { + bool minute_match = minute.IsMatch(tm->tm_min); + bool hour_match = hour.IsMatch(tm->tm_hour); + bool mday_match = mday.IsMatch(tm->tm_mday, 1); + bool month_match = month.IsMatch(tm->tm_mon + 1, 1); + bool wday_match = wday.IsMatch(tm->tm_wday); + + return minute_match && hour_match && mday_match && month_match && wday_match; +} + +StatusOr CronScheduler::Parse(std::string_view minute, std::string_view hour, std::string_view mday, + std::string_view month, std::string_view wday) { + CronScheduler st; + + st.minute = GET_OR_RET(CronPattern::Parse(minute, {0, 59})); + st.hour = GET_OR_RET(CronPattern::Parse(hour, {0, 23})); + st.mday = GET_OR_RET(CronPattern::Parse(mday, {1, 31})); + st.month = GET_OR_RET(CronPattern::Parse(month, {1, 12})); + st.wday = GET_OR_RET(CronPattern::Parse(wday, {0, 6})); + + return st; +} + +void Cron::Clear() { schedulers_.clear(); } + Status Cron::SetScheduleTime(const std::vector &args) { if (args.empty()) { schedulers_.clear(); return Status::OK(); } if (args.size() % 5 != 0) { - return {Status::NotOK, "time expression format error,should only contain 5x fields"}; + return {Status::NotOK, "cron expression format error, should only contain 5x fields"}; } - std::vector new_schedulers; + std::vector new_schedulers; for (size_t i = 0; i < args.size(); i += 5) { - auto s = convertToScheduleTime(args[i], args[i + 1], args[i + 2], args[i + 3], args[i + 4]); + auto s = CronScheduler::Parse(args[i], args[i + 1], args[i + 2], args[i + 3], args[i + 4]); if (!s.IsOK()) { - return std::move(s).Prefixed("time expression format error"); + return std::move(s).Prefixed("cron expression format error"); } new_schedulers.push_back(*s); } @@ -52,15 +78,14 @@ Status Cron::SetScheduleTime(const std::vector &args) { return Status::OK(); } -bool Cron::IsTimeMatch(tm *tm) { +bool Cron::IsTimeMatch(const tm *tm) { if (tm->tm_min == last_tm_.tm_min && tm->tm_hour == last_tm_.tm_hour && tm->tm_mday == last_tm_.tm_mday && tm->tm_mon == last_tm_.tm_mon && tm->tm_wday == last_tm_.tm_wday) { return false; } + for (const auto &st : schedulers_) { - if ((st.minute == -1 || tm->tm_min == st.minute) && (st.hour == -1 || tm->tm_hour == st.hour) && - (st.mday == -1 || tm->tm_mday == st.mday) && (st.month == -1 || (tm->tm_mon + 1) == st.month) && - (st.wday == -1 || tm->tm_wday == st.wday)) { + if (st.IsMatch(tm)) { last_tm_ = *tm; return true; } @@ -78,30 +103,3 @@ std::string Cron::ToString() const { } return ret; } - -StatusOr Cron::convertToScheduleTime(const std::string &minute, const std::string &hour, - const std::string &mday, const std::string &month, - const std::string &wday) { - Scheduler st; - - st.minute = GET_OR_RET(convertParam(minute, 0, 59)); - st.hour = GET_OR_RET(convertParam(hour, 0, 23)); - st.mday = GET_OR_RET(convertParam(mday, 1, 31)); - st.month = GET_OR_RET(convertParam(month, 1, 12)); - st.wday = GET_OR_RET(convertParam(wday, 0, 6)); - - return st; -} - -StatusOr Cron::convertParam(const std::string ¶m, int lower_bound, int upper_bound) { - if (param == "*") { - return -1; - } - - auto s = ParseInt(param, {lower_bound, upper_bound}, 10); - if (!s) { - return std::move(s).Prefixed(fmt::format("malformed cron token `{}`", param)); - } - - return *s; -} diff --git a/src/common/cron.h b/src/common/cron.h index cba6d275af3..ede6231beab 100644 --- a/src/common/cron.h +++ b/src/common/cron.h @@ -23,18 +23,135 @@ #include #include #include +#include #include +#include "parse_util.h" #include "status.h" +#include "string_util.h" -struct Scheduler { - int minute; - int hour; - int mday; - int month; - int wday; +struct CronPattern { + using Number = int; + using Range = std::pair; + + struct Interval { + int interval; + }; // */n + struct Any {}; // * + using Numbers = std::vector>; // 1,2,3-6,7 + + std::variant val; + + static StatusOr Parse(std::string_view str, std::tuple minmax) { + if (str == "*") { + return CronPattern{Any{}}; + } else if (str.rfind("*/", 0) == 0) { + auto num_str = str.substr(2); + auto interval = GET_OR_RET(ParseInt(std::string(num_str.begin(), num_str.end()), minmax) + .Prefixed("an integer is expected after `*/` in a cron expression")); + + if (interval == 0) { + return {Status::NotOK, "interval value after `*/` cannot be zero"}; + } + + return CronPattern{Interval{interval}}; + } else { + auto num_strs = util::Split(str, ","); + + Numbers results; + for (const auto &num_str : num_strs) { + if (auto pos = num_str.find('-'); pos != num_str.npos) { + auto l_str = num_str.substr(0, pos); + auto r_str = num_str.substr(pos + 1); + auto l = GET_OR_RET( + ParseInt(l_str, minmax).Prefixed("an integer is expected before `-` in a cron expression")); + auto r = GET_OR_RET( + ParseInt(r_str, minmax).Prefixed("an integer is expected after `-` in a cron expression")); + + if (l >= r) { + return {Status::NotOK, "for pattern `l-r` in cron expression, r should be larger than l"}; + } + results.push_back(Range(l, r)); + } else { + auto n = GET_OR_RET(ParseInt(std::string(num_str.begin(), num_str.end()), minmax) + .Prefixed("an integer is expected in a cron expression")); + results.push_back(n); + } + } + + if (results.empty()) { + return {Status::NotOK, "invalid cron expression"}; + } + + return CronPattern{results}; + } + } + + std::string ToString() const { + if (std::holds_alternative(val)) { + std::string result; + bool first = true; + + for (const auto &v : std::get(val)) { + if (first) + first = false; + else + result += ","; + + if (std::holds_alternative(v)) { + result += std::to_string(std::get(v)); + } else { + auto range = std::get(v); + result += std::to_string(range.first) + "-" + std::to_string(range.second); + } + } + + return result; + } else if (std::holds_alternative(val)) { + return "*/" + std::to_string(std::get(val).interval); + } else if (std::holds_alternative(val)) { + return "*"; + } + + __builtin_unreachable(); + } + + bool IsMatch(int input, int interval_offset = 0) const { + if (std::holds_alternative(val)) { + bool result = false; + for (const auto &v : std::get(val)) { + if (std::holds_alternative(v)) { + result = result || input == std::get(v); + } else { + auto range = std::get(v); + result = result || (range.first <= input && input <= range.second); + } + } + + return result; + } else if (std::holds_alternative(val)) { + return (input - interval_offset) % std::get(val).interval == 0; + } else if (std::holds_alternative(val)) { + return true; + } + + __builtin_unreachable(); + } +}; + +struct CronScheduler { + CronPattern minute; + CronPattern hour; + CronPattern mday; + CronPattern month; + CronPattern wday; std::string ToString() const; + + static StatusOr Parse(std::string_view minute, std::string_view hour, std::string_view mday, + std::string_view month, std::string_view wday); + + bool IsMatch(const tm *tm) const; }; class Cron { @@ -43,16 +160,13 @@ class Cron { ~Cron() = default; Status SetScheduleTime(const std::vector &args); - bool IsTimeMatch(tm *tm); + void Clear(); + + bool IsTimeMatch(const tm *tm); std::string ToString() const; bool IsEnabled() const; private: - std::vector schedulers_; + std::vector schedulers_; tm last_tm_ = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, nullptr}; - - static StatusOr convertToScheduleTime(const std::string &minute, const std::string &hour, - const std::string &mday, const std::string &month, - const std::string &wday); - static StatusOr convertParam(const std::string ¶m, int lower_bound, int upper_bound); }; diff --git a/src/common/db_util.h b/src/common/db_util.h index 8df34daaa50..d29262f2b33 100644 --- a/src/common/db_util.h +++ b/src/common/db_util.h @@ -37,6 +37,8 @@ struct UniqueIterator : std::unique_ptr { UniqueIterator(engine::Storage* storage, const rocksdb::ReadOptions& options, rocksdb::ColumnFamilyHandle* column_family) : BaseType(storage->NewIterator(options, column_family)) {} + UniqueIterator(engine::Storage* storage, const rocksdb::ReadOptions& options, ColumnFamilyID cf) + : BaseType(storage->NewIterator(options, storage->GetCFHandle(cf))) {} UniqueIterator(engine::Storage* storage, const rocksdb::ReadOptions& options) : BaseType(storage->NewIterator(options)) {} }; diff --git a/src/common/encoding.h b/src/common/encoding.h index d497793068a..5d05e45614a 100644 --- a/src/common/encoding.h +++ b/src/common/encoding.h @@ -102,6 +102,21 @@ inline bool GetFixed16(rocksdb::Slice *input, uint16_t *value) { return GetFixed inline bool GetFixed32(rocksdb::Slice *input, uint32_t *value) { return GetFixed(input, value); } inline bool GetFixed64(rocksdb::Slice *input, uint64_t *value) { return GetFixed(input, value); } +inline void PutSizedString(std::string *dst, rocksdb::Slice value) { + PutFixed32(dst, value.size()); + dst->append(value.ToStringView()); +} + +inline bool GetSizedString(rocksdb::Slice *input, rocksdb::Slice *value) { + uint32_t size = 0; + if (!GetFixed32(input, &size)) return false; + + if (input->size() < size) return false; + *value = rocksdb::Slice(input->data(), size); + input->remove_prefix(size); + return true; +} + char *EncodeDouble(char *buf, double value); void PutDouble(std::string *dst, double value); double DecodeDouble(const char *ptr); diff --git a/src/common/event_util.h b/src/common/event_util.h index fccdfd5336b..f3c83012430 100644 --- a/src/common/event_util.h +++ b/src/common/event_util.h @@ -22,6 +22,7 @@ #include #include +#include #include #include "event2/buffer.h" @@ -44,6 +45,8 @@ struct UniqueEvbufReadln : UniqueFreePtr { : UniqueFreePtr(evbuffer_readln(buffer, &length, eol_style)) {} size_t length; + + std::string_view View() { return {get(), length}; } }; using StaticEvbufFree = StaticFunction; diff --git a/src/common/rdb_stream.cc b/src/common/rdb_stream.cc index 15896eebd90..7be1088be80 100644 --- a/src/common/rdb_stream.cc +++ b/src/common/rdb_stream.cc @@ -32,6 +32,11 @@ Status RdbStringStream::Read(char *buf, size_t n) { return Status::OK(); } +Status RdbStringStream::Write(const char *buf, size_t len) { + input_.append(buf, len); + return Status::OK(); +} + StatusOr RdbStringStream::GetCheckSum() const { if (input_.size() < 8) { return {Status::NotOK, "invalid payload length"}; diff --git a/src/common/rdb_stream.h b/src/common/rdb_stream.h index c67a1b468dd..824808ae052 100644 --- a/src/common/rdb_stream.h +++ b/src/common/rdb_stream.h @@ -33,6 +33,7 @@ class RdbStream { virtual ~RdbStream() = default; virtual Status Read(char *buf, size_t len) = 0; + virtual Status Write(const char *buf, size_t len) = 0; virtual StatusOr GetCheckSum() const = 0; StatusOr ReadByte() { uint8_t value = 0; @@ -52,7 +53,9 @@ class RdbStringStream : public RdbStream { ~RdbStringStream() override = default; Status Read(char *buf, size_t len) override; + Status Write(const char *buf, size_t len) override; StatusOr GetCheckSum() const override; + std::string &GetInput() { return input_; } private: std::string input_; @@ -69,6 +72,7 @@ class RdbFileStream : public RdbStream { Status Open(); Status Read(char *buf, size_t len) override; + Status Write(const char *buf, size_t len) override { return {Status::NotOK, fmt::format("No implement")}; }; StatusOr GetCheckSum() const override { uint64_t crc = check_sum_; memrev64ifbe(&crc); diff --git a/src/common/status.h b/src/common/status.h index b425ea5b238..d06c2f06014 100644 --- a/src/common/status.h +++ b/src/common/status.h @@ -50,9 +50,21 @@ class [[nodiscard]] Status { RedisInvalidCmd, RedisParseErr, RedisExecErr, + RedisErrorNoPrefix, + RedisNoProto, + RedisLoading, + RedisMasterDown, + RedisNoScript, + RedisNoAuth, + RedisWrongType, + RedisReadOnly, + RedisExecAbort, + RedisMoved, + RedisCrossSlot, + RedisTryAgain, + RedisClusterDown, // Cluster - ClusterDown, ClusterInvalidInfo, // Blocking diff --git a/src/common/string_util.cc b/src/common/string_util.cc index ae41f918af9..4a4d79a8f07 100644 --- a/src/common/string_util.cc +++ b/src/common/string_util.cc @@ -357,6 +357,12 @@ std::string EscapeString(std::string_view s) { case '\b': str += "\\b"; break; + case '\v': + str += "\\v"; + break; + case '\f': + str += "\\f"; + break; default: if (isprint(ch)) { str += ch; @@ -371,4 +377,14 @@ std::string EscapeString(std::string_view s) { return str; } +std::string StringNext(std::string s) { + for (auto iter = s.rbegin(); iter != s.rend(); ++iter) { + if (*iter != char(0xff)) { + (*iter)++; + break; + } + } + return s; +} + } // namespace util diff --git a/src/common/string_util.h b/src/common/string_util.h index d23ebad7b90..ac3b2904fae 100644 --- a/src/common/string_util.h +++ b/src/common/string_util.h @@ -38,5 +38,22 @@ std::vector RegexMatch(const std::string &str, const std::string &r std::string StringToHex(std::string_view input); std::vector TokenizeRedisProtocol(const std::string &value); std::string EscapeString(std::string_view s); +std::string StringNext(std::string s); + +template +std::string StringJoin( + const T &con, F &&f = [](const auto &v) -> decltype(auto) { return v; }, std::string_view sep = ", ") { + std::string res; + bool is_first = true; + for (const auto &v : con) { + if (is_first) { + is_first = false; + } else { + res += sep; + } + res += std::forward(f)(v); + } + return res; +} } // namespace util diff --git a/src/common/time_util.h b/src/common/time_util.h index 1c8dc7b6272..9eb6daa4266 100644 --- a/src/common/time_util.h +++ b/src/common/time_util.h @@ -24,6 +24,7 @@ namespace util { +/// Get the system timestamp in seconds, milliseconds or microseconds. template auto GetTimeStamp() { return std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count(); diff --git a/src/common/type_util.h b/src/common/type_util.h index d55019f77f7..98439243622 100644 --- a/src/common/type_util.h +++ b/src/common/type_util.h @@ -39,3 +39,11 @@ using RemoveCVRef = typename std::remove_cv_t constexpr bool AlwaysFalse = false; + +template +struct GetClassFromMember; + +template +struct GetClassFromMember { + using type = C; // NOLINT +}; diff --git a/src/config/config.cc b/src/config/config.cc index 6e3cff8c8f9..c1afeae3cbb 100644 --- a/src/config/config.cc +++ b/src/config/config.cc @@ -42,6 +42,9 @@ #include "status.h" #include "storage/redis_metadata.h" +constexpr const char *kDefaultDir = "/tmp/kvrocks"; +constexpr const char *kDefaultBackupDir = "/tmp/kvrocks/backup"; +constexpr const char *kDefaultPidfile = "/tmp/kvrocks/kvrocks.pid"; constexpr const char *kDefaultBindAddress = "127.0.0.1"; constexpr const char *errBlobDbNotEnabled = "Must set rocksdb.enable_blob_files to yes first."; @@ -134,18 +137,20 @@ Config::Config() { {"slaveof", true, new StringField(&slaveof_, "")}, {"compact-cron", false, new StringField(&compact_cron_str_, "")}, {"bgsave-cron", false, new StringField(&bgsave_cron_str_, "")}, + {"dbsize-scan-cron", false, new StringField(&dbsize_scan_cron_str_, "")}, {"replica-announce-ip", false, new StringField(&replica_announce_ip, "")}, {"replica-announce-port", false, new UInt32Field(&replica_announce_port, 0, 0, PORT_LIMIT)}, {"compaction-checker-range", false, new StringField(&compaction_checker_range_str_, "")}, + {"compaction-checker-cron", false, new StringField(&compaction_checker_cron_str_, "")}, {"force-compact-file-age", false, new Int64Field(&force_compact_file_age, 2 * 24 * 3600, 60, INT64_MAX)}, {"force-compact-file-min-deleted-percentage", false, new IntField(&force_compact_file_min_deleted_percentage, 10, 1, 100)}, {"db-name", true, new StringField(&db_name, "change.me.db")}, - {"dir", true, new StringField(&dir, "/tmp/kvrocks")}, - {"backup-dir", false, new StringField(&backup_dir_, "")}, + {"dir", true, new StringField(&dir, kDefaultDir)}, + {"backup-dir", false, new StringField(&backup_dir, kDefaultBackupDir)}, {"log-dir", true, new StringField(&log_dir, "")}, {"log-level", false, new EnumField(&log_level, log_levels, google::INFO)}, - {"pidfile", true, new StringField(&pidfile_, "")}, + {"pidfile", true, new StringField(&pidfile, kDefaultPidfile)}, {"max-io-mb", false, new IntField(&max_io_mb, 0, 0, INT_MAX)}, {"max-bitmap-to-string-mb", false, new IntField(&max_bitmap_to_string_mb, 16, 0, INT_MAX)}, {"max-db-size", false, new IntField(&max_db_size, 0, 0, INT_MAX)}, @@ -191,6 +196,7 @@ Config::Config() { {"rocksdb.compression", false, new EnumField(&rocks_db.compression, compression_types, rocksdb::CompressionType::kNoCompression)}, + {"rocksdb.compression_level", true, new IntField(&rocks_db.compression_level, 32767, INT_MIN, INT_MAX)}, {"rocksdb.block_size", true, new IntField(&rocks_db.block_size, 16384, 0, INT_MAX)}, {"rocksdb.max_open_files", false, new IntField(&rocks_db.max_open_files, 8096, -1, INT_MAX)}, {"rocksdb.write_buffer_size", false, new IntField(&rocks_db.write_buffer_size, 64, 0, 4096)}, @@ -247,7 +253,7 @@ Config::Config() { new YesNoField(&rocks_db.write_options.memtable_insert_hint_per_batch, false)}, /* rocksdb read options */ - {"rocksdb.read_options.async_io", false, new YesNoField(&rocks_db.read_options.async_io, false)}, + {"rocksdb.read_options.async_io", false, new YesNoField(&rocks_db.read_options.async_io, true)}, }; for (auto &wrapper : fields) { auto &field = wrapper.field; @@ -289,23 +295,26 @@ void Config::initFieldValidator() { std::vector args = util::Split(v, " \t"); return bgsave_cron.SetScheduleTime(args); }}, + {"dbsize-scan-cron", + [this](const std::string &k, const std::string &v) -> Status { + std::vector args = util::Split(v, " \t"); + return dbsize_scan_cron.SetScheduleTime(args); + }}, {"compaction-checker-range", [this](const std::string &k, const std::string &v) -> Status { + if (!compaction_checker_cron_str_.empty()) { + return {Status::NotOK, "compaction-checker-range cannot be set while compaction-checker-cron is set"}; + } if (v.empty()) { - compaction_checker_range.start = -1; - compaction_checker_range.stop = -1; + compaction_checker_cron.Clear(); return Status::OK(); } - std::vector args = util::Split(v, "-"); - if (args.size() != 2) { - return {Status::NotOK, "invalid range format, the range should be between 0 and 24"}; - } - auto start = GET_OR_RET(ParseInt(args[0], {0, 24}, 10)), - stop = GET_OR_RET(ParseInt(args[1], {0, 24}, 10)); - if (start > stop) return {Status::NotOK, "invalid range format, start should be smaller than stop"}; - compaction_checker_range.start = start; - compaction_checker_range.stop = stop; - return Status::OK(); + return compaction_checker_cron.SetScheduleTime({"*", v, "*", "*", "*"}); + }}, + {"compaction-checker-cron", + [this](const std::string &k, const std::string &v) -> Status { + std::vector args = util::Split(v, " \t"); + return compaction_checker_cron.SetScheduleTime(args); }}, {"rename-command", [](const std::string &k, const std::string &v) -> Status { @@ -403,6 +412,8 @@ void Config::initFieldCallback() { checkpoint_dir = dir + "/checkpoint"; sync_checkpoint_dir = dir + "/sync_checkpoint"; backup_sync_dir = dir + "/backup_for_sync"; + if (backup_dir == kDefaultBackupDir) backup_dir = dir + "/backup"; + if (pidfile == kDefaultPidfile) pidfile = dir + "/kvrocks.pid"; return Status::OK(); }}, {"backup-dir", @@ -412,8 +423,8 @@ void Config::initFieldCallback() { // Note: currently, backup_mu_ may block by backing up or purging, // the command may wait for seconds. std::lock_guard lg(this->backup_mu); - previous_backup = std::move(backup_dir_); - backup_dir_ = v; + previous_backup = std::move(backup_dir); + backup_dir = v; } if (!previous_backup.empty() && srv != nullptr && !srv->IsLoading()) { // LOG(INFO) should be called after log is initialized and server is loaded. @@ -727,7 +738,7 @@ void Config::ClearMaster() { Status Config::parseConfigFromPair(const std::pair &input, int line_number) { std::string field_key = util::ToLower(input.first); - const char ns_str[] = "namespace."; + constexpr const char ns_str[] = "namespace."; size_t ns_str_size = sizeof(ns_str) - 1; if (strncasecmp(input.first.data(), ns_str, ns_str_size) == 0) { // namespace should keep key case-sensitive @@ -778,9 +789,7 @@ Status Config::finish() { if (master_port != 0 && binds.size() == 0) { return {Status::NotOK, "replication doesn't support unix socket"}; } - if (backup_dir_.empty()) backup_dir_ = dir + "/backup"; if (db_dir.empty()) db_dir = dir + "/db"; - if (pidfile_.empty()) pidfile_ = dir + "/kvrocks.pid"; if (log_dir.empty()) log_dir = dir; std::vector create_dirs = {dir}; for (const auto &name : create_dirs) { @@ -876,11 +885,20 @@ Status Config::Set(Server *srv, std::string key, const std::string &value) { if (!s.IsOK()) return s.Prefixed("invalid value"); } + auto origin_value = field->ToString(); auto s = field->Set(value); if (!s.IsOK()) return s.Prefixed("failed to set new value"); if (field->callback) { - return field->callback(srv, key, value); + s = field->callback(srv, key, value); + if (!s.IsOK()) { + // rollback the value if the callback failed + auto set_status = field->Set(origin_value); + if (!set_status.IsOK()) { + return set_status.Prefixed("failed to rollback the value"); + } + } + return s; } return Status::OK(); diff --git a/src/config/config.h b/src/config/config.h index 2477a90bf3a..68a606449f4 100644 --- a/src/config/config.h +++ b/src/config/config.h @@ -57,14 +57,6 @@ constexpr const char *kDefaultNamespace = "__namespace"; enum class BlockCacheType { kCacheTypeLRU = 0, kCacheTypeHCC }; -struct CompactionCheckerRange { - public: - int start; - int stop; - - bool Enabled() const { return start != -1 || stop != -1; } -}; - struct CLIOptions { std::string conf_file; std::vector> cli_options; @@ -122,6 +114,8 @@ struct Config { std::vector binds; std::string dir; std::string db_dir; + std::string backup_dir; // GUARD_BY(backup_mu_) + std::string pidfile; std::string backup_sync_dir; std::string checkpoint_dir; std::string sync_checkpoint_dir; @@ -135,7 +129,8 @@ struct Config { uint32_t master_port = 0; Cron compact_cron; Cron bgsave_cron; - CompactionCheckerRange compaction_checker_range{-1, -1}; + Cron dbsize_scan_cron; + Cron compaction_checker_cron; int64_t force_compact_file_age; int force_compact_file_min_deleted_percentage; bool repl_namespace_enabled = false; @@ -200,6 +195,7 @@ struct Config { int level0_stop_writes_trigger; int level0_file_num_compaction_trigger; rocksdb::CompressionType compression; + int compression_level; bool disable_auto_compactions; bool enable_blob_files; int min_blob_size; @@ -237,18 +233,16 @@ struct Config { void ClearMaster(); bool IsSlave() const { return !master_host.empty(); } bool HasConfigFile() const { return !path_.empty(); } - std::string GetBackupDir() const { return backup_dir_.empty() ? dir + "/backup" : backup_dir_; } - std::string GetPidFile() const { return pidfile_.empty() ? dir + "/kvrocks.pid" : pidfile_; } private: std::string path_; - std::string backup_dir_; // GUARD_BY(backup_mu_) - std::string pidfile_; std::string binds_str_; std::string slaveof_; std::string compact_cron_str_; std::string bgsave_cron_str_; + std::string dbsize_scan_cron_str_; std::string compaction_checker_range_str_; + std::string compaction_checker_cron_str_; std::string profiling_sample_commands_str_; std::map> fields_; std::vector rename_command_; diff --git a/src/search/README.md b/src/search/README.md new file mode 100644 index 00000000000..65c2baccb31 --- /dev/null +++ b/src/search/README.md @@ -0,0 +1,27 @@ +## KQIR: Kvrocks Query Intermediate Representation + +Here, *KQIR* refers to both +- the multiple-level *query intermediate representation* for Apache Kvrocks, and +- the *architecture and toolset* for the query optimization and execution. + +### Architecture + +![Architecture of KQIR](../../assets/KQIR.png) + +### Components + +- User Interface: both SQL and Redis Query syntax is supported to be the frontend language of KQIR + - SQL Parser: A parser that accepts an extended subset of MySQL syntax + - Redis Query Parser: A parser that accepts [Redis query syntax](https://redis.io/docs/latest/develop/interact/search-and-query/advanced-concepts/query_syntax/) (only DIALECT 2 or greater is planned to be supported) +- KQIR: a multiple level query intermediate representation, currently including two levels (syntactical IR and planning IR) + - Syntactical IR: A high level IR that syntactically represents the query language + - Planning IR: A low level IR that represents plan operators for query execution +- KQIR passes: analysis and transformation procedures on KQIR + - Semantic Checker: to check if there is any semantic errors in the IR + - Expression Passes: passes for query expressions, especially for logical expressions + - Numeric Passes: passes for numeric & arithmetic properties + - Plan Passes: passes on the plan operators + - Pass Manager: to manage the pass execution sequence and order + - Cost Model: to analyze the cost for the current plan, used by some plan passes +- Plan Executor: a component for query execution via iterator model +- Indexer: to perform the indexing for various types of fields during data changes diff --git a/src/search/common_parser.h b/src/search/common_parser.h new file mode 100644 index 00000000000..7d617dbdc3d --- /dev/null +++ b/src/search/common_parser.h @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include + +namespace kqir { + +namespace peg = tao::pegtl; + +struct True : peg::string<'t', 'r', 'u', 'e'> {}; +struct False : peg::string<'f', 'a', 'l', 's', 'e'> {}; +struct Boolean : peg::sor {}; + +struct Digits : peg::plus {}; +struct NumberExp : peg::seq, peg::opt>, Digits> {}; +struct NumberFrac : peg::seq, Digits> {}; +struct Number : peg::seq>, Digits, peg::opt, peg::opt> {}; + +struct UnicodeXDigit : peg::list, peg::rep<4, peg::xdigit>>, peg::one<'\\'>> {}; +struct EscapedSingleChar : peg::one<'"', '\\', 'b', 'f', 'n', 'r', 't'> {}; +struct EscapedChar : peg::sor {}; +struct UnescapedChar : peg::utf8::range<0x20, 0x10FFFF> {}; +struct Char : peg::if_then_else, EscapedChar, UnescapedChar> {}; + +struct StringContent : peg::until>, Char> {}; +struct StringL : peg::seq, StringContent, peg::any> {}; + +struct Identifier : peg::identifier {}; + +struct WhiteSpace : peg::one<' ', '\t', '\n', '\r'> {}; +template +struct WSPad : peg::pad {}; + +struct UnsignedInteger : Digits {}; +struct Integer : peg::seq>, Digits> {}; + +} // namespace kqir diff --git a/src/search/common_transformer.h b/src/search/common_transformer.h new file mode 100644 index 00000000000..8febbb4ce73 --- /dev/null +++ b/src/search/common_transformer.h @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include +#include +#include + +#include "common_parser.h" +#include "status.h" + +namespace kqir { + +struct TreeTransformer { + using TreeNode = std::unique_ptr; + + template + static bool Is(const TreeNode& node) { + return node->type == peg::demangle(); + } + + static bool IsRoot(const TreeNode& node) { return node->type.empty(); } + + static StatusOr UnescapeString(std::string_view str) { + str = str.substr(1, str.size() - 2); + + std::string result; + while (!str.empty()) { + if (str[0] == '\\') { + str.remove_prefix(1); + switch (str[0]) { + case '\\': + case '"': + result.push_back(str[0]); + break; + case 'b': + result.push_back('\b'); + break; + case 'f': + result.push_back('\f'); + break; + case 'n': + result.push_back('\n'); + break; + case 'r': + result.push_back('\r'); + break; + case 't': + result.push_back('\t'); + break; + case 'u': + if (!peg::unescape::utf8_append_utf32( + result, peg::unescape::unhex_string(str.data() + 1, str.data() + 5))) { + return {Status::NotOK, + fmt::format("invalid Unicode code point '{}' in string literal", std::string(str.data() + 1, 4))}; + } + str.remove_prefix(4); + break; + default: + __builtin_unreachable(); + }; + str.remove_prefix(1); + } else { + result.push_back(str[0]); + str.remove_prefix(1); + } + } + + return result; + } +}; + +} // namespace kqir diff --git a/src/search/executors/filter_executor.h b/src/search/executors/filter_executor.h new file mode 100644 index 00000000000..df14b29b80a --- /dev/null +++ b/src/search/executors/filter_executor.h @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include + +#include "parse_util.h" +#include "search/ir.h" +#include "search/plan_executor.h" +#include "search/search_encoding.h" +#include "string_util.h" + +namespace kqir { + +struct QueryExprEvaluator { + ExecutorContext *ctx; + ExecutorNode::RowType &row; + + StatusOr Transform(QueryExpr *e) const { + if (auto v = dynamic_cast(e)) { + return Visit(v); + } + if (auto v = dynamic_cast(e)) { + return Visit(v); + } + if (auto v = dynamic_cast(e)) { + return Visit(v); + } + if (auto v = dynamic_cast(e)) { + return Visit(v); + } + if (auto v = dynamic_cast(e)) { + return Visit(v); + } + + CHECK(false) << "unreachable"; + } + + StatusOr Visit(AndExpr *v) const { + for (const auto &n : v->inners) { + if (!GET_OR_RET(Transform(n.get()))) return false; + } + + return true; + } + + StatusOr Visit(OrExpr *v) const { + for (const auto &n : v->inners) { + if (GET_OR_RET(Transform(n.get()))) return true; + } + + return false; + } + + StatusOr Visit(NotExpr *v) const { return !GET_OR_RET(Transform(v->inner.get())); } + + StatusOr Visit(TagContainExpr *v) const { + auto val = GET_OR_RET(ctx->Retrieve(row, v->field->info)); + + CHECK(val.Is()); + auto tags = val.Get(); + + auto meta = v->field->info->MetadataAs(); + if (meta->case_sensitive) { + return std::find(tags.begin(), tags.end(), v->tag->val) != tags.end(); + } else { + return std::find_if(tags.begin(), tags.end(), + [v](const auto &tag) { return util::EqualICase(tag, v->tag->val); }) != tags.end(); + } + } + + StatusOr Visit(NumericCompareExpr *v) const { + auto l_val = GET_OR_RET(ctx->Retrieve(row, v->field->info)); + + CHECK(l_val.Is()); + auto l = l_val.Get(); + auto r = v->num->val; + + switch (v->op) { + case NumericCompareExpr::EQ: + return l == r; + case NumericCompareExpr::NE: + return l != r; + case NumericCompareExpr::LT: + return l < r; + case NumericCompareExpr::LET: + return l <= r; + case NumericCompareExpr::GT: + return l > r; + case NumericCompareExpr::GET: + return l >= r; + default: + CHECK(false) << "unreachable"; + __builtin_unreachable(); + } + } +}; + +struct FilterExecutor : ExecutorNode { + Filter *filter; + + FilterExecutor(ExecutorContext *ctx, Filter *filter) : ExecutorNode(ctx), filter(filter) {} + + StatusOr Next() override { + while (true) { + auto v = GET_OR_RET(ctx->Get(filter->source)->Next()); + + if (std::holds_alternative(v)) return end; + + QueryExprEvaluator eval{ctx, std::get(v)}; + + bool res = GET_OR_RET(eval.Transform(filter->filter_expr.get())); + + if (res) { + return v; + } + } + } +}; + +} // namespace kqir diff --git a/src/search/executors/full_index_scan_executor.h b/src/search/executors/full_index_scan_executor.h new file mode 100644 index 00000000000..3fde9ef5622 --- /dev/null +++ b/src/search/executors/full_index_scan_executor.h @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include "db_util.h" +#include "search/plan_executor.h" +#include "storage/redis_db.h" +#include "storage/redis_metadata.h" +#include "storage/storage.h" + +namespace kqir { + +struct FullIndexScanExecutor : ExecutorNode { + FullIndexScan *scan; + redis::LatestSnapShot ss; + util::UniqueIterator iter{nullptr}; + const std::string *prefix_iter; + + FullIndexScanExecutor(ExecutorContext *ctx, FullIndexScan *scan) + : ExecutorNode(ctx), scan(scan), ss(ctx->storage), prefix_iter(scan->index->info->prefixes.begin()) {} + + std::string NSKey(const std::string &user_key) { + return ComposeNamespaceKey(scan->index->info->ns, user_key, ctx->storage->IsSlotIdEncoded()); + } + + StatusOr Next() override { + if (prefix_iter == scan->index->info->prefixes.end()) { + return end; + } + + auto ns_key = NSKey(*prefix_iter); + if (!iter) { + rocksdb::ReadOptions read_options = ctx->storage->DefaultScanOptions(); + read_options.snapshot = ss.GetSnapShot(); + iter = util::UniqueIterator(ctx->storage, read_options, ctx->storage->GetCFHandle(ColumnFamilyID::Metadata)); + iter->Seek(ns_key); + } + + while (!iter->Valid() || !iter->key().starts_with(ns_key)) { + prefix_iter++; + if (prefix_iter == scan->index->info->prefixes.end()) { + return end; + } + + ns_key = NSKey(*prefix_iter); + iter->Seek(ns_key); + } + + auto [_, key] = ExtractNamespaceKey(iter->key(), ctx->storage->IsSlotIdEncoded()); + auto key_str = key.ToString(); + + iter->Next(); + return RowType{key_str, {}, scan->index->info}; + } +}; + +} // namespace kqir diff --git a/src/search/executors/limit_executor.h b/src/search/executors/limit_executor.h new file mode 100644 index 00000000000..8b1d4916c9f --- /dev/null +++ b/src/search/executors/limit_executor.h @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include "search/plan_executor.h" + +namespace kqir { + +struct LimitExecutor : ExecutorNode { + Limit *limit; + size_t step = 0; + + LimitExecutor(ExecutorContext *ctx, Limit *limit) : ExecutorNode(ctx), limit(limit) {} + + StatusOr Next() override { + auto offset = limit->limit->offset; + auto count = limit->limit->count; + + if (step == count) { + return end; + } + + if (step == 0) { + while (offset--) { + auto res = GET_OR_RET(ctx->Get(limit->op)->Next()); + + if (std::holds_alternative(res)) { + return end; + } + } + } + + auto res = GET_OR_RET(ctx->Get(limit->op)->Next()); + step++; + return res; + } +}; + +} // namespace kqir diff --git a/src/search/executors/merge_executor.h b/src/search/executors/merge_executor.h new file mode 100644 index 00000000000..66b7bb85650 --- /dev/null +++ b/src/search/executors/merge_executor.h @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include + +#include "search/plan_executor.h" + +namespace kqir { + +struct MergeExecutor : ExecutorNode { + Merge *merge; + decltype(merge->ops)::iterator iter; + + MergeExecutor(ExecutorContext *ctx, Merge *merge) : ExecutorNode(ctx), merge(merge), iter(merge->ops.begin()) {} + + StatusOr Next() override { + if (iter == merge->ops.end()) { + return end; + } + + auto v = GET_OR_RET(ctx->Get(*iter)->Next()); + while (std::holds_alternative(v)) { + iter++; + if (iter == merge->ops.end()) { + return end; + } + + v = GET_OR_RET(ctx->Get(*iter)->Next()); + } + + return v; + } +}; + +} // namespace kqir diff --git a/src/search/executors/mock_executor.h b/src/search/executors/mock_executor.h new file mode 100644 index 00000000000..f9cdf57d131 --- /dev/null +++ b/src/search/executors/mock_executor.h @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include + +#include "search/ir_plan.h" +#include "search/plan_executor.h" + +namespace kqir { + +// this operator is only for executor-testing/debugging purpose +struct Mock : PlanOperator { + std::vector rows; + + explicit Mock(std::vector rows) : rows(std::move(rows)) {} + + std::string Dump() const override { return "mock"; } + std::string_view Name() const override { return "Mock"; } + + std::unique_ptr Clone() const override { return std::make_unique(rows); } +}; + +struct MockExecutor : ExecutorNode { + Mock *mock; + decltype(mock->rows)::iterator iter; + + MockExecutor(ExecutorContext *ctx, Mock *mock) : ExecutorNode(ctx), mock(mock), iter(mock->rows.begin()) {} + + StatusOr Next() override { + if (iter == mock->rows.end()) { + return end; + } + + return *(iter++); + } +}; + +} // namespace kqir diff --git a/src/search/executors/noop_executor.h b/src/search/executors/noop_executor.h new file mode 100644 index 00000000000..1e3685cac50 --- /dev/null +++ b/src/search/executors/noop_executor.h @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include "search/plan_executor.h" + +namespace kqir { + +struct NoopExecutor : ExecutorNode { + Noop *noop; + + NoopExecutor(ExecutorContext *ctx, Noop *noop) : ExecutorNode(ctx), noop(noop) {} + + StatusOr Next() override { return end; } +}; + +} // namespace kqir diff --git a/src/search/executors/numeric_field_scan_executor.h b/src/search/executors/numeric_field_scan_executor.h new file mode 100644 index 00000000000..5c997df0fc1 --- /dev/null +++ b/src/search/executors/numeric_field_scan_executor.h @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include + +#include "db_util.h" +#include "encoding.h" +#include "search/plan_executor.h" +#include "search/search_encoding.h" +#include "search/value.h" +#include "storage/redis_db.h" +#include "storage/redis_metadata.h" +#include "storage/storage.h" + +namespace kqir { + +struct NumericFieldScanExecutor : ExecutorNode { + NumericFieldScan *scan; + redis::LatestSnapShot ss; + util::UniqueIterator iter{nullptr}; + + IndexInfo *index; + redis::SearchKey search_key; + + NumericFieldScanExecutor(ExecutorContext *ctx, NumericFieldScan *scan) + : ExecutorNode(ctx), + scan(scan), + ss(ctx->storage), + index(scan->field->info->index), + search_key(index->ns, index->name, scan->field->name) {} + + std::string IndexKey(double num) const { return search_key.ConstructNumericFieldData(num, {}); } + + bool InRangeDecode(Slice key, double *curr, Slice *user_key) const { + uint8_t ns_size = 0; + if (!GetFixed8(&key, &ns_size)) return false; + if (ns_size != index->ns.size()) return false; + if (!key.starts_with(index->ns)) return false; + key.remove_prefix(ns_size); + + uint8_t subkey_type = 0; + if (!GetFixed8(&key, &subkey_type)) return false; + if (subkey_type != (uint8_t)redis::SearchSubkeyType::FIELD) return false; + + Slice value; + if (!GetSizedString(&key, &value)) return false; + if (value != index->name) return false; + + if (!GetSizedString(&key, &value)) return false; + if (value != scan->field->name) return false; + + if (!GetDouble(&key, curr)) return false; + + if (!GetSizedString(&key, user_key)) return false; + + return true; + } + + StatusOr Next() override { + if (!iter) { + rocksdb::ReadOptions read_options = ctx->storage->DefaultScanOptions(); + read_options.snapshot = ss.GetSnapShot(); + + iter = util::UniqueIterator(ctx->storage, read_options, ctx->storage->GetCFHandle(ColumnFamilyID::Search)); + if (scan->order == SortByClause::ASC) { + iter->Seek(IndexKey(scan->range.l)); + } else { + iter->SeekForPrev(IndexKey(IntervalSet::PrevNum(scan->range.r))); + } + } + + if (!iter->Valid()) { + return end; + } + + double curr = 0; + Slice user_key; + if (!InRangeDecode(iter->key(), &curr, &user_key)) { + return end; + } + + if (scan->order == SortByClause::ASC ? curr >= scan->range.r : curr < scan->range.l) { + return end; + } + + auto key_str = user_key.ToString(); + + if (scan->order == SortByClause::ASC) { + iter->Next(); + } else { + iter->Prev(); + } + return RowType{key_str, {{scan->field->info, kqir::MakeValue(curr)}}, scan->field->info->index}; + } +}; + +} // namespace kqir diff --git a/src/search/executors/projection_executor.h b/src/search/executors/projection_executor.h new file mode 100644 index 00000000000..fe167334500 --- /dev/null +++ b/src/search/executors/projection_executor.h @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include + +#include "search/plan_executor.h" + +namespace kqir { + +struct ProjectionExecutor : ExecutorNode { + Projection *proj; + + ProjectionExecutor(ExecutorContext *ctx, Projection *proj) : ExecutorNode(ctx), proj(proj) {} + + StatusOr Next() override { + auto v = GET_OR_RET(ctx->Get(proj->source)->Next()); + + if (std::holds_alternative(v)) return end; + + auto &row = std::get(v); + if (proj->select->fields.empty()) { + for (const auto &field : row.index->fields) { + GET_OR_RET(ctx->Retrieve(row, &field.second)); + } + } else { + std::map res; + + for (const auto &field : proj->select->fields) { + auto r = GET_OR_RET(ctx->Retrieve(row, field->info)); + res.emplace(field->info, std::move(r)); + } + + return RowType{row.key, res, row.index}; + } + + return v; + } +}; + +} // namespace kqir diff --git a/src/search/executors/sort_executor.h b/src/search/executors/sort_executor.h new file mode 100644 index 00000000000..ed4b205db57 --- /dev/null +++ b/src/search/executors/sort_executor.h @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include "search/plan_executor.h" + +namespace kqir { + +struct SortExecutor : ExecutorNode { + Sort *sort; + + SortExecutor(ExecutorContext *ctx, Sort *sort) : ExecutorNode(ctx), sort(sort) {} + + StatusOr Next() override { + // most of the sort operator will be eliminated via the optimizer passes, + // so currently we don't support this operator since external sort is a little complicated + return {Status::NotSupported, "sort operator is currently not supported"}; + } +}; + +} // namespace kqir diff --git a/src/search/executors/tag_field_scan_executor.h b/src/search/executors/tag_field_scan_executor.h new file mode 100644 index 00000000000..946db767432 --- /dev/null +++ b/src/search/executors/tag_field_scan_executor.h @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include + +#include "db_util.h" +#include "encoding.h" +#include "search/plan_executor.h" +#include "search/search_encoding.h" +#include "storage/redis_db.h" +#include "storage/redis_metadata.h" +#include "storage/storage.h" +#include "string_util.h" + +namespace kqir { + +struct TagFieldScanExecutor : ExecutorNode { + TagFieldScan *scan; + redis::LatestSnapShot ss; + util::UniqueIterator iter{nullptr}; + + IndexInfo *index; + std::string index_key; + bool case_sensitive; + + TagFieldScanExecutor(ExecutorContext *ctx, TagFieldScan *scan) + : ExecutorNode(ctx), + scan(scan), + ss(ctx->storage), + index(scan->field->info->index), + index_key(redis::SearchKey(index->ns, index->name, scan->field->name).ConstructTagFieldData(scan->tag, {})), + case_sensitive(scan->field->info->MetadataAs()->case_sensitive) {} + + bool InRangeDecode(Slice key, Slice *user_key) const { + uint8_t ns_size = 0; + if (!GetFixed8(&key, &ns_size)) return false; + if (ns_size != index->ns.size()) return false; + if (!key.starts_with(index->ns)) return false; + key.remove_prefix(ns_size); + + uint8_t subkey_type = 0; + if (!GetFixed8(&key, &subkey_type)) return false; + if (subkey_type != (uint8_t)redis::SearchSubkeyType::FIELD) return false; + + Slice value; + if (!GetSizedString(&key, &value)) return false; + if (value != index->name) return false; + + if (!GetSizedString(&key, &value)) return false; + if (value != scan->field->name) return false; + + if (!GetSizedString(&key, &value)) return false; + if (case_sensitive ? value != scan->tag : !util::EqualICase(value.ToStringView(), scan->tag)) return false; + + if (!GetSizedString(&key, user_key)) return false; + + return true; + } + + StatusOr Next() override { + if (!iter) { + rocksdb::ReadOptions read_options = ctx->storage->DefaultScanOptions(); + read_options.snapshot = ss.GetSnapShot(); + + iter = util::UniqueIterator(ctx->storage, read_options, ctx->storage->GetCFHandle(ColumnFamilyID::Search)); + iter->Seek(index_key); + } + + if (!iter->Valid()) { + return end; + } + + Slice user_key; + if (!InRangeDecode(iter->key(), &user_key)) { + return end; + } + + auto key_str = user_key.ToString(); + + iter->Next(); + return RowType{key_str, {}, scan->field->info->index}; + } +}; + +} // namespace kqir diff --git a/src/search/executors/topn_sort_executor.h b/src/search/executors/topn_sort_executor.h new file mode 100644 index 00000000000..163b1bc7f3c --- /dev/null +++ b/src/search/executors/topn_sort_executor.h @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include +#include + +#include "parse_util.h" +#include "search/plan_executor.h" + +namespace kqir { + +struct TopNSortExecutor : ExecutorNode { + TopNSort *sort; + + struct ComparedRow { + RowType row; + double val; + + ComparedRow(RowType row, double val) : row(std::move(row)), val(val) {} + + friend bool operator<(const ComparedRow &l, const ComparedRow &r) { return l.val < r.val; } + }; + + std::vector rows; + decltype(rows)::iterator rows_iter; + bool initialized = false; + + TopNSortExecutor(ExecutorContext *ctx, TopNSort *sort) : ExecutorNode(ctx), sort(sort) {} + + StatusOr Next() override { + if (!initialized) { + auto total = sort->limit->offset + sort->limit->count; + if (total == 0) return end; + + auto v = GET_OR_RET(ctx->Get(sort->op)->Next()); + + while (!std::holds_alternative(v)) { + auto &row = std::get(v); + + auto get_order = [this](RowType &row) -> StatusOr { + auto order_val = GET_OR_RET(ctx->Retrieve(row, sort->order->field->info)); + CHECK(order_val.Is()); + return order_val.Get(); + }; + + if (rows.size() == total) { + std::make_heap(rows.begin(), rows.end()); + } + + if (rows.size() < total) { + auto order = GET_OR_RET(get_order(row)); + rows.emplace_back(row, order); + } else { + auto order = GET_OR_RET(get_order(row)); + + if (order < rows[0].val) { + std::pop_heap(rows.begin(), rows.end()); + rows.back() = ComparedRow{row, order}; + std::push_heap(rows.begin(), rows.end()); + } + } + + v = GET_OR_RET(ctx->Get(sort->op)->Next()); + } + + if (rows.size() <= sort->limit->offset) { + return end; + } + + std::sort(rows.begin(), rows.end()); + rows_iter = rows.begin() + static_cast(sort->limit->offset); + initialized = true; + } + + if (rows_iter == rows.end()) { + return end; + } + + auto res = rows_iter->row; + rows_iter++; + return res; + } +}; + +} // namespace kqir diff --git a/src/search/index_info.h b/src/search/index_info.h new file mode 100644 index 00000000000..3badc372adc --- /dev/null +++ b/src/search/index_info.h @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include "search_encoding.h" +#include "storage/redis_metadata.h" + +namespace kqir { + +struct IndexInfo; + +struct FieldInfo { + std::string name; + IndexInfo *index = nullptr; + std::unique_ptr metadata; + + FieldInfo(std::string name, std::unique_ptr &&metadata) + : name(std::move(name)), metadata(std::move(metadata)) {} + + bool IsSortable() const { return metadata->IsSortable(); } + bool HasIndex() const { return !metadata->noindex; } + + template + const T *MetadataAs() const { + return dynamic_cast(metadata.get()); + } +}; + +struct IndexInfo { + using FieldMap = std::map; + + std::string name; + redis::IndexMetadata metadata; + FieldMap fields; + redis::IndexPrefixes prefixes; + std::string ns; + + IndexInfo(std::string name, redis::IndexMetadata metadata, std::string ns) + : name(std::move(name)), metadata(std::move(metadata)), ns(std::move(ns)) {} + + void Add(FieldInfo &&field) { + const auto &name = field.name; + field.index = this; + fields.emplace(name, std::move(field)); + } +}; + +struct IndexMap : std::map> { + auto Insert(std::unique_ptr index_info) { + auto key = ComposeNamespaceKey(index_info->ns, index_info->name, false); + return emplace(key, std::move(index_info)); + } + + auto Find(std::string_view index, std::string_view ns) const { return find(ComposeNamespaceKey(ns, index, false)); } +}; + +} // namespace kqir diff --git a/src/search/index_manager.h b/src/search/index_manager.h new file mode 100644 index 00000000000..1d7447047ee --- /dev/null +++ b/src/search/index_manager.h @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include "db_util.h" +#include "encoding.h" +#include "search/index_info.h" +#include "search/indexer.h" +#include "search/ir.h" +#include "search/ir_sema_checker.h" +#include "search/passes/manager.h" +#include "search/plan_executor.h" +#include "search/search_encoding.h" +#include "status.h" +#include "storage/storage.h" +#include "string_util.h" + +namespace redis { + +struct IndexManager { + kqir::IndexMap index_map; + GlobalIndexer *indexer; + engine::Storage *storage; + + IndexManager(GlobalIndexer *indexer, engine::Storage *storage) : indexer(indexer), storage(storage) {} + + Status Load(const std::string &ns) { + // currently index cannot work in cluster mode + if (storage->GetConfig()->cluster_enabled) { + return Status::OK(); + } + + util::UniqueIterator iter(storage, storage->DefaultScanOptions(), ColumnFamilyID::Search); + auto begin = SearchKey{ns, ""}.ConstructIndexMeta(); + + for (iter->Seek(begin); iter->Valid(); iter->Next()) { + auto key = iter->key(); + + uint8_t ns_size = 0; + if (!GetFixed8(&key, &ns_size)) break; + if (ns_size != ns.size()) break; + if (!key.starts_with(ns)) break; + key.remove_prefix(ns_size); + + uint8_t subkey_type = 0; + if (!GetFixed8(&key, &subkey_type)) break; + if (subkey_type != (uint8_t)SearchSubkeyType::INDEX_META) break; + + Slice index_name; + if (!GetSizedString(&key, &index_name)) break; + + IndexMetadata metadata; + auto index_meta_value = iter->value(); + if (auto s = metadata.Decode(&index_meta_value); !s.ok()) { + return {Status::NotOK, fmt::format("fail to decode index metadata for index {}: {}", index_name, s.ToString())}; + } + + auto index_key = SearchKey(ns, index_name.ToStringView()); + std::string prefix_value; + if (auto s = storage->Get(storage->DefaultMultiGetOptions(), storage->GetCFHandle(ColumnFamilyID::Search), + index_key.ConstructIndexPrefixes(), &prefix_value); + !s.ok()) { + return {Status::NotOK, fmt::format("fail to find index prefixes for index {}: {}", index_name, s.ToString())}; + } + + IndexPrefixes prefixes; + Slice prefix_slice = prefix_value; + if (auto s = prefixes.Decode(&prefix_slice); !s.ok()) { + return {Status::NotOK, fmt::format("fail to decode index prefixes for index {}: {}", index_name, s.ToString())}; + } + + auto info = std::make_unique(index_name.ToString(), metadata, ns); + info->prefixes = prefixes; + + util::UniqueIterator field_iter(storage, storage->DefaultScanOptions(), ColumnFamilyID::Search); + auto field_begin = index_key.ConstructFieldMeta(); + + for (field_iter->Seek(field_begin); field_iter->Valid(); field_iter->Next()) { + auto key = field_iter->key(); + + uint8_t ns_size = 0; + if (!GetFixed8(&key, &ns_size)) break; + if (ns_size != ns.size()) break; + if (!key.starts_with(ns)) break; + key.remove_prefix(ns_size); + + uint8_t subkey_type = 0; + if (!GetFixed8(&key, &subkey_type)) break; + if (subkey_type != (uint8_t)SearchSubkeyType::FIELD_META) break; + + Slice value; + if (!GetSizedString(&key, &value)) break; + if (value != index_name) break; + + if (!GetSizedString(&key, &value)) break; + + auto field_name = value; + auto field_value = field_iter->value(); + + std::unique_ptr field_meta; + if (auto s = IndexFieldMetadata::Decode(&field_value, field_meta); !s.ok()) { + return {Status::NotOK, fmt::format("fail to decode index field metadata for index {}, field {}: {}", + index_name, field_name, s.ToString())}; + } + + info->Add(kqir::FieldInfo(field_name.ToString(), std::move(field_meta))); + } + + IndexUpdater updater(info.get()); + indexer->Add(updater); + index_map.Insert(std::move(info)); + } + + return Status::OK(); + } + + Status Create(std::unique_ptr info) { + if (storage->GetConfig()->cluster_enabled) { + return {Status::NotOK, "currently index cannot work in cluster mode"}; + } + + if (auto iter = index_map.Find(info->name, info->ns); iter != index_map.end()) { + return {Status::NotOK, "index already exists"}; + } + + SearchKey index_key(info->ns, info->name); + auto cf = storage->GetCFHandle(ColumnFamilyID::Search); + + auto batch = storage->GetWriteBatchBase(); + + std::string meta_val; + info->metadata.Encode(&meta_val); + batch->Put(cf, index_key.ConstructIndexMeta(), meta_val); + + std::string prefix_val; + info->prefixes.Encode(&prefix_val); + batch->Put(cf, index_key.ConstructIndexPrefixes(), prefix_val); + + for (const auto &[_, field_info] : info->fields) { + SearchKey field_key(info->ns, info->name, field_info.name); + + std::string field_val; + field_info.metadata->Encode(&field_val); + + batch->Put(cf, field_key.ConstructFieldMeta(), field_val); + } + + if (auto s = storage->Write(storage->DefaultWriteOptions(), batch->GetWriteBatch()); !s.ok()) { + return {Status::NotOK, fmt::format("failed to write index metadata: {}", s.ToString())}; + } + + IndexUpdater updater(info.get()); + indexer->Add(updater); + index_map.Insert(std::move(info)); + + for (auto updater : indexer->updater_list) { + GET_OR_RET(updater.Build()); + } + + return Status::OK(); + } + + StatusOr> GeneratePlan(std::unique_ptr ir, + const std::string &ns) const { + kqir::SemaChecker sema_checker(index_map); + sema_checker.ns = ns; + + GET_OR_RET(sema_checker.Check(ir.get())); + + auto plan_ir = kqir::PassManager::Execute(kqir::PassManager::Default(), std::move(ir)); + std::unique_ptr plan_op; + if (plan_op = kqir::Node::As(std::move(plan_ir)); !plan_op) { + return {Status::NotOK, "failed to convert the query to plan operators"}; + } + + return plan_op; + } + + StatusOr> Search(std::unique_ptr ir, + const std::string &ns) const { + auto plan_op = GET_OR_RET(GeneratePlan(std::move(ir), ns)); + + kqir::ExecutorContext executor_ctx(plan_op.get(), storage); + + std::vector results; + + auto iter_res = GET_OR_RET(executor_ctx.Next()); + while (!std::holds_alternative(iter_res)) { + results.push_back(std::get(iter_res)); + + iter_res = GET_OR_RET(executor_ctx.Next()); + } + + return results; + } + + Status Drop(std::string_view index_name, const std::string &ns) { + auto iter = index_map.Find(index_name, ns); + if (iter == index_map.end()) { + return {Status::NotOK, "index not found"}; + } + + auto info = iter->second.get(); + indexer->Remove(info); + + SearchKey index_key(info->ns, info->name); + auto cf = storage->GetCFHandle(ColumnFamilyID::Search); + + auto batch = storage->GetWriteBatchBase(); + + batch->Delete(cf, index_key.ConstructIndexMeta()); + batch->Delete(cf, index_key.ConstructIndexPrefixes()); + + auto begin = index_key.ConstructAllFieldMetaBegin(); + auto end = index_key.ConstructAllFieldMetaEnd(); + batch->DeleteRange(cf, begin, end); + + begin = index_key.ConstructAllFieldDataBegin(); + end = index_key.ConstructAllFieldDataEnd(); + batch->DeleteRange(cf, begin, end); + + if (auto s = storage->Write(storage->DefaultWriteOptions(), batch->GetWriteBatch()); !s.ok()) { + return {Status::NotOK, fmt::format("failed to delete index metadata and data: {}", s.ToString())}; + } + + index_map.erase(iter); + + return Status::OK(); + } +}; + +} // namespace redis diff --git a/src/search/indexer.cc b/src/search/indexer.cc index c082133ec0b..7ce0b3d013b 100644 --- a/src/search/indexer.cc +++ b/src/search/indexer.cc @@ -23,8 +23,10 @@ #include #include +#include "db_util.h" #include "parse_util.h" #include "search/search_encoding.h" +#include "search/value.h" #include "storage/redis_metadata.h" #include "storage/storage.h" #include "string_util.h" @@ -32,16 +34,16 @@ namespace redis { -StatusOr FieldValueRetriever::Create(SearchOnDataType type, std::string_view key, +StatusOr FieldValueRetriever::Create(IndexOnDataType type, std::string_view key, engine::Storage *storage, const std::string &ns) { - if (type == SearchOnDataType::HASH) { + if (type == IndexOnDataType::HASH) { Hash db(storage, ns); std::string ns_key = db.AppendNamespacePrefix(key); HashMetadata metadata(false); - auto s = db.GetMetadata(ns_key, &metadata); + auto s = db.GetMetadata(Database::GetOptions{}, ns_key, &metadata); if (!s.ok()) return {Status::NotOK, s.ToString()}; return FieldValueRetriever(db, metadata, key); - } else if (type == SearchOnDataType::JSON) { + } else if (type == IndexOnDataType::JSON) { Json db(storage, ns); std::string ns_key = db.AppendNamespacePrefix(key); JsonMetadata metadata(false); @@ -50,34 +52,110 @@ StatusOr FieldValueRetriever::Create(SearchOnDataType type, if (!s.ok()) return {Status::NotOK, s.ToString()}; return FieldValueRetriever(value); } else { - assert(false && "unreachable code: unexpected SearchOnDataType"); + assert(false && "unreachable code: unexpected IndexOnDataType"); __builtin_unreachable(); } } -rocksdb::Status FieldValueRetriever::Retrieve(std::string_view field, std::string *output) { +// placeholders, remove them after vector indexing is implemented +static bool IsVectorType(const redis::IndexFieldMetadata *) { return false; } +static size_t GetVectorDim(const redis::IndexFieldMetadata *) { return 1; } + +StatusOr FieldValueRetriever::ParseFromJson(const jsoncons::json &val, + const redis::IndexFieldMetadata *type) { + if (auto numeric [[maybe_unused]] = dynamic_cast(type)) { + if (!val.is_number() || val.is_string()) return {Status::NotOK, "json value cannot be string for numeric fields"}; + return kqir::MakeValue(val.as_double()); + } else if (auto tag = dynamic_cast(type)) { + if (val.is_string()) { + const char delim[] = {tag->separator, '\0'}; + auto vec = util::Split(val.as_string(), delim); + return kqir::MakeValue(vec); + } else if (val.is_array()) { + std::vector strs; + for (size_t i = 0; i < val.size(); ++i) { + if (!val[i].is_string()) + return {Status::NotOK, "json value should be string or array of strings for tag fields"}; + strs.push_back(val[i].as_string()); + } + return kqir::MakeValue(strs); + } else { + return {Status::NotOK, "json value should be string or array of strings for tag fields"}; + } + } else if (IsVectorType(type)) { + size_t dim = GetVectorDim(type); + if (!val.is_array()) return {Status::NotOK, "json value should be array of numbers for vector fields"}; + if (dim != val.size()) return {Status::NotOK, "the size of the json array is not equal to the dim of the vector"}; + std::vector nums; + for (size_t i = 0; i < dim; ++i) { + if (!val[i].is_number() || val[i].is_string()) + return {Status::NotOK, "json value should be array of numbers for vector fields"}; + nums.push_back(val[i].as_double()); + } + return kqir::MakeValue(nums); + } else { + return {Status::NotOK, "unknown field type to retrieve"}; + } +} + +StatusOr FieldValueRetriever::ParseFromHash(const std::string &value, + const redis::IndexFieldMetadata *type) { + if (auto numeric [[maybe_unused]] = dynamic_cast(type)) { + auto num = GET_OR_RET(ParseFloat(value)); + return kqir::MakeValue(num); + } else if (auto tag = dynamic_cast(type)) { + const char delim[] = {tag->separator, '\0'}; + auto vec = util::Split(value, delim); + return kqir::MakeValue(vec); + } else if (IsVectorType(type)) { + const size_t dim = GetVectorDim(type); + if (value.size() != dim * sizeof(double)) { + return {Status::NotOK, "field value is too short or too long to be parsed as a vector"}; + } + std::vector vec; + for (size_t i = 0; i < dim; ++i) { + // TODO: care about endian later + // TODO: currently only support 64bit floating point + vec.push_back(*(reinterpret_cast(value.data()) + i)); + } + return kqir::MakeValue(vec); + } else { + return {Status::NotOK, "unknown field type to retrieve"}; + } +} + +StatusOr FieldValueRetriever::Retrieve(std::string_view field, const redis::IndexFieldMetadata *type) { if (std::holds_alternative(db)) { auto &[hash, metadata, key] = std::get(db); std::string ns_key = hash.AppendNamespacePrefix(key); + LatestSnapShot ss(hash.storage_); rocksdb::ReadOptions read_options; read_options.snapshot = ss.GetSnapShot(); std::string sub_key = InternalKey(ns_key, field, metadata.version, hash.storage_->IsSlotIdEncoded()).Encode(); - return hash.storage_->Get(read_options, sub_key, output); + std::string value; + auto s = hash.storage_->Get(read_options, sub_key, &value); + if (s.IsNotFound()) return {Status::NotFound, s.ToString()}; + if (!s.ok()) return {Status::NotOK, s.ToString()}; + + return ParseFromHash(value, type); } else if (std::holds_alternative(db)) { auto &value = std::get(db); - auto s = value.Get(field); - if (!s.IsOK()) return rocksdb::Status::Corruption(s.Msg()); + + auto s = value.Get(field.front() == '$' ? field : fmt::format("$.{}", field)); + if (!s.IsOK()) return {Status::NotOK, s.Msg()}; if (s->value.size() != 1) - return rocksdb::Status::NotFound("json value specified by the field (json path) should exist and be unique"); - *output = s->value[0].as_string(); - return rocksdb::Status::OK(); + return {Status::NotFound, "json value specified by the field (json path) should exist and be unique"}; + auto val = s->value[0]; + + return ParseFromJson(val, type); } else { - __builtin_unreachable(); + return {Status::NotOK, "unknown redis data type to retrieve"}; } } -StatusOr IndexUpdater::Record(std::string_view key, const std::string &ns) { +StatusOr IndexUpdater::Record(std::string_view key) const { + const auto &ns = info->ns; Database db(indexer->storage, ns); RedisType type = kRedisNone; @@ -87,114 +165,128 @@ StatusOr IndexUpdater::Record(std::string_view key, c // key not exist if (type == kRedisNone) return FieldValues(); - if (type != static_cast(metadata.on_data_type)) { + if (type != static_cast(info->metadata.on_data_type)) { // not the expected type, stop record return {Status::TypeMismatched}; } - auto retriever = GET_OR_RET(FieldValueRetriever::Create(metadata.on_data_type, key, indexer->storage, ns)); + auto retriever = GET_OR_RET(FieldValueRetriever::Create(info->metadata.on_data_type, key, indexer->storage, ns)); FieldValues values; - for (const auto &[field, info] : fields) { - std::string value; - auto s = retriever.Retrieve(field, &value); - if (s.IsNotFound()) continue; - if (!s.ok()) return {Status::NotOK, s.ToString()}; + for (const auto &[field, i] : info->fields) { + if (i.metadata->noindex) { + continue; + } - values.emplace(field, value); + auto s = retriever.Retrieve(field, i.metadata.get()); + if (s.Is()) continue; + if (!s) return s; + + values.emplace(field, *s); } return values; } -Status IndexUpdater::UpdateIndex(const std::string &field, std::string_view key, std::string_view original, - std::string_view current, const std::string &ns) { - if (original == current) { - // the value of this field is unchanged, no need to update - return Status::OK(); +Status IndexUpdater::UpdateTagIndex(std::string_view key, const kqir::Value &original, const kqir::Value ¤t, + const SearchKey &search_key, const TagFieldMetadata *tag) const { + CHECK(original.IsNull() || original.Is()); + CHECK(current.IsNull() || current.Is()); + auto original_tags = original.IsNull() ? std::vector() : original.Get(); + auto current_tags = current.IsNull() ? std::vector() : current.Get(); + + auto to_tag_set = [](const std::vector &tags, bool case_sensitive) -> std::set { + if (case_sensitive) { + return {tags.begin(), tags.end()}; + } else { + std::set res; + std::transform(tags.begin(), tags.end(), std::inserter(res, res.begin()), util::ToLower); + return res; + } + }; + + std::set tags_to_delete = to_tag_set(original_tags, tag->case_sensitive); + std::set tags_to_add = to_tag_set(current_tags, tag->case_sensitive); + + for (auto it = tags_to_delete.begin(); it != tags_to_delete.end();) { + if (auto jt = tags_to_add.find(*it); jt != tags_to_add.end()) { + it = tags_to_delete.erase(it); + tags_to_add.erase(jt); + } else { + ++it; + } } - auto iter = fields.find(field); - if (iter == fields.end()) { - return {Status::NotOK, "No such field to do index updating"}; + if (tags_to_add.empty() && tags_to_delete.empty()) { + // no change, skip index updating + return Status::OK(); } - auto *metadata = iter->second.get(); auto *storage = indexer->storage; - auto ns_key = ComposeNamespaceKey(ns, name, storage->IsSlotIdEncoded()); - if (auto tag = dynamic_cast(metadata)) { - const char delim[] = {tag->separator, '\0'}; - auto original_tags = util::Split(original, delim); - auto current_tags = util::Split(current, delim); - - auto to_tag_set = [](const std::vector &tags, bool case_sensitive) -> std::set { - if (case_sensitive) { - return {tags.begin(), tags.end()}; - } else { - std::set res; - std::transform(tags.begin(), tags.end(), std::inserter(res, res.begin()), util::ToLower); - return res; - } - }; + auto batch = storage->GetWriteBatchBase(); + auto cf_handle = storage->GetCFHandle(ColumnFamilyID::Search); - std::set tags_to_delete = to_tag_set(original_tags, tag->case_sensitive); - std::set tags_to_add = to_tag_set(current_tags, tag->case_sensitive); + for (const auto &tag : tags_to_delete) { + auto index_key = search_key.ConstructTagFieldData(tag, key); - for (auto it = tags_to_delete.begin(); it != tags_to_delete.end();) { - if (auto jt = tags_to_add.find(*it); jt != tags_to_add.end()) { - it = tags_to_delete.erase(it); - tags_to_add.erase(jt); - } else { - ++it; - } - } + batch->Delete(cf_handle, index_key); + } - if (tags_to_add.empty() && tags_to_delete.empty()) { - // no change, skip index updating - return Status::OK(); - } + for (const auto &tag : tags_to_add) { + auto index_key = search_key.ConstructTagFieldData(tag, key); + + batch->Put(cf_handle, index_key, Slice()); + } - auto batch = storage->GetWriteBatchBase(); - auto cf_handle = storage->GetCFHandle(engine::kSearchColumnFamilyName); + auto s = storage->Write(storage->DefaultWriteOptions(), batch->GetWriteBatch()); + if (!s.ok()) return {Status::NotOK, s.ToString()}; + return Status::OK(); +} - for (const auto &tag : tags_to_delete) { - auto sub_key = ConstructTagFieldSubkey(field, tag, key); - auto index_key = InternalKey(ns_key, sub_key, this->metadata.version, storage->IsSlotIdEncoded()); +Status IndexUpdater::UpdateNumericIndex(std::string_view key, const kqir::Value &original, const kqir::Value ¤t, + const SearchKey &search_key, const NumericFieldMetadata *num) const { + CHECK(original.IsNull() || original.Is()); + CHECK(original.IsNull() || original.Is()); - batch->Delete(cf_handle, index_key.Encode()); - } + auto *storage = indexer->storage; + auto batch = storage->GetWriteBatchBase(); + auto cf_handle = storage->GetCFHandle(ColumnFamilyID::Search); - for (const auto &tag : tags_to_add) { - auto sub_key = ConstructTagFieldSubkey(field, tag, key); - auto index_key = InternalKey(ns_key, sub_key, this->metadata.version, storage->IsSlotIdEncoded()); + if (!original.IsNull()) { + auto index_key = search_key.ConstructNumericFieldData(original.Get(), key); - batch->Put(cf_handle, index_key.Encode(), Slice()); - } + batch->Delete(cf_handle, index_key); + } - auto s = storage->Write(storage->DefaultWriteOptions(), batch->GetWriteBatch()); - if (!s.ok()) return {Status::NotOK, s.ToString()}; - } else if (auto numeric [[maybe_unused]] = dynamic_cast(metadata)) { - auto batch = storage->GetWriteBatchBase(); - auto cf_handle = storage->GetCFHandle(engine::kSearchColumnFamilyName); + if (!current.IsNull()) { + auto index_key = search_key.ConstructNumericFieldData(current.Get(), key); - if (!original.empty()) { - auto original_num = GET_OR_RET(ParseFloat(std::string(original.begin(), original.end()))); - auto sub_key = ConstructNumericFieldSubkey(field, original_num, key); - auto index_key = InternalKey(ns_key, sub_key, this->metadata.version, storage->IsSlotIdEncoded()); + batch->Put(cf_handle, index_key, Slice()); + } - batch->Delete(cf_handle, index_key.Encode()); - } + auto s = storage->Write(storage->DefaultWriteOptions(), batch->GetWriteBatch()); + if (!s.ok()) return {Status::NotOK, s.ToString()}; + return Status::OK(); +} - if (!current.empty()) { - auto current_num = GET_OR_RET(ParseFloat(std::string(current.begin(), current.end()))); - auto sub_key = ConstructNumericFieldSubkey(field, current_num, key); - auto index_key = InternalKey(ns_key, sub_key, this->metadata.version, storage->IsSlotIdEncoded()); +Status IndexUpdater::UpdateIndex(const std::string &field, std::string_view key, const kqir::Value &original, + const kqir::Value ¤t) const { + if (original == current) { + // the value of this field is unchanged, no need to update + return Status::OK(); + } - batch->Put(cf_handle, index_key.Encode(), Slice()); - } + auto iter = info->fields.find(field); + if (iter == info->fields.end()) { + return {Status::NotOK, "No such field to do index updating"}; + } - auto s = storage->Write(storage->DefaultWriteOptions(), batch->GetWriteBatch()); - if (!s.ok()) return {Status::NotOK, s.ToString()}; + auto *metadata = iter->second.metadata.get(); + SearchKey search_key(info->ns, info->name, field); + if (auto tag = dynamic_cast(metadata)) { + GET_OR_RET(UpdateTagIndex(key, original, current, search_key, tag)); + } else if (auto numeric [[maybe_unused]] = dynamic_cast(metadata)) { + GET_OR_RET(UpdateNumericIndex(key, original, current, search_key, numeric)); } else { return {Status::NotOK, "Unexpected field type"}; } @@ -202,11 +294,15 @@ Status IndexUpdater::UpdateIndex(const std::string &field, std::string_view key, return Status::OK(); } -Status IndexUpdater::Update(const FieldValues &original, std::string_view key, const std::string &ns) { - auto current = GET_OR_RET(Record(key, ns)); +Status IndexUpdater::Update(const FieldValues &original, std::string_view key) const { + auto current = GET_OR_RET(Record(key)); + + for (const auto &[field, i] : info->fields) { + if (i.metadata->noindex) { + continue; + } - for (const auto &[field, _] : fields) { - std::string_view original_val, current_val; + kqir::Value original_val, current_val; if (auto it = original.find(field); it != original.end()) { original_val = it->second; @@ -215,31 +311,72 @@ Status IndexUpdater::Update(const FieldValues &original, std::string_view key, c current_val = it->second; } - GET_OR_RET(UpdateIndex(field, key, original_val, current_val, ns)); + GET_OR_RET(UpdateIndex(field, key, original_val, current_val)); + } + + return Status::OK(); +} + +Status IndexUpdater::Build() const { + auto storage = indexer->storage; + util::UniqueIterator iter(storage, storage->DefaultScanOptions(), ColumnFamilyID::Metadata); + + for (const auto &prefix : info->prefixes) { + auto ns_key = ComposeNamespaceKey(info->ns, prefix, storage->IsSlotIdEncoded()); + for (iter->Seek(ns_key); iter->Valid(); iter->Next()) { + if (!iter->key().starts_with(ns_key)) { + break; + } + + auto [_, key] = ExtractNamespaceKey(iter->key(), storage->IsSlotIdEncoded()); + + auto s = Update({}, key.ToStringView()); + if (s.Is()) continue; + if (!s.OK()) return s; + } } return Status::OK(); } void GlobalIndexer::Add(IndexUpdater updater) { - auto &up = updaters.emplace_back(std::move(updater)); - for (const auto &prefix : up.prefixes) { - prefix_map.insert(prefix, &up); + updater.indexer = this; + for (const auto &prefix : updater.info->prefixes) { + prefix_map.insert(ComposeNamespaceKey(updater.info->ns, prefix, false), updater); + } + updater_list.push_back(updater); +} + +void GlobalIndexer::Remove(const kqir::IndexInfo *index) { + for (auto iter = prefix_map.begin(); iter != prefix_map.end();) { + if (iter->info == index) { + iter = prefix_map.erase(iter); + } else { + ++iter; + } } + + updater_list.erase(std::remove_if(updater_list.begin(), updater_list.end(), + [index](IndexUpdater updater) { return updater.info == index; }), + updater_list.end()); } StatusOr GlobalIndexer::Record(std::string_view key, const std::string &ns) { - auto iter = prefix_map.longest_prefix(key); + if (updater_list.empty()) { + return Status::NoPrefixMatched; + } + + auto iter = prefix_map.longest_prefix(ComposeNamespaceKey(ns, key, false)); if (iter != prefix_map.end()) { auto updater = iter.value(); - return std::make_pair(updater, GET_OR_RET(updater->Record(key, ns))); + return RecordResult{updater, std::string(key.begin(), key.end()), GET_OR_RET(updater.Record(key))}; } return {Status::NoPrefixMatched}; } -Status GlobalIndexer::Update(const RecordResult &original, std::string_view key, const std::string &ns) { - return original.first->Update(original.second, key, ns); +Status GlobalIndexer::Update(const RecordResult &original) { + return original.updater.Update(original.fields, original.key); } } // namespace redis diff --git a/src/search/indexer.h b/src/search/indexer.h index c404ecbbc07..8ffd503b6ba 100644 --- a/src/search/indexer.h +++ b/src/search/indexer.h @@ -29,13 +29,14 @@ #include "commands/commander.h" #include "config/config.h" +#include "index_info.h" #include "indexer.h" #include "search/search_encoding.h" -#include "server/server.h" #include "storage/redis_metadata.h" #include "storage/storage.h" #include "types/redis_hash.h" #include "types/redis_json.h" +#include "value.h" namespace redis { @@ -55,7 +56,7 @@ struct FieldValueRetriever { using Variant = std::variant; Variant db; - static StatusOr Create(SearchOnDataType type, std::string_view key, engine::Storage *storage, + static StatusOr Create(IndexOnDataType type, std::string_view key, engine::Storage *storage, const std::string &ns); explicit FieldValueRetriever(Hash hash, HashMetadata metadata, std::string_view key) @@ -63,46 +64,53 @@ struct FieldValueRetriever { explicit FieldValueRetriever(JsonValue json) : db(std::in_place_type, std::move(json)) {} - rocksdb::Status Retrieve(std::string_view field, std::string *output); + StatusOr Retrieve(std::string_view field, const redis::IndexFieldMetadata *type); + + static StatusOr ParseFromJson(const jsoncons::json &value, const redis::IndexFieldMetadata *type); + static StatusOr ParseFromHash(const std::string &value, const redis::IndexFieldMetadata *type); }; struct IndexUpdater { - using FieldValues = std::map; + using FieldValues = std::map; - std::string name; - SearchMetadata metadata; - std::vector prefixes; - std::map> fields; + const kqir::IndexInfo *info = nullptr; GlobalIndexer *indexer = nullptr; - IndexUpdater(const IndexUpdater &) = delete; - IndexUpdater(IndexUpdater &&) = default; + explicit IndexUpdater(const kqir::IndexInfo *info) : info(info) {} - IndexUpdater &operator=(IndexUpdater &&) = default; - IndexUpdater &operator=(const IndexUpdater &) = delete; + StatusOr Record(std::string_view key) const; + Status UpdateIndex(const std::string &field, std::string_view key, const kqir::Value &original, + const kqir::Value ¤t) const; + Status Update(const FieldValues &original, std::string_view key) const; - ~IndexUpdater() = default; + Status Build() const; - StatusOr Record(std::string_view key, const std::string &ns); - Status UpdateIndex(const std::string &field, std::string_view key, std::string_view original, - std::string_view current, const std::string &ns); - Status Update(const FieldValues &original, std::string_view key, const std::string &ns); + Status UpdateTagIndex(std::string_view key, const kqir::Value &original, const kqir::Value ¤t, + const SearchKey &search_key, const TagFieldMetadata *tag) const; + Status UpdateNumericIndex(std::string_view key, const kqir::Value &original, const kqir::Value ¤t, + const SearchKey &search_key, const NumericFieldMetadata *num) const; }; struct GlobalIndexer { using FieldValues = IndexUpdater::FieldValues; - using RecordResult = std::pair; + struct RecordResult { + IndexUpdater updater; + std::string key; + FieldValues fields; + }; - std::deque updaters; - tsl::htrie_map prefix_map; + tsl::htrie_map prefix_map; + std::vector updater_list; engine::Storage *storage = nullptr; explicit GlobalIndexer(engine::Storage *storage) : storage(storage) {} void Add(IndexUpdater updater); + void Remove(const kqir::IndexInfo *index); + StatusOr Record(std::string_view key, const std::string &ns); - static Status Update(const RecordResult &original, std::string_view key, const std::string &ns); + static Status Update(const RecordResult &original); }; } // namespace redis diff --git a/src/search/interval.h b/src/search/interval.h new file mode 100644 index 00000000000..efe462b4074 --- /dev/null +++ b/src/search/interval.h @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include +#include +#include +#include + +#include "fmt/format.h" +#include "search/ir.h" +#include "string_util.h" + +namespace kqir { + +struct Interval { + double l, r; // [l, r) + + static inline const double inf = std::numeric_limits::infinity(); + static inline const double minf = -inf; + + Interval(double l, double r) : l(l), r(r) {} + + bool IsEmpty() const { return l >= r; } + static Interval Full() { return {minf, inf}; } + + bool operator==(const Interval &other) const { return l == other.l && r == other.r; } + bool operator!=(const Interval &other) const { return !(*this == other); } + + std::string ToString() const { return fmt::format("[{}, {})", l, r); } +}; + +template +void ForEachMerged(Iter1 first1, Iter1 last1, Iter2 first2, Iter2 last2, F &&f) { + while (first1 != last1) { + if (first2 == last2) { + std::for_each(first1, last1, std::forward(f)); + return; + } + + if (*first2 < *first1) { + std::forward(f)(*first2); + ++first2; + } else { + std::forward(f)(*first1); + ++first1; + } + } + std::for_each(first2, last2, std::forward(f)); +} + +struct IntervalSet { + // NOTE: element must be sorted in this vector + // but we don't need to use map here + using DataType = std::vector>; + DataType intervals; + + static inline const double inf = Interval::inf; + static inline const double minf = Interval::minf; + + static double NextNum(double val) { return std::nextafter(val, inf); } + + static double PrevNum(double val) { return std::nextafter(val, minf); } + + explicit IntervalSet() = default; + + struct Full {}; + static constexpr const Full full{}; + + explicit IntervalSet(Full) { intervals.emplace_back(minf, inf); } + + explicit IntervalSet(Interval range) { + if (!range.IsEmpty()) intervals.emplace_back(range.l, range.r); + } + + IntervalSet(NumericCompareExpr::Op op, double val) { + if (op == NumericCompareExpr::EQ) { + intervals.emplace_back(val, NextNum(val)); + } else if (op == NumericCompareExpr::NE) { + intervals.emplace_back(minf, val); + intervals.emplace_back(NextNum(val), inf); + } else if (op == NumericCompareExpr::LT) { + intervals.emplace_back(minf, val); + } else if (op == NumericCompareExpr::GT) { + intervals.emplace_back(NextNum(val), inf); + } else if (op == NumericCompareExpr::LET) { + intervals.emplace_back(minf, NextNum(val)); + } else if (op == NumericCompareExpr::GET) { + intervals.emplace_back(val, inf); + } + } + + bool operator==(const IntervalSet &other) const { return intervals == other.intervals; } + bool operator!=(const IntervalSet &other) const { return intervals != other.intervals; } + + std::string ToString() const { + if (IsEmpty()) return "empty set"; + return util::StringJoin( + intervals, [](const auto &i) { return Interval(i.first, i.second).ToString(); }, " or "); + } + + friend std::ostream &operator<<(std::ostream &os, const IntervalSet &is) { return os << is.ToString(); } + + bool IsEmpty() const { return intervals.empty(); } + bool IsFull() const { + if (intervals.size() != 1) return false; + + const auto &v = *intervals.begin(); + return std::isinf(v.first) && std::isinf(v.second) && v.first * v.second < 0; + } + + friend IntervalSet operator&(const IntervalSet &l, const IntervalSet &r) { + IntervalSet result; + + if (l.intervals.empty() || r.intervals.empty()) { + return result; + } + + auto it_l = l.intervals.begin(); + auto it_r = r.intervals.begin(); + + while (it_l != l.intervals.end() && it_r != r.intervals.end()) { + // Find overlap between current intervals + double start = std::max(it_l->first, it_r->first); + double end = std::min(it_l->second, it_r->second); + + if (start <= end) { + result.intervals.emplace_back(start, end); + } + + if (it_l->second < it_r->second) { + ++it_l; + } else { + ++it_r; + } + } + + return result; + } + + friend IntervalSet operator|(const IntervalSet &l, const IntervalSet &r) { + if (l.IsEmpty()) { + return r; + } + + if (r.IsEmpty()) { + return l; + } + + IntervalSet result; + ForEachMerged(l.intervals.begin(), l.intervals.end(), r.intervals.begin(), r.intervals.end(), + [&result](const auto &v) { + if (result.IsEmpty() || result.intervals.rbegin()->second < v.first) { + result.intervals.emplace_back(v.first, v.second); + } else { + result.intervals.rbegin()->second = std::max(result.intervals.rbegin()->second, v.second); + } + }); + + return result; + } + + friend IntervalSet operator~(const IntervalSet &v) { + if (v.IsEmpty()) { + return IntervalSet(full); + } + + IntervalSet result; + + auto iter = v.intervals.begin(); + if (!std::isinf(iter->first)) { + result.intervals.emplace_back(minf, iter->first); + } + + double last = iter->second; + ++iter; + while (iter != v.intervals.end()) { + result.intervals.emplace_back(last, iter->first); + + last = iter->second; + ++iter; + } + + if (!std::isinf(last)) { + result.intervals.emplace_back(last, inf); + } + + return result; + } +}; + +} // namespace kqir diff --git a/src/search/ir.h b/src/search/ir.h new file mode 100644 index 00000000000..72cb7351e03 --- /dev/null +++ b/src/search/ir.h @@ -0,0 +1,440 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "fmt/core.h" +#include "ir_iterator.h" +#include "search/index_info.h" +#include "string_util.h" +#include "type_util.h" + +// kqir stands for Kvrocks Query Intermediate Representation +namespace kqir { + +struct Node { + virtual std::string Dump() const = 0; + virtual std::string_view Name() const = 0; + virtual std::string Content() const { return {}; } + + virtual NodeIterator ChildBegin() { return {}; }; + virtual NodeIterator ChildEnd() { return {}; }; + + virtual std::unique_ptr Clone() const = 0; + + template + std::unique_ptr CloneAs() const { + return Node::MustAs(Clone()); + } + + virtual ~Node() = default; + + template + static std::unique_ptr Create(Args &&...args) { + return std::unique_ptr(new T(std::forward(args)...)); + } + + template + static std::unique_ptr MustAs(std::unique_ptr &&original) { + auto casted = As(std::move(original)); + CHECK(casted != nullptr); + return casted; + } + + template + static std::unique_ptr As(std::unique_ptr &&original) { + auto casted = dynamic_cast(original.get()); + if (casted) original.release(); + return std::unique_ptr(casted); + } + + template + static std::vector> List(std::unique_ptr... args) { + std::vector> result; + result.reserve(sizeof...(Args)); + (result.push_back(std::move(args)), ...); + return result; + } +}; + +struct Ref : Node {}; + +struct FieldRef : Ref { + std::string name; + const FieldInfo *info = nullptr; + + explicit FieldRef(std::string name) : name(std::move(name)) {} + FieldRef(std::string name, const FieldInfo *info) : name(std::move(name)), info(info) {} + + std::string_view Name() const override { return "FieldRef"; } + std::string Dump() const override { return name; } + std::string Content() const override { return Dump(); } + + std::unique_ptr Clone() const override { return std::make_unique(*this); } +}; + +struct Literal : virtual Node {}; + +struct StringLiteral : Literal { + std::string val; + + explicit StringLiteral(std::string val) : val(std::move(val)) {} + + std::string_view Name() const override { return "StringLiteral"; } + std::string Dump() const override { return fmt::format("\"{}\"", util::EscapeString(val)); } + std::string Content() const override { return Dump(); } + + std::unique_ptr Clone() const override { return std::make_unique(*this); } +}; + +struct QueryExpr : virtual Node {}; + +struct BoolAtomExpr : QueryExpr {}; + +struct TagContainExpr : BoolAtomExpr { + std::unique_ptr field; + std::unique_ptr tag; + + TagContainExpr(std::unique_ptr &&field, std::unique_ptr &&tag) + : field(std::move(field)), tag(std::move(tag)) {} + + std::string_view Name() const override { return "TagContainExpr"; } + std::string Dump() const override { return fmt::format("{} hastag {}", field->Dump(), tag->Dump()); } + + NodeIterator ChildBegin() override { return {field.get(), tag.get()}; }; + NodeIterator ChildEnd() override { return {}; }; + + std::unique_ptr Clone() const override { + return std::make_unique(Node::MustAs(field->Clone()), + Node::MustAs(tag->Clone())); + } +}; + +struct NumericLiteral : Literal { + double val; + + explicit NumericLiteral(double val) : val(val) {} + + std::string_view Name() const override { return "NumericLiteral"; } + std::string Dump() const override { return fmt::format("{}", val); } + std::string Content() const override { return Dump(); } + + std::unique_ptr Clone() const override { return std::make_unique(*this); } +}; + +// NOLINTNEXTLINE +#define KQIR_NUMERIC_COMPARE_OPS(X) \ + X(EQ, =, NE, EQ) X(NE, !=, EQ, NE) X(LT, <, GET, GT) X(LET, <=, GT, GET) X(GT, >, LET, LT) X(GET, >=, LT, LET) + +struct NumericCompareExpr : BoolAtomExpr { + enum Op { +#define X(n, s, o, f) n, // NOLINT + KQIR_NUMERIC_COMPARE_OPS(X) +#undef X + } op; + std::unique_ptr field; + std::unique_ptr num; + + NumericCompareExpr(Op op, std::unique_ptr &&field, std::unique_ptr &&num) + : op(op), field(std::move(field)), num(std::move(num)) {} + + static constexpr const char *ToOperator(Op op) { + switch (op) { +// NOLINTNEXTLINE +#define X(n, s, o, f) \ + case n: \ + return #s; + KQIR_NUMERIC_COMPARE_OPS(X) +#undef X + } + + return nullptr; + } + + static constexpr std::optional FromOperator(std::string_view op) { +// NOLINTNEXTLINE +#define X(n, s, o, f) \ + if (op == #s) return n; + KQIR_NUMERIC_COMPARE_OPS(X) +#undef X + + return std::nullopt; + } + + static constexpr Op Negative(Op op) { + switch (op) { +// NOLINTNEXTLINE +#define X(n, s, o, f) \ + case n: \ + return o; + KQIR_NUMERIC_COMPARE_OPS(X) +#undef X + } + + __builtin_unreachable(); + } + + static constexpr Op Flip(Op op) { + switch (op) { +// NOLINTNEXTLINE +#define X(n, s, o, f) \ + case n: \ + return f; + KQIR_NUMERIC_COMPARE_OPS(X) +#undef X + } + + __builtin_unreachable(); + } + + std::string_view Name() const override { return "NumericCompareExpr"; } + std::string Dump() const override { return fmt::format("{} {} {}", field->Dump(), ToOperator(op), num->Dump()); }; + std::string Content() const override { return ToOperator(op); } + + NodeIterator ChildBegin() override { return {field.get(), num.get()}; }; + NodeIterator ChildEnd() override { return {}; }; + + std::unique_ptr Clone() const override { + return std::make_unique(op, Node::MustAs(field->Clone()), + Node::MustAs(num->Clone())); + } +}; + +struct BoolLiteral : BoolAtomExpr, Literal { + bool val; + + explicit BoolLiteral(bool val) : val(val) {} + + std::string_view Name() const override { return "BoolLiteral"; } + std::string Dump() const override { return val ? "true" : "false"; } + std::string Content() const override { return Dump(); } + + std::unique_ptr Clone() const override { return std::make_unique(*this); } +}; + +struct QueryExpr; + +struct NotExpr : QueryExpr { + std::unique_ptr inner; + + explicit NotExpr(std::unique_ptr &&inner) : inner(std::move(inner)) {} + + std::string_view Name() const override { return "NotExpr"; } + std::string Dump() const override { return fmt::format("not {}", inner->Dump()); } + + NodeIterator ChildBegin() override { return NodeIterator{inner.get()}; }; + NodeIterator ChildEnd() override { return {}; }; + + std::unique_ptr Clone() const override { + return std::make_unique(Node::MustAs(inner->Clone())); + } +}; + +struct AndExpr : QueryExpr { + std::vector> inners; + + explicit AndExpr(std::vector> &&inners) : inners(std::move(inners)) {} + + static std::unique_ptr Create(std::vector> &&exprs) { + CHECK(!exprs.empty()); + + if (exprs.size() == 1) { + return std::move(exprs.front()); + } + + return std::make_unique(std::move(exprs)); + } + + std::string_view Name() const override { return "AndExpr"; } + std::string Dump() const override { + return fmt::format("(and {})", util::StringJoin(inners, [](const auto &v) { return v->Dump(); })); + } + + NodeIterator ChildBegin() override { return NodeIterator(inners.begin()); }; + NodeIterator ChildEnd() override { return NodeIterator(inners.end()); }; + + std::unique_ptr Clone() const override { + std::vector> res; + res.reserve(inners.size()); + for (const auto &n : inners) { + res.push_back(Node::MustAs(n->Clone())); + } + return std::make_unique(std::move(res)); + } +}; + +struct OrExpr : QueryExpr { + std::vector> inners; + + explicit OrExpr(std::vector> &&inners) : inners(std::move(inners)) {} + + static std::unique_ptr Create(std::vector> &&exprs) { + CHECK(!exprs.empty()); + + if (exprs.size() == 1) { + return std::move(exprs.front()); + } + + return std::make_unique(std::move(exprs)); + } + + std::string_view Name() const override { return "OrExpr"; } + std::string Dump() const override { + return fmt::format("(or {})", util::StringJoin(inners, [](const auto &v) { return v->Dump(); })); + } + + NodeIterator ChildBegin() override { return NodeIterator(inners.begin()); }; + NodeIterator ChildEnd() override { return NodeIterator(inners.end()); }; + + std::unique_ptr Clone() const override { + std::vector> res; + res.reserve(inners.size()); + for (const auto &n : inners) { + res.push_back(Node::MustAs(n->Clone())); + } + return std::make_unique(std::move(res)); + } +}; + +struct LimitClause : Node { + size_t offset = 0; + size_t count = std::numeric_limits::max(); + + LimitClause(size_t offset, size_t count) : offset(offset), count(count) {} + + std::string_view Name() const override { return "LimitClause"; } + std::string Dump() const override { return fmt::format("limit {}, {}", offset, count); } + std::string Content() const override { return fmt::format("{}, {}", offset, count); } + + std::unique_ptr Clone() const override { return std::make_unique(*this); } +}; + +struct SortByClause : Node { + enum Order { ASC, DESC } order = ASC; + std::unique_ptr field; + + SortByClause(Order order, std::unique_ptr &&field) : order(order), field(std::move(field)) {} + + static constexpr const char *OrderToString(Order order) { return order == ASC ? "asc" : "desc"; } + + std::string_view Name() const override { return "SortByClause"; } + std::string Dump() const override { return fmt::format("sortby {}, {}", field->Dump(), OrderToString(order)); } + std::string Content() const override { return OrderToString(order); } + + NodeIterator ChildBegin() override { return NodeIterator(field.get()); }; + NodeIterator ChildEnd() override { return {}; }; + + std::unique_ptr Clone() const override { + return std::make_unique(order, Node::MustAs(field->Clone())); + } +}; + +struct SelectClause : Node { + std::vector> fields; + + explicit SelectClause(std::vector> &&fields) : fields(std::move(fields)) {} + + std::string_view Name() const override { return "SelectClause"; } + std::string Dump() const override { + if (fields.empty()) return "select *"; + return fmt::format("select {}", util::StringJoin(fields, [](const auto &v) { return v->Dump(); })); + } + + NodeIterator ChildBegin() override { return NodeIterator(fields.begin()); }; + NodeIterator ChildEnd() override { return NodeIterator(fields.end()); }; + + std::unique_ptr Clone() const override { + std::vector> res; + res.reserve(fields.size()); + for (const auto &f : fields) { + res.push_back(Node::MustAs(f->Clone())); + } + return std::make_unique(std::move(res)); + } +}; + +struct IndexRef : Ref { + std::string name; + const IndexInfo *info = nullptr; + + explicit IndexRef(std::string name) : name(std::move(name)) {} + explicit IndexRef(std::string name, const IndexInfo *info) : name(std::move(name)), info(info) {} + + std::string_view Name() const override { return "IndexRef"; } + std::string Dump() const override { return name; } + std::string Content() const override { return Dump(); } + + std::unique_ptr Clone() const override { return std::make_unique(*this); } +}; + +struct SearchExpr : Node { + std::unique_ptr select; + std::unique_ptr index; + std::unique_ptr query_expr; + std::unique_ptr limit; // optional + std::unique_ptr sort_by; // optional + + SearchExpr(std::unique_ptr &&index, std::unique_ptr &&query_expr, + std::unique_ptr &&limit, std::unique_ptr &&sort_by, + std::unique_ptr &&select) + : select(std::move(select)), + index(std::move(index)), + query_expr(std::move(query_expr)), + limit(std::move(limit)), + sort_by(std::move(sort_by)) {} + + std::string_view Name() const override { return "SearchExpr"; } + std::string Dump() const override { + std::string opt; + if (sort_by) opt += " " + sort_by->Dump(); + if (limit) opt += " " + limit->Dump(); + return fmt::format("{} from {} where {}{}", select->Dump(), index->Dump(), query_expr->Dump(), opt); + } + + static inline const std::vector> ChildMap = { + NodeIterator::MemFn<&SearchExpr::select>, NodeIterator::MemFn<&SearchExpr::index>, + NodeIterator::MemFn<&SearchExpr::query_expr>, NodeIterator::MemFn<&SearchExpr::limit>, + NodeIterator::MemFn<&SearchExpr::sort_by>, + }; + + NodeIterator ChildBegin() override { return NodeIterator(this, ChildMap.begin()); }; + NodeIterator ChildEnd() override { return NodeIterator(this, ChildMap.end()); }; + + std::unique_ptr Clone() const override { + return std::make_unique( + Node::MustAs(index->Clone()), Node::MustAs(query_expr->Clone()), + Node::MustAs(limit->Clone()), Node::MustAs(sort_by->Clone()), + Node::MustAs(select->Clone())); + } +}; + +} // namespace kqir diff --git a/src/search/ir_dot_dumper.h b/src/search/ir_dot_dumper.h new file mode 100644 index 00000000000..a0bcfb724f4 --- /dev/null +++ b/src/search/ir_dot_dumper.h @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include "ir.h" +#include "string_util.h" + +namespace kqir { + +struct DotDumper { + std::ostream &os; + + explicit DotDumper(std::ostream &os) : os(os) {} + + void Dump(Node *node) { + os << "digraph {\n"; + dump(node); + os << "}\n"; + } + + private: + static std::string nodeId(Node *node) { return fmt::format("x{:x}", (uint64_t)node); } + + void dump(Node *node) { + os << " " << nodeId(node) << " [ label = \"" << node->Name(); + if (auto content = node->Content(); !content.empty()) { + os << " (" << util::EscapeString(content) << ")\" ];\n"; + } else { + os << "\" ];\n"; + } + for (auto i = node->ChildBegin(); i != node->ChildEnd(); ++i) { + os << " " << nodeId(node) << " -> " << nodeId(*i) << ";\n"; + dump(*i); + } + } +}; + +} // namespace kqir diff --git a/src/search/ir_iterator.h b/src/search/ir_iterator.h new file mode 100644 index 00000000000..0730e86ed50 --- /dev/null +++ b/src/search/ir_iterator.h @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include "type_util.h" + +namespace kqir { + +struct Node; + +struct NodeIterator { + std::variant, + std::pair>::const_iterator>, + std::vector>::iterator> + val; + + NodeIterator() : val(nullptr) {} + explicit NodeIterator(Node *node) : val(node) {} + NodeIterator(Node *n1, Node *n2) : val(std::array{n1, n2}) {} + explicit NodeIterator(Node *parent, std::vector>::const_iterator iter) + : val(std::make_pair(parent, iter)) {} + template , int> = 0> + explicit NodeIterator(Iterator iter) : val(*CastToNodeIter(&iter)) {} + + template + static auto CastToNodeIter(Iterator *iter) { + auto res __attribute__((__may_alias__)) = reinterpret_cast>::iterator *>(iter); + return res; + } + + template + static Node *MemFn(Node *parent) { + return (reinterpret_cast::type *>(parent)->*F).get(); + } + + friend bool operator==(NodeIterator l, NodeIterator r) { return l.val == r.val; } + + friend bool operator!=(NodeIterator l, NodeIterator r) { return l.val != r.val; } + + Node *operator*() { + if (val.index() == 0) { + return std::get<0>(val); + } else if (val.index() == 1) { + return std::get<1>(val)[0]; + } else if (val.index() == 2) { + auto &[parent, iter] = std::get<2>(val); + return (*iter)(parent); + } else { + return std::get<3>(val)->get(); + } + } + + NodeIterator &operator++() { + if (val.index() == 0) { + val = nullptr; + } else if (val.index() == 1) { + val = std::get<1>(val)[1]; + } else if (val.index() == 2) { + ++std::get<2>(val).second; + } else { + ++std::get<3>(val); + } + + return *this; + } +}; + +} // namespace kqir diff --git a/src/search/ir_pass.h b/src/search/ir_pass.h new file mode 100644 index 00000000000..2068a45a4f4 --- /dev/null +++ b/src/search/ir_pass.h @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include "ir.h" +#include "search/ir_plan.h" + +namespace kqir { + +struct Pass { + virtual std::unique_ptr Transform(std::unique_ptr node) = 0; + + virtual void Reset() {} + + virtual ~Pass() = default; +}; + +struct Visitor : Pass { + std::unique_ptr Transform(std::unique_ptr node) override { + if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } + + __builtin_unreachable(); + } + + template + std::unique_ptr VisitAs(std::unique_ptr n) { + return Node::MustAs(Visit(std::move(n))); + } + + template + std::unique_ptr TransformAs(std::unique_ptr n) { + return Node::MustAs(Transform(std::move(n))); + } + + virtual std::unique_ptr Visit(std::unique_ptr node) { + node->index = VisitAs(std::move(node->index)); + node->select = VisitAs(std::move(node->select)); + node->query_expr = TransformAs(std::move(node->query_expr)); + if (node->sort_by) node->sort_by = VisitAs(std::move(node->sort_by)); + if (node->limit) node->limit = VisitAs(std::move(node->limit)); + return node; + } + + virtual std::unique_ptr Visit(std::unique_ptr node) { + for (auto &n : node->fields) { + n = VisitAs(std::move(n)); + } + + return node; + } + + virtual std::unique_ptr Visit(std::unique_ptr node) { return node; } + + virtual std::unique_ptr Visit(std::unique_ptr node) { return node; } + + virtual std::unique_ptr Visit(std::unique_ptr node) { return node; } + + virtual std::unique_ptr Visit(std::unique_ptr node) { return node; } + + virtual std::unique_ptr Visit(std::unique_ptr node) { return node; } + + virtual std::unique_ptr Visit(std::unique_ptr node) { + node->field = VisitAs(std::move(node->field)); + node->num = VisitAs(std::move(node->num)); + return node; + } + + virtual std::unique_ptr Visit(std::unique_ptr node) { + node->field = VisitAs(std::move(node->field)); + node->tag = VisitAs(std::move(node->tag)); + return node; + } + + virtual std::unique_ptr Visit(std::unique_ptr node) { + for (auto &n : node->inners) { + n = TransformAs(std::move(n)); + } + + return node; + } + + virtual std::unique_ptr Visit(std::unique_ptr node) { + for (auto &n : node->inners) { + n = TransformAs(std::move(n)); + } + + return node; + } + + virtual std::unique_ptr Visit(std::unique_ptr node) { + node->inner = TransformAs(std::move(node->inner)); + return node; + } + + virtual std::unique_ptr Visit(std::unique_ptr node) { return node; } + + virtual std::unique_ptr Visit(std::unique_ptr node) { + node->field = VisitAs(std::move(node->field)); + return node; + } + + virtual std::unique_ptr Visit(std::unique_ptr node) { return node; } + + virtual std::unique_ptr Visit(std::unique_ptr node) { return node; } + + virtual std::unique_ptr Visit(std::unique_ptr node) { return node; } + + virtual std::unique_ptr Visit(std::unique_ptr node) { return node; } + + virtual std::unique_ptr Visit(std::unique_ptr node) { + node->source = TransformAs(std::move(node->source)); + node->filter_expr = TransformAs(std::move(node->filter_expr)); + return node; + } + + virtual std::unique_ptr Visit(std::unique_ptr node) { + node->op = TransformAs(std::move(node->op)); + node->limit = VisitAs(std::move(node->limit)); + return node; + } + + virtual std::unique_ptr Visit(std::unique_ptr node) { + node->op = TransformAs(std::move(node->op)); + node->order = VisitAs(std::move(node->order)); + return node; + } + + virtual std::unique_ptr Visit(std::unique_ptr node) { + node->op = TransformAs(std::move(node->op)); + node->limit = VisitAs(std::move(node->limit)); + node->order = VisitAs(std::move(node->order)); + return node; + } + + virtual std::unique_ptr Visit(std::unique_ptr node) { + node->source = TransformAs(std::move(node->source)); + node->select = VisitAs(std::move(node->select)); + return node; + } + + virtual std::unique_ptr Visit(std::unique_ptr node) { + for (auto &n : node->ops) { + n = TransformAs(std::move(n)); + } + + return node; + } +}; + +} // namespace kqir diff --git a/src/search/ir_plan.h b/src/search/ir_plan.h new file mode 100644 index 00000000000..f93199b50c9 --- /dev/null +++ b/src/search/ir_plan.h @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include +#include + +#include "ir.h" +#include "search/interval.h" +#include "search/ir_sema_checker.h" +#include "string_util.h" + +namespace kqir { + +struct PlanOperator : Node {}; + +struct Noop : PlanOperator { + std::string_view Name() const override { return "Noop"; }; + std::string Dump() const override { return "noop"; } + + std::unique_ptr Clone() const override { return std::make_unique(*this); } +}; + +struct FullIndexScan : PlanOperator { + std::unique_ptr index; + + explicit FullIndexScan(std::unique_ptr index) : index(std::move(index)) {} + + std::string_view Name() const override { return "FullIndexScan"; } + std::string Dump() const override { return fmt::format("full-scan {}", index->name); } + + NodeIterator ChildBegin() override { return NodeIterator{index.get()}; } + NodeIterator ChildEnd() override { return {}; } + + std::unique_ptr Clone() const override { + return std::make_unique(Node::MustAs(index->Clone())); + } +}; + +struct FieldScan : PlanOperator { + std::unique_ptr field; + + explicit FieldScan(std::unique_ptr field) : field(std::move(field)) {} + + NodeIterator ChildBegin() override { return NodeIterator{field.get()}; } + NodeIterator ChildEnd() override { return {}; } +}; + +struct NumericFieldScan : FieldScan { + Interval range; + SortByClause::Order order; + + NumericFieldScan(std::unique_ptr field, Interval range, SortByClause::Order order = SortByClause::ASC) + : FieldScan(std::move(field)), range(range), order(order) {} + + std::string_view Name() const override { return "NumericFieldScan"; }; + std::string Content() const override { + return fmt::format("{}, {}", range.ToString(), SortByClause::OrderToString(order)); + }; + std::string Dump() const override { return fmt::format("numeric-scan {}, {}", field->name, Content()); } + + std::unique_ptr Clone() const override { + return std::make_unique(field->CloneAs(), range, order); + } +}; + +struct TagFieldScan : FieldScan { + std::string tag; + + TagFieldScan(std::unique_ptr field, std::string tag) : FieldScan(std::move(field)), tag(std::move(tag)) {} + + std::string_view Name() const override { return "TagFieldScan"; }; + std::string Content() const override { return tag; }; + std::string Dump() const override { return fmt::format("tag-scan {}, {}", field->name, tag); } + + std::unique_ptr Clone() const override { + return std::make_unique(field->CloneAs(), tag); + } +}; + +struct Filter : PlanOperator { + std::unique_ptr source; + std::unique_ptr filter_expr; + + Filter(std::unique_ptr &&source, std::unique_ptr &&filter_expr) + : source(std::move(source)), filter_expr(std::move(filter_expr)) {} + + std::string_view Name() const override { return "Filter"; }; + std::string Dump() const override { return fmt::format("(filter {}: {})", filter_expr->Dump(), source->Dump()); } + + NodeIterator ChildBegin() override { return {source.get(), filter_expr.get()}; } + NodeIterator ChildEnd() override { return {}; } + + std::unique_ptr Clone() const override { + return std::make_unique(Node::MustAs(source->Clone()), + Node::MustAs(filter_expr->Clone())); + } +}; + +struct Merge : PlanOperator { + std::vector> ops; + + explicit Merge(std::vector> &&ops) : ops(std::move(ops)) {} + + static std::unique_ptr Create(std::vector> &&ops) { + CHECK(!ops.empty()); + + if (ops.size() == 1) { + return std::move(ops.front()); + } + + return std::make_unique(std::move(ops)); + } + + std::string_view Name() const override { return "Merge"; }; + std::string Dump() const override { + return fmt::format("(merge {})", util::StringJoin(ops, [](const auto &v) { return v->Dump(); })); + } + + NodeIterator ChildBegin() override { return NodeIterator(ops.begin()); } + NodeIterator ChildEnd() override { return NodeIterator(ops.end()); } + + std::unique_ptr Clone() const override { + std::vector> res; + res.reserve(ops.size()); + for (const auto &op : ops) { + res.push_back(Node::MustAs(op->Clone())); + } + return std::make_unique(std::move(res)); + } +}; + +struct Limit : PlanOperator { + std::unique_ptr op; + std::unique_ptr limit; + + Limit(std::unique_ptr &&op, std::unique_ptr &&limit) + : op(std::move(op)), limit(std::move(limit)) {} + + std::string_view Name() const override { return "Limit"; }; + std::string Dump() const override { + return fmt::format("(limit {}, {}: {})", limit->offset, limit->count, op->Dump()); + } + + NodeIterator ChildBegin() override { return NodeIterator{op.get(), limit.get()}; } + NodeIterator ChildEnd() override { return {}; } + + std::unique_ptr Clone() const override { + return std::make_unique(Node::MustAs(op->Clone()), Node::MustAs(limit->Clone())); + } +}; + +struct Sort : PlanOperator { + std::unique_ptr op; + std::unique_ptr order; + + Sort(std::unique_ptr &&op, std::unique_ptr &&order) + : op(std::move(op)), order(std::move(order)) {} + + std::string_view Name() const override { return "Sort"; }; + std::string Dump() const override { + return fmt::format("(sort {}, {}: {})", order->field->Dump(), order->OrderToString(order->order), op->Dump()); + } + + NodeIterator ChildBegin() override { return NodeIterator{op.get(), order.get()}; } + NodeIterator ChildEnd() override { return {}; } + + std::unique_ptr Clone() const override { + return std::make_unique(Node::MustAs(op->Clone()), Node::MustAs(order->Clone())); + } +}; + +// operator fusion: Sort + Limit +struct TopNSort : PlanOperator { + std::unique_ptr op; + std::unique_ptr order; + std::unique_ptr limit; + + TopNSort(std::unique_ptr &&op, std::unique_ptr &&order, + std::unique_ptr &&limit) + : op(std::move(op)), order(std::move(order)), limit(std::move(limit)) {} + + std::string_view Name() const override { return "TopNSort"; }; + std::string Dump() const override { + return fmt::format("(top-n sort {}, {}, {}, {}: {})", order->field->Dump(), order->OrderToString(order->order), + limit->offset, limit->count, op->Dump()); + } + + static inline const std::vector> ChildMap = { + NodeIterator::MemFn<&TopNSort::op>, NodeIterator::MemFn<&TopNSort::order>, NodeIterator::MemFn<&TopNSort::limit>}; + + NodeIterator ChildBegin() override { return NodeIterator(this, ChildMap.begin()); } + NodeIterator ChildEnd() override { return NodeIterator(this, ChildMap.end()); } + + std::unique_ptr Clone() const override { + return std::make_unique(Node::MustAs(op->Clone()), + Node::MustAs(order->Clone()), + Node::MustAs(limit->Clone())); + } +}; + +struct Projection : PlanOperator { + std::unique_ptr source; + std::unique_ptr select; + + Projection(std::unique_ptr &&source, std::unique_ptr &&select) + : source(std::move(source)), select(std::move(select)) {} + + std::string_view Name() const override { return "Projection"; }; + std::string Dump() const override { + auto select_str = + select->fields.empty() ? "*" : util::StringJoin(select->fields, [](const auto &v) { return v->Dump(); }); + return fmt::format("project {}: {}", select_str, source->Dump()); + } + + NodeIterator ChildBegin() override { return {source.get(), select.get()}; } + NodeIterator ChildEnd() override { return {}; } + + std::unique_ptr Clone() const override { + return std::make_unique(Node::MustAs(source->Clone()), + Node::MustAs(select->Clone())); + } +}; + +} // namespace kqir diff --git a/src/search/ir_sema_checker.h b/src/search/ir_sema_checker.h new file mode 100644 index 00000000000..a7a7618173d --- /dev/null +++ b/src/search/ir_sema_checker.h @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include +#include + +#include "fmt/core.h" +#include "index_info.h" +#include "ir.h" +#include "search_encoding.h" +#include "storage/redis_metadata.h" + +namespace kqir { + +struct SemaChecker { + const IndexMap &index_map; + std::string ns; + + const IndexInfo *current_index = nullptr; + + explicit SemaChecker(const IndexMap &index_map) : index_map(index_map) {} + + Status Check(Node *node) { + if (auto v = dynamic_cast(node)) { + auto index_name = v->index->name; + if (auto iter = index_map.Find(index_name, ns); iter != index_map.end()) { + current_index = iter->second.get(); + v->index->info = current_index; + + GET_OR_RET(Check(v->select.get())); + GET_OR_RET(Check(v->query_expr.get())); + if (v->limit) GET_OR_RET(Check(v->limit.get())); + if (v->sort_by) GET_OR_RET(Check(v->sort_by.get())); + } else { + return {Status::NotOK, fmt::format("index `{}` not found", index_name)}; + } + } else if (auto v [[maybe_unused]] = dynamic_cast(node)) { + return Status::OK(); + } else if (auto v = dynamic_cast(node)) { + if (auto iter = current_index->fields.find(v->field->name); iter == current_index->fields.end()) { + return {Status::NotOK, fmt::format("field `{}` not found in index `{}`", v->field->name, current_index->name)}; + } else if (!iter->second.IsSortable()) { + return {Status::NotOK, fmt::format("field `{}` is not sortable", v->field->name)}; + } else { + v->field->info = &iter->second; + } + } else if (auto v = dynamic_cast(node)) { + for (const auto &n : v->inners) { + GET_OR_RET(Check(n.get())); + } + } else if (auto v = dynamic_cast(node)) { + for (const auto &n : v->inners) { + GET_OR_RET(Check(n.get())); + } + } else if (auto v = dynamic_cast(node)) { + GET_OR_RET(Check(v->inner.get())); + } else if (auto v = dynamic_cast(node)) { + if (auto iter = current_index->fields.find(v->field->name); iter == current_index->fields.end()) { + return {Status::NotOK, fmt::format("field `{}` not found in index `{}`", v->field->name)}; + } else if (auto meta = iter->second.MetadataAs(); !meta) { + return {Status::NotOK, fmt::format("field `{}` is not a tag field", v->field->name)}; + } else { + v->field->info = &iter->second; + + if (v->tag->val.empty()) { + return {Status::NotOK, "tag cannot be an empty string"}; + } + + if (v->tag->val.find(meta->separator) != std::string::npos) { + return {Status::NotOK, fmt::format("tag cannot contain the separator `{}`", meta->separator)}; + } + } + } else if (auto v = dynamic_cast(node)) { + if (auto iter = current_index->fields.find(v->field->name); iter == current_index->fields.end()) { + return {Status::NotOK, fmt::format("field `{}` not found in index `{}`", v->field->name, current_index->name)}; + } else if (!iter->second.MetadataAs()) { + return {Status::NotOK, fmt::format("field `{}` is not a numeric field", v->field->name)}; + } else { + v->field->info = &iter->second; + } + } else if (auto v = dynamic_cast(node)) { + for (const auto &n : v->fields) { + if (auto iter = current_index->fields.find(n->name); iter == current_index->fields.end()) { + return {Status::NotOK, fmt::format("field `{}` not found in index `{}`", n->name, current_index->name)}; + } else { + n->info = &iter->second; + } + } + } else if (auto v [[maybe_unused]] = dynamic_cast(node)) { + return Status::OK(); + } else { + return {Status::NotOK, fmt::format("unexpected IR node type: {}", node->Name())}; + } + + return Status::OK(); + } +}; + +} // namespace kqir diff --git a/src/search/passes/cost_model.h b/src/search/passes/cost_model.h new file mode 100644 index 00000000000..86e0e3a58e5 --- /dev/null +++ b/src/search/passes/cost_model.h @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include +#include + +#include "search/interval.h" +#include "search/ir.h" +#include "search/ir_plan.h" + +namespace kqir { + +// TODO: collect statistical information of index in runtime +// to optimize the cost model +struct CostModel { + static size_t Transform(const PlanOperator *node) { + if (auto v = dynamic_cast(node)) { + return Visit(v); + } + if (auto v = dynamic_cast(node)) { + return Visit(v); + } + if (auto v = dynamic_cast(node)) { + return Visit(v); + } + if (auto v = dynamic_cast(node)) { + return Visit(v); + } + if (auto v = dynamic_cast(node)) { + return Visit(v); + } + + CHECK(false) << "plan operator type not supported"; + } + + static size_t Visit(const FullIndexScan *node) { return 100; } + + static size_t Visit(const NumericFieldScan *node) { + if (node->range.r == IntervalSet::NextNum(node->range.l)) { + return 5; + } + + size_t base = 10; + + if (std::isinf(node->range.l)) { + base += 20; + } + + if (std::isinf(node->range.r)) { + base += 20; + } + + return base; + } + + static size_t Visit(const TagFieldScan *node) { return 10; } + + static size_t Visit(const Filter *node) { return Transform(node->source.get()) + 1; } + + static size_t Visit(const Merge *node) { + return std::accumulate(node->ops.begin(), node->ops.end(), size_t(0), [](size_t res, const auto &v) { + if (dynamic_cast(v.get())) { + res += 9; + } + return res + Transform(v.get()); + }); + } +}; + +} // namespace kqir diff --git a/src/search/passes/index_selection.h b/src/search/passes/index_selection.h new file mode 100644 index 00000000000..e60287d4d01 --- /dev/null +++ b/src/search/passes/index_selection.h @@ -0,0 +1,323 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include +#include +#include +#include + +#include "search/index_info.h" +#include "search/interval.h" +#include "search/ir.h" +#include "search/ir_pass.h" +#include "search/ir_plan.h" +#include "search/passes/cost_model.h" +#include "search/passes/interval_analysis.h" +#include "search/passes/push_down_not_expr.h" +#include "search/passes/simplify_and_or_expr.h" +#include "search/search_encoding.h" + +namespace kqir { + +struct IndexSelection : Visitor { + SortByClause *order = nullptr; + bool sort_removable = false; + IndexRef *index = nullptr; + IntervalAnalysis::Result intervals; + + void Reset() override { + order = nullptr; + sort_removable = false; + index = nullptr; + intervals.clear(); + } + + std::unique_ptr Visit(std::unique_ptr node) override { + IntervalAnalysis analysis(false); + node = Node::MustAs(analysis.Transform(std::move(node))); + intervals = std::move(analysis.result); + + return Visitor::Visit(std::move(node)); + } + + std::unique_ptr Visit(std::unique_ptr node) override { + order = node->order.get(); + + node = Node::MustAs(Visitor::Visit(std::move(node))); + + if (sort_removable) return std::move(node->op); + + return node; + } + + bool HasGoodOrder() const { return order && order->field->info->HasIndex(); } + + std::unique_ptr GenerateScanFromOrder() const { + if (order->field->info->MetadataAs()) { + return std::make_unique(order->field->CloneAs(), Interval::Full(), order->order); + } else { + CHECK(false) << "current only numeric field is supported for ordering"; + } + } + + // if there's no Filter node, enter this method + std::unique_ptr Visit(std::unique_ptr node) override { + if (HasGoodOrder()) { + sort_removable = true; + return GenerateScanFromOrder(); + } + + return node; + } + + std::unique_ptr Visit(std::unique_ptr node) override { + auto index_scan = Node::MustAs(std::move(node->source)); + + if (HasGoodOrder()) { + // TODO: optimize plan with sorting order via the cost model + sort_removable = true; + + auto scan = GenerateScanFromOrder(); + return std::make_unique(std::move(scan), std::move(node->filter_expr)); + } else { + index = index_scan->index.get(); + + return TransformExpr(node->filter_expr.get()); + } + } + + std::unique_ptr TransformExpr(QueryExpr *node) { + if (auto v = dynamic_cast(node)) { + return VisitExpr(v); + } + if (auto v = dynamic_cast(node)) { + return VisitExpr(v); + } + if (auto v = dynamic_cast(node)) { + return VisitExpr(v); + } + if (auto v = dynamic_cast(node)) { + return VisitExpr(v); + } + if (auto v = dynamic_cast(node)) { + return VisitExpr(v); + } + + CHECK(false) << "unreachable"; + } + + std::unique_ptr MakeFullIndexFilter(QueryExpr *node) const { + return std::make_unique(std::make_unique(index->CloneAs()), + node->CloneAs()); + } + + std::unique_ptr VisitExpr(NotExpr *node) const { + // after PushDownNotExpr, `node->inner` should be one of TagContainExpr and NumericCompareExpr + return MakeFullIndexFilter(node); + } + + std::unique_ptr VisitExpr(TagContainExpr *node) const { + if (node->field->info->HasIndex()) { + return std::make_unique(node->field->CloneAs(), node->tag->val); + } + + return MakeFullIndexFilter(node); + } + + // enter only if there's just a single NumericCompareExpr, without and/or expression + std::unique_ptr VisitExpr(NumericCompareExpr *node) const { + if (node->field->info->HasIndex() && node->op != NumericCompareExpr::NE) { + IntervalSet is(node->op, node->num->val); + return PlanFromInterval(is, node->field.get(), SortByClause::ASC); + } + + return MakeFullIndexFilter(node); + } + + template + std::unique_ptr VisitExprImpl(Expr *node) { + struct AggregatedNodes { + std::set nodes; + IntervalSet intervals; + }; + + std::map agg_nodes; + std::vector> rest_nodes; + + for (const auto &n : node->inners) { + IntervalSet is; + const FieldInfo *field = nullptr; + + if (auto iter = intervals.find(n.get()); iter != intervals.end()) { + field = iter->second.field_info; + is = iter->second.intervals; + } else if (auto expr = dynamic_cast(n.get()); expr && expr->op != NumericCompareExpr::NE) { + field = expr->field->info; + is = IntervalSet(expr->op, expr->num->val); + } else { + rest_nodes.emplace_back(n.get()); + continue; + } + + if (!field->HasIndex()) { + rest_nodes.emplace_back(n.get()); + continue; + } + + if (auto jter = agg_nodes.find(field); jter != agg_nodes.end()) { + jter->second.nodes.emplace(n.get()); + if constexpr (std::is_same_v) { + jter->second.intervals = jter->second.intervals & is; + } else { + jter->second.intervals = jter->second.intervals | is; + } + } else { + rest_nodes.emplace_back(field); + agg_nodes.emplace(field, AggregatedNodes{std::set{n.get()}, is}); + } + } + + if constexpr (std::is_same_v) { + struct SelectionInfo { + std::unique_ptr plan; + std::set selected_nodes; + size_t cost; + + SelectionInfo(std::unique_ptr &&plan, std::set nodes) + : plan(std::move(plan)), selected_nodes(std::move(nodes)), cost(CostModel::Transform(this->plan.get())) {} + }; + + std::vector available_plans; + + available_plans.emplace_back(std::make_unique(index->CloneAs()), std::set{}); + + for (auto v : rest_nodes) { + if (std::holds_alternative(v)) { + auto n = std::get(v); + auto op = TransformExpr(n); + + available_plans.emplace_back(std::move(op), std::set{n}); + } else { + auto n = std::get(v); + const auto &agg_info = agg_nodes.at(n); + auto field_ref = std::make_unique(n->name, n); + available_plans.emplace_back(PlanFromInterval(agg_info.intervals, field_ref.get(), SortByClause::ASC), + agg_info.nodes); + } + } + + auto &best_plan = *std::min_element(available_plans.begin(), available_plans.end(), + [](const auto &l, const auto &r) { return l.cost < r.cost; }); + + std::vector> filter_nodes; + for (const auto &n : node->inners) { + if (best_plan.selected_nodes.count(n.get()) == 0) filter_nodes.push_back(n->template CloneAs()); + } + + if (filter_nodes.empty()) { + return std::move(best_plan.plan); + } else if (filter_nodes.size() == 1) { + return std::make_unique(std::move(best_plan.plan), std::move(filter_nodes.front())); + } else { + return std::make_unique(std::move(best_plan.plan), std::make_unique(std::move(filter_nodes))); + } + } else { + auto full_scan_plan = MakeFullIndexFilter(node); + + std::vector> merged_elems; + std::vector> elem_filter; + + auto add_filter = [&elem_filter](std::unique_ptr op) { + if (!elem_filter.empty()) { + std::unique_ptr filter = std::make_unique(OrExpr::Create(CloneExprs(elem_filter))); + + PushDownNotExpr pdne; + filter = Node::MustAs(pdne.Transform(std::move(filter))); + SimplifyAndOrExpr saoe; + filter = Node::MustAs(saoe.Transform(std::move(filter))); + + op = std::make_unique(std::move(op), std::move(filter)); + } + + return op; + }; + + for (auto v : rest_nodes) { + if (std::holds_alternative(v)) { + auto n = std::get(v); + auto op = add_filter(TransformExpr(n)); + + merged_elems.push_back(std::move(op)); + elem_filter.push_back(n->CloneAs()); + } else { + auto n = std::get(v); + const auto &agg_info = agg_nodes.at(n); + auto field_ref = std::make_unique(n->name, n); + auto elem = PlanFromInterval(agg_info.intervals, field_ref.get(), SortByClause::ASC); + elem = add_filter(std::move(elem)); + + merged_elems.push_back(std::move(elem)); + for (auto nn : agg_info.nodes) { + elem_filter.push_back(nn->template CloneAs()); + } + } + } + + auto merge_plan = Merge::Create(std::move(merged_elems)); + auto &best_plan = const_cast &>(std::min( + full_scan_plan, merge_plan, + [](const auto &l, const auto &r) { return CostModel::Transform(l.get()) < CostModel::Transform(r.get()); })); + + return std::move(best_plan); + } + } + + static std::vector> CloneExprs(const std::vector> &exprs) { + std::vector> result; + result.reserve(exprs.size()); + + for (const auto &e : exprs) result.push_back(e->CloneAs()); + return result; + } + + std::unique_ptr VisitExpr(AndExpr *node) { return VisitExprImpl(node); } + + std::unique_ptr VisitExpr(OrExpr *node) { return VisitExprImpl(node); } + + static std::unique_ptr PlanFromInterval(const IntervalSet &intervals, FieldRef *field, + SortByClause::Order order) { + std::vector> result; + if (order == SortByClause::ASC) { + for (const auto &[l, r] : intervals.intervals) { + result.push_back(std::make_unique(field->CloneAs(), Interval(l, r), order)); + } + } else { + for (const auto &[l, r] : ranges::views::reverse(intervals.intervals)) { + result.push_back(std::make_unique(field->CloneAs(), Interval(l, r), order)); + } + } + + return Merge::Create(std::move(result)); + } +}; + +} // namespace kqir diff --git a/src/search/passes/interval_analysis.h b/src/search/passes/interval_analysis.h new file mode 100644 index 00000000000..5ed56125195 --- /dev/null +++ b/src/search/passes/interval_analysis.h @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include "search/interval.h" +#include "search/ir.h" +#include "search/ir_pass.h" +#include "search/ir_plan.h" +#include "type_util.h" + +namespace kqir { + +struct IntervalAnalysis : Visitor { + struct IntervalInfo { + std::string field_name; + const FieldInfo *field_info; + IntervalSet intervals; + }; + + using Result = std::map; + + Result result; + const bool simplify_numeric_compare; + + explicit IntervalAnalysis(bool simplify_numeric_compare = false) + : simplify_numeric_compare(simplify_numeric_compare) {} + + void Reset() override { result.clear(); } + + template + std::unique_ptr VisitImpl(std::unique_ptr node) { + node = Node::MustAs(Visitor::Visit(std::move(node))); + + struct LocalIntervalInfo { + IntervalSet intervals; + std::set nodes; + const FieldInfo *field; + }; + + std::map interval_map; + for (const auto &n : node->inners) { + IntervalSet new_interval; + const FieldInfo *new_field_info = nullptr; + std::string new_field; + + if (auto v = dynamic_cast(n.get())) { + new_interval = IntervalSet(v->op, v->num->val); + new_field = v->field->name; + new_field_info = v->field->info; + } else if (auto iter = result.find(n.get()); iter != result.end()) { + new_interval = iter->second.intervals; + new_field = iter->second.field_name; + new_field_info = iter->second.field_info; + } else { + continue; + } + + if (auto iter = interval_map.find(new_field); iter != interval_map.end()) { + if constexpr (std::is_same_v) { + iter->second.intervals = iter->second.intervals | new_interval; + } else if constexpr (std::is_same_v) { + iter->second.intervals = iter->second.intervals & new_interval; + } else { + static_assert(AlwaysFalse); + } + iter->second.nodes.emplace(n.get()); + iter->second.field = new_field_info; + } else { + interval_map.emplace(new_field, LocalIntervalInfo{new_interval, std::set{n.get()}, new_field_info}); + } + } + + if (interval_map.size() == 1) { + const auto &elem = *interval_map.begin(); + result.emplace(node.get(), IntervalInfo{elem.first, elem.second.field, elem.second.intervals}); + } + + if (simplify_numeric_compare) { + for (const auto &[field, info] : interval_map) { + auto iter = std::remove_if(node->inners.begin(), node->inners.end(), + [&info = info](const auto &n) { return info.nodes.count(n.get()) == 1; }); + node->inners.erase(iter, node->inners.end()); + for (const auto &n : info.nodes) { + if (auto iter = result.find(n); iter != result.end()) result.erase(iter); + } + + auto field_node = std::make_unique(field, info.field); + node->inners.emplace_back(GenerateFromInterval(info.intervals, field_node.get())); + } + } + + return node; + } + + static std::unique_ptr GenerateFromInterval(const IntervalSet &intervals, FieldRef *field) { + if (intervals.IsEmpty()) { + return std::make_unique(false); + } + + if (intervals.IsFull()) { + return std::make_unique(true); + } + + std::vector> exprs; + + if (intervals.intervals.size() > 1 && std::isinf(intervals.intervals.front().first) && + std::isinf(intervals.intervals.back().second)) { + bool is_all_ne = true; + auto iter = intervals.intervals.begin(); + auto last = iter->second; + ++iter; + while (iter != intervals.intervals.end()) { + if (iter->first != IntervalSet::NextNum(last)) { + is_all_ne = false; + break; + } + + last = iter->second; + ++iter; + } + + if (is_all_ne) { + for (auto i = intervals.intervals.begin(); i != intervals.intervals.end() && !std::isinf(i->second); ++i) { + exprs.emplace_back(std::make_unique(NumericCompareExpr::NE, field->CloneAs(), + std::make_unique(i->second))); + } + + return std::make_unique(std::move(exprs)); + } + } + + for (auto [l, r] : intervals.intervals) { + if (std::isinf(l)) { + exprs.emplace_back(std::make_unique(NumericCompareExpr::LT, field->CloneAs(), + std::make_unique(r))); + } else if (std::isinf(r)) { + exprs.emplace_back(std::make_unique(NumericCompareExpr::GET, field->CloneAs(), + std::make_unique(l))); + } else if (r == IntervalSet::NextNum(l)) { + exprs.emplace_back(std::make_unique(NumericCompareExpr::EQ, field->CloneAs(), + std::make_unique(l))); + } else { + std::vector> sub_expr; + sub_expr.emplace_back(std::make_unique(NumericCompareExpr::GET, field->CloneAs(), + std::make_unique(l))); + sub_expr.emplace_back(std::make_unique(NumericCompareExpr::LT, field->CloneAs(), + std::make_unique(r))); + + exprs.emplace_back(std::make_unique(std::move(sub_expr))); + } + } + + if (exprs.size() == 1) { + return std::move(exprs.front()); + } else { + return std::make_unique(std::move(exprs)); + } + } + + std::unique_ptr Visit(std::unique_ptr node) override { return VisitImpl(std::move(node)); } + + std::unique_ptr Visit(std::unique_ptr node) override { return VisitImpl(std::move(node)); } +}; + +} // namespace kqir diff --git a/src/search/passes/lower_to_plan.h b/src/search/passes/lower_to_plan.h new file mode 100644 index 00000000000..94828c22b2b --- /dev/null +++ b/src/search/passes/lower_to_plan.h @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include + +#include "search/ir.h" +#include "search/ir_pass.h" +#include "search/ir_plan.h" + +namespace kqir { + +struct LowerToPlan : Visitor { + std::unique_ptr Visit(std::unique_ptr node) override { + auto scan = std::make_unique(node->index->CloneAs()); + + std::unique_ptr op; + if (auto b = Node::As(std::move(node->query_expr))) { + if (b->val) { + op = std::move(scan); + } else { + op = std::make_unique(); + } + } else { + op = std::make_unique(std::move(scan), std::move(node->query_expr)); + } + + if (!dynamic_cast(op.get())) { + // order is important here, since limit(sort(op)) is different from sort(limit(op)) + if (node->sort_by) { + op = std::make_unique(std::move(op), std::move(node->sort_by)); + } + + if (node->limit) { + op = std::make_unique(std::move(op), std::move(node->limit)); + } + } + + return std::make_unique(std::move(op), std::move(node->select)); + } +}; + +} // namespace kqir diff --git a/src/search/passes/manager.h b/src/search/passes/manager.h new file mode 100644 index 00000000000..e94a12e940f --- /dev/null +++ b/src/search/passes/manager.h @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include +#include +#include +#include + +#include "search/ir.h" +#include "search/ir_pass.h" +#include "search/passes/index_selection.h" +#include "search/passes/interval_analysis.h" +#include "search/passes/lower_to_plan.h" +#include "search/passes/push_down_not_expr.h" +#include "search/passes/simplify_and_or_expr.h" +#include "search/passes/simplify_boolean.h" +#include "search/passes/sort_limit_fuse.h" +#include "type_util.h" + +namespace kqir { + +using PassSequence = std::vector>; + +struct PassManager { + static std::unique_ptr Execute(const PassSequence &seq, std::unique_ptr node) { + for (auto &pass : seq) { + pass->Reset(); + node = pass->Transform(std::move(node)); + } + return node; + } + + template + static PassSequence Create(Passes &&...passes) { + static_assert(std::conjunction_v>...>); + + PassSequence result; + result.reserve(sizeof...(passes)); + (result.push_back(std::make_unique(std::move(passes))), ...); + + return result; + } + + template + static PassSequence Merge(PassSeqs &&...seqs) { + static_assert(std::conjunction_v>...>); + static_assert(std::conjunction_v>...>); + + PassSequence result; + result.reserve((seqs.size() + ...)); + (result.insert(result.end(), std::make_move_iterator(seqs.begin()), std::make_move_iterator(seqs.end())), ...); + + return result; + } + + static PassSequence ExprPasses() { + return Create(SimplifyAndOrExpr{}, PushDownNotExpr{}, SimplifyBoolean{}, SimplifyAndOrExpr{}); + } + static PassSequence NumericPasses() { return Create(IntervalAnalysis{true}, SimplifyAndOrExpr{}, SimplifyBoolean{}); } + static PassSequence PlanPasses() { return Create(LowerToPlan{}, IndexSelection{}, SortLimitFuse{}); } + + static PassSequence Default() { return Merge(ExprPasses(), NumericPasses(), PlanPasses()); } +}; + +} // namespace kqir diff --git a/src/search/passes/push_down_not_expr.h b/src/search/passes/push_down_not_expr.h new file mode 100644 index 00000000000..3c286c0914a --- /dev/null +++ b/src/search/passes/push_down_not_expr.h @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include + +#include "search/ir.h" +#include "search/ir_pass.h" + +namespace kqir { + +struct PushDownNotExpr : Visitor { + std::unique_ptr Visit(std::unique_ptr node) override { + std::unique_ptr res; + + if (auto v = Node::As(std::move(node->inner))) { + v->op = v->Negative(v->op); + return v; + } else if (auto v = Node::As(std::move(node->inner))) { + return std::make_unique(std::move(v)); + } else if (auto v = Node::As(std::move(node->inner))) { + std::vector> nodes; + for (auto& n : v->inners) { + nodes.push_back(std::make_unique(std::move(n))); + } + res = std::make_unique(std::move(nodes)); + } else if (auto v = Node::As(std::move(node->inner))) { + std::vector> nodes; + for (auto& n : v->inners) { + nodes.push_back(std::make_unique(std::move(n))); + } + res = std::make_unique(std::move(nodes)); + } else if (auto v = Node::As(std::move(node->inner))) { + res = std::move(v->inner); + } + + return Visitor::Transform(std::move(res)); + } +}; + +} // namespace kqir diff --git a/src/search/passes/simplify_and_or_expr.h b/src/search/passes/simplify_and_or_expr.h new file mode 100644 index 00000000000..22ac7b57fdd --- /dev/null +++ b/src/search/passes/simplify_and_or_expr.h @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include + +#include "search/ir.h" +#include "search/ir_pass.h" + +namespace kqir { + +struct SimplifyAndOrExpr : Visitor { + std::unique_ptr Visit(std::unique_ptr node) override { + node = Node::MustAs(Visitor::Visit(std::move(node))); + + std::vector> merged_nodes; + for (auto &n : node->inners) { + if (auto v = Node::As(std::move(n))) { + for (auto &m : v->inners) { + merged_nodes.push_back(std::move(m)); + } + } else { + merged_nodes.push_back(std::move(n)); + } + } + + if (merged_nodes.size() == 1) { + return std::move(merged_nodes.front()); + } else { + return std::make_unique(std::move(merged_nodes)); + } + } + + std::unique_ptr Visit(std::unique_ptr node) override { + node = Node::MustAs(Visitor::Visit(std::move(node))); + + std::vector> merged_nodes; + for (auto &n : node->inners) { + if (auto v = Node::As(std::move(n))) { + for (auto &m : v->inners) { + merged_nodes.push_back(std::move(m)); + } + } else { + merged_nodes.push_back(std::move(n)); + } + } + + if (merged_nodes.size() == 1) { + return std::move(merged_nodes.front()); + } else { + return std::make_unique(std::move(merged_nodes)); + } + } +}; + +} // namespace kqir diff --git a/src/search/passes/simplify_boolean.h b/src/search/passes/simplify_boolean.h new file mode 100644 index 00000000000..79281e8ee02 --- /dev/null +++ b/src/search/passes/simplify_boolean.h @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include + +#include "search/ir.h" +#include "search/ir_pass.h" + +namespace kqir { + +struct SimplifyBoolean : Visitor { + std::unique_ptr Visit(std::unique_ptr node) override { + node = Node::MustAs(Visitor::Visit(std::move(node))); + + for (auto iter = node->inners.begin(); iter != node->inners.end();) { + if (auto v = Node::As(std::move(*iter))) { + if (!v->val) { + iter = node->inners.erase(iter); + } else { + return v; + } + } else { + ++iter; + } + } + + if (node->inners.size() == 0) { + return std::make_unique(false); + } else if (node->inners.size() == 1) { + return std::move(node->inners[0]); + } + + return node; + } + + std::unique_ptr Visit(std::unique_ptr node) override { + node = Node::MustAs(Visitor::Visit(std::move(node))); + + for (auto iter = node->inners.begin(); iter != node->inners.end();) { + if (auto v = Node::As(std::move(*iter))) { + if (v->val) { + iter = node->inners.erase(iter); + } else { + return v; + } + } else { + ++iter; + } + } + + if (node->inners.size() == 0) { + return std::make_unique(true); + } else if (node->inners.size() == 1) { + return std::move(node->inners[0]); + } + + return node; + } + + std::unique_ptr Visit(std::unique_ptr node) override { + node = Node::MustAs(Visitor::Visit(std::move(node))); + + if (auto v = Node::As(std::move(node->inner))) { + v->val = !v->val; + return v; + } + + return node; + } +}; + +} // namespace kqir diff --git a/src/search/passes/sort_limit_fuse.h b/src/search/passes/sort_limit_fuse.h new file mode 100644 index 00000000000..0e857297175 --- /dev/null +++ b/src/search/passes/sort_limit_fuse.h @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include + +#include "search/ir.h" +#include "search/ir_pass.h" +#include "search/ir_plan.h" + +namespace kqir { + +struct SortLimitFuse : Visitor { + std::unique_ptr Visit(std::unique_ptr node) override { + node = Node::MustAs(Visitor::Visit(std::move(node))); + + if (auto sort = Node::As(std::move(node->op))) { + return std::make_unique(std::move(sort->op), std::move(sort->order), std::move(node->limit)); + } + + return node; + } +}; + +} // namespace kqir diff --git a/src/search/plan_executor.cc b/src/search/plan_executor.cc new file mode 100644 index 00000000000..9140587e1e9 --- /dev/null +++ b/src/search/plan_executor.cc @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#include "plan_executor.h" + +#include + +#include "search/executors/filter_executor.h" +#include "search/executors/full_index_scan_executor.h" +#include "search/executors/limit_executor.h" +#include "search/executors/merge_executor.h" +#include "search/executors/mock_executor.h" +#include "search/executors/noop_executor.h" +#include "search/executors/numeric_field_scan_executor.h" +#include "search/executors/projection_executor.h" +#include "search/executors/sort_executor.h" +#include "search/executors/tag_field_scan_executor.h" +#include "search/executors/topn_sort_executor.h" +#include "search/indexer.h" +#include "search/ir_plan.h" + +namespace kqir { + +namespace details { + +struct ExecutorContextVisitor { + ExecutorContext *ctx; + + void Transform(PlanOperator *op) { + if (auto v = dynamic_cast(op)) { + return Visit(v); + } + + if (auto v = dynamic_cast(op)) { + return Visit(v); + } + + if (auto v = dynamic_cast(op)) { + return Visit(v); + } + + if (auto v = dynamic_cast(op)) { + return Visit(v); + } + + if (auto v = dynamic_cast(op)) { + return Visit(v); + } + + if (auto v = dynamic_cast(op)) { + return Visit(v); + } + + if (auto v = dynamic_cast(op)) { + return Visit(v); + } + + if (auto v = dynamic_cast(op)) { + return Visit(v); + } + + if (auto v = dynamic_cast(op)) { + return Visit(v); + } + + if (auto v = dynamic_cast(op)) { + return Visit(v); + } + + if (auto v = dynamic_cast(op)) { + return Visit(v); + } + + CHECK(false) << "unreachable"; + } + + void Visit(Limit *op) { + ctx->nodes[op] = std::make_unique(ctx, op); + Transform(op->op.get()); + } + + void Visit(Sort *op) { + ctx->nodes[op] = std::make_unique(ctx, op); + Transform(op->op.get()); + } + + void Visit(Noop *op) { ctx->nodes[op] = std::make_unique(ctx, op); } + + void Visit(Merge *op) { + ctx->nodes[op] = std::make_unique(ctx, op); + for (const auto &child : op->ops) Transform(child.get()); + } + + void Visit(Filter *op) { + ctx->nodes[op] = std::make_unique(ctx, op); + Transform(op->source.get()); + } + + void Visit(Projection *op) { + ctx->nodes[op] = std::make_unique(ctx, op); + Transform(op->source.get()); + } + + void Visit(TopNSort *op) { + ctx->nodes[op] = std::make_unique(ctx, op); + Transform(op->op.get()); + } + + void Visit(FullIndexScan *op) { ctx->nodes[op] = std::make_unique(ctx, op); } + + void Visit(NumericFieldScan *op) { ctx->nodes[op] = std::make_unique(ctx, op); } + + void Visit(TagFieldScan *op) { ctx->nodes[op] = std::make_unique(ctx, op); } + + void Visit(Mock *op) { ctx->nodes[op] = std::make_unique(ctx, op); } +}; + +} // namespace details + +ExecutorContext::ExecutorContext(PlanOperator *op) : root(op) { + details::ExecutorContextVisitor visitor{this}; + visitor.Transform(root); +} + +ExecutorContext::ExecutorContext(PlanOperator *op, engine::Storage *storage) : root(op), storage(storage) { + details::ExecutorContextVisitor visitor{this}; + visitor.Transform(root); +} + +auto ExecutorContext::Retrieve(RowType &row, const FieldInfo *field) -> StatusOr { // NOLINT + if (auto iter = row.fields.find(field); iter != row.fields.end()) { + return iter->second; + } + + auto retriever = GET_OR_RET( + redis::FieldValueRetriever::Create(field->index->metadata.on_data_type, row.key, storage, field->index->ns)); + + auto s = retriever.Retrieve(field->name, field->metadata.get()); + if (!s) return s; + + row.fields.emplace(field, *s); + return *s; +} + +} // namespace kqir diff --git a/src/search/plan_executor.h b/src/search/plan_executor.h new file mode 100644 index 00000000000..0ead68701d2 --- /dev/null +++ b/src/search/plan_executor.h @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include + +#include "ir_plan.h" +#include "search/index_info.h" +#include "search/value.h" +#include "storage/storage.h" +#include "string_util.h" + +namespace kqir { + +struct ExecutorContext; + +struct ExecutorNode { + using KeyType = std::string; + using ValueType = kqir::Value; + struct RowType { + KeyType key; + std::map fields; + const IndexInfo *index; + + bool operator==(const RowType &another) const { + return key == another.key && fields == another.fields && index == another.index; + } + + bool operator!=(const RowType &another) const { return !(*this == another); } + + // for debug purpose + friend std::ostream &operator<<(std::ostream &os, const RowType &row) { + if (row.index) { + os << row.key << "@" << row.index->name; + } else { + os << row.key; + } + return os << " {" << util::StringJoin(row.fields, [](const auto &v) { + return v.first->name + ": " + v.second.ToString(); + }) << "}"; + } + }; + + static constexpr inline const struct End { + } end{}; + friend constexpr bool operator==(End, End) noexcept { return true; } + friend constexpr bool operator!=(End, End) noexcept { return false; } + + using Result = std::variant; + + ExecutorContext *ctx; + explicit ExecutorNode(ExecutorContext *ctx) : ctx(ctx) {} + + virtual StatusOr Next() = 0; + virtual ~ExecutorNode() = default; +}; + +struct ExecutorContext { + std::map> nodes; + PlanOperator *root; + engine::Storage *storage; + + using Result = ExecutorNode::Result; + using RowType = ExecutorNode::RowType; + using KeyType = ExecutorNode::KeyType; + using ValueType = ExecutorNode::ValueType; + + explicit ExecutorContext(PlanOperator *op); + explicit ExecutorContext(PlanOperator *op, engine::Storage *storage); + + ExecutorNode *Get(PlanOperator *op) { + if (auto iter = nodes.find(op); iter != nodes.end()) { + return iter->second.get(); + } + + return nullptr; + } + + ExecutorNode *Get(const std::unique_ptr &op) { return Get(op.get()); } + + StatusOr Next() { return Get(root)->Next(); } + StatusOr Retrieve(RowType &row, const FieldInfo *field); +}; + +} // namespace kqir diff --git a/src/search/redis_query_parser.h b/src/search/redis_query_parser.h new file mode 100644 index 00000000000..d64d0bf0ed1 --- /dev/null +++ b/src/search/redis_query_parser.h @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include + +#include "common_parser.h" + +namespace kqir { + +namespace redis_query { + +using namespace peg; + +struct Field : seq, Identifier> {}; + +struct Tag : sor {}; +struct TagList : seq, WSPad, star, WSPad>>, one<'}'>> {}; + +struct Inf : seq>, string<'i', 'n', 'f'>> {}; +struct ExclusiveNumber : seq, Number> {}; +struct NumericRangePart : sor {}; +struct NumericRange : seq, WSPad, WSPad, one<']'>> {}; + +struct FieldQuery : seq, one<':'>, WSPad>> {}; + +struct Wildcard : one<'*'> {}; + +struct QueryExpr; + +struct ParenExpr : WSPad, QueryExpr, one<')'>>> {}; + +struct NotExpr; + +struct BooleanExpr : sor> {}; + +struct NotExpr : seq>, BooleanExpr> {}; + +struct AndExpr : seq>> {}; +struct AndExprP : sor {}; + +struct OrExpr : seq, AndExprP>>> {}; +struct OrExprP : sor {}; + +struct QueryExpr : seq {}; + +} // namespace redis_query + +} // namespace kqir diff --git a/src/search/redis_query_transformer.h b/src/search/redis_query_transformer.h new file mode 100644 index 00000000000..0928ed592e6 --- /dev/null +++ b/src/search/redis_query_transformer.h @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include + +#include "common_transformer.h" +#include "ir.h" +#include "parse_util.h" +#include "redis_query_parser.h" +#include "search/common_parser.h" + +namespace kqir { + +namespace redis_query { + +namespace ir = kqir; + +template +using TreeSelector = + parse_tree::selector, + parse_tree::remove_content::on>; + +template +StatusOr> ParseToTree(Input&& in) { + if (auto root = parse_tree::parse, TreeSelector>(std::forward(in))) { + return root; + } else { + // TODO: improve parse error message, with source location + return {Status::NotOK, "invalid syntax"}; + } +} + +struct Transformer : ir::TreeTransformer { + static auto Transform(const TreeNode& node) -> StatusOr> { + if (Is(node)) { + return Node::Create(*ParseFloat(node->string())); + } else if (Is(node)) { + return Node::Create(true); + } else if (Is(node)) { + CHECK(node->children.size() == 2); + + auto field = node->children[0]->string(); + const auto& query = node->children[1]; + + if (Is(query)) { + std::vector> exprs; + + for (const auto& tag : query->children) { + auto tag_str = Is(tag) ? tag->string() : GET_OR_RET(UnescapeString(tag->string())); + exprs.push_back(std::make_unique(std::make_unique(field), + std::make_unique(tag_str))); + } + + if (exprs.size() == 1) { + return std::move(exprs[0]); + } else { + return std::make_unique(std::move(exprs)); + } + } else { // NumericRange + std::vector> exprs; + + const auto& lhs = query->children[0]; + const auto& rhs = query->children[1]; + + if (Is(lhs)) { + exprs.push_back(std::make_unique( + NumericCompareExpr::GT, std::make_unique(field), + Node::MustAs(GET_OR_RET(Transform(lhs->children[0]))))); + } else if (Is(lhs)) { + exprs.push_back( + std::make_unique(NumericCompareExpr::GET, std::make_unique(field), + Node::MustAs(GET_OR_RET(Transform(lhs))))); + } else { // Inf + if (lhs->string_view() == "+inf") { + return {Status::NotOK, "it's not allowed to set the lower bound as positive infinity"}; + } + } + + if (Is(rhs)) { + exprs.push_back(std::make_unique( + NumericCompareExpr::LT, std::make_unique(field), + Node::MustAs(GET_OR_RET(Transform(rhs->children[0]))))); + } else if (Is(rhs)) { + exprs.push_back( + std::make_unique(NumericCompareExpr::LET, std::make_unique(field), + Node::MustAs(GET_OR_RET(Transform(rhs))))); + } else { // Inf + if (rhs->string_view() == "-inf") { + return {Status::NotOK, "it's not allowed to set the upper bound as negative infinity"}; + } + } + + if (exprs.empty()) { + return std::make_unique(true); + } else if (exprs.size() == 1) { + return std::move(exprs[0]); + } else { + return std::make_unique(std::move(exprs)); + } + } + } else if (Is(node)) { + CHECK(node->children.size() == 1); + + return Node::Create(Node::MustAs(GET_OR_RET(Transform(node->children[0])))); + } else if (Is(node)) { + std::vector> exprs; + + for (const auto& child : node->children) { + exprs.push_back(Node::MustAs(GET_OR_RET(Transform(child)))); + } + + return Node::Create(std::move(exprs)); + } else if (Is(node)) { + std::vector> exprs; + + for (const auto& child : node->children) { + exprs.push_back(Node::MustAs(GET_OR_RET(Transform(child)))); + } + + return Node::Create(std::move(exprs)); + } else if (IsRoot(node)) { + CHECK(node->children.size() == 1); + + return Transform(node->children[0]); + } else { + // UNREACHABLE CODE, just for debugging here + return {Status::NotOK, fmt::format("encountered invalid node type: {}", node->type)}; + } + } // NOLINT +}; + +template +StatusOr> ParseToIR(Input&& in) { + return Transformer::Transform(GET_OR_RET(ParseToTree(std::forward(in)))); +} + +} // namespace redis_query + +} // namespace kqir diff --git a/src/search/search_encoding.h b/src/search/search_encoding.h index 1637a504c28..68e248bb40b 100644 --- a/src/search/search_encoding.h +++ b/src/search/search_encoding.h @@ -23,28 +23,175 @@ #include #include +#include + namespace redis { +enum class IndexOnDataType : uint8_t { + HASH = kRedisHash, + JSON = kRedisJson, +}; + inline constexpr auto kErrorInsufficientLength = "insufficient length while decoding metadata"; +class IndexMetadata { + public: + uint8_t flag = 0; // all reserved + IndexOnDataType on_data_type; + + void Encode(std::string *dst) const { + PutFixed8(dst, flag); + PutFixed8(dst, uint8_t(on_data_type)); + } + + rocksdb::Status Decode(Slice *input) { + if (!GetFixed8(input, &flag)) { + return rocksdb::Status::InvalidArgument(kErrorInsufficientLength); + } + + if (!GetFixed8(input, reinterpret_cast(&on_data_type))) { + return rocksdb::Status::InvalidArgument(kErrorInsufficientLength); + } + + return rocksdb::Status::OK(); + } +}; + enum class SearchSubkeyType : uint8_t { - // search global metadata + INDEX_META = 0, + PREFIXES = 1, - // field metadata for different types - TAG_FIELD_META = 64 + 1, - NUMERIC_FIELD_META = 64 + 2, + // field metadata + FIELD_META = 2, - // field indexing for different types - TAG_FIELD = 128 + 1, - NUMERIC_FIELD = 128 + 2, + // field indexing data + FIELD = 3, + + // field alias + FIELD_ALIAS = 4, }; -inline std::string ConstructSearchPrefixesSubkey() { return {(char)SearchSubkeyType::PREFIXES}; } +enum class IndexFieldType : uint8_t { + TAG = 1, + + NUMERIC = 2, +}; + +struct SearchKey { + std::string_view ns; + std::string_view index; + std::string_view field; + + SearchKey(std::string_view ns, std::string_view index) : ns(ns), index(index) {} + SearchKey(std::string_view ns, std::string_view index, std::string_view field) : ns(ns), index(index), field(field) {} + + void PutNamespace(std::string *dst) const { + PutFixed8(dst, ns.size()); + dst->append(ns); + } + + static void PutType(std::string *dst, SearchSubkeyType type) { PutFixed8(dst, uint8_t(type)); } + + void PutIndex(std::string *dst) const { PutSizedString(dst, index); } + + std::string ConstructIndexMeta() const { + std::string dst; + PutNamespace(&dst); + PutType(&dst, SearchSubkeyType::INDEX_META); + PutIndex(&dst); + return dst; + } + + std::string ConstructIndexPrefixes() const { + std::string dst; + PutNamespace(&dst); + PutType(&dst, SearchSubkeyType::PREFIXES); + PutIndex(&dst); + return dst; + } + + std::string ConstructFieldMeta() const { + std::string dst; + PutNamespace(&dst); + PutType(&dst, SearchSubkeyType::FIELD_META); + PutIndex(&dst); + PutSizedString(&dst, field); + return dst; + } + + std::string ConstructAllFieldMetaBegin() const { + std::string dst; + PutNamespace(&dst); + PutType(&dst, SearchSubkeyType::FIELD_META); + PutIndex(&dst); + PutFixed32(&dst, 0); + return dst; + } -struct SearchPrefixesMetadata { + std::string ConstructAllFieldMetaEnd() const { + std::string dst; + PutNamespace(&dst); + PutType(&dst, SearchSubkeyType::FIELD_META); + PutIndex(&dst); + PutFixed32(&dst, (uint32_t)(-1)); + return dst; + } + + std::string ConstructAllFieldDataBegin() const { + std::string dst; + PutNamespace(&dst); + PutType(&dst, SearchSubkeyType::FIELD); + PutIndex(&dst); + PutFixed32(&dst, 0); + return dst; + } + + std::string ConstructAllFieldDataEnd() const { + std::string dst; + PutNamespace(&dst); + PutType(&dst, SearchSubkeyType::FIELD); + PutIndex(&dst); + PutFixed32(&dst, (uint32_t)(-1)); + return dst; + } + + std::string ConstructTagFieldData(std::string_view tag, std::string_view key) const { + std::string dst; + PutNamespace(&dst); + PutType(&dst, SearchSubkeyType::FIELD); + PutIndex(&dst); + PutSizedString(&dst, field); + PutSizedString(&dst, tag); + PutSizedString(&dst, key); + return dst; + } + + std::string ConstructNumericFieldData(double num, std::string_view key) const { + std::string dst; + PutNamespace(&dst); + PutType(&dst, SearchSubkeyType::FIELD); + PutIndex(&dst); + PutSizedString(&dst, field); + PutDouble(&dst, num); + PutSizedString(&dst, key); + return dst; + } +}; + +struct IndexPrefixes { std::vector prefixes; + static inline const std::string all[] = {""}; + + auto begin() const { // NOLINT + return prefixes.empty() ? std::begin(all) : prefixes.data(); + } + + auto end() const { // NOLINT + return prefixes.empty() ? std::end(all) : prefixes.data() + prefixes.size(); + } + void Encode(std::string *dst) const { for (const auto &prefix : prefixes) { PutFixed32(dst, prefix.size()); @@ -65,15 +212,34 @@ struct SearchPrefixesMetadata { } }; -struct SearchFieldMetadata { +struct IndexFieldMetadata { bool noindex = false; + IndexFieldType type; + + explicit IndexFieldMetadata(IndexFieldType type) : type(type) {} - // flag: - uint8_t MakeFlag() const { return noindex; } + // flag: + uint8_t MakeFlag() const { return noindex | (uint8_t)type << 1; } - void DecodeFlag(uint8_t flag) { noindex = flag & 1; } + void DecodeFlag(uint8_t flag) { + noindex = flag & 1; + type = DecodeType(flag); + } + + static IndexFieldType DecodeType(uint8_t flag) { return IndexFieldType(flag >> 1); } + + virtual ~IndexFieldMetadata() = default; - virtual ~SearchFieldMetadata() = default; + std::string_view Type() const { + switch (type) { + case IndexFieldType::TAG: + return "tag"; + case IndexFieldType::NUMERIC: + return "numeric"; + default: + return "unknown"; + } + } virtual void Encode(std::string *dst) const { PutFixed8(dst, MakeFlag()); } @@ -86,30 +252,30 @@ struct SearchFieldMetadata { DecodeFlag(flag); return rocksdb::Status::OK(); } -}; -inline std::string ConstructTagFieldMetadataSubkey(std::string_view field_name) { - std::string res = {(char)SearchSubkeyType::TAG_FIELD_META}; - res.append(field_name); - return res; -} + virtual bool IsSortable() const { return false; } + + static inline rocksdb::Status Decode(Slice *input, std::unique_ptr &ptr); +}; -struct SearchTagFieldMetadata : SearchFieldMetadata { +struct TagFieldMetadata : IndexFieldMetadata { char separator = ','; bool case_sensitive = false; + TagFieldMetadata() : IndexFieldMetadata(IndexFieldType::TAG) {} + void Encode(std::string *dst) const override { - SearchFieldMetadata::Encode(dst); + IndexFieldMetadata::Encode(dst); PutFixed8(dst, separator); PutFixed8(dst, case_sensitive); } rocksdb::Status Decode(Slice *input) override { - if (auto s = SearchFieldMetadata::Decode(input); !s.ok()) { + if (auto s = IndexFieldMetadata::Decode(input); !s.ok()) { return s; } - if (input->size() < 8 + 8) { + if (input->size() < 2) { return rocksdb::Status::Corruption(kErrorInsufficientLength); } @@ -119,33 +285,29 @@ struct SearchTagFieldMetadata : SearchFieldMetadata { } }; -inline std::string ConstructNumericFieldMetadataSubkey(std::string_view field_name) { - std::string res = {(char)SearchSubkeyType::NUMERIC_FIELD_META}; - res.append(field_name); - return res; -} +struct NumericFieldMetadata : IndexFieldMetadata { + NumericFieldMetadata() : IndexFieldMetadata(IndexFieldType::NUMERIC) {} -struct SearchNumericFieldMetadata : SearchFieldMetadata {}; - -inline std::string ConstructTagFieldSubkey(std::string_view field_name, std::string_view tag, std::string_view key) { - std::string res = {(char)SearchSubkeyType::TAG_FIELD}; - PutFixed32(&res, field_name.size()); - res.append(field_name); - PutFixed32(&res, tag.size()); - res.append(tag); - PutFixed32(&res, key.size()); - res.append(key); - return res; -} + bool IsSortable() const override { return true; } +}; + +inline rocksdb::Status IndexFieldMetadata::Decode(Slice *input, std::unique_ptr &ptr) { + if (input->size() < 1) { + return rocksdb::Status::Corruption(kErrorInsufficientLength); + } + + switch (DecodeType((*input)[0])) { + case IndexFieldType::TAG: + ptr = std::make_unique(); + break; + case IndexFieldType::NUMERIC: + ptr = std::make_unique(); + break; + default: + return rocksdb::Status::Corruption("encountered unknown field type"); + } -inline std::string ConstructNumericFieldSubkey(std::string_view field_name, double number, std::string_view key) { - std::string res = {(char)SearchSubkeyType::NUMERIC_FIELD}; - PutFixed32(&res, field_name.size()); - res.append(field_name); - PutDouble(&res, number); - PutFixed32(&res, key.size()); - res.append(key); - return res; + return ptr->Decode(input); } } // namespace redis diff --git a/src/search/sql_parser.h b/src/search/sql_parser.h new file mode 100644 index 00000000000..26e3da6d5b4 --- /dev/null +++ b/src/search/sql_parser.h @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include + +#include "common_parser.h" + +namespace kqir { + +namespace sql { + +using namespace peg; + +struct HasTag : string<'h', 'a', 's', 't', 'a', 'g'> {}; +struct HasTagExpr : WSPad, StringL>> {}; + +struct NumericAtomExpr : WSPad> {}; +struct NumericCompareOp : sor', '='>, one<'=', '<', '>'>> {}; +struct NumericCompareExpr : seq {}; + +struct BooleanAtomExpr : sor> {}; + +struct QueryExpr; + +struct ParenExpr : WSPad, QueryExpr, one<')'>>> {}; + +struct NotExpr; + +struct BooleanExpr : sor {}; + +struct Not : string<'n', 'o', 't'> {}; +struct NotExpr : seq, BooleanExpr> {}; + +struct And : string<'a', 'n', 'd'> {}; +// left recursion elimination +// struct AndExpr : sor, BooleanExpr> {}; +struct AndExpr : seq>> {}; +struct AndExprP : sor {}; + +struct Or : string<'o', 'r'> {}; +// left recursion elimination +// struct OrExpr : sor, AndExpr> {}; +struct OrExpr : seq>> {}; +struct OrExprP : sor {}; + +struct QueryExpr : seq {}; + +struct Select : string<'s', 'e', 'l', 'e', 'c', 't'> {}; +struct From : string<'f', 'r', 'o', 'm'> {}; + +struct Wildcard : one<'*'> {}; +struct IdentifierList : seq>, Identifier>> {}; +struct SelectExpr : WSPad> {}; +struct FromExpr : WSPad {}; + +struct Where : string<'w', 'h', 'e', 'r', 'e'> {}; +struct OrderBy : seq, plus, string<'b', 'y'>> {}; +struct Asc : string<'a', 's', 'c'> {}; +struct Desc : string<'d', 'e', 's', 'c'> {}; +struct Limit : string<'l', 'i', 'm', 'i', 't'> {}; + +struct WhereClause : seq {}; +struct AscOrDesc : WSPad> {}; +struct OrderByClause : seq, opt> {}; +struct LimitClause : seq, one<','>>>, WSPad> {}; + +struct SearchStmt + : WSPad, opt, opt>> {}; + +} // namespace sql + +} // namespace kqir diff --git a/src/search/sql_transformer.h b/src/search/sql_transformer.h new file mode 100644 index 00000000000..972ae894b53 --- /dev/null +++ b/src/search/sql_transformer.h @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include +#include +#include + +#include "common_transformer.h" +#include "ir.h" +#include "parse_util.h" +#include "sql_parser.h" + +namespace kqir { + +namespace sql { + +namespace ir = kqir; + +template +using TreeSelector = parse_tree::selector< + Rule, + parse_tree::store_content::on, + parse_tree::remove_content::on>; + +template +StatusOr> ParseToTree(Input&& in) { + if (auto root = parse_tree::parse, TreeSelector>(std::forward(in))) { + return root; + } else { + // TODO: improve parse error message, with source location + return {Status::NotOK, "invalid syntax"}; + } +} + +struct Transformer : ir::TreeTransformer { + static auto Transform(const TreeNode& node) -> StatusOr> { + if (Is(node)) { + return Node::Create(node->string_view() == "true"); + } else if (Is(node)) { + return Node::Create(*ParseFloat(node->string())); + } else if (Is(node)) { + return Node::Create(GET_OR_RET(UnescapeString(node->string_view()))); + } else if (Is(node)) { + CHECK(node->children.size() == 2); + + return Node::Create( + std::make_unique(node->children[0]->string()), + Node::MustAs(GET_OR_RET(Transform(node->children[1])))); + } else if (Is(node)) { + CHECK(node->children.size() == 3); + + const auto& lhs = node->children[0]; + const auto& rhs = node->children[2]; + + auto op = ir::NumericCompareExpr::FromOperator(node->children[1]->string_view()).value(); + if (Is(lhs) && Is(rhs)) { + return Node::Create(op, std::make_unique(lhs->string()), + Node::MustAs(GET_OR_RET(Transform(rhs)))); + } else if (Is(lhs) && Is(rhs)) { + return Node::Create(ir::NumericCompareExpr::Flip(op), + std::make_unique(rhs->string()), + Node::MustAs(GET_OR_RET(Transform(lhs)))); + } else { + return {Status::NotOK, "the left and right side of numeric comparison should be an identifier and a number"}; + } + } else if (Is(node)) { + CHECK(node->children.size() == 1); + + return Node::Create(Node::MustAs(GET_OR_RET(Transform(node->children[0])))); + } else if (Is(node)) { + std::vector> exprs; + + for (const auto& child : node->children) { + exprs.push_back(Node::MustAs(GET_OR_RET(Transform(child)))); + } + + return Node::Create(std::move(exprs)); + } else if (Is(node)) { + std::vector> exprs; + + for (const auto& child : node->children) { + exprs.push_back(Node::MustAs(GET_OR_RET(Transform(child)))); + } + + return Node::Create(std::move(exprs)); + } else if (Is(node)) { + std::vector> fields; + + if (node->children.size() == 1 && Is(node->children[0])) { + return Node::Create(std::move(fields)); + } + + for (const auto& child : node->children) { + fields.push_back(std::make_unique(child->string())); + } + + return Node::Create(std::move(fields)); + } else if (Is(node)) { + CHECK(node->children.size() == 1); + return Node::Create(node->children[0]->string()); + } else if (Is(node)) { + CHECK(node->children.size() == 1); + return Transform(node->children[0]); + } else if (Is(node)) { + CHECK(node->children.size() == 1 || node->children.size() == 2); + + size_t offset = 0, count = std::numeric_limits::max(); + if (node->children.size() == 1) { + count = *ParseInt(node->children[0]->string()); + } else { + offset = *ParseInt(node->children[0]->string()); + count = *ParseInt(node->children[1]->string()); + } + + return Node::Create(offset, count); + } + if (Is(node)) { + CHECK(node->children.size() == 1 || node->children.size() == 2); + + auto field = std::make_unique(node->children[0]->string()); + auto order = SortByClause::Order::ASC; + if (node->children.size() == 2 && node->children[1]->string_view() == "desc") { + order = SortByClause::Order::DESC; + } + + return Node::Create(order, std::move(field)); + } else if (Is(node)) { // root node + CHECK(node->children.size() >= 2 && node->children.size() <= 5); + + auto index = Node::MustAs(GET_OR_RET(Transform(node->children[1]))); + auto select = Node::MustAs(GET_OR_RET(Transform(node->children[0]))); + + std::unique_ptr query_expr; + std::unique_ptr limit; + std::unique_ptr sort_by; + + for (size_t i = 2; i < node->children.size(); ++i) { + if (Is(node->children[i])) { + query_expr = Node::MustAs(GET_OR_RET(Transform(node->children[i]))); + } else if (Is(node->children[i])) { + limit = Node::MustAs(GET_OR_RET(Transform(node->children[i]))); + } else if (Is(node->children[i])) { + sort_by = Node::MustAs(GET_OR_RET(Transform(node->children[i]))); + } + } + + if (!query_expr) { + query_expr = std::make_unique(true); + } + + return Node::Create(std::move(index), std::move(query_expr), std::move(limit), std::move(sort_by), + std::move(select)); + } else if (IsRoot(node)) { + CHECK(node->children.size() == 1); + + return Transform(node->children[0]); + } else { + // UNREACHABLE CODE, just for debugging here + return {Status::NotOK, fmt::format("encountered invalid node type: {}", node->type)}; + } + } +}; + +template +StatusOr> ParseToIR(Input&& in) { + return Transformer::Transform(GET_OR_RET(ParseToTree(std::forward(in)))); +} + +} // namespace sql + +} // namespace kqir diff --git a/src/search/value.h b/src/search/value.h new file mode 100644 index 00000000000..f339571788d --- /dev/null +++ b/src/search/value.h @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include +#include +#include +#include + +#include "fmt/core.h" +#include "search/search_encoding.h" +#include "string_util.h" + +namespace kqir { + +using Null = std::monostate; + +using Numeric = double; // used for numeric fields + +using String = std::string; // e.g. a single tag + +using NumericArray = std::vector; // used for vector fields +using StringArray = std::vector; // used for tag fields, e.g. a list for tags + +struct Value : std::variant { + using Base = std::variant; + + using Base::Base; + + bool IsNull() const { return Is(); } + + template + bool Is() const { + return std::holds_alternative(*this); + } + + template + const auto &Get() const { + CHECK(Is()); + return std::get(*this); + } + + template + auto &Get() { + CHECK(Is()); + return std::get(*this); + } + + std::string ToString(const std::string &sep = ",") const { + if (IsNull()) { + return ""; + } else if (Is()) { + return fmt::format("{}", Get()); + } else if (Is()) { + return util::StringJoin( + Get(), [](const auto &v) -> decltype(auto) { return v; }, sep); + } else if (Is()) { + return util::StringJoin( + Get(), [](const auto &v) -> decltype(auto) { return std::to_string(v); }, sep); + } + + __builtin_unreachable(); + } + + std::string ToString(redis::IndexFieldMetadata *meta) const { + if (IsNull()) { + return ""; + } else if (Is()) { + return fmt::format("{}", Get()); + } else if (Is()) { + auto tag = dynamic_cast(meta); + char sep = tag ? tag->separator : ','; + return util::StringJoin( + Get(), [](const auto &v) -> decltype(auto) { return v; }, std::string(1, sep)); + } else if (Is()) { + return util::StringJoin(Get(), [](const auto &v) -> decltype(auto) { return std::to_string(v); }); + } + + __builtin_unreachable(); + } +}; + +template +auto MakeValue(Args &&...args) { + return Value(std::in_place_type, std::forward(args)...); +} + +} // namespace kqir diff --git a/src/server/namespace.cc b/src/server/namespace.cc index 01a44dcfe40..31b169ed200 100644 --- a/src/server/namespace.cc +++ b/src/server/namespace.cc @@ -51,37 +51,53 @@ bool Namespace::IsAllowModify() const { return config->HasConfigFile() || config->repl_namespace_enabled; } -Status Namespace::LoadAndRewrite() { - auto config = storage_->GetConfig(); - // Load from the configuration file first - tokens_ = config->load_tokens; - - // We would like to load namespaces from db even if repl_namespace_enabled is false, - // this can avoid missing some namespaces when turn on/off repl_namespace_enabled. +Status Namespace::loadFromDB(std::map* db_tokens) const { std::string value; auto s = storage_->Get(rocksdb::ReadOptions(), cf_, kNamespaceDBKey, &value); - if (!s.ok() && !s.IsNotFound()) { + if (!s.ok()) { + if (s.IsNotFound()) return Status::OK(); return {Status::NotOK, s.ToString()}; } - if (s.ok()) { - // The namespace db key is existed, so it doesn't allow to switch off repl_namespace_enabled - if (!config->repl_namespace_enabled) { - return {Status::NotOK, "cannot switch off repl_namespace_enabled when namespaces exist in db"}; - } - jsoncons::json j = jsoncons::json::parse(value); - for (const auto& iter : j.object_range()) { - if (tokens_.find(iter.key()) == tokens_.end()) { - // merge the namespace from db - tokens_[iter.key()] = iter.value().as(); - } + jsoncons::json j = jsoncons::json::parse(value); + for (const auto& iter : j.object_range()) { + db_tokens->insert({iter.key(), iter.value().as_string()}); + } + return Status::OK(); +} + +Status Namespace::LoadAndRewrite() { + auto config = storage_->GetConfig(); + // Namespace is NOT allowed in the cluster mode, so we don't need to rewrite here. + if (config->cluster_enabled) return Status::OK(); + + std::map db_tokens; + auto s = loadFromDB(&db_tokens); + if (!s.IsOK()) return s; + + if (!db_tokens.empty() && !config->repl_namespace_enabled) { + return {Status::NotOK, "cannot switch off repl_namespace_enabled when namespaces exist in db"}; + } + + std::unique_lock lock(tokens_mu_); + // Load from the configuration file first + tokens_ = config->load_tokens; + // Merge the tokens from the database if the token is not in the configuration file + for (const auto& iter : db_tokens) { + if (tokens_.find(iter.first) == tokens_.end()) { + tokens_[iter.first] = iter.second; } } - return Rewrite(); + // The following rewrite is to remove namespace/token pairs from the configuration if the namespace replication + // is enabled. So we don't need to do that if no tokens are loaded or the namespace replication is disabled. + if (config->load_tokens.empty() || !config->repl_namespace_enabled) return Status::OK(); + + return Rewrite(tokens_); } -StatusOr Namespace::Get(const std::string& ns) const { +StatusOr Namespace::Get(const std::string& ns) { + std::shared_lock lock(tokens_mu_); for (const auto& iter : tokens_) { if (iter.second == ns) { return iter.first; @@ -90,7 +106,8 @@ StatusOr Namespace::Get(const std::string& ns) const { return {Status::NotFound}; } -StatusOr Namespace::GetByToken(const std::string& token) const { +StatusOr Namespace::GetByToken(const std::string& token) { + std::shared_lock lock(tokens_mu_); auto iter = tokens_.find(token); if (iter == tokens_.end()) { return {Status::NotFound}; @@ -118,6 +135,7 @@ Status Namespace::Set(const std::string& ns, const std::string& token) { return {Status::NotOK, kErrInvalidToken}; } + std::unique_lock lock(tokens_mu_); for (const auto& iter : tokens_) { if (iter.second == ns) { // need to delete the old token first tokens_.erase(iter.first); @@ -126,7 +144,7 @@ Status Namespace::Set(const std::string& ns, const std::string& token) { } tokens_[token] = ns; - s = Rewrite(); + s = Rewrite(tokens_); if (!s.IsOK()) { tokens_.erase(token); return s; @@ -135,17 +153,22 @@ Status Namespace::Set(const std::string& ns, const std::string& token) { } Status Namespace::Add(const std::string& ns, const std::string& token) { - // duplicate namespace - for (const auto& iter : tokens_) { - if (iter.second == ns) { - if (iter.first == token) return Status::OK(); - return {Status::NotOK, kErrNamespaceExists}; + { + std::shared_lock lock(tokens_mu_); + // duplicate namespace + for (const auto& iter : tokens_) { + if (iter.second == ns) { + if (iter.first == token) return Status::OK(); + return {Status::NotOK, kErrNamespaceExists}; + } + } + // duplicate token + if (tokens_.find(token) != tokens_.end()) { + return {Status::NotOK, kErrTokenExists}; } } - // duplicate token - if (tokens_.find(token) != tokens_.end()) { - return {Status::NotOK, kErrTokenExists}; - } + + // we don't need to lock the mutex here because the Set method will lock it return Set(ns, token); } @@ -157,10 +180,11 @@ Status Namespace::Del(const std::string& ns) { return {Status::NotOK, kErrCantModifyNamespace}; } + std::unique_lock lock(tokens_mu_); for (const auto& iter : tokens_) { if (iter.second == ns) { tokens_.erase(iter.first); - auto s = Rewrite(); + auto s = Rewrite(tokens_); if (!s.IsOK()) { tokens_[iter.first] = iter.second; return s; @@ -171,11 +195,11 @@ Status Namespace::Del(const std::string& ns) { return {Status::NotOK, kErrNamespaceNotFound}; } -Status Namespace::Rewrite() { +Status Namespace::Rewrite(const std::map& tokens) const { auto config = storage_->GetConfig(); // Rewrite the configuration file only if it's running with the configuration file if (config->HasConfigFile()) { - auto s = config->Rewrite(tokens_); + auto s = config->Rewrite(tokens); if (!s.IsOK()) { return s; } @@ -192,7 +216,7 @@ Status Namespace::Rewrite() { return Status::OK(); } jsoncons::json json; - for (const auto& iter : tokens_) { + for (const auto& iter : tokens) { json[iter.first] = iter.second; } return storage_->WriteToPropagateCF(kNamespaceDBKey, json.to_string()); diff --git a/src/server/namespace.h b/src/server/namespace.h index 943a3ecead3..3e22d382419 100644 --- a/src/server/namespace.h +++ b/src/server/namespace.h @@ -26,26 +26,30 @@ constexpr const char *kNamespaceDBKey = "__namespace_keys__"; class Namespace { public: - explicit Namespace(engine::Storage *storage) : storage_(storage) { - cf_ = storage_->GetCFHandle(engine::kPropagateColumnFamilyName); - } + explicit Namespace(engine::Storage *storage) + : storage_(storage), cf_(storage_->GetCFHandle(ColumnFamilyID::Propagate)) {} ~Namespace() = default; Namespace(const Namespace &) = delete; Namespace &operator=(const Namespace &) = delete; Status LoadAndRewrite(); - StatusOr Get(const std::string &ns) const; - StatusOr GetByToken(const std::string &token) const; + StatusOr Get(const std::string &ns); + StatusOr GetByToken(const std::string &token); Status Set(const std::string &ns, const std::string &token); Status Add(const std::string &ns, const std::string &token); Status Del(const std::string &ns); const std::map &List() const { return tokens_; } - Status Rewrite(); + Status Rewrite(const std::map &tokens) const; bool IsAllowModify() const; private: engine::Storage *storage_; rocksdb::ColumnFamilyHandle *cf_ = nullptr; + + std::shared_mutex tokens_mu_; + // mapping from token to namespace name std::map tokens_; + + Status loadFromDB(std::map *db_tokens) const; }; diff --git a/src/server/redis_connection.cc b/src/server/redis_connection.cc index 19bcbcebb2d..0457c71d57e 100644 --- a/src/server/redis_connection.cc +++ b/src/server/redis_connection.cc @@ -26,7 +26,10 @@ #include #include "commands/commander.h" +#include "commands/error_constants.h" #include "fmt/format.h" +#include "search/indexer.h" +#include "server/redis_reply.h" #include "string_util.h" #ifdef ENABLE_OPENSSL #include @@ -84,7 +87,7 @@ void Connection::OnRead(struct bufferevent *bev) { auto s = req_.Tokenize(Input()); if (!s.IsOK()) { EnableFlag(redis::Connection::kCloseAfterReply); - Reply(redis::Error("ERR " + s.Msg())); + Reply(redis::Error(s)); LOG(INFO) << "[connection] Failed to tokenize the request. Error: " << s.Msg(); return; } @@ -138,7 +141,7 @@ std::string Connection::Bool(bool b) const { } std::string Connection::MultiBulkString(const std::vector &values) const { - std::string result = "*" + std::to_string(values.size()) + CRLF; + std::string result = MultiLen(values.size()); for (const auto &value : values) { if (value.empty()) { result += NilString(); @@ -151,7 +154,7 @@ std::string Connection::MultiBulkString(const std::vector &values) std::string Connection::MultiBulkString(const std::vector &values, const std::vector &statuses) const { - std::string result = "*" + std::to_string(values.size()) + CRLF; + std::string result = MultiLen(values.size()); for (size_t i = 0; i < values.size(); i++) { if (i < statuses.size() && !statuses[i].ok()) { result += NilString(); @@ -220,6 +223,7 @@ std::string Connection::GetFlags() const { if (IsFlagEnabled(kSlave)) flags.append("S"); if (IsFlagEnabled(kCloseAfterReply)) flags.append("c"); if (IsFlagEnabled(kMonitor)) flags.append("M"); + if (IsFlagEnabled(kAsking)) flags.append("A"); if (!subscribe_channels_.empty() || !subscribe_patterns_.empty()) flags.append("P"); if (flags.empty()) flags = "N"; return flags; @@ -408,34 +412,39 @@ Status Connection::ExecuteCommand(const std::string &cmd_name, const std::vector return s; } +static bool IsHashOrJsonCommand(const std::string &cmd) { + return util::HasPrefix(cmd, "h") || util::HasPrefix(cmd, "json."); +} + void Connection::ExecuteCommands(std::deque *to_process_cmds) { - Config *config = srv_->GetConfig(); - std::string reply, password = config->requirepass; + const Config *config = srv_->GetConfig(); + std::string reply; + std::string password = config->requirepass; while (!to_process_cmds->empty()) { - auto cmd_tokens = to_process_cmds->front(); + CommandTokens cmd_tokens = std::move(to_process_cmds->front()); to_process_cmds->pop_front(); if (cmd_tokens.empty()) continue; bool is_multi_exec = IsFlagEnabled(Connection::kMultiExec); if (IsFlagEnabled(redis::Connection::kCloseAfterReply) && !is_multi_exec) break; - auto cmd_s = srv_->LookupAndCreateCommand(cmd_tokens.front()); + auto cmd_s = Server::LookupAndCreateCommand(cmd_tokens.front()); if (!cmd_s.IsOK()) { if (is_multi_exec) multi_error_ = true; - Reply(redis::Error("ERR unknown command " + cmd_tokens.front())); + Reply(redis::Error({Status::NotOK, "unknown command " + cmd_tokens.front()})); continue; } auto current_cmd = std::move(*cmd_s); - const auto attributes = current_cmd->GetAttributes(); + const auto &attributes = current_cmd->GetAttributes(); auto cmd_name = attributes->name; auto cmd_flags = attributes->GenerateFlags(cmd_tokens); if (GetNamespace().empty()) { if (!password.empty()) { if (cmd_name != "auth" && cmd_name != "hello") { - Reply(redis::Error("NOAUTH Authentication required.")); + Reply(redis::Error({Status::RedisNoAuth, "Authentication required."})); continue; } } else { @@ -468,7 +477,7 @@ void Connection::ExecuteCommands(std::deque *to_process_cmds) { } if (srv_->IsLoading() && !(cmd_flags & kCmdLoading)) { - Reply(redis::Error("LOADING kvrocks is restoring the db from backup")); + Reply(redis::Error({Status::RedisLoading, errRestoringBackup})); if (is_multi_exec) multi_error_ = true; continue; } @@ -476,7 +485,7 @@ void Connection::ExecuteCommands(std::deque *to_process_cmds) { int tokens = static_cast(cmd_tokens.size()); if (!attributes->CheckArity(tokens)) { if (is_multi_exec) multi_error_ = true; - Reply(redis::Error("ERR wrong number of arguments")); + Reply(redis::Error({Status::NotOK, "wrong number of arguments"})); continue; } @@ -484,12 +493,12 @@ void Connection::ExecuteCommands(std::deque *to_process_cmds) { auto s = current_cmd->Parse(); if (!s.IsOK()) { if (is_multi_exec) multi_error_ = true; - Reply(redis::Error("ERR " + s.Msg())); + Reply(redis::Error(s)); continue; } if (is_multi_exec && (cmd_flags & kCmdNoMulti)) { - Reply(redis::Error("ERR Can't execute " + cmd_name + " in MULTI")); + Reply(redis::Error({Status::NotOK, "Can't execute " + cmd_name + " in MULTI"})); multi_error_ = true; continue; } @@ -498,11 +507,16 @@ void Connection::ExecuteCommands(std::deque *to_process_cmds) { s = srv_->cluster->CanExecByMySelf(attributes, cmd_tokens, this); if (!s.IsOK()) { if (is_multi_exec) multi_error_ = true; - Reply(redis::Error(s.Msg())); + Reply(redis::Error(s)); continue; } } + // reset the ASKING flag after executing the next query + if (IsFlagEnabled(kAsking)) { + DisableFlag(kAsking); + } + // We don't execute commands, but queue them, ant then execute in EXEC command if (is_multi_exec && !in_exec_ && !(cmd_flags & kCmdMulti)) { multi_cmds_.emplace_back(cmd_tokens); @@ -511,26 +525,53 @@ void Connection::ExecuteCommands(std::deque *to_process_cmds) { } if (config->slave_readonly && srv_->IsSlave() && (cmd_flags & kCmdWrite)) { - Reply(redis::Error("READONLY You can't write against a read only slave.")); + Reply(redis::Error({Status::RedisReadOnly, "You can't write against a read only slave."})); continue; } if ((cmd_flags & kCmdWrite) && !(cmd_flags & kCmdNoDBSizeCheck) && srv_->storage->ReachedDBSizeLimit()) { - Reply(redis::Error("ERR write command not allowed when reached max-db-size.")); + Reply(redis::Error({Status::NotOK, "write command not allowed when reached max-db-size."})); continue; } if (!config->slave_serve_stale_data && srv_->IsSlave() && cmd_name != "info" && cmd_name != "slaveof" && srv_->GetReplicationState() != kReplConnected) { - Reply( - redis::Error("MASTERDOWN Link with MASTER is down " - "and slave-serve-stale-data is set to 'no'.")); + Reply(redis::Error({Status::RedisMasterDown, + "Link with MASTER is down " + "and slave-serve-stale-data is set to 'no'."})); continue; } + // TODO: transaction support for index recording + std::vector index_records; + if (IsHashOrJsonCommand(cmd_name) && (attributes->flags & redis::kCmdWrite) && !config->cluster_enabled) { + attributes->ForEachKeyRange( + [&, this](const std::vector &args, const CommandKeyRange &key_range) { + key_range.ForEachKey( + [&, this](const std::string &key) { + auto res = srv_->indexer.Record(key, ns_); + if (res.IsOK()) { + index_records.push_back(*res); + } else if (!res.Is() && !res.Is()) { + LOG(WARNING) << "index recording failed for key: " << key; + } + }, + args); + }, + cmd_tokens); + } + SetLastCmd(cmd_name); s = ExecuteCommand(cmd_name, cmd_tokens, current_cmd.get(), &reply); + // TODO: transaction support for index updating + for (const auto &record : index_records) { + auto s = GlobalIndexer::Update(record); + if (!s.IsOK() && !s.Is()) { + LOG(WARNING) << "index updating failed for key: " << record.key; + } + } + // Break the execution loop when occurring the blocking command like BLPOP or BRPOP, // it will suspend the connection and wait for the wakeup signal. if (s.Is()) { @@ -544,7 +585,7 @@ void Connection::ExecuteCommands(std::deque *to_process_cmds) { // Reply for MULTI if (!s.IsOK()) { - Reply(redis::Error("ERR " + s.Msg())); + Reply(redis::Error(s)); continue; } diff --git a/src/server/redis_connection.h b/src/server/redis_connection.h index 9faf0c2700a..f3015fcfd6c 100644 --- a/src/server/redis_connection.h +++ b/src/server/redis_connection.h @@ -45,6 +45,8 @@ class Connection : public EvbufCallbackBase { kCloseAfterReply = 1 << 6, kCloseAsync = 1 << 7, kMultiExec = 1 << 8, + kReadOnly = 1 << 9, + kAsking = 1 << 10, }; explicit Connection(bufferevent *bev, Worker *owner); @@ -101,6 +103,9 @@ class Connection : public EvbufCallbackBase { std::string HeaderOfAttribute(T len) const { return "|" + std::to_string(len) + CRLF; } + std::string HeaderOfPush(int64_t len) const { + return protocol_version_ == RESP::v3 ? ">" + std::to_string(len) + CRLF : MultiLen(len); + } using UnsubscribeCallback = std::function; void SubscribeChannel(const std::string &channel); diff --git a/src/server/redis_reply.cc b/src/server/redis_reply.cc index 95bbf9fde03..20e15e512f8 100644 --- a/src/server/redis_reply.cc +++ b/src/server/redis_reply.cc @@ -20,21 +20,43 @@ #include "redis_reply.h" +#include #include +const std::map redisErrorPrefixMapping = { + {Status::RedisErrorNoPrefix, ""}, {Status::RedisNoProto, "NOPROTO"}, + {Status::RedisLoading, "LOADING"}, {Status::RedisMasterDown, "MASTERDOWN"}, + {Status::RedisNoScript, "NOSCRIPT"}, {Status::RedisNoAuth, "NOAUTH"}, + {Status::RedisWrongType, "WRONGTYPE"}, {Status::RedisReadOnly, "READONLY"}, + {Status::RedisExecAbort, "EXECABORT"}, {Status::RedisMoved, "MOVED"}, + {Status::RedisCrossSlot, "CROSSSLOT"}, {Status::RedisTryAgain, "TRYAGAIN"}, + {Status::RedisClusterDown, "CLUSTERDOWN"}}; + namespace redis { void Reply(evbuffer *output, const std::string &data) { evbuffer_add(output, data.c_str(), data.length()); } std::string SimpleString(const std::string &data) { return "+" + data + CRLF; } -std::string Error(const std::string &err) { return "-" + err + CRLF; } +std::string Error(const Status &s) { return RESP_PREFIX_ERROR + StatusToRedisErrorMsg(s) + CRLF; } + +std::string StatusToRedisErrorMsg(const Status &s) { + CHECK(!s.IsOK()); + std::string prefix = "ERR"; + if (auto it = redisErrorPrefixMapping.find(s.GetCode()); it != redisErrorPrefixMapping.end()) { + prefix = it->second; + } + if (prefix.empty()) { + return s.Msg(); + } + return prefix + " " + s.Msg(); +} std::string BulkString(const std::string &data) { return "$" + std::to_string(data.length()) + CRLF + data + CRLF; } std::string Array(const std::vector &list) { size_t n = std::accumulate(list.begin(), list.end(), 0, [](size_t n, const std::string &s) { return n + s.size(); }); - std::string result = "*" + std::to_string(list.size()) + CRLF; + std::string result = MultiLen(list.size()); std::string::size_type final_size = result.size() + n; result.reserve(final_size); for (const auto &i : list) result += i; @@ -42,7 +64,7 @@ std::string Array(const std::vector &list) { } std::string ArrayOfBulkStrings(const std::vector &elems) { - std::string result = "*" + std::to_string(elems.size()) + CRLF; + std::string result = MultiLen(elems.size()); for (const auto &elem : elems) { result += BulkString(elem); } diff --git a/src/server/redis_reply.h b/src/server/redis_reply.h index 213a8bc0849..f50c559d704 100644 --- a/src/server/redis_reply.h +++ b/src/server/redis_reply.h @@ -25,7 +25,11 @@ #include #include -#define CRLF "\r\n" // NOLINT +#include "status.h" + +#define CRLF "\r\n" // NOLINT +#define RESP_PREFIX_ERROR "-" // NOLINT +#define RESP_PREFIX_SIMPLE_STRING "+" // NOLINT namespace redis { @@ -33,7 +37,9 @@ enum class RESP { v2, v3 }; void Reply(evbuffer *output, const std::string &data); std::string SimpleString(const std::string &data); -std::string Error(const std::string &err); + +std::string Error(const Status &s); +std::string StatusToRedisErrorMsg(const Status &s); template , int> = 0> std::string Integer(T data) { diff --git a/src/server/server.cc b/src/server/server.cc index c44085f3123..30c9db69ee8 100644 --- a/src/server/server.cc +++ b/src/server/server.cc @@ -39,6 +39,7 @@ #include "commands/commander.h" #include "config.h" +#include "config/config.h" #include "fmt/format.h" #include "redis_connection.h" #include "storage/compaction_checker.h" @@ -52,7 +53,12 @@ #include "worker.h" Server::Server(engine::Storage *storage, Config *config) - : storage(storage), start_time_(util::GetTimeStamp()), config_(config), namespace_(storage) { + : storage(storage), + indexer(storage), + index_mgr(&indexer, storage), + start_time_secs_(util::GetTimeStamp()), + config_(config), + namespace_(storage) { // init commands stats here to prevent concurrent insert, and cause core auto commands = redis::CommandTable::GetOriginal(); for (const auto &iter : *commands) { @@ -150,6 +156,13 @@ Status Server::Start() { } } + if (!config_->cluster_enabled) { + GET_OR_RET(index_mgr.Load(kDefaultNamespace)); + for (auto [_, ns] : namespace_.List()) { + GET_OR_RET(index_mgr.Load(ns)); + } + } + if (config_->cluster_enabled) { if (config_->persist_cluster_nodes_enabled) { auto s = cluster->LoadClusterNodes(config_->NodesFilePath()); @@ -179,7 +192,7 @@ Status Server::Start() { compaction_checker_thread_ = GET_OR_RET(util::CreateThread("compact-check", [this] { uint64_t counter = 0; - time_t last_compact_date = 0; + int64_t last_compact_date = 0; CompactionChecker compaction_checker{this->storage}; while (!stop_) { @@ -191,21 +204,20 @@ Status Server::Start() { if (storage->IsClosing()) continue; if (!is_loading_ && ++counter % 600 == 0 // check every minute - && config_->compaction_checker_range.Enabled()) { - auto now = static_cast(util::GetTimeStamp()); - std::tm local_time{}; - localtime_r(&now, &local_time); - if (local_time.tm_hour >= config_->compaction_checker_range.start && - local_time.tm_hour <= config_->compaction_checker_range.stop) { - std::vector cf_names = {engine::kMetadataColumnFamilyName, engine::kSubkeyColumnFamilyName, - engine::kZSetScoreColumnFamilyName, engine::kStreamColumnFamilyName}; - for (const auto &cf_name : cf_names) { - compaction_checker.PickCompactionFiles(cf_name); + && config_->compaction_checker_cron.IsEnabled()) { + auto t_now = static_cast(util::GetTimeStamp()); + std::tm now{}; + localtime_r(&t_now, &now); + if (config_->compaction_checker_cron.IsTimeMatch(&now)) { + const auto &column_family_list = engine::ColumnFamilyConfigs::ListAllColumnFamilies(); + for (auto &column_family : column_family_list) { + compaction_checker.PickCompactionFilesForCf(column_family); } } // compact once per day - if (now != 0 && last_compact_date != now / 86400) { - last_compact_date = now / 86400; + auto now_hours = t_now / 3600; + if (now_hours != 0 && last_compact_date != now_hours / 24) { + last_compact_date = now_hours / 24; compaction_checker.CompactPropagateAndPubSubFiles(); } } @@ -344,9 +356,9 @@ void Server::CleanupExitedSlaves() { void Server::FeedMonitorConns(redis::Connection *conn, const std::vector &tokens) { if (monitor_clients_ <= 0) return; - auto now = util::GetTimeStampUS(); + auto now_us = util::GetTimeStampUS(); std::string output = - fmt::format("{}.{} [{} {}]", now / 1000000, now % 1000000, conn->GetNamespace(), conn->GetAddr()); + fmt::format("{}.{} [{} {}]", now_us / 1000000, now_us % 1000000, conn->GetNamespace(), conn->GetAddr()); for (const auto &token : tokens) { output += " \""; output += util::EscapeString(token); @@ -674,7 +686,7 @@ void Server::OnEntryAddedToStream(const std::string &ns, const std::string &key, } } -void Server::updateCachedTime() { unix_time.store(util::GetTimeStamp()); } +void Server::updateCachedTime() { unix_time_secs.store(util::GetTimeStamp()); } int Server::IncrClientNum() { total_clients_.fetch_add(1, std::memory_order_relaxed); @@ -746,7 +758,7 @@ void Server::cron() { std::tm now{}; localtime_r(&t, &now); // disable compaction cron when the compaction checker was enabled - if (!config_->compaction_checker_range.Enabled() && config_->compact_cron.IsEnabled() && + if (!config_->compaction_checker_cron.IsEnabled() && config_->compact_cron.IsEnabled() && config_->compact_cron.IsTimeMatch(&now)) { Status s = AsyncCompactDB(); LOG(INFO) << "[server] Schedule to compact the db, result: " << s.Msg(); @@ -755,6 +767,24 @@ void Server::cron() { Status s = AsyncBgSaveDB(); LOG(INFO) << "[server] Schedule to bgsave the db, result: " << s.Msg(); } + if (config_->dbsize_scan_cron.IsEnabled() && config_->dbsize_scan_cron.IsTimeMatch(&now)) { + auto tokens = namespace_.List(); + std::vector namespaces; + + // Number of namespaces (custom namespaces + default one) + namespaces.reserve(tokens.size() + 1); + for (auto &token : tokens) { + namespaces.emplace_back(token.second); // namespace + } + + // add default namespace as fallback + namespaces.emplace_back(kDefaultNamespace); + + for (auto &ns : namespaces) { + Status s = AsyncScanDBSize(ns); + LOG(INFO) << "[server] Schedule to recalculate the db size on namespace: " << ns << ", result: " << s.Msg(); + } + } } // check every 10s if (counter != 0 && counter % 100 == 0) { @@ -769,13 +799,14 @@ void Server::cron() { // No replica uses this checkpoint, we can remove it. if (counter != 0 && counter % 100 == 0) { - time_t create_time = storage->GetCheckpointCreateTime(); - time_t access_time = storage->GetCheckpointAccessTime(); + int64_t create_time_secs = storage->GetCheckpointCreateTimeSecs(); + int64_t access_time_secs = storage->GetCheckpointAccessTimeSecs(); if (storage->ExistCheckpoint()) { // TODO(shooterit): support to config the alive time of checkpoint - auto now = static_cast(util::GetTimeStamp()); - if ((GetFetchFileThreadNum() == 0 && now - access_time > 30) || (now - create_time > 24 * 60 * 60)) { + int64_t now_secs = util::GetTimeStamp(); + if ((GetFetchFileThreadNum() == 0 && now_secs - access_time_secs > 30) || + (now_secs - create_time_secs > 24 * 60 * 60)) { auto s = rocksdb::DestroyDB(config_->checkpoint_dir, rocksdb::Options()); if (!s.ok()) { LOG(WARNING) << "[server] Fail to clean checkpoint, error: " << s.ToString(); @@ -792,8 +823,12 @@ void Server::cron() { // In order to properly handle all possible situations on rocksdb, we manually resume here // when encountering no space error and disk quota exceeded error. if (counter != 0 && counter % 600 == 0 && storage->IsDBInRetryableIOError()) { - storage->GetDB()->Resume(); - LOG(INFO) << "[server] Schedule to resume DB after retryable IO error"; + auto s = storage->GetDB()->Resume(); + if (s.ok()) { + LOG(WARNING) << "[server] Successfully resumed DB after retryable IO error"; + } else { + LOG(ERROR) << "[server] Failed to resume DB after retryable IO error: " << s.ToString(); + } storage->SetDBInRetryableIOError(false); } @@ -828,16 +863,26 @@ void Server::GetRocksDBInfo(std::string *info) { db->GetAggregatedIntProperty("rocksdb.num-live-versions", &num_live_versions); string_stream << "# RocksDB\r\n"; + + { + // All column families share the same block cache, so it's good to count a single one. + uint64_t block_cache_usage = 0; + uint64_t block_cache_pinned_usage = 0; + auto subkey_cf_handle = storage->GetCFHandle(ColumnFamilyID::PrimarySubkey); + db->GetIntProperty(subkey_cf_handle, rocksdb::DB::Properties::kBlockCacheUsage, &block_cache_usage); + string_stream << "block_cache_usage:" << block_cache_usage << "\r\n"; + db->GetIntProperty(subkey_cf_handle, rocksdb::DB::Properties::kBlockCachePinnedUsage, &block_cache_pinned_usage); + string_stream << "block_cache_pinned_usage[" << subkey_cf_handle->GetName() << "]:" << block_cache_pinned_usage + << "\r\n"; + } + for (const auto &cf_handle : *storage->GetCFHandles()) { - uint64_t estimate_keys = 0, block_cache_usage = 0, block_cache_pinned_usage = 0, index_and_filter_cache_usage = 0; + uint64_t estimate_keys = 0; + uint64_t index_and_filter_cache_usage = 0; std::map cf_stats_map; - db->GetIntProperty(cf_handle, "rocksdb.estimate-num-keys", &estimate_keys); + db->GetIntProperty(cf_handle, rocksdb::DB::Properties::kEstimateNumKeys, &estimate_keys); string_stream << "estimate_keys[" << cf_handle->GetName() << "]:" << estimate_keys << "\r\n"; - db->GetIntProperty(cf_handle, "rocksdb.block-cache-usage", &block_cache_usage); - string_stream << "block_cache_usage[" << cf_handle->GetName() << "]:" << block_cache_usage << "\r\n"; - db->GetIntProperty(cf_handle, "rocksdb.block-cache-pinned-usage", &block_cache_pinned_usage); - string_stream << "block_cache_pinned_usage[" << cf_handle->GetName() << "]:" << block_cache_pinned_usage << "\r\n"; - db->GetIntProperty(cf_handle, "rocksdb.estimate-table-readers-mem", &index_and_filter_cache_usage); + db->GetIntProperty(cf_handle, rocksdb::DB::Properties::kEstimateTableReadersMem, &index_and_filter_cache_usage); string_stream << "index_and_filter_cache_usage[" << cf_handle->GetName() << "]:" << index_and_filter_cache_usage << "\r\n"; db->GetMapProperty(cf_handle, rocksdb::DB::Properties::kCFStats, &cf_stats_map); @@ -935,9 +980,9 @@ void Server::GetServerInfo(std::string *info) { string_stream << "arch_bits:" << sizeof(void *) * 8 << "\r\n"; string_stream << "process_id:" << getpid() << "\r\n"; string_stream << "tcp_port:" << config_->port << "\r\n"; - int64_t now = util::GetTimeStamp(); - string_stream << "uptime_in_seconds:" << now - start_time_ << "\r\n"; - string_stream << "uptime_in_days:" << (now - start_time_) / 86400 << "\r\n"; + int64_t now_secs = util::GetTimeStamp(); + string_stream << "uptime_in_seconds:" << now_secs - start_time_secs_ << "\r\n"; + string_stream << "uptime_in_days:" << (now_secs - start_time_secs_) / 86400 << "\r\n"; *info = string_stream.str(); } @@ -972,14 +1017,14 @@ void Server::GetReplicationInfo(std::string *info) { string_stream << "# Replication\r\n"; string_stream << "role:" << (IsSlave() ? "slave" : "master") << "\r\n"; if (IsSlave()) { - time_t now = util::GetTimeStamp(); + int64_t now_secs = util::GetTimeStamp(); string_stream << "master_host:" << master_host_ << "\r\n"; string_stream << "master_port:" << master_port_ << "\r\n"; ReplState state = GetReplicationState(); string_stream << "master_link_status:" << (state == kReplConnected ? "up" : "down") << "\r\n"; string_stream << "master_sync_unrecoverable_error:" << (state == kReplError ? "yes" : "no") << "\r\n"; string_stream << "master_sync_in_progress:" << (state == kReplFetchMeta || state == kReplFetchSST) << "\r\n"; - string_stream << "master_last_io_seconds_ago:" << now - replication_thread_->LastIOTime() << "\r\n"; + string_stream << "master_last_io_seconds_ago:" << now_secs - replication_thread_->LastIOTimeSecs() << "\r\n"; string_stream << "slave_repl_offset:" << storage->LatestSeqNumber() << "\r\n"; string_stream << "slave_priority:" << config_->slave_priority << "\r\n"; } @@ -1063,15 +1108,15 @@ void Server::SetLastRandomKeyCursor(const std::string &cursor) { } int64_t Server::GetCachedUnixTime() { - if (unix_time.load() == 0) { + if (unix_time_secs.load() == 0) { updateCachedTime(); } - return unix_time.load(); + return unix_time_secs.load(); } int64_t Server::GetLastBgsaveTime() { std::lock_guard lg(db_job_mu_); - return last_bgsave_time_ == -1 ? start_time_ : last_bgsave_time_; + return last_bgsave_timestamp_secs_ == -1 ? start_time_secs_ : last_bgsave_timestamp_secs_; } void Server::GetStatsInfo(std::string *info) { @@ -1113,7 +1158,7 @@ void Server::GetCommandsStatsInfo(std::string *info) { auto latency = cmd_stat.second.latency.load(); string_stream << "cmdstat_" << cmd_stat.first << ":calls=" << calls << ",usec=" << latency - << ",usec_per_call=" << ((calls == 0) ? 0 : static_cast(latency / calls)) << "\r\n"; + << ",usec_per_call=" << static_cast(latency / calls) << "\r\n"; } *info = string_stream.str(); @@ -1167,9 +1212,10 @@ void Server::GetInfo(const std::string &ns, const std::string §ion, std::str std::lock_guard lg(db_job_mu_); string_stream << "bgsave_in_progress:" << (is_bgsave_in_progress_ ? 1 : 0) << "\r\n"; - string_stream << "last_bgsave_time:" << (last_bgsave_time_ == -1 ? start_time_ : last_bgsave_time_) << "\r\n"; + string_stream << "last_bgsave_time:" + << (last_bgsave_timestamp_secs_ == -1 ? start_time_secs_ : last_bgsave_timestamp_secs_) << "\r\n"; string_stream << "last_bgsave_status:" << last_bgsave_status_ << "\r\n"; - string_stream << "last_bgsave_time_sec:" << last_bgsave_time_sec_ << "\r\n"; + string_stream << "last_bgsave_time_sec:" << last_bgsave_duration_secs_ << "\r\n"; } if (all || section == "stats") { @@ -1226,8 +1272,11 @@ void Server::GetInfo(const std::string &ns, const std::string §ion, std::str GetLatestKeyNumStats(ns, &stats); } - time_t last_scan_time = GetLastScanTime(ns); - tm last_scan_tm{}; + //time_t last_scan_time = GetLastScanTime(ns); + //tm last_scan_tm{}; + // FIXME(mwish): output still requires std::tm. + auto last_scan_time = static_cast(GetLastScanTime(ns)); + std::tm last_scan_tm{}; localtime_r(&last_scan_time, &last_scan_tm); if (section_cnt++) string_stream << "\r\n"; @@ -1398,15 +1447,15 @@ Status Server::AsyncBgSaveDB() { is_bgsave_in_progress_ = true; return task_runner_.TryPublish([this] { - auto start_bgsave_time = util::GetTimeStamp(); + auto start_bgsave_time_secs = util::GetTimeStamp(); Status s = storage->CreateBackup(); - auto stop_bgsave_time = util::GetTimeStamp(); + auto stop_bgsave_time_secs = util::GetTimeStamp(); std::lock_guard lg(db_job_mu_); is_bgsave_in_progress_ = false; - last_bgsave_time_ = start_bgsave_time; + last_bgsave_timestamp_secs_ = start_bgsave_time_secs; last_bgsave_status_ = s.IsOK() ? "ok" : "err"; - last_bgsave_time_sec_ = stop_bgsave_time - start_bgsave_time; + last_bgsave_duration_secs_ = stop_bgsave_time_secs - start_bgsave_time_secs; }); } @@ -1441,7 +1490,7 @@ Status Server::AsyncScanDBSize(const std::string &ns) { std::lock_guard lg(db_job_mu_); db_scan_infos_[ns].key_num_stats = stats; - db_scan_infos_[ns].last_scan_time = util::GetTimeStamp(); + db_scan_infos_[ns].last_scan_time_secs = util::GetTimeStamp(); db_scan_infos_[ns].is_scanning = false; }); } @@ -1534,10 +1583,10 @@ void Server::GetLatestKeyNumStats(const std::string &ns, KeyNumStats *stats) { } } -time_t Server::GetLastScanTime(const std::string &ns) { +int64_t Server::GetLastScanTime(const std::string &ns) const { auto iter = db_scan_infos_.find(ns); if (iter != db_scan_infos_.end()) { - return iter->second.last_scan_time; + return iter->second.last_scan_time_secs; } return 0; } @@ -1641,7 +1690,7 @@ StatusOr> Server::LookupAndCreateCommand(const auto cmd = cmd_attr->factory(); cmd->SetAttributes(cmd_attr); - return cmd; + return std::move(cmd); } Status Server::ScriptExists(const std::string &sha) { @@ -1655,7 +1704,7 @@ Status Server::ScriptExists(const std::string &sha) { Status Server::ScriptGet(const std::string &sha, std::string *body) const { std::string func_name = engine::kLuaFuncSHAPrefix + sha; - auto cf = storage->GetCFHandle(engine::kPropagateColumnFamilyName); + auto cf = storage->GetCFHandle(ColumnFamilyID::Propagate); auto s = storage->Get(rocksdb::ReadOptions(), cf, func_name, body); if (!s.ok()) { return {s.IsNotFound() ? Status::NotFound : Status::NotOK, s.ToString()}; @@ -1670,7 +1719,7 @@ Status Server::ScriptSet(const std::string &sha, const std::string &body) const Status Server::FunctionGetCode(const std::string &lib, std::string *code) const { std::string func_name = engine::kLuaLibCodePrefix + lib; - auto cf = storage->GetCFHandle(engine::kPropagateColumnFamilyName); + auto cf = storage->GetCFHandle(ColumnFamilyID::Propagate); auto s = storage->Get(rocksdb::ReadOptions(), cf, func_name, code); if (!s.ok()) { return {s.IsNotFound() ? Status::NotFound : Status::NotOK, s.ToString()}; @@ -1680,7 +1729,7 @@ Status Server::FunctionGetCode(const std::string &lib, std::string *code) const Status Server::FunctionGetLib(const std::string &func, std::string *lib) const { std::string func_name = engine::kLuaFuncLibPrefix + func; - auto cf = storage->GetCFHandle(engine::kPropagateColumnFamilyName); + auto cf = storage->GetCFHandle(ColumnFamilyID::Propagate); auto s = storage->Get(rocksdb::ReadOptions(), cf, func_name, lib); if (!s.ok()) { return {s.IsNotFound() ? Status::NotFound : Status::NotOK, s.ToString()}; @@ -1704,7 +1753,7 @@ void Server::ScriptReset() { } Status Server::ScriptFlush() { - auto cf = storage->GetCFHandle(engine::kPropagateColumnFamilyName); + auto cf = storage->GetCFHandle(ColumnFamilyID::Propagate); auto s = storage->FlushScripts(storage->DefaultWriteOptions(), cf); if (!s.ok()) return {Status::NotOK, s.ToString()}; ScriptReset(); @@ -2052,3 +2101,22 @@ std::string Server::GetKeyNameFromCursor(const std::string &cursor, CursorType c return {}; } + +AuthResult Server::AuthenticateUser(const std::string &user_password, std::string *ns) { + const auto &requirepass = GetConfig()->requirepass; + if (requirepass.empty()) { + return AuthResult::NO_REQUIRE_PASS; + } + + auto get_ns = GetNamespace()->GetByToken(user_password); + if (get_ns.IsOK()) { + *ns = get_ns.GetValue(); + return AuthResult::IS_USER; + } + + if (user_password != requirepass) { + return AuthResult::INVALID_PASSWORD; + } + *ns = kDefaultNamespace; + return AuthResult::IS_ADMIN; +} diff --git a/src/server/server.h b/src/server/server.h index cf360659b6a..2b89b94afcf 100644 --- a/src/server/server.h +++ b/src/server/server.h @@ -44,6 +44,8 @@ #include "commands/commander.h" #include "lua.hpp" #include "namespace.h" +#include "search/index_manager.h" +#include "search/indexer.h" #include "server/redis_connection.h" #include "stats/log_collector.h" #include "stats/stats.h" @@ -56,7 +58,8 @@ constexpr const char *REDIS_VERSION = "4.0.0"; struct DBScanInfo { - time_t last_scan_time = 0; + // Last scan system clock in seconds + int64_t last_scan_time_secs = 0; KeyNumStats key_num_stats; bool is_scanning = false; }; @@ -139,6 +142,13 @@ enum ClientType { enum ServerLogType { kServerLogNone, kReplIdLog }; +enum class AuthResult { + IS_USER, + IS_ADMIN, + INVALID_PASSWORD, + NO_REQUIRE_PASS, +}; + class ServerLogData { public: // Redis::WriteBatchLogData always starts with digit ascii, we use alphabetic to @@ -242,7 +252,7 @@ class Server { Status AsyncPurgeOldBackups(uint32_t num_backups_to_keep, uint32_t backup_max_keep_hours); Status AsyncScanDBSize(const std::string &ns); void GetLatestKeyNumStats(const std::string &ns, KeyNumStats *stats); - time_t GetLastScanTime(const std::string &ns); + int64_t GetLastScanTime(const std::string &ns) const; std::string GenerateCursorFromKeyName(const std::string &key_name, CursorType cursor_type, const char *prefix = ""); std::string GetKeyNameFromCursor(const std::string &cursor, CursorType cursor_type); @@ -287,7 +297,7 @@ class Server { Stats stats; engine::Storage *storage; std::unique_ptr cluster; - static inline std::atomic unix_time = 0; + static inline std::atomic unix_time_secs = 0; std::unique_ptr slot_migrator; std::unique_ptr slot_import; @@ -299,10 +309,16 @@ class Server { std::list> GetSlaveHostAndPort(); Namespace *GetNamespace() { return &namespace_; } + AuthResult AuthenticateUser(const std::string &user_password, std::string *ns); + #ifdef ENABLE_OPENSSL UniqueSSLContext ssl_ctx; #endif + // search + redis::GlobalIndexer indexer; + redis::IndexManager index_mgr; + private: void cron(); void recordInstantaneousMetrics(); @@ -316,7 +332,7 @@ class Server { std::atomic stop_ = false; std::atomic is_loading_ = false; - int64_t start_time_; + int64_t start_time_secs_; std::mutex slaveof_mu_; std::string master_host_; uint32_t master_port_ = 0; @@ -346,9 +362,9 @@ class Server { std::mutex db_job_mu_; bool db_compacting_ = false; bool is_bgsave_in_progress_ = false; - int64_t last_bgsave_time_ = -1; + int64_t last_bgsave_timestamp_secs_ = -1; std::string last_bgsave_status_ = "ok"; - int64_t last_bgsave_time_sec_ = -1; + int64_t last_bgsave_duration_secs_ = -1; std::map db_scan_infos_; diff --git a/src/server/worker.cc b/src/server/worker.cc index 1e8fed37441..22054e1faf8 100644 --- a/src/server/worker.cc +++ b/src/server/worker.cc @@ -167,7 +167,7 @@ void Worker::newTCPConnection(evconnlistener *listener, evutil_socket_t fd, sock s = AddConnection(conn); if (!s.IsOK()) { - std::string err_msg = redis::Error("ERR " + s.Msg()); + std::string err_msg = redis::Error({Status::NotOK, s.Msg()}); s = util::SockSend(fd, err_msg, ssl); if (!s.IsOK()) { LOG(WARNING) << "Failed to send error response to socket: " << s.Msg(); @@ -200,8 +200,7 @@ void Worker::newUnixSocketConnection(evconnlistener *listener, evutil_socket_t f auto s = AddConnection(conn); if (!s.IsOK()) { - std::string err_msg = redis::Error("ERR " + s.Msg()); - s = util::SockSend(fd, err_msg); + s = util::SockSend(fd, redis::Error(s)); if (!s.IsOK()) { LOG(WARNING) << "Failed to send error response to socket: " << s.Msg(); } diff --git a/src/stats/disk_stats.cc b/src/stats/disk_stats.cc index 7c3c99982d2..8c7f98ddfa2 100644 --- a/src/stats/disk_stats.cc +++ b/src/stats/disk_stats.cc @@ -77,62 +77,62 @@ rocksdb::Status Disk::GetStringSize(const Slice &ns_key, uint64_t *key_size) { rocksdb::Status Disk::GetHashSize(const Slice &ns_key, uint64_t *key_size) { HashMetadata metadata(false); - rocksdb::Status s = Database::GetMetadata({kRedisHash}, ns_key, &metadata); + rocksdb::Status s = Database::GetMetadata(Database::GetOptions{}, {kRedisHash}, ns_key, &metadata); if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s; - return GetApproximateSizes(metadata, ns_key, storage_->GetCFHandle(engine::kSubkeyColumnFamilyName), key_size); + return GetApproximateSizes(metadata, ns_key, storage_->GetCFHandle(ColumnFamilyID::PrimarySubkey), key_size); } rocksdb::Status Disk::GetSetSize(const Slice &ns_key, uint64_t *key_size) { SetMetadata metadata(false); - rocksdb::Status s = Database::GetMetadata({kRedisSet}, ns_key, &metadata); + rocksdb::Status s = Database::GetMetadata(Database::GetOptions{}, {kRedisSet}, ns_key, &metadata); if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s; - return GetApproximateSizes(metadata, ns_key, storage_->GetCFHandle(engine::kSubkeyColumnFamilyName), key_size); + return GetApproximateSizes(metadata, ns_key, storage_->GetCFHandle(ColumnFamilyID::PrimarySubkey), key_size); } rocksdb::Status Disk::GetListSize(const Slice &ns_key, uint64_t *key_size) { ListMetadata metadata(false); - rocksdb::Status s = Database::GetMetadata({kRedisList}, ns_key, &metadata); + rocksdb::Status s = Database::GetMetadata(Database::GetOptions{}, {kRedisList}, ns_key, &metadata); if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s; std::string buf; PutFixed64(&buf, metadata.head); - return GetApproximateSizes(metadata, ns_key, storage_->GetCFHandle(engine::kSubkeyColumnFamilyName), key_size, buf); + return GetApproximateSizes(metadata, ns_key, storage_->GetCFHandle(ColumnFamilyID::PrimarySubkey), key_size, buf); } rocksdb::Status Disk::GetZsetSize(const Slice &ns_key, uint64_t *key_size) { ZSetMetadata metadata(false); - rocksdb::Status s = Database::GetMetadata({kRedisZSet}, ns_key, &metadata); + rocksdb::Status s = Database::GetMetadata(Database::GetOptions{}, {kRedisZSet}, ns_key, &metadata); if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s; std::string score_bytes; PutDouble(&score_bytes, kMinScore); - s = GetApproximateSizes(metadata, ns_key, storage_->GetCFHandle(engine::kZSetScoreColumnFamilyName), key_size, + s = GetApproximateSizes(metadata, ns_key, storage_->GetCFHandle(ColumnFamilyID::SecondarySubkey), key_size, score_bytes, score_bytes); if (!s.ok()) return s; - return GetApproximateSizes(metadata, ns_key, storage_->GetCFHandle(engine::kSubkeyColumnFamilyName), key_size); + return GetApproximateSizes(metadata, ns_key, storage_->GetCFHandle(ColumnFamilyID::PrimarySubkey), key_size); } rocksdb::Status Disk::GetBitmapSize(const Slice &ns_key, uint64_t *key_size) { BitmapMetadata metadata(false); - rocksdb::Status s = Database::GetMetadata({kRedisBitmap}, ns_key, &metadata); + rocksdb::Status s = Database::GetMetadata(Database::GetOptions{}, {kRedisBitmap}, ns_key, &metadata); if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s; - return GetApproximateSizes(metadata, ns_key, storage_->GetCFHandle(engine::kSubkeyColumnFamilyName), key_size, + return GetApproximateSizes(metadata, ns_key, storage_->GetCFHandle(ColumnFamilyID::PrimarySubkey), key_size, std::to_string(0), std::to_string(0)); } rocksdb::Status Disk::GetSortedintSize(const Slice &ns_key, uint64_t *key_size) { SortedintMetadata metadata(false); - rocksdb::Status s = Database::GetMetadata({kRedisSortedint}, ns_key, &metadata); + rocksdb::Status s = Database::GetMetadata(Database::GetOptions{}, {kRedisSortedint}, ns_key, &metadata); if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s; std::string start_buf; PutFixed64(&start_buf, 0); - return GetApproximateSizes(metadata, ns_key, storage_->GetCFHandle(engine::kSubkeyColumnFamilyName), key_size, + return GetApproximateSizes(metadata, ns_key, storage_->GetCFHandle(ColumnFamilyID::PrimarySubkey), key_size, start_buf, start_buf); } rocksdb::Status Disk::GetStreamSize(const Slice &ns_key, uint64_t *key_size) { StreamMetadata metadata(false); - rocksdb::Status s = Database::GetMetadata({kRedisStream}, ns_key, &metadata); + rocksdb::Status s = Database::GetMetadata(Database::GetOptions{}, {kRedisStream}, ns_key, &metadata); if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s; - return GetApproximateSizes(metadata, ns_key, storage_->GetCFHandle(engine::kStreamColumnFamilyName), key_size); + return GetApproximateSizes(metadata, ns_key, storage_->GetCFHandle(ColumnFamilyID::Stream), key_size); } } // namespace redis diff --git a/src/stats/stats.cc b/src/stats/stats.cc index 115fc4d9e13..ae18638b221 100644 --- a/src/stats/stats.cc +++ b/src/stats/stats.cc @@ -29,7 +29,7 @@ Stats::Stats() { for (int i = 0; i < STATS_METRIC_COUNT; i++) { InstMetric im; - im.last_sample_time = 0; + im.last_sample_time_ms = 0; im.last_sample_count = 0; im.idx = 0; for (uint64_t &sample : im.samples) { @@ -93,15 +93,15 @@ void Stats::IncrLatency(uint64_t latency, const std::string &command_name) { } void Stats::TrackInstantaneousMetric(int metric, uint64_t current_reading) { - uint64_t curr_time = util::GetTimeStampMS(); + uint64_t curr_time_ms = util::GetTimeStampMS(); std::unique_lock lock(inst_metrics_mutex); - uint64_t t = curr_time - inst_metrics[metric].last_sample_time; + uint64_t t = curr_time_ms - inst_metrics[metric].last_sample_time_ms; uint64_t ops = current_reading - inst_metrics[metric].last_sample_count; uint64_t ops_sec = t > 0 ? (ops * 1000 / t) : 0; inst_metrics[metric].samples[inst_metrics[metric].idx] = ops_sec; inst_metrics[metric].idx++; inst_metrics[metric].idx %= STATS_METRIC_SAMPLES; - inst_metrics[metric].last_sample_time = curr_time; + inst_metrics[metric].last_sample_time_ms = curr_time_ms; inst_metrics[metric].last_sample_count = current_reading; } diff --git a/src/stats/stats.h b/src/stats/stats.h index 88ab2108b09..6fdba09a194 100644 --- a/src/stats/stats.h +++ b/src/stats/stats.h @@ -49,8 +49,8 @@ struct CommandStat { }; struct InstMetric { - uint64_t last_sample_time; // Timestamp of the last sample in ms - uint64_t last_sample_count; // Count in the last sample + uint64_t last_sample_time_ms; // Timestamp of the last sample in ms + uint64_t last_sample_count; // Count in the last sample uint64_t samples[STATS_METRIC_SAMPLES]; int idx; }; diff --git a/src/storage/batch_extractor.cc b/src/storage/batch_extractor.cc index f60669ebb5e..c1559b1a71d 100644 --- a/src/storage/batch_extractor.cc +++ b/src/storage/batch_extractor.cc @@ -44,14 +44,14 @@ void WriteBatchExtractor::LogData(const rocksdb::Slice &blob) { } rocksdb::Status WriteBatchExtractor::PutCF(uint32_t column_family_id, const Slice &key, const Slice &value) { - if (column_family_id == kColumnFamilyIDZSetScore) { + if (column_family_id == static_cast(ColumnFamilyID::SecondarySubkey)) { return rocksdb::Status::OK(); } std::string ns, user_key; std::vector command_args; - if (column_family_id == kColumnFamilyIDMetadata) { + if (column_family_id == static_cast(ColumnFamilyID::Metadata)) { std::tie(ns, user_key) = ExtractNamespaceKey(key, is_slot_id_encoded_); if (slot_id_ >= 0 && static_cast(slot_id_) != GetSlotIdFromKey(user_key)) { return rocksdb::Status::OK(); @@ -109,7 +109,7 @@ rocksdb::Status WriteBatchExtractor::PutCF(uint32_t column_family_id, const Slic return rocksdb::Status::OK(); } - if (column_family_id == kColumnFamilyIDDefault) { + if (column_family_id == static_cast(ColumnFamilyID::PrimarySubkey)) { InternalKey ikey(key, is_slot_id_encoded_); user_key = ikey.GetKey().ToString(); if (slot_id_ >= 0 && static_cast(slot_id_) != GetSlotIdFromKey(user_key)) { @@ -252,7 +252,7 @@ rocksdb::Status WriteBatchExtractor::PutCF(uint32_t column_family_id, const Slic default: break; } - } else if (column_family_id == kColumnFamilyIDStream) { + } else if (column_family_id == static_cast(ColumnFamilyID::Stream)) { auto s = ExtractStreamAddCommand(is_slot_id_encoded_, key, value, &command_args); if (!s.IsOK()) { LOG(ERROR) << "Failed to parse write_batch in PutCF. Type=Stream: " << s.Msg(); @@ -268,14 +268,14 @@ rocksdb::Status WriteBatchExtractor::PutCF(uint32_t column_family_id, const Slic } rocksdb::Status WriteBatchExtractor::DeleteCF(uint32_t column_family_id, const Slice &key) { - if (column_family_id == kColumnFamilyIDZSetScore) { + if (column_family_id == static_cast(ColumnFamilyID::SecondarySubkey)) { return rocksdb::Status::OK(); } std::vector command_args; std::string ns; - if (column_family_id == kColumnFamilyIDMetadata) { + if (column_family_id == static_cast(ColumnFamilyID::Metadata)) { std::string user_key; std::tie(ns, user_key) = ExtractNamespaceKey(key, is_slot_id_encoded_); @@ -284,7 +284,7 @@ rocksdb::Status WriteBatchExtractor::DeleteCF(uint32_t column_family_id, const S } command_args = {"DEL", user_key}; - } else if (column_family_id == kColumnFamilyIDDefault) { + } else if (column_family_id == static_cast(ColumnFamilyID::PrimarySubkey)) { InternalKey ikey(key, is_slot_id_encoded_); std::string user_key = ikey.GetKey().ToString(); if (slot_id_ >= 0 && static_cast(slot_id_) != GetSlotIdFromKey(user_key)) { @@ -376,7 +376,7 @@ rocksdb::Status WriteBatchExtractor::DeleteCF(uint32_t column_family_id, const S default: break; } - } else if (column_family_id == kColumnFamilyIDStream) { + } else if (column_family_id == static_cast(ColumnFamilyID::Stream)) { InternalKey ikey(key, is_slot_id_encoded_); Slice encoded_id = ikey.GetSubKey(); redis::StreamEntryID entry_id; diff --git a/src/storage/compact_filter.h b/src/storage/compact_filter.h index 9788c849886..118bb8f6c7b 100644 --- a/src/storage/compact_filter.h +++ b/src/storage/compact_filter.h @@ -124,4 +124,23 @@ class PubSubFilterFactory : public rocksdb::CompactionFilterFactory { } }; +class SearchFilter : public rocksdb::CompactionFilter { + public: + const char *Name() const override { return "SearchFilter"; } + bool Filter(int level, const Slice &key, const Slice &value, std::string *new_value, bool *modified) const override { + // TODO: just a dummy one here + return false; + } +}; + +class SearchFilterFactory : public rocksdb::CompactionFilterFactory { + public: + SearchFilterFactory() = default; + const char *Name() const override { return "SearchFilterFactory"; } + std::unique_ptr CreateCompactionFilter( + const rocksdb::CompactionFilter::Context &context) override { + return std::unique_ptr(new SearchFilter()); + } +}; + } // namespace engine diff --git a/src/storage/compaction_checker.cc b/src/storage/compaction_checker.cc index 55502b2ab70..649f1f2316c 100644 --- a/src/storage/compaction_checker.cc +++ b/src/storage/compaction_checker.cc @@ -29,18 +29,19 @@ void CompactionChecker::CompactPropagateAndPubSubFiles() { rocksdb::CompactRangeOptions compact_opts; compact_opts.change_level = true; - std::vector cf_names = {engine::kPubSubColumnFamilyName, engine::kPropagateColumnFamilyName}; - for (const auto &cf_name : cf_names) { - LOG(INFO) << "[compaction checker] Start the compact the column family: " << cf_name; - auto cf_handle = storage_->GetCFHandle(cf_name); + for (const auto &cf : + {engine::ColumnFamilyConfigs::PubSubColumnFamily(), engine::ColumnFamilyConfigs::PropagateColumnFamily()}) { + LOG(INFO) << "[compaction checker] Start the compact the column family: " << cf.Name(); + auto cf_handle = storage_->GetCFHandle(cf.Id()); auto s = storage_->GetDB()->CompactRange(compact_opts, cf_handle, nullptr, nullptr); - LOG(INFO) << "[compaction checker] Compact the column family: " << cf_name << " finished, result: " << s.ToString(); + LOG(INFO) << "[compaction checker] Compact the column family: " << cf.Name() + << " finished, result: " << s.ToString(); } } -void CompactionChecker::PickCompactionFiles(const std::string &cf_name) { +void CompactionChecker::PickCompactionFilesForCf(const engine::ColumnFamilyConfig &column_family_config) { rocksdb::TablePropertiesCollection props; - rocksdb::ColumnFamilyHandle *cf = storage_->GetCFHandle(cf_name); + rocksdb::ColumnFamilyHandle *cf = storage_->GetCFHandle(column_family_config.Id()); auto s = storage_->GetDB()->GetPropertiesOfAllTables(cf, &props); if (!s.ok()) { LOG(WARNING) << "[compaction checker] Failed to get table properties, " << s.ToString(); diff --git a/src/storage/compaction_checker.h b/src/storage/compaction_checker.h index 750c32298f4..f7c260b97ba 100644 --- a/src/storage/compaction_checker.h +++ b/src/storage/compaction_checker.h @@ -30,7 +30,7 @@ class CompactionChecker { public: explicit CompactionChecker(engine::Storage *storage) : storage_(storage) {} ~CompactionChecker() = default; - void PickCompactionFiles(const std::string &cf_name); + void PickCompactionFilesForCf(const engine::ColumnFamilyConfig &cf_name); void CompactPropagateAndPubSubFiles(); private: diff --git a/src/storage/iterator.cc b/src/storage/iterator.cc index 6514207b35d..c65dd3ab7eb 100644 --- a/src/storage/iterator.cc +++ b/src/storage/iterator.cc @@ -26,8 +26,10 @@ namespace engine { DBIterator::DBIterator(Storage *storage, rocksdb::ReadOptions read_options, int slot) - : storage_(storage), read_options_(std::move(read_options)), slot_(slot) { - metadata_cf_handle_ = storage_->GetCFHandle(kMetadataColumnFamilyName); + : storage_(storage), + read_options_(std::move(read_options)), + slot_(slot), + metadata_cf_handle_(storage_->GetCFHandle(ColumnFamilyID::Metadata)) { metadata_iter_ = util::UniqueIterator(storage_->NewIterator(read_options_, metadata_cf_handle_)); } @@ -115,9 +117,9 @@ std::unique_ptr DBIterator::GetSubKeyIterator() const { SubKeyIterator::SubKeyIterator(Storage *storage, rocksdb::ReadOptions read_options, RedisType type, std::string prefix) : storage_(storage), read_options_(std::move(read_options)), type_(type), prefix_(std::move(prefix)) { if (type_ == kRedisStream) { - cf_handle_ = storage_->GetCFHandle(kStreamColumnFamilyName); + cf_handle_ = storage_->GetCFHandle(ColumnFamilyID::Stream); } else { - cf_handle_ = storage_->GetCFHandle(kSubkeyColumnFamilyName); + cf_handle_ = storage_->GetCFHandle(ColumnFamilyID::PrimarySubkey); } iter_ = util::UniqueIterator(storage_->NewIterator(read_options_, cf_handle_)); } diff --git a/src/storage/rdb.cc b/src/storage/rdb.cc index bd433a4d860..31c56f00159 100644 --- a/src/storage/rdb.cc +++ b/src/storage/rdb.cc @@ -64,7 +64,6 @@ constexpr const int RDBOpcodeSelectDB = 254; /* DB number of the following k constexpr const int RDBOpcodeEof = 255; /* End of the RDB file. */ constexpr const int SupportedRDBVersion = 10; // not been tested for version 11, so use this version with caution. -constexpr const int MaxRDBVersion = 11; // The current max rdb version supported by redis. constexpr const int RDBCheckSumLen = 8; // rdb check sum length constexpr const int RestoreRdbVersionLen = 2; // rdb version len in restore string @@ -86,7 +85,7 @@ Status RDB::VerifyPayloadChecksum(const std::string_view &payload) { } auto footer = payload.substr(payload.size() - RestoreFooterLen); auto rdb_version = (footer[1] << 8) | footer[0]; - // For now, the max redis rdb version is 11 + // For now, the max redis rdb version is 12 if (rdb_version > MaxRDBVersion) { return {Status::NotOK, fmt::format("invalid or unsupported rdb version: {}", rdb_version)}; } @@ -460,7 +459,11 @@ Status RDB::saveRdbObject(int type, const std::string &key, const RedisObjValue if (type == RDBTypeString) { const auto &value = std::get(obj); redis::String string_db(storage_, ns_); - db_status = string_db.SetEX(key, value, ttl_ms); + uint64_t expire_ms = 0; + if (ttl_ms > 0) { + expire_ms = ttl_ms + util::GetTimeStampMS(); + } + db_status = string_db.SetEX(key, value, expire_ms); } else if (type == RDBTypeSet || type == RDBTypeSetIntSet || type == RDBTypeSetListPack) { const auto &members = std::get>(obj); redis::Set set_db(storage_, ns_); @@ -564,21 +567,20 @@ Status RDB::LoadRdb(uint32_t db_index, bool overwrite_exist_key) { return {Status::NotOK, fmt::format("Can't handle RDB format version {}", rdb_ver)}; } - uint64_t expire_time = 0; + uint64_t expire_time_ms = 0; int64_t expire_keys = 0; int64_t load_keys = 0; int64_t empty_keys_skipped = 0; - auto now = util::GetTimeStampMS(); + auto now_ms = util::GetTimeStampMS(); uint32_t db_id = 0; uint64_t skip_exist_keys = 0; while (true) { auto type = GET_OR_RET(LogWhenError(loadRdbType())); if (type == RDBOpcodeExpireTime) { - expire_time = static_cast(GET_OR_RET(LogWhenError(loadExpiredTimeSeconds()))); - expire_time *= 1000; + expire_time_ms = static_cast(GET_OR_RET(LogWhenError(loadExpiredTimeSeconds()))) * 1000; continue; } else if (type == RDBOpcodeExpireTimeMs) { - expire_time = GET_OR_RET(LogWhenError(loadExpiredTimeMilliseconds(rdb_ver))); + expire_time_ms = GET_OR_RET(LogWhenError(loadExpiredTimeMilliseconds(rdb_ver))); continue; } else if (type == RDBOpcodeFreq) { // LFU frequency: not use in kvrocks GET_OR_RET(LogWhenError(stream_->ReadByte())); // discard the value @@ -634,8 +636,8 @@ Status RDB::LoadRdb(uint32_t db_index, bool overwrite_exist_key) { LOG(WARNING) << "skipping empty key: " << key; } continue; - } else if (expire_time != 0 && - expire_time < now) { // in redis this used to feed this deletion to any connected replicas + } else if (expire_time_ms != 0 && + expire_time_ms < now_ms) { // in redis this used to feed this deletion to any connected replicas expire_keys++; continue; } @@ -652,7 +654,7 @@ Status RDB::LoadRdb(uint32_t db_index, bool overwrite_exist_key) { } } - auto ret = saveRdbObject(type, key, value, expire_time); + auto ret = saveRdbObject(type, key, value, expire_time_ms); if (!ret.IsOK()) { LOG(WARNING) << "save rdb object key " << key << " failed: " << ret.Msg(); } else { @@ -680,3 +682,291 @@ Status RDB::LoadRdb(uint32_t db_index, bool overwrite_exist_key) { return Status::OK(); } + +Status RDB::Dump(const std::string &key, const RedisType type) { + unsigned char buf[2]; + /* Serialize the object in an RDB-like format. It consist of an object type + * byte followed by the serialized object. This is understood by RESTORE. */ + auto s = SaveObjectType(type); + if (!s.IsOK()) return s; + s = SaveObject(key, type); + if (!s.IsOK()) return s; + + /* Write the footer, this is how it looks like: + * ----------------+---------------------+---------------+ + * ... RDB payload | 2 bytes RDB version | 8 bytes CRC64 | + * ----------------+---------------------+---------------+ + * RDB version and CRC are both in little endian. + */ + + // We should choose the minimum RDB version for compatibility consideration. + // For the current DUMP implementation, it was supported since from the Redis 2.6, + // so we choose the RDB version of Redis 2.6 as the minimum version. + buf[0] = MinRDBVersion & 0xff; + buf[1] = (MinRDBVersion >> 8) & 0xff; + s = stream_->Write((const char *)buf, 2); + if (!s.IsOK()) return s; + + /* CRC64 */ + CHECK(dynamic_cast(stream_.get()) != nullptr); + std::string &output = static_cast(stream_.get())->GetInput(); + uint64_t crc = crc64(0, (unsigned char *)(output.c_str()), output.length()); + memrev64ifbe(&crc); + return stream_->Write((const char *)(&crc), 8); +} + +Status RDB::SaveObjectType(const RedisType type) { + int robj_type = -1; + if (type == kRedisString) { + robj_type = RDBTypeString; + } else if (type == kRedisHash) { + robj_type = RDBTypeHash; + } else if (type == kRedisList) { + robj_type = RDBTypeListQuickList; + } else if (type == kRedisSet) { + robj_type = RDBTypeSet; + } else if (type == kRedisZSet) { + robj_type = RDBTypeZSet2; + } else { + LOG(WARNING) << "Invalid or Not supported object type: " << type; + return {Status::NotOK, "Invalid or Not supported object type"}; + } + return stream_->Write((const char *)(&robj_type), 1); +} + +Status RDB::SaveObject(const std::string &key, const RedisType type) { + if (type == kRedisString) { + std::string value; + redis::String string_db(storage_, ns_); + auto s = string_db.Get(key, &value); + if (!s.ok() && !s.IsNotFound()) { + return {Status::RedisExecErr, s.ToString()}; + } + return SaveStringObject(value); + } else if (type == kRedisList) { + std::vector elems; + redis::List list_db(storage_, ns_); + auto s = list_db.Range(key, 0, -1, &elems); + if (!s.ok() && !s.IsNotFound()) { + return {Status::RedisExecErr, s.ToString()}; + } + return SaveListObject(elems); + } else if (type == kRedisSet) { + redis::Set set_db(storage_, ns_); + std::vector members; + auto s = set_db.Members(key, &members); + if (!s.ok()) { + return {Status::RedisExecErr, s.ToString()}; + } + return SaveSetObject(members); + } else if (type == kRedisZSet) { + redis::ZSet zset_db(storage_, ns_); + std::vector member_scores; + RangeScoreSpec spec; + auto s = zset_db.RangeByScore(key, spec, &member_scores, nullptr); + if (!s.ok()) { + return {Status::RedisExecErr, s.ToString()}; + } + std::sort(member_scores.begin(), member_scores.end(), + [](const MemberScore &v1, const MemberScore &v2) { return v1.score > v2.score; }); + return SaveZSetObject(member_scores); + } else if (type == kRedisHash) { + redis::Hash hash_db(storage_, ns_); + std::vector field_values; + auto s = hash_db.GetAll(key, &field_values); + if (!s.ok()) { + return {Status::RedisExecErr, s.ToString()}; + } + + return SaveHashObject(field_values); + } else { + LOG(WARNING) << "Invalid or Not supported object type: " << type; + return {Status::NotOK, "Invalid or Not supported object type"}; + } +} + +Status RDB::RdbSaveLen(uint64_t len) { + unsigned char buf[2]; + if (len < (1 << 6)) { + /* Save a 6 bit len */ + buf[0] = (len & 0xFF) | (RDB6BitLen << 6); + return stream_->Write((const char *)buf, 1); + } else if (len < (1 << 14)) { + /* Save a 14 bit len */ + buf[0] = ((len >> 8) & 0xFF) | (RDB14BitLen << 6); + buf[1] = len & 0xFF; + return stream_->Write((const char *)buf, 2); + } else if (len <= UINT32_MAX) { + /* Save a 32 bit len */ + buf[0] = RDB32BitLen; + auto status = stream_->Write((const char *)buf, 1); + if (!status.IsOK()) return status; + + uint32_t len32 = htonl(len); + return stream_->Write((const char *)(&len32), 4); + } else { + /* Save a 64 bit len */ + buf[0] = RDB64BitLen; + auto status = stream_->Write((const char *)buf, 1); + if (!status.IsOK()) return status; + + len = htonu64(len); + return stream_->Write((const char *)(&len), 8); + } +} + +Status RDB::SaveStringObject(const std::string &value) { + const size_t len = value.length(); + int enclen = 0; + + // When the length is less than 11, value may be an integer, + // so special encoding is performed. + if (len <= 11) { + unsigned char buf[5]; + // convert string to long long + auto parse_result = ParseInt(value, 10); + if (parse_result) { + long long integer_value = *parse_result; + // encode integer + enclen = rdbEncodeInteger(integer_value, buf); + if (enclen > 0) { + return stream_->Write((const char *)buf, enclen); + } + } + } + + // Since we do not support rdb compression, + // the lzf encoding method has not been implemented yet. + + /* Store verbatim */ + auto status = RdbSaveLen(value.length()); + if (!status.IsOK()) return status; + if (value.length() > 0) { + return stream_->Write(value.c_str(), value.length()); + } + return Status::OK(); +} + +Status RDB::SaveListObject(const std::vector &elems) { + if (elems.size() > 0) { + auto status = RdbSaveLen(elems.size()); + if (!status.IsOK()) return status; + + for (const auto &elem : elems) { + auto status = rdbSaveZipListObject(elem); + if (!status.IsOK()) return status; + } + } else { + LOG(WARNING) << "the size of elems is zero"; + return {Status::NotOK, "the size of elems is zero"}; + } + return Status::OK(); +} + +Status RDB::SaveSetObject(const std::vector &members) { + if (members.size() > 0) { + auto status = RdbSaveLen(members.size()); + if (!status.IsOK()) return status; + + for (const auto &elem : members) { + status = SaveStringObject(elem); + if (!status.IsOK()) return status; + } + } else { + LOG(WARNING) << "the size of elems is zero"; + return {Status::NotOK, "the size of elems is zero"}; + } + return Status::OK(); +} + +Status RDB::SaveZSetObject(const std::vector &member_scores) { + if (member_scores.size() > 0) { + auto status = RdbSaveLen(member_scores.size()); + if (!status.IsOK()) return status; + + for (const auto &elem : member_scores) { + status = SaveStringObject(elem.member); + if (!status.IsOK()) return status; + + status = rdbSaveBinaryDoubleValue(elem.score); + if (!status.IsOK()) return status; + } + } else { + LOG(WARNING) << "the size of member_scores is zero"; + return {Status::NotOK, "the size of ZSet is 0"}; + } + return Status::OK(); +} + +Status RDB::SaveHashObject(const std::vector &field_values) { + if (field_values.size() > 0) { + auto status = RdbSaveLen(field_values.size()); + if (!status.IsOK()) return status; + + for (const auto &p : field_values) { + status = SaveStringObject(p.field); + if (!status.IsOK()) return status; + + status = SaveStringObject(p.value); + if (!status.IsOK()) return status; + } + } else { + LOG(WARNING) << "the size of field_values is zero"; + return {Status::NotOK, "the size of Hash is 0"}; + } + return Status::OK(); +} + +int RDB::rdbEncodeInteger(const long long value, unsigned char *enc) { + if (value >= -(1 << 7) && value <= (1 << 7) - 1) { + enc[0] = (RDBEncVal << 6) | RDBEncInt8; + enc[1] = value & 0xFF; + return 2; + } else if (value >= -(1 << 15) && value <= (1 << 15) - 1) { + enc[0] = (RDBEncVal << 6) | RDBEncInt16; + enc[1] = value & 0xFF; + enc[2] = (value >> 8) & 0xFF; + return 3; + } else if (value >= -((long long)1 << 31) && value <= ((long long)1 << 31) - 1) { + enc[0] = (RDBEncVal << 6) | RDBEncInt32; + enc[1] = value & 0xFF; + enc[2] = (value >> 8) & 0xFF; + enc[3] = (value >> 16) & 0xFF; + enc[4] = (value >> 24) & 0xFF; + return 5; + } else { + return 0; + } +} + +Status RDB::rdbSaveBinaryDoubleValue(double val) { + memrev64ifbe(&val); + return stream_->Write((const char *)(&val), sizeof(val)); +} + +Status RDB::rdbSaveZipListObject(const std::string &elem) { + // calc total ziplist size + uint prevlen = 0; + const size_t ziplist_size = zlHeaderSize + zlEndSize + elem.length() + + ZipList::ZipStorePrevEntryLength(nullptr, 0, prevlen) + + ZipList::ZipStoreEntryEncoding(nullptr, 0, elem.length()); + auto zl_string = std::string(ziplist_size, '\0'); + auto zl_ptr = reinterpret_cast(&zl_string[0]); + + // set ziplist header + ZipList::SetZipListBytes(zl_ptr, ziplist_size, (static_cast(ziplist_size))); + ZipList::SetZipListTailOffset(zl_ptr, ziplist_size, intrev32ifbe(zlHeaderSize)); + + // set ziplist entry + auto pos = ZipList::GetZipListEntryHead(zl_ptr, ziplist_size); + pos += ZipList::ZipStorePrevEntryLength(pos, ziplist_size, prevlen); + pos += ZipList::ZipStoreEntryEncoding(pos, ziplist_size, elem.length()); + assert(pos + elem.length() <= zl_ptr + ziplist_size); + memcpy(pos, elem.c_str(), elem.length()); + + // set ziplist end + ZipList::SetZipListLength(zl_ptr, ziplist_size, 1); + zl_ptr[ziplist_size - 1] = zlEnd; + + return SaveStringObject(zl_string); +} diff --git a/src/storage/rdb.h b/src/storage/rdb.h index 7b4cce31f98..3a08df84a40 100644 --- a/src/storage/rdb.h +++ b/src/storage/rdb.h @@ -28,6 +28,7 @@ #include #include "status.h" +#include "types/redis_hash.h" #include "types/redis_zset.h" // Redis object type @@ -52,12 +53,18 @@ constexpr const int RDBTypeZSetListPack = 17; constexpr const int RDBTypeListQuickList2 = 18; constexpr const int RDBTypeStreamListPack2 = 19; constexpr const int RDBTypeSetListPack = 20; +constexpr const int RDBTypeStreamListPack3 = 21; // NOTE: when adding new Redis object encoding type, update isObjectType. // Quick list node encoding constexpr const int QuickListNodeContainerPlain = 1; constexpr const int QuickListNodeContainerPacked = 2; +constexpr const int MaxRDBVersion = 12; // The current max rdb version supported by redis. +// Min Redis RDB version supported by Kvrocks, we choose 6 because it's the first version +// that supports the DUMP command. +constexpr int MinRDBVersion = 6; + class RdbStream; using RedisObjValue = @@ -100,6 +107,29 @@ class RDB { // Load rdb Status LoadRdb(uint32_t db_index, bool overwrite_exist_key = true); + std::unique_ptr &GetStream() { return stream_; } + + Status Dump(const std::string &key, RedisType type); + + Status SaveObjectType(RedisType type); + Status SaveObject(const std::string &key, RedisType type); + Status RdbSaveLen(uint64_t len); + + // String + Status SaveStringObject(const std::string &value); + + // List + Status SaveListObject(const std::vector &elems); + + // Set + Status SaveSetObject(const std::vector &members); + + // Sorted Set + Status SaveZSetObject(const std::vector &member_scores); + + // Hash + Status SaveHashObject(const std::vector &filed_value); + private: engine::Storage *storage_; std::string ns_; @@ -121,4 +151,7 @@ class RDB { Redis allow basic is 0-7 and 6/7 is for the module type which we don't support here.*/ static bool isObjectType(int type) { return (type >= 0 && type <= 5) || (type >= 9 && type <= 21); }; static bool isEmptyRedisObject(const RedisObjValue &value); + static int rdbEncodeInteger(long long value, unsigned char *enc); + Status rdbSaveBinaryDoubleValue(double val); + Status rdbSaveZipListObject(const std::string &elem); }; diff --git a/src/storage/rdb_ziplist.cc b/src/storage/rdb_ziplist.cc index 772226eaa47..b51dc8ddadc 100644 --- a/src/storage/rdb_ziplist.cc +++ b/src/storage/rdb_ziplist.cc @@ -20,11 +20,9 @@ #include "rdb_ziplist.h" -#include "vendor/endianconv.h" +#include -constexpr const int zlHeaderSize = 10; -constexpr const uint8_t ZipListBigLen = 0xFE; -constexpr const uint8_t zlEnd = 0xFF; +#include "vendor/endianconv.h" constexpr const uint8_t ZIP_STR_MASK = 0xC0; constexpr const uint8_t ZIP_STR_06B = (0 << 6); @@ -52,7 +50,7 @@ StatusOr ZipList::Next() { std::string value; if ((encoding) < ZIP_STR_MASK) { // For integer type, needs to convert to uint8_t* to avoid signed extension - auto data = reinterpret_cast(input_.data()); + auto data = reinterpret_cast(input_.data()); if ((encoding) == ZIP_STR_06B) { len_bytes = 1; len = data[pos_] & 0x3F; @@ -91,7 +89,7 @@ StatusOr ZipList::Next() { } else if ((encoding) == ZIP_INT_24B) { GET_OR_RET(peekOK(3)); int32_t i32 = 0; - memcpy(reinterpret_cast(&i32) + 1, input_.data() + pos_, sizeof(int32_t) - 1); + memcpy(reinterpret_cast(&i32) + 1, input_.data() + pos_, sizeof(int32_t) - 1); memrev32ifbe(&i32); i32 >>= 8; setPreEntryLen(4); // 3byte for encoding and 1byte for the prev entry length @@ -126,7 +124,7 @@ StatusOr ZipList::Next() { StatusOr> ZipList::Entries() { GET_OR_RET(peekOK(zlHeaderSize)); // ignore 8 bytes of total bytes and tail of zip list - auto zl_len = intrev16ifbe(*reinterpret_cast(input_.data() + 8)); + auto zl_len = intrev16ifbe(*reinterpret_cast(input_.data() + 8)); pos_ += zlHeaderSize; std::vector entries; @@ -152,3 +150,72 @@ Status ZipList::peekOK(size_t n) { } uint32_t ZipList::getEncodedLengthSize(uint32_t len) { return len < ZipListBigLen ? 1 : 5; } + +uint32_t ZipList::ZipStorePrevEntryLengthLarge(unsigned char *p, size_t zl_size, unsigned int len) { + uint32_t u32 = 0; + if (p != nullptr) { + p[0] = ZipListBigLen; + u32 = len; + assert(zl_size >= 1 + sizeof(uint32_t) + zlHeaderSize); + memcpy(p + 1, &u32, sizeof(u32)); + memrev32ifbe(p + 1); + } + return 1 + sizeof(uint32_t); +} + +uint32_t ZipList::ZipStorePrevEntryLength(unsigned char *p, size_t zl_size, unsigned int len) { + if (p == nullptr) { + return (len < ZipListBigLen) ? 1 : sizeof(uint32_t) + 1; + } + if (len < ZipListBigLen) { + p[0] = len; + return 1; + } + return ZipStorePrevEntryLengthLarge(p, zl_size, len); +} + +uint32_t ZipList::ZipStoreEntryEncoding(unsigned char *p, size_t zl_size, unsigned int rawlen) { + unsigned char len = 1, buf[5]; + + /* Although encoding is given it may not be set for strings, + * so we determine it here using the raw length. */ + if (rawlen <= 0x3f) { + if (!p) return len; + buf[0] = ZIP_STR_06B | rawlen; + } else if (rawlen <= 0x3fff) { + len += 1; + if (!p) return len; + buf[0] = ZIP_STR_14B | ((rawlen >> 8) & 0x3f); + buf[1] = rawlen & 0xff; + } else { + len += 4; + if (!p) return len; + buf[0] = ZIP_STR_32B; + buf[1] = (rawlen >> 24) & 0xff; + buf[2] = (rawlen >> 16) & 0xff; + buf[3] = (rawlen >> 8) & 0xff; + buf[4] = rawlen & 0xff; + } + assert(zl_size >= static_cast(zlHeaderSize) + len); + /* Store this length at p. */ + memcpy(p, buf, len); + return len; +} + +void ZipList::SetZipListBytes(unsigned char *zl, size_t zl_size, uint32_t value) { + assert(zl_size >= sizeof(uint32_t)); + memcpy(zl, &value, sizeof(uint32_t)); +} +void ZipList::SetZipListTailOffset(unsigned char *zl, size_t zl_size, uint32_t value) { + assert(zl_size >= sizeof(uint32_t) * 2); + memcpy(zl + sizeof(uint32_t), &value, sizeof(uint32_t)); +} +void ZipList::SetZipListLength(unsigned char *zl, size_t zl_size, uint16_t value) { + assert(zl_size >= sizeof(uint32_t) * 2 + sizeof(uint16_t)); + memcpy(zl + sizeof(uint32_t) * 2, &value, sizeof(uint16_t)); +} + +unsigned char *ZipList::GetZipListEntryHead(unsigned char *zl, size_t zl_size) { + assert(zl_size >= zlHeaderSize); + return ((zl) + zlHeaderSize); +} diff --git a/src/storage/rdb_ziplist.h b/src/storage/rdb_ziplist.h index e9d05fde716..8f0d99c6693 100644 --- a/src/storage/rdb_ziplist.h +++ b/src/storage/rdb_ziplist.h @@ -25,6 +25,11 @@ #include "common/status.h" +constexpr const int zlHeaderSize = 10; +constexpr const int zlEndSize = 1; +constexpr const uint8_t ZipListBigLen = 0xFE; +constexpr const uint8_t zlEnd = 0xFF; + class ZipList { public: explicit ZipList(std::string_view input) : input_(input){}; @@ -32,6 +37,13 @@ class ZipList { StatusOr Next(); StatusOr> Entries(); + static uint32_t ZipStorePrevEntryLengthLarge(unsigned char *p, size_t zl_size, unsigned int len); + static uint32_t ZipStorePrevEntryLength(unsigned char *p, size_t zl_size, unsigned int len); + static uint32_t ZipStoreEntryEncoding(unsigned char *p, size_t zl_size, unsigned int rawlen); + static void SetZipListBytes(unsigned char *zl, size_t zl_size, uint32_t value); + static void SetZipListTailOffset(unsigned char *zl, size_t zl_size, uint32_t value); + static void SetZipListLength(unsigned char *zl, size_t zl_size, uint16_t value); + static unsigned char *GetZipListEntryHead(unsigned char *zl, size_t zl_size); private: std::string_view input_; diff --git a/src/storage/redis_db.cc b/src/storage/redis_db.cc index 63e8ef571cc..4f08490bd66 100644 --- a/src/storage/redis_db.cc +++ b/src/storage/redis_db.cc @@ -35,12 +35,18 @@ #include "storage/redis_metadata.h" #include "storage/storage.h" #include "time_util.h" +#include "types/redis_hash.h" +#include "types/redis_list.h" +#include "types/redis_set.h" +#include "types/redis_string.h" +#include "types/redis_zset.h" namespace redis { -Database::Database(engine::Storage *storage, std::string ns) : storage_(storage), namespace_(std::move(ns)) { - metadata_cf_handle_ = storage->GetCFHandle("metadata"); -} +Database::Database(engine::Storage *storage, std::string ns) + : storage_(storage), + metadata_cf_handle_(storage->GetCFHandle(ColumnFamilyID::Metadata)), + namespace_(std::move(ns)) {} // Some data types may support reading multiple types of metadata. // For example, bitmap supports reading string metadata and bitmap metadata. @@ -84,25 +90,25 @@ rocksdb::Status Database::ParseMetadata(RedisTypes types, Slice *bytes, Metadata return s; } -rocksdb::Status Database::GetMetadata(RedisTypes types, const Slice &ns_key, Metadata *metadata) { +rocksdb::Status Database::GetMetadata(GetOptions options, RedisTypes types, const Slice &ns_key, Metadata *metadata) { std::string raw_value; Slice rest; - return GetMetadata(types, ns_key, &raw_value, metadata, &rest); + return GetMetadata(options, types, ns_key, &raw_value, metadata, &rest); } -rocksdb::Status Database::GetMetadata(RedisTypes types, const Slice &ns_key, std::string *raw_value, Metadata *metadata, - Slice *rest) { - auto s = GetRawMetadata(ns_key, raw_value); +rocksdb::Status Database::GetMetadata(GetOptions options, RedisTypes types, const Slice &ns_key, std::string *raw_value, + Metadata *metadata, Slice *rest) { + auto s = GetRawMetadata(options, ns_key, raw_value); *rest = *raw_value; if (!s.ok()) return s; return ParseMetadata(types, rest, metadata); } -rocksdb::Status Database::GetRawMetadata(const Slice &ns_key, std::string *bytes) { - LatestSnapShot ss(storage_); - rocksdb::ReadOptions read_options; - read_options.snapshot = ss.GetSnapShot(); - return storage_->Get(read_options, metadata_cf_handle_, ns_key, bytes); +rocksdb::Status Database::GetRawMetadata(GetOptions options, const Slice &ns_key, std::string *bytes) { + rocksdb::ReadOptions opts; + // If options.snapshot == nullptr, we can avoid allocating a snapshot here. + opts.snapshot = options.snapshot; + return storage_->Get(opts, metadata_cf_handle_, ns_key, bytes); } rocksdb::Status Database::Expire(const Slice &user_key, uint64_t timestamp) { @@ -204,25 +210,12 @@ rocksdb::Status Database::MDel(const std::vector &keys, uint64_t *deleted } rocksdb::Status Database::Exists(const std::vector &keys, int *ret) { - *ret = 0; - LatestSnapShot ss(storage_); - rocksdb::ReadOptions read_options; - read_options.snapshot = ss.GetSnapShot(); - - rocksdb::Status s; - std::string value; + std::vector ns_keys; + ns_keys.reserve(keys.size()); for (const auto &key : keys) { - std::string ns_key = AppendNamespacePrefix(key); - s = storage_->Get(read_options, metadata_cf_handle_, ns_key, &value); - if (!s.ok() && !s.IsNotFound()) return s; - if (s.ok()) { - Metadata metadata(kRedisNone, false); - s = metadata.Decode(value); - if (!s.ok()) return s; - if (!metadata.Expired()) *ret += 1; - } + ns_keys.emplace_back(AppendNamespacePrefix(key)); } - return rocksdb::Status::OK(); + return existsInternal(ns_keys, ret); } rocksdb::Status Database::TTL(const Slice &user_key, int64_t *ttl) { @@ -247,7 +240,7 @@ rocksdb::Status Database::TTL(const Slice &user_key, int64_t *ttl) { rocksdb::Status Database::GetExpireTime(const Slice &user_key, uint64_t *timestamp) { std::string ns_key = AppendNamespacePrefix(user_key); Metadata metadata(kRedisNone, false); - auto s = GetMetadata(RedisTypes::All(), ns_key, &metadata); + auto s = GetMetadata(GetOptions{}, RedisTypes::All(), ns_key, &metadata); if (!s.ok()) return s; *timestamp = metadata.expire; @@ -323,7 +316,7 @@ rocksdb::Status Database::Keys(const std::string &prefix, std::vector *keys, std::string *end_cursor) { + std::vector *keys, std::string *end_cursor, RedisType type) { end_cursor->clear(); uint64_t cnt = 0; uint16_t slot_start = 0; @@ -368,6 +361,8 @@ rocksdb::Status Database::Scan(const std::string &cursor, uint64_t limit, const auto s = metadata.Decode(iter->value()); if (!s.ok()) continue; + if (type != kRedisNone && type != metadata.Type()) continue; + if (metadata.Expired()) continue; std::tie(std::ignore, user_key) = ExtractNamespaceKey(iter->key(), storage_->IsSlotIdEncoded()); keys->emplace_back(user_key); @@ -438,18 +433,10 @@ rocksdb::Status Database::RandomKey(const std::string &cursor, std::string *key) } rocksdb::Status Database::FlushDB() { - std::string begin_key, end_key; - std::string prefix = ComposeNamespaceKey(namespace_, "", false); - auto s = FindKeyRangeWithPrefix(prefix, std::string(), &begin_key, &end_key); - if (!s.ok()) { - return rocksdb::Status::OK(); - } - s = storage_->DeleteRange(begin_key, end_key); - if (!s.ok()) { - return s; - } + auto begin_key = ComposeNamespaceKey(namespace_, "", false); + auto end_key = util::StringNext(begin_key); - return rocksdb::Status::OK(); + return storage_->DeleteRange(begin_key, end_key); } rocksdb::Status Database::FlushAll() { @@ -466,12 +453,8 @@ rocksdb::Status Database::FlushAll() { if (!iter->Valid()) { return rocksdb::Status::OK(); } - auto last_key = iter->key().ToString(); - auto s = storage_->DeleteRange(first_key, last_key); - if (!s.ok()) { - return s; - } - return rocksdb::Status::OK(); + auto last_key = util::StringNext(iter->key().ToString()); + return storage_->DeleteRange(first_key, last_key); } rocksdb::Status Database::Dump(const Slice &user_key, std::vector *infos) { @@ -517,7 +500,7 @@ rocksdb::Status Database::Dump(const Slice &user_key, std::vector * if (metadata.Type() == kRedisList) { ListMetadata list_metadata(false); - s = GetMetadata({kRedisList}, ns_key, &list_metadata); + s = GetMetadata(GetOptions{}, {kRedisList}, ns_key, &list_metadata); if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s; infos->emplace_back("head"); infos->emplace_back(std::to_string(list_metadata.head)); @@ -528,80 +511,15 @@ rocksdb::Status Database::Dump(const Slice &user_key, std::vector * return rocksdb::Status::OK(); } -rocksdb::Status Database::Type(const Slice &user_key, RedisType *type) { - std::string ns_key = AppendNamespacePrefix(user_key); - - *type = kRedisNone; - LatestSnapShot ss(storage_); - rocksdb::ReadOptions read_options; - read_options.snapshot = ss.GetSnapShot(); - std::string value; - rocksdb::Status s = storage_->Get(read_options, metadata_cf_handle_, ns_key, &value); - if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s; - - Metadata metadata(kRedisNone, false); - s = metadata.Decode(value); - if (!s.ok()) return s; - if (metadata.Expired()) { - *type = kRedisNone; - } else { - *type = metadata.Type(); - } - return rocksdb::Status::OK(); +rocksdb::Status Database::Type(const Slice &key, RedisType *type) { + std::string ns_key = AppendNamespacePrefix(key); + return typeInternal(ns_key, type); } std::string Database::AppendNamespacePrefix(const Slice &user_key) { return ComposeNamespaceKey(namespace_, user_key, storage_->IsSlotIdEncoded()); } -rocksdb::Status Database::FindKeyRangeWithPrefix(const std::string &prefix, const std::string &prefix_end, - std::string *begin, std::string *end, - rocksdb::ColumnFamilyHandle *cf_handle) { - if (cf_handle == nullptr) { - cf_handle = metadata_cf_handle_; - } - if (prefix.empty()) { - return rocksdb::Status::NotFound(); - } - begin->clear(); - end->clear(); - - LatestSnapShot ss(storage_); - rocksdb::ReadOptions read_options = storage_->DefaultScanOptions(); - read_options.snapshot = ss.GetSnapShot(); - auto iter = util::UniqueIterator(storage_, read_options, cf_handle); - iter->Seek(prefix); - if (!iter->Valid() || !iter->key().starts_with(prefix)) { - return rocksdb::Status::NotFound(); - } - *begin = iter->key().ToString(); - - // it's ok to increase the last char in prefix as the boundary of the prefix - // while we limit the namespace last char shouldn't be larger than 128. - std::string next_prefix; - if (!prefix_end.empty()) { - next_prefix = prefix_end; - } else { - next_prefix = prefix; - char last_char = next_prefix.back(); - last_char++; - next_prefix.pop_back(); - next_prefix.push_back(last_char); - } - iter->SeekForPrev(next_prefix); - int max_prev_limit = 128; // prevent unpredicted long while loop - int i = 0; - // reversed seek the key til with prefix or end of the iterator - while (i++ < max_prev_limit && iter->Valid() && !iter->key().starts_with(prefix)) { - iter->Prev(); - } - if (!iter->Valid() || !iter->key().starts_with(prefix)) { - return rocksdb::Status::NotFound(); - } - *end = iter->key().ToString(); - return rocksdb::Status::OK(); -} - rocksdb::Status Database::ClearKeysOfSlot(const rocksdb::Slice &ns, int slot) { if (!storage_->IsSlotIdEncoded()) { return rocksdb::Status::Aborted("It is not in cluster mode"); @@ -618,8 +536,7 @@ rocksdb::Status Database::ClearKeysOfSlot(const rocksdb::Slice &ns, int slot) { rocksdb::Status Database::KeyExist(const std::string &key) { int cnt = 0; - std::vector keys; - keys.emplace_back(key); + std::vector keys{key}; auto s = Exists(keys, &cnt); if (!s.ok()) { return s; @@ -636,12 +553,11 @@ rocksdb::Status SubKeyScanner::Scan(RedisType type, const Slice &user_key, const uint64_t cnt = 0; std::string ns_key = AppendNamespacePrefix(user_key); Metadata metadata(type, false); - rocksdb::Status s = GetMetadata({type}, ns_key, &metadata); + LatestSnapShot ss(storage_); + rocksdb::Status s = GetMetadata(GetOptions{ss.GetSnapShot()}, {type}, ns_key, &metadata); if (!s.ok()) return s; - LatestSnapShot ss(storage_); rocksdb::ReadOptions read_options = storage_->DefaultScanOptions(); - read_options.snapshot = ss.GetSnapShot(); auto iter = util::UniqueIterator(storage_, read_options); std::string match_prefix_key = InternalKey(ns_key, subkey_prefix, metadata.version, storage_->IsSlotIdEncoded()).Encode(); @@ -699,28 +615,71 @@ Status WriteBatchLogData::Decode(const rocksdb::Slice &blob) { return Status::OK(); } -rocksdb::Status Database::Rename(const std::string &key, const std::string &new_key, bool nx, bool *ret) { - *ret = true; - std::string ns_key = AppendNamespacePrefix(key); - std::string new_ns_key = AppendNamespacePrefix(new_key); +rocksdb::Status Database::existsInternal(const std::vector &keys, int *ret) { + *ret = 0; + LatestSnapShot ss(storage_); + rocksdb::ReadOptions read_options; + read_options.snapshot = ss.GetSnapShot(); + + rocksdb::Status s; + std::string value; + for (const auto &key : keys) { + s = storage_->Get(read_options, metadata_cf_handle_, key, &value); + if (!s.ok() && !s.IsNotFound()) return s; + if (s.ok()) { + Metadata metadata(kRedisNone, false); + s = metadata.Decode(value); + if (!s.ok()) return s; + if (!metadata.Expired()) *ret += 1; + } + } + return rocksdb::Status::OK(); +} + +rocksdb::Status Database::typeInternal(const Slice &key, RedisType *type) { + *type = kRedisNone; + LatestSnapShot ss(storage_); + rocksdb::ReadOptions read_options; + read_options.snapshot = ss.GetSnapShot(); + std::string value; + rocksdb::Status s = storage_->Get(read_options, metadata_cf_handle_, key, &value); + if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s; - std::vector lock_keys = {ns_key, new_ns_key}; + Metadata metadata(kRedisNone, false); + s = metadata.Decode(value); + if (!s.ok()) return s; + if (metadata.Expired()) { + *type = kRedisNone; + } else { + *type = metadata.Type(); + } + return rocksdb::Status::OK(); +} + +rocksdb::Status Database::Copy(const std::string &key, const std::string &new_key, bool nx, bool delete_old, + CopyResult *res) { + std::vector lock_keys = {key, new_key}; MultiLockGuard guard(storage_->GetLockManager(), lock_keys); RedisType type = kRedisNone; - auto s = Type(key, &type); + auto s = typeInternal(key, &type); if (!s.ok()) return s; - if (type == kRedisNone) return rocksdb::Status::InvalidArgument("ERR no such key"); + if (type == kRedisNone) { + *res = CopyResult::KEY_NOT_EXIST; + return rocksdb::Status::OK(); + } if (nx) { int exist = 0; - if (s = Exists({new_key}, &exist), !s.ok()) return s; + if (s = existsInternal({new_key}, &exist), !s.ok()) return s; if (exist > 0) { - *ret = false; + *res = CopyResult::KEY_ALREADY_EXIST; return rocksdb::Status::OK(); } } + *res = CopyResult::DONE; + if (key == new_key) return rocksdb::Status::OK(); auto batch = storage_->GetWriteBatchBase(); @@ -728,21 +687,23 @@ rocksdb::Status Database::Rename(const std::string &key, const std::string &new_ batch->PutLogData(log_data.Encode()); engine::DBIterator iter(storage_, rocksdb::ReadOptions()); - iter.Seek(ns_key); + iter.Seek(key); + if (delete_old) { + batch->Delete(metadata_cf_handle_, key); + } // copy metadata - batch->Delete(metadata_cf_handle_, ns_key); - batch->Put(metadata_cf_handle_, new_ns_key, iter.Value()); + batch->Put(metadata_cf_handle_, new_key, iter.Value()); auto subkey_iter = iter.GetSubKeyIterator(); if (subkey_iter != nullptr) { - auto zset_score_cf = type == kRedisZSet ? storage_->GetCFHandle(engine::kZSetScoreColumnFamilyName) : nullptr; + auto zset_score_cf = type == kRedisZSet ? storage_->GetCFHandle(ColumnFamilyID::SecondarySubkey) : nullptr; for (subkey_iter->Seek(); subkey_iter->Valid(); subkey_iter->Next()) { InternalKey from_ikey(subkey_iter->Key(), storage_->IsSlotIdEncoded()); std::string to_ikey = - InternalKey(new_ns_key, from_ikey.GetSubKey(), from_ikey.GetVersion(), storage_->IsSlotIdEncoded()).Encode(); + InternalKey(new_key, from_ikey.GetSubKey(), from_ikey.GetVersion(), storage_->IsSlotIdEncoded()).Encode(); // copy sub key batch->Put(subkey_iter->ColumnFamilyHandle(), to_ikey, subkey_iter->Value()); @@ -753,7 +714,7 @@ rocksdb::Status Database::Rename(const std::string &key, const std::string &new_ score_bytes.append(from_ikey.GetSubKey().ToString()); // copy score key std::string score_key = - InternalKey(new_ns_key, score_bytes, from_ikey.GetVersion(), storage_->IsSlotIdEncoded()).Encode(); + InternalKey(new_key, score_bytes, from_ikey.GetVersion(), storage_->IsSlotIdEncoded()).Encode(); batch->Put(zset_score_cf, score_key, Slice()); } } @@ -761,4 +722,212 @@ rocksdb::Status Database::Rename(const std::string &key, const std::string &new_ return storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); } + +std::optional Database::lookupKeyByPattern(const std::string &pattern, const std::string &subst) { + if (pattern == "#") { + return subst; + } + + auto match_pos = pattern.find('*'); + if (match_pos == std::string::npos) { + return std::nullopt; + } + + // hash field + std::string field; + auto arrow_pos = pattern.find("->", match_pos + 1); + if (arrow_pos != std::string::npos && arrow_pos + 2 < pattern.size()) { + field = pattern.substr(arrow_pos + 2); + } + + std::string key = pattern.substr(0, match_pos + 1); + key.replace(match_pos, 1, subst); + + std::string value; + RedisType type = RedisType::kRedisNone; + if (!field.empty()) { + auto hash_db = redis::Hash(storage_, namespace_); + if (auto s = hash_db.Type(key, &type); !s.ok() || type != RedisType::kRedisHash) { + return std::nullopt; + } + + if (auto s = hash_db.Get(key, field, &value); !s.ok()) { + return std::nullopt; + } + } else { + auto string_db = redis::String(storage_, namespace_); + if (auto s = string_db.Type(key, &type); !s.ok() || type != RedisType::kRedisString) { + return std::nullopt; + } + if (auto s = string_db.Get(key, &value); !s.ok()) { + return std::nullopt; + } + } + return value; +} + +rocksdb::Status Database::Sort(RedisType type, const std::string &key, const SortArgument &args, + std::vector> *elems, SortResult *res) { + // Obtain the length of the object to sort. + const std::string ns_key = AppendNamespacePrefix(key); + Metadata metadata(type, false); + auto s = GetMetadata(GetOptions{}, {type}, ns_key, &metadata); + if (!s.ok()) return s; + + if (metadata.size > SORT_LENGTH_LIMIT) { + *res = SortResult::LIMIT_EXCEEDED; + return rocksdb::Status::OK(); + } + auto vectorlen = static_cast(metadata.size); + + // Adjust the offset and count of the limit + int offset = args.offset >= vectorlen ? 0 : std::clamp(args.offset, 0, vectorlen - 1); + int count = args.offset >= vectorlen ? 0 : std::clamp(args.count, -1, vectorlen - offset); + if (count == -1) count = vectorlen - offset; + + // Get the elements that need to be sorted + std::vector str_vec; + if (count != 0) { + if (type == RedisType::kRedisList) { + auto list_db = redis::List(storage_, namespace_); + + if (args.dontsort) { + if (args.desc) { + s = list_db.Range(key, -count - offset, -1 - offset, &str_vec); + if (!s.ok()) return s; + std::reverse(str_vec.begin(), str_vec.end()); + } else { + s = list_db.Range(key, offset, offset + count - 1, &str_vec); + if (!s.ok()) return s; + } + } else { + s = list_db.Range(key, 0, -1, &str_vec); + if (!s.ok()) return s; + } + } else if (type == RedisType::kRedisSet) { + auto set_db = redis::Set(storage_, namespace_); + s = set_db.Members(key, &str_vec); + if (!s.ok()) return s; + + if (args.dontsort) { + str_vec = std::vector(std::make_move_iterator(str_vec.begin() + offset), + std::make_move_iterator(str_vec.begin() + offset + count)); + } + } else if (type == RedisType::kRedisZSet) { + auto zset_db = redis::ZSet(storage_, namespace_); + std::vector member_scores; + + if (args.dontsort) { + RangeRankSpec spec; + spec.start = offset; + spec.stop = offset + count - 1; + spec.reversed = args.desc; + s = zset_db.RangeByRank(key, spec, &member_scores, nullptr); + if (!s.ok()) return s; + + for (auto &member_score : member_scores) { + str_vec.emplace_back(std::move(member_score.member)); + } + } else { + s = zset_db.GetAllMemberScores(key, &member_scores); + if (!s.ok()) return s; + + for (auto &member_score : member_scores) { + str_vec.emplace_back(std::move(member_score.member)); + } + } + } else { + *res = SortResult::UNKNOWN_TYPE; + return s; + } + } + + std::vector sort_vec(str_vec.size()); + for (size_t i = 0; i < str_vec.size(); ++i) { + sort_vec[i].obj = str_vec[i]; + } + + // Sort by BY, ALPHA, ASC/DESC + if (!args.dontsort) { + for (size_t i = 0; i < sort_vec.size(); ++i) { + std::string byval; + if (!args.sortby.empty()) { + auto lookup = lookupKeyByPattern(args.sortby, str_vec[i]); + if (!lookup.has_value()) continue; + byval = std::move(lookup.value()); + } else { + byval = str_vec[i]; + } + + if (args.alpha && !args.sortby.empty()) { + sort_vec[i].v = byval; + } else if (!args.alpha && !byval.empty()) { + auto double_byval = ParseFloat(byval); + if (!double_byval) { + *res = SortResult::DOUBLE_CONVERT_ERROR; + return rocksdb::Status::OK(); + } + sort_vec[i].v = *double_byval; + } + } + + std::sort(sort_vec.begin(), sort_vec.end(), [&args](const RedisSortObject &a, const RedisSortObject &b) { + return RedisSortObject::SortCompare(a, b, args); + }); + + // Gets the element specified by Limit + if (offset != 0 || count != vectorlen) { + sort_vec = std::vector(std::make_move_iterator(sort_vec.begin() + offset), + std::make_move_iterator(sort_vec.begin() + offset + count)); + } + } + + // Perform storage + for (auto &elem : sort_vec) { + if (args.getpatterns.empty()) { + elems->emplace_back(elem.obj); + } + for (const std::string &pattern : args.getpatterns) { + std::optional val = lookupKeyByPattern(pattern, elem.obj); + if (val.has_value()) { + elems->emplace_back(val.value()); + } else { + elems->emplace_back(std::nullopt); + } + } + } + + if (!args.storekey.empty()) { + std::vector store_elems; + store_elems.reserve(elems->size()); + for (const auto &e : *elems) { + store_elems.emplace_back(e.value_or("")); + } + redis::List list_db(storage_, namespace_); + s = list_db.Trim(args.storekey, -1, 0); + if (!s.ok()) return s; + uint64_t new_size = 0; + s = list_db.Push(args.storekey, std::vector(store_elems.cbegin(), store_elems.cend()), false, &new_size); + if (!s.ok()) return s; + } + + return rocksdb::Status::OK(); +} + +bool RedisSortObject::SortCompare(const RedisSortObject &a, const RedisSortObject &b, const SortArgument &args) { + if (!args.alpha) { + double score_a = std::get(a.v); + double score_b = std::get(b.v); + return !args.desc ? score_a < score_b : score_a > score_b; + } else { + if (!args.sortby.empty()) { + std::string cmp_a = std::get(a.v); + std::string cmp_b = std::get(b.v); + return !args.desc ? cmp_a < cmp_b : cmp_a > cmp_b; + } else { + return !args.desc ? a.obj < b.obj : a.obj > b.obj; + } + } +} + } // namespace redis diff --git a/src/storage/redis_db.h b/src/storage/redis_db.h index 0627b651702..12d9896b5ee 100644 --- a/src/storage/redis_db.h +++ b/src/storage/redis_db.h @@ -21,31 +21,108 @@ #pragma once #include +#include #include #include +#include #include #include "redis_metadata.h" +#include "server/redis_reply.h" #include "storage.h" namespace redis { + +/// SORT_LENGTH_LIMIT limits the number of elements to be sorted +/// to avoid using too much memory and causing system crashes. +/// TODO: Expect to expand or eliminate SORT_LENGTH_LIMIT +/// through better mechanisms such as memory restriction logic. +constexpr uint64_t SORT_LENGTH_LIMIT = 512; + +struct SortArgument { + std::string sortby; // BY + bool dontsort = false; // DONT SORT + int offset = 0; // LIMIT OFFSET + int count = -1; // LIMIT COUNT + std::vector getpatterns; // GET + bool desc = false; // ASC/DESC + bool alpha = false; // ALPHA + std::string storekey; // STORE +}; + +struct RedisSortObject { + std::string obj; + std::variant v; + + /// SortCompare is a helper function that enables `RedisSortObject` to be sorted based on `SortArgument`. + /// + /// It can assist in implementing the third parameter `Compare comp` required by `std::sort` + /// + /// \param args The basis used to compare two RedisSortObjects. + /// If `args.alpha` is false, `RedisSortObject.v` will be taken as double for comparison + /// If `args.alpha` is true and `args.sortby` is not empty, `RedisSortObject.v` will be taken as string for comparison + /// If `args.alpha` is true and `args.sortby` is empty, the comparison is by `RedisSortObject.obj`. + /// + /// \return If `desc` is false, returns true when `a < b`, otherwise returns true when `a > b` + static bool SortCompare(const RedisSortObject &a, const RedisSortObject &b, const SortArgument &args); +}; + +/// Database is a wrapper of underlying storage engine, it provides +/// some common operations for redis commands. class Database { public: static constexpr uint64_t RANDOM_KEY_SCAN_LIMIT = 60; + struct GetOptions { + // If snapshot is not nullptr, read from the specified snapshot, + // otherwise read from the "latest" snapshot. + const rocksdb::Snapshot *snapshot = nullptr; + + GetOptions() = default; + explicit GetOptions(const rocksdb::Snapshot *ss) : snapshot(ss) {} + }; + explicit Database(engine::Storage *storage, std::string ns = ""); + /// Parsing metadata with type of `types` from bytes, the metadata is a base class of all metadata. + /// When parsing, the bytes will be consumed. [[nodiscard]] rocksdb::Status ParseMetadata(RedisTypes types, Slice *bytes, Metadata *metadata); - [[nodiscard]] rocksdb::Status GetMetadata(RedisTypes types, const Slice &ns_key, Metadata *metadata); - [[nodiscard]] rocksdb::Status GetMetadata(RedisTypes types, const Slice &ns_key, std::string *raw_value, - Metadata *metadata, Slice *rest); - [[nodiscard]] rocksdb::Status GetRawMetadata(const Slice &ns_key, std::string *bytes); + /// GetMetadata is a helper function to get metadata from the database. It will read the "raw metadata" + /// from underlying storage, and then parse the raw metadata to the specified metadata type. + /// + /// \param options The read options, including whether uses a snapshot during reading the metadata. + /// \param types The candidate types of the metadata. + /// \param ns_key The key with namespace of the metadata. + /// \param metadata The output metadata. + [[nodiscard]] rocksdb::Status GetMetadata(GetOptions options, RedisTypes types, const Slice &ns_key, + Metadata *metadata); + /// GetMetadata is a helper function to get metadata from the database. It will read the "raw metadata" + /// from underlying storage, and then parse the raw metadata to the specified metadata type. + /// + /// Compared with the above function, this function will also return the rest of the bytes + /// after parsing the metadata. + /// + /// \param options The read options, including whether uses a snapshot during reading the metadata. + /// \param types The candidate types of the metadata. + /// \param ns_key The key with namespace of the metadata. + /// \param raw_value Holding the raw metadata. + /// \param metadata The output metadata. + /// \param rest The rest of the bytes after parsing the metadata. + [[nodiscard]] rocksdb::Status GetMetadata(GetOptions options, RedisTypes types, const Slice &ns_key, + std::string *raw_value, Metadata *metadata, Slice *rest); + /// GetRawMetadata is a helper function to get the "raw metadata" from the database without parsing + /// it to the specified metadata type. + /// + /// \param options The read options, including whether uses a snapshot during reading the metadata. + /// \param ns_key The key with namespace of the metadata. + /// \param bytes The output raw metadata. + [[nodiscard]] rocksdb::Status GetRawMetadata(GetOptions options, const Slice &ns_key, std::string *bytes); [[nodiscard]] rocksdb::Status Expire(const Slice &user_key, uint64_t timestamp); [[nodiscard]] rocksdb::Status Del(const Slice &user_key); [[nodiscard]] rocksdb::Status MDel(const std::vector &keys, uint64_t *deleted_cnt); [[nodiscard]] rocksdb::Status Exists(const std::vector &keys, int *ret); [[nodiscard]] rocksdb::Status TTL(const Slice &user_key, int64_t *ttl); [[nodiscard]] rocksdb::Status GetExpireTime(const Slice &user_key, uint64_t *timestamp); - [[nodiscard]] rocksdb::Status Type(const Slice &user_key, RedisType *type); + [[nodiscard]] rocksdb::Status Type(const Slice &key, RedisType *type); [[nodiscard]] rocksdb::Status Dump(const Slice &user_key, std::vector *infos); [[nodiscard]] rocksdb::Status FlushDB(); [[nodiscard]] rocksdb::Status FlushAll(); @@ -53,15 +130,28 @@ class Database { [[nodiscard]] rocksdb::Status Keys(const std::string &prefix, std::vector *keys = nullptr, KeyNumStats *stats = nullptr); [[nodiscard]] rocksdb::Status Scan(const std::string &cursor, uint64_t limit, const std::string &prefix, - std::vector *keys, std::string *end_cursor = nullptr); + std::vector *keys, std::string *end_cursor = nullptr, + RedisType type = kRedisNone); [[nodiscard]] rocksdb::Status RandomKey(const std::string &cursor, std::string *key); std::string AppendNamespacePrefix(const Slice &user_key); - [[nodiscard]] rocksdb::Status FindKeyRangeWithPrefix(const std::string &prefix, const std::string &prefix_end, - std::string *begin, std::string *end, - rocksdb::ColumnFamilyHandle *cf_handle = nullptr); [[nodiscard]] rocksdb::Status ClearKeysOfSlot(const rocksdb::Slice &ns, int slot); [[nodiscard]] rocksdb::Status KeyExist(const std::string &key); - [[nodiscard]] rocksdb::Status Rename(const std::string &key, const std::string &new_key, bool nx, bool *ret); + + // Copy to (already an internal key) + enum class CopyResult { KEY_NOT_EXIST, KEY_ALREADY_EXIST, DONE }; + [[nodiscard]] rocksdb::Status Copy(const std::string &key, const std::string &new_key, bool nx, bool delete_old, + CopyResult *res); + enum class SortResult { UNKNOWN_TYPE, DOUBLE_CONVERT_ERROR, LIMIT_EXCEEDED, DONE }; + /// Sort sorts keys of the specified type according to SortArgument + /// + /// \param type is the type of sort key, which must be LIST, SET or ZSET + /// \param key is to be sorted + /// \param args provide the parameters to sort by + /// \param elems contain the sorted results + /// \param res represents the sorted result type. + /// When status is not ok, `res` should not been checked, otherwise it should be checked whether `res` is `DONE` + [[nodiscard]] rocksdb::Status Sort(RedisType type, const std::string &key, const SortArgument &args, + std::vector> *elems, SortResult *res); protected: engine::Storage *storage_; @@ -69,8 +159,27 @@ class Database { std::string namespace_; friend class LatestSnapShot; -}; + private: + // Already internal keys + [[nodiscard]] rocksdb::Status existsInternal(const std::vector &keys, int *ret); + [[nodiscard]] rocksdb::Status typeInternal(const Slice &key, RedisType *type); + + /// lookupKeyByPattern is a helper function of `Sort` to support `GET` and `BY` fields. + /// + /// \param pattern can be the value of a `BY` or `GET` field + /// \param subst is used to replace the "*" or "#" matched in the pattern string. + /// \return Returns the value associated to the key with a name obtained using the following rules: + /// 1) The first occurrence of '*' in 'pattern' is substituted with 'subst'. + /// 2) If 'pattern' matches the "->" string, everything on the left of + /// the arrow is treated as the name of a hash field, and the part on the + /// left as the key name containing a hash. The value of the specified + /// field is returned. + /// 3) If 'pattern' equals "#", the function simply returns 'subst' itself so + /// that the SORT command can be used like: SORT key GET # to retrieve + /// the Set/List elements directly. + std::optional lookupKeyByPattern(const std::string &pattern, const std::string &subst); +}; class LatestSnapShot { public: explicit LatestSnapShot(engine::Storage *storage) : storage_(storage), snapshot_(storage_->GetDB()->GetSnapshot()) {} diff --git a/src/storage/redis_metadata.cc b/src/storage/redis_metadata.cc index 5e872af0eda..e44b39cad7c 100644 --- a/src/storage/redis_metadata.cc +++ b/src/storage/redis_metadata.cc @@ -96,6 +96,7 @@ std::string InternalKey::Encode() const { } bool InternalKey::operator==(const InternalKey &that) const { + if (namespace_ != this->namespace_) return false; if (key_ != that.key_) return false; if (sub_key_ != that.sub_key_) return false; return version_ == that.version_; @@ -471,21 +472,3 @@ rocksdb::Status JsonMetadata::Decode(Slice *input) { return rocksdb::Status::OK(); } - -void SearchMetadata::Encode(std::string *dst) const { - Metadata::Encode(dst); - - PutFixed8(dst, uint8_t(on_data_type)); -} - -rocksdb::Status SearchMetadata::Decode(Slice *input) { - if (auto s = Metadata::Decode(input); !s.ok()) { - return s; - } - - if (!GetFixed8(input, reinterpret_cast(&on_data_type))) { - return rocksdb::Status::InvalidArgument(kErrMetadataTooShort); - } - - return rocksdb::Status::OK(); -} diff --git a/src/storage/redis_metadata.h b/src/storage/redis_metadata.h index dac2d0e16bb..68f36b2c994 100644 --- a/src/storage/redis_metadata.h +++ b/src/storage/redis_metadata.h @@ -49,7 +49,6 @@ enum RedisType : uint8_t { kRedisStream = 8, kRedisBloomFilter = 9, kRedisJson = 10, - kRedisSearch = 11, }; struct RedisTypes { @@ -65,7 +64,7 @@ struct RedisTypes { return RedisTypes(types); } - bool Contains(RedisType type) { return types_[type]; } + bool Contains(RedisType type) const { return types_[type]; } private: using UnderlyingType = std::bitset<128>; @@ -314,18 +313,3 @@ class JsonMetadata : public Metadata { void Encode(std::string *dst) const override; rocksdb::Status Decode(Slice *input) override; }; - -enum class SearchOnDataType : uint8_t { - HASH = kRedisHash, - JSON = kRedisJson, -}; - -class SearchMetadata : public Metadata { - public: - SearchOnDataType on_data_type; - - explicit SearchMetadata(bool generate_version = true) : Metadata(kRedisSearch, generate_version) {} - - void Encode(std::string *dst) const override; - rocksdb::Status Decode(Slice *input) override; -}; diff --git a/src/storage/redis_pubsub.cc b/src/storage/redis_pubsub.cc index 52264ff9203..6ca153b70be 100644 --- a/src/storage/redis_pubsub.cc +++ b/src/storage/redis_pubsub.cc @@ -23,6 +23,9 @@ namespace redis { rocksdb::Status PubSub::Publish(const Slice &channel, const Slice &value) { + if (storage_->GetConfig()->IsSlave()) { + return rocksdb::Status::NotSupported("can't publish to db in slave mode"); + } auto batch = storage_->GetWriteBatchBase(); batch->Put(pubsub_cf_handle_, channel, value); return storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); diff --git a/src/storage/redis_pubsub.h b/src/storage/redis_pubsub.h index 995560aaa1f..d6e3d392f40 100644 --- a/src/storage/redis_pubsub.h +++ b/src/storage/redis_pubsub.h @@ -29,7 +29,8 @@ namespace redis { class PubSub : public Database { public: - explicit PubSub(engine::Storage *storage) : Database(storage), pubsub_cf_handle_(storage->GetCFHandle("pubsub")) {} + explicit PubSub(engine::Storage *storage) + : Database(storage), pubsub_cf_handle_(storage->GetCFHandle(ColumnFamilyID::PubSub)) {} rocksdb::Status Publish(const Slice &channel, const Slice &value); private: diff --git a/src/storage/scripting.cc b/src/storage/scripting.cc index 65d8179d500..5197eb5c1b1 100644 --- a/src/storage/scripting.cc +++ b/src/storage/scripting.cc @@ -30,6 +30,7 @@ #include #include "commands/commander.h" +#include "commands/error_constants.h" #include "db_util.h" #include "fmt/format.h" #include "lua.h" @@ -442,7 +443,7 @@ Status FunctionList(Server *srv, const redis::Connection *conn, const std::strin rocksdb::Slice upper_bound(end_key); read_options.iterate_upper_bound = &upper_bound; - auto *cf = srv->storage->GetCFHandle(engine::kPropagateColumnFamilyName); + auto *cf = srv->storage->GetCFHandle(ColumnFamilyID::Propagate); auto iter = util::UniqueIterator(srv->storage, read_options, cf); std::vector> result; for (iter->Seek(start_key); iter->Valid(); iter->Next()) { @@ -478,7 +479,7 @@ Status FunctionListFunc(Server *srv, const redis::Connection *conn, const std::s rocksdb::Slice upper_bound(end_key); read_options.iterate_upper_bound = &upper_bound; - auto *cf = srv->storage->GetCFHandle(engine::kPropagateColumnFamilyName); + auto *cf = srv->storage->GetCFHandle(ColumnFamilyID::Propagate); auto iter = util::UniqueIterator(srv->storage, read_options, cf); std::vector> result; for (iter->Seek(start_key); iter->Valid(); iter->Next()) { @@ -556,7 +557,7 @@ Status FunctionDelete(Server *srv, const std::string &name) { } auto storage = srv->storage; - auto cf = storage->GetCFHandle(engine::kPropagateColumnFamilyName); + auto cf = storage->GetCFHandle(ColumnFamilyID::Propagate); for (size_t i = 1; i <= lua_objlen(lua, -1); ++i) { lua_rawgeti(lua, -1, static_cast(i)); @@ -610,7 +611,7 @@ Status EvalGenericCommand(redis::Connection *conn, const std::string &body_or_sh auto s = srv->ScriptGet(funcname + 2, &body); if (!s.IsOK()) { lua_pop(lua, 1); /* remove the error handler from the stack. */ - return {Status::NotOK, "NOSCRIPT No matching script. Please use EVAL"}; + return {Status::RedisNoScript, redis::errNoMatchingScript}; } } else { body = body_or_sha; @@ -638,8 +639,8 @@ Status EvalGenericCommand(redis::Connection *conn, const std::string &body_or_sh SetGlobalArray(lua, "ARGV", argv); if (lua_pcall(lua, 0, 1, -2)) { - auto msg = fmt::format("ERR running script (call to {}): {}", funcname, lua_tostring(lua, -1)); - *output = redis::Error(msg); + auto msg = fmt::format("running script (call to {}): {}", funcname, lua_tostring(lua, -1)); + *output = redis::Error({Status::NotOK, msg}); lua_pop(lua, 2); } else { *output = ReplyToRedisReply(conn, lua); @@ -753,7 +754,7 @@ int RedisGenericCommand(lua_State *lua, int raise_error) { if (config->cluster_enabled) { auto s = srv->cluster->CanExecByMySelf(attributes, args, conn); if (!s.IsOK()) { - PushError(lua, s.Msg().c_str()); + PushError(lua, redis::StatusToRedisErrorMsg(s).c_str()); return raise_error ? RaiseError(lua) : 1; } } @@ -1190,7 +1191,7 @@ std::string ReplyToRedisReply(redis::Connection *conn, lua_State *lua) { lua_rawget(lua, -2); t = lua_type(lua, -1); if (t == LUA_TSTRING) { - output = redis::Error(lua_tostring(lua, -1)); + output = redis::Error({Status::RedisErrorNoPrefix, lua_tostring(lua, -1)}); lua_pop(lua, 1); return output; } diff --git a/src/storage/storage.cc b/src/storage/storage.cc index 94a6e64da8d..8cdb63bf16f 100644 --- a/src/storage/storage.cc +++ b/src/storage/storage.cc @@ -67,7 +67,7 @@ constexpr double kRocksdbLRUBlockCacheHighPriPoolRatio = 0.75; constexpr double kRocksdbLRURowCacheHighPriPoolRatio = 0.5; // used in creating rocksdb::HyperClockCache, set`estimated_entry_charge` to 0 means let rocksdb dynamically and -// automacally adjust the table size for the cache. +// automatically adjust the table size for the cache. constexpr size_t kRockdbHCCAutoAdjustCharge = 0; const int64_t kIORateLimitMaxMb = 1024000; @@ -75,7 +75,7 @@ const int64_t kIORateLimitMaxMb = 1024000; using rocksdb::Slice; Storage::Storage(Config *config) - : backup_creating_time_(util::GetTimeStamp()), + : backup_creating_time_secs_(util::GetTimeStamp()), env_(rocksdb::Env::Default()), config_(config), lock_mgr_(16), @@ -101,11 +101,11 @@ void Storage::CloseDB() { } void Storage::SetWriteOptions(const Config::RocksDB::WriteOptions &config) { - write_opts_.sync = config.sync; - write_opts_.disableWAL = config.disable_wal; - write_opts_.no_slowdown = config.no_slowdown; - write_opts_.low_pri = config.low_pri; - write_opts_.memtable_insert_hint_per_batch = config.memtable_insert_hint_per_batch; + default_write_opts_.sync = config.sync; + default_write_opts_.disableWAL = config.disable_wal; + default_write_opts_.no_slowdown = config.no_slowdown; + default_write_opts_.low_pri = config.low_pri; + default_write_opts_.memtable_insert_hint_per_batch = config.memtable_insert_hint_per_batch; } rocksdb::ReadOptions Storage::DefaultScanOptions() const { @@ -165,6 +165,7 @@ rocksdb::Options Storage::InitRocksDBOptions() { options.min_write_buffer_number_to_merge = 2; options.write_buffer_size = config_->rocks_db.write_buffer_size * MiB; options.num_levels = 7; + options.compression_opts.level = config_->rocks_db.compression_level; options.compression_per_level.resize(options.num_levels); // only compress levels >= 2 for (int i = 0; i < options.num_levels; ++i) { @@ -235,11 +236,12 @@ Status Storage::CreateColumnFamilies(const rocksdb::Options &options) { rocksdb::ColumnFamilyOptions cf_options(options); auto res = util::DBOpen(options, config_->db_dir); if (res) { - std::vector cf_names = {kMetadataColumnFamilyName, kZSetScoreColumnFamilyName, - kPubSubColumnFamilyName, kPropagateColumnFamilyName, - kStreamColumnFamilyName, kSearchColumnFamilyName}; + std::vector cf_names_except_default; + for (const auto &cf : ColumnFamilyConfigs::ListColumnFamiliesWithoutDefault()) { + cf_names_except_default.emplace_back(cf.Name()); + } std::vector cf_handles; - auto s = (*res)->CreateColumnFamilies(cf_options, cf_names, &cf_handles); + auto s = (*res)->CreateColumnFamilies(cf_options, cf_names_except_default, &cf_handles); if (!s.ok()) { return {Status::DBOpenErr, s.ToString()}; } @@ -306,7 +308,7 @@ Status Storage::Open(DBOpenMode mode) { metadata_opts.memtable_whole_key_filtering = true; metadata_opts.memtable_prefix_bloom_size_ratio = 0.1; metadata_opts.table_properties_collector_factories.emplace_back( - NewCompactOnExpiredTableCollectorFactory(kMetadataColumnFamilyName, 0.3)); + NewCompactOnExpiredTableCollectorFactory(std::string(kMetadataColumnFamilyName), 0.3)); SetBlobDB(&metadata_opts); rocksdb::BlockBasedTableOptions subkey_table_opts = InitTableOptions(); @@ -319,7 +321,7 @@ Status Storage::Open(DBOpenMode mode) { subkey_opts.compaction_filter_factory = std::make_shared(this); subkey_opts.disable_auto_compactions = config_->rocks_db.disable_auto_compactions; subkey_opts.table_properties_collector_factories.emplace_back( - NewCompactOnExpiredTableCollectorFactory(kSubkeyColumnFamilyName, 0.3)); + NewCompactOnExpiredTableCollectorFactory(std::string(kPrimarySubkeyColumnFamilyName), 0.3)); SetBlobDB(&subkey_opts); rocksdb::BlockBasedTableOptions pubsub_table_opts = InitTableOptions(); @@ -336,15 +338,22 @@ Status Storage::Open(DBOpenMode mode) { propagate_opts.disable_auto_compactions = config_->rocks_db.disable_auto_compactions; SetBlobDB(&propagate_opts); + rocksdb::BlockBasedTableOptions search_table_opts = InitTableOptions(); + rocksdb::ColumnFamilyOptions search_opts(options); + search_opts.table_factory.reset(rocksdb::NewBlockBasedTableFactory(search_table_opts)); + search_opts.compaction_filter_factory = std::make_shared(); + search_opts.disable_auto_compactions = config_->rocks_db.disable_auto_compactions; + SetBlobDB(&search_opts); + std::vector column_families; // Caution: don't change the order of column family, or the handle will be mismatched column_families.emplace_back(rocksdb::kDefaultColumnFamilyName, subkey_opts); - column_families.emplace_back(kMetadataColumnFamilyName, metadata_opts); - column_families.emplace_back(kZSetScoreColumnFamilyName, subkey_opts); - column_families.emplace_back(kPubSubColumnFamilyName, pubsub_opts); - column_families.emplace_back(kPropagateColumnFamilyName, propagate_opts); - column_families.emplace_back(kStreamColumnFamilyName, subkey_opts); - column_families.emplace_back(kSearchColumnFamilyName, subkey_opts); + column_families.emplace_back(std::string(kMetadataColumnFamilyName), metadata_opts); + column_families.emplace_back(std::string(kSecondarySubkeyColumnFamilyName), subkey_opts); + column_families.emplace_back(std::string(kPubSubColumnFamilyName), pubsub_opts); + column_families.emplace_back(std::string(kPropagateColumnFamilyName), propagate_opts); + column_families.emplace_back(std::string(kStreamColumnFamilyName), subkey_opts); + column_families.emplace_back(std::string(kSearchColumnFamilyName), search_opts); std::vector old_column_families; auto s = rocksdb::DB::ListColumnFamilies(options, config_->db_dir, &old_column_families); @@ -383,7 +392,7 @@ Status Storage::Open(DBOpenMode mode) { Status Storage::CreateBackup(uint64_t *sequence_number) { LOG(INFO) << "[storage] Start to create new backup"; std::lock_guard lg(config_->backup_mu); - std::string task_backup_dir = config_->GetBackupDir(); + std::string task_backup_dir = config_->backup_dir; std::string tmpdir = task_backup_dir + ".tmp"; // Maybe there is a dirty tmp checkpoint, try to clean it @@ -420,8 +429,8 @@ Status Storage::CreateBackup(uint64_t *sequence_number) { return {Status::NotOK, s.ToString()}; } - // 'backup_mu_' can guarantee 'backup_creating_time_' is thread-safe - backup_creating_time_ = static_cast(util::GetTimeStamp()); + // 'backup_mu_' can guarantee 'backup_creating_time_secs_' is thread-safe + backup_creating_time_secs_ = util::GetTimeStamp(); LOG(INFO) << "[storage] Success to create new backup"; return Status::OK(); @@ -520,6 +529,19 @@ Status Storage::RestoreFromCheckpoint() { return Status::OK(); } +bool Storage::IsEmptyDB() { + std::unique_ptr iter( + db_->NewIterator(DefaultScanOptions(), GetCFHandle(ColumnFamilyID::Metadata))); + for (iter->SeekToFirst(); iter->Valid(); iter->Next()) { + Metadata metadata(kRedisNone, false); + // If cannot decode the metadata we think the key is alive, so the db is not empty + if (!metadata.Decode(iter->value()).ok() || !metadata.Expired()) { + return false; + } + } + return true; +} + void Storage::EmptyDB() { // Clean old backups and checkpoints PurgeOldBackups(0, 0); @@ -532,22 +554,23 @@ void Storage::EmptyDB() { } void Storage::PurgeOldBackups(uint32_t num_backups_to_keep, uint32_t backup_max_keep_hours) { - time_t now = util::GetTimeStamp(); + auto now_secs = util::GetTimeStamp(); std::lock_guard lg(config_->backup_mu); - std::string task_backup_dir = config_->GetBackupDir(); + std::string task_backup_dir = config_->backup_dir; // Return if there is no backup auto s = env_->FileExists(task_backup_dir); if (!s.ok()) return; // No backup is needed to keep or the backup is expired, we will clean it. - bool backup_expired = (backup_max_keep_hours != 0 && backup_creating_time_ + backup_max_keep_hours * 3600 < now); + bool backup_expired = + (backup_max_keep_hours != 0 && backup_creating_time_secs_ + backup_max_keep_hours * 3600 < now_secs); if (num_backups_to_keep == 0 || backup_expired) { s = rocksdb::DestroyDB(task_backup_dir, rocksdb::Options()); if (s.ok()) { - LOG(INFO) << "[storage] Succeeded cleaning old backup that was created at " << backup_creating_time_; + LOG(INFO) << "[storage] Succeeded cleaning old backup that was created at " << backup_creating_time_secs_; } else { - LOG(INFO) << "[storage] Failed cleaning old backup that was created at " << backup_creating_time_ + LOG(INFO) << "[storage] Failed cleaning old backup that was created at " << backup_creating_time_secs_ << ". Error: " << s.ToString(); } } @@ -661,20 +684,19 @@ rocksdb::Status Storage::Delete(const rocksdb::WriteOptions &options, rocksdb::C return Write(options, batch->GetWriteBatch()); } -rocksdb::Status Storage::DeleteRange(const std::string &first_key, const std::string &last_key) { +rocksdb::Status Storage::DeleteRange(const rocksdb::WriteOptions &options, rocksdb::ColumnFamilyHandle *cf_handle, + Slice begin, Slice end) { auto batch = GetWriteBatchBase(); - rocksdb::ColumnFamilyHandle *cf_handle = GetCFHandle(kMetadataColumnFamilyName); - auto s = batch->DeleteRange(cf_handle, first_key, last_key); + auto s = batch->DeleteRange(cf_handle, begin, end); if (!s.ok()) { return s; } - s = batch->Delete(cf_handle, last_key); - if (!s.ok()) { - return s; - } + return Write(options, batch->GetWriteBatch()); +} - return Write(write_opts_, batch->GetWriteBatch()); +rocksdb::Status Storage::DeleteRange(Slice begin, Slice end) { + return DeleteRange(default_write_opts_, GetCFHandle(ColumnFamilyID::Metadata), begin, end); } rocksdb::Status Storage::FlushScripts(const rocksdb::WriteOptions &options, rocksdb::ColumnFamilyHandle *cf_handle) { @@ -693,7 +715,7 @@ rocksdb::Status Storage::FlushScripts(const rocksdb::WriteOptions &options, rock } Status Storage::ReplicaApplyWriteBatch(std::string &&raw_batch) { - return ApplyWriteBatch(write_opts_, std::move(raw_batch)); + return ApplyWriteBatch(default_write_opts_, std::move(raw_batch)); } Status Storage::ApplyWriteBatch(const rocksdb::WriteOptions &options, std::string &&raw_batch) { @@ -725,23 +747,6 @@ void Storage::RecordStat(StatType type, uint64_t v) { } } -rocksdb::ColumnFamilyHandle *Storage::GetCFHandle(const std::string &name) { - if (name == kMetadataColumnFamilyName) { - return cf_handles_[1]; - } else if (name == kZSetScoreColumnFamilyName) { - return cf_handles_[2]; - } else if (name == kPubSubColumnFamilyName) { - return cf_handles_[3]; - } else if (name == kPropagateColumnFamilyName) { - return cf_handles_[4]; - } else if (name == kStreamColumnFamilyName) { - return cf_handles_[5]; - } else if (name == kSearchColumnFamilyName) { - return cf_handles_[6]; - } - return cf_handles_[0]; -} - rocksdb::ColumnFamilyHandle *Storage::GetCFHandle(ColumnFamilyID id) { return cf_handles_[static_cast(id)]; } rocksdb::Status Storage::Compact(rocksdb::ColumnFamilyHandle *cf, const Slice *begin, const Slice *end) { @@ -763,8 +768,8 @@ uint64_t Storage::GetTotalSize(const std::string &ns) { return sst_file_manager_->GetTotalSize(); } - std::string begin_key, end_key; - std::string prefix = ComposeNamespaceKey(ns, "", false); + auto begin_key = ComposeNamespaceKey(ns, "", false); + auto end_key = util::StringNext(begin_key); redis::Database db(this, ns); uint64_t size = 0, total_size = 0; @@ -772,13 +777,10 @@ uint64_t Storage::GetTotalSize(const std::string &ns) { rocksdb::DB::SizeApproximationFlags::INCLUDE_FILES | rocksdb::DB::SizeApproximationFlags::INCLUDE_MEMTABLES; for (auto cf_handle : cf_handles_) { - if (cf_handle == GetCFHandle(kPubSubColumnFamilyName) || cf_handle == GetCFHandle(kPropagateColumnFamilyName)) { + if (cf_handle == GetCFHandle(ColumnFamilyID::PubSub) || cf_handle == GetCFHandle(ColumnFamilyID::Propagate)) { continue; } - auto s = db.FindKeyRangeWithPrefix(prefix, std::string(), &begin_key, &end_key, cf_handle); - if (!s.ok()) continue; - rocksdb::Range r(begin_key, end_key); db_->GetApproximateSizes(cf_handle, &r, 1, &size, include_both); total_size += size; @@ -831,7 +833,7 @@ Status Storage::CommitTxn() { return Status{Status::NotOK, "cannot commit while not in transaction mode"}; } - auto s = writeToDB(write_opts_, txn_write_batch_->GetWriteBatch()); + auto s = writeToDB(default_write_opts_, txn_write_batch_->GetWriteBatch()); is_txn_mode_ = false; txn_write_batch_ = nullptr; @@ -849,10 +851,13 @@ ObserverOrUniquePtr Storage::GetWriteBatchBase() { } Status Storage::WriteToPropagateCF(const std::string &key, const std::string &value) { + if (config_->IsSlave()) { + return {Status::NotOK, "cannot write to propagate column family in slave mode"}; + } auto batch = GetWriteBatchBase(); - auto cf = GetCFHandle(kPropagateColumnFamilyName); + auto cf = GetCFHandle(ColumnFamilyID::Propagate); batch->Put(cf, key, value); - auto s = Write(write_opts_, batch->GetWriteBatch()); + auto s = Write(default_write_opts_, batch->GetWriteBatch()); if (!s.ok()) { return {Status::NotOK, s.ToString()}; } @@ -927,7 +932,7 @@ std::string Storage::GetReplIdFromWalBySeq(rocksdb::SequenceNumber seq) { std::string Storage::GetReplIdFromDbEngine() { std::string replid_in_db; - auto cf = GetCFHandle(kPropagateColumnFamilyName); + auto cf = GetCFHandle(ColumnFamilyID::Propagate); auto s = db_->Get(rocksdb::ReadOptions(), cf, kReplicationIdKey, &replid_in_db); return replid_in_db; } @@ -958,9 +963,9 @@ Status Storage::ReplDataManager::GetFullReplDataInfo(Storage *storage, std::stri uint64_t checkpoint_latest_seq = 0; s = checkpoint->CreateCheckpoint(data_files_dir, storage->config_->rocks_db.write_buffer_size * MiB, &checkpoint_latest_seq); - auto now = static_cast(util::GetTimeStamp()); - storage->checkpoint_info_.create_time = now; - storage->checkpoint_info_.access_time = now; + auto now_secs = util::GetTimeStamp(); + storage->checkpoint_info_.create_time_secs = now_secs; + storage->checkpoint_info_.access_time_secs = now_secs; storage->checkpoint_info_.latest_seq = checkpoint_latest_seq; if (!s.ok()) { LOG(WARNING) << "[storage] Failed to create checkpoint (snapshot). Error: " << s.ToString(); @@ -970,12 +975,12 @@ Status Storage::ReplDataManager::GetFullReplDataInfo(Storage *storage, std::stri LOG(INFO) << "[storage] Create checkpoint successfully"; } else { // Replicas can share checkpoint to replication if the checkpoint existing time is less than a half of WAL TTL. - int64_t can_shared_time = storage->config_->rocks_db.wal_ttl_seconds / 2; - if (can_shared_time > 60 * 60) can_shared_time = 60 * 60; - if (can_shared_time < 10 * 60) can_shared_time = 10 * 60; + int64_t can_shared_time_secs = storage->config_->rocks_db.wal_ttl_seconds / 2; + if (can_shared_time_secs > 60 * 60) can_shared_time_secs = 60 * 60; + if (can_shared_time_secs < 10 * 60) can_shared_time_secs = 10 * 60; - auto now = static_cast(util::GetTimeStamp()); - if (now - storage->GetCheckpointCreateTime() > can_shared_time) { + auto now_secs = util::GetTimeStamp(); + if (now_secs - storage->GetCheckpointCreateTimeSecs() > can_shared_time_secs) { LOG(WARNING) << "[storage] Can't use current checkpoint, waiting next checkpoint"; return {Status::NotOK, "Can't use current checkpoint, waiting for next checkpoint"}; } diff --git a/src/storage/storage.h b/src/storage/storage.h index 3627ad6441e..7f31fc451a2 100644 --- a/src/storage/storage.h +++ b/src/storage/storage.h @@ -51,31 +51,25 @@ inline constexpr StorageEngineType STORAGE_ENGINE_TYPE = StorageEngineType::KVRO const int kReplIdLength = 16; -enum ColumnFamilyID { - kColumnFamilyIDDefault = 0, - kColumnFamilyIDMetadata, - kColumnFamilyIDZSetScore, - kColumnFamilyIDPubSub, - kColumnFamilyIDPropagate, - kColumnFamilyIDStream, - kColumnFamilyIDSearch, -}; - enum DBOpenMode { kDBOpenModeDefault, kDBOpenModeForReadOnly, kDBOpenModeAsSecondaryInstance, }; -namespace engine { +enum class ColumnFamilyID : uint32_t { + PrimarySubkey = 0, + Metadata, + SecondarySubkey, + PubSub, + Propagate, + Stream, + Search, +}; -constexpr const char *kPubSubColumnFamilyName = "pubsub"; -constexpr const char *kZSetScoreColumnFamilyName = "zset_score"; -constexpr const char *kMetadataColumnFamilyName = "metadata"; -constexpr const char *kSubkeyColumnFamilyName = "default"; -constexpr const char *kPropagateColumnFamilyName = "propagate"; -constexpr const char *kStreamColumnFamilyName = "stream"; -constexpr const char *kSearchColumnFamilyName = "search"; +constexpr uint32_t kMaxColumnFamilyID = static_cast(ColumnFamilyID::Search); + +namespace engine { constexpr const char *kPropagateScriptCommand = "script"; @@ -122,6 +116,84 @@ struct DBStats { alignas(CACHE_LINE_SIZE) std::atomic keyspace_misses = 0; }; +class ColumnFamilyConfig { + public: + ColumnFamilyConfig(ColumnFamilyID id, std::string_view name, bool is_minor) + : id_(id), name_(name), is_minor_(is_minor) {} + ColumnFamilyID Id() const { return id_; } + std::string_view Name() const { return name_; } + bool IsMinor() const { return is_minor_; } + + private: + ColumnFamilyID id_; + std::string_view name_; + bool is_minor_; +}; + +constexpr const std::string_view kPrimarySubkeyColumnFamilyName = "default"; +constexpr const std::string_view kMetadataColumnFamilyName = "metadata"; +constexpr const std::string_view kSecondarySubkeyColumnFamilyName = "zset_score"; +constexpr const std::string_view kPubSubColumnFamilyName = "pubsub"; +constexpr const std::string_view kPropagateColumnFamilyName = "propagate"; +constexpr const std::string_view kStreamColumnFamilyName = "stream"; +constexpr const std::string_view kSearchColumnFamilyName = "search"; + +class ColumnFamilyConfigs { + public: + /// DefaultSubkeyColumnFamily is the default column family in rocksdb. + /// In kvrocks, we use it to store the data if metadata is not enough. + static ColumnFamilyConfig PrimarySubkeyColumnFamily() { + return {ColumnFamilyID::PrimarySubkey, kPrimarySubkeyColumnFamilyName, /*is_minor=*/false}; + } + + /// MetadataColumnFamily stores the metadata of data-structures. + static ColumnFamilyConfig MetadataColumnFamily() { + return {ColumnFamilyID::Metadata, kMetadataColumnFamilyName, /*is_minor=*/false}; + } + + /// SecondarySubkeyColumnFamily stores the score of zset or other secondary subkey. + /// See https://kvrocks.apache.org/community/data-structure-on-rocksdb#zset for more details. + static ColumnFamilyConfig SecondarySubkeyColumnFamily() { + return {ColumnFamilyID::SecondarySubkey, kSecondarySubkeyColumnFamilyName, + /*is_minor=*/true}; + } + + /// PubSubColumnFamily stores the pubsub data. + static ColumnFamilyConfig PubSubColumnFamily() { + return {ColumnFamilyID::PubSub, kPubSubColumnFamilyName, /*is_minor=*/true}; + } + + static ColumnFamilyConfig PropagateColumnFamily() { + return {ColumnFamilyID::Propagate, kPropagateColumnFamilyName, /*is_minor=*/true}; + } + + static ColumnFamilyConfig StreamColumnFamily() { + return {ColumnFamilyID::Stream, kStreamColumnFamilyName, /*is_minor=*/true}; + } + + static ColumnFamilyConfig SearchColumnFamily() { + return {ColumnFamilyID::Search, kSearchColumnFamilyName, /*is_minor=*/true}; + } + + /// ListAllColumnFamilies returns all column families in kvrocks. + static const std::vector &ListAllColumnFamilies() { return AllCfs; } + + static const std::vector &ListColumnFamiliesWithoutDefault() { return AllCfsWithoutDefault; } + + static const ColumnFamilyConfig &GetColumnFamily(ColumnFamilyID id) { return AllCfs[static_cast(id)]; } + + private: + // Caution: don't change the order of column family, or the handle will be mismatched + inline const static std::vector AllCfs = { + PrimarySubkeyColumnFamily(), MetadataColumnFamily(), SecondarySubkeyColumnFamily(), PubSubColumnFamily(), + PropagateColumnFamily(), StreamColumnFamily(), SearchColumnFamily(), + }; + inline const static std::vector AllCfsWithoutDefault = { + MetadataColumnFamily(), SecondarySubkeyColumnFamily(), PubSubColumnFamily(), + PropagateColumnFamily(), StreamColumnFamily(), SearchColumnFamily(), + }; +}; + class Storage { public: explicit Storage(Config *config); @@ -130,6 +202,7 @@ class Storage { void SetWriteOptions(const Config::RocksDB::WriteOptions &config); Status Open(DBOpenMode mode = kDBOpenModeDefault); void CloseDB(); + bool IsEmptyDB(); void EmptyDB(); rocksdb::BlockBasedTableOptions InitTableOptions(); void SetBlobDB(rocksdb::ColumnFamilyOptions *cf_options); @@ -162,12 +235,14 @@ class Storage { rocksdb::Iterator *NewIterator(const rocksdb::ReadOptions &options); [[nodiscard]] rocksdb::Status Write(const rocksdb::WriteOptions &options, rocksdb::WriteBatch *updates); - const rocksdb::WriteOptions &DefaultWriteOptions() { return write_opts_; } + const rocksdb::WriteOptions &DefaultWriteOptions() { return default_write_opts_; } rocksdb::ReadOptions DefaultScanOptions() const; rocksdb::ReadOptions DefaultMultiGetOptions() const; [[nodiscard]] rocksdb::Status Delete(const rocksdb::WriteOptions &options, rocksdb::ColumnFamilyHandle *cf_handle, const rocksdb::Slice &key); - [[nodiscard]] rocksdb::Status DeleteRange(const std::string &first_key, const std::string &last_key); + [[nodiscard]] rocksdb::Status DeleteRange(const rocksdb::WriteOptions &options, + rocksdb::ColumnFamilyHandle *cf_handle, Slice begin, Slice end); + [[nodiscard]] rocksdb::Status DeleteRange(Slice begin, Slice end); [[nodiscard]] rocksdb::Status FlushScripts(const rocksdb::WriteOptions &options, rocksdb::ColumnFamilyHandle *cf_handle); bool WALHasNewData(rocksdb::SequenceNumber seq) { return seq <= LatestSeqNumber(); } @@ -179,7 +254,7 @@ class Storage { rocksdb::DB *GetDB(); bool IsClosing() const { return db_closing_; } std::string GetName() const { return config_->db_name; } - rocksdb::ColumnFamilyHandle *GetCFHandle(const std::string &name); + /// Get the column family handle by the column family id. rocksdb::ColumnFamilyHandle *GetCFHandle(ColumnFamilyID id); std::vector *GetCFHandles() { return &cf_handles_; } LockManager *GetLockManager() { return &lock_mgr_; } @@ -214,8 +289,10 @@ class Storage { static int OpenDataFile(Storage *storage, const std::string &rel_file, uint64_t *file_size); static Status CleanInvalidFiles(Storage *storage, const std::string &dir, std::vector valid_files); struct CheckpointInfo { - std::atomic create_time = 0; - std::atomic access_time = 0; + // System clock time when the checkpoint was created. + std::atomic create_time_secs = 0; + // System clock time when the checkpoint was last accessed. + std::atomic access_time_secs = 0; uint64_t latest_seq = 0; }; @@ -237,9 +314,9 @@ class Storage { bool ExistCheckpoint(); bool ExistSyncCheckpoint(); - time_t GetCheckpointCreateTime() const { return checkpoint_info_.create_time; } - void SetCheckpointAccessTime(time_t t) { checkpoint_info_.access_time = t; } - time_t GetCheckpointAccessTime() const { return checkpoint_info_.access_time; } + int64_t GetCheckpointCreateTimeSecs() const { return checkpoint_info_.create_time_secs; } + void SetCheckpointAccessTimeSecs(int64_t t) { checkpoint_info_.access_time_secs = t; } + int64_t GetCheckpointAccessTimeSecs() const { return checkpoint_info_.access_time_secs; } void SetDBInRetryableIOError(bool yes_or_no) { db_in_retryable_io_error_ = yes_or_no; } bool IsDBInRetryableIOError() const { return db_in_retryable_io_error_; } @@ -250,7 +327,8 @@ class Storage { private: std::unique_ptr db_ = nullptr; std::string replid_; - time_t backup_creating_time_; + // The system clock time when the backup was created. + int64_t backup_creating_time_secs_; std::unique_ptr backup_ = nullptr; rocksdb::Env *env_; std::shared_ptr sst_file_manager_; @@ -278,7 +356,7 @@ class Storage { // command, so it won't have multi transactions to be executed at the same time. std::unique_ptr txn_write_batch_; - rocksdb::WriteOptions write_opts_ = rocksdb::WriteOptions(); + rocksdb::WriteOptions default_write_opts_ = rocksdb::WriteOptions(); rocksdb::Status writeToDB(const rocksdb::WriteOptions &options, rocksdb::WriteBatch *updates); void recordKeyspaceStat(const rocksdb::ColumnFamilyHandle *column_family, const rocksdb::Status &s); diff --git a/src/types/json.h b/src/types/json.h index 4b786834a39..13abb5cc954 100644 --- a/src/types/json.h +++ b/src/types/json.h @@ -21,6 +21,7 @@ #pragma once #include +#include #include #include #include @@ -31,8 +32,6 @@ #include #include #include -#include -#include #include #include #include @@ -40,6 +39,7 @@ #include "common/string_util.h" #include "jsoncons_ext/jsonpath/jsonpath_error.hpp" #include "status.h" +#include "storage/redis_metadata.h" template using Optionals = std::vector>; @@ -151,9 +151,21 @@ struct JsonValue { Status Set(std::string_view path, JsonValue &&new_value) { try { - jsoncons::jsonpath::json_replace(value, path, [&new_value](const std::string & /*path*/, jsoncons::json &origin) { - origin = new_value.value; - }); + bool is_set = false; + jsoncons::jsonpath::json_replace(value, path, + [&new_value, &is_set](const std::string & /*path*/, jsoncons::json &origin) { + origin = new_value.value; + is_set = true; + }); + + if (!is_set) { + // NOTE: this is a workaround since jsonpath doesn't support replace for nonexistent paths in jsoncons + // and in this workaround we can only accept normalized path + // refer to https://github.com/danielaparker/jsoncons/issues/496 + jsoncons::jsonpath::json_location location = jsoncons::jsonpath::json_location::parse(path); + + jsoncons::jsonpath::replace(value, location, new_value.value, true); + } } catch (const jsoncons::jsonpath::jsonpath_error &e) { return {Status::NotOK, e.what()}; } @@ -207,6 +219,30 @@ struct JsonValue { return results; } + StatusOr> GetBytes(std::string_view path, JsonStorageFormat format, + int max_nesting_depth = std::numeric_limits::max()) const { + std::vector results; + Status s; + try { + jsoncons::jsonpath::json_query(value, path, [&](const std::string & /*path*/, const jsoncons::json &origin) { + if (!s) return; + std::string buffer; + JsonValue query_value(origin); + if (format == JsonStorageFormat::JSON) { + s = query_value.Dump(&buffer, max_nesting_depth); + } else if (format == JsonStorageFormat::CBOR) { + s = query_value.DumpCBOR(&buffer, max_nesting_depth); + } + results.emplace_back(buffer.size()); + }); + } catch (const jsoncons::jsonpath::jsonpath_error &e) { + return {Status::NotOK, e.what()}; + } + if (!s) return s; + + return results; + } + StatusOr Get(std::string_view path) const { try { return jsoncons::jsonpath::json_query(value, path); @@ -413,17 +449,12 @@ struct JsonValue { bool not_exists = jsoncons::jsonpath::json_query(value, path).empty(); if (not_exists) { + // NOTE: this is a workaround since jsonpath doesn't support replace for nonexistent paths in jsoncons + // and in this workaround we can only accept normalized path + // refer to https://github.com/danielaparker/jsoncons/issues/496 jsoncons::jsonpath::json_location location = jsoncons::jsonpath::json_location::parse(path); - jsoncons::jsonpointer::json_pointer ptr{}; - - for (const auto &element : location) { - if (element.has_name()) - ptr /= element.name(); - else { - ptr /= element.index(); - } - } - jsoncons::jsonpointer::replace(value, ptr, patch_value, true); + + jsoncons::jsonpath::replace(value, location, patch_value, true); is_updated = true; } else if (path == json_root_path) { @@ -442,8 +473,6 @@ struct JsonValue { jsoncons::jsonpath::remove(value, path); is_updated = true; } - } catch (const jsoncons::jsonpointer::jsonpointer_error &e) { - return {Status::NotOK, e.what()}; } catch (const jsoncons::jsonpath::jsonpath_error &e) { return {Status::NotOK, e.what()}; } catch (const jsoncons::ser_error &e) { @@ -568,28 +597,29 @@ struct JsonValue { if (!status.IsOK()) { return; } - if (!origin.is_number()) { + // is_number() will return true + // if it's actually a string but can convert to a number + // so here we should exclude such case + if (!origin.is_number() || origin.is_string()) { result->value.push_back(jsoncons::json::null()); return; } - if (number.value.is_double() || origin.is_double()) { - double v = 0; - if (op == NumOpEnum::Incr) { - v = origin.as_double() + number.value.as_double(); - } else if (op == NumOpEnum::Mul) { - v = origin.as_double() * number.value.as_double(); - } - if (std::isinf(v)) { - status = {Status::RedisExecErr, "result is an infinite number"}; - return; - } - origin = v; + double v = 0; + if (op == NumOpEnum::Incr) { + v = origin.as_double() + number.value.as_double(); + } else if (op == NumOpEnum::Mul) { + v = origin.as_double() * number.value.as_double(); + } + if (std::isinf(v)) { + status = {Status::RedisExecErr, "the result is an infinite number"}; + return; + } + double v_int = 0; + if (std::modf(v, &v_int) == 0 && double(std::numeric_limits::min()) < v && + v < double(std::numeric_limits::max())) { + origin = int64_t(v); } else { - if (op == NumOpEnum::Incr) { - origin = origin.as_integer() + number.value.as_integer(); - } else if (op == NumOpEnum::Mul) { - origin = origin.as_integer() * number.value.as_integer(); - } + origin = v; } result->value.push_back(origin); }); diff --git a/src/types/redis_bitmap.cc b/src/types/redis_bitmap.cc index ac3d8768f8e..9a08c1fe5fd 100644 --- a/src/types/redis_bitmap.cc +++ b/src/types/redis_bitmap.cc @@ -21,7 +21,9 @@ #include "redis_bitmap.h" #include +#include #include +#include #include #include "common/bit_util.h" @@ -93,8 +95,9 @@ uint32_t SegmentSubKeyIndexForBit(uint32_t bit_offset) { return (bit_offset / kBitmapSegmentBits) * kBitmapSegmentBytes; } -rocksdb::Status Bitmap::GetMetadata(const Slice &ns_key, BitmapMetadata *metadata, std::string *raw_value) { - auto s = GetRawMetadata(ns_key, raw_value); +rocksdb::Status Bitmap::GetMetadata(Database::GetOptions get_options, const Slice &ns_key, BitmapMetadata *metadata, + std::string *raw_value) { + auto s = GetRawMetadata(get_options, ns_key, raw_value); if (!s.ok()) return s; Slice slice = *raw_value; @@ -107,7 +110,8 @@ rocksdb::Status Bitmap::GetBit(const Slice &user_key, uint32_t bit_offset, bool std::string ns_key = AppendNamespacePrefix(user_key); BitmapMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata, &raw_value); + LatestSnapShot ss(storage_); + rocksdb::Status s = GetMetadata(GetOptions{ss.GetSnapShot()}, ns_key, &metadata, &raw_value); if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s; if (metadata.Type() == kRedisString) { @@ -115,7 +119,6 @@ rocksdb::Status Bitmap::GetBit(const Slice &user_key, uint32_t bit_offset, bool return bitmap_string_db.GetBit(raw_value, bit_offset, bit); } - LatestSnapShot ss(storage_); rocksdb::ReadOptions read_options; read_options.snapshot = ss.GetSnapShot(); rocksdb::PinnableSlice value; @@ -142,7 +145,8 @@ rocksdb::Status Bitmap::GetString(const Slice &user_key, const uint32_t max_btos std::string ns_key = AppendNamespacePrefix(user_key); BitmapMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata, &raw_value); + LatestSnapShot ss(storage_); + rocksdb::Status s = GetMetadata(GetOptions{ss.GetSnapShot()}, ns_key, &metadata, &raw_value); if (!s.ok()) return s; if (metadata.size > max_btos_size) { return rocksdb::Status::Aborted(kErrBitmapStringOutOfRange); @@ -152,8 +156,9 @@ rocksdb::Status Bitmap::GetString(const Slice &user_key, const uint32_t max_btos std::string prefix_key = InternalKey(ns_key, "", metadata.version, storage_->IsSlotIdEncoded()).Encode(); rocksdb::ReadOptions read_options = storage_->DefaultScanOptions(); - LatestSnapShot ss(storage_); read_options.snapshot = ss.GetSnapShot(); + Slice prefix_key_slice(prefix_key); + read_options.iterate_lower_bound = &prefix_key_slice; auto iter = util::UniqueIterator(storage_, read_options); for (iter->Seek(prefix_key); iter->Valid() && iter->key().starts_with(prefix_key); iter->Next()) { @@ -184,7 +189,7 @@ rocksdb::Status Bitmap::SetBit(const Slice &user_key, uint32_t bit_offset, bool LockGuard guard(storage_->GetLockManager(), ns_key); BitmapMetadata metadata; - rocksdb::Status s = GetMetadata(ns_key, &metadata, &raw_value); + rocksdb::Status s = GetMetadata(GetOptions{}, ns_key, &metadata, &raw_value); if (!s.ok() && !s.IsNotFound()) return s; if (metadata.Type() == kRedisString) { @@ -228,7 +233,8 @@ rocksdb::Status Bitmap::BitCount(const Slice &user_key, int64_t start, int64_t s std::string ns_key = AppendNamespacePrefix(user_key); BitmapMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata, &raw_value); + std::optional ss(storage_); + rocksdb::Status s = GetMetadata(GetOptions{ss->GetSnapShot()}, ns_key, &metadata, &raw_value); if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s; /* Convert negative indexes */ @@ -237,6 +243,9 @@ rocksdb::Status Bitmap::BitCount(const Slice &user_key, int64_t start, int64_t s } if (metadata.Type() == kRedisString) { + // Release snapshot ahead for performance, this requires + // `bitmap_string_db` doesn't get anything. + ss = std::nullopt; redis::BitmapString bitmap_string_db(storage_, namespace_); return bitmap_string_db.BitCount(raw_value, start, stop, is_bit_index, cnt); } @@ -257,9 +266,8 @@ rocksdb::Status Bitmap::BitCount(const Slice &user_key, int64_t start, int64_t s auto u_start = static_cast(start_byte); auto u_stop = static_cast(stop_byte); - LatestSnapShot ss(storage_); rocksdb::ReadOptions read_options; - read_options.snapshot = ss.GetSnapShot(); + read_options.snapshot = ss->GetSnapShot(); uint32_t start_index = u_start / kBitmapSegmentBytes; uint32_t stop_index = u_stop / kBitmapSegmentBytes; // Don't use multi get to prevent large range query, and take too much memory @@ -303,12 +311,15 @@ rocksdb::Status Bitmap::BitCount(const Slice &user_key, int64_t start, int64_t s } rocksdb::Status Bitmap::BitPos(const Slice &user_key, bool bit, int64_t start, int64_t stop, bool stop_given, - int64_t *pos) { + int64_t *pos, bool is_bit_index) { + if (is_bit_index) DCHECK(stop_given); + std::string raw_value; std::string ns_key = AppendNamespacePrefix(user_key); BitmapMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata, &raw_value); + std::optional ss(storage_); + rocksdb::Status s = GetMetadata(GetOptions{ss->GetSnapShot()}, ns_key, &metadata, &raw_value); if (!s.ok() && !s.IsNotFound()) return s; if (s.IsNotFound()) { *pos = bit ? -1 : 0; @@ -316,12 +327,17 @@ rocksdb::Status Bitmap::BitPos(const Slice &user_key, bool bit, int64_t start, i } if (metadata.Type() == kRedisString) { + ss = std::nullopt; redis::BitmapString bitmap_string_db(storage_, namespace_); - return bitmap_string_db.BitPos(raw_value, bit, start, stop, stop_given, pos); + return bitmap_string_db.BitPos(raw_value, bit, start, stop, stop_given, pos, is_bit_index); } - std::tie(start, stop) = BitmapString::NormalizeRange(start, stop, static_cast(metadata.size)); - auto u_start = static_cast(start); - auto u_stop = static_cast(stop); + + uint32_t to_bit_factor = is_bit_index ? 8 : 1; + auto size = static_cast(metadata.size) * static_cast(to_bit_factor); + + std::tie(start, stop) = BitmapString::NormalizeRange(start, stop, size); + auto u_start = static_cast(start); + auto u_stop = static_cast(stop); if (u_start > u_stop) { *pos = -1; return rocksdb::Status::OK(); @@ -335,14 +351,40 @@ rocksdb::Status Bitmap::BitPos(const Slice &user_key, bool bit, int64_t start, i return -1; }; - LatestSnapShot ss(storage_); + auto bit_pos_in_byte_startstop = [](char byte, bool bit, uint32_t start, uint32_t stop) -> int { + for (uint32_t i = start; i <= stop; i++) { + if (bit && (byte & (1 << i)) != 0) return (int)i; // typecast to int since the value ranges from 0 to 7 + if (!bit && (byte & (1 << i)) == 0) return (int)i; + } + return -1; + }; + rocksdb::ReadOptions read_options; - read_options.snapshot = ss.GetSnapShot(); - uint32_t start_index = u_start / kBitmapSegmentBytes; - uint32_t stop_index = u_stop / kBitmapSegmentBytes; + read_options.snapshot = ss->GetSnapShot(); + // if bit index, (Eg start = 1, stop = 35), then + // u_start = 1/8 = 0, u_stop = 35/8 = 4 (in bytes) + uint32_t start_segment_index = (u_start / to_bit_factor) / kBitmapSegmentBytes; + uint32_t stop_segment_index = (u_stop / to_bit_factor) / kBitmapSegmentBytes; + uint32_t start_bit_pos_in_byte = 0; + uint32_t stop_bit_pos_in_byte = 0; + + if (is_bit_index) { + start_bit_pos_in_byte = u_start % 8; + stop_bit_pos_in_byte = u_stop % 8; + } + + auto range_in_byte = [start_bit_pos_in_byte, stop_bit_pos_in_byte]( + uint32_t byte_start, uint32_t byte_end, + uint32_t curr_byte) -> std::pair { + if (curr_byte == byte_start && curr_byte == byte_end) return {start_bit_pos_in_byte, stop_bit_pos_in_byte}; + if (curr_byte == byte_start) return {start_bit_pos_in_byte, 7}; + if (curr_byte == byte_end) return {0, stop_bit_pos_in_byte}; + return {0, 7}; + }; + // Don't use multi get to prevent large range query, and take too much memory // Searching bits in segments [start_index, stop_index]. - for (uint32_t i = start_index; i <= stop_index; i++) { + for (uint32_t i = start_segment_index; i <= stop_segment_index; i++) { rocksdb::PinnableSlice pin_value; std::string sub_key = InternalKey(ns_key, std::to_string(i * kBitmapSegmentBytes), metadata.version, storage_->IsSlotIdEncoded()) @@ -359,17 +401,33 @@ rocksdb::Status Bitmap::BitPos(const Slice &user_key, bool bit, int64_t start, i continue; } size_t byte_pos_in_segment = 0; - if (i == start_index) byte_pos_in_segment = u_start % kBitmapSegmentBytes; + size_t byte_with_bit_start = -1; + size_t byte_with_bit_stop = -2; + // if bit index, (Eg start = 1, stop = 35), then + // byte_pos_in_segment should be calculated in bytes, hence divide by 8 + if (i == start_segment_index) { + byte_pos_in_segment = (u_start / to_bit_factor) % kBitmapSegmentBytes; + byte_with_bit_start = byte_pos_in_segment; + } size_t stop_byte_in_segment = pin_value.size(); - if (i == stop_index) { - DCHECK_LE(u_stop % kBitmapSegmentBytes + 1, pin_value.size()); - stop_byte_in_segment = u_stop % kBitmapSegmentBytes + 1; + if (i == stop_segment_index) { + DCHECK_LE((u_stop / to_bit_factor) % kBitmapSegmentBytes + 1, pin_value.size()); + stop_byte_in_segment = (u_stop / to_bit_factor) % kBitmapSegmentBytes + 1; + byte_with_bit_stop = stop_byte_in_segment; } // Invariant: // 1. pin_value.size() <= kBitmapSegmentBytes. // 2. If it's the last segment, metadata.size % kBitmapSegmentBytes <= pin_value.size(). for (; byte_pos_in_segment < stop_byte_in_segment; byte_pos_in_segment++) { - int bit_pos_in_byte_value = bit_pos_in_byte(pin_value[byte_pos_in_segment], bit); + int bit_pos_in_byte_value = -1; + if (is_bit_index) { + uint32_t start_bit = 0, stop_bit = 7; + std::tie(start_bit, stop_bit) = range_in_byte(byte_with_bit_start, byte_with_bit_stop, byte_pos_in_segment); + bit_pos_in_byte_value = bit_pos_in_byte_startstop(pin_value[byte_pos_in_segment], bit, start_bit, stop_bit); + } else { + bit_pos_in_byte_value = bit_pos_in_byte(pin_value[byte_pos_in_segment], bit); + } + if (bit_pos_in_byte_value != -1) { *pos = static_cast(i * kBitmapSegmentBits + byte_pos_in_segment * 8 + bit_pos_in_byte_value); return rocksdb::Status::OK(); @@ -382,7 +440,7 @@ rocksdb::Status Bitmap::BitPos(const Slice &user_key, bool bit, int64_t start, i // 1. If it's the last segment, we've done searching in the above loop. // 2. If it's not the last segment, we can check if the segment is all 0. if (pin_value.size() < kBitmapSegmentBytes) { - if (i == stop_index) { + if (i == stop_segment_index) { continue; } *pos = static_cast(i * kBitmapSegmentBits + pin_value.size() * 8); @@ -417,7 +475,7 @@ rocksdb::Status Bitmap::BitOp(BitOpFlags op_flag, const std::string &op_name, co for (const auto &op_key : op_keys) { BitmapMetadata metadata(false); std::string ns_op_key = AppendNamespacePrefix(op_key); - auto s = GetMetadata(ns_op_key, &metadata, &raw_value); + auto s = GetMetadata(GetOptions{}, ns_op_key, &metadata, &raw_value); if (!s.ok()) { if (s.IsNotFound()) { continue; @@ -769,7 +827,8 @@ rocksdb::Status Bitmap::bitfield(const Slice &user_key, const std::vector &op_keys, int64_t *len); rocksdb::Status Bitfield(const Slice &user_key, const std::vector &ops, @@ -72,7 +73,8 @@ class Bitmap : public Database { std::vector> *rets); static bool bitfieldWriteAheadLog(const ObserverOrUniquePtr &batch, const std::vector &ops); - rocksdb::Status GetMetadata(const Slice &ns_key, BitmapMetadata *metadata, std::string *raw_value); + rocksdb::Status GetMetadata(Database::GetOptions get_options, const Slice &ns_key, BitmapMetadata *metadata, + std::string *raw_value); template static rocksdb::Status runBitfieldOperationsWithCache(SegmentCacheStore &cache, diff --git a/src/types/redis_bitmap_string.cc b/src/types/redis_bitmap_string.cc index b226d9c2f7a..b10a5d45d49 100644 --- a/src/types/redis_bitmap_string.cc +++ b/src/types/redis_bitmap_string.cc @@ -100,31 +100,80 @@ rocksdb::Status BitmapString::BitCount(const std::string &raw_value, int64_t sta } rocksdb::Status BitmapString::BitPos(const std::string &raw_value, bool bit, int64_t start, int64_t stop, - bool stop_given, int64_t *pos) { + bool stop_given, int64_t *pos, bool is_bit_index) { std::string_view string_value = std::string_view{raw_value}.substr(Metadata::GetOffsetAfterExpire(raw_value[0])); auto strlen = static_cast(string_value.size()); /* Convert negative and out-of-bound indexes */ - std::tie(start, stop) = NormalizeRange(start, stop, strlen); + + int64_t length = is_bit_index ? strlen * 8 : strlen; + std::tie(start, stop) = NormalizeRange(start, stop, length); if (start > stop) { *pos = -1; - } else { - int64_t bytes = stop - start + 1; - *pos = util::msb::RawBitpos(reinterpret_cast(string_value.data()) + start, bytes, bit); - - /* If we are looking for clear bits, and the user specified an exact - * range with start-end, we can't consider the right of the range as - * zero padded (as we do when no explicit end is given). - * - * So if redisBitpos() returns the first bit outside the range, - * we return -1 to the caller, to mean, in the specified range there - * is not a single "0" bit. */ - if (stop_given && bit == 0 && *pos == bytes * 8) { + return rocksdb::Status::OK(); + } + + int64_t byte_start = is_bit_index ? start / 8 : start; + int64_t byte_stop = is_bit_index ? stop / 8 : stop; + int64_t bit_in_start_byte = is_bit_index ? start % 8 : 0; + int64_t bit_in_stop_byte = is_bit_index ? stop % 8 : 7; + int64_t bytes_cnt = byte_stop - byte_start + 1; + + auto bit_pos_in_byte_startstop = [](char byte, bool bit, uint32_t start, uint32_t stop) -> int { + for (uint32_t i = start; i <= stop; i++) { + if (util::msb::GetBitFromByte(byte, i) == bit) { + return (int)i; + } + } + return -1; + }; + + // if the bit start and bit end are in the same byte, we can process it manually + if (is_bit_index && byte_start == byte_stop) { + int res = bit_pos_in_byte_startstop(string_value[byte_start], bit, bit_in_start_byte, bit_in_stop_byte); + if (res != -1) { + *pos = res + byte_start * 8; + return rocksdb::Status::OK(); + } + *pos = -1; + return rocksdb::Status::OK(); + } + + if (is_bit_index && bit_in_start_byte != 0) { + // process first byte + int res = bit_pos_in_byte_startstop(string_value[byte_start], bit, bit_in_start_byte, 7); + if (res != -1) { + *pos = res + byte_start * 8; + return rocksdb::Status::OK(); + } + + byte_start++; + bytes_cnt--; + } + + *pos = util::msb::RawBitpos(reinterpret_cast(string_value.data()) + byte_start, bytes_cnt, bit); + + if (is_bit_index && *pos != -1 && *pos != bytes_cnt * 8) { + // if the pos is more than stop bit, then it is not in the range + if (*pos > stop) { *pos = -1; return rocksdb::Status::OK(); } - if (*pos != -1) *pos += start * 8; /* Adjust for the bytes we skipped. */ } + + /* If we are looking for clear bits, and the user specified an exact + * range with start-end, we tcan' consider the right of the range as + * zero padded (as we do when no explicit end is given). + * + * So if redisBitpos() returns the first bit outside the range, + * we return -1 to the caller, to mean, in the specified range there + * is not a single "0" bit. */ + if (stop_given && bit == 0 && *pos == bytes_cnt * 8) { + *pos = -1; + return rocksdb::Status::OK(); + } + if (*pos != -1) *pos += byte_start * 8; /* Adjust for the bytes we skipped. */ + return rocksdb::Status::OK(); } diff --git a/src/types/redis_bitmap_string.h b/src/types/redis_bitmap_string.h index 7997165afa3..030415c3492 100644 --- a/src/types/redis_bitmap_string.h +++ b/src/types/redis_bitmap_string.h @@ -39,7 +39,7 @@ class BitmapString : public Database { static rocksdb::Status BitCount(const std::string &raw_value, int64_t start, int64_t stop, bool is_bit_index, uint32_t *cnt); static rocksdb::Status BitPos(const std::string &raw_value, bool bit, int64_t start, int64_t stop, bool stop_given, - int64_t *pos); + int64_t *pos, bool is_bit_index); rocksdb::Status Bitfield(const Slice &ns_key, std::string *raw_value, const std::vector &ops, std::vector> *rets); static rocksdb::Status BitfieldReadOnly(const Slice &ns_key, const std::string &raw_value, diff --git a/src/types/redis_bloom_chain.cc b/src/types/redis_bloom_chain.cc index 90a43f056e6..6a4f3860464 100644 --- a/src/types/redis_bloom_chain.cc +++ b/src/types/redis_bloom_chain.cc @@ -24,8 +24,9 @@ namespace redis { -rocksdb::Status BloomChain::getBloomChainMetadata(const Slice &ns_key, BloomChainMetadata *metadata) { - return Database::GetMetadata({kRedisBloomFilter}, ns_key, metadata); +rocksdb::Status BloomChain::getBloomChainMetadata(Database::GetOptions get_options, const Slice &ns_key, + BloomChainMetadata *metadata) { + return Database::GetMetadata(get_options, {kRedisBloomFilter}, ns_key, metadata); } std::string BloomChain::getBFKey(const Slice &ns_key, const BloomChainMetadata &metadata, uint16_t filters_index) { @@ -124,7 +125,7 @@ rocksdb::Status BloomChain::Reserve(const Slice &user_key, uint32_t capacity, do LockGuard guard(storage_->GetLockManager(), ns_key); BloomChainMetadata bloom_chain_metadata; - rocksdb::Status s = getBloomChainMetadata(ns_key, &bloom_chain_metadata); + rocksdb::Status s = getBloomChainMetadata(GetOptions{}, ns_key, &bloom_chain_metadata); if (!s.ok() && !s.IsNotFound()) return s; if (!s.IsNotFound()) { return rocksdb::Status::InvalidArgument("the key already exists"); @@ -153,7 +154,7 @@ rocksdb::Status BloomChain::InsertCommon(const Slice &user_key, const std::vecto LockGuard guard(storage_->GetLockManager(), ns_key); BloomChainMetadata metadata; - rocksdb::Status s = getBloomChainMetadata(ns_key, &metadata); + rocksdb::Status s = getBloomChainMetadata(GetOptions{}, ns_key, &metadata); if (s.IsNotFound() && insert_options.auto_create) { s = createBloomChain(ns_key, insert_options.error_rate, insert_options.capacity, insert_options.expansion, @@ -235,7 +236,7 @@ rocksdb::Status BloomChain::MExists(const Slice &user_key, const std::vectorbegin(), exists->end(), false); return rocksdb::Status::OK(); @@ -269,7 +270,7 @@ rocksdb::Status BloomChain::Info(const Slice &user_key, BloomFilterInfo *info) { std::string ns_key = AppendNamespacePrefix(user_key); BloomChainMetadata metadata; - rocksdb::Status s = getBloomChainMetadata(ns_key, &metadata); + rocksdb::Status s = getBloomChainMetadata(Database::GetOptions{}, ns_key, &metadata); if (!s.ok()) return s; info->capacity = metadata.GetCapacity(); diff --git a/src/types/redis_bloom_chain.h b/src/types/redis_bloom_chain.h index 59f4b5b7d35..6f5f76da63f 100644 --- a/src/types/redis_bloom_chain.h +++ b/src/types/redis_bloom_chain.h @@ -74,7 +74,8 @@ class BloomChain : public Database { rocksdb::Status Info(const Slice &user_key, BloomFilterInfo *info); private: - rocksdb::Status getBloomChainMetadata(const Slice &ns_key, BloomChainMetadata *metadata); + rocksdb::Status getBloomChainMetadata(Database::GetOptions get_options, const Slice &ns_key, + BloomChainMetadata *metadata); std::string getBFKey(const Slice &ns_key, const BloomChainMetadata &metadata, uint16_t filters_index); void getBFKeyList(const Slice &ns_key, const BloomChainMetadata &metadata, std::vector *bf_key_list); rocksdb::Status getBFDataList(const std::vector &bf_key_list, diff --git a/src/types/redis_geo.cc b/src/types/redis_geo.cc index 37adc35d374..966e082f729 100644 --- a/src/types/redis_geo.cc +++ b/src/types/redis_geo.cc @@ -128,7 +128,7 @@ rocksdb::Status Geo::SearchStore(const Slice &user_key, GeoShape geo_shape, Orig std::string ns_key = AppendNamespacePrefix(user_key); ZSetMetadata metadata(false); - rocksdb::Status s = ZSet::GetMetadata(ns_key, &metadata); + rocksdb::Status s = ZSet::GetMetadata(GetOptions{}, ns_key, &metadata); // store key is not empty, try to remove it before returning. if (!s.ok() && s.IsNotFound() && !store_key.empty()) { auto del_s = ZSet::Del(store_key); diff --git a/src/types/redis_hash.cc b/src/types/redis_hash.cc index dcb1978e599..c4d60685934 100644 --- a/src/types/redis_hash.cc +++ b/src/types/redis_hash.cc @@ -34,8 +34,8 @@ namespace redis { -rocksdb::Status Hash::GetMetadata(const Slice &ns_key, HashMetadata *metadata) { - return Database::GetMetadata({kRedisHash}, ns_key, metadata); +rocksdb::Status Hash::GetMetadata(Database::GetOptions get_options, const Slice &ns_key, HashMetadata *metadata) { + return Database::GetMetadata(get_options, {kRedisHash}, ns_key, metadata); } rocksdb::Status Hash::Size(const Slice &user_key, uint64_t *size) { @@ -43,7 +43,7 @@ rocksdb::Status Hash::Size(const Slice &user_key, uint64_t *size) { std::string ns_key = AppendNamespacePrefix(user_key); HashMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(Database::GetOptions{}, ns_key, &metadata); if (!s.ok()) return s; *size = metadata.size; return rocksdb::Status::OK(); @@ -52,9 +52,9 @@ rocksdb::Status Hash::Size(const Slice &user_key, uint64_t *size) { rocksdb::Status Hash::Get(const Slice &user_key, const Slice &field, std::string *value) { std::string ns_key = AppendNamespacePrefix(user_key); HashMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); - if (!s.ok()) return s; LatestSnapShot ss(storage_); + rocksdb::Status s = GetMetadata(Database::GetOptions{ss.GetSnapShot()}, ns_key, &metadata); + if (!s.ok()) return s; rocksdb::ReadOptions read_options; read_options.snapshot = ss.GetSnapShot(); std::string sub_key = InternalKey(ns_key, field, metadata.version, storage_->IsSlotIdEncoded()).Encode(); @@ -69,7 +69,7 @@ rocksdb::Status Hash::IncrBy(const Slice &user_key, const Slice &field, int64_t LockGuard guard(storage_->GetLockManager(), ns_key); HashMetadata metadata; - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(GetOptions{}, ns_key, &metadata); if (!s.ok() && !s.IsNotFound()) return s; std::string sub_key = InternalKey(ns_key, field, metadata.version, storage_->IsSlotIdEncoded()).Encode(); @@ -116,7 +116,7 @@ rocksdb::Status Hash::IncrByFloat(const Slice &user_key, const Slice &field, dou LockGuard guard(storage_->GetLockManager(), ns_key); HashMetadata metadata; - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(GetOptions{}, ns_key, &metadata); if (!s.ok() && !s.IsNotFound()) return s; std::string sub_key = InternalKey(ns_key, field, metadata.version, storage_->IsSlotIdEncoded()).Encode(); @@ -159,12 +159,12 @@ rocksdb::Status Hash::MGet(const Slice &user_key, const std::vector &fiel std::string ns_key = AppendNamespacePrefix(user_key); HashMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + LatestSnapShot ss(storage_); + rocksdb::Status s = GetMetadata(GetOptions{ss.GetSnapShot()}, ns_key, &metadata); if (!s.ok()) { return s; } - LatestSnapShot ss(storage_); rocksdb::ReadOptions read_options = storage_->DefaultMultiGetOptions(); read_options.snapshot = ss.GetSnapShot(); std::vector keys; @@ -205,7 +205,7 @@ rocksdb::Status Hash::Delete(const Slice &user_key, const std::vector &fi WriteBatchLogData log_data(kRedisHash); batch->PutLogData(log_data.Encode()); LockGuard guard(storage_->GetLockManager(), ns_key); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(GetOptions{}, ns_key, &metadata); if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s; std::string value; @@ -238,7 +238,7 @@ rocksdb::Status Hash::MSet(const Slice &user_key, const std::vector LockGuard guard(storage_->GetLockManager(), ns_key); HashMetadata metadata; - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(GetOptions{}, ns_key, &metadata); if (!s.ok() && !s.IsNotFound()) return s; int added = 0; @@ -290,7 +290,8 @@ rocksdb::Status Hash::RangeByLex(const Slice &user_key, const RangeLexSpec &spec } std::string ns_key = AppendNamespacePrefix(user_key); HashMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + LatestSnapShot ss(storage_); + rocksdb::Status s = GetMetadata(GetOptions{ss.GetSnapShot()}, ns_key, &metadata); if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s; std::string start_member = spec.reversed ? spec.max : spec.min; @@ -299,7 +300,6 @@ rocksdb::Status Hash::RangeByLex(const Slice &user_key, const RangeLexSpec &spec std::string next_version_prefix_key = InternalKey(ns_key, "", metadata.version + 1, storage_->IsSlotIdEncoded()).Encode(); rocksdb::ReadOptions read_options = storage_->DefaultScanOptions(); - LatestSnapShot ss(storage_); read_options.snapshot = ss.GetSnapShot(); rocksdb::Slice upper_bound(next_version_prefix_key); read_options.iterate_upper_bound = &upper_bound; @@ -346,7 +346,8 @@ rocksdb::Status Hash::GetAll(const Slice &user_key, std::vector *fie std::string ns_key = AppendNamespacePrefix(user_key); HashMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + LatestSnapShot ss(storage_); + rocksdb::Status s = GetMetadata(GetOptions{ss.GetSnapShot()}, ns_key, &metadata); if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s; std::string prefix_key = InternalKey(ns_key, "", metadata.version, storage_->IsSlotIdEncoded()).Encode(); @@ -354,7 +355,6 @@ rocksdb::Status Hash::GetAll(const Slice &user_key, std::vector *fie InternalKey(ns_key, "", metadata.version + 1, storage_->IsSlotIdEncoded()).Encode(); rocksdb::ReadOptions read_options = storage_->DefaultScanOptions(); - LatestSnapShot ss(storage_); read_options.snapshot = ss.GetSnapShot(); rocksdb::Slice upper_bound(next_version_prefix_key); read_options.iterate_upper_bound = &upper_bound; @@ -387,7 +387,7 @@ rocksdb::Status Hash::RandField(const Slice &user_key, int64_t command_count, st std::string ns_key = AppendNamespacePrefix(user_key); HashMetadata metadata(/*generate_version=*/false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(GetOptions{}, ns_key, &metadata); if (!s.ok()) return s; std::vector samples; diff --git a/src/types/redis_hash.h b/src/types/redis_hash.h index 8ae0a066cce..10fc7d54502 100644 --- a/src/types/redis_hash.h +++ b/src/types/redis_hash.h @@ -65,7 +65,7 @@ class Hash : public SubKeyScanner { HashFetchType type = HashFetchType::kOnlyKey); private: - rocksdb::Status GetMetadata(const Slice &ns_key, HashMetadata *metadata); + rocksdb::Status GetMetadata(Database::GetOptions get_options, const Slice &ns_key, HashMetadata *metadata); friend struct FieldValueRetriever; }; diff --git a/src/types/redis_json.cc b/src/types/redis_json.cc index fb3add5feb0..5f14b024078 100644 --- a/src/types/redis_json.cc +++ b/src/types/redis_json.cc @@ -74,7 +74,7 @@ rocksdb::Status Json::read(const Slice &ns_key, JsonMetadata *metadata, JsonValu std::string bytes; Slice rest; - auto s = GetMetadata({kRedisJson}, ns_key, &bytes, metadata, &rest); + auto s = GetMetadata(GetOptions{}, {kRedisJson}, ns_key, &bytes, metadata, &rest); if (!s.ok()) return s; return parse(*metadata, rest, value); @@ -105,7 +105,7 @@ rocksdb::Status Json::Info(const std::string &user_key, JsonStorageFormat *stora Slice rest; JsonMetadata metadata; - auto s = GetMetadata({kRedisJson}, ns_key, &bytes, &metadata, &rest); + auto s = GetMetadata(GetOptions{}, {kRedisJson}, ns_key, &bytes, &metadata, &rest); if (!s.ok()) return s; *storage_format = metadata.format; @@ -447,12 +447,11 @@ rocksdb::Status Json::NumMultBy(const std::string &user_key, const std::string & rocksdb::Status Json::numop(JsonValue::NumOpEnum op, const std::string &user_key, const std::string &path, const std::string &value, JsonValue *result) { - JsonValue number; auto number_res = JsonValue::FromString(value); - if (!number_res || !number_res.GetValue().value.is_number()) { - return rocksdb::Status::InvalidArgument("should be a number"); + if (!number_res || !number_res.GetValue().value.is_number() || number_res.GetValue().value.is_string()) { + return rocksdb::Status::InvalidArgument("the input value should be a number"); } - number = std::move(number_res.GetValue()); + JsonValue number = std::move(number_res.GetValue()); auto ns_key = AppendNamespacePrefix(user_key); JsonMetadata metadata; @@ -548,6 +547,63 @@ std::vector Json::MGet(const std::vector &user_key return statuses; } +rocksdb::Status Json::MSet(const std::vector &user_keys, const std::vector &paths, + const std::vector &values) { + std::vector ns_keys; + ns_keys.reserve(user_keys.size()); + for (const auto &user_key : user_keys) { + std::string ns_key = AppendNamespacePrefix(user_key); + ns_keys.emplace_back(std::move(ns_key)); + } + MultiLockGuard guard(storage_->GetLockManager(), ns_keys); + + auto batch = storage_->GetWriteBatchBase(); + WriteBatchLogData log_data(kRedisJson); + batch->PutLogData(log_data.Encode()); + + for (size_t i = 0; i < user_keys.size(); i++) { + auto json_res = JsonValue::FromString(values[i], storage_->GetConfig()->json_max_nesting_depth); + if (!json_res) return rocksdb::Status::InvalidArgument(json_res.Msg()); + + JsonMetadata metadata; + JsonValue value; + + if (auto s = read(ns_keys[i], &metadata, &value); s.IsNotFound()) { + if (paths[i] != "$") return rocksdb::Status::InvalidArgument("new objects must be created at the root"); + + value = *std::move(json_res); + } else { + if (!s.ok()) return s; + + JsonValue new_val = *std::move(json_res); + auto set_res = value.Set(paths[i], std::move(new_val)); + if (!set_res) return rocksdb::Status::InvalidArgument(set_res.Msg()); + } + + auto format = storage_->GetConfig()->json_storage_format; + metadata.format = format; + + std::string val; + metadata.Encode(&val); + + Status res; + if (format == JsonStorageFormat::JSON) { + res = value.Dump(&val, storage_->GetConfig()->json_max_nesting_depth); + } else if (format == JsonStorageFormat::CBOR) { + res = value.DumpCBOR(&val, storage_->GetConfig()->json_max_nesting_depth); + } else { + return rocksdb::Status::InvalidArgument("JSON storage format not supported"); + } + if (!res) { + return rocksdb::Status::InvalidArgument("Failed to encode JSON into storage: " + res.Msg()); + } + + batch->Put(metadata_cf_handle_, ns_keys[i], val); + } + + return storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); +} + std::vector Json::readMulti(const std::vector &ns_keys, std::vector &values) { rocksdb::ReadOptions read_options = storage_->DefaultMultiGetOptions(); LatestSnapShot ss(storage_); @@ -570,4 +626,24 @@ std::vector Json::readMulti(const std::vector &ns_keys, return statuses; } +rocksdb::Status Json::DebugMemory(const std::string &user_key, const std::string &path, std::vector *results) { + auto ns_key = AppendNamespacePrefix(user_key); + JsonMetadata metadata; + if (path == "$") { + std::string bytes; + Slice rest; + auto s = GetMetadata(GetOptions{}, {kRedisJson}, ns_key, &bytes, &metadata, &rest); + if (!s.ok()) return s; + results->emplace_back(rest.size()); + } else { + JsonValue json_val; + auto s = read(ns_key, &metadata, &json_val); + if (!s.ok()) return s; + auto str_bytes = json_val.GetBytes(path, metadata.format, storage_->GetConfig()->json_max_nesting_depth); + if (!str_bytes) return rocksdb::Status::InvalidArgument(str_bytes.Msg()); + *results = std::move(*str_bytes); + } + return rocksdb::Status::OK(); +} + } // namespace redis diff --git a/src/types/redis_json.h b/src/types/redis_json.h index 8d0f15cb6dc..8dc212356e8 100644 --- a/src/types/redis_json.h +++ b/src/types/redis_json.h @@ -66,6 +66,9 @@ class Json : public Database { std::vector MGet(const std::vector &user_keys, const std::string &path, std::vector &results); + rocksdb::Status MSet(const std::vector &user_keys, const std::vector &paths, + const std::vector &values); + rocksdb::Status DebugMemory(const std::string &user_key, const std::string &path, std::vector *results); private: rocksdb::Status write(Slice ns_key, JsonMetadata *metadata, const JsonValue &json_val); diff --git a/src/types/redis_list.cc b/src/types/redis_list.cc index ed67007e905..21edda041e4 100644 --- a/src/types/redis_list.cc +++ b/src/types/redis_list.cc @@ -27,8 +27,8 @@ namespace redis { -rocksdb::Status List::GetMetadata(const Slice &ns_key, ListMetadata *metadata) { - return Database::GetMetadata({kRedisList}, ns_key, metadata); +rocksdb::Status List::GetMetadata(Database::GetOptions get_options, const Slice &ns_key, ListMetadata *metadata) { + return Database::GetMetadata(get_options, {kRedisList}, ns_key, metadata); } rocksdb::Status List::Size(const Slice &user_key, uint64_t *size) { @@ -36,7 +36,7 @@ rocksdb::Status List::Size(const Slice &user_key, uint64_t *size) { std::string ns_key = AppendNamespacePrefix(user_key); ListMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(GetOptions{}, ns_key, &metadata); if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s; *size = metadata.size; return rocksdb::Status::OK(); @@ -61,7 +61,7 @@ rocksdb::Status List::push(const Slice &user_key, const std::vector &elem WriteBatchLogData log_data(kRedisList, {std::to_string(cmd)}); batch->PutLogData(log_data.Encode()); LockGuard guard(storage_->GetLockManager(), ns_key); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(GetOptions{}, ns_key, &metadata); if (!s.ok() && !(create_if_missing && s.IsNotFound())) { return s.IsNotFound() ? rocksdb::Status::OK() : s; } @@ -105,7 +105,7 @@ rocksdb::Status List::PopMulti(const rocksdb::Slice &user_key, bool left, uint32 LockGuard guard(storage_->GetLockManager(), ns_key); ListMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(GetOptions{}, ns_key, &metadata); if (!s.ok()) return s; auto batch = storage_->GetWriteBatchBase(); @@ -169,7 +169,7 @@ rocksdb::Status List::Rem(const Slice &user_key, int count, const Slice &elem, u LockGuard guard(storage_->GetLockManager(), ns_key); ListMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(GetOptions{}, ns_key, &metadata); if (!s.ok()) return s; uint64_t index = count >= 0 ? metadata.head : metadata.tail - 1; @@ -259,7 +259,7 @@ rocksdb::Status List::Insert(const Slice &user_key, const Slice &pivot, const Sl LockGuard guard(storage_->GetLockManager(), ns_key); ListMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(GetOptions{}, ns_key, &metadata); if (!s.ok()) return s; std::string buf; @@ -334,14 +334,14 @@ rocksdb::Status List::Index(const Slice &user_key, int index, std::string *elem) std::string ns_key = AppendNamespacePrefix(user_key); ListMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + LatestSnapShot ss(storage_); + rocksdb::Status s = GetMetadata(GetOptions{ss.GetSnapShot()}, ns_key, &metadata); if (!s.ok()) return s; if (index < 0) index += static_cast(metadata.size); if (index < 0 || index >= static_cast(metadata.size)) return rocksdb::Status::NotFound(); rocksdb::ReadOptions read_options; - LatestSnapShot ss(storage_); read_options.snapshot = ss.GetSnapShot(); std::string buf; PutFixed64(&buf, metadata.head + index); @@ -359,7 +359,8 @@ rocksdb::Status List::Range(const Slice &user_key, int start, int stop, std::vec std::string ns_key = AppendNamespacePrefix(user_key); ListMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + LatestSnapShot ss(storage_); + rocksdb::Status s = GetMetadata(GetOptions{ss.GetSnapShot()}, ns_key, &metadata); if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s; if (start < 0) start = static_cast(metadata.size) + start; @@ -374,7 +375,6 @@ rocksdb::Status List::Range(const Slice &user_key, int start, int stop, std::vec std::string next_version_prefix = InternalKey(ns_key, "", metadata.version + 1, storage_->IsSlotIdEncoded()).Encode(); rocksdb::ReadOptions read_options = storage_->DefaultScanOptions(); - LatestSnapShot ss(storage_); read_options.snapshot = ss.GetSnapShot(); rocksdb::Slice upper_bound(next_version_prefix); read_options.iterate_upper_bound = &upper_bound; @@ -398,7 +398,7 @@ rocksdb::Status List::Pos(const Slice &user_key, const Slice &elem, const PosSpe std::string ns_key = AppendNamespacePrefix(user_key); ListMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(GetOptions{}, ns_key, &metadata); if (!s.ok()) return s; // A negative rank means start from the tail. @@ -454,7 +454,7 @@ rocksdb::Status List::Set(const Slice &user_key, int index, Slice elem) { LockGuard guard(storage_->GetLockManager(), ns_key); ListMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(GetOptions{}, ns_key, &metadata); if (!s.ok()) return s; if (index < 0) index += static_cast(metadata.size); if (index < 0 || index >= static_cast(metadata.size)) { @@ -490,7 +490,7 @@ rocksdb::Status List::lmoveOnSingleList(const rocksdb::Slice &src, bool src_left LockGuard guard(storage_->GetLockManager(), ns_key); ListMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(GetOptions{}, ns_key, &metadata); if (!s.ok()) { return s; } @@ -553,13 +553,13 @@ rocksdb::Status List::lmoveOnTwoLists(const rocksdb::Slice &src, const rocksdb:: std::vector lock_keys{src_ns_key, dst_ns_key}; MultiLockGuard guard(storage_->GetLockManager(), lock_keys); ListMetadata src_metadata(false); - auto s = GetMetadata(src_ns_key, &src_metadata); + auto s = GetMetadata(GetOptions{}, src_ns_key, &src_metadata); if (!s.ok()) { return s; } ListMetadata dst_metadata(false); - s = GetMetadata(dst_ns_key, &dst_metadata); + s = GetMetadata(GetOptions{}, dst_ns_key, &dst_metadata); if (!s.ok() && !s.IsNotFound()) { return s; } @@ -615,7 +615,7 @@ rocksdb::Status List::Trim(const Slice &user_key, int start, int stop) { LockGuard guard(storage_->GetLockManager(), ns_key); ListMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(GetOptions{}, ns_key, &metadata); if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s; if (start < 0) start += static_cast(metadata.size); diff --git a/src/types/redis_list.h b/src/types/redis_list.h index cf0effd8cc9..d5e6c7a1fae 100644 --- a/src/types/redis_list.h +++ b/src/types/redis_list.h @@ -56,7 +56,7 @@ class List : public Database { rocksdb::Status Pos(const Slice &user_key, const Slice &elem, const PosSpec &spec, std::vector *indexes); private: - rocksdb::Status GetMetadata(const Slice &ns_key, ListMetadata *metadata); + rocksdb::Status GetMetadata(Database::GetOptions get_options, const Slice &ns_key, ListMetadata *metadata); rocksdb::Status push(const Slice &user_key, const std::vector &elems, bool create_if_missing, bool left, uint64_t *new_size); rocksdb::Status lmoveOnSingleList(const Slice &src, bool src_left, bool dst_left, std::string *elem); diff --git a/src/types/redis_set.cc b/src/types/redis_set.cc index 35403b88ac8..0c7e4009b28 100644 --- a/src/types/redis_set.cc +++ b/src/types/redis_set.cc @@ -29,8 +29,8 @@ namespace redis { -rocksdb::Status Set::GetMetadata(const Slice &ns_key, SetMetadata *metadata) { - return Database::GetMetadata({kRedisSet}, ns_key, metadata); +rocksdb::Status Set::GetMetadata(Database::GetOptions get_options, const Slice &ns_key, SetMetadata *metadata) { + return Database::GetMetadata(get_options, {kRedisSet}, ns_key, metadata); } // Make sure members are uniq before use Overwrite @@ -60,7 +60,7 @@ rocksdb::Status Set::Add(const Slice &user_key, const std::vector &member LockGuard guard(storage_->GetLockManager(), ns_key); SetMetadata metadata; - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(Database::GetOptions{}, ns_key, &metadata); if (!s.ok() && !s.IsNotFound()) return s; std::string value; @@ -94,7 +94,7 @@ rocksdb::Status Set::Remove(const Slice &user_key, const std::vector &mem LockGuard guard(storage_->GetLockManager(), ns_key); SetMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(Database::GetOptions{}, ns_key, &metadata); if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s; std::string value; @@ -130,7 +130,7 @@ rocksdb::Status Set::Card(const Slice &user_key, uint64_t *size) { std::string ns_key = AppendNamespacePrefix(user_key); SetMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(Database::GetOptions{}, ns_key, &metadata); if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s; *size = metadata.size; return rocksdb::Status::OK(); @@ -142,14 +142,15 @@ rocksdb::Status Set::Members(const Slice &user_key, std::vector *me std::string ns_key = AppendNamespacePrefix(user_key); SetMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + LatestSnapShot ss(storage_); + + rocksdb::Status s = GetMetadata(Database::GetOptions{ss.GetSnapShot()}, ns_key, &metadata); if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s; std::string prefix = InternalKey(ns_key, "", metadata.version, storage_->IsSlotIdEncoded()).Encode(); std::string next_version_prefix = InternalKey(ns_key, "", metadata.version + 1, storage_->IsSlotIdEncoded()).Encode(); rocksdb::ReadOptions read_options = storage_->DefaultScanOptions(); - LatestSnapShot ss(storage_); read_options.snapshot = ss.GetSnapShot(); rocksdb::Slice upper_bound(next_version_prefix); read_options.iterate_upper_bound = &upper_bound; @@ -176,11 +177,12 @@ rocksdb::Status Set::MIsMember(const Slice &user_key, const std::vector & std::string ns_key = AppendNamespacePrefix(user_key); SetMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + LatestSnapShot ss(storage_); + rocksdb::Status s = GetMetadata(Database::GetOptions{ss.GetSnapShot()}, ns_key, &metadata); if (!s.ok()) return s; rocksdb::ReadOptions read_options; - LatestSnapShot ss(storage_); + read_options.snapshot = ss.GetSnapShot(); std::string value; for (const auto &member : members) { @@ -212,7 +214,7 @@ rocksdb::Status Set::Take(const Slice &user_key, std::vector *membe if (pop) lock_guard.emplace(storage_->GetLockManager(), ns_key); SetMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(Database::GetOptions{}, ns_key, &metadata); if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s; ObserverOrUniquePtr batch = storage_->GetWriteBatchBase(); diff --git a/src/types/redis_set.h b/src/types/redis_set.h index 739c3077c35..bfd8bd3c7aa 100644 --- a/src/types/redis_set.h +++ b/src/types/redis_set.h @@ -52,7 +52,7 @@ class Set : public SubKeyScanner { const std::string &member_prefix, std::vector *members); private: - rocksdb::Status GetMetadata(const Slice &ns_key, SetMetadata *metadata); + rocksdb::Status GetMetadata(Database::GetOptions options, const Slice &ns_key, SetMetadata *metadata); }; } // namespace redis diff --git a/src/types/redis_sortedint.cc b/src/types/redis_sortedint.cc index 6670eb7474f..7de9e70a0e0 100644 --- a/src/types/redis_sortedint.cc +++ b/src/types/redis_sortedint.cc @@ -28,8 +28,9 @@ namespace redis { -rocksdb::Status Sortedint::GetMetadata(const Slice &ns_key, SortedintMetadata *metadata) { - return Database::GetMetadata({kRedisSortedint}, ns_key, metadata); +rocksdb::Status Sortedint::GetMetadata(Database::GetOptions get_options, const Slice &ns_key, + SortedintMetadata *metadata) { + return Database::GetMetadata(get_options, {kRedisSortedint}, ns_key, metadata); } rocksdb::Status Sortedint::Add(const Slice &user_key, const std::vector &ids, uint64_t *added_cnt) { @@ -39,7 +40,7 @@ rocksdb::Status Sortedint::Add(const Slice &user_key, const std::vectorGetLockManager(), ns_key); SortedintMetadata metadata; - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(GetOptions{}, ns_key, &metadata); if (!s.ok() && !s.IsNotFound()) return s; std::string value; @@ -72,7 +73,7 @@ rocksdb::Status Sortedint::Remove(const Slice &user_key, const std::vectorGetLockManager(), ns_key); SortedintMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(GetOptions{}, ns_key, &metadata); if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s; std::string value; @@ -101,7 +102,7 @@ rocksdb::Status Sortedint::Card(const Slice &user_key, uint64_t *size) { std::string ns_key = AppendNamespacePrefix(user_key); SortedintMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(GetOptions{}, ns_key, &metadata); if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s; *size = metadata.size; return rocksdb::Status::OK(); @@ -114,7 +115,8 @@ rocksdb::Status Sortedint::Range(const Slice &user_key, uint64_t cursor_id, uint std::string ns_key = AppendNamespacePrefix(user_key); SortedintMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + LatestSnapShot ss(storage_); + rocksdb::Status s = GetMetadata(GetOptions{ss.GetSnapShot()}, ns_key, &metadata); if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s; std::string start_buf; @@ -128,7 +130,6 @@ rocksdb::Status Sortedint::Range(const Slice &user_key, uint64_t cursor_id, uint std::string next_version_prefix = InternalKey(ns_key, "", metadata.version + 1, storage_->IsSlotIdEncoded()).Encode(); rocksdb::ReadOptions read_options = storage_->DefaultScanOptions(); - LatestSnapShot ss(storage_); read_options.snapshot = ss.GetSnapShot(); rocksdb::Slice upper_bound(next_version_prefix); read_options.iterate_upper_bound = &upper_bound; @@ -157,7 +158,8 @@ rocksdb::Status Sortedint::RangeByValue(const Slice &user_key, SortedintRangeSpe std::string ns_key = AppendNamespacePrefix(user_key); SortedintMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + LatestSnapShot ss(storage_); + rocksdb::Status s = GetMetadata(GetOptions{ss.GetSnapShot()}, ns_key, &metadata); if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s; std::string start_buf; @@ -168,7 +170,6 @@ rocksdb::Status Sortedint::RangeByValue(const Slice &user_key, SortedintRangeSpe InternalKey(ns_key, "", metadata.version + 1, storage_->IsSlotIdEncoded()).Encode(); rocksdb::ReadOptions read_options = storage_->DefaultScanOptions(); - LatestSnapShot ss(storage_); read_options.snapshot = ss.GetSnapShot(); rocksdb::Slice upper_bound(next_version_prefix_key); read_options.iterate_upper_bound = &upper_bound; @@ -207,10 +208,10 @@ rocksdb::Status Sortedint::MExist(const Slice &user_key, const std::vectorGetLockManager(), ns_key); StreamMetadata metadata; - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(GetOptions{}, ns_key, &metadata); if (!s.ok() && !s.IsNotFound()) return s; if (s.IsNotFound() && options.nomkstream) { @@ -236,8 +237,8 @@ std::string Stream::consumerNameFromInternalKey(rocksdb::Slice key) const { std::string Stream::encodeStreamConsumerMetadataValue(const StreamConsumerMetadata &consumer_metadata) { std::string dst; PutFixed64(&dst, consumer_metadata.pending_number); - PutFixed64(&dst, consumer_metadata.last_idle); - PutFixed64(&dst, consumer_metadata.last_active); + PutFixed64(&dst, consumer_metadata.last_attempted_interaction_ms); + PutFixed64(&dst, consumer_metadata.last_successful_interaction_ms); return dst; } @@ -245,8 +246,8 @@ StreamConsumerMetadata Stream::decodeStreamConsumerMetadataValue(const std::stri StreamConsumerMetadata consumer_metadata; rocksdb::Slice input(value); GetFixed64(&input, &consumer_metadata.pending_number); - GetFixed64(&input, &consumer_metadata.last_idle); - GetFixed64(&input, &consumer_metadata.last_active); + GetFixed64(&input, &consumer_metadata.last_attempted_interaction_ms); + GetFixed64(&input, &consumer_metadata.last_successful_interaction_ms); return consumer_metadata; } @@ -276,7 +277,7 @@ StreamEntryID Stream::groupAndEntryIdFromPelInternalKey(rocksdb::Slice key, std: std::string Stream::encodeStreamPelEntryValue(const StreamPelEntry &pel_entry) { std::string dst; - PutFixed64(&dst, pel_entry.last_delivery_time); + PutFixed64(&dst, pel_entry.last_delivery_time_ms); PutFixed64(&dst, pel_entry.last_delivery_count); PutFixed64(&dst, pel_entry.consumer_name.size()); dst += pel_entry.consumer_name; @@ -286,7 +287,7 @@ std::string Stream::encodeStreamPelEntryValue(const StreamPelEntry &pel_entry) { StreamPelEntry Stream::decodeStreamPelEntryValue(const std::string &value) { StreamPelEntry pel_entry; rocksdb::Slice input(value); - GetFixed64(&input, &pel_entry.last_delivery_time); + GetFixed64(&input, &pel_entry.last_delivery_time_ms); GetFixed64(&input, &pel_entry.last_delivery_count); uint64_t consumer_name_len = 0; GetFixed64(&input, &consumer_name_len); @@ -314,6 +315,362 @@ StreamSubkeyType Stream::identifySubkeyType(const rocksdb::Slice &key) const { return StreamSubkeyType::StreamConsumerMetadata; } +rocksdb::Status Stream::DeletePelEntries(const Slice &stream_name, const std::string &group_name, + const std::vector &entry_ids, uint64_t *acknowledged) { + *acknowledged = 0; + + std::string ns_key = AppendNamespacePrefix(stream_name); + + LockGuard guard(storage_->GetLockManager(), ns_key); + StreamMetadata metadata(false); + rocksdb::Status s = GetMetadata(GetOptions{}, ns_key, &metadata); + if (!s.ok()) { + return s.IsNotFound() ? rocksdb::Status::OK() : s; + } + + std::string group_key = internalKeyFromGroupName(ns_key, metadata, group_name); + std::string get_group_value; + s = storage_->Get(rocksdb::ReadOptions(), stream_cf_handle_, group_key, &get_group_value); + if (!s.ok()) { + return s.IsNotFound() ? rocksdb::Status::OK() : s; + } + + auto batch = storage_->GetWriteBatchBase(); + WriteBatchLogData log_data(kRedisStream); + batch->PutLogData(log_data.Encode()); + + std::map consumer_acknowledges; + for (const auto &id : entry_ids) { + std::string entry_key = internalPelKeyFromGroupAndEntryId(ns_key, metadata, group_name, id); + std::string value; + s = storage_->Get(rocksdb::ReadOptions(), stream_cf_handle_, entry_key, &value); + if (!s.ok() && !s.IsNotFound()) { + return s; + } + if (s.ok()) { + *acknowledged += 1; + batch->Delete(stream_cf_handle_, entry_key); + + // increment ack for each related consumer + auto pel_entry = decodeStreamPelEntryValue(value); + consumer_acknowledges[pel_entry.consumer_name]++; + } + } + if (*acknowledged > 0) { + StreamConsumerGroupMetadata group_metadata = decodeStreamConsumerGroupMetadataValue(get_group_value); + group_metadata.pending_number -= *acknowledged; + std::string group_value = encodeStreamConsumerGroupMetadataValue(group_metadata); + batch->Put(stream_cf_handle_, group_key, group_value); + + for (const auto &[consumer_name, ack_count] : consumer_acknowledges) { + auto consumer_meta_key = internalKeyFromConsumerName(ns_key, metadata, group_name, consumer_name); + std::string consumer_meta_original; + s = storage_->Get(rocksdb::ReadOptions(), stream_cf_handle_, consumer_meta_key, &consumer_meta_original); + if (!s.ok() && !s.IsNotFound()) { + return s; + } + if (s.ok()) { + auto consumer_metadata = decodeStreamConsumerMetadataValue(consumer_meta_original); + consumer_metadata.pending_number -= ack_count; + batch->Put(stream_cf_handle_, consumer_meta_key, encodeStreamConsumerMetadataValue(consumer_metadata)); + } + } + } + return storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); +} + +rocksdb::Status Stream::ClaimPelEntries(const Slice &stream_name, const std::string &group_name, + const std::string &consumer_name, const uint64_t min_idle_time_ms, + const std::vector &entry_ids, const StreamClaimOptions &options, + StreamClaimResult *result) { + std::string ns_key = AppendNamespacePrefix(stream_name); + LockGuard guard(storage_->GetLockManager(), ns_key); + StreamMetadata metadata(false); + rocksdb::Status s = GetMetadata(GetOptions{}, ns_key, &metadata); + if (!s.ok()) return s; + + std::string group_key = internalKeyFromGroupName(ns_key, metadata, group_name); + std::string get_group_value; + s = storage_->Get(rocksdb::ReadOptions(), stream_cf_handle_, group_key, &get_group_value); + if (!s.ok() && !s.IsNotFound()) { + return s; + } + if (s.IsNotFound()) { + return rocksdb::Status::InvalidArgument("NOGROUP No such consumer group " + group_name + " for key name " + + stream_name.ToString()); + } + StreamConsumerGroupMetadata group_metadata = decodeStreamConsumerGroupMetadataValue(get_group_value); + + std::string consumer_key = internalKeyFromConsumerName(ns_key, metadata, group_name, consumer_name); + std::string get_consumer_value; + s = storage_->Get(rocksdb::ReadOptions(), stream_cf_handle_, consumer_key, &get_consumer_value); + if (!s.ok() && !s.IsNotFound()) { + return s; + } + if (s.IsNotFound()) { + int created_number = 0; + s = createConsumerWithoutLock(stream_name, group_name, consumer_name, &created_number); + if (!s.ok()) { + return s; + } + group_metadata.consumer_number += created_number; + } + StreamConsumerMetadata consumer_metadata; + if (!s.IsNotFound()) { + consumer_metadata = decodeStreamConsumerMetadataValue(get_consumer_value); + } + auto now = util::GetTimeStampMS(); + consumer_metadata.last_attempted_interaction_ms = now; + consumer_metadata.last_successful_interaction_ms = now; + + auto batch = storage_->GetWriteBatchBase(); + WriteBatchLogData log_data(kRedisStream); + batch->PutLogData(log_data.Encode()); + + for (const auto &id : entry_ids) { + std::string raw_value; + rocksdb::Status s = getEntryRawValue(ns_key, metadata, id, &raw_value); + if (!s.ok() && !s.IsNotFound()) { + return s; + } + if (s.IsNotFound()) continue; + + std::string entry_key = internalPelKeyFromGroupAndEntryId(ns_key, metadata, group_name, id); + std::string value; + s = storage_->Get(rocksdb::ReadOptions(), stream_cf_handle_, entry_key, &value); + StreamPelEntry pel_entry; + + if (!s.ok() && s.IsNotFound() && options.force) { + pel_entry = {0, 0, ""}; + group_metadata.pending_number += 1; + } + + if (s.ok()) { + pel_entry = decodeStreamPelEntryValue(value); + } + + if (s.ok() || (s.IsNotFound() && options.force)) { + if (now - pel_entry.last_delivery_time_ms < min_idle_time_ms) continue; + + std::vector values; + if (options.just_id) { + result->ids.emplace_back(id.ToString()); + } else { + auto rv = DecodeRawStreamEntryValue(raw_value, &values); + if (!rv.IsOK()) { + return rocksdb::Status::InvalidArgument(rv.Msg()); + } + result->entries.emplace_back(id.ToString(), std::move(values)); + } + + if (pel_entry.consumer_name != "") { + std::string original_consumer_key = + internalKeyFromConsumerName(ns_key, metadata, group_name, pel_entry.consumer_name); + std::string get_original_consumer_value; + s = storage_->Get(rocksdb::ReadOptions(), stream_cf_handle_, original_consumer_key, + &get_original_consumer_value); + if (!s.ok()) { + return s; + } + StreamConsumerMetadata original_consumer_metadata = + decodeStreamConsumerMetadataValue(get_original_consumer_value); + original_consumer_metadata.pending_number -= 1; + batch->Put(stream_cf_handle_, original_consumer_key, + encodeStreamConsumerMetadataValue(original_consumer_metadata)); + } + + pel_entry.consumer_name = consumer_name; + consumer_metadata.pending_number += 1; + if (options.with_time) { + pel_entry.last_delivery_time_ms = options.last_delivery_time_ms; + } else { + pel_entry.last_delivery_time_ms = now - options.idle_time_ms; + } + + if (pel_entry.last_delivery_time_ms < 0 || pel_entry.last_delivery_time_ms > now) { + pel_entry.last_delivery_time_ms = now; + } + + if (options.with_retry_count) { + pel_entry.last_delivery_count = options.last_delivery_count; + } else if (!options.just_id) { + pel_entry.last_delivery_count += 1; + } + + std::string pel_value = encodeStreamPelEntryValue(pel_entry); + batch->Put(stream_cf_handle_, entry_key, pel_value); + } + } + + if (options.with_last_id && options.last_delivered_id > group_metadata.last_delivered_id) { + group_metadata.last_delivered_id = options.last_delivered_id; + } + + batch->Put(stream_cf_handle_, consumer_key, encodeStreamConsumerMetadataValue(consumer_metadata)); + batch->Put(stream_cf_handle_, group_key, encodeStreamConsumerGroupMetadataValue(group_metadata)); + return storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); +} + +rocksdb::Status Stream::AutoClaim(const Slice &stream_name, const std::string &group_name, + const std::string &consumer_name, const StreamAutoClaimOptions &options, + StreamAutoClaimResult *result) { + if (options.exclude_start && options.start_id.IsMaximum()) { + return rocksdb::Status::InvalidArgument("invalid start ID for the interval"); + } + + std::string ns_key = AppendNamespacePrefix(stream_name); + StreamMetadata metadata(false); + + LockGuard guard(storage_->GetLockManager(), ns_key); + auto s = GetMetadata(GetOptions{}, ns_key, &metadata); + if (!s.ok()) { // not found will be caught by outside with no such key or consumer group + return s; + } + + std::string consumer_key = internalKeyFromConsumerName(ns_key, metadata, group_name, consumer_name); + std::string get_consumer_value; + s = storage_->Get(rocksdb::ReadOptions(), stream_cf_handle_, consumer_key, &get_consumer_value); + if (!s.ok() && !s.IsNotFound()) { + return s; + } + if (s.IsNotFound()) { + int created_number = 0; + s = createConsumerWithoutLock(stream_name, group_name, consumer_name, &created_number); + if (!s.ok()) { + return s; + } + s = storage_->Get(rocksdb::ReadOptions(), stream_cf_handle_, consumer_key, &get_consumer_value); + if (!s.ok()) { + return s; + } + } + + StreamConsumerMetadata current_consumer_metadata = decodeStreamConsumerMetadataValue(get_consumer_value); + std::map claimed_consumer_entity_count; + std::string prefix_key = internalPelKeyFromGroupAndEntryId(ns_key, metadata, group_name, options.start_id); + std::string end_key = internalPelKeyFromGroupAndEntryId(ns_key, metadata, group_name, StreamEntryID::Maximum()); + + LatestSnapShot ss{storage_}; + rocksdb::ReadOptions read_options = storage_->DefaultScanOptions(); + read_options.snapshot = ss.GetSnapShot(); + rocksdb::Slice lower_bound(prefix_key); + rocksdb::Slice upper_bound(end_key); + read_options.iterate_lower_bound = &lower_bound; + read_options.iterate_upper_bound = &upper_bound; + + auto count = options.count; + uint64_t attempts = options.attempts_factors * count; + auto now_ms = util::GetTimeStampMS(); + std::vector deleted_entries; + std::vector pending_entries; + + auto batch = storage_->GetWriteBatchBase(); + WriteBatchLogData log_data(kRedisStream); + batch->PutLogData(log_data.Encode()); + + auto iter = util::UniqueIterator(storage_, read_options, stream_cf_handle_); + uint64_t total_claimed_count = 0; + for (iter->SeekToFirst(); iter->Valid() && count > 0 && attempts > 0; iter->Next()) { + if (identifySubkeyType(iter->key()) == StreamSubkeyType::StreamPelEntry) { + std::string tmp_group_name; + StreamEntryID entry_id = groupAndEntryIdFromPelInternalKey(iter->key(), tmp_group_name); + if (tmp_group_name != group_name) { + continue; + } + + if (options.exclude_start && entry_id == options.start_id) { + continue; + } + + attempts--; + + StreamPelEntry penl_entry = decodeStreamPelEntryValue(iter->value().ToString()); + if ((now_ms - penl_entry.last_delivery_time_ms) < options.min_idle_time_ms) { + continue; + } + + auto entry_key = internalKeyFromEntryID(ns_key, metadata, entry_id); + std::string entry_value; + s = storage_->Get(rocksdb::ReadOptions(), stream_cf_handle_, entry_key, &entry_value); + if (!s.ok()) { + if (s.IsNotFound()) { + deleted_entries.push_back(entry_id); + batch->Delete(stream_cf_handle_, iter->key()); + --count; + continue; + } + return s; + } + + StreamEntry entry(entry_id.ToString(), {}); + if (!options.just_id) { + auto rv_status = DecodeRawStreamEntryValue(entry_value, &entry.values); + if (!rv_status.OK()) { + return rocksdb::Status::InvalidArgument(rv_status.Msg()); + } + } + + pending_entries.emplace_back(std::move(entry)); + --count; + + if (penl_entry.consumer_name != consumer_name) { + ++total_claimed_count; + claimed_consumer_entity_count[penl_entry.consumer_name] += 1; + penl_entry.consumer_name = consumer_name; + penl_entry.last_delivery_time_ms = now_ms; + // Increment the delivery attempts counter unless JUSTID option provided + if (!options.just_id) { + penl_entry.last_delivery_count += 1; + } + batch->Put(stream_cf_handle_, iter->key(), encodeStreamPelEntryValue(penl_entry)); + } + } + } + + if (total_claimed_count > 0 && !pending_entries.empty()) { + current_consumer_metadata.pending_number += total_claimed_count; + current_consumer_metadata.last_attempted_interaction_ms = now_ms; + + batch->Put(stream_cf_handle_, consumer_key, encodeStreamConsumerMetadataValue(current_consumer_metadata)); + + for (const auto &[consumer, count] : claimed_consumer_entity_count) { + std::string tmp_consumer_key = internalKeyFromConsumerName(ns_key, metadata, group_name, consumer); + std::string tmp_consumer_value; + s = storage_->Get(rocksdb::ReadOptions(), stream_cf_handle_, tmp_consumer_key, &tmp_consumer_value); + if (!s.ok()) { + return s; + } + StreamConsumerMetadata tmp_consumer_metadata = decodeStreamConsumerMetadataValue(tmp_consumer_value); + tmp_consumer_metadata.pending_number -= count; + batch->Put(stream_cf_handle_, tmp_consumer_key, encodeStreamConsumerMetadataValue(tmp_consumer_metadata)); + } + } + + bool has_next_entry = false; + for (; iter->Valid(); iter->Next()) { + if (identifySubkeyType(iter->key()) == StreamSubkeyType::StreamPelEntry) { + has_next_entry = true; + break; + } + } + + if (has_next_entry) { + std::string tmp_group_name; + StreamEntryID entry_id = groupAndEntryIdFromPelInternalKey(iter->key(), tmp_group_name); + result->next_claim_id = entry_id.ToString(); + } else { + result->next_claim_id = StreamEntryID::Minimum().ToString(); + } + + result->entries = std::move(pending_entries); + result->deleted_ids.clear(); + result->deleted_ids.reserve(deleted_entries.size()); + std::transform(deleted_entries.cbegin(), deleted_entries.cend(), std::back_inserter(result->deleted_ids), + [](const StreamEntryID &id) { return id.ToString(); }); + + return storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); +} + rocksdb::Status Stream::CreateGroup(const Slice &stream_name, const StreamXGroupCreateOptions &options, const std::string &group_name) { if (std::isdigit(group_name[0])) { @@ -323,7 +680,7 @@ rocksdb::Status Stream::CreateGroup(const Slice &stream_name, const StreamXGroup LockGuard guard(storage_->GetLockManager(), ns_key); StreamMetadata metadata; - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(GetOptions{}, ns_key, &metadata); if (!s.ok() && !s.IsNotFound()) { return s; } @@ -372,7 +729,7 @@ rocksdb::Status Stream::DestroyGroup(const Slice &stream_name, const std::string LockGuard guard(storage_->GetLockManager(), ns_key); StreamMetadata metadata; - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(GetOptions{}, ns_key, &metadata); if (!s.ok() && !s.IsNotFound()) { return s; } @@ -423,7 +780,7 @@ rocksdb::Status Stream::createConsumerWithoutLock(const Slice &stream_name, cons } std::string ns_key = AppendNamespacePrefix(stream_name); StreamMetadata metadata; - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(GetOptions{}, ns_key, &metadata); if (!s.ok() && !s.IsNotFound()) { return s; } @@ -444,8 +801,8 @@ rocksdb::Status Stream::createConsumerWithoutLock(const Slice &stream_name, cons StreamConsumerMetadata consumer_metadata; auto now = util::GetTimeStampMS(); - consumer_metadata.last_idle = now; - consumer_metadata.last_active = now; + consumer_metadata.last_attempted_interaction_ms = now; + consumer_metadata.last_successful_interaction_ms = now; std::string consumer_key = internalKeyFromConsumerName(ns_key, metadata, group_name, consumer_name); std::string consumer_value = encodeStreamConsumerMetadataValue(consumer_metadata); std::string get_consumer_value; @@ -480,7 +837,7 @@ rocksdb::Status Stream::DestroyConsumer(const Slice &stream_name, const std::str std::string ns_key = AppendNamespacePrefix(stream_name); LockGuard guard(storage_->GetLockManager(), ns_key); StreamMetadata metadata; - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(GetOptions{}, ns_key, &metadata); if (!s.ok() && !s.IsNotFound()) { return s; } @@ -551,7 +908,7 @@ rocksdb::Status Stream::GroupSetId(const Slice &stream_name, const std::string & std::string ns_key = AppendNamespacePrefix(stream_name); LockGuard guard(storage_->GetLockManager(), ns_key); StreamMetadata metadata; - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(GetOptions{}, ns_key, &metadata); if (!s.ok() && !s.IsNotFound()) { return s; } @@ -597,7 +954,7 @@ rocksdb::Status Stream::DeleteEntries(const Slice &stream_name, const std::vecto LockGuard guard(storage_->GetLockManager(), ns_key); StreamMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(GetOptions{}, ns_key, &metadata); if (!s.ok()) { return s.IsNotFound() ? rocksdb::Status::OK() : s; } @@ -686,7 +1043,7 @@ rocksdb::Status Stream::Len(const Slice &stream_name, const StreamLenOptions &op std::string ns_key = AppendNamespacePrefix(stream_name); StreamMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(GetOptions{}, ns_key, &metadata); if (!s.ok()) { return s.IsNotFound() ? rocksdb::Status::OK() : s; } @@ -835,7 +1192,7 @@ rocksdb::Status Stream::GetStreamInfo(const rocksdb::Slice &stream_name, bool fu LockGuard guard(storage_->GetLockManager(), ns_key); StreamMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(GetOptions{}, ns_key, &metadata); if (!s.ok()) return s; info->size = metadata.size; @@ -959,7 +1316,7 @@ rocksdb::Status Stream::GetGroupInfo(const Slice &stream_name, std::vector> &group_metadata) { std::string ns_key = AppendNamespacePrefix(stream_name); StreamMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(GetOptions{}, ns_key, &metadata); if (!s.ok()) return s; std::string next_version_prefix_key = @@ -992,7 +1349,7 @@ rocksdb::Status Stream::GetConsumerInfo( std::vector> &consumer_metadata) { std::string ns_key = AppendNamespacePrefix(stream_name); StreamMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(GetOptions{}, ns_key, &metadata); if (!s.ok()) return s; std::string next_version_prefix_key = @@ -1040,7 +1397,7 @@ rocksdb::Status Stream::Range(const Slice &stream_name, const StreamRangeOptions std::string ns_key = AppendNamespacePrefix(stream_name); StreamMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(GetOptions{}, ns_key, &metadata); if (!s.ok()) { return s.IsNotFound() ? rocksdb::Status::OK() : s; } @@ -1069,7 +1426,7 @@ rocksdb::Status Stream::RangeWithPending(const Slice &stream_name, StreamRangeOp LockGuard guard(storage_->GetLockManager(), ns_key); StreamMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(GetOptions{}, ns_key, &metadata); if (!s.ok()) { return s.IsNotFound() ? rocksdb::Status::OK() : s; } @@ -1097,6 +1454,7 @@ rocksdb::Status Stream::RangeWithPending(const Slice &stream_name, StreamRangeOp if (!s.ok()) { return s; } + s = storage_->Get(rocksdb::ReadOptions(), stream_cf_handle_, group_key, &get_group_value); } auto batch = storage_->GetWriteBatchBase(); @@ -1109,9 +1467,9 @@ rocksdb::Status Stream::RangeWithPending(const Slice &stream_name, StreamRangeOp return s; } StreamConsumerMetadata consumer_metadata = decodeStreamConsumerMetadataValue(get_consumer_value); - auto now = util::GetTimeStampMS(); - consumer_metadata.last_idle = now; - consumer_metadata.last_active = now; + auto now_ms = util::GetTimeStampMS(); + consumer_metadata.last_attempted_interaction_ms = now_ms; + consumer_metadata.last_successful_interaction_ms = now_ms; if (latest) { options.start = consumergroup_metadata.last_delivered_id; @@ -1175,7 +1533,7 @@ rocksdb::Status Stream::RangeWithPending(const Slice &stream_name, StreamRangeOp } entries->emplace_back(entry_id.ToString(), std::move(values)); pel_entry.last_delivery_count += 1; - pel_entry.last_delivery_time = now; + pel_entry.last_delivery_time_ms = now_ms; batch->Put(stream_cf_handle_, iter->key(), encodeStreamPelEntryValue(pel_entry)); ++count; if (count >= options.count) break; @@ -1199,7 +1557,7 @@ rocksdb::Status Stream::Trim(const Slice &stream_name, const StreamTrimOptions & LockGuard guard(storage_->GetLockManager(), ns_key); StreamMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(GetOptions{}, ns_key, &metadata); if (!s.ok()) { return s.IsNotFound() ? rocksdb::Status::OK() : s; } @@ -1304,7 +1662,7 @@ rocksdb::Status Stream::SetId(const Slice &stream_name, const StreamEntryID &las LockGuard guard(storage_->GetLockManager(), ns_key); StreamMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(GetOptions{}, ns_key, &metadata); if (!s.ok() && !s.IsNotFound()) { return s; } diff --git a/src/types/redis_stream.h b/src/types/redis_stream.h index 8fa7bb706cf..510cbb66058 100644 --- a/src/types/redis_stream.h +++ b/src/types/redis_stream.h @@ -36,7 +36,7 @@ namespace redis { class Stream : public SubKeyScanner { public: explicit Stream(engine::Storage *storage, const std::string &ns) - : SubKeyScanner(storage, ns), stream_cf_handle_(storage->GetCFHandle("stream")) {} + : SubKeyScanner(storage, ns), stream_cf_handle_(storage->GetCFHandle(ColumnFamilyID::Stream)) {} rocksdb::Status Add(const Slice &stream_name, const StreamAddOptions &options, const std::vector &values, StreamEntryID *id); rocksdb::Status CreateGroup(const Slice &stream_name, const StreamXGroupCreateOptions &options, @@ -49,6 +49,14 @@ class Stream : public SubKeyScanner { rocksdb::Status GroupSetId(const Slice &stream_name, const std::string &group_name, const StreamXGroupCreateOptions &options); rocksdb::Status DeleteEntries(const Slice &stream_name, const std::vector &ids, uint64_t *deleted_cnt); + rocksdb::Status DeletePelEntries(const Slice &stream_name, const std::string &group_name, + const std::vector &entry_ids, uint64_t *acknowledged); + rocksdb::Status ClaimPelEntries(const Slice &stream_name, const std::string &group_name, + const std::string &consumer_name, uint64_t min_idle_time_ms, + const std::vector &entry_ids, const StreamClaimOptions &options, + StreamClaimResult *result); + rocksdb::Status AutoClaim(const Slice &stream_name, const std::string &group_name, const std::string &consumer_name, + const StreamAutoClaimOptions &options, StreamAutoClaimResult *result); rocksdb::Status Len(const Slice &stream_name, const StreamLenOptions &options, uint64_t *size); rocksdb::Status GetStreamInfo(const Slice &stream_name, bool full, uint64_t count, StreamInfo *info); rocksdb::Status GetGroupInfo(const Slice &stream_name, @@ -60,7 +68,7 @@ class Stream : public SubKeyScanner { std::vector *entries, std::string &group_name, std::string &consumer_name, bool noack, bool latest); rocksdb::Status Trim(const Slice &stream_name, const StreamTrimOptions &options, uint64_t *delete_cnt); - rocksdb::Status GetMetadata(const Slice &stream_name, StreamMetadata *metadata); + rocksdb::Status GetMetadata(Database::GetOptions get_options, const Slice &stream_name, StreamMetadata *metadata); rocksdb::Status GetLastGeneratedID(const Slice &stream_name, StreamEntryID *id); rocksdb::Status SetId(const Slice &stream_name, const StreamEntryID &last_generated_id, std::optional entries_added, std::optional max_deleted_id); diff --git a/src/types/redis_stream_base.h b/src/types/redis_stream_base.h index 60d54d231a2..82c8f945b39 100644 --- a/src/types/redis_stream_base.h +++ b/src/types/redis_stream_base.h @@ -161,6 +161,27 @@ struct StreamXGroupCreateOptions { std::string last_id; }; +struct StreamClaimOptions { + uint64_t idle_time_ms = 0; + bool with_time = false; + bool with_retry_count = false; + bool force = false; + bool just_id = false; + bool with_last_id = false; + uint64_t last_delivery_time_ms; + uint64_t last_delivery_count; + StreamEntryID last_delivered_id; +}; + +struct StreamAutoClaimOptions { + uint64_t min_idle_time_ms; + uint64_t count = 100; + uint64_t attempts_factors = 10; + StreamEntryID start_id; + bool just_id = false; + bool exclude_start = false; +}; + struct StreamConsumerGroupMetadata { uint64_t consumer_number = 0; uint64_t pending_number = 0; @@ -171,8 +192,8 @@ struct StreamConsumerGroupMetadata { struct StreamConsumerMetadata { uint64_t pending_number = 0; - uint64_t last_idle; - uint64_t last_active; + uint64_t last_attempted_interaction_ms; + uint64_t last_successful_interaction_ms; }; enum class StreamSubkeyType { @@ -183,7 +204,7 @@ enum class StreamSubkeyType { }; struct StreamPelEntry { - uint64_t last_delivery_time; + uint64_t last_delivery_time_ms; uint64_t last_delivery_count; std::string consumer_name; }; @@ -207,6 +228,17 @@ struct StreamReadResult { : name(std::move(name)), entries(std::move(result)) {} }; +struct StreamClaimResult { + std::vector ids; + std::vector entries; +}; + +struct StreamAutoClaimResult { + std::string next_claim_id; + std::vector entries; + std::vector deleted_ids; +}; + Status IncrementStreamEntryID(StreamEntryID *id); Status ParseStreamEntryID(const std::string &input, StreamEntryID *id); StatusOr> ParseNextStreamEntryIDStrategy(const std::string &input); diff --git a/src/types/redis_string.cc b/src/types/redis_string.cc index 38df40b5bd0..b934f3b8d5d 100644 --- a/src/types/redis_string.cc +++ b/src/types/redis_string.cc @@ -62,7 +62,7 @@ std::vector String::getRawValues(const std::vector &keys rocksdb::Status String::getRawValue(const std::string &ns_key, std::string *raw_value) { raw_value->clear(); - auto s = GetRawMetadata(ns_key, raw_value); + auto s = GetRawMetadata(GetOptions{}, ns_key, raw_value); if (!s.ok()) return s; Metadata metadata(kRedisNone, false); @@ -148,12 +148,7 @@ rocksdb::Status String::Get(const std::string &user_key, std::string *value) { return getValue(ns_key, value); } -rocksdb::Status String::GetEx(const std::string &user_key, std::string *value, uint64_t ttl, bool persist) { - uint64_t expire = 0; - if (ttl > 0) { - uint64_t now = util::GetTimeStampMS(); - expire = now + ttl; - } +rocksdb::Status String::GetEx(const std::string &user_key, std::string *value, std::optional expire) { std::string ns_key = AppendNamespacePrefix(user_key); LockGuard guard(storage_->GetLockManager(), ns_key); @@ -162,8 +157,8 @@ rocksdb::Status String::GetEx(const std::string &user_key, std::string *value, u std::string raw_data; Metadata metadata(kRedisString, false); - if (ttl > 0 || persist) { - metadata.expire = expire; + if (expire.has_value()) { + metadata.expire = expire.value(); } else { // If there is no ttl or persist is false, then skip the following updates. return rocksdb::Status::OK(); @@ -181,7 +176,7 @@ rocksdb::Status String::GetEx(const std::string &user_key, std::string *value, u rocksdb::Status String::GetSet(const std::string &user_key, const std::string &new_value, std::optional &old_value) { - auto s = Set(user_key, new_value, {/*ttl=*/0, StringSetType::NONE, /*get=*/true, /*keep_ttl=*/false}, old_value); + auto s = Set(user_key, new_value, {/*expire=*/0, StringSetType::NONE, /*get=*/true, /*keep_ttl=*/false}, old_value); return s; } rocksdb::Status String::GetDel(const std::string &user_key, std::string *value) { @@ -196,7 +191,7 @@ rocksdb::Status String::GetDel(const std::string &user_key, std::string *value) rocksdb::Status String::Set(const std::string &user_key, const std::string &value) { std::vector pairs{StringPair{user_key, value}}; - return MSet(pairs, /*ttl=*/0, /*lock=*/true); + return MSet(pairs, /*expire=*/0, /*lock=*/true); } rocksdb::Status String::Set(const std::string &user_key, const std::string &value, StringSetArgs args, @@ -247,9 +242,8 @@ rocksdb::Status String::Set(const std::string &user_key, const std::string &valu } // Handle expire time - if (args.ttl > 0) { - uint64_t now = util::GetTimeStampMS(); - expire = now + args.ttl; + if (!args.keep_ttl) { + expire = args.expire; } // Create new value @@ -261,21 +255,21 @@ rocksdb::Status String::Set(const std::string &user_key, const std::string &valu return updateRawValue(ns_key, new_raw_value); } -rocksdb::Status String::SetEX(const std::string &user_key, const std::string &value, uint64_t ttl) { +rocksdb::Status String::SetEX(const std::string &user_key, const std::string &value, uint64_t expire_ms) { std::optional ret; - return Set(user_key, value, {ttl, StringSetType::NONE, /*get=*/false, /*keep_ttl=*/false}, ret); + return Set(user_key, value, {expire_ms, StringSetType::NONE, /*get=*/false, /*keep_ttl=*/false}, ret); } -rocksdb::Status String::SetNX(const std::string &user_key, const std::string &value, uint64_t ttl, bool *flag) { +rocksdb::Status String::SetNX(const std::string &user_key, const std::string &value, uint64_t expire_ms, bool *flag) { std::optional ret; - auto s = Set(user_key, value, {ttl, StringSetType::NX, /*get=*/false, /*keep_ttl=*/false}, ret); + auto s = Set(user_key, value, {expire_ms, StringSetType::NX, /*get=*/false, /*keep_ttl=*/false}, ret); *flag = ret.has_value(); return s; } -rocksdb::Status String::SetXX(const std::string &user_key, const std::string &value, uint64_t ttl, bool *flag) { +rocksdb::Status String::SetXX(const std::string &user_key, const std::string &value, uint64_t expire_ms, bool *flag) { std::optional ret; - auto s = Set(user_key, value, {ttl, StringSetType::XX, /*get=*/false, /*keep_ttl=*/false}, ret); + auto s = Set(user_key, value, {expire_ms, StringSetType::XX, /*get=*/false, /*keep_ttl=*/false}, ret); *flag = ret.has_value(); return s; } @@ -390,13 +384,7 @@ rocksdb::Status String::IncrByFloat(const std::string &user_key, double incremen return updateRawValue(ns_key, raw_value); } -rocksdb::Status String::MSet(const std::vector &pairs, uint64_t ttl, bool lock) { - uint64_t expire = 0; - if (ttl > 0) { - uint64_t now = util::GetTimeStampMS(); - expire = now + ttl; - } - +rocksdb::Status String::MSet(const std::vector &pairs, uint64_t expire_ms, bool lock) { // Data race, key string maybe overwrite by other key while didn't lock the keys here, // to improve the set performance std::optional guard; @@ -416,7 +404,7 @@ rocksdb::Status String::MSet(const std::vector &pairs, uint64_t ttl, for (const auto &pair : pairs) { std::string bytes; Metadata metadata(kRedisString, false); - metadata.expire = expire; + metadata.expire = expire_ms; metadata.Encode(&bytes); bytes.append(pair.value.data(), pair.value.size()); std::string ns_key = AppendNamespacePrefix(pair.key); @@ -425,7 +413,7 @@ rocksdb::Status String::MSet(const std::vector &pairs, uint64_t ttl, return storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); } -rocksdb::Status String::MSetNX(const std::vector &pairs, uint64_t ttl, bool *flag) { +rocksdb::Status String::MSetNX(const std::vector &pairs, uint64_t expire_ms, bool *flag) { *flag = false; int exists = 0; @@ -447,7 +435,7 @@ rocksdb::Status String::MSetNX(const std::vector &pairs, uint64_t tt return rocksdb::Status::OK(); } - rocksdb::Status s = MSet(pairs, /*ttl=*/ttl, /*lock=*/false); + rocksdb::Status s = MSet(pairs, /*expire_ms=*/expire_ms, /*lock=*/false); if (!s.ok()) return s; *flag = true; @@ -460,7 +448,7 @@ rocksdb::Status String::MSetNX(const std::vector &pairs, uint64_t tt // -1 if the user_key does not exist // 0 if the operation fails rocksdb::Status String::CAS(const std::string &user_key, const std::string &old_value, const std::string &new_value, - uint64_t ttl, int *flag) { + uint64_t expire, int *flag) { *flag = 0; std::string current_value; @@ -480,12 +468,7 @@ rocksdb::Status String::CAS(const std::string &user_key, const std::string &old_ if (old_value == current_value) { std::string raw_value; - uint64_t expire = 0; Metadata metadata(kRedisString, false); - if (ttl > 0) { - uint64_t now = util::GetTimeStampMS(); - expire = now + ttl; - } metadata.expire = expire; metadata.Encode(&raw_value); raw_value.append(new_value); @@ -520,8 +503,8 @@ rocksdb::Status String::CAD(const std::string &user_key, const std::string &valu } if (value == current_value) { - auto delete_status = storage_->Delete(storage_->DefaultWriteOptions(), - storage_->GetCFHandle(engine::kMetadataColumnFamilyName), ns_key); + auto delete_status = + storage_->Delete(storage_->DefaultWriteOptions(), storage_->GetCFHandle(ColumnFamilyID::Metadata), ns_key); if (!delete_status.ok()) { return delete_status; } diff --git a/src/types/redis_string.h b/src/types/redis_string.h index 166acc63a5d..34afb0bd95b 100644 --- a/src/types/redis_string.h +++ b/src/types/redis_string.h @@ -37,7 +37,8 @@ struct StringPair { enum class StringSetType { NONE, NX, XX }; struct StringSetArgs { - uint64_t ttl; + // Expire time in mill seconds. + uint64_t expire; StringSetType type; bool get; bool keep_ttl; @@ -78,31 +79,31 @@ class String : public Database { explicit String(engine::Storage *storage, const std::string &ns) : Database(storage, ns) {} rocksdb::Status Append(const std::string &user_key, const std::string &value, uint64_t *new_size); rocksdb::Status Get(const std::string &user_key, std::string *value); - rocksdb::Status GetEx(const std::string &user_key, std::string *value, uint64_t ttl, bool persist); + rocksdb::Status GetEx(const std::string &user_key, std::string *value, std::optional expire); rocksdb::Status GetSet(const std::string &user_key, const std::string &new_value, std::optional &old_value); rocksdb::Status GetDel(const std::string &user_key, std::string *value); rocksdb::Status Set(const std::string &user_key, const std::string &value); rocksdb::Status Set(const std::string &user_key, const std::string &value, StringSetArgs args, std::optional &ret); - rocksdb::Status SetEX(const std::string &user_key, const std::string &value, uint64_t ttl); - rocksdb::Status SetNX(const std::string &user_key, const std::string &value, uint64_t ttl, bool *flag); - rocksdb::Status SetXX(const std::string &user_key, const std::string &value, uint64_t ttl, bool *flag); + rocksdb::Status SetEX(const std::string &user_key, const std::string &value, uint64_t expire_ms); + rocksdb::Status SetNX(const std::string &user_key, const std::string &value, uint64_t expire_ms, bool *flag); + rocksdb::Status SetXX(const std::string &user_key, const std::string &value, uint64_t expire_ms, bool *flag); rocksdb::Status SetRange(const std::string &user_key, size_t offset, const std::string &value, uint64_t *new_size); rocksdb::Status IncrBy(const std::string &user_key, int64_t increment, int64_t *new_value); rocksdb::Status IncrByFloat(const std::string &user_key, double increment, double *new_value); std::vector MGet(const std::vector &keys, std::vector *values); - rocksdb::Status MSet(const std::vector &pairs, uint64_t ttl = 0, bool lock = true); - rocksdb::Status MSetNX(const std::vector &pairs, uint64_t ttl, bool *flag); + rocksdb::Status MSet(const std::vector &pairs, uint64_t expire_ms, bool lock = true); + rocksdb::Status MSetNX(const std::vector &pairs, uint64_t expire_ms, bool *flag); rocksdb::Status CAS(const std::string &user_key, const std::string &old_value, const std::string &new_value, - uint64_t ttl, int *flag); + uint64_t expire_ms, int *flag); rocksdb::Status CAD(const std::string &user_key, const std::string &value, int *flag); rocksdb::Status LCS(const std::string &user_key1, const std::string &user_key2, StringLCSArgs args, StringLCSResult *rst); private: rocksdb::Status getValue(const std::string &ns_key, std::string *value); - rocksdb::Status getValueAndExpire(const std::string &ns_key, std::string *value, uint64_t *expire); + rocksdb::Status getValueAndExpire(const std::string &ns_key, std::string *value, uint64_t *expire_ms); std::vector getValues(const std::vector &ns_keys, std::vector *values); rocksdb::Status getRawValue(const std::string &ns_key, std::string *raw_value); std::vector getRawValues(const std::vector &keys, std::vector *raw_values); diff --git a/src/types/redis_zset.cc b/src/types/redis_zset.cc index 1dd2feb5f51..4d2d92d1604 100644 --- a/src/types/redis_zset.cc +++ b/src/types/redis_zset.cc @@ -32,8 +32,8 @@ namespace redis { -rocksdb::Status ZSet::GetMetadata(const Slice &ns_key, ZSetMetadata *metadata) { - return Database::GetMetadata({kRedisZSet}, ns_key, metadata); +rocksdb::Status ZSet::GetMetadata(Database::GetOptions get_options, const Slice &ns_key, ZSetMetadata *metadata) { + return Database::GetMetadata(get_options, {kRedisZSet}, ns_key, metadata); } rocksdb::Status ZSet::Add(const Slice &user_key, ZAddFlags flags, MemberScores *mscores, uint64_t *added_cnt) { @@ -43,7 +43,7 @@ rocksdb::Status ZSet::Add(const Slice &user_key, ZAddFlags flags, MemberScores * LockGuard guard(storage_->GetLockManager(), ns_key); ZSetMetadata metadata; - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(Database::GetOptions{}, ns_key, &metadata); if (!s.ok() && !s.IsNotFound()) return s; int added = 0; @@ -52,7 +52,7 @@ rocksdb::Status ZSet::Add(const Slice &user_key, ZAddFlags flags, MemberScores * WriteBatchLogData log_data(kRedisZSet); batch->PutLogData(log_data.Encode()); std::unordered_set added_member_keys; - for (auto it = mscores->rbegin(); it != mscores->rend(); it++) { + for (auto it = mscores->rbegin(); it != mscores->rend(); ++it) { if (!added_member_keys.insert(it->member).second) { continue; } @@ -125,7 +125,7 @@ rocksdb::Status ZSet::Card(const Slice &user_key, uint64_t *size) { std::string ns_key = AppendNamespacePrefix(user_key); ZSetMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(GetOptions{}, ns_key, &metadata); if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s; *size = metadata.size; return rocksdb::Status::OK(); @@ -152,7 +152,7 @@ rocksdb::Status ZSet::Pop(const Slice &user_key, int count, bool min, MemberScor LockGuard guard(storage_->GetLockManager(), ns_key); ZSetMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(GetOptions{}, ns_key, &metadata); if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s; if (count <= 0) return rocksdb::Status::OK(); if (count > static_cast(metadata.size)) count = static_cast(metadata.size); @@ -216,7 +216,8 @@ rocksdb::Status ZSet::RangeByRank(const Slice &user_key, const RangeRankSpec &sp std::optional lock_guard; if (spec.with_deletion) lock_guard.emplace(storage_->GetLockManager(), ns_key); ZSetMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + LatestSnapShot ss(storage_); + rocksdb::Status s = GetMetadata(GetOptions{ss.GetSnapShot()}, ns_key, &metadata); if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s; int start = spec.start; @@ -239,7 +240,6 @@ rocksdb::Status ZSet::RangeByRank(const Slice &user_key, const RangeRankSpec &sp int removed_subkey = 0; rocksdb::ReadOptions read_options = storage_->DefaultScanOptions(); - LatestSnapShot ss(storage_); read_options.snapshot = ss.GetSnapShot(); rocksdb::Slice upper_bound(next_version_prefix_key); read_options.iterate_upper_bound = &upper_bound; @@ -296,7 +296,7 @@ rocksdb::Status ZSet::RangeByScore(const Slice &user_key, const RangeScoreSpec & std::optional lock_guard; if (spec.with_deletion) lock_guard.emplace(storage_->GetLockManager(), ns_key); ZSetMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(GetOptions{}, ns_key, &metadata); if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s; // let's get familiar with score first: @@ -419,7 +419,7 @@ rocksdb::Status ZSet::RangeByLex(const Slice &user_key, const RangeLexSpec &spec lock_guard.emplace(storage_->GetLockManager(), ns_key); } ZSetMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(GetOptions{}, ns_key, &metadata); if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s; std::string start_member = spec.reversed ? spec.max : spec.min; @@ -493,7 +493,7 @@ rocksdb::Status ZSet::RangeByLex(const Slice &user_key, const RangeLexSpec &spec rocksdb::Status ZSet::Score(const Slice &user_key, const Slice &member, double *score) { std::string ns_key = AppendNamespacePrefix(user_key); ZSetMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(GetOptions{}, ns_key, &metadata); if (!s.ok()) return s; rocksdb::ReadOptions read_options; @@ -514,7 +514,7 @@ rocksdb::Status ZSet::Remove(const Slice &user_key, const std::vector &me LockGuard guard(storage_->GetLockManager(), ns_key); ZSetMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(GetOptions{}, ns_key, &metadata); if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s; auto batch = storage_->GetWriteBatchBase(); @@ -554,11 +554,11 @@ rocksdb::Status ZSet::Rank(const Slice &user_key, const Slice &member, bool reve std::string ns_key = AppendNamespacePrefix(user_key); ZSetMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + LatestSnapShot ss(storage_); + rocksdb::Status s = GetMetadata(GetOptions{ss.GetSnapShot()}, ns_key, &metadata); if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s; rocksdb::ReadOptions read_options = storage_->DefaultScanOptions(); - LatestSnapShot ss(storage_); read_options.snapshot = ss.GetSnapShot(); std::string score_bytes; std::string member_key = InternalKey(ns_key, member, metadata.version, storage_->IsSlotIdEncoded()).Encode(); @@ -831,11 +831,11 @@ rocksdb::Status ZSet::MGet(const Slice &user_key, const std::vector &memb std::string ns_key = AppendNamespacePrefix(user_key); ZSetMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + LatestSnapShot ss(storage_); + rocksdb::Status s = GetMetadata(GetOptions{ss.GetSnapShot()}, ns_key, &metadata); if (!s.ok()) return s; rocksdb::ReadOptions read_options; - LatestSnapShot ss(storage_); read_options.snapshot = ss.GetSnapShot(); std::string score_bytes; for (const auto &member : members) { @@ -856,7 +856,8 @@ rocksdb::Status ZSet::GetAllMemberScores(const Slice &user_key, std::vectorclear(); std::string ns_key = AppendNamespacePrefix(user_key); ZSetMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + LatestSnapShot ss(storage_); + rocksdb::Status s = GetMetadata(GetOptions{ss.GetSnapShot()}, ns_key, &metadata); if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s; std::string prefix_key = InternalKey(ns_key, "", metadata.version, storage_->IsSlotIdEncoded()).Encode(); @@ -864,7 +865,6 @@ rocksdb::Status ZSet::GetAllMemberScores(const Slice &user_key, std::vectorIsSlotIdEncoded()).Encode(); rocksdb::ReadOptions read_options = storage_->DefaultScanOptions(); - LatestSnapShot ss(storage_); read_options.snapshot = ss.GetSnapShot(); rocksdb::Slice upper_bound(next_version_prefix_key); @@ -896,7 +896,7 @@ rocksdb::Status ZSet::RandMember(const Slice &user_key, int64_t command_count, std::string ns_key = AppendNamespacePrefix(user_key); ZSetMetadata metadata(false); - rocksdb::Status s = GetMetadata(ns_key, &metadata); + rocksdb::Status s = GetMetadata(GetOptions{}, ns_key, &metadata); if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s; if (metadata.size == 0) return rocksdb::Status::OK(); diff --git a/src/types/redis_zset.h b/src/types/redis_zset.h index d806d57e3cf..a5ea754af5c 100644 --- a/src/types/redis_zset.h +++ b/src/types/redis_zset.h @@ -90,7 +90,7 @@ namespace redis { class ZSet : public SubKeyScanner { public: explicit ZSet(engine::Storage *storage, const std::string &ns) - : SubKeyScanner(storage, ns), score_cf_handle_(storage->GetCFHandle("zset_score")) {} + : SubKeyScanner(storage, ns), score_cf_handle_(storage->GetCFHandle(ColumnFamilyID::SecondarySubkey)) {} using Members = std::vector; using MemberScores = std::vector; @@ -119,7 +119,7 @@ class ZSet : public SubKeyScanner { rocksdb::Status Diff(const std::vector &keys, MemberScores *members); rocksdb::Status DiffStore(const Slice &dst, const std::vector &keys, uint64_t *stored_count); rocksdb::Status MGet(const Slice &user_key, const std::vector &members, std::map *scores); - rocksdb::Status GetMetadata(const Slice &ns_key, ZSetMetadata *metadata); + rocksdb::Status GetMetadata(Database::GetOptions get_options, const Slice &ns_key, ZSetMetadata *metadata); rocksdb::Status Count(const Slice &user_key, const RangeScoreSpec &spec, uint64_t *size); rocksdb::Status RangeByRank(const Slice &user_key, const RangeRankSpec &spec, MemberScores *mscores, diff --git a/tests/cppunit/cluster_test.cc b/tests/cppunit/cluster_test.cc index a19d05e03b1..e70c3136f62 100644 --- a/tests/cppunit/cluster_test.cc +++ b/tests/cppunit/cluster_test.cc @@ -30,8 +30,15 @@ #include "cluster/cluster_defs.h" #include "commands/commander.h" #include "server/server.h" +#include "test_base.h" -TEST(Cluster, CluseterSetNodes) { +class ClusterTest : public TestBase { + protected: + explicit ClusterTest() = default; + ~ClusterTest() override = default; +}; + +TEST_F(ClusterTest, CluseterSetNodes) { Status s; Cluster cluster(nullptr, {"127.0.0.1"}, 3002); @@ -101,13 +108,21 @@ TEST(Cluster, CluseterSetNodes) { ASSERT_TRUE(cluster.GetVersion() == 1); } -TEST(Cluster, CluseterGetNodes) { +TEST_F(ClusterTest, CluseterGetNodes) { const std::string nodes = "07c37dfeb235213a872192d90877d0cd55635b91 127.0.0.1 30004 " "slave e7d1eecce10fd6bb5eb35b9f99a514335d9ba9ca\n" "67ed2db8d677e59ec4a4cefb06858cf2a1a89fa1 127.0.0.1 30002 " "master - 5461-10922"; - Cluster cluster(nullptr, {"127.0.0.1"}, 30002); + auto config = storage_->GetConfig(); + // don't start workers + config->workers = 0; + Server server(storage_.get(), config); + // we don't need the server resource, so just stop it once it's started + server.Stop(); + server.Join(); + + Cluster cluster(&server, {"127.0.0.1"}, 30002); Status s = cluster.SetClusterNodes(nodes, 1, false); ASSERT_TRUE(s.IsOK()); @@ -139,7 +154,7 @@ TEST(Cluster, CluseterGetNodes) { } } -TEST(Cluster, CluseterGetSlotInfo) { +TEST_F(ClusterTest, CluseterGetSlotInfo) { const std::string nodes = "07c37dfeb235213a872192d90877d0cd55635b91 127.0.0.1 30004 " "slave 67ed2db8d677e59ec4a4cefb06858cf2a1a89fa1\n" @@ -161,7 +176,7 @@ TEST(Cluster, CluseterGetSlotInfo) { ASSERT_TRUE(info.nodes[1].id == "07c37dfeb235213a872192d90877d0cd55635b91"); } -TEST(Cluster, TestDumpAndLoadClusterNodesInfo) { +TEST_F(ClusterTest, TestDumpAndLoadClusterNodesInfo) { int64_t version = 2; const std::string nodes = "07c37dfeb235213a872192d90877d0cd55635b91 127.0.0.1 30004 " @@ -200,7 +215,7 @@ TEST(Cluster, TestDumpAndLoadClusterNodesInfo) { unlink(nodes_filename.c_str()); } -TEST(Cluster, ClusterParseSlotRanges) { +TEST_F(ClusterTest, ClusterParseSlotRanges) { Status s; Cluster cluster(nullptr, {"127.0.0.1"}, 3002); const std::string node_id = "67ed2db8d677e59ec4a4cefb06858cf2a1a89fa1"; @@ -325,3 +340,49 @@ TEST(Cluster, ClusterParseSlotRanges) { slots.clear(); } } + +TEST_F(ClusterTest, GetReplicas) { + auto config = storage_->GetConfig(); + // don't start workers + config->workers = 0; + Server server(storage_.get(), config); + // we don't need the server resource, so just stop it once it's started + server.Stop(); + server.Join(); + + const std::string nodes = + "7dbee3d628f04cc5d763b36e92b10533e627a1d0 127.0.0.1 6480 slave 159dde1194ebf5bfc5a293dff839c3d1476f2a49\n" + "159dde1194ebf5bfc5a293dff839c3d1476f2a49 127.0.0.1 6479 master - 8192-16383\n" + "bb2e5b3c5282086df51eff6b3e35519aede96fa6 127.0.0.1 6379 master - 0-8191"; + + Cluster cluster(&server, {"127.0.0.1"}, 6379); + Status s = cluster.SetClusterNodes(nodes, 2, false); + ASSERT_TRUE(s.IsOK()); + + auto with_replica = cluster.GetReplicas("159dde1194ebf5bfc5a293dff839c3d1476f2a49"); + ASSERT_TRUE(s.IsOK()); + + std::vector replicas = util::Split(with_replica.GetValue(), "\n"); + for (const auto &replica : replicas) { + std::vector replica_fields = util::Split(replica, " "); + + ASSERT_TRUE(replica_fields.size() == 8); + ASSERT_TRUE(replica_fields[0] == "7dbee3d628f04cc5d763b36e92b10533e627a1d0"); + ASSERT_TRUE(replica_fields[1] == "127.0.0.1:6480@16480"); + ASSERT_TRUE(replica_fields[2] == "slave"); + ASSERT_TRUE(replica_fields[3] == "159dde1194ebf5bfc5a293dff839c3d1476f2a49"); + ASSERT_TRUE(replica_fields[7] == "connected"); + } + + auto without_replica = cluster.GetReplicas("bb2e5b3c5282086df51eff6b3e35519aede96fa6"); + ASSERT_TRUE(without_replica.IsOK()); + ASSERT_EQ(without_replica.GetValue(), ""); + + auto replica_node = cluster.GetReplicas("7dbee3d628f04cc5d763b36e92b10533e627a1d0"); + ASSERT_FALSE(replica_node.IsOK()); + ASSERT_EQ(replica_node.Msg(), "The node isn't a master"); + + auto unknown_node = cluster.GetReplicas("1234567890"); + ASSERT_FALSE(unknown_node.IsOK()); + ASSERT_EQ(unknown_node.Msg(), "Invalid cluster node id"); +} diff --git a/tests/cppunit/compact_test.cc b/tests/cppunit/compact_test.cc index c56163047cc..fad4cb13894 100644 --- a/tests/cppunit/compact_test.cc +++ b/tests/cppunit/compact_test.cc @@ -60,17 +60,17 @@ TEST(Compact, Filter) { read_options.snapshot = db->GetSnapshot(); read_options.fill_cache = false; - auto new_iterator = [db, read_options, &storage](const std::string& name) { - return std::unique_ptr(db->NewIterator(read_options, storage->GetCFHandle(name))); + auto new_iterator = [db, read_options, &storage](ColumnFamilyID column_family_id) { + return std::unique_ptr(db->NewIterator(read_options, storage->GetCFHandle(column_family_id))); }; - auto iter = new_iterator("metadata"); + auto iter = new_iterator(ColumnFamilyID::Metadata); for (iter->SeekToFirst(); iter->Valid(); iter->Next()) { auto [user_ns, user_key] = ExtractNamespaceKey(iter->key(), storage->IsSlotIdEncoded()); EXPECT_EQ(user_key.ToString(), live_hash_key); } - iter = new_iterator("subkey"); + iter = new_iterator(ColumnFamilyID::PrimarySubkey); for (iter->SeekToFirst(); iter->Valid(); iter->Next()) { InternalKey ikey(iter->key(), storage->IsSlotIdEncoded()); EXPECT_EQ(ikey.GetKey().ToString(), live_hash_key); @@ -85,17 +85,17 @@ TEST(Compact, Filter) { // Same as the above compact, need to compact twice here status = storage->Compact(nullptr, nullptr, nullptr); - assert(status.ok()); + EXPECT_TRUE(status.ok()); status = storage->Compact(nullptr, nullptr, nullptr); - assert(status.ok()); + EXPECT_TRUE(status.ok()); - iter = new_iterator("default"); + iter = new_iterator(ColumnFamilyID::PrimarySubkey); for (iter->SeekToFirst(); iter->Valid(); iter->Next()) { InternalKey ikey(iter->key(), storage->IsSlotIdEncoded()); EXPECT_EQ(ikey.GetKey().ToString(), live_hash_key); } - iter = new_iterator("zset_score"); + iter = new_iterator(ColumnFamilyID::SecondarySubkey); for (iter->SeekToFirst(); iter->Valid(); iter->Next()) { EXPECT_TRUE(false); // never reach here } @@ -107,7 +107,7 @@ TEST(Compact, Filter) { int retry = 2; while (retry-- > 0) { status = storage->Compact(nullptr, nullptr, nullptr); - assert(status.ok()); + ASSERT_TRUE(status.ok()); std::vector fieldvalues; auto get_res = hash->GetAll(mk_with_ttl, &fieldvalues); auto s_expire = hash->Expire(mk_with_ttl, 1); // expired immediately.. diff --git a/tests/cppunit/config_test.cc b/tests/cppunit/config_test.cc index 49e7623bc1a..f9252127d22 100644 --- a/tests/cppunit/config_test.cc +++ b/tests/cppunit/config_test.cc @@ -46,6 +46,7 @@ TEST(Config, GetAndSet) { {"masterauth", "mytest_masterauth"}, {"compact-cron", "1 2 3 4 5"}, {"bgsave-cron", "5 4 3 2 1"}, + {"dbsize-scan-cron", "1 2 3 2 1"}, {"max-io-mb", "5000"}, {"max-db-size", "6000"}, {"max-replication-mb", "7000"}, @@ -126,6 +127,7 @@ TEST(Config, GetAndSet) { {"rocksdb.subkey_block_cache_size", "100"}, {"rocksdb.row_cache_size", "100"}, {"rocksdb.rate_limiter_auto_tuned", "yes"}, + {"rocksdb.compression_level", "32767"}, }; for (const auto &iter : immutable_cases) { s = config.Set(nullptr, iter.first, iter.second); @@ -174,6 +176,8 @@ TEST(Config, Rewrite) { redis::CommandTable::Reset(); Config config; ASSERT_TRUE(config.Load(CLIOptions(path)).IsOK()); + ASSERT_EQ(config.dir + "/backup", config.backup_dir); + ASSERT_EQ(config.dir + "/kvrocks.pid", config.pidfile); ASSERT_TRUE(config.Rewrite({}).IsOK()); // Need to re-populate the command table since it has renamed by the previous redis::CommandTable::Reset(); diff --git a/tests/cppunit/cron_test.cc b/tests/cppunit/cron_test.cc index 9322050cec2..bccb5ee24e1 100644 --- a/tests/cppunit/cron_test.cc +++ b/tests/cppunit/cron_test.cc @@ -24,20 +24,51 @@ #include -class CronTest : public testing::Test { +// At minute 10 +class CronTestMin : public testing::Test { protected: - explicit CronTest() { + explicit CronTestMin() { + cron_ = std::make_unique(); + std::vector schedule{"10", "*", "*", "*", "*"}; + auto s = cron_->SetScheduleTime(schedule); + EXPECT_TRUE(s.IsOK()); + } + ~CronTestMin() override = default; + + std::unique_ptr cron_; +}; + +TEST_F(CronTestMin, IsTimeMatch) { + std::time_t t = std::time(nullptr); + std::tm *now = std::localtime(&t); + now->tm_min = 10; + now->tm_hour = 3; + ASSERT_TRUE(cron_->IsTimeMatch(now)); + now->tm_min = 15; + now->tm_hour = 4; + ASSERT_FALSE(cron_->IsTimeMatch(now)); +} + +TEST_F(CronTestMin, ToString) { + std::string got = cron_->ToString(); + ASSERT_EQ("10 * * * *", got); +} + +// At every minute past hour 3 +class CronTestHour : public testing::Test { + protected: + explicit CronTestHour() { cron_ = std::make_unique(); std::vector schedule{"*", "3", "*", "*", "*"}; auto s = cron_->SetScheduleTime(schedule); EXPECT_TRUE(s.IsOK()); } - ~CronTest() override = default; + ~CronTestHour() override = default; std::unique_ptr cron_; }; -TEST_F(CronTest, IsTimeMatch) { +TEST_F(CronTestHour, IsTimeMatch) { std::time_t t = std::time(nullptr); std::tm *now = std::localtime(&t); now->tm_hour = 3; @@ -46,7 +77,338 @@ TEST_F(CronTest, IsTimeMatch) { ASSERT_FALSE(cron_->IsTimeMatch(now)); } -TEST_F(CronTest, ToString) { +TEST_F(CronTestHour, ToString) { std::string got = cron_->ToString(); ASSERT_EQ("* 3 * * *", got); } + +// At 03:00 on day-of-month 5 +class CronTestMonthDay : public testing::Test { + protected: + explicit CronTestMonthDay() { + cron_ = std::make_unique(); + std::vector schedule{"0", "3", "5", "*", "*"}; + auto s = cron_->SetScheduleTime(schedule); + EXPECT_TRUE(s.IsOK()); + } + ~CronTestMonthDay() override = default; + + std::unique_ptr cron_; +}; + +TEST_F(CronTestMonthDay, IsTimeMatch) { + std::time_t t = std::time(nullptr); + std::tm *now = std::localtime(&t); + now->tm_min = 0; + now->tm_hour = 3; + now->tm_mday = 5; + ASSERT_TRUE(cron_->IsTimeMatch(now)); + now->tm_hour = 0; + now->tm_hour = 3; + now->tm_hour = 6; + ASSERT_FALSE(cron_->IsTimeMatch(now)); +} + +TEST_F(CronTestMonthDay, ToString) { + std::string got = cron_->ToString(); + ASSERT_EQ("0 3 5 * *", got); +} + +// At 03:00 on day-of-month 5 in September +class CronTestMonth : public testing::Test { + protected: + explicit CronTestMonth() { + cron_ = std::make_unique(); + std::vector schedule{"0", "3", "5", "9", "*"}; + auto s = cron_->SetScheduleTime(schedule); + EXPECT_TRUE(s.IsOK()); + } + ~CronTestMonth() override = default; + + std::unique_ptr cron_; +}; + +TEST_F(CronTestMonth, IsTimeMatch) { + std::time_t t = std::time(nullptr); + std::tm *now = std::localtime(&t); + now->tm_min = 0; + now->tm_hour = 3; + now->tm_mday = 5; + now->tm_mon = 8; + ASSERT_TRUE(cron_->IsTimeMatch(now)); + now->tm_min = 0; + now->tm_hour = 3; + now->tm_mday = 5; + now->tm_mon = 5; + ASSERT_FALSE(cron_->IsTimeMatch(now)); +} + +TEST_F(CronTestMonth, ToString) { + std::string got = cron_->ToString(); + ASSERT_EQ("0 3 5 9 *", got); +} + +// At 03:00 on Sunday in September +class CronTestWeekDay : public testing::Test { + protected: + explicit CronTestWeekDay() { + cron_ = std::make_unique(); + std::vector schedule{"0", "3", "*", "9", "0"}; + auto s = cron_->SetScheduleTime(schedule); + EXPECT_TRUE(s.IsOK()); + } + ~CronTestWeekDay() override = default; + + std::unique_ptr cron_; +}; + +TEST_F(CronTestWeekDay, IsTimeMatch) { + std::time_t t = std::time(nullptr); + std::tm *now = std::localtime(&t); + now->tm_min = 0; + now->tm_hour = 3; + now->tm_mon = 8; + now->tm_wday = 0; + ASSERT_TRUE(cron_->IsTimeMatch(now)); + now->tm_min = 0; + now->tm_hour = 3; + now->tm_mon = 8; + now->tm_wday = 0; + ASSERT_FALSE(cron_->IsTimeMatch(now)); +} + +TEST_F(CronTestWeekDay, ToString) { + std::string got = cron_->ToString(); + ASSERT_EQ("0 3 * 9 0", got); +} + +// At every 4th minute +class CronTestMinInterval : public testing::Test { + protected: + explicit CronTestMinInterval() { + cron_ = std::make_unique(); + std::vector schedule{"*/4", "*", "*", "*", "*"}; + auto s = cron_->SetScheduleTime(schedule); + EXPECT_TRUE(s.IsOK()); + } + ~CronTestMinInterval() override = default; + + std::unique_ptr cron_; +}; + +TEST_F(CronTestMinInterval, IsTimeMatch) { + std::time_t t = std::time(nullptr); + std::tm *now = std::localtime(&t); + now->tm_hour = 0; + now->tm_min = 0; + ASSERT_TRUE(cron_->IsTimeMatch(now)); + now->tm_min = 4; + ASSERT_TRUE(cron_->IsTimeMatch(now)); + now->tm_min = 8; + ASSERT_TRUE(cron_->IsTimeMatch(now)); + now->tm_min = 12; + ASSERT_TRUE(cron_->IsTimeMatch(now)); + now->tm_min = 3; + ASSERT_FALSE(cron_->IsTimeMatch(now)); + now->tm_min = 99; + ASSERT_FALSE(cron_->IsTimeMatch(now)); +} + +TEST_F(CronTestMinInterval, ToString) { + std::string got = cron_->ToString(); + ASSERT_EQ("*/4 * * * *", got); +} + +// At minute 0 past every 4th hour +class CronTestHourInterval : public testing::Test { + protected: + explicit CronTestHourInterval() { + cron_ = std::make_unique(); + std::vector schedule{"0", "*/4", "*", "*", "*"}; + auto s = cron_->SetScheduleTime(schedule); + EXPECT_TRUE(s.IsOK()); + } + ~CronTestHourInterval() override = default; + + std::unique_ptr cron_; +}; + +TEST_F(CronTestHourInterval, IsTimeMatch) { + std::time_t t = std::time(nullptr); + std::tm *now = std::localtime(&t); + now->tm_hour = 0; + now->tm_min = 0; + ASSERT_TRUE(cron_->IsTimeMatch(now)); + now->tm_hour = 4; + ASSERT_TRUE(cron_->IsTimeMatch(now)); + now->tm_hour = 8; + ASSERT_TRUE(cron_->IsTimeMatch(now)); + now->tm_hour = 12; + ASSERT_TRUE(cron_->IsTimeMatch(now)); + now->tm_hour = 3; + ASSERT_FALSE(cron_->IsTimeMatch(now)); + now->tm_hour = 55; + ASSERT_FALSE(cron_->IsTimeMatch(now)); +} + +TEST_F(CronTestHourInterval, ToString) { + std::string got = cron_->ToString(); + ASSERT_EQ("0 */4 * * *", got); +} + +// At minute 0 on every 4th day-of-month +// https://crontab.guru/#0_0_*/4_*_* (click on next) +class CronTestMonthDayInterval : public testing::Test { + protected: + explicit CronTestMonthDayInterval() { + cron_ = std::make_unique(); + std::vector schedule{"0", "*", "*/4", "*", "*"}; + auto s = cron_->SetScheduleTime(schedule); + EXPECT_TRUE(s.IsOK()); + } + ~CronTestMonthDayInterval() override = default; + + std::unique_ptr cron_; +}; + +TEST_F(CronTestMonthDayInterval, IsTimeMatch) { + std::time_t t = std::time(nullptr); + std::tm *now = std::localtime(&t); + now->tm_min = 0; + now->tm_hour = 3; + now->tm_mday = 17; + now->tm_mon = 6; + ASSERT_TRUE(cron_->IsTimeMatch(now)); + now->tm_hour = 5; + now->tm_mday = 21; + now->tm_mon = 6; + ASSERT_TRUE(cron_->IsTimeMatch(now)); + now->tm_hour = 6; + now->tm_mday = 25; + now->tm_mon = 6; + ASSERT_TRUE(cron_->IsTimeMatch(now)); + now->tm_hour = 1; + now->tm_mday = 2; + now->tm_mon = 7; + ASSERT_FALSE(cron_->IsTimeMatch(now)); + now->tm_hour = 1; + now->tm_mday = 99; + ASSERT_FALSE(cron_->IsTimeMatch(now)); +} + +TEST_F(CronTestMonthDayInterval, ToString) { + std::string got = cron_->ToString(); + ASSERT_EQ("0 * */4 * *", got); +} + +// At minute 0 in every 4th month +class CronTestMonthInterval : public testing::Test { + protected: + explicit CronTestMonthInterval() { + cron_ = std::make_unique(); + std::vector schedule{"0", "*", "*", "*/4", "*"}; + auto s = cron_->SetScheduleTime(schedule); + EXPECT_TRUE(s.IsOK()); + } + ~CronTestMonthInterval() override = default; + + std::unique_ptr cron_; +}; + +TEST_F(CronTestMonthInterval, IsTimeMatch) { + std::time_t t = std::time(nullptr); + std::tm *now = std::localtime(&t); + now->tm_hour = 0; + now->tm_min = 0; + now->tm_mon = 4; + ASSERT_TRUE(cron_->IsTimeMatch(now)); + now->tm_hour = 5; + now->tm_mon = 8; + ASSERT_TRUE(cron_->IsTimeMatch(now)); + now->tm_hour = 1; + now->tm_mon = 3; + ASSERT_FALSE(cron_->IsTimeMatch(now)); + now->tm_hour = 1; + now->tm_mon = 99; + ASSERT_FALSE(cron_->IsTimeMatch(now)); +} + +TEST_F(CronTestMonthInterval, ToString) { + std::string got = cron_->ToString(); + ASSERT_EQ("0 * * */4 *", got); +} + +// At minute 0 on every 4th day-of-week +class CronTestWeekDayInterval : public testing::Test { + protected: + explicit CronTestWeekDayInterval() { + cron_ = std::make_unique(); + std::vector schedule{"0", "*", "*", "*", "*/4"}; + auto s = cron_->SetScheduleTime(schedule); + EXPECT_TRUE(s.IsOK()); + } + ~CronTestWeekDayInterval() override = default; + + std::unique_ptr cron_; +}; + +TEST_F(CronTestWeekDayInterval, IsTimeMatch) { + std::time_t t = std::time(nullptr); + std::tm *now = std::localtime(&t); + now->tm_hour = 0; + now->tm_min = 0; + now->tm_hour = 3; + now->tm_wday = 4; + ASSERT_TRUE(cron_->IsTimeMatch(now)); + now->tm_hour = 5; + now->tm_wday = 3; + ASSERT_FALSE(cron_->IsTimeMatch(now)); + now->tm_hour = 1; + now->tm_wday = 99; + ASSERT_FALSE(cron_->IsTimeMatch(now)); +} + +TEST_F(CronTestWeekDayInterval, ToString) { + std::string got = cron_->ToString(); + ASSERT_EQ("0 * * * */4", got); +} + +class CronTestNumberAndRange : public testing::Test { + protected: + explicit CronTestNumberAndRange() { + cron_ = std::make_unique(); + std::vector schedule{"*", "1,3,6-10,20", "*", "*", "*"}; + auto s = cron_->SetScheduleTime(schedule); + EXPECT_TRUE(s.IsOK()); + } + ~CronTestNumberAndRange() override = default; + + std::unique_ptr cron_; +}; + +TEST_F(CronTestNumberAndRange, IsTimeMatch) { + std::time_t t = std::time(nullptr); + std::tm *now = std::localtime(&t); + now->tm_hour = 1; + ASSERT_TRUE(cron_->IsTimeMatch(now)); + now->tm_hour = 3; + ASSERT_TRUE(cron_->IsTimeMatch(now)); + now->tm_hour = 6; + ASSERT_TRUE(cron_->IsTimeMatch(now)); + now->tm_hour = 8; + ASSERT_TRUE(cron_->IsTimeMatch(now)); + now->tm_hour = 10; + ASSERT_TRUE(cron_->IsTimeMatch(now)); + now->tm_hour = 20; + ASSERT_TRUE(cron_->IsTimeMatch(now)); + now->tm_hour = 0; + ASSERT_FALSE(cron_->IsTimeMatch(now)); + now->tm_hour = 2; + ASSERT_FALSE(cron_->IsTimeMatch(now)); + now->tm_hour = 5; + ASSERT_FALSE(cron_->IsTimeMatch(now)); + now->tm_hour = 14; + ASSERT_FALSE(cron_->IsTimeMatch(now)); + now->tm_hour = 22; + ASSERT_FALSE(cron_->IsTimeMatch(now)); +} diff --git a/tests/cppunit/indexer_test.cc b/tests/cppunit/indexer_test.cc index f30b45cad25..c3a7769e5bf 100644 --- a/tests/cppunit/indexer_test.cc +++ b/tests/cppunit/indexer_test.cc @@ -25,32 +25,42 @@ #include +#include "search/index_info.h" #include "search/search_encoding.h" #include "storage/redis_metadata.h" #include "types/redis_hash.h" +static auto T(const std::string& v) { return kqir::MakeValue(util::Split(v, ",")); } + struct IndexerTest : TestBase { redis::GlobalIndexer indexer; + kqir::IndexMap map; std::string ns = "index_test"; IndexerTest() : indexer(storage_.get()) { - SearchMetadata hash_field_meta(false); - hash_field_meta.on_data_type = SearchOnDataType::HASH; + redis::IndexMetadata hash_field_meta; + hash_field_meta.on_data_type = redis::IndexOnDataType::HASH; + + auto hash_info = std::make_unique("hashtest", hash_field_meta, ns); + hash_info->Add(kqir::FieldInfo("x", std::make_unique())); + hash_info->Add(kqir::FieldInfo("y", std::make_unique())); + hash_info->prefixes.prefixes.emplace_back("idxtesthash"); - std::map> hash_fields; - hash_fields.emplace("x", std::make_unique()); - hash_fields.emplace("y", std::make_unique()); + map.emplace("hashtest", std::move(hash_info)); - redis::IndexUpdater hash_updater{"hashtest", hash_field_meta, {"idxtesthash"}, std::move(hash_fields), &indexer}; + redis::IndexUpdater hash_updater{map.at("hashtest").get()}; - SearchMetadata json_field_meta(false); - json_field_meta.on_data_type = SearchOnDataType::JSON; + redis::IndexMetadata json_field_meta; + json_field_meta.on_data_type = redis::IndexOnDataType::JSON; - std::map> json_fields; - json_fields.emplace("$.x", std::make_unique()); - json_fields.emplace("$.y", std::make_unique()); + auto json_info = std::make_unique("jsontest", json_field_meta, ns); + json_info->Add(kqir::FieldInfo("$.x", std::make_unique())); + json_info->Add(kqir::FieldInfo("$.y", std::make_unique())); + json_info->prefixes.prefixes.emplace_back("idxtestjson"); - redis::IndexUpdater json_updater{"jsontest", json_field_meta, {"idxtestjson"}, std::move(json_fields), &indexer}; + map.emplace("jsontest", std::move(json_info)); + + redis::IndexUpdater json_updater{map.at("jsontest").get()}; indexer.Add(std::move(hash_updater)); indexer.Add(std::move(json_updater)); @@ -59,7 +69,7 @@ struct IndexerTest : TestBase { TEST_F(IndexerTest, HashTag) { redis::Hash db(storage_.get(), ns); - auto cfhandler = storage_->GetCFHandle("search"); + auto cfhandler = storage_->GetCFHandle(ColumnFamilyID::Search); { auto s = indexer.Record("no_exist", ns); @@ -71,39 +81,33 @@ TEST_F(IndexerTest, HashTag) { { auto s = indexer.Record(key1, ns); - ASSERT_TRUE(s); - ASSERT_EQ(s->first->name, idxname); - ASSERT_TRUE(s->second.empty()); + ASSERT_EQ(s.Msg(), Status::ok_msg); + ASSERT_EQ(s->updater.info->name, idxname); + ASSERT_TRUE(s->fields.empty()); uint64_t cnt = 0; db.Set(key1, "x", "food,kitChen,Beauty", &cnt); ASSERT_EQ(cnt, 1); - auto s2 = indexer.Update(*s, key1, ns); + auto s2 = indexer.Update(*s); ASSERT_TRUE(s2); - auto subkey = redis::ConstructTagFieldSubkey("x", "food", key1); - auto nskey = ComposeNamespaceKey(ns, idxname, false); - auto key = InternalKey(nskey, subkey, 0, false); + auto key = redis::SearchKey(ns, idxname, "x").ConstructTagFieldData("food", key1); std::string val; - auto s3 = storage_->Get(storage_->DefaultMultiGetOptions(), cfhandler, key.Encode(), &val); + auto s3 = storage_->Get(storage_->DefaultMultiGetOptions(), cfhandler, key, &val); ASSERT_TRUE(s3.ok()); ASSERT_EQ(val, ""); - subkey = redis::ConstructTagFieldSubkey("x", "kitchen", key1); - nskey = ComposeNamespaceKey(ns, idxname, false); - key = InternalKey(nskey, subkey, 0, false); + key = redis::SearchKey(ns, idxname, "x").ConstructTagFieldData("kitchen", key1); - s3 = storage_->Get(storage_->DefaultMultiGetOptions(), cfhandler, key.Encode(), &val); + s3 = storage_->Get(storage_->DefaultMultiGetOptions(), cfhandler, key, &val); ASSERT_TRUE(s3.ok()); ASSERT_EQ(val, ""); - subkey = redis::ConstructTagFieldSubkey("x", "beauty", key1); - nskey = ComposeNamespaceKey(ns, idxname, false); - key = InternalKey(nskey, subkey, 0, false); + key = redis::SearchKey(ns, idxname, "x").ConstructTagFieldData("beauty", key1); - s3 = storage_->Get(storage_->DefaultMultiGetOptions(), cfhandler, key.Encode(), &val); + s3 = storage_->Get(storage_->DefaultMultiGetOptions(), cfhandler, key, &val); ASSERT_TRUE(s3.ok()); ASSERT_EQ(val, ""); } @@ -111,62 +115,52 @@ TEST_F(IndexerTest, HashTag) { { auto s = indexer.Record(key1, ns); ASSERT_TRUE(s); - ASSERT_EQ(s->first->name, idxname); - ASSERT_EQ(s->second.size(), 1); - ASSERT_EQ(s->second["x"], "food,kitChen,Beauty"); + ASSERT_EQ(s->updater.info->name, idxname); + ASSERT_EQ(s->fields.size(), 1); + ASSERT_EQ(s->fields["x"], T("food,kitChen,Beauty")); uint64_t cnt = 0; auto s_set = db.Set(key1, "x", "Clothing,FOOD,sport", &cnt); ASSERT_EQ(cnt, 0); ASSERT_TRUE(s_set.ok()); - auto s2 = indexer.Update(*s, key1, ns); + auto s2 = indexer.Update(*s); ASSERT_TRUE(s2); - auto subkey = redis::ConstructTagFieldSubkey("x", "food", key1); - auto nskey = ComposeNamespaceKey(ns, idxname, false); - auto key = InternalKey(nskey, subkey, 0, false); + auto key = redis::SearchKey(ns, idxname, "x").ConstructTagFieldData("food", key1); std::string val; - auto s3 = storage_->Get(storage_->DefaultMultiGetOptions(), cfhandler, key.Encode(), &val); + auto s3 = storage_->Get(storage_->DefaultMultiGetOptions(), cfhandler, key, &val); ASSERT_TRUE(s3.ok()); ASSERT_EQ(val, ""); - subkey = redis::ConstructTagFieldSubkey("x", "clothing", key1); - nskey = ComposeNamespaceKey(ns, idxname, false); - key = InternalKey(nskey, subkey, 0, false); + key = redis::SearchKey(ns, idxname, "x").ConstructTagFieldData("clothing", key1); - s3 = storage_->Get(storage_->DefaultMultiGetOptions(), cfhandler, key.Encode(), &val); + s3 = storage_->Get(storage_->DefaultMultiGetOptions(), cfhandler, key, &val); ASSERT_TRUE(s3.ok()); ASSERT_EQ(val, ""); - subkey = redis::ConstructTagFieldSubkey("x", "sport", key1); - nskey = ComposeNamespaceKey(ns, idxname, false); - key = InternalKey(nskey, subkey, 0, false); + key = redis::SearchKey(ns, idxname, "x").ConstructTagFieldData("sport", key1); - s3 = storage_->Get(storage_->DefaultMultiGetOptions(), cfhandler, key.Encode(), &val); + s3 = storage_->Get(storage_->DefaultMultiGetOptions(), cfhandler, key, &val); ASSERT_TRUE(s3.ok()); ASSERT_EQ(val, ""); - subkey = redis::ConstructTagFieldSubkey("x", "kitchen", key1); - nskey = ComposeNamespaceKey(ns, idxname, false); - key = InternalKey(nskey, subkey, 0, false); + key = redis::SearchKey(ns, idxname, "x").ConstructTagFieldData("kitchen", key1); - s3 = storage_->Get(storage_->DefaultMultiGetOptions(), cfhandler, key.Encode(), &val); + s3 = storage_->Get(storage_->DefaultMultiGetOptions(), cfhandler, key, &val); ASSERT_TRUE(s3.IsNotFound()); - subkey = redis::ConstructTagFieldSubkey("x", "beauty", key1); - nskey = ComposeNamespaceKey(ns, idxname, false); - key = InternalKey(nskey, subkey, 0, false); + key = redis::SearchKey(ns, idxname, "x").ConstructTagFieldData("beauty", key1); - s3 = storage_->Get(storage_->DefaultMultiGetOptions(), cfhandler, key.Encode(), &val); + s3 = storage_->Get(storage_->DefaultMultiGetOptions(), cfhandler, key, &val); ASSERT_TRUE(s3.IsNotFound()); } } TEST_F(IndexerTest, JsonTag) { redis::Json db(storage_.get(), ns); - auto cfhandler = storage_->GetCFHandle("search"); + auto cfhandler = storage_->GetCFHandle(ColumnFamilyID::Search); { auto s = indexer.Record("no_exist", ns); @@ -179,37 +173,31 @@ TEST_F(IndexerTest, JsonTag) { { auto s = indexer.Record(key1, ns); ASSERT_TRUE(s); - ASSERT_EQ(s->first->name, idxname); - ASSERT_TRUE(s->second.empty()); + ASSERT_EQ(s->updater.info->name, idxname); + ASSERT_TRUE(s->fields.empty()); auto s_set = db.Set(key1, "$", R"({"x": "food,kitChen,Beauty"})"); ASSERT_TRUE(s_set.ok()); - auto s2 = indexer.Update(*s, key1, ns); + auto s2 = indexer.Update(*s); ASSERT_TRUE(s2); - auto subkey = redis::ConstructTagFieldSubkey("$.x", "food", key1); - auto nskey = ComposeNamespaceKey(ns, idxname, false); - auto key = InternalKey(nskey, subkey, 0, false); + auto key = redis::SearchKey(ns, idxname, "$.x").ConstructTagFieldData("food", key1); std::string val; - auto s3 = storage_->Get(storage_->DefaultMultiGetOptions(), cfhandler, key.Encode(), &val); + auto s3 = storage_->Get(storage_->DefaultMultiGetOptions(), cfhandler, key, &val); ASSERT_TRUE(s3.ok()); ASSERT_EQ(val, ""); - subkey = redis::ConstructTagFieldSubkey("$.x", "kitchen", key1); - nskey = ComposeNamespaceKey(ns, idxname, false); - key = InternalKey(nskey, subkey, 0, false); + key = redis::SearchKey(ns, idxname, "$.x").ConstructTagFieldData("kitchen", key1); - s3 = storage_->Get(storage_->DefaultMultiGetOptions(), cfhandler, key.Encode(), &val); + s3 = storage_->Get(storage_->DefaultMultiGetOptions(), cfhandler, key, &val); ASSERT_TRUE(s3.ok()); ASSERT_EQ(val, ""); - subkey = redis::ConstructTagFieldSubkey("$.x", "beauty", key1); - nskey = ComposeNamespaceKey(ns, idxname, false); - key = InternalKey(nskey, subkey, 0, false); + key = redis::SearchKey(ns, idxname, "$.x").ConstructTagFieldData("beauty", key1); - s3 = storage_->Get(storage_->DefaultMultiGetOptions(), cfhandler, key.Encode(), &val); + s3 = storage_->Get(storage_->DefaultMultiGetOptions(), cfhandler, key, &val); ASSERT_TRUE(s3.ok()); ASSERT_EQ(val, ""); } @@ -217,53 +205,78 @@ TEST_F(IndexerTest, JsonTag) { { auto s = indexer.Record(key1, ns); ASSERT_TRUE(s); - ASSERT_EQ(s->first->name, idxname); - ASSERT_EQ(s->second.size(), 1); - ASSERT_EQ(s->second["$.x"], "food,kitChen,Beauty"); + ASSERT_EQ(s->updater.info->name, idxname); + ASSERT_EQ(s->fields.size(), 1); + ASSERT_EQ(s->fields["$.x"], T("food,kitChen,Beauty")); auto s_set = db.Set(key1, "$.x", "\"Clothing,FOOD,sport\""); ASSERT_TRUE(s_set.ok()); - auto s2 = indexer.Update(*s, key1, ns); + auto s2 = indexer.Update(*s); ASSERT_TRUE(s2); - auto subkey = redis::ConstructTagFieldSubkey("$.x", "food", key1); - auto nskey = ComposeNamespaceKey(ns, idxname, false); - auto key = InternalKey(nskey, subkey, 0, false); + auto key = redis::SearchKey(ns, idxname, "$.x").ConstructTagFieldData("food", key1); std::string val; - auto s3 = storage_->Get(storage_->DefaultMultiGetOptions(), cfhandler, key.Encode(), &val); + auto s3 = storage_->Get(storage_->DefaultMultiGetOptions(), cfhandler, key, &val); ASSERT_TRUE(s3.ok()); ASSERT_EQ(val, ""); - subkey = redis::ConstructTagFieldSubkey("$.x", "clothing", key1); - nskey = ComposeNamespaceKey(ns, idxname, false); - key = InternalKey(nskey, subkey, 0, false); + key = redis::SearchKey(ns, idxname, "$.x").ConstructTagFieldData("clothing", key1); - s3 = storage_->Get(storage_->DefaultMultiGetOptions(), cfhandler, key.Encode(), &val); + s3 = storage_->Get(storage_->DefaultMultiGetOptions(), cfhandler, key, &val); ASSERT_TRUE(s3.ok()); ASSERT_EQ(val, ""); - subkey = redis::ConstructTagFieldSubkey("$.x", "sport", key1); - nskey = ComposeNamespaceKey(ns, idxname, false); - key = InternalKey(nskey, subkey, 0, false); + key = redis::SearchKey(ns, idxname, "$.x").ConstructTagFieldData("sport", key1); - s3 = storage_->Get(storage_->DefaultMultiGetOptions(), cfhandler, key.Encode(), &val); + s3 = storage_->Get(storage_->DefaultMultiGetOptions(), cfhandler, key, &val); ASSERT_TRUE(s3.ok()); ASSERT_EQ(val, ""); - subkey = redis::ConstructTagFieldSubkey("$.x", "kitchen", key1); - nskey = ComposeNamespaceKey(ns, idxname, false); - key = InternalKey(nskey, subkey, 0, false); + key = redis::SearchKey(ns, idxname, "$.x").ConstructTagFieldData("kitchen", key1); - s3 = storage_->Get(storage_->DefaultMultiGetOptions(), cfhandler, key.Encode(), &val); + s3 = storage_->Get(storage_->DefaultMultiGetOptions(), cfhandler, key, &val); ASSERT_TRUE(s3.IsNotFound()); - subkey = redis::ConstructTagFieldSubkey("$.x", "beauty", key1); - nskey = ComposeNamespaceKey(ns, idxname, false); - key = InternalKey(nskey, subkey, 0, false); + key = redis::SearchKey(ns, idxname, "$.x").ConstructTagFieldData("beauty", key1); - s3 = storage_->Get(storage_->DefaultMultiGetOptions(), cfhandler, key.Encode(), &val); + s3 = storage_->Get(storage_->DefaultMultiGetOptions(), cfhandler, key, &val); ASSERT_TRUE(s3.IsNotFound()); } } + +TEST_F(IndexerTest, JsonTagBuildIndex) { + redis::Json db(storage_.get(), ns); + auto cfhandler = storage_->GetCFHandle(ColumnFamilyID::Search); + + auto key1 = "idxtestjson:k2"; + auto idxname = "jsontest"; + + { + auto s_set = db.Set(key1, "$", R"({"x": "food,kitChen,Beauty"})"); + ASSERT_TRUE(s_set.ok()); + + auto s2 = indexer.updater_list[1].Build(); + ASSERT_EQ(s2.Msg(), Status::ok_msg); + + auto key = redis::SearchKey(ns, idxname, "$.x").ConstructTagFieldData("food", key1); + + std::string val; + auto s3 = storage_->Get(storage_->DefaultMultiGetOptions(), cfhandler, key, &val); + ASSERT_TRUE(s3.ok()); + ASSERT_EQ(val, ""); + + key = redis::SearchKey(ns, idxname, "$.x").ConstructTagFieldData("kitchen", key1); + + s3 = storage_->Get(storage_->DefaultMultiGetOptions(), cfhandler, key, &val); + ASSERT_TRUE(s3.ok()); + ASSERT_EQ(val, ""); + + key = redis::SearchKey(ns, idxname, "$.x").ConstructTagFieldData("beauty", key1); + + s3 = storage_->Get(storage_->DefaultMultiGetOptions(), cfhandler, key, &val); + ASSERT_TRUE(s3.ok()); + ASSERT_EQ(val, ""); + } +} diff --git a/tests/cppunit/interval_test.cc b/tests/cppunit/interval_test.cc new file mode 100644 index 00000000000..bffd5d630c2 --- /dev/null +++ b/tests/cppunit/interval_test.cc @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#include "search/interval.h" + +#include + +#include "search/ir.h" + +using namespace kqir; + +TEST(IntervalSet, Simple) { + ASSERT_TRUE(IntervalSet().IsEmpty()); + ASSERT_TRUE(!IntervalSet().IsFull()); + ASSERT_TRUE(IntervalSet(IntervalSet::full).IsFull()); + ASSERT_TRUE(!IntervalSet(IntervalSet::full).IsEmpty()); + ASSERT_TRUE((~IntervalSet()).IsFull()); + ASSERT_TRUE((~IntervalSet(IntervalSet::full)).IsEmpty()); + + ASSERT_EQ(IntervalSet(Interval(1, 2)) | IntervalSet(Interval(2, 4)), IntervalSet(Interval(1, 4))); + ASSERT_EQ((IntervalSet(Interval(1, 2)) | IntervalSet(Interval(2, 4))).intervals, (IntervalSet::DataType{{1, 4}})); + ASSERT_EQ((IntervalSet(Interval(1, 2)) | IntervalSet(Interval(3, 4))).intervals, + (IntervalSet::DataType{{1, 2}, {3, 4}})); + ASSERT_EQ((IntervalSet(Interval(1, 4)) | IntervalSet(Interval(2, 3))).intervals, (IntervalSet::DataType{{1, 4}})); + ASSERT_EQ((IntervalSet(Interval(2, 3)) | IntervalSet(Interval(1, 4))).intervals, (IntervalSet::DataType{{1, 4}})); + ASSERT_EQ((IntervalSet(NumericCompareExpr::GET, 1) | IntervalSet(NumericCompareExpr::LT, 4)).intervals, + (IntervalSet::DataType{{IntervalSet::minf, IntervalSet::inf}})); + ASSERT_EQ((IntervalSet(NumericCompareExpr::GET, 1) | IntervalSet(NumericCompareExpr::NE, 4)).intervals, + (IntervalSet::DataType{{IntervalSet::minf, IntervalSet::inf}})); + ASSERT_EQ((IntervalSet(NumericCompareExpr::GET, 4) | IntervalSet(NumericCompareExpr::LT, 1)).intervals, + (IntervalSet::DataType{{IntervalSet::minf, 1}, {4, IntervalSet::inf}})); + ASSERT_EQ((IntervalSet(NumericCompareExpr::GET, 4) | IntervalSet(NumericCompareExpr::NE, 1)).intervals, + (IntervalSet::DataType{{IntervalSet::minf, 1}, {IntervalSet::NextNum(1), IntervalSet::inf}})); + + ASSERT_TRUE((IntervalSet(Interval(1, 2)) & IntervalSet(Interval(3, 4))).IsEmpty()); + ASSERT_EQ((IntervalSet(Interval(1, 2)) & IntervalSet(Interval(2, 4))).intervals, (IntervalSet::DataType{{2, 2}})); + ASSERT_EQ((IntervalSet(Interval(1, 3)) & IntervalSet(Interval(2, 4))).intervals, (IntervalSet::DataType{{2, 3}})); + ASSERT_EQ((IntervalSet(Interval(3, 8)) & (IntervalSet(Interval(1, 4)) | IntervalSet(Interval(5, 7)))).intervals, + (IntervalSet::DataType{{3, 4}, {5, 7}})); + ASSERT_EQ((IntervalSet(Interval(3, 8)) & (IntervalSet(Interval(1, 4)) | IntervalSet(Interval(9, 11)))).intervals, + (IntervalSet::DataType{{3, 4}})); + ASSERT_EQ((IntervalSet(NumericCompareExpr::GET, 1) & IntervalSet(NumericCompareExpr::LT, 4)).intervals, + (IntervalSet::DataType{{1, 4}})); + ASSERT_EQ((IntervalSet(NumericCompareExpr::GET, 1) & IntervalSet(NumericCompareExpr::NE, 4)).intervals, + (IntervalSet::DataType{{1, 4}, {IntervalSet::NextNum(4), IntervalSet::inf}})); + + ASSERT_EQ(IntervalSet(IntervalSet::full) & IntervalSet(IntervalSet::full), IntervalSet(IntervalSet::full)); + ASSERT_EQ(IntervalSet(IntervalSet::full) | IntervalSet(IntervalSet::full), IntervalSet(IntervalSet::full)); + + ASSERT_EQ((IntervalSet({1, 5}) | IntervalSet({7, 10})) & IntervalSet({2, 8}), + IntervalSet({2, 5}) | IntervalSet({7, 8})); + ASSERT_EQ(~IntervalSet({2, 8}), IntervalSet({IntervalSet::minf, 2}) | IntervalSet({8, IntervalSet::inf})); + + for (auto i = 0; i < 2000; ++i) { + auto gen = [] { return static_cast(std::rand()) / 100; }; + auto geni = [&gen] { + auto r = std::rand() % 50; + if (r == 0) { + return IntervalSet(NumericCompareExpr::GET, gen()); + } else if (r == 1) { + return IntervalSet(NumericCompareExpr::LT, gen()); + } else if (r == 2) { + return IntervalSet(NumericCompareExpr::NE, gen()); + } else { + return IntervalSet({gen(), gen()}); + } + }; + + auto l = geni(), r = geni(); + for (int j = 0; j < i % 10; ++j) { + l = l | geni(); + } + for (int j = 0; j < i % 7; ++j) { + r = r | geni(); + } + ASSERT_EQ(~l | ~r, ~(l & r)); + ASSERT_EQ(~l & ~r, ~(l | r)); + } +} diff --git a/tests/cppunit/ir_dot_dumper_test.cc b/tests/cppunit/ir_dot_dumper_test.cc new file mode 100644 index 00000000000..d4ae72927ae --- /dev/null +++ b/tests/cppunit/ir_dot_dumper_test.cc @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#include "search/ir_dot_dumper.h" + +#include +#include + +#include "gtest/gtest.h" +#include "search/ir_plan.h" +#include "search/ir_sema_checker.h" +#include "search/passes/manager.h" +#include "search/search_encoding.h" +#include "search/sql_transformer.h" +#include "storage/redis_metadata.h" + +using namespace kqir; + +static auto Parse(const std::string& in) { return sql::ParseToIR(peg::string_input(in, "test")); } + +TEST(DotDumperTest, Simple) { + auto ir = *Parse("select a from b where c = 1 or d hastag \"x\" and 2 <= e order by e asc limit 0, 10"); + + std::stringstream ss; + DotDumper dumper{ss}; + + dumper.Dump(ir.get()); + + std::string dot = ss.str(); + std::smatch matches; + + std::regex_search(dot, matches, std::regex(R"((\w+) \[ label = "SearchStmt)")); + auto search_stmt = matches[1].str(); + + std::regex_search(dot, matches, std::regex(R"((\w+) \[ label = "OrExpr)")); + auto or_expr = matches[1].str(); + + std::regex_search(dot, matches, std::regex(R"((\w+) \[ label = "AndExpr)")); + auto and_expr = matches[1].str(); + + ASSERT_NE(dot.find(fmt::format("{} -> {}", search_stmt, or_expr)), std::string::npos); + ASSERT_NE(dot.find(fmt::format("{} -> {}", or_expr, and_expr)), std::string::npos); +} + +static auto ParseS(SemaChecker& sc, const std::string& in) { + auto ir = *Parse(in); + EXPECT_EQ(sc.Check(ir.get()).Msg(), Status::ok_msg); + return ir; +} + +static IndexMap MakeIndexMap() { + auto f1 = FieldInfo("t1", std::make_unique()); + auto f2 = FieldInfo("t2", std::make_unique()); + f2.metadata->noindex = true; + auto f3 = FieldInfo("n1", std::make_unique()); + auto f4 = FieldInfo("n2", std::make_unique()); + auto f5 = FieldInfo("n3", std::make_unique()); + f5.metadata->noindex = true; + auto ia = std::make_unique("ia", redis::IndexMetadata(), ""); + ia->Add(std::move(f1)); + ia->Add(std::move(f2)); + ia->Add(std::move(f3)); + ia->Add(std::move(f4)); + ia->Add(std::move(f5)); + + IndexMap res; + res.Insert(std::move(ia)); + return res; +} + +TEST(DotDumperTest, Plan) { + auto index_map = MakeIndexMap(); + SemaChecker sc(index_map); + auto plan = PassManager::Execute( + PassManager::Default(), + ParseS( + sc, + "select * from ia where (n1 < 2 or n1 >= 3) and (n1 >= 1 and n1 < 4) and not n3 != 1 and t2 hastag \"a\"")); + + std::stringstream ss; + DotDumper dd(ss); + dd.Dump(plan.get()); + + std::string dot = ss.str(); + std::smatch matches; + + std::regex_search(dot, matches, std::regex(R"((\w+) \[ label = "Filter)")); + auto filter = matches[1].str(); + + std::regex_search(dot, matches, std::regex(R"((\w+) \[ label = "Merge)")); + auto merge = matches[1].str(); + + std::regex_search(dot, matches, std::regex(R"((\w+) \[ label = "NumericFieldScan)")); + auto scan = matches[1].str(); + + ASSERT_NE(dot.find(fmt::format("{} -> {}", filter, merge)), std::string::npos); + ASSERT_NE(dot.find(fmt::format("{} -> {}", merge, scan)), std::string::npos); +} diff --git a/tests/cppunit/ir_pass_test.cc b/tests/cppunit/ir_pass_test.cc new file mode 100644 index 00000000000..81ed49e8b94 --- /dev/null +++ b/tests/cppunit/ir_pass_test.cc @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#include "search/ir_pass.h" + +#include "fmt/core.h" +#include "gtest/gtest.h" +#include "search/interval.h" +#include "search/ir_sema_checker.h" +#include "search/passes/interval_analysis.h" +#include "search/passes/lower_to_plan.h" +#include "search/passes/manager.h" +#include "search/passes/push_down_not_expr.h" +#include "search/passes/simplify_and_or_expr.h" +#include "search/passes/simplify_boolean.h" +#include "search/sql_transformer.h" + +using namespace kqir; + +static auto Parse(const std::string& in) { return sql::ParseToIR(peg::string_input(in, "test")); } + +TEST(IRPassTest, Simple) { + auto ir = *Parse("select a from b where not c = 1 or d hastag \"x\" and 2 <= e order by e asc limit 0, 10"); + + auto original = ir->Dump(); + + Visitor visitor; + auto ir2 = visitor.Transform(std::move(ir)); + ASSERT_EQ(original, ir2->Dump()); +} + +TEST(IRPassTest, SimplifyBoolean) { + SimplifyBoolean sb; + ASSERT_EQ(sb.Transform(*Parse("select a from b where not false"))->Dump(), "select a from b where true"); + ASSERT_EQ(sb.Transform(*Parse("select a from b where not not false"))->Dump(), "select a from b where false"); + ASSERT_EQ(sb.Transform(*Parse("select a from b where true and true"))->Dump(), "select a from b where true"); + ASSERT_EQ(sb.Transform(*Parse("select a from b where true and false"))->Dump(), "select a from b where false"); + ASSERT_EQ(sb.Transform(*Parse("select a from b where false and true"))->Dump(), "select a from b where false"); + ASSERT_EQ(sb.Transform(*Parse("select a from b where true and false and true"))->Dump(), + "select a from b where false"); + ASSERT_EQ(sb.Transform(*Parse("select a from b where true and true and true"))->Dump(), "select a from b where true"); + ASSERT_EQ(sb.Transform(*Parse("select a from b where x > 1 and false"))->Dump(), "select a from b where false"); + ASSERT_EQ(sb.Transform(*Parse("select a from b where x > 1 and true"))->Dump(), "select a from b where x > 1"); + ASSERT_EQ(sb.Transform(*Parse("select a from b where x > 1 and true and y < 10"))->Dump(), + "select a from b where (and x > 1, y < 10)"); + ASSERT_EQ(sb.Transform(*Parse("select a from b where not (false and (not true))"))->Dump(), + "select a from b where true"); + ASSERT_EQ(sb.Transform(*Parse("select a from b where true or true"))->Dump(), "select a from b where true"); + ASSERT_EQ(sb.Transform(*Parse("select a from b where true or false"))->Dump(), "select a from b where true"); + ASSERT_EQ(sb.Transform(*Parse("select a from b where false or true"))->Dump(), "select a from b where true"); + ASSERT_EQ(sb.Transform(*Parse("select a from b where true or false or true"))->Dump(), "select a from b where true"); + ASSERT_EQ(sb.Transform(*Parse("select a from b where true or false or true"))->Dump(), "select a from b where true"); + ASSERT_EQ(sb.Transform(*Parse("select a from b where not ((x < 1 or true) and (y > 2 and true))"))->Dump(), + "select a from b where not y > 2"); +} + +TEST(IRPassTest, SimplifyAndOrExpr) { + SimplifyAndOrExpr saoe; + + ASSERT_EQ(Parse("select a from b where true and (false and true)").GetValue()->Dump(), + "select a from b where (and true, (and false, true))"); + ASSERT_EQ(saoe.Transform(*Parse("select a from b where true and (false and true)"))->Dump(), + "select a from b where (and true, false, true)"); + ASSERT_EQ(saoe.Transform(*Parse("select a from b where true or (false or true)"))->Dump(), + "select a from b where (or true, false, true)"); + ASSERT_EQ(saoe.Transform(*Parse("select a from b where true and (false or true)"))->Dump(), + "select a from b where (and true, (or false, true))"); + ASSERT_EQ(saoe.Transform(*Parse("select a from b where true or (false and true)"))->Dump(), + "select a from b where (or true, (and false, true))"); + ASSERT_EQ(saoe.Transform(*Parse("select a from b where x > 1 or (y < 2 or z = 3)"))->Dump(), + "select a from b where (or x > 1, y < 2, z = 3)"); +} + +TEST(IRPassTest, PushDownNotExpr) { + PushDownNotExpr pdne; + + ASSERT_EQ(pdne.Transform(*Parse("select * from a where not a > 1"))->Dump(), "select * from a where a <= 1"); + ASSERT_EQ(pdne.Transform(*Parse("select * from a where not a hastag \"\""))->Dump(), + "select * from a where not a hastag \"\""); + ASSERT_EQ(pdne.Transform(*Parse("select * from a where not not a > 1"))->Dump(), "select * from a where a > 1"); + ASSERT_EQ(pdne.Transform(*Parse("select * from a where not (a > 1 and b <= 3)"))->Dump(), + "select * from a where (or a <= 1, b > 3)"); + ASSERT_EQ(pdne.Transform(*Parse("select * from a where not (a > 1 or b <= 3)"))->Dump(), + "select * from a where (and a <= 1, b > 3)"); + ASSERT_EQ(pdne.Transform(*Parse("select * from a where not (not a > 1 or (b < 3 and c hastag \"\"))"))->Dump(), + "select * from a where (and a > 1, (or b >= 3, not c hastag \"\"))"); +} + +TEST(IRPassTest, Manager) { + auto expr_passes = PassManager::ExprPasses(); + ASSERT_EQ(PassManager::Execute(expr_passes, + *Parse("select * from a where not (x > 1 or (y < 2 or z = 3)) and (true or x = 1)")) + ->Dump(), + "select * from a where (and x <= 1, y >= 2, z != 3)"); +} + +TEST(IRPassTest, LowerToPlan) { + LowerToPlan ltp; + + ASSERT_EQ(ltp.Transform(*Parse("select * from a"))->Dump(), "project *: full-scan a"); + ASSERT_EQ(ltp.Transform(*Parse("select * from a limit 1"))->Dump(), "project *: (limit 0, 1: full-scan a)"); + ASSERT_EQ(ltp.Transform(*Parse("select * from a where false"))->Dump(), "project *: noop"); + ASSERT_EQ(ltp.Transform(*Parse("select * from a where false limit 1"))->Dump(), "project *: noop"); + ASSERT_EQ(ltp.Transform(*Parse("select * from a where b > 1"))->Dump(), "project *: (filter b > 1: full-scan a)"); + ASSERT_EQ(ltp.Transform(*Parse("select a from b where c = 1 order by d"))->Dump(), + "project a: (sort d, asc: (filter c = 1: full-scan b))"); + ASSERT_EQ(ltp.Transform(*Parse("select a from b where c = 1 limit 1"))->Dump(), + "project a: (limit 0, 1: (filter c = 1: full-scan b))"); + ASSERT_EQ(ltp.Transform(*Parse("select a from b where c = 1 order by d limit 1"))->Dump(), + "project a: (limit 0, 1: (sort d, asc: (filter c = 1: full-scan b)))"); +} + +TEST(IRPassTest, IntervalAnalysis) { + auto ia_passes = PassManager::Create(IntervalAnalysis{true}, SimplifyAndOrExpr{}, SimplifyBoolean{}); + + ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a > 1 or a < 3"))->Dump(), + "select * from a where true"); + ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a < 1 and a > 3"))->Dump(), + "select * from a where false"); + ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where (a > 3 or a < 1) and a = 2"))->Dump(), + "select * from a where false"); + ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where b = 1 and (a = 1 or a != 1)"))->Dump(), + "select * from a where b = 1"); + ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a = 1 or b = 1 or a != 1"))->Dump(), + "select * from a where true"); + ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where (a < 3 or a > 1) and b >= 1"))->Dump(), + "select * from a where b >= 1"); + ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a != 1 or a != 2"))->Dump(), + "select * from a where true"); + ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a = 1 and a = 2"))->Dump(), + "select * from a where false"); + + ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a < 1 and a < 3"))->Dump(), + "select * from a where a < 1"); + ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a < 1 or a < 3"))->Dump(), + "select * from a where a < 3"); + ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a = 1 and a < 3"))->Dump(), + "select * from a where a = 1"); + ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a = 1 or a < 3"))->Dump(), + "select * from a where a < 3"); + ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a = 1 or a = 3"))->Dump(), + "select * from a where (or a = 1, a = 3)"); + ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a != 1"))->Dump(), + "select * from a where a != 1"); + ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a != 1 and a != 2"))->Dump(), + "select * from a where (and a != 1, a != 2)"); + ASSERT_EQ( + PassManager::Execute(ia_passes, *Parse("select * from a where a >= 0 and a >= 1 and a < 4 and a != 2"))->Dump(), + fmt::format("select * from a where (or (and a >= 1, a < 2), (and a >= {}, a < 4))", IntervalSet::NextNum(2))); + ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a != 1 and b > 1 and b = 2"))->Dump(), + "select * from a where (and a != 1, b = 2)"); +} + +static IndexMap MakeIndexMap() { + auto f1 = FieldInfo("t1", std::make_unique()); + auto f2 = FieldInfo("t2", std::make_unique()); + f2.metadata->noindex = true; + auto f3 = FieldInfo("n1", std::make_unique()); + auto f4 = FieldInfo("n2", std::make_unique()); + auto f5 = FieldInfo("n3", std::make_unique()); + f5.metadata->noindex = true; + auto ia = std::make_unique("ia", redis::IndexMetadata(), ""); + ia->Add(std::move(f1)); + ia->Add(std::move(f2)); + ia->Add(std::move(f3)); + ia->Add(std::move(f4)); + ia->Add(std::move(f5)); + + IndexMap res; + res.Insert(std::move(ia)); + return res; +} + +std::unique_ptr ParseS(SemaChecker& sc, const std::string& in) { + auto res = *Parse(in); + EXPECT_EQ(sc.Check(res.get()).Msg(), Status::ok_msg); + return res; +} + +TEST(IRPassTest, IndexSelection) { + auto index_map = MakeIndexMap(); + auto sc = SemaChecker(index_map); + + auto passes = PassManager::Default(); + ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia"))->Dump(), "project *: full-scan ia"); + ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia order by n1"))->Dump(), + "project *: numeric-scan n1, [-inf, inf), asc"); + ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia order by n1 limit 1"))->Dump(), + "project *: (limit 0, 1: numeric-scan n1, [-inf, inf), asc)"); + ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia order by n3"))->Dump(), + "project *: (sort n3, asc: full-scan ia)"); + ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia order by n3 limit 1"))->Dump(), + "project *: (top-n sort n3, asc, 0, 1: full-scan ia)"); + ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where n2 = 1 order by n1"))->Dump(), + "project *: (filter n2 = 1: numeric-scan n1, [-inf, inf), asc)"); + + ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where n1 = 1"))->Dump(), + fmt::format("project *: numeric-scan n1, [1, {}), asc", IntervalSet::NextNum(1))); + ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where n1 != 1"))->Dump(), + "project *: (filter n1 != 1: full-scan ia)"); + ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where n1 >= 1"))->Dump(), + "project *: numeric-scan n1, [1, inf), asc"); + ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where n1 >= 1 and n1 < 2"))->Dump(), + "project *: numeric-scan n1, [1, 2), asc"); + ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where n1 >= 1 and n2 >= 2"))->Dump(), + "project *: (filter n2 >= 2: numeric-scan n1, [1, inf), asc)"); + ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where n1 >= 1 and n2 = 2"))->Dump(), + fmt::format("project *: (filter n1 >= 1: numeric-scan n2, [2, {}), asc)", IntervalSet::NextNum(2))); + ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where n1 >= 1 and n3 = 2"))->Dump(), + "project *: (filter n3 = 2: numeric-scan n1, [1, inf), asc)"); + ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where n3 = 1"))->Dump(), + "project *: (filter n3 = 1: full-scan ia)"); + ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where n1 = 1 and n3 = 2"))->Dump(), + fmt::format("project *: (filter n3 = 2: numeric-scan n1, [1, {}), asc)", IntervalSet::NextNum(1))); + ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where n1 = 1 and t1 hastag \"a\""))->Dump(), + fmt::format("project *: (filter t1 hastag \"a\": numeric-scan n1, [1, {}), asc)", IntervalSet::NextNum(1))); + ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where t1 hastag \"a\""))->Dump(), + "project *: tag-scan t1, a"); + ASSERT_EQ( + PassManager::Execute(passes, ParseS(sc, "select * from ia where t1 hastag \"a\" and t2 hastag \"a\""))->Dump(), + "project *: (filter t2 hastag \"a\": tag-scan t1, a)"); + ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where t2 hastag \"a\""))->Dump(), + "project *: (filter t2 hastag \"a\": full-scan ia)"); + + ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where n1 >= 2 or n1 < 1"))->Dump(), + "project *: (merge numeric-scan n1, [-inf, 1), asc, numeric-scan n1, [2, inf), asc)"); + ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where n1 >= 1 or n2 >= 2"))->Dump(), + "project *: (merge numeric-scan n1, [1, inf), asc, (filter n1 < 1: numeric-scan n2, [2, inf), asc))"); + ASSERT_EQ( + PassManager::Execute(passes, ParseS(sc, "select * from ia where n1 >= 1 or n2 = 2"))->Dump(), + fmt::format("project *: (merge numeric-scan n1, [1, inf), asc, (filter n1 < 1: numeric-scan n2, [2, {}), asc))", + IntervalSet::NextNum(2))); + ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where n1 >= 1 or n3 = 2"))->Dump(), + "project *: (filter (or n1 >= 1, n3 = 2): full-scan ia)"); + ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where n1 = 1 or n3 = 2"))->Dump(), + "project *: (filter (or n1 = 1, n3 = 2): full-scan ia)"); + ASSERT_EQ( + PassManager::Execute(passes, ParseS(sc, "select * from ia where n1 = 1 or t1 hastag \"a\""))->Dump(), + fmt::format("project *: (merge tag-scan t1, a, (filter not t1 hastag \"a\": numeric-scan n1, [1, {}), asc))", + IntervalSet::NextNum(1))); + ASSERT_EQ( + PassManager::Execute(passes, ParseS(sc, "select * from ia where t1 hastag \"a\" or t2 hastag \"a\""))->Dump(), + "project *: (filter (or t1 hastag \"a\", t2 hastag \"a\"): full-scan ia)"); + ASSERT_EQ( + PassManager::Execute(passes, ParseS(sc, "select * from ia where t1 hastag \"a\" or t1 hastag \"b\""))->Dump(), + "project *: (merge tag-scan t1, a, (filter not t1 hastag \"a\": tag-scan t1, b))"); + + ASSERT_EQ( + PassManager::Execute( + passes, ParseS(sc, "select * from ia where (n1 < 2 or n1 >= 3) and (n1 >= 1 and n1 < 4) and not n3 != 1")) + ->Dump(), + "project *: (filter n3 = 1: (merge numeric-scan n1, [1, 2), asc, numeric-scan n1, [3, 4), asc))"); +} diff --git a/tests/cppunit/ir_sema_checker_test.cc b/tests/cppunit/ir_sema_checker_test.cc new file mode 100644 index 00000000000..3a15dde725c --- /dev/null +++ b/tests/cppunit/ir_sema_checker_test.cc @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#include "search/ir_sema_checker.h" + +#include +#include + +#include + +#include "gtest/gtest.h" +#include "search/search_encoding.h" +#include "search/sql_transformer.h" +#include "storage/redis_metadata.h" + +using namespace kqir; + +static auto Parse(const std::string& in) { return sql::ParseToIR(peg::string_input(in, "test")); } + +static IndexMap MakeIndexMap() { + auto f1 = FieldInfo("f1", std::make_unique()); + auto f2 = FieldInfo("f2", std::make_unique()); + auto f3 = FieldInfo("f3", std::make_unique()); + auto ia = std::make_unique("ia", redis::IndexMetadata(), ""); + ia->Add(std::move(f1)); + ia->Add(std::move(f2)); + ia->Add(std::move(f3)); + + IndexMap res; + res.Insert(std::move(ia)); + return res; +} + +using testing::MatchesRegex; + +TEST(SemaCheckerTest, Simple) { + auto index_map = MakeIndexMap(); + + { + SemaChecker checker(index_map); + ASSERT_EQ(checker.Check(Parse("select a from b")->get()).Msg(), "index `b` not found"); + ASSERT_EQ(checker.Check(Parse("select a from ia")->get()).Msg(), "field `a` not found in index `ia`"); + ASSERT_EQ(checker.Check(Parse("select f1 from ia")->get()).Msg(), "ok"); + ASSERT_EQ(checker.Check(Parse("select f1 from ia where b = 1")->get()).Msg(), "field `b` not found in index `ia`"); + ASSERT_EQ(checker.Check(Parse("select f1 from ia where f1 = 1")->get()).Msg(), "field `f1` is not a numeric field"); + ASSERT_EQ(checker.Check(Parse("select f1 from ia where f2 hastag \"a\"")->get()).Msg(), + "field `f2` is not a tag field"); + ASSERT_EQ(checker.Check(Parse("select f1 from ia where f1 hastag \"a\" and f2 = 1")->get()).Msg(), "ok"); + ASSERT_EQ(checker.Check(Parse("select f1 from ia where f1 hastag \"\"")->get()).Msg(), + "tag cannot be an empty string"); + ASSERT_EQ(checker.Check(Parse("select f1 from ia where f1 hastag \",\"")->get()).Msg(), + "tag cannot contain the separator `,`"); + ASSERT_EQ(checker.Check(Parse("select f1 from ia order by a")->get()).Msg(), "field `a` not found in index `ia`"); + } + + { + SemaChecker checker(index_map); + auto root = *Parse("select f1 from ia where f1 hastag \"a\" and f2 = 1 order by f3"); + + ASSERT_EQ(checker.Check(root.get()).Msg(), "ok"); + } +} diff --git a/tests/cppunit/iterator_test.cc b/tests/cppunit/iterator_test.cc index 3645e59cdac..08705d7719c 100644 --- a/tests/cppunit/iterator_test.cc +++ b/tests/cppunit/iterator_test.cc @@ -390,7 +390,7 @@ TEST_F(WALIteratorTest, BasicString) { auto start_seq = storage_->GetDB()->GetLatestSequenceNumber(); redis::String string(storage_.get(), "test_ns0"); string.Set("a", "1"); - string.MSet({{"b", "2"}, {"c", "3"}}); + string.MSet({{"b", "2"}, {"c", "3"}}, 0); ASSERT_TRUE(string.Del("b").ok()); std::vector put_keys, delete_keys; @@ -451,10 +451,10 @@ TEST_F(WALIteratorTest, BasicHash) { auto item = iter.Item(); switch (item.type) { case engine::WALItem::Type::kTypePut: { - if (item.column_family_id == kColumnFamilyIDDefault) { + if (item.column_family_id == static_cast(ColumnFamilyID::PrimarySubkey)) { InternalKey internal_key(item.key, storage_->IsSlotIdEncoded()); put_fields.emplace_back(internal_key.GetSubKey().ToString()); - } else if (item.column_family_id == kColumnFamilyIDMetadata) { + } else if (item.column_family_id == static_cast(ColumnFamilyID::Metadata)) { auto [ns, key] = ExtractNamespaceKey(item.key, storage_->IsSlotIdEncoded()); ASSERT_EQ(ns.ToString(), "test_ns1"); put_keys.emplace_back(key.ToString()); @@ -504,10 +504,10 @@ TEST_F(WALIteratorTest, BasicSet) { auto item = iter.Item(); switch (item.type) { case engine::WALItem::Type::kTypePut: { - if (item.column_family_id == kColumnFamilyIDDefault) { + if (item.column_family_id == static_cast(ColumnFamilyID::PrimarySubkey)) { InternalKey internal_key(item.key, storage_->IsSlotIdEncoded()); put_members.emplace_back(internal_key.GetSubKey().ToString()); - } else if (item.column_family_id == kColumnFamilyIDMetadata) { + } else if (item.column_family_id == static_cast(ColumnFamilyID::Metadata)) { auto [ns, key] = ExtractNamespaceKey(item.key, storage_->IsSlotIdEncoded()); ASSERT_EQ(ns.ToString(), "test_ns2"); put_keys.emplace_back(key.ToString()); @@ -559,10 +559,10 @@ TEST_F(WALIteratorTest, BasicZSet) { auto item = iter.Item(); switch (item.type) { case engine::WALItem::Type::kTypePut: { - if (item.column_family_id == kColumnFamilyIDDefault) { + if (item.column_family_id == static_cast(ColumnFamilyID::PrimarySubkey)) { InternalKey internal_key(item.key, storage_->IsSlotIdEncoded()); put_members.emplace_back(internal_key.GetSubKey().ToString()); - } else if (item.column_family_id == kColumnFamilyIDMetadata) { + } else if (item.column_family_id == static_cast(ColumnFamilyID::Metadata)) { auto [ns, key] = ExtractNamespaceKey(item.key, storage_->IsSlotIdEncoded()); ASSERT_EQ(ns.ToString(), "test_ns3"); put_keys.emplace_back(key.ToString()); @@ -609,9 +609,9 @@ TEST_F(WALIteratorTest, BasicList) { auto item = iter.Item(); switch (item.type) { case engine::WALItem::Type::kTypePut: { - if (item.column_family_id == kColumnFamilyIDDefault) { + if (item.column_family_id == static_cast(ColumnFamilyID::PrimarySubkey)) { put_values.emplace_back(item.value); - } else if (item.column_family_id == kColumnFamilyIDMetadata) { + } else if (item.column_family_id == static_cast(ColumnFamilyID::Metadata)) { auto [ns, key] = ExtractNamespaceKey(item.key, storage_->IsSlotIdEncoded()); ASSERT_EQ(ns.ToString(), "test_ns4"); put_keys.emplace_back(key.ToString()); @@ -663,12 +663,12 @@ TEST_F(WALIteratorTest, BasicStream) { auto item = iter.Item(); switch (item.type) { case engine::WALItem::Type::kTypePut: { - if (item.column_family_id == kColumnFamilyIDStream) { + if (item.column_family_id == static_cast(ColumnFamilyID::Stream)) { std::vector elems; auto s = redis::DecodeRawStreamEntryValue(item.value, &elems); ASSERT_TRUE(s.IsOK() && !elems.empty()); put_values.emplace_back(elems[0]); - } else if (item.column_family_id == kColumnFamilyIDMetadata) { + } else if (item.column_family_id == static_cast(ColumnFamilyID::Metadata)) { auto [ns, key] = ExtractNamespaceKey(item.key, storage_->IsSlotIdEncoded()); ASSERT_EQ(ns.ToString(), "test_ns5"); put_keys.emplace_back(key.ToString()); @@ -715,7 +715,7 @@ TEST_F(WALIteratorTest, BasicBitmap) { auto item = iter.Item(); switch (item.type) { case engine::WALItem::Type::kTypePut: { - if (item.column_family_id == kColumnFamilyIDDefault) { + if (item.column_family_id == static_cast(ColumnFamilyID::PrimarySubkey)) { put_values.emplace_back(item.value); } break; @@ -754,7 +754,7 @@ TEST_F(WALIteratorTest, BasicJSON) { auto item = iter.Item(); switch (item.type) { case engine::WALItem::Type::kTypePut: { - ASSERT_EQ(item.column_family_id, kColumnFamilyIDMetadata); + ASSERT_EQ(item.column_family_id, static_cast(ColumnFamilyID::Metadata)); auto [ns, key] = ExtractNamespaceKey(item.key, storage_->IsSlotIdEncoded()); ASSERT_EQ(ns.ToString(), "test_ns7"); put_keys.emplace_back(key.ToString()); @@ -767,7 +767,7 @@ TEST_F(WALIteratorTest, BasicJSON) { break; } case engine::WALItem::Type::kTypeDelete: { - ASSERT_EQ(item.column_family_id, kColumnFamilyIDMetadata); + ASSERT_EQ(item.column_family_id, static_cast(ColumnFamilyID::Metadata)); auto [ns, key] = ExtractNamespaceKey(item.key, storage_->IsSlotIdEncoded()); ASSERT_EQ(ns.ToString(), "test_ns7"); delete_keys.emplace_back(key.ToString()); @@ -801,7 +801,7 @@ TEST_F(WALIteratorTest, BasicSortedInt) { auto item = iter.Item(); switch (item.type) { case engine::WALItem::Type::kTypePut: { - if (item.column_family_id == kColumnFamilyIDDefault) { + if (item.column_family_id == static_cast(ColumnFamilyID::PrimarySubkey)) { const InternalKey internal_key(item.key, storage_->IsSlotIdEncoded()); auto value = DecodeFixed64(internal_key.GetSubKey().data()); put_values.emplace_back(value); diff --git a/tests/cppunit/metadata_test.cc b/tests/cppunit/metadata_test.cc index 5e3e60ee9b3..bd7282052e4 100644 --- a/tests/cppunit/metadata_test.cc +++ b/tests/cppunit/metadata_test.cc @@ -92,7 +92,7 @@ TEST_F(RedisTypeTest, GetMetadata) { EXPECT_TRUE(s.ok() && fvs.size() == ret); HashMetadata metadata; std::string ns_key = redis_->AppendNamespacePrefix(key_); - s = redis_->GetMetadata({kRedisHash}, ns_key, &metadata); + s = redis_->GetMetadata(redis::Database::GetOptions{}, {kRedisHash}, ns_key, &metadata); EXPECT_EQ(fvs.size(), metadata.size); s = redis_->Del(key_); EXPECT_TRUE(s.ok()); diff --git a/tests/cppunit/plan_executor_test.cc b/tests/cppunit/plan_executor_test.cc new file mode 100644 index 00000000000..00e0c162021 --- /dev/null +++ b/tests/cppunit/plan_executor_test.cc @@ -0,0 +1,419 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ +#include "search/plan_executor.h" + +#include + +#include + +#include "config/config.h" +#include "search/executors/mock_executor.h" +#include "search/indexer.h" +#include "search/interval.h" +#include "search/ir.h" +#include "search/ir_plan.h" +#include "search/value.h" +#include "string_util.h" +#include "test_base.h" +#include "types/redis_json.h" + +using namespace kqir; + +static auto exe_end = ExecutorNode::Result(ExecutorNode::end); + +static IndexMap MakeIndexMap() { + auto f1 = FieldInfo("f1", std::make_unique()); + auto f2 = FieldInfo("f2", std::make_unique()); + auto f3 = FieldInfo("f3", std::make_unique()); + auto ia = std::make_unique("ia", redis::IndexMetadata(), "search_ns"); + ia->metadata.on_data_type = redis::IndexOnDataType::JSON; + ia->prefixes.prefixes.emplace_back("test2:"); + ia->prefixes.prefixes.emplace_back("test4:"); + ia->Add(std::move(f1)); + ia->Add(std::move(f2)); + ia->Add(std::move(f3)); + + IndexMap res; + res.Insert(std::move(ia)); + return res; +} + +static auto index_map = MakeIndexMap(); + +static auto NextRow(ExecutorContext& ctx) { + auto n = ctx.Next(); + EXPECT_EQ(n.Msg(), Status::ok_msg); + auto v = std::move(n).GetValue(); + EXPECT_EQ(v.index(), 1); + return std::get(std::move(v)); +} + +TEST(PlanExecutorTest, Mock) { + auto op = std::make_unique(std::vector{}); + + auto ctx = ExecutorContext(op.get()); + ASSERT_EQ(ctx.Next().GetValue(), exe_end); + + op = std::make_unique(std::vector{{"a"}, {"b"}, {"c"}}); + + ctx = ExecutorContext(op.get()); + ASSERT_EQ(NextRow(ctx).key, "a"); + ASSERT_EQ(NextRow(ctx).key, "b"); + ASSERT_EQ(NextRow(ctx).key, "c"); + ASSERT_EQ(ctx.Next().GetValue(), exe_end); +} + +static auto IndexI() -> const IndexInfo* { return index_map.Find("ia", "search_ns")->second.get(); } +static auto FieldI(const std::string& f) -> const FieldInfo* { return &IndexI()->fields.at(f); } + +static auto N(double n) { return MakeValue(n); } +static auto T(const std::string& v) { return MakeValue(util::Split(v, ",")); } + +TEST(PlanExecutorTest, TopNSort) { + std::vector data{ + {"a", {{FieldI("f3"), N(4)}}, IndexI()}, {"b", {{FieldI("f3"), N(2)}}, IndexI()}, + {"c", {{FieldI("f3"), N(7)}}, IndexI()}, {"d", {{FieldI("f3"), N(3)}}, IndexI()}, + {"e", {{FieldI("f3"), N(1)}}, IndexI()}, {"f", {{FieldI("f3"), N(6)}}, IndexI()}, + {"g", {{FieldI("f3"), N(8)}}, IndexI()}, + }; + { + auto op = std::make_unique( + std::make_unique(data), + std::make_unique(SortByClause::ASC, std::make_unique("f3", FieldI("f3"))), + std::make_unique(0, 4)); + + auto ctx = ExecutorContext(op.get()); + ASSERT_EQ(NextRow(ctx).key, "e"); + ASSERT_EQ(NextRow(ctx).key, "b"); + ASSERT_EQ(NextRow(ctx).key, "d"); + ASSERT_EQ(NextRow(ctx).key, "a"); + ASSERT_EQ(ctx.Next().GetValue(), exe_end); + } + { + auto op = std::make_unique( + std::make_unique(data), + std::make_unique(SortByClause::ASC, std::make_unique("f3", FieldI("f3"))), + std::make_unique(1, 4)); + + auto ctx = ExecutorContext(op.get()); + ASSERT_EQ(NextRow(ctx).key, "b"); + ASSERT_EQ(NextRow(ctx).key, "d"); + ASSERT_EQ(NextRow(ctx).key, "a"); + ASSERT_EQ(NextRow(ctx).key, "f"); + ASSERT_EQ(ctx.Next().GetValue(), exe_end); + } +} + +TEST(PlanExecutorTest, Filter) { + std::vector data{ + {"a", {{FieldI("f3"), N(4)}}, IndexI()}, {"b", {{FieldI("f3"), N(2)}}, IndexI()}, + {"c", {{FieldI("f3"), N(7)}}, IndexI()}, {"d", {{FieldI("f3"), N(3)}}, IndexI()}, + {"e", {{FieldI("f3"), N(1)}}, IndexI()}, {"f", {{FieldI("f3"), N(6)}}, IndexI()}, + {"g", {{FieldI("f3"), N(8)}}, IndexI()}, + }; + { + auto field = std::make_unique("f3", FieldI("f3")); + auto op = std::make_unique( + std::make_unique(data), + AndExpr::Create(Node::List( + std::make_unique(NumericCompareExpr::GT, field->CloneAs(), + std::make_unique(2)), + std::make_unique(NumericCompareExpr::LET, field->CloneAs(), + std::make_unique(6))))); + + auto ctx = ExecutorContext(op.get()); + ASSERT_EQ(NextRow(ctx).key, "a"); + ASSERT_EQ(NextRow(ctx).key, "d"); + ASSERT_EQ(NextRow(ctx).key, "f"); + ASSERT_EQ(ctx.Next().GetValue(), exe_end); + } + { + auto field = std::make_unique("f3", FieldI("f3")); + auto op = std::make_unique( + std::make_unique(data), + OrExpr::Create(Node::List( + std::make_unique(NumericCompareExpr::GET, field->CloneAs(), + std::make_unique(6)), + std::make_unique(NumericCompareExpr::LT, field->CloneAs(), + std::make_unique(2))))); + + auto ctx = ExecutorContext(op.get()); + ASSERT_EQ(NextRow(ctx).key, "c"); + ASSERT_EQ(NextRow(ctx).key, "e"); + ASSERT_EQ(NextRow(ctx).key, "f"); + ASSERT_EQ(NextRow(ctx).key, "g"); + ASSERT_EQ(ctx.Next().GetValue(), exe_end); + } + + data = {{"a", {{FieldI("f1"), T("cpp,java")}}, IndexI()}, {"b", {{FieldI("f1"), T("python,cpp,c")}}, IndexI()}, + {"c", {{FieldI("f1"), T("c,perl")}}, IndexI()}, {"d", {{FieldI("f1"), T("rust,python")}}, IndexI()}, + {"e", {{FieldI("f1"), T("java,kotlin")}}, IndexI()}, {"f", {{FieldI("f1"), T("c,rust")}}, IndexI()}, + {"g", {{FieldI("f1"), T("c,cpp,java")}}, IndexI()}}; + { + auto field = std::make_unique("f1", FieldI("f1")); + auto op = std::make_unique( + std::make_unique(data), + AndExpr::Create(Node::List( + std::make_unique(field->CloneAs(), std::make_unique("c")), + std::make_unique(field->CloneAs(), std::make_unique("cpp"))))); + + auto ctx = ExecutorContext(op.get()); + ASSERT_EQ(NextRow(ctx).key, "b"); + ASSERT_EQ(NextRow(ctx).key, "g"); + ASSERT_EQ(ctx.Next().GetValue(), exe_end); + } + { + auto field = std::make_unique("f1", FieldI("f1")); + auto op = std::make_unique( + std::make_unique(data), + OrExpr::Create(Node::List( + std::make_unique(field->CloneAs(), std::make_unique("rust")), + std::make_unique(field->CloneAs(), std::make_unique("perl"))))); + + auto ctx = ExecutorContext(op.get()); + ASSERT_EQ(NextRow(ctx).key, "c"); + ASSERT_EQ(NextRow(ctx).key, "d"); + ASSERT_EQ(NextRow(ctx).key, "f"); + ASSERT_EQ(ctx.Next().GetValue(), exe_end); + } +} + +TEST(PlanExecutorTest, Limit) { + std::vector data{ + {"a", {{FieldI("f3"), N(4)}}, IndexI()}, {"b", {{FieldI("f3"), N(2)}}, IndexI()}, + {"c", {{FieldI("f3"), N(7)}}, IndexI()}, {"d", {{FieldI("f3"), N(3)}}, IndexI()}, + {"e", {{FieldI("f3"), N(1)}}, IndexI()}, {"f", {{FieldI("f3"), N(6)}}, IndexI()}, + {"g", {{FieldI("f3"), N(8)}}, IndexI()}, + }; + { + auto op = std::make_unique(std::make_unique(data), std::make_unique(1, 2)); + + auto ctx = ExecutorContext(op.get()); + ASSERT_EQ(NextRow(ctx).key, "b"); + ASSERT_EQ(NextRow(ctx).key, "c"); + ASSERT_EQ(ctx.Next().GetValue(), exe_end); + } + { + auto field = std::make_unique("f3", FieldI("f3")); + auto op = std::make_unique(std::make_unique(data), std::make_unique(0, 3)); + + auto ctx = ExecutorContext(op.get()); + ASSERT_EQ(NextRow(ctx).key, "a"); + ASSERT_EQ(NextRow(ctx).key, "b"); + ASSERT_EQ(NextRow(ctx).key, "c"); + ASSERT_EQ(ctx.Next().GetValue(), exe_end); + } +} + +TEST(PlanExecutorTest, Merge) { + std::vector data1{ + {"a", {{FieldI("f3"), N(4)}}, IndexI()}, + {"b", {{FieldI("f3"), N(2)}}, IndexI()}, + }; + std::vector data2{{"c", {{FieldI("f3"), N(7)}}, IndexI()}, + {"d", {{FieldI("f3"), N(3)}}, IndexI()}, + {"e", {{FieldI("f3"), N(1)}}, IndexI()}}; + { + auto op = + std::make_unique(Node::List(std::make_unique(data1), std::make_unique(data2))); + + auto ctx = ExecutorContext(op.get()); + ASSERT_EQ(NextRow(ctx).key, "a"); + ASSERT_EQ(NextRow(ctx).key, "b"); + ASSERT_EQ(NextRow(ctx).key, "c"); + ASSERT_EQ(NextRow(ctx).key, "d"); + ASSERT_EQ(NextRow(ctx).key, "e"); + ASSERT_EQ(ctx.Next().GetValue(), exe_end); + } + { + auto op = std::make_unique( + Node::List(std::make_unique(decltype(data1){}), std::make_unique(data1))); + + auto ctx = ExecutorContext(op.get()); + ASSERT_EQ(NextRow(ctx).key, "a"); + ASSERT_EQ(NextRow(ctx).key, "b"); + ASSERT_EQ(ctx.Next().GetValue(), exe_end); + } +} + +class PlanExecutorTestC : public TestBase { + protected: + explicit PlanExecutorTestC() : json_(std::make_unique(storage_.get(), "search_ns")) {} + ~PlanExecutorTestC() override = default; + + void SetUp() override {} + void TearDown() override {} + + std::unique_ptr json_; +}; + +TEST_F(PlanExecutorTestC, FullIndexScan) { + json_->Set("test1:a", "$", "{}"); + json_->Set("test1:b", "$", "{}"); + json_->Set("test2:c", "$", "{\"f3\": 6}"); + json_->Set("test3:d", "$", "{}"); + json_->Set("test4:e", "$", "{\"f3\": 7}"); + json_->Set("test4:f", "$", "{\"f3\": 2}"); + json_->Set("test4:g", "$", "{\"f3\": 8}"); + json_->Set("test5:h", "$", "{}"); + json_->Set("test5:i", "$", "{}"); + json_->Set("test5:g", "$", "{}"); + + { + auto op = std::make_unique(std::make_unique("ia", IndexI())); + + auto ctx = ExecutorContext(op.get(), storage_.get()); + ASSERT_EQ(NextRow(ctx).key, "test2:c"); + ASSERT_EQ(NextRow(ctx).key, "test4:e"); + ASSERT_EQ(NextRow(ctx).key, "test4:f"); + ASSERT_EQ(NextRow(ctx).key, "test4:g"); + ASSERT_EQ(ctx.Next().GetValue(), exe_end); + } + + { + auto op = std::make_unique( + std::make_unique(std::make_unique("ia", IndexI())), + std::make_unique(NumericCompareExpr::GT, std::make_unique("f3", FieldI("f3")), + std::make_unique(3))); + + auto ctx = ExecutorContext(op.get(), storage_.get()); + ASSERT_EQ(NextRow(ctx).key, "test2:c"); + ASSERT_EQ(NextRow(ctx).key, "test4:e"); + ASSERT_EQ(NextRow(ctx).key, "test4:g"); + ASSERT_EQ(ctx.Next().GetValue(), exe_end); + } +} + +struct ScopedUpdate { + redis::GlobalIndexer::RecordResult rr; + std::string_view key; + std::string ns; + + static auto Create(redis::GlobalIndexer& indexer, std::string_view key, const std::string& ns) { + auto s = indexer.Record(key, ns); + EXPECT_EQ(s.Msg(), Status::ok_msg); + return *s; + } + + ScopedUpdate(redis::GlobalIndexer& indexer, std::string_view key, const std::string& ns) + : rr(Create(indexer, key, ns)), key(key), ns(ns) {} + + ScopedUpdate(const ScopedUpdate&) = delete; + ScopedUpdate(ScopedUpdate&&) = delete; + ScopedUpdate& operator=(const ScopedUpdate&) = delete; + ScopedUpdate& operator=(ScopedUpdate&&) = delete; + + ~ScopedUpdate() { + auto s = redis::GlobalIndexer::Update(rr); + EXPECT_EQ(s.Msg(), Status::ok_msg); + } +}; + +std::vector> ScopedUpdates(redis::GlobalIndexer& indexer, + const std::vector& keys, + const std::string& ns) { + std::vector> sus; + + sus.reserve(keys.size()); + for (auto key : keys) { + sus.emplace_back(std::make_unique(indexer, key, ns)); + } + + return sus; +} + +TEST_F(PlanExecutorTestC, NumericFieldScan) { + redis::GlobalIndexer indexer(storage_.get()); + indexer.Add(redis::IndexUpdater(IndexI())); + + { + auto updates = ScopedUpdates(indexer, {"test2:a", "test2:b", "test2:c", "test2:d", "test2:e", "test2:f", "test2:g"}, + "search_ns"); + json_->Set("test2:a", "$", "{\"f2\": 6}"); + json_->Set("test2:b", "$", "{\"f2\": 3}"); + json_->Set("test2:c", "$", "{\"f2\": 8}"); + json_->Set("test2:d", "$", "{\"f2\": 14}"); + json_->Set("test2:e", "$", "{\"f2\": 1}"); + json_->Set("test2:f", "$", "{\"f2\": 3}"); + json_->Set("test2:g", "$", "{\"f2\": 9}"); + } + + { + auto op = std::make_unique(std::make_unique("f2", FieldI("f2")), Interval(3, 9), + SortByClause::ASC); + + auto ctx = ExecutorContext(op.get(), storage_.get()); + ASSERT_EQ(NextRow(ctx).key, "test2:b"); + ASSERT_EQ(NextRow(ctx).key, "test2:f"); + ASSERT_EQ(NextRow(ctx).key, "test2:a"); + ASSERT_EQ(NextRow(ctx).key, "test2:c"); + ASSERT_EQ(ctx.Next().GetValue(), exe_end); + } + + { + auto op = std::make_unique(std::make_unique("f2", FieldI("f2")), Interval(3, 9), + SortByClause::DESC); + + auto ctx = ExecutorContext(op.get(), storage_.get()); + ASSERT_EQ(NextRow(ctx).key, "test2:c"); + ASSERT_EQ(NextRow(ctx).key, "test2:a"); + ASSERT_EQ(NextRow(ctx).key, "test2:f"); + ASSERT_EQ(NextRow(ctx).key, "test2:b"); + ASSERT_EQ(ctx.Next().GetValue(), exe_end); + } +} + +TEST_F(PlanExecutorTestC, TagFieldScan) { + redis::GlobalIndexer indexer(storage_.get()); + indexer.Add(redis::IndexUpdater(IndexI())); + + { + auto updates = ScopedUpdates(indexer, {"test2:a", "test2:b", "test2:c", "test2:d", "test2:e", "test2:f", "test2:g"}, + "search_ns"); + json_->Set("test2:a", "$", "{\"f1\": \"c,cpp,java\"}"); + json_->Set("test2:b", "$", "{\"f1\": \"python,c\"}"); + json_->Set("test2:c", "$", "{\"f1\": \"java,scala\"}"); + json_->Set("test2:d", "$", "{\"f1\": \"rust,python,perl\"}"); + json_->Set("test2:e", "$", "{\"f1\": \"python,cpp\"}"); + json_->Set("test2:f", "$", "{\"f1\": \"c,cpp\"}"); + json_->Set("test2:g", "$", "{\"f1\": \"cpp,rust\"}"); + } + + { + auto op = std::make_unique(std::make_unique("f1", FieldI("f1")), "cpp"); + + auto ctx = ExecutorContext(op.get(), storage_.get()); + ASSERT_EQ(NextRow(ctx).key, "test2:a"); + ASSERT_EQ(NextRow(ctx).key, "test2:e"); + ASSERT_EQ(NextRow(ctx).key, "test2:f"); + ASSERT_EQ(NextRow(ctx).key, "test2:g"); + ASSERT_EQ(ctx.Next().GetValue(), exe_end); + } + + { + auto op = std::make_unique(std::make_unique("f1", FieldI("f1")), "python"); + + auto ctx = ExecutorContext(op.get(), storage_.get()); + ASSERT_EQ(NextRow(ctx).key, "test2:b"); + ASSERT_EQ(NextRow(ctx).key, "test2:d"); + ASSERT_EQ(NextRow(ctx).key, "test2:e"); + ASSERT_EQ(ctx.Next().GetValue(), exe_end); + } +} \ No newline at end of file diff --git a/tests/cppunit/redis_query_parser_test.cc b/tests/cppunit/redis_query_parser_test.cc new file mode 100644 index 00000000000..a31051a58bf --- /dev/null +++ b/tests/cppunit/redis_query_parser_test.cc @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#include +#include + +#include "tao/pegtl/string_input.hpp" + +using namespace kqir::redis_query; + +static auto Parse(const std::string& in) { return ParseToIR(string_input(in, "test")); } + +#define AssertSyntaxError(node) ASSERT_EQ(node.Msg(), "invalid syntax"); // NOLINT + +// NOLINTNEXTLINE +#define AssertIR(node, val) \ + ASSERT_EQ(node.Msg(), Status::ok_msg); \ + ASSERT_EQ(node.GetValue()->Dump(), val); + +TEST(RedisQueryParserTest, Simple) { + AssertSyntaxError(Parse("")); + AssertSyntaxError(Parse("a")); + AssertSyntaxError(Parse("@a")); + AssertSyntaxError(Parse("a:")); + AssertSyntaxError(Parse("@a:")); + AssertSyntaxError(Parse("@a:[]")); + AssertSyntaxError(Parse("@a:[1 2")); + AssertSyntaxError(Parse("@a:[(inf 1]")); + AssertSyntaxError(Parse("@a:[((1 2]")); + AssertSyntaxError(Parse("@a:[1]")); + AssertSyntaxError(Parse("@a:[1 2 3]")); + AssertSyntaxError(Parse("@a:{}")); + AssertSyntaxError(Parse("@a:{x")); + AssertSyntaxError(Parse("@a:{|}")); + AssertSyntaxError(Parse("@a:{x|}")); + AssertSyntaxError(Parse("@a:{|y}")); + AssertSyntaxError(Parse("@a:{x|y|}")); + AssertSyntaxError(Parse("@a:{x}|")); + AssertSyntaxError(Parse("@a:{x} -")); + AssertSyntaxError(Parse("@a:{x}|@a:{x}|")); + + AssertIR(Parse("@a:[1 2]"), "(and a >= 1, a <= 2)"); + AssertIR(Parse("@a : [1 2]"), "(and a >= 1, a <= 2)"); + AssertIR(Parse("@a:[(1 2]"), "(and a > 1, a <= 2)"); + AssertIR(Parse("@a:[1 (2]"), "(and a >= 1, a < 2)"); + AssertIR(Parse("@a:[(1 (2]"), "(and a > 1, a < 2)"); + AssertIR(Parse("@a:[inf 2]"), "a <= 2"); + AssertIR(Parse("@a:[-inf 2]"), "a <= 2"); + AssertIR(Parse("@a:[1 inf]"), "a >= 1"); + AssertIR(Parse("@a:[1 +inf]"), "a >= 1"); + AssertIR(Parse("@a:[(1 +inf]"), "a > 1"); + AssertIR(Parse("@a:[-inf +inf]"), "true"); + AssertIR(Parse("@a:{x}"), "a hastag \"x\""); + AssertIR(Parse("@a:{x|y}"), R"((or a hastag "x", a hastag "y"))"); + AssertIR(Parse("@a:{x|y|z}"), R"((or a hastag "x", a hastag "y", a hastag "z"))"); + AssertIR(Parse(R"(@a:{"x"|y})"), R"((or a hastag "x", a hastag "y"))"); + AssertIR(Parse(R"(@a:{"x" | "y"})"), R"((or a hastag "x", a hastag "y"))"); + AssertIR(Parse("@a:{x} @b:[1 inf]"), "(and a hastag \"x\", b >= 1)"); + AssertIR(Parse("@a:{x} | @b:[1 inf]"), "(or a hastag \"x\", b >= 1)"); + AssertIR(Parse("@a:{x} @b:[1 inf] @c:{y}"), "(and a hastag \"x\", b >= 1, c hastag \"y\")"); + AssertIR(Parse("@a:{x}|@b:[1 inf] | @c:{y}"), "(or a hastag \"x\", b >= 1, c hastag \"y\")"); + AssertIR(Parse("@a:[1 inf] @b:[inf 2]| @c:[(3 inf]"), "(or (and a >= 1, b <= 2), c > 3)"); + AssertIR(Parse("@a:[1 inf] | @b:[inf 2] @c:[(3 inf]"), "(or a >= 1, (and b <= 2, c > 3))"); + AssertIR(Parse("(@a:[1 inf] @b:[inf 2])| @c:[(3 inf]"), "(or (and a >= 1, b <= 2), c > 3)"); + AssertIR(Parse("@a:[1 inf] | (@b:[inf 2] @c:[(3 inf])"), "(or a >= 1, (and b <= 2, c > 3))"); + AssertIR(Parse("@a:[1 inf] (@b:[inf 2]| @c:[(3 inf])"), "(and a >= 1, (or b <= 2, c > 3))"); + AssertIR(Parse("(@a:[1 inf] | @b:[inf 2]) @c:[(3 inf]"), "(and (or a >= 1, b <= 2), c > 3)"); + AssertIR(Parse("-@a:{x}"), "not a hastag \"x\""); + AssertIR(Parse("-@a:[(1 +inf]"), "not a > 1"); + AssertIR(Parse("-@a:[1 inf] @b:[inf 2]| -@c:[(3 inf]"), "(or (and not a >= 1, b <= 2), not c > 3)"); + AssertIR(Parse("@a:[1 inf] -(@b:[inf 2]| @c:[(3 inf])"), "(and a >= 1, not (or b <= 2, c > 3))"); + AssertIR(Parse("*"), "true"); + AssertIR(Parse("* *"), "(and true, true)"); + AssertIR(Parse("*|*"), "(or true, true)"); +} diff --git a/tests/cppunit/sql_parser_test.cc b/tests/cppunit/sql_parser_test.cc new file mode 100644 index 00000000000..a566d965411 --- /dev/null +++ b/tests/cppunit/sql_parser_test.cc @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#include +#include + +#include "tao/pegtl/string_input.hpp" + +using namespace kqir::sql; + +static auto Parse(const std::string& in) { return ParseToIR(string_input(in, "test")); } + +#define AssertSyntaxError(node) ASSERT_EQ(node.Msg(), "invalid syntax"); // NOLINT + +// NOLINTNEXTLINE +#define AssertIR(node, val) \ + ASSERT_EQ(node.Msg(), Status::ok_msg); \ + ASSERT_EQ(node.GetValue()->Dump(), val); + +TEST(SQLParserTest, Simple) { + AssertSyntaxError(Parse("x")); + AssertSyntaxError(Parse("1")); + AssertSyntaxError(Parse("select")); + AssertSyntaxError(Parse("where")); + AssertSyntaxError(Parse("limit")); + AssertSyntaxError(Parse("from a")); + AssertSyntaxError(Parse("select 0 from")); + AssertSyntaxError(Parse("select 0 from b")); + AssertSyntaxError(Parse("select a from 123")); + AssertSyntaxError(Parse("select a from \"b\"")); + AssertSyntaxError(Parse("select a from b, c")); + AssertSyntaxError(Parse("select a from b where")); + AssertSyntaxError(Parse("select a from b hello")); + AssertSyntaxError(Parse("select a from b where")); + AssertSyntaxError(Parse("select a from b where 1")); + AssertSyntaxError(Parse("select a from b where 0")); + AssertSyntaxError(Parse("select a from b where \"x\"")); + AssertSyntaxError(Parse("select a from b where limit 10")); + AssertSyntaxError(Parse("select a from b where true and")); + AssertSyntaxError(Parse("select a from b where (true")); + AssertSyntaxError(Parse("select a from b where (true))")); + AssertSyntaxError(Parse("select a from b where 1 >")); + AssertSyntaxError(Parse("select a from b where x =")); + AssertSyntaxError(Parse("select a from b where x hastag")); + AssertSyntaxError(Parse("select a from b where =")); + AssertSyntaxError(Parse("select a from b where hastag x")); + AssertSyntaxError(Parse("select a from b where = 1")); + AssertSyntaxError(Parse("select a from b where x hashtag \"")); + AssertSyntaxError(Parse(R"(select a from b where x hashtag "\p")")); + AssertSyntaxError(Parse(R"(select a from b where x hashtag "\u11")")); + AssertSyntaxError(Parse(R"(select a from b where x hashtag "\")")); + AssertSyntaxError(Parse(R"(select a from b where x hashtag "abc)")); + AssertSyntaxError(Parse("select a from b where limit 10")); + AssertSyntaxError(Parse("select a from b limit 1, 1, 1")); + AssertSyntaxError(Parse("select a from b limit -10")); + AssertSyntaxError(Parse("select a from b limit")); + AssertSyntaxError(Parse("select a from b order")); + AssertSyntaxError(Parse("select a from b order by")); + AssertSyntaxError(Parse("select a from b order by a bsc")); + AssertSyntaxError(Parse("select a from b order a")); + AssertSyntaxError(Parse("select a from b order asc")); + AssertSyntaxError(Parse("select a from b order by a limit")); + + AssertIR(Parse("select a from b"), "select a from b where true"); + AssertIR(Parse(" select a from b "), "select a from b where true"); + AssertIR(Parse("\nselect\n a\t \tfrom \n\nb "), "select a from b where true"); + AssertIR(Parse("select * from b"), "select * from b where true"); + AssertIR(Parse("select a, b from c"), "select a, b from c where true"); + AssertIR(Parse("select a, b, c from d"), "select a, b, c from d where true"); + AssertIR(Parse("select xY_z12_3 , X00 from b"), "select xY_z12_3, X00 from b where true"); + AssertIR(Parse("select a from b where true"), "select a from b where true"); + AssertIR(Parse("select a from b where false"), "select a from b where false"); + AssertIR(Parse("select a from b where true and true"), "select a from b where (and true, true)"); + AssertIR(Parse("select a from b where false and true and false"), "select a from b where (and false, true, false)"); + AssertIR(Parse("select a from b where false or true"), "select a from b where (or false, true)"); + AssertIR(Parse("select a from b where true or false or true"), "select a from b where (or true, false, true)"); + AssertIR(Parse("select a from b where false and true or false"), + "select a from b where (or (and false, true), false)"); + AssertIR(Parse("select a from b where false or true and false"), + "select a from b where (or false, (and true, false))"); + AssertIR(Parse("select a from b where false and (true or false)"), + "select a from b where (and false, (or true, false))"); + AssertIR(Parse("select a from b where (false or true) and false"), + "select a from b where (and (or false, true), false)"); + AssertIR(Parse("select a from b where (false)"), "select a from b where false"); + AssertIR(Parse("select a from b where ((false))"), "select a from b where false"); + AssertIR(Parse("select a from b where (((false)))"), "select a from b where false"); + AssertIR(Parse("select a from b where ((false) and (false))"), "select a from b where (and false, false)"); + AssertIR(Parse("select a from b where x=1"), "select a from b where x = 1"); + AssertIR(Parse("select a from b where x = 1.0"), "select a from b where x = 1"); + AssertIR(Parse("select a from b where x = -1.234e5"), "select a from b where x = -123400"); + AssertIR(Parse("select a from b where x = -1.234e-5"), "select a from b where x = -1.234e-05"); + AssertIR(Parse("select a from b where x = 222e+5"), "select a from b where x = 22200000"); + AssertIR(Parse("select a from b where 1 = x"), "select a from b where x = 1"); + AssertIR(Parse("select a from b where 2 < y"), "select a from b where y > 2"); + AssertIR(Parse("select a from b where y > 2"), "select a from b where y > 2"); + AssertIR(Parse("select a from b where 3 >= z"), "select a from b where z <= 3"); + AssertIR(Parse("select a from b where x hastag \"hi\""), "select a from b where x hastag \"hi\""); + AssertIR(Parse(R"(select a from b where x hastag "a\nb")"), R"(select a from b where x hastag "a\nb")"); + AssertIR(Parse(R"(select a from b where x hastag "")"), R"(select a from b where x hastag "")"); + AssertIR(Parse(R"(select a from b where x hastag "hello , hi")"), R"(select a from b where x hastag "hello , hi")"); + AssertIR(Parse(R"(select a from b where x hastag "a\nb\t\n")"), R"(select a from b where x hastag "a\nb\t\n")"); + AssertIR(Parse(R"(select a from b where x hastag "a\u0000")"), R"(select a from b where x hastag "a\x00")"); + AssertIR(Parse("select a from b where x > 1 and y < 33"), "select a from b where (and x > 1, y < 33)"); + AssertIR(Parse("select a from b where x >= 1 and y hastag \"hi\" or c <= 233"), + "select a from b where (or (and x >= 1, y hastag \"hi\"), c <= 233)"); + AssertIR(Parse("select a from b limit 10"), "select a from b where true limit 0, 10"); + AssertIR(Parse("select a from b limit 2, 3"), "select a from b where true limit 2, 3"); + AssertIR(Parse("select a from b order by a"), "select a from b where true sortby a, asc"); + AssertIR(Parse("select a from b order by c desc"), "select a from b where true sortby c, desc"); + AssertIR(Parse("select a from b order by a limit 10"), "select a from b where true sortby a, asc limit 0, 10"); + AssertIR(Parse("select a from b where c = 1 limit 10"), "select a from b where c = 1 limit 0, 10"); + AssertIR(Parse("select a from b where c = 1 and d hastag \"x\" order by e"), + "select a from b where (and c = 1, d hastag \"x\") sortby e, asc"); + AssertIR(Parse("select a from b where c = 1 or d hastag \"x\" and 2 <= e order by e asc limit 0, 10"), + "select a from b where (or c = 1, (and d hastag \"x\", e >= 2)) sortby e, asc limit 0, 10"); +} diff --git a/tests/cppunit/types/bitmap_test.cc b/tests/cppunit/types/bitmap_test.cc index 4795e476ad0..6ec2d9e39a1 100644 --- a/tests/cppunit/types/bitmap_test.cc +++ b/tests/cppunit/types/bitmap_test.cc @@ -179,7 +179,7 @@ TEST_P(RedisBitmapTest, BitPosClearBit) { /// /// String will set a empty string value when initializing, so, when first /// querying, it should return -1. - bitmap_->BitPos(key_, false, 0, -1, /*stop_given=*/false, &pos); + bitmap_->BitPos(key_, false, 0, -1, /*stop_given=*/false, &pos, /*bit_index=*/false); if (i == 0 && !use_bitmap) { EXPECT_EQ(pos, -1); } else { @@ -201,7 +201,7 @@ TEST_P(RedisBitmapTest, BitPosSetBit) { int64_t pos = 0; int start_indexes[] = {0, 1, 124, 1025, 1027, 3 * 1024 + 1}; for (size_t i = 0; i < sizeof(start_indexes) / sizeof(start_indexes[0]); i++) { - bitmap_->BitPos(key_, true, start_indexes[i], -1, true, &pos); + bitmap_->BitPos(key_, true, start_indexes[i], -1, true, &pos, /*bit_index=*/false); EXPECT_EQ(pos, offsets[i]); } auto s = bitmap_->Del(key_); @@ -215,19 +215,19 @@ TEST_P(RedisBitmapTest, BitPosNegative) { } int64_t pos = 0; // First bit is negative - bitmap_->BitPos(key_, false, 0, -1, true, &pos); + bitmap_->BitPos(key_, false, 0, -1, true, &pos, /*bit_index=*/false); EXPECT_EQ(0, pos); // 8 * 1024 - 1 bit is positive - bitmap_->BitPos(key_, true, 0, -1, true, &pos); + bitmap_->BitPos(key_, true, 0, -1, true, &pos, /*bit_index=*/false); EXPECT_EQ(8 * 1024 - 1, pos); // First bit in 1023 byte is negative - bitmap_->BitPos(key_, false, -1, -1, true, &pos); + bitmap_->BitPos(key_, false, -1, -1, true, &pos, /*bit_index=*/false); EXPECT_EQ(8 * 1023, pos); // Last Bit in 1023 byte is positive - bitmap_->BitPos(key_, true, -1, -1, true, &pos); + bitmap_->BitPos(key_, true, -1, -1, true, &pos, /*bit_index=*/false); EXPECT_EQ(8 * 1024 - 1, pos); // Large negative number will be normalized. - bitmap_->BitPos(key_, false, -10000, -10000, true, &pos); + bitmap_->BitPos(key_, false, -10000, -10000, true, &pos, /*bit_index=*/false); EXPECT_EQ(0, pos); auto s = bitmap_->Del(key_); @@ -242,9 +242,9 @@ TEST_P(RedisBitmapTest, BitPosStopGiven) { EXPECT_FALSE(bit); } int64_t pos = 0; - bitmap_->BitPos(key_, false, 0, 0, /*stop_given=*/true, &pos); + bitmap_->BitPos(key_, false, 0, 0, /*stop_given=*/true, &pos, /*bit_index=*/false); EXPECT_EQ(-1, pos); - bitmap_->BitPos(key_, false, 0, 0, /*stop_given=*/false, &pos); + bitmap_->BitPos(key_, false, 0, 0, /*stop_given=*/false, &pos, /*bit_index=*/false); EXPECT_EQ(8, pos); // Set a bit at 8 not affect that @@ -253,9 +253,9 @@ TEST_P(RedisBitmapTest, BitPosStopGiven) { bitmap_->SetBit(key_, 8, true, &bit); EXPECT_FALSE(bit); } - bitmap_->BitPos(key_, false, 0, 0, /*stop_given=*/true, &pos); + bitmap_->BitPos(key_, false, 0, 0, /*stop_given=*/true, &pos, /*bit_index=*/false); EXPECT_EQ(-1, pos); - bitmap_->BitPos(key_, false, 0, 1, /*stop_given=*/false, &pos); + bitmap_->BitPos(key_, false, 0, 1, /*stop_given=*/false, &pos, /*bit_index=*/false); EXPECT_EQ(9, pos); auto s = bitmap_->Del(key_); diff --git a/tests/cppunit/types/json_test.cc b/tests/cppunit/types/json_test.cc index 82bfa944097..30d6f60ef53 100644 --- a/tests/cppunit/types/json_test.cc +++ b/tests/cppunit/types/json_test.cc @@ -96,6 +96,14 @@ TEST_F(RedisJsonTest, Set) { ASSERT_EQ(json_val_.Dump().GetValue(), "[{},[]]"); ASSERT_THAT(json_->Set(key_, "$[1]", "invalid").ToString(), MatchesRegex(".*syntax_error.*")); ASSERT_TRUE(json_->Del(key_, "$", &result).ok()); + + ASSERT_TRUE(json_->Set(key_, "$", R"({"a":1})").ok()); + ASSERT_TRUE(json_->Set(key_, "$.b", "2").ok()); + ASSERT_TRUE(json_->Set(key_, "$.c", R"({"x":3})").ok()); + ASSERT_TRUE(json_->Set(key_, "$.c.y", "4").ok()); + + ASSERT_TRUE(json_->Get(key_, {}, &json_val_).ok()); + ASSERT_EQ(json_val_.value, jsoncons::json::parse(R"({"a":1,"b":2,"c":{"x":3,"y":4}})")); } TEST_F(RedisJsonTest, Get) { @@ -612,7 +620,7 @@ TEST_F(RedisJsonTest, NumMultBy) { ASSERT_EQ(res.Print(0, true).GetValue(), "[2]"); res.value.clear(); ASSERT_TRUE(json_->NumMultBy(key_, "$.foo", "0.5", &res).ok()); - ASSERT_EQ(res.Print(0, true).GetValue(), "[1.0]"); + ASSERT_EQ(res.Print(0, true).GetValue(), "[1]"); res.value.clear(); ASSERT_TRUE(json_->NumMultBy(key_, "$.bar", "1", &res).ok()); @@ -626,7 +634,7 @@ TEST_F(RedisJsonTest, NumMultBy) { // num object ASSERT_TRUE(json_->Set(key_, "$", "1.0").ok()); ASSERT_TRUE(json_->NumMultBy(key_, "$", "1", &res).ok()); - ASSERT_EQ(res.Print(0, true).GetValue(), "[1.0]"); + ASSERT_EQ(res.Print(0, true).GetValue(), "[1]"); res.value.clear(); ASSERT_TRUE(json_->NumMultBy(key_, "$", "1.5", &res).ok()); ASSERT_EQ(res.Print(0, true).GetValue(), "[1.5]"); diff --git a/tests/cppunit/types/string_test.cc b/tests/cppunit/types/string_test.cc index 1c850a714c9..3fed5ce8944 100644 --- a/tests/cppunit/types/string_test.cc +++ b/tests/cppunit/types/string_test.cc @@ -23,6 +23,7 @@ #include #include "test_base.h" +#include "time_util.h" #include "types/redis_string.h" class RedisStringTest : public TestBase { @@ -68,7 +69,7 @@ TEST_F(RedisStringTest, GetAndSet) { } TEST_F(RedisStringTest, MGetAndMSet) { - string_->MSet(pairs_); + string_->MSet(pairs_, 0); std::vector keys; std::vector values; keys.reserve(pairs_.size()); @@ -172,10 +173,10 @@ TEST_F(RedisStringTest, GetDel) { TEST_F(RedisStringTest, MSetXX) { bool flag = false; - string_->SetXX(key_, "test-value", 3000, &flag); + string_->SetXX(key_, "test-value", util::GetTimeStampMS() + 3000, &flag); EXPECT_FALSE(flag); string_->Set(key_, "test-value"); - string_->SetXX(key_, "test-value", 3000, &flag); + string_->SetXX(key_, "test-value", util::GetTimeStampMS() + 3000, &flag); EXPECT_TRUE(flag); int64_t ttl = 0; auto s = string_->TTL(key_, &ttl); @@ -211,7 +212,7 @@ TEST_F(RedisStringTest, MSetNX) { TEST_F(RedisStringTest, MSetNXWithTTL) { bool flag = false; - string_->SetNX(key_, "test-value", 3000, &flag); + string_->SetNX(key_, "test-value", util::GetTimeStampMS() + 3000, &flag); int64_t ttl = 0; auto s = string_->TTL(key_, &ttl); EXPECT_TRUE(ttl >= 2000 && ttl <= 4000); @@ -219,7 +220,7 @@ TEST_F(RedisStringTest, MSetNXWithTTL) { } TEST_F(RedisStringTest, SetEX) { - string_->SetEX(key_, "test-value", 3000); + string_->SetEX(key_, "test-value", util::GetTimeStampMS() + 3000); int64_t ttl = 0; auto s = string_->TTL(key_, &ttl); EXPECT_TRUE(ttl >= 2000 && ttl <= 4000); @@ -258,15 +259,15 @@ TEST_F(RedisStringTest, CAS) { auto status = string_->Set(key, value); ASSERT_TRUE(status.ok()); - status = string_->CAS("non_exist_key", value, new_value, 10000, &flag); + status = string_->CAS("non_exist_key", value, new_value, util::GetTimeStampMS() + 10000, &flag); ASSERT_TRUE(status.ok()); EXPECT_EQ(-1, flag); - status = string_->CAS(key, "cas_value_err", new_value, 10000, &flag); + status = string_->CAS(key, "cas_value_err", new_value, util::GetTimeStampMS() + 10000, &flag); ASSERT_TRUE(status.ok()); EXPECT_EQ(0, flag); - status = string_->CAS(key, value, new_value, 10000, &flag); + status = string_->CAS(key, value, new_value, util::GetTimeStampMS() + 10000, &flag); ASSERT_TRUE(status.ok()); EXPECT_EQ(1, flag); diff --git a/tests/gocase/integration/cluster/cluster_test.go b/tests/gocase/integration/cluster/cluster_test.go index 8bc42fdafdb..fad7d0128f7 100644 --- a/tests/gocase/integration/cluster/cluster_test.go +++ b/tests/gocase/integration/cluster/cluster_test.go @@ -24,10 +24,12 @@ import ( "fmt" "strings" "testing" + "time" - "github.com/apache/kvrocks/tests/gocase/util" "github.com/redis/go-redis/v9" "github.com/stretchr/testify/require" + + "github.com/apache/kvrocks/tests/gocase/util" ) func TestDisableCluster(t *testing.T) { @@ -130,6 +132,61 @@ func TestClusterNodes(t *testing.T) { }) } +func TestClusterReplicas(t *testing.T) { + srv := util.StartServer(t, map[string]string{"cluster-enabled": "yes"}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + nodes := "" + + master1ID := "bb2e5b3c5282086df51eff6b3e35519aede96fa6" + master1Node := fmt.Sprintf("%s %s %d master - 0-8191", master1ID, srv.Host(), srv.Port()) + nodes += master1Node + "\n" + + master2ID := "159dde1194ebf5bfc5a293dff839c3d1476f2a49" + master2Node := fmt.Sprintf("%s %s %d master - 8192-16383", master2ID, srv.Host(), srv.Port()) + nodes += master2Node + "\n" + + replica2ID := "7dbee3d628f04cc5d763b36e92b10533e627a1d0" + replica2Node := fmt.Sprintf("%s %s %d slave %s", replica2ID, srv.Host(), srv.Port(), master2ID) + nodes += replica2Node + + require.NoError(t, rdb.Do(ctx, "clusterx", "SETNODES", nodes, "2").Err()) + require.EqualValues(t, "2", rdb.Do(ctx, "clusterx", "version").Val()) + + t.Run("with replicas", func(t *testing.T) { + replicas, err := rdb.Do(ctx, "cluster", "replicas", "159dde1194ebf5bfc5a293dff839c3d1476f2a49").Text() + require.NoError(t, err) + fields := strings.Split(replicas, " ") + require.Len(t, fields, 8) + require.Equal(t, fmt.Sprintf("%s@%d", srv.HostPort(), srv.Port()+10000), fields[1]) + require.Equal(t, "slave", fields[2]) + require.Equal(t, master2ID, fields[3]) + require.Equal(t, "connected\n", fields[7]) + }) + + t.Run("without replicas", func(t *testing.T) { + replicas, err := rdb.Do(ctx, "cluster", "replicas", "bb2e5b3c5282086df51eff6b3e35519aede96fa6").Text() + require.NoError(t, err) + require.Empty(t, replicas) + }) + + t.Run("send command to replica", func(t *testing.T) { + err := rdb.Do(ctx, "cluster", "replicas", "7dbee3d628f04cc5d763b36e92b10533e627a1d0").Err() + require.Error(t, err) + require.Contains(t, err.Error(), "The node isn't a master") + }) + + t.Run("unknown node", func(t *testing.T) { + err := rdb.Do(ctx, "cluster", "replicas", "unknown").Err() + require.Error(t, err) + require.Contains(t, err.Error(), "Invalid cluster node id") + }) +} + func TestClusterDumpAndLoadClusterNodesInfo(t *testing.T) { srv1 := util.StartServer(t, map[string]string{ "bind": "0.0.0.0", @@ -305,6 +362,15 @@ func TestClusterMultiple(t *testing.T) { require.NoError(t, rdb[i].Do(ctx, "clusterx", "setnodes", clusterNodes, "1").Err()) } + t.Run("check if the node id is correct", func(t *testing.T) { + // only node1, node2 and node3 was the member of the cluster + for i := 1; i < 4; i++ { + myid, err := rdb[i].Do(ctx, "clusterx", "myid").Text() + require.NoError(t, err) + require.Equal(t, nodeID[i], myid) + } + }) + t.Run("cluster info command", func(t *testing.T) { r := rdb[1].ClusterInfo(ctx).Val() require.Contains(t, r, "cluster_state:ok") @@ -334,6 +400,11 @@ func TestClusterMultiple(t *testing.T) { require.ErrorContains(t, rdb[3].Set(ctx, util.SlotTable[16383], 16383, 0).Err(), "MOVED") // request a read-only command to node3 that serve slot 16383, that's ok util.WaitForOffsetSync(t, rdb[2], rdb[3]) + //the default option is READWRITE, which will redirect both read and write to master + require.ErrorContains(t, rdb[3].Get(ctx, util.SlotTable[16383]).Err(), "MOVED") + + require.NoError(t, rdb[3].Do(ctx, "READONLY").Err()) + require.Equal(t, "16383", rdb[3].Get(ctx, util.SlotTable[16383]).Val()) }) @@ -369,4 +440,115 @@ func TestClusterMultiple(t *testing.T) { require.ErrorContains(t, rdb[1].Do(ctx, "EXEC").Err(), "EXECABORT") require.Equal(t, "no-multi", rdb[1].Get(ctx, util.SlotTable[0]).Val()) }) + + t.Run("requests on cluster are ok when enable readonly", func(t *testing.T) { + + require.NoError(t, rdb[3].Do(ctx, "READONLY").Err()) + require.NoError(t, rdb[2].Set(ctx, util.SlotTable[8192], 8192, 0).Err()) + util.WaitForOffsetSync(t, rdb[2], rdb[3]) + // request node3 that serves slot 8192, that's ok + require.Equal(t, "8192", rdb[3].Get(ctx, util.SlotTable[8192]).Val()) + + require.NoError(t, rdb[3].Do(ctx, "READWRITE").Err()) + + // when enable READWRITE, request node3 that serves slot 8192, that's not ok + util.ErrorRegexp(t, rdb[3].Get(ctx, util.SlotTable[8192]).Err(), fmt.Sprintf("MOVED 8192.*%d.*", srv[2].Port())) + }) +} + +func TestClusterReset(t *testing.T) { + ctx := context.Background() + + srv0 := util.StartServer(t, map[string]string{"cluster-enabled": "yes"}) + defer func() { srv0.Close() }() + rdb0 := srv0.NewClientWithOption(&redis.Options{PoolSize: 1}) + defer func() { require.NoError(t, rdb0.Close()) }() + id0 := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx00" + require.NoError(t, rdb0.Do(ctx, "clusterx", "SETNODEID", id0).Err()) + + srv1 := util.StartServer(t, map[string]string{"cluster-enabled": "yes"}) + defer func() { srv1.Close() }() + rdb1 := srv1.NewClientWithOption(&redis.Options{PoolSize: 1}) + defer func() { require.NoError(t, rdb1.Close()) }() + id1 := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx01" + require.NoError(t, rdb1.Do(ctx, "clusterx", "SETNODEID", id1).Err()) + + srv2 := util.StartServer(t, map[string]string{"cluster-enabled": "yes"}) + defer func() { srv2.Close() }() + rdb2 := srv2.NewClientWithOption(&redis.Options{PoolSize: 1}) + defer func() { require.NoError(t, rdb2.Close()) }() + id2 := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx02" + require.NoError(t, rdb1.Do(ctx, "clusterx", "SETNODEID", id1).Err()) + + clusterNodes := fmt.Sprintf("%s %s %d master - 0-8191\n", id0, srv0.Host(), srv0.Port()) + clusterNodes += fmt.Sprintf("%s %s %d master - 8192-16383\n", id1, srv1.Host(), srv1.Port()) + clusterNodes += fmt.Sprintf("%s %s %d slave %s", id2, srv2.Host(), srv2.Port(), id1) + require.NoError(t, rdb0.Do(ctx, "clusterx", "SETNODES", clusterNodes, "1").Err()) + require.NoError(t, rdb1.Do(ctx, "clusterx", "SETNODES", clusterNodes, "1").Err()) + require.NoError(t, rdb2.Do(ctx, "clusterx", "SETNODES", clusterNodes, "1").Err()) + + t.Run("cannot reset cluster if the db is not empty", func(t *testing.T) { + key := util.SlotTable[0] + require.NoError(t, rdb0.Set(ctx, key, "value", 0).Err()) + require.Contains(t, rdb0.ClusterResetHard(ctx).Err(), "Can't reset cluster while database is not empty") + require.NoError(t, rdb0.Del(ctx, key).Err()) + require.NoError(t, rdb0.ClusterResetSoft(ctx).Err()) + require.EqualValues(t, "-1", rdb0.Do(ctx, "clusterx", "version").Val()) + // reset the cluster topology to avoid breaking other test cases + require.NoError(t, rdb0.Do(ctx, "clusterx", "SETNODES", clusterNodes, "1").Err()) + }) + + t.Run("replica should become master after reset", func(t *testing.T) { + require.Eventually(t, func() bool { + return util.FindInfoEntry(rdb2, "role") == "slave" + }, 5*time.Second, 50*time.Millisecond) + require.NoError(t, rdb2.ClusterResetHard(ctx).Err()) + require.Equal(t, "master", util.FindInfoEntry(rdb2, "role")) + require.NoError(t, rdb0.Do(ctx, "clusterx", "SETNODES", clusterNodes, "1").Err()) + }) + + t.Run("cannot reset cluster if the db is importing the slot", func(t *testing.T) { + slotNum := 1 + require.Equal(t, "OK", rdb1.Do(ctx, "cluster", "import", slotNum, 0).Val()) + clusterInfo := rdb1.ClusterInfo(ctx).Val() + require.Contains(t, clusterInfo, "importing_slot: 1") + require.Contains(t, clusterInfo, "import_state: start") + require.Contains(t, rdb1.ClusterResetHard(ctx).Err(), "Can't reset cluster while importing slot") + require.Equal(t, "OK", rdb1.Do(ctx, "cluster", "import", slotNum, 1).Val()) + clusterInfo = rdb1.ClusterInfo(ctx).Val() + require.Contains(t, clusterInfo, "import_state: success") + require.NoError(t, rdb0.ClusterResetHard(ctx).Err()) + require.EqualValues(t, "-1", rdb0.Do(ctx, "clusterx", "version").Val()) + // reset the cluster topology to avoid breaking other test cases + require.NoError(t, rdb0.Do(ctx, "clusterx", "SETNODES", clusterNodes, "1").Err()) + }) + + t.Run("cannot reset cluster if the db is migrating the slot", func(t *testing.T) { + slotNum := 2 + // slow down the migration speed to avoid breaking other test cases + require.NoError(t, rdb0.ConfigSet(ctx, "migrate-speed", "128").Err()) + for i := 0; i < 1024; i++ { + require.NoError(t, rdb0.RPush(ctx, "my-list", fmt.Sprintf("element%d", i)).Err()) + } + + require.Equal(t, "OK", rdb0.Do(ctx, "clusterx", "migrate", slotNum, id1).Val()) + clusterInfo := rdb0.ClusterInfo(ctx).Val() + require.Contains(t, clusterInfo, "migrating_slot: 2") + require.Contains(t, clusterInfo, "migrating_state: start") + require.Contains(t, rdb0.ClusterResetHard(ctx).Err(), "Can't reset cluster while migrating slot") + + // wait for the migration to finish + require.Eventually(t, func() bool { + clusterInfo := rdb0.ClusterInfo(context.Background()).Val() + return strings.Contains(clusterInfo, fmt.Sprintf("migrating_state: %s", "success")) + }, 10*time.Second, 100*time.Millisecond) + // Need to flush keys in the source node since the success migration will not mean + // the keys are removed from the source node right now. + require.NoError(t, rdb0.FlushAll(ctx).Err()) + + require.NoError(t, rdb0.ClusterResetHard(ctx).Err()) + require.EqualValues(t, "-1", rdb0.Do(ctx, "clusterx", "version").Val()) + // reset the cluster topology to avoid breaking other test cases + require.NoError(t, rdb0.Do(ctx, "clusterx", "SETNODES", clusterNodes, "1").Err()) + }) } diff --git a/tests/gocase/integration/replication/replication_test.go b/tests/gocase/integration/replication/replication_test.go index 20de35819b0..71e08c77f61 100644 --- a/tests/gocase/integration/replication/replication_test.go +++ b/tests/gocase/integration/replication/replication_test.go @@ -32,6 +32,64 @@ import ( "github.com/stretchr/testify/require" ) +func TestClusterReplication(t *testing.T) { + ctx := context.Background() + + masterSrv := util.StartServer(t, map[string]string{"cluster-enabled": "yes"}) + defer func() { masterSrv.Close() }() + masterClient := masterSrv.NewClient() + defer func() { require.NoError(t, masterClient.Close()) }() + masterNodeID := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx00" + require.NoError(t, masterClient.Do(ctx, "clusterx", "SETNODEID", masterNodeID).Err()) + + replicaSrv := util.StartServer(t, map[string]string{ + "cluster-enabled": "yes", + // enabled the replication namespace to reproduce the issue #2214 + "repl-namespace-enabled": "yes", + }) + defer func() { replicaSrv.Close() }() + replicaClient := replicaSrv.NewClient() + // allow to run the read-only command in the replica + require.NoError(t, replicaClient.ReadOnly(ctx).Err()) + defer func() { require.NoError(t, replicaClient.Close()) }() + replicaNodeID := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx01" + require.NoError(t, replicaClient.Do(ctx, "clusterx", "SETNODEID", replicaNodeID).Err()) + + clusterNodes := fmt.Sprintf("%s 127.0.0.1 %d master - 0-16383", masterNodeID, masterSrv.Port()) + clusterNodes = fmt.Sprintf("%s\n%s 127.0.0.1 %d slave %s", clusterNodes, replicaNodeID, replicaSrv.Port(), masterNodeID) + + require.NoError(t, masterClient.Do(ctx, "clusterx", "SETNODES", clusterNodes, "1").Err()) + require.NoError(t, replicaClient.Do(ctx, "clusterx", "SETNODES", clusterNodes, "1").Err()) + + t.Run("Cluster replication should work", func(t *testing.T) { + util.WaitForSync(t, replicaClient) + require.Equal(t, "slave", util.FindInfoEntry(replicaClient, "role")) + masterClient.Set(ctx, "k0", "v0", 0) + masterClient.LPush(ctx, "k1", "e0", "e1", "e2") + util.WaitForOffsetSync(t, masterClient, replicaClient) + + require.Equal(t, "v0", replicaClient.Get(ctx, "k0").Val()) + require.Equal(t, []string{"e2", "e1", "e0"}, replicaClient.LRange(ctx, "k1", 0, -1).Val()) + }) + + t.Run("Cluster replication should work normally after restart(issue #2214)", func(t *testing.T) { + replicaSrv.Close() + masterClient.Set(ctx, "k0", "v1", 0) + masterClient.HSet(ctx, "k2", "f0", "v0", "f1", "v1") + + // start the replica server again + replicaSrv.Start() + _ = replicaClient.Close() + replicaClient = replicaSrv.NewClient() + // allow to run the read-only command in the replica + require.NoError(t, replicaClient.ReadOnly(ctx).Err()) + + util.WaitForOffsetSync(t, masterClient, replicaClient) + require.Equal(t, "v1", replicaClient.Get(ctx, "k0").Val()) + require.Equal(t, map[string]string{"f0": "v0", "f1": "v1"}, replicaClient.HGetAll(ctx, "k2").Val()) + }) +} + func TestReplicationWithHostname(t *testing.T) { srvA := util.StartServer(t, map[string]string{}) defer srvA.Close() diff --git a/tests/gocase/integration/slotimport/slotimport_test.go b/tests/gocase/integration/slotimport/slotimport_test.go index b86f9275e58..a3566cabfe6 100644 --- a/tests/gocase/integration/slotimport/slotimport_test.go +++ b/tests/gocase/integration/slotimport/slotimport_test.go @@ -22,6 +22,7 @@ package slotimport import ( "context" "fmt" + "strings" "testing" "time" @@ -61,23 +62,39 @@ func TestImportSlaveServer(t *testing.T) { func TestImportedServer(t *testing.T) { ctx := context.Background() - srv := util.StartServer(t, map[string]string{"cluster-enabled": "yes"}) - defer func() { srv.Close() }() - rdb := srv.NewClient() - defer func() { require.NoError(t, rdb.Close()) }() - srvID := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx00" - clusterNodes := fmt.Sprintf("%s 127.0.0.1 %d master - 0-16383", srvID, srv.Port()) - require.NoError(t, rdb.Do(ctx, "clusterx", "SETNODEID", srvID).Err()) - require.NoError(t, rdb.Do(ctx, "clusterx", "SETNODES", clusterNodes, "1").Err()) + srvA := util.StartServer(t, map[string]string{"cluster-enabled": "yes"}) + defer func() { srvA.Close() }() + rdbA := srvA.NewClient() + defer func() { require.NoError(t, rdbA.Close()) }() + srvAID := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx00" + require.NoError(t, rdbA.Do(ctx, "clusterx", "SETNODEID", srvAID).Err()) + + srvB := util.StartServer(t, map[string]string{"cluster-enabled": "yes"}) + defer func() { srvB.Close() }() + rdbB := srvB.NewClient() + defer func() { require.NoError(t, rdbB.Close()) }() + srvBID := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx01" + require.NoError(t, rdbB.Do(ctx, "clusterx", "SETNODEID", srvBID).Err()) + + clusterNodes := fmt.Sprintf("%s 127.0.0.1 %d master - 0-8191", srvAID, srvA.Port()) + clusterNodes = fmt.Sprintf("%s\n%s 127.0.0.1 %d master - 8192-16383", clusterNodes, srvBID, srvB.Port()) + + require.NoError(t, rdbA.Do(ctx, "clusterx", "SETNODES", clusterNodes, "1").Err()) + require.NoError(t, rdbB.Do(ctx, "clusterx", "SETNODES", clusterNodes, "1").Err()) t.Run("IMPORT - error slot", func(t *testing.T) { - require.ErrorContains(t, rdb.Do(ctx, "cluster", "import", -1, 0).Err(), "Slot is out of range") - require.ErrorContains(t, rdb.Do(ctx, "cluster", "import", 16384, 0).Err(), "Slot is out of range") + require.ErrorContains(t, rdbA.Do(ctx, "cluster", "import", -1, 0).Err(), "Slot is out of range") + require.ErrorContains(t, rdbA.Do(ctx, "cluster", "import", 16384, 0).Err(), "Slot is out of range") }) t.Run("IMPORT - slot with error state", func(t *testing.T) { - require.ErrorContains(t, rdb.Do(ctx, "cluster", "import", 1, 4).Err(), "Invalid import state") - require.ErrorContains(t, rdb.Do(ctx, "cluster", "import", 1, -3).Err(), "Invalid import state") + require.ErrorContains(t, rdbA.Do(ctx, "cluster", "import", 1, 4).Err(), "Invalid import state") + require.ErrorContains(t, rdbA.Do(ctx, "cluster", "import", 1, -3).Err(), "Invalid import state") + }) + + t.Run("IMPORT - slot with wrong state", func(t *testing.T) { + require.Contains(t, rdbA.Do(ctx, "cluster", "import", 1, 0).Err(), + "Can't import slot which belongs to me") }) t.Run("IMPORT - slot states in right order", func(t *testing.T) { @@ -85,60 +102,114 @@ func TestImportedServer(t *testing.T) { slotKey := util.SlotTable[slotNum] // import start - require.Equal(t, "OK", rdb.Do(ctx, "cluster", "import", slotNum, 0).Val()) - require.NoError(t, rdb.Set(ctx, slotKey, "slot1", 0).Err()) - require.Equal(t, "slot1", rdb.Get(ctx, slotKey).Val()) - clusterInfo := rdb.ClusterInfo(ctx).Val() + require.NoError(t, rdbA.Set(ctx, slotKey, "slot1", 0).Err()) + require.Equal(t, "slot1", rdbA.Get(ctx, slotKey).Val()) + require.Equal(t, "OK", rdbB.Do(ctx, "cluster", "import", slotNum, 0).Val()) + clusterInfo := rdbB.ClusterInfo(ctx).Val() require.Contains(t, clusterInfo, "importing_slot: 1") require.Contains(t, clusterInfo, "import_state: start") + clusterNodes := rdbB.ClusterNodes(ctx).Val() + require.Contains(t, clusterNodes, fmt.Sprintf("[%d-<-%s]", slotNum, srvAID)) + + require.NoError(t, rdbA.Do(ctx, "clusterx", "migrate", slotNum, srvBID).Err()) + require.Eventually(t, func() bool { + clusterInfo := rdbA.ClusterInfo(context.Background()).Val() + return strings.Contains(clusterInfo, fmt.Sprintf("migrating_slot: %d", slotNum)) && + strings.Contains(clusterInfo, fmt.Sprintf("migrating_state: %s", "success")) + }, 5*time.Second, 100*time.Millisecond) // import success - require.Equal(t, "OK", rdb.Do(ctx, "cluster", "import", slotNum, 1).Val()) - clusterInfo = rdb.ClusterInfo(ctx).Val() + require.Equal(t, "OK", rdbB.Do(ctx, "cluster", "import", slotNum, 1).Val()) + clusterInfo = rdbB.ClusterInfo(ctx).Val() require.Contains(t, clusterInfo, "importing_slot: 1") require.Contains(t, clusterInfo, "import_state: success") + // import finish and should not contain the import section + clusterNodes = rdbB.ClusterNodes(ctx).Val() + require.NotContains(t, clusterNodes, fmt.Sprintf("[%d-<-%s]", slotNum, srvAID)) + time.Sleep(50 * time.Millisecond) - require.Equal(t, "slot1", rdb.Get(ctx, slotKey).Val()) + require.Equal(t, "slot1", rdbB.Get(ctx, slotKey).Val()) }) t.Run("IMPORT - slot state 'error'", func(t *testing.T) { slotNum := 10 slotKey := util.SlotTable[slotNum] - require.Equal(t, "OK", rdb.Do(ctx, "cluster", "import", slotNum, 0).Val()) - require.NoError(t, rdb.Set(ctx, slotKey, "slot10_again", 0).Err()) - require.Equal(t, "slot10_again", rdb.Get(ctx, slotKey).Val()) + require.Equal(t, "OK", rdbB.Do(ctx, "cluster", "import", slotNum, 0).Val()) + require.NoError(t, rdbB.Set(ctx, slotKey, "slot10_again", 0).Err()) + require.Equal(t, "slot10_again", rdbB.Get(ctx, slotKey).Val()) // import error - require.Equal(t, "OK", rdb.Do(ctx, "cluster", "import", slotNum, 2).Val()) + require.Equal(t, "OK", rdbB.Do(ctx, "cluster", "import", slotNum, 2).Val()) time.Sleep(50 * time.Millisecond) - clusterInfo := rdb.ClusterInfo(ctx).Val() + clusterInfo := rdbB.ClusterInfo(ctx).Val() require.Contains(t, clusterInfo, "importing_slot: 10") require.Contains(t, clusterInfo, "import_state: error") // get empty - require.Zero(t, rdb.Exists(ctx, slotKey).Val()) + require.Zero(t, rdbB.Exists(ctx, slotKey).Val()) }) t.Run("IMPORT - connection broken", func(t *testing.T) { slotNum := 11 slotKey := util.SlotTable[slotNum] - require.Equal(t, "OK", rdb.Do(ctx, "cluster", "import", slotNum, 0).Val()) - require.NoError(t, rdb.Set(ctx, slotKey, "slot11", 0).Err()) - require.Equal(t, "slot11", rdb.Get(ctx, slotKey).Val()) + require.Equal(t, "OK", rdbB.Do(ctx, "cluster", "import", slotNum, 0).Val()) + require.NoError(t, rdbB.Set(ctx, slotKey, "slot11", 0).Err()) + require.Equal(t, "slot11", rdbB.Get(ctx, slotKey).Val()) // close connection, server will stop importing - require.NoError(t, rdb.Close()) - rdb = srv.NewClient() + require.NoError(t, rdbB.Close()) + rdbB = srvB.NewClient() time.Sleep(50 * time.Millisecond) - clusterInfo := rdb.ClusterInfo(ctx).Val() + clusterInfo := rdbB.ClusterInfo(ctx).Val() require.Contains(t, clusterInfo, "importing_slot: 11") require.Contains(t, clusterInfo, "import_state: error") // get empty - require.Zero(t, rdb.Exists(ctx, slotKey).Val()) + require.Zero(t, rdbB.Exists(ctx, slotKey).Val()) + }) +} + +func TestServiceImportingSlot(t *testing.T) { + ctx := context.Background() + + mockID0 := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx00" + mockSrv0Host := "127.0.0.1" + mockSrv0Port := 6666 + + srv1 := util.StartServer(t, map[string]string{"cluster-enabled": "yes"}) + defer func() { srv1.Close() }() + rdb1 := srv1.NewClient() + defer func() { require.NoError(t, rdb1.Close()) }() + id1 := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx01" + require.NoError(t, rdb1.Do(ctx, "clusterx", "SETNODEID", id1).Err()) + + clusterNodes := fmt.Sprintf("%s %s %d master - 0-8191\n", mockID0, mockSrv0Host, mockSrv0Port) + clusterNodes += fmt.Sprintf("%s %s %d master - 8192-16383\n", id1, srv1.Host(), srv1.Port()) + require.NoError(t, rdb1.Do(ctx, "clusterx", "SETNODES", clusterNodes, "1").Err()) + + slotNum := 1 + require.Equal(t, "OK", rdb1.Do(ctx, "cluster", "import", slotNum, 0).Val()) + + // create a new client that is not importing + cli := srv1.NewClient() + slotKey := util.SlotTable[slotNum] + + t.Run("IMPORT - query a key in importing slot without asking", func(t *testing.T) { + util.ErrorRegexp(t, cli.Type(ctx, slotKey).Err(), fmt.Sprintf("MOVED %d.*%d.*", slotNum, mockSrv0Port)) + }) + + t.Run("IMPORT - query a key in importing slot after asking", func(t *testing.T) { + require.Equal(t, "OK", cli.Do(ctx, "asking").Val()) + require.NoError(t, cli.Type(ctx, slotKey).Err()) + }) + + t.Run("IMPORT - asking flag will be reset after executing", func(t *testing.T) { + require.Equal(t, "OK", cli.Do(ctx, "asking").Val()) + require.NoError(t, cli.Type(ctx, slotKey).Err()) + util.ErrorRegexp(t, cli.Type(ctx, slotKey).Err(), fmt.Sprintf("MOVED %d.*%d.*", slotNum, mockSrv0Port)) }) } diff --git a/tests/gocase/integration/slotmigrate/slotmigrate_test.go b/tests/gocase/integration/slotmigrate/slotmigrate_test.go index d782e64d3c5..f54a428f71c 100644 --- a/tests/gocase/integration/slotmigrate/slotmigrate_test.go +++ b/tests/gocase/integration/slotmigrate/slotmigrate_test.go @@ -1023,6 +1023,10 @@ func TestSlotMigrateDataType(t *testing.T) { require.NoError(t, rdb0.LPush(ctx, util.SlotTable[testSlot], i).Err()) } require.Equal(t, "OK", rdb0.Do(ctx, "clusterx", "migrate", testSlot, id1).Val()) + + clusterNodes := rdb0.ClusterNodes(ctx).Val() + require.Contains(t, clusterNodes, fmt.Sprintf("[%d->-%s]", testSlot, id1)) + // should not finish 1.5s time.Sleep(1500 * time.Millisecond) requireMigrateState(t, rdb0, testSlot, SlotMigrationStateStarted) diff --git a/tests/gocase/unit/auth/auth_test.go b/tests/gocase/unit/auth/auth_test.go index a044e84d3a0..562281487e9 100644 --- a/tests/gocase/unit/auth/auth_test.go +++ b/tests/gocase/unit/auth/auth_test.go @@ -53,7 +53,7 @@ func TestAuth(t *testing.T) { t.Run("AUTH fails when a wrong password is given", func(t *testing.T) { r := rdb.Do(ctx, "AUTH", "wrong!") - require.ErrorContains(t, r.Err(), "invalid password") + require.ErrorContains(t, r.Err(), "Invalid password") }) t.Run("Arbitrary command gives an error when AUTH is required", func(t *testing.T) { diff --git a/tests/gocase/unit/copy/copy_test.go b/tests/gocase/unit/copy/copy_test.go new file mode 100644 index 00000000000..869c9d93bec --- /dev/null +++ b/tests/gocase/unit/copy/copy_test.go @@ -0,0 +1,1115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package copycmd + +import ( + "context" + "strconv" + "testing" + "time" + + "github.com/apache/kvrocks/tests/gocase/util" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" +) + +func TestCopyString(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + t.Run("Copy string replace", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Set(ctx, "a", "hello", 0).Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + require.EqualValues(t, "hello", rdb.Get(ctx, "a").Val()) + require.EqualValues(t, "hello", rdb.Get(ctx, "a1").Val()) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // to-key has value + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Set(ctx, "a", "hello", 0).Err()) + require.NoError(t, rdb.Set(ctx, "a1", "world", 0).Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + require.EqualValues(t, "hello", rdb.Get(ctx, "a").Val()) + require.EqualValues(t, "hello", rdb.Get(ctx, "a1").Val()) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // to-key has value with TTL + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Set(ctx, "a", "hello", 10*time.Second).Err()) + require.NoError(t, rdb.Set(ctx, "a1", "world", 1000*time.Second).Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + require.EqualValues(t, "hello", rdb.Get(ctx, "a").Val()) + require.EqualValues(t, "hello", rdb.Get(ctx, "a1").Val()) + util.BetweenValues(t, rdb.TTL(ctx, "a1").Val(), time.Second, 10*time.Second) + + // to-key has value that not same type + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Set(ctx, "a", "hello", 0).Err()) + require.NoError(t, rdb.LPush(ctx, "a1", "world").Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + require.EqualValues(t, "hello", rdb.Get(ctx, "a").Val()) + require.EqualValues(t, "hello", rdb.Get(ctx, "a1").Val()) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // key == newkey + require.NoError(t, rdb.Del(ctx, "a").Err()) + require.NoError(t, rdb.Set(ctx, "a", "hello", 0).Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a", 0, true).Err()) + require.EqualValues(t, "hello", rdb.Get(ctx, "a").Val()) + + // copy * 3 + require.NoError(t, rdb.Del(ctx, "a", "a1", "a2", "a3").Err()) + require.NoError(t, rdb.Set(ctx, "a", "hello", 0).Err()) + require.NoError(t, rdb.Set(ctx, "a1", "world1", 0).Err()) + require.NoError(t, rdb.Set(ctx, "a2", "world2", 0).Err()) + require.NoError(t, rdb.Set(ctx, "a3", "world3", 0).Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + require.NoError(t, rdb.Copy(ctx, "a1", "a2", 0, true).Err()) + require.NoError(t, rdb.Copy(ctx, "a2", "a3", 0, true).Err()) + require.EqualValues(t, "hello", rdb.Get(ctx, "a").Val()) + require.EqualValues(t, "hello", rdb.Get(ctx, "a1").Val()) + require.EqualValues(t, "hello", rdb.Get(ctx, "a2").Val()) + require.EqualValues(t, "hello", rdb.Get(ctx, "a3").Val()) + }) + + t.Run("Copy string not replace", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Set(ctx, "a", "hello", 0).Err()) + require.NoError(t, rdb.Set(ctx, "a1", "world", 0).Err()) + require.EqualValues(t, int64(0), rdb.Copy(ctx, "a", "a1", 0, false).Val()) + require.EqualValues(t, "hello", rdb.Get(ctx, "a").Val()) + require.EqualValues(t, "world", rdb.Get(ctx, "a1").Val()) + + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Set(ctx, "a", "hello", 0).Err()) + require.EqualValues(t, int64(1), rdb.Copy(ctx, "a", "a1", 0, false).Val()) + require.EqualValues(t, "hello", rdb.Get(ctx, "a").Val()) + require.EqualValues(t, "hello", rdb.Get(ctx, "a1").Val()) + + // key == newkey + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Set(ctx, "a", "hello", 0).Err()) + require.EqualValues(t, int64(0), rdb.Copy(ctx, "a", "a", 0, false).Val()) + require.EqualValues(t, "hello", rdb.Get(ctx, "a").Val()) + }) + +} + +func TestCopyJSON(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + setCmd := "JSON.SET" + getCmd := "JSON.GET" + jsonA := `{"x":1,"y":2}` + jsonB := `{"x":1}` + + t.Run("Copy json replace", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, setCmd, "a", "$", jsonA).Err()) + require.EqualValues(t, jsonA, rdb.Do(ctx, getCmd, "a").Val()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + require.EqualValues(t, jsonA, rdb.Do(ctx, getCmd, "a").Val()) + require.EqualValues(t, jsonA, rdb.Do(ctx, getCmd, "a1").Val()) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // to-key has value + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, setCmd, "a", "$", jsonA).Err()) + require.NoError(t, rdb.Do(ctx, setCmd, "a1", "$", jsonB).Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + require.EqualValues(t, jsonA, rdb.Do(ctx, getCmd, "a").Val()) + require.EqualValues(t, jsonA, rdb.Do(ctx, getCmd, "a1").Val()) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // to-key has value with TTL + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, setCmd, "a", "$", jsonA).Err()) + require.NoError(t, rdb.Expire(ctx, "a", 10*time.Second).Err()) + require.NoError(t, rdb.Do(ctx, setCmd, "a1", "$", jsonA).Err()) + require.NoError(t, rdb.Expire(ctx, "a1", 1000*time.Second).Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + require.EqualValues(t, jsonA, rdb.Do(ctx, getCmd, "a").Val()) + require.EqualValues(t, jsonA, rdb.Do(ctx, getCmd, "a1").Val()) + util.BetweenValues(t, rdb.TTL(ctx, "a1").Val(), time.Second, 10*time.Second) + + // to-key has value that not same type + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, setCmd, "a", "$", jsonA).Err()) + require.NoError(t, rdb.LPush(ctx, "a1", "world").Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + require.EqualValues(t, jsonA, rdb.Do(ctx, getCmd, "a").Val()) + require.EqualValues(t, jsonA, rdb.Do(ctx, getCmd, "a1").Val()) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // key == newkey + require.NoError(t, rdb.Del(ctx, "a").Err()) + require.NoError(t, rdb.Do(ctx, setCmd, "a", "$", jsonA).Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a", 0, true).Err()) + require.EqualValues(t, jsonA, rdb.Do(ctx, getCmd, "a").Val()) + + // copy * 3 + require.NoError(t, rdb.Del(ctx, "a", "a1", "a2", "a3").Err()) + require.NoError(t, rdb.Do(ctx, setCmd, "a", "$", jsonA).Err()) + require.NoError(t, rdb.Do(ctx, setCmd, "a1", "$", jsonB).Err()) + require.NoError(t, rdb.Do(ctx, setCmd, "a2", "$", jsonB).Err()) + require.NoError(t, rdb.Do(ctx, setCmd, "a3", "$", jsonB).Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + require.NoError(t, rdb.Copy(ctx, "a1", "a2", 0, true).Err()) + require.NoError(t, rdb.Copy(ctx, "a2", "a3", 0, true).Err()) + require.EqualValues(t, jsonA, rdb.Do(ctx, getCmd, "a").Val()) + require.EqualValues(t, jsonA, rdb.Do(ctx, getCmd, "a1").Val()) + require.EqualValues(t, jsonA, rdb.Do(ctx, getCmd, "a2").Val()) + require.EqualValues(t, jsonA, rdb.Do(ctx, getCmd, "a3").Val()) + }) + + t.Run("Copy json not replace", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, setCmd, "a", "$", jsonA).Err()) + require.NoError(t, rdb.Do(ctx, setCmd, "a1", "$", jsonB).Err()) + require.EqualValues(t, int64(0), rdb.Copy(ctx, "a", "a1", 0, false).Val()) + require.EqualValues(t, jsonA, rdb.Do(ctx, getCmd, "a").Val()) + require.EqualValues(t, jsonB, rdb.Do(ctx, getCmd, "a1").Val()) + + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, setCmd, "a", "$", jsonA).Err()) + require.EqualValues(t, int64(1), rdb.Copy(ctx, "a", "a1", 0, false).Val()) + require.EqualValues(t, jsonA, rdb.Do(ctx, getCmd, "a").Val()) + require.EqualValues(t, jsonA, rdb.Do(ctx, getCmd, "a1").Val()) + + // key == newkey + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, setCmd, "a", "$", jsonA).Err()) + require.EqualValues(t, int64(0), rdb.Copy(ctx, "a", "a", 0, false).Val()) + require.EqualValues(t, jsonA, rdb.Do(ctx, getCmd, "a").Val()) + }) + +} + +func TestCopyList(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + EqualListValues := func(t *testing.T, key string, value []string) { + require.EqualValues(t, len(value), rdb.LLen(ctx, key).Val()) + for i := 0; i < len(value); i++ { + require.EqualValues(t, value[i], rdb.LIndex(ctx, key, int64(i)).Val()) + } + } + + t.Run("Copy string replace", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.LPush(ctx, "a", "1", "2", "3").Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + require.EqualValues(t, 3, rdb.LLen(ctx, "a").Val()) + require.EqualValues(t, 3, rdb.LLen(ctx, "a1").Val()) + EqualListValues(t, "a1", []string{"3", "2", "1"}) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // to-key has value + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.LPush(ctx, "a", "1", "2", "3").Err()) + require.NoError(t, rdb.LPush(ctx, "a1", "a").Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + EqualListValues(t, "a", []string{"3", "2", "1"}) + EqualListValues(t, "a1", []string{"3", "2", "1"}) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // to-key has value with TTL + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.LPush(ctx, "a", "1", "2", "3").Err()) + require.NoError(t, rdb.Expire(ctx, "a", 10*time.Second).Err()) + require.NoError(t, rdb.LPush(ctx, "a1", "a").Err()) + require.NoError(t, rdb.Expire(ctx, "a1", 1000*time.Second).Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + EqualListValues(t, "a", []string{"3", "2", "1"}) + EqualListValues(t, "a1", []string{"3", "2", "1"}) + util.BetweenValues(t, rdb.TTL(ctx, "a1").Val(), time.Second, 10*time.Second) + + // to-key has value that not same type + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.LPush(ctx, "a", "1", "2", "3").Err()) + require.NoError(t, rdb.Set(ctx, "a1", "world", 0).Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + EqualListValues(t, "a", []string{"3", "2", "1"}) + EqualListValues(t, "a1", []string{"3", "2", "1"}) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // key == newkey + require.NoError(t, rdb.Del(ctx, "a").Err()) + require.NoError(t, rdb.LPush(ctx, "a", "1", "2", "3").Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a", 0, true).Err()) + EqualListValues(t, "a1", []string{"3", "2", "1"}) + + // coopy * 3 + require.NoError(t, rdb.Del(ctx, "a", "a1", "a2", "a3").Err()) + require.NoError(t, rdb.LPush(ctx, "a", "1", "2", "3").Err()) + require.NoError(t, rdb.LPush(ctx, "a1", "2").Err()) + require.NoError(t, rdb.LPush(ctx, "a2", "3").Err()) + require.NoError(t, rdb.LPush(ctx, "a3", "1").Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + require.NoError(t, rdb.Copy(ctx, "a1", "a2", 0, true).Err()) + require.NoError(t, rdb.Copy(ctx, "a2", "a3", 0, true).Err()) + EqualListValues(t, "a", []string{"3", "2", "1"}) + EqualListValues(t, "a1", []string{"3", "2", "1"}) + EqualListValues(t, "a2", []string{"3", "2", "1"}) + EqualListValues(t, "a3", []string{"3", "2", "1"}) + }) + + t.Run("Copy string not replace", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.LPush(ctx, "a", "1", "2", "3").Err()) + require.NoError(t, rdb.LPush(ctx, "a1", "3").Err()) + require.EqualValues(t, int64(0), rdb.Copy(ctx, "a", "a1", 0, false).Val()) + EqualListValues(t, "a", []string{"3", "2", "1"}) + EqualListValues(t, "a1", []string{"3"}) + + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.LPush(ctx, "a", "1", "2", "3").Err()) + require.EqualValues(t, int64(1), rdb.Copy(ctx, "a", "a1", 0, false).Val()) + EqualListValues(t, "a", []string{"3", "2", "1"}) + EqualListValues(t, "a1", []string{"3", "2", "1"}) + + // key == newkey + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.LPush(ctx, "a", "1", "2", "3").Err()) + require.EqualValues(t, int64(0), rdb.Copy(ctx, "a", "a", 0, false).Val()) + EqualListValues(t, "a", []string{"3", "2", "1"}) + }) + +} + +func TestCopyHash(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + EqualListValues := func(t *testing.T, key string, value map[string]string) { + require.EqualValues(t, len(value), rdb.HLen(ctx, key).Val()) + for subKey := range value { + require.EqualValues(t, value[subKey], rdb.HGet(ctx, key, subKey).Val()) + } + } + + t.Run("Copy hash replace", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.HSet(ctx, "a", "a", "1", "b", "2", "c", "3").Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + EqualListValues(t, "a", map[string]string{ + "a": "1", + "b": "2", + "c": "3", + }) + EqualListValues(t, "a1", map[string]string{ + "a": "1", + "b": "2", + "c": "3", + }) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // to-key has value + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.HSet(ctx, "a", "a", "1", "b", "2", "c", "3").Err()) + require.NoError(t, rdb.HSet(ctx, "a1", "a", "1").Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + EqualListValues(t, "a", map[string]string{ + "a": "1", + "b": "2", + "c": "3", + }) + EqualListValues(t, "a1", map[string]string{ + "a": "1", + "b": "2", + "c": "3", + }) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // to-key has value with TTL + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.HSet(ctx, "a", "a", "1", "b", "2", "c", "3").Err()) + require.NoError(t, rdb.Expire(ctx, "a", 10*time.Second).Err()) + require.NoError(t, rdb.HSet(ctx, "a1", "a", "1").Err()) + require.NoError(t, rdb.Expire(ctx, "a1", 1000*time.Second).Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + EqualListValues(t, "a", map[string]string{ + "a": "1", + "b": "2", + "c": "3", + }) + EqualListValues(t, "a1", map[string]string{ + "a": "1", + "b": "2", + "c": "3", + }) + util.BetweenValues(t, rdb.TTL(ctx, "a1").Val(), time.Second, 10*time.Second) + + // to-key has value that not same type + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.HSet(ctx, "a", "a", "1", "b", "2", "c", "3").Err()) + require.NoError(t, rdb.LPush(ctx, "a1", "a", "1").Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + EqualListValues(t, "a", map[string]string{ + "a": "1", + "b": "2", + "c": "3", + }) + EqualListValues(t, "a1", map[string]string{ + "a": "1", + "b": "2", + "c": "3", + }) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // key == newkey + require.NoError(t, rdb.Del(ctx, "a").Err()) + require.NoError(t, rdb.HSet(ctx, "a", "a", "1", "b", "2", "c", "3").Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a", 0, true).Err()) + EqualListValues(t, "a1", map[string]string{ + "a": "1", + "b": "2", + "c": "3", + }) + + // copy * 3 + require.NoError(t, rdb.Del(ctx, "a", "a1", "a2", "a3").Err()) + require.NoError(t, rdb.HSet(ctx, "a", "a", "1", "b", "2", "c", "3").Err()) + require.NoError(t, rdb.HSet(ctx, "a1", "a", "1").Err()) + require.NoError(t, rdb.HSet(ctx, "a2", "a", "1").Err()) + require.NoError(t, rdb.HSet(ctx, "a3", "a", "1").Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + require.NoError(t, rdb.Copy(ctx, "a1", "a2", 0, true).Err()) + require.NoError(t, rdb.Copy(ctx, "a2", "a3", 0, true).Err()) + EqualListValues(t, "a", map[string]string{ + "a": "1", + "b": "2", + "c": "3", + }) + EqualListValues(t, "a1", map[string]string{ + "a": "1", + "b": "2", + "c": "3", + }) + EqualListValues(t, "a2", map[string]string{ + "a": "1", + "b": "2", + "c": "3", + }) + EqualListValues(t, "a3", map[string]string{ + "a": "1", + "b": "2", + "c": "3", + }) + }) + + t.Run("Copy hash not replace", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.HSet(ctx, "a", "a", "1", "b", "2", "c", "3").Err()) + require.NoError(t, rdb.HSet(ctx, "a1", "a", "1").Err()) + require.EqualValues(t, int64(0), rdb.Copy(ctx, "a", "a1", 0, false).Val()) + EqualListValues(t, "a", map[string]string{ + "a": "1", + "b": "2", + "c": "3", + }) + EqualListValues(t, "a1", map[string]string{ + "a": "1", + }) + + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.HSet(ctx, "a", "a", "1", "b", "2", "c", "3").Err()) + require.EqualValues(t, int64(1), rdb.Copy(ctx, "a", "a1", 0, false).Val()) + EqualListValues(t, "a", map[string]string{ + "a": "1", + "b": "2", + "c": "3", + }) + EqualListValues(t, "a1", map[string]string{ + "a": "1", + "b": "2", + "c": "3", + }) + + // key == newkey + require.NoError(t, rdb.Del(ctx, "a").Err()) + require.NoError(t, rdb.HSet(ctx, "a", "a", "1", "b", "2", "c", "3").Err()) + require.EqualValues(t, int64(0), rdb.Copy(ctx, "a", "a", 0, false).Val()) + EqualListValues(t, "a", map[string]string{ + "a": "1", + "b": "2", + "c": "3", + }) + }) + +} + +func TestCopySet(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + EqualSetValues := func(t *testing.T, key string, value []string) { + require.EqualValues(t, len(value), rdb.SCard(ctx, key).Val()) + for index := range value { + require.EqualValues(t, true, rdb.SIsMember(ctx, key, value[index]).Val()) + } + } + + t.Run("Copy set replace", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.SAdd(ctx, "a", "1", "2", "3").Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + EqualSetValues(t, "a", []string{"1", "2", "3"}) + EqualSetValues(t, "a1", []string{"1", "2", "3"}) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // to-key has value + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.SAdd(ctx, "a", "1", "2", "3").Err()) + require.NoError(t, rdb.SAdd(ctx, "a1", "a", "1").Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + EqualSetValues(t, "a", []string{"1", "2", "3"}) + EqualSetValues(t, "a1", []string{"1", "2", "3"}) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // to-key has value with TTL + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.SAdd(ctx, "a", "1", "2", "3").Err()) + require.NoError(t, rdb.Expire(ctx, "a", 10*time.Second).Err()) + require.NoError(t, rdb.SAdd(ctx, "a1", "1").Err()) + require.NoError(t, rdb.Expire(ctx, "a1", 1000*time.Second).Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + EqualSetValues(t, "a", []string{"1", "2", "3"}) + EqualSetValues(t, "a1", []string{"1", "2", "3"}) + util.BetweenValues(t, rdb.TTL(ctx, "a1").Val(), time.Second, 10*time.Second) + + // to-key has value that not same type + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.SAdd(ctx, "a", "1", "2", "3").Err()) + require.NoError(t, rdb.LPush(ctx, "a1", "1").Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + EqualSetValues(t, "a", []string{"1", "2", "3"}) + EqualSetValues(t, "a1", []string{"1", "2", "3"}) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // key == newkey + require.NoError(t, rdb.Del(ctx, "a").Err()) + require.NoError(t, rdb.SAdd(ctx, "a", "1", "2", "3").Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a", 0, true).Err()) + EqualSetValues(t, "a1", []string{"1", "2", "3"}) + + // copy * 3 + require.NoError(t, rdb.Del(ctx, "a", "a1", "a2", "a3").Err()) + require.NoError(t, rdb.SAdd(ctx, "a", "1", "2", "3").Err()) + require.NoError(t, rdb.SAdd(ctx, "a1", "1").Err()) + require.NoError(t, rdb.SAdd(ctx, "a2", "a2", "1").Err()) + require.NoError(t, rdb.SAdd(ctx, "a3", "a3", "1").Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + require.NoError(t, rdb.Copy(ctx, "a1", "a2", 0, true).Err()) + require.NoError(t, rdb.Copy(ctx, "a2", "a3", 0, true).Err()) + EqualSetValues(t, "a", []string{"1", "2", "3"}) + EqualSetValues(t, "a1", []string{"1", "2", "3"}) + EqualSetValues(t, "a2", []string{"1", "2", "3"}) + EqualSetValues(t, "a3", []string{"1", "2", "3"}) + }) + + t.Run("Copy set not replace", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.SAdd(ctx, "a", "1", "2", "3").Err()) + require.NoError(t, rdb.SAdd(ctx, "a1", "1").Err()) + require.EqualValues(t, int64(0), rdb.Copy(ctx, "a", "a1", 0, false).Val()) + EqualSetValues(t, "a", []string{"1", "2", "3"}) + EqualSetValues(t, "a1", []string{"1"}) + + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.SAdd(ctx, "a", "1", "2", "3").Err()) + require.EqualValues(t, int64(1), rdb.Copy(ctx, "a", "a1", 0, false).Val()) + EqualSetValues(t, "a1", []string{"1", "2", "3"}) + EqualSetValues(t, "a", []string{"1", "2", "3"}) + + // key == newkey + require.NoError(t, rdb.Del(ctx, "a").Err()) + require.NoError(t, rdb.SAdd(ctx, "a", "1", "2", "3").Err()) + require.EqualValues(t, int64(0), rdb.Copy(ctx, "a", "a", 0, false).Val()) + EqualSetValues(t, "a", []string{"1", "2", "3"}) + + }) + +} + +func TestCopyZset(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + EqualZSetValues := func(t *testing.T, key string, value map[string]int) { + require.EqualValues(t, len(value), rdb.ZCard(ctx, key).Val()) + for subKey := range value { + score := value[subKey] + require.EqualValues(t, []string{subKey}, rdb.ZRangeByScore(ctx, key, + &redis.ZRangeBy{Max: strconv.Itoa(score), Min: strconv.Itoa(score)}).Val()) + require.EqualValues(t, float64(score), rdb.ZScore(ctx, key, subKey).Val()) + } + } + + zMember := []redis.Z{{Member: "a", Score: 1}, {Member: "b", Score: 2}, {Member: "c", Score: 3}} + zMember2 := []redis.Z{{Member: "a", Score: 2}} + + t.Run("Copy zset", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.ZAdd(ctx, "a", zMember...).Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + EqualZSetValues(t, "a", map[string]int{ + "a": 1, + "b": 2, + "c": 3, + }) + EqualZSetValues(t, "a1", map[string]int{ + "a": 1, + "b": 2, + "c": 3, + }) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // to-key has value + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.ZAdd(ctx, "a", zMember...).Err()) + require.NoError(t, rdb.ZAdd(ctx, "a1", zMember2...).Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + EqualZSetValues(t, "a", map[string]int{ + "a": 1, + "b": 2, + "c": 3, + }) + EqualZSetValues(t, "a1", map[string]int{ + "a": 1, + "b": 2, + "c": 3, + }) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // to-key has value with TTL + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.ZAdd(ctx, "a", zMember...).Err()) + require.NoError(t, rdb.Expire(ctx, "a", 10*time.Second).Err()) + require.NoError(t, rdb.ZAdd(ctx, "a1", zMember2...).Err()) + require.NoError(t, rdb.Expire(ctx, "a1", 1000*time.Second).Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + EqualZSetValues(t, "a", map[string]int{ + "a": 1, + "b": 2, + "c": 3, + }) + EqualZSetValues(t, "a1", map[string]int{ + "a": 1, + "b": 2, + "c": 3, + }) + util.BetweenValues(t, rdb.TTL(ctx, "a1").Val(), time.Second, 10*time.Second) + + // to-key has value that not same type + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.ZAdd(ctx, "a", zMember...).Err()) + require.NoError(t, rdb.LPush(ctx, "a1", 1, 2, 3).Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + EqualZSetValues(t, "a", map[string]int{ + "a": 1, + "b": 2, + "c": 3, + }) + EqualZSetValues(t, "a1", map[string]int{ + "a": 1, + "b": 2, + "c": 3, + }) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // key == newkey + require.NoError(t, rdb.Del(ctx, "a").Err()) + require.NoError(t, rdb.ZAdd(ctx, "a", zMember...).Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a", 0, true).Err()) + EqualZSetValues(t, "a1", map[string]int{ + "a": 1, + "b": 2, + "c": 3, + }) + // copy * 3 + require.NoError(t, rdb.Del(ctx, "a", "a1", "a2", "a3").Err()) + require.NoError(t, rdb.ZAdd(ctx, "a", zMember...).Err()) + require.NoError(t, rdb.ZAdd(ctx, "a1", zMember2...).Err()) + require.NoError(t, rdb.ZAdd(ctx, "a2", zMember2...).Err()) + require.NoError(t, rdb.ZAdd(ctx, "a3", zMember2...).Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + require.NoError(t, rdb.Copy(ctx, "a1", "a2", 0, true).Err()) + require.NoError(t, rdb.Copy(ctx, "a2", "a3", 0, true).Err()) + EqualZSetValues(t, "a", map[string]int{ + "a": 1, + "b": 2, + "c": 3, + }) + EqualZSetValues(t, "a1", map[string]int{ + "a": 1, + "b": 2, + "c": 3, + }) + EqualZSetValues(t, "a2", map[string]int{ + "a": 1, + "b": 2, + "c": 3, + }) + EqualZSetValues(t, "a3", map[string]int{ + "a": 1, + "b": 2, + "c": 3, + }) + }) + + t.Run("Copy zset not replace", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.ZAdd(ctx, "a", zMember...).Err()) + require.NoError(t, rdb.ZAdd(ctx, "a1", zMember2...).Err()) + require.EqualValues(t, int64(0), rdb.Copy(ctx, "a", "a1", 0, false).Val()) + EqualZSetValues(t, "a", map[string]int{ + "a": 1, + "b": 2, + "c": 3, + }) + EqualZSetValues(t, "a1", map[string]int{ + "a": 2, + }) + + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.ZAdd(ctx, "a", zMember...).Err()) + require.EqualValues(t, int64(1), rdb.Copy(ctx, "a", "a1", 0, false).Val()) + EqualZSetValues(t, "a", map[string]int{ + "a": 1, + "b": 2, + "c": 3, + }) + EqualZSetValues(t, "a1", map[string]int{ + "a": 1, + "b": 2, + "c": 3, + }) + + // key == newkey + require.NoError(t, rdb.Del(ctx, "a").Err()) + require.NoError(t, rdb.ZAdd(ctx, "a", zMember...).Err()) + require.EqualValues(t, int64(0), rdb.Copy(ctx, "a", "a", 0, false).Val()) + EqualZSetValues(t, "a", map[string]int{ + "a": 1, + "b": 2, + "c": 3, + }) + }) + +} + +func TestCopyBitmap(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + EqualBitSetValues := func(t *testing.T, key string, value []int64) { + for i := 0; i < len(value); i++ { + require.EqualValues(t, int64(value[i]), rdb.Do(ctx, "BITPOS", key, 1, value[i]/8).Val()) + } + } + + SetBits := func(t *testing.T, key string, value []int64) { + for i := 0; i < len(value); i++ { + require.NoError(t, rdb.Do(ctx, "SETBIT", key, value[i], 1).Err()) + } + } + bitSetA := []int64{16, 1024 * 8 * 2, 1024 * 8 * 12} + bitSetB := []int64{1} + + t.Run("Copy bitmap replace", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + SetBits(t, "a", bitSetA) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + EqualBitSetValues(t, "a", bitSetA) + EqualBitSetValues(t, "a1", bitSetA) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // newkey has value with TTL + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + SetBits(t, "a", bitSetA) + SetBits(t, "a1", bitSetB) + require.NoError(t, rdb.Expire(ctx, "a", 10*time.Second).Err()) + require.NoError(t, rdb.Expire(ctx, "a1", 1000*time.Second).Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + EqualBitSetValues(t, "a", bitSetA) + EqualBitSetValues(t, "a1", bitSetA) + util.BetweenValues(t, rdb.TTL(ctx, "a1").Val(), time.Second, 10*time.Second) + + // newkey has value that not same type + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + SetBits(t, "a", bitSetA) + require.NoError(t, rdb.LPush(ctx, "a1", "a").Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + EqualBitSetValues(t, "a", bitSetA) + EqualBitSetValues(t, "a1", bitSetA) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // key == newkey + require.NoError(t, rdb.Del(ctx, "a").Err()) + SetBits(t, "a", bitSetA) + require.NoError(t, rdb.Copy(ctx, "a", "a", 0, true).Err()) + EqualBitSetValues(t, "a", bitSetA) + + // copy * 3 + require.NoError(t, rdb.Del(ctx, "a", "a1", "a2", "a3").Err()) + SetBits(t, "a", bitSetA) + SetBits(t, "a1", bitSetB) + SetBits(t, "a2", bitSetB) + SetBits(t, "a3", bitSetB) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + require.NoError(t, rdb.Copy(ctx, "a1", "a2", 0, true).Err()) + require.NoError(t, rdb.Copy(ctx, "a2", "a3", 0, true).Err()) + EqualBitSetValues(t, "a", bitSetA) + EqualBitSetValues(t, "a1", bitSetA) + EqualBitSetValues(t, "a2", bitSetA) + EqualBitSetValues(t, "a3", bitSetA) + }) + + t.Run("Copy bitmap not replace", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + SetBits(t, "a", bitSetA) + SetBits(t, "a1", bitSetB) + require.EqualValues(t, int64(0), rdb.Copy(ctx, "a", "a1", 0, false).Val()) + EqualBitSetValues(t, "a", bitSetA) + EqualBitSetValues(t, "a1", bitSetB) + + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + SetBits(t, "a", bitSetA) + require.EqualValues(t, int64(1), rdb.Copy(ctx, "a", "a1", 0, false).Val()) + EqualBitSetValues(t, "a", bitSetA) + EqualBitSetValues(t, "a1", bitSetA) + + // key == newkey + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + SetBits(t, "a", bitSetA) + require.EqualValues(t, int64(0), rdb.Copy(ctx, "a", "a", 0, false).Val()) + EqualBitSetValues(t, "a", bitSetA) + }) + +} + +func TestCopySint(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + EqualSIntValues := func(t *testing.T, key string, value []int) { + require.EqualValues(t, len(value), rdb.Do(ctx, "SICARD", key).Val()) + for i := 0; i < len(value); i++ { + require.EqualValues(t, []interface{}{int64(1)}, rdb.Do(ctx, "SIEXISTS", key, value[i]).Val()) + } + } + + t.Run("Copy sint replace", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, "SIADD", "a", 3, 4, 5, 123, 245).Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + EqualSIntValues(t, "a", []int{3, 4, 5, 123, 245}) + EqualSIntValues(t, "a1", []int{3, 4, 5, 123, 245}) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // to-key has value with TTL + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, "SIADD", "a", 3, 4, 5, 123, 245).Err()) + require.NoError(t, rdb.Expire(ctx, "a", 10*time.Second).Err()) + require.NoError(t, rdb.Do(ctx, "SIADD", "a1", 99).Err()) + require.NoError(t, rdb.Expire(ctx, "a1", 1000*time.Second).Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + EqualSIntValues(t, "a", []int{3, 4, 5, 123, 245}) + EqualSIntValues(t, "a1", []int{3, 4, 5, 123, 245}) + util.BetweenValues(t, rdb.TTL(ctx, "a1").Val(), time.Second, 10*time.Second) + + // to-key has value that not same type + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, "SIADD", "a", 3, 4, 5, 123, 245).Err()) + require.NoError(t, rdb.LPush(ctx, "a1", "a").Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + EqualSIntValues(t, "a", []int{3, 4, 5, 123, 245}) + EqualSIntValues(t, "a1", []int{3, 4, 5, 123, 245}) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // key == newkey + require.NoError(t, rdb.Del(ctx, "a").Err()) + require.NoError(t, rdb.Do(ctx, "SIADD", "a", 3, 4, 5, 123, 245).Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a", 0, true).Err()) + EqualSIntValues(t, "a1", []int{3, 4, 5, 123, 245}) + + // copy * 3 + require.NoError(t, rdb.Del(ctx, "a", "a1", "a2", "a3").Err()) + require.NoError(t, rdb.Do(ctx, "SIADD", "a", 3, 4, 5, 123, 245).Err()) + require.NoError(t, rdb.Do(ctx, "SIADD", "a1", 85).Err()) + require.NoError(t, rdb.Do(ctx, "SIADD", "a2", 77, 0).Err()) + require.NoError(t, rdb.Do(ctx, "SIADD", "a3", 111, 222, 333).Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + require.NoError(t, rdb.Copy(ctx, "a1", "a2", 0, true).Err()) + require.NoError(t, rdb.Copy(ctx, "a2", "a3", 0, true).Err()) + EqualSIntValues(t, "a", []int{3, 4, 5, 123, 245}) + EqualSIntValues(t, "a1", []int{3, 4, 5, 123, 245}) + EqualSIntValues(t, "a2", []int{3, 4, 5, 123, 245}) + EqualSIntValues(t, "a3", []int{3, 4, 5, 123, 245}) + }) + + t.Run("Copy sint not replace", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, "SIADD", "a", 3, 4, 5, 123, 245).Err()) + require.NoError(t, rdb.Do(ctx, "SIADD", "a1", 99).Err()) + require.EqualValues(t, int64(0), rdb.Copy(ctx, "a", "a1", 0, false).Val()) + EqualSIntValues(t, "a", []int{3, 4, 5, 123, 245}) + EqualSIntValues(t, "a1", []int{99}) + + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, "SIADD", "a", 3, 4, 5, 123, 245).Err()) + require.EqualValues(t, int64(1), rdb.Copy(ctx, "a", "a1", 0, false).Val()) + EqualSIntValues(t, "a", []int{3, 4, 5, 123, 245}) + EqualSIntValues(t, "a1", []int{3, 4, 5, 123, 245}) + + // key == newkey + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, "SIADD", "a", 3, 4, 5, 123, 245).Err()) + require.EqualValues(t, int64(0), rdb.Copy(ctx, "a", "a", 0, false).Val()) + EqualSIntValues(t, "a", []int{3, 4, 5, 123, 245}) + + }) + +} + +func TestCopyBloom(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + bfAdd := "BF.ADD" + bfExists := "BF.EXISTS" + + t.Run("Copy bloom replace", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, bfAdd, "a", "hello").Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + require.EqualValues(t, 1, rdb.Do(ctx, bfExists, "a", "hello").Val()) + require.EqualValues(t, 1, rdb.Do(ctx, bfExists, "a1", "hello").Val()) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // to-key has value with TTL + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, bfAdd, "a", "hello").Err()) + require.NoError(t, rdb.Expire(ctx, "a", 10*time.Second).Err()) + require.NoError(t, rdb.Do(ctx, bfAdd, "a1", "world").Err()) + require.NoError(t, rdb.Expire(ctx, "a1", 1000*time.Second).Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + require.EqualValues(t, 1, rdb.Do(ctx, bfExists, "a", "hello").Val()) + require.EqualValues(t, 1, rdb.Do(ctx, bfExists, "a1", "hello").Val()) + require.EqualValues(t, 0, rdb.Do(ctx, bfExists, "a1", "world").Val()) + util.BetweenValues(t, rdb.TTL(ctx, "a1").Val(), time.Second, 10*time.Second) + + // to-key has value that not same type + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, bfAdd, "a", "hello").Err()) + require.NoError(t, rdb.LPush(ctx, "a1", "a").Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + require.EqualValues(t, 1, rdb.Do(ctx, bfExists, "a", "hello").Val()) + require.EqualValues(t, 1, rdb.Do(ctx, bfExists, "a1", "hello").Val()) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // key == newkey + require.NoError(t, rdb.Del(ctx, "a").Err()) + require.NoError(t, rdb.Do(ctx, bfAdd, "a", "hello").Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a", 0, true).Err()) + require.EqualValues(t, 1, rdb.Do(ctx, bfExists, "a", "hello").Val()) + + // copy * 3 + require.NoError(t, rdb.Del(ctx, "a", "a1", "a2", "a3").Err()) + require.NoError(t, rdb.Do(ctx, bfAdd, "a", "hello").Err()) + require.NoError(t, rdb.Do(ctx, bfAdd, "a1", "world1").Err()) + require.NoError(t, rdb.Do(ctx, bfAdd, "a2", "world2").Err()) + require.NoError(t, rdb.Do(ctx, bfAdd, "a3", "world3").Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + require.NoError(t, rdb.Copy(ctx, "a1", "a2", 0, true).Err()) + require.NoError(t, rdb.Copy(ctx, "a2", "a3", 0, true).Err()) + require.EqualValues(t, 1, rdb.Do(ctx, bfExists, "a", "hello").Val()) + require.EqualValues(t, 1, rdb.Do(ctx, bfExists, "a1", "hello").Val()) + require.EqualValues(t, 1, rdb.Do(ctx, bfExists, "a2", "hello").Val()) + require.EqualValues(t, 1, rdb.Do(ctx, bfExists, "a3", "hello").Val()) + }) + + t.Run("Copy bloom not replace", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, bfAdd, "a", "hello").Err()) + require.NoError(t, rdb.Do(ctx, bfAdd, "a1", "world").Err()) + require.EqualValues(t, int64(0), rdb.Copy(ctx, "a", "a1", 0, false).Val()) + require.EqualValues(t, 1, rdb.Do(ctx, bfExists, "a", "hello").Val()) + require.EqualValues(t, 1, rdb.Do(ctx, bfExists, "a1", "world").Val()) + + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, bfAdd, "a", "hello").Err()) + require.EqualValues(t, int64(1), rdb.Copy(ctx, "a", "a1", 0, false).Val()) + require.EqualValues(t, 1, rdb.Do(ctx, bfExists, "a", "hello").Val()) + require.EqualValues(t, 1, rdb.Do(ctx, bfExists, "a1", "hello").Val()) + + // key == newkey + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, bfAdd, "a", "hello").Err()) + require.EqualValues(t, int64(0), rdb.Copy(ctx, "a", "a", 0, false).Val()) + require.EqualValues(t, 1, rdb.Do(ctx, bfExists, "a", "hello").Val()) + }) + +} + +func TestCopyStream(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + XADD := "XADD" + XREAD := "XREAD" + t.Run("Copy stream replace", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, XADD, "a", "*", "a", "hello").Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + require.Contains(t, rdb.Do(ctx, XREAD, "STREAMS", "a", "0").String(), "hello") + require.Contains(t, rdb.Do(ctx, XREAD, "STREAMS", "a1", "0").String(), "hello") + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // to-key has value with TTL + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, XADD, "a", "*", "a", "hello").Err()) + require.NoError(t, rdb.Expire(ctx, "a", 10*time.Second).Err()) + require.NoError(t, rdb.Do(ctx, XADD, "a1", "*", "a", "world").Err()) + require.NoError(t, rdb.Expire(ctx, "a1", 1000*time.Second).Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + require.Contains(t, rdb.Do(ctx, XREAD, "STREAMS", "a", "0").String(), "hello") + require.Contains(t, rdb.Do(ctx, XREAD, "STREAMS", "a1", "0").String(), "hello") + util.BetweenValues(t, rdb.TTL(ctx, "a1").Val(), time.Second, 10*time.Second) + + // to-key has value that not same type + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, XADD, "a", "*", "a", "hello").Err()) + require.NoError(t, rdb.LPush(ctx, "a1", "a").Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + require.Contains(t, rdb.Do(ctx, XREAD, "STREAMS", "a", "0").String(), "hello") + require.Contains(t, rdb.Do(ctx, XREAD, "STREAMS", "a1", "0").String(), "hello") + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // key == newkey + require.NoError(t, rdb.Del(ctx, "a").Err()) + require.NoError(t, rdb.Do(ctx, XADD, "a", "*", "a", "hello").Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a", 0, true).Err()) + require.Contains(t, rdb.Do(ctx, XREAD, "STREAMS", "a", "0").String(), "hello") + + // copy * 3 + require.NoError(t, rdb.Del(ctx, "a", "a1", "a2", "a3").Err()) + require.NoError(t, rdb.Do(ctx, XADD, "a", "*", "a", "hello").Err()) + require.NoError(t, rdb.Do(ctx, XADD, "a1", "*", "a", "world1").Err()) + require.NoError(t, rdb.Do(ctx, XADD, "a2", "*", "a", "world2").Err()) + require.NoError(t, rdb.Do(ctx, XADD, "a3", "*", "a", "world3").Err()) + require.NoError(t, rdb.Copy(ctx, "a", "a1", 0, true).Err()) + require.NoError(t, rdb.Copy(ctx, "a1", "a2", 0, true).Err()) + require.NoError(t, rdb.Copy(ctx, "a2", "a3", 0, true).Err()) + require.Contains(t, rdb.Do(ctx, XREAD, "STREAMS", "a", "0").String(), "hello") + require.Contains(t, rdb.Do(ctx, XREAD, "STREAMS", "a1", "0").String(), "hello") + require.Contains(t, rdb.Do(ctx, XREAD, "STREAMS", "a2", "0").String(), "hello") + require.Contains(t, rdb.Do(ctx, XREAD, "STREAMS", "a3", "0").String(), "hello") + }) + + t.Run("Copy stream not replace", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, XADD, "a", "*", "a", "hello").Err()) + require.NoError(t, rdb.Do(ctx, XADD, "a1", "*", "a", "world").Err()) + require.EqualValues(t, int64(0), rdb.Copy(ctx, "a", "a1", 0, false).Val()) + require.Contains(t, rdb.Do(ctx, XREAD, "STREAMS", "a", "0").String(), "hello") + require.Contains(t, rdb.Do(ctx, XREAD, "STREAMS", "a1", "0").String(), "world") + + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, XADD, "a", "*", "a", "hello").Err()) + require.EqualValues(t, int64(1), rdb.Copy(ctx, "a", "a1", 0, false).Val()) + require.Contains(t, rdb.Do(ctx, XREAD, "STREAMS", "a", "0").String(), "hello") + require.Contains(t, rdb.Do(ctx, XREAD, "STREAMS", "a1", "0").String(), "hello") + + // key == newkey + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, XADD, "a", "*", "a", "hello").Err()) + require.EqualValues(t, int64(0), rdb.Copy(ctx, "a", "a", 0, false).Val()) + require.Contains(t, rdb.Do(ctx, XREAD, "STREAMS", "a", "0").String(), "hello") + }) + +} + +func TestCopyError(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + t.Run("Copy to not db 0", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "a").Err()) + require.NoError(t, rdb.Set(ctx, "a", "hello", 0).Err()) + require.Error(t, rdb.Copy(ctx, "", "a", 1, true).Err()) + require.Error(t, rdb.Copy(ctx, "", "a", 3, false).Err()) + }) + + t.Run("Copy from empty key", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, ".empty", "a").Err()) + require.EqualValues(t, int64(0), rdb.Copy(ctx, ".empty", "a", 0, false).Val()) + require.EqualValues(t, int64(0), rdb.Copy(ctx, ".empty", "a", 0, true).Val()) + }) + +} diff --git a/tests/gocase/unit/dump/dump_test.go b/tests/gocase/unit/dump/dump_test.go new file mode 100644 index 00000000000..bca9300d1ea --- /dev/null +++ b/tests/gocase/unit/dump/dump_test.go @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package dump + +import ( + "context" + "fmt" + "testing" + + "github.com/apache/kvrocks/tests/gocase/util" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" +) + +func TestDump_String(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + keyValues := map[string]string{ + "test_string_key0": "hello,world!", + "test_string_key1": "654321", + } + for key, value := range keyValues { + require.NoError(t, rdb.Del(ctx, key).Err()) + require.NoError(t, rdb.Set(ctx, key, value, 0).Err()) + serialized, err := rdb.Dump(ctx, key).Result() + require.NoError(t, err) + + restoredKey := fmt.Sprintf("restore_%s", key) + require.NoError(t, rdb.RestoreReplace(ctx, restoredKey, 0, serialized).Err()) + require.Equal(t, value, rdb.Get(ctx, restoredKey).Val()) + } +} + +func TestDump_Hash(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + key := "test_hash_key" + fields := map[string]string{ + "name": "redis tutorial", + "description": "redis basic commands for caching", + "likes": "20", + "visitors": "23000", + } + require.NoError(t, rdb.Del(ctx, key).Err()) + require.NoError(t, rdb.HMSet(ctx, key, fields).Err()) + serialized, err := rdb.Dump(ctx, key).Result() + require.NoError(t, err) + + restoredKey := fmt.Sprintf("restore_%s", key) + require.NoError(t, rdb.RestoreReplace(ctx, restoredKey, 0, serialized).Err()) + require.EqualValues(t, fields, rdb.HGetAll(ctx, restoredKey).Val()) +} + +func TestDump_ZSet(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + memberScores := []redis.Z{{Member: "kvrocks1", Score: 1}, {Member: "kvrocks2", Score: 2}, {Member: "kvrocks3", Score: 3}} + key := "test_zset_key" + require.NoError(t, rdb.Del(ctx, key).Err()) + require.NoError(t, rdb.ZAdd(ctx, key, memberScores...).Err()) + serialized, err := rdb.Dump(ctx, key).Result() + require.NoError(t, err) + + restoredKey := fmt.Sprintf("restore_%s", key) + require.NoError(t, rdb.RestoreReplace(ctx, restoredKey, 0, serialized).Err()) + + require.EqualValues(t, memberScores, rdb.ZRangeWithScores(ctx, restoredKey, 0, -1).Val()) +} + +func TestDump_List(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + elements := []string{"kvrocks1", "kvrocks2", "kvrocks3"} + key := "test_list_key" + require.NoError(t, rdb.Del(ctx, key).Err()) + require.NoError(t, rdb.RPush(ctx, key, elements).Err()) + serialized, err := rdb.Dump(ctx, key).Result() + require.NoError(t, err) + require.Equal(t, "\x0e\x03\x15\x15\x00\x00\x00\n\x00\x00\x00\x01\x00\x00\bkvrocks1\xff\x15\x15\x00\x00\x00\n\x00\x00\x00\x01\x00\x00\bkvrocks2\xff\x15\x15\x00\x00\x00\n\x00\x00\x00\x01\x00\x00\bkvrocks3\xff\x06\x00u\xc7\x19h\x1da\xd0\xd8", serialized) + + restoredKey := fmt.Sprintf("restore_%s", key) + require.NoError(t, rdb.RestoreReplace(ctx, restoredKey, 0, serialized).Err()) + require.EqualValues(t, elements, rdb.LRange(ctx, restoredKey, 0, -1).Val()) + + //test special case + elements = []string{"A", " ", "", util.RandString(0, 4000, util.Alpha)} + require.NoError(t, rdb.Del(ctx, key).Err()) + require.NoError(t, rdb.RPush(ctx, key, elements).Err()) + serialized, err = rdb.Dump(ctx, key).Result() + require.NoError(t, err) + + require.NoError(t, rdb.RestoreReplace(ctx, restoredKey, 0, serialized).Err()) + require.EqualValues(t, elements, rdb.LRange(ctx, restoredKey, 0, -1).Val()) +} + +func TestDump_Set(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + members := []string{"kvrocks1", "kvrocks2", "kvrocks3"} + key := "test_set_key" + require.NoError(t, rdb.Del(ctx, key).Err()) + require.NoError(t, rdb.SAdd(ctx, key, members).Err()) + serialized, err := rdb.Dump(ctx, key).Result() + require.NoError(t, err) + + restoredKey := fmt.Sprintf("restore_%s", key) + require.NoError(t, rdb.RestoreReplace(ctx, restoredKey, 0, serialized).Err()) + require.ElementsMatch(t, members, rdb.SMembers(ctx, restoredKey).Val()) +} diff --git a/tests/gocase/unit/hello/hello_test.go b/tests/gocase/unit/hello/hello_test.go index ab3d417b2fd..5b6eea35d01 100644 --- a/tests/gocase/unit/hello/hello_test.go +++ b/tests/gocase/unit/hello/hello_test.go @@ -38,7 +38,7 @@ func TestHello(t *testing.T) { t.Run("hello with wrong protocol", func(t *testing.T) { r := rdb.Do(ctx, "HELLO", "1") - require.ErrorContains(t, r.Err(), "-NOPROTO unsupported protocol version") + require.ErrorContains(t, r.Err(), "NOPROTO unsupported protocol version") }) t.Run("hello with protocol 2", func(t *testing.T) { @@ -61,7 +61,7 @@ func TestHello(t *testing.T) { t.Run("hello with wrong protocol", func(t *testing.T) { r := rdb.Do(ctx, "HELLO", "5") - require.ErrorContains(t, r.Err(), "-NOPROTO unsupported protocol version") + require.ErrorContains(t, r.Err(), "NOPROTO unsupported protocol version") }) t.Run("hello with non protocol", func(t *testing.T) { @@ -114,12 +114,12 @@ func TestHelloWithAuth(t *testing.T) { t.Run("AUTH fails when a wrong password is given", func(t *testing.T) { r := rdb.Do(ctx, "HELLO", "3", "AUTH", "wrong!") - require.ErrorContains(t, r.Err(), "invalid password") + require.ErrorContains(t, r.Err(), "Invalid password") }) t.Run("AUTH fails when a wrong username is given", func(t *testing.T) { r := rdb.Do(ctx, "HELLO", "3", "AUTH", "wrong!", "foobar") - require.ErrorContains(t, r.Err(), "invalid password") + require.ErrorContains(t, r.Err(), "Invalid password") }) t.Run("Arbitrary command gives an error when AUTH is required", func(t *testing.T) { diff --git a/tests/gocase/unit/movex/movex_test.go b/tests/gocase/unit/movex/movex_test.go new file mode 100644 index 00000000000..cb777396aa7 --- /dev/null +++ b/tests/gocase/unit/movex/movex_test.go @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package movex + +import ( + "context" + "testing" + + "github.com/apache/kvrocks/tests/gocase/util" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" +) + +func TestMoveDummy(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + t.Run("Dummy move", func(t *testing.T) { + r := rdb.Move(ctx, "key1", 1) + require.NoError(t, r.Err()) + require.Equal(t, false, r.Val()) + + require.NoError(t, rdb.Del(ctx, "key1").Err()) + require.NoError(t, rdb.Set(ctx, "key1", "value1", 0).Err()) + r = rdb.Move(ctx, "key1", 1) + require.NoError(t, r.Err()) + require.Equal(t, true, r.Val()) + }) +} + +func TestMoveXWithoutPwd(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + t.Run("MoveX without password", func(t *testing.T) { + r := rdb.Do(ctx, "MOVEX", "key1", "token1") + require.EqualError(t, r.Err(), "ERR Forbidden to move key when requirepass is empty") + }) +} + +func TestMoveX(t *testing.T) { + token := "pwd" + srv := util.StartServer(t, map[string]string{ + "requirepass": token, + }) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClientWithOption(&redis.Options{ + Password: token, + }) + defer func() { require.NoError(t, rdb.Close()) }() + + t.Run("MoveX test", func(t *testing.T) { + nsTokens := map[string]string{ + "ns1": "token1", + "ns2": "token2", + "ns3": "token3", + } + for ns, token := range nsTokens { + r := rdb.Do(ctx, "NAMESPACE", "ADD", ns, token) + require.NoError(t, r.Err()) + require.Equal(t, "OK", r.Val()) + } + for ns, token := range nsTokens { + r := rdb.Do(ctx, "NAMESPACE", "GET", ns) + require.NoError(t, r.Err()) + require.Equal(t, token, r.Val()) + } + + // add 3 kvs to default namespace + require.NoError(t, rdb.Del(ctx, "key1", "key2", "key3").Err()) + require.NoError(t, rdb.Set(ctx, "key1", "value1", 0).Err()) + require.NoError(t, rdb.Set(ctx, "key2", "value2", 0).Err()) + require.NoError(t, rdb.Set(ctx, "key3", "value3", 0).Err()) + require.EqualValues(t, "value1", rdb.Get(ctx, "key1").Val()) + require.EqualValues(t, "value2", rdb.Get(ctx, "key2").Val()) + require.EqualValues(t, "value3", rdb.Get(ctx, "key3").Val()) + + // move key1 to ns1 + r := rdb.Do(ctx, "MOVEX", "key1", "token1") + require.NoError(t, r.Err()) + require.EqualValues(t, int64(1), r.Val()) + require.EqualValues(t, "", rdb.Get(ctx, "key1").Val()) + require.NoError(t, rdb.Do(ctx, "AUTH", "token1").Err()) + require.EqualValues(t, "value1", rdb.Get(ctx, "key1").Val()) + require.NoError(t, rdb.Do(ctx, "AUTH", token).Err()) + + // move key2 to ns2, with wrong token first + r = rdb.Do(ctx, "MOVEX", "key2", "token4") + require.EqualError(t, r.Err(), "ERR Invalid password") + r = rdb.Do(ctx, "MOVEX", "key2", "token2") + require.NoError(t, r.Err()) + require.EqualValues(t, int64(1), r.Val()) + require.EqualValues(t, "", rdb.Get(ctx, "key2").Val()) + require.NoError(t, rdb.Do(ctx, "AUTH", "token2").Err()) + require.EqualValues(t, "value2", rdb.Get(ctx, "key2").Val()) + require.NoError(t, rdb.Do(ctx, "AUTH", token).Err()) + + // move non-existent keys + r = rdb.Do(ctx, "MOVEX", "key2", "token2") + require.NoError(t, r.Err()) + require.EqualValues(t, int64(0), r.Val()) + + // move key that exists in the target namespace + require.NoError(t, rdb.Set(ctx, "key1", "value4", 0).Err()) + r = rdb.Do(ctx, "MOVEX", "key1", "token1") + require.NoError(t, r.Err()) + require.EqualValues(t, int64(0), r.Val()) + + // move key3 to ns3, and move back + r = rdb.Do(ctx, "MOVEX", "key3", "token3") + require.NoError(t, r.Err()) + require.EqualValues(t, int64(1), r.Val()) + require.EqualValues(t, "", rdb.Get(ctx, "key3").Val()) + require.NoError(t, rdb.Do(ctx, "AUTH", "token3").Err()) + require.EqualValues(t, "value3", rdb.Get(ctx, "key3").Val()) + r = rdb.Do(ctx, "MOVEX", "key3", token) + require.NoError(t, r.Err()) + require.EqualValues(t, int64(1), r.Val()) + require.EqualValues(t, "", rdb.Get(ctx, "key3").Val()) + require.NoError(t, rdb.Do(ctx, "AUTH", token).Err()) + require.EqualValues(t, "value3", rdb.Get(ctx, "key3").Val()) + + // move in place + require.NoError(t, rdb.Do(ctx, "AUTH", "token1").Err()) + require.EqualValues(t, "value1", rdb.Get(ctx, "key1").Val()) + r = rdb.Do(ctx, "MOVEX", "key1", "token1") + require.NoError(t, r.Err()) + require.EqualValues(t, int64(0), r.Val()) + }) +} diff --git a/tests/gocase/unit/namespace/namespace_test.go b/tests/gocase/unit/namespace/namespace_test.go index 1b80e5b94b0..755766bbcf2 100644 --- a/tests/gocase/unit/namespace/namespace_test.go +++ b/tests/gocase/unit/namespace/namespace_test.go @@ -238,7 +238,16 @@ func TestNamespaceReplicate(t *testing.T) { }) t.Run("Turn off namespace replication is not allowed", func(t *testing.T) { + r := masterRdb.Do(ctx, "NAMESPACE", "ADD", "test-ns", "ns-token") + require.NoError(t, r.Err()) + require.Equal(t, "OK", r.Val()) util.ErrorRegexp(t, masterRdb.ConfigSet(ctx, "repl-namespace-enabled", "no").Err(), ".*cannot switch off repl_namespace_enabled when namespaces exist in db.*") + + // it should be allowed after deleting all namespaces + r = masterRdb.Do(ctx, "NAMESPACE", "DEL", "test-ns") + require.NoError(t, r.Err()) + require.Equal(t, "OK", r.Val()) + require.NoError(t, masterRdb.ConfigSet(ctx, "repl-namespace-enabled", "no").Err()) }) } diff --git a/tests/gocase/unit/protocol/protocol_test.go b/tests/gocase/unit/protocol/protocol_test.go index 9becba5e46c..6be669bb872 100644 --- a/tests/gocase/unit/protocol/protocol_test.go +++ b/tests/gocase/unit/protocol/protocol_test.go @@ -224,6 +224,21 @@ func TestProtocolRESP2(t *testing.T) { }) } +func handshakeWithRESP3(t *testing.T, c *util.TCPClient) { + require.NoError(t, c.WriteArgs("HELLO", "3")) + values := []string{"%6", + "$6", "server", "$5", "redis", + "$7", "version", "$5", "4.0.0", + "$5", "proto", ":3", + "$4", "mode", "$10", "standalone", + "$4", "role", "$6", "master", + "$7", "modules", "_", + } + for _, line := range values { + c.MustRead(t, line) + } +} + func TestProtocolRESP3(t *testing.T) { srv := util.StartServer(t, map[string]string{ "resp3-enabled": "yes", @@ -236,20 +251,9 @@ func TestProtocolRESP3(t *testing.T) { require.NoError(t, c.Close()) require.NoError(t, rdb.Close()) }() + handshakeWithRESP3(t, c) t.Run("debug protocol string", func(t *testing.T) { - require.NoError(t, c.WriteArgs("HELLO", "3")) - values := []string{"%6", - "$6", "server", "$5", "redis", - "$7", "version", "$5", "4.0.0", - "$5", "proto", ":3", - "$4", "mode", "$10", "standalone", - "$4", "role", "$6", "master", - "$7", "modules", "_", - } - for _, line := range values { - c.MustRead(t, line) - } types := map[string][]string{ "string": {"$11", "Hello World"}, @@ -285,6 +289,17 @@ func TestProtocolRESP3(t *testing.T) { c.MustRead(t, "_") }) + t.Run("should return PUSH type", func(t *testing.T) { + // use a new client to avoid affecting other tests + require.NoError(t, c.WriteArgs("SUBSCRIBE", "test-channel")) + c.MustRead(t, ">3") + c.MustRead(t, "$9") + c.MustRead(t, "subscribe") + c.MustRead(t, "$12") + c.MustRead(t, "test-channel") + c.MustRead(t, ":1") + }) + t.Run("null array", func(t *testing.T) { require.NoError(t, c.WriteArgs("ZRANK", "no-exists-zset", "m0", "WITHSCORE")) c.MustRead(t, "_") @@ -295,19 +310,6 @@ func TestProtocolRESP3(t *testing.T) { Members: []redis.Z{{1, "one"}, {2, "two"}, {3, "three"}}, }) - require.NoError(t, c.WriteArgs("HELLO", "3")) - values := []string{"%6", - "$6", "server", "$5", "redis", - "$7", "version", "$5", "4.0.0", - "$5", "proto", ":3", - "$4", "mode", "$10", "standalone", - "$4", "role", "$6", "master", - "$7", "modules", "_", - } - for _, line := range values { - c.MustRead(t, line) - } - // should return an array of strings if without score require.NoError(t, c.WriteArgs("ZRANGE", "zset", "0", "-1")) c.MustRead(t, "*3") diff --git a/tests/gocase/unit/pubsub/pubsub_test.go b/tests/gocase/unit/pubsub/pubsub_test.go index d3f8c4a1c11..bf4d655a726 100644 --- a/tests/gocase/unit/pubsub/pubsub_test.go +++ b/tests/gocase/unit/pubsub/pubsub_test.go @@ -36,8 +36,18 @@ func receiveType[T any](t *testing.T, pubsub *redis.PubSub, typ T) T { return msg.(T) } -func TestPubSub(t *testing.T) { - srv := util.StartServer(t, map[string]string{}) +func TestPubSubWithRESP2(t *testing.T) { + testPubSub(t, "no") +} + +func TestPubSubWithRESP3(t *testing.T) { + testPubSub(t, "yes") +} + +func testPubSub(t *testing.T, enabledRESP3 string) { + srv := util.StartServer(t, map[string]string{ + "resp3-enabled": enabledRESP3, + }) defer srv.Close() ctx := context.Background() diff --git a/tests/gocase/unit/rename/rename_test.go b/tests/gocase/unit/rename/rename_test.go index 7bbd4a5284d..7cc0248c288 100644 --- a/tests/gocase/unit/rename/rename_test.go +++ b/tests/gocase/unit/rename/rename_test.go @@ -30,7 +30,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestRename_String(t *testing.T) { +func TestRenameString(t *testing.T) { srv := util.StartServer(t, map[string]string{}) defer srv.Close() @@ -117,7 +117,7 @@ func TestRename_String(t *testing.T) { } -func TestRename_JSON(t *testing.T) { +func TestRenameJSON(t *testing.T) { srv := util.StartServer(t, map[string]string{}) defer srv.Close() @@ -212,7 +212,7 @@ func TestRename_JSON(t *testing.T) { } -func TestRename_List(t *testing.T) { +func TestRenameList(t *testing.T) { srv := util.StartServer(t, map[string]string{}) defer srv.Close() @@ -309,7 +309,7 @@ func TestRename_List(t *testing.T) { } -func TestRename_hash(t *testing.T) { +func TestRenameHash(t *testing.T) { srv := util.StartServer(t, map[string]string{}) defer srv.Close() @@ -443,7 +443,7 @@ func TestRename_hash(t *testing.T) { } -func TestRename_set(t *testing.T) { +func TestRenameSet(t *testing.T) { srv := util.StartServer(t, map[string]string{}) defer srv.Close() @@ -540,7 +540,7 @@ func TestRename_set(t *testing.T) { } -func TestRename_zset(t *testing.T) { +func TestRenameZset(t *testing.T) { srv := util.StartServer(t, map[string]string{}) defer srv.Close() @@ -681,7 +681,7 @@ func TestRename_zset(t *testing.T) { } -func TestRename_Bitmap(t *testing.T) { +func TestRenameBitmap(t *testing.T) { srv := util.StartServer(t, map[string]string{}) defer srv.Close() @@ -776,7 +776,7 @@ func TestRename_Bitmap(t *testing.T) { } -func TestRename_SInt(t *testing.T) { +func TestRenameSint(t *testing.T) { srv := util.StartServer(t, map[string]string{}) defer srv.Close() @@ -864,7 +864,7 @@ func TestRename_SInt(t *testing.T) { } -func TestRename_Bloom(t *testing.T) { +func TestRenameBloom(t *testing.T) { srv := util.StartServer(t, map[string]string{}) defer srv.Close() @@ -947,7 +947,7 @@ func TestRename_Bloom(t *testing.T) { }) } -func TestRename_Stream(t *testing.T) { +func TestRenameStream(t *testing.T) { srv := util.StartServer(t, map[string]string{}) defer srv.Close() @@ -1028,7 +1028,7 @@ func TestRename_Stream(t *testing.T) { }) } -func TestRename_Error(t *testing.T) { +func TestRenameError(t *testing.T) { srv := util.StartServer(t, map[string]string{}) defer srv.Close() diff --git a/tests/gocase/unit/scan/scan_test.go b/tests/gocase/unit/scan/scan_test.go index 1e6488d78d1..62b91aff9fd 100644 --- a/tests/gocase/unit/scan/scan_test.go +++ b/tests/gocase/unit/scan/scan_test.go @@ -290,6 +290,84 @@ func ScanTest(t *testing.T, rdb *redis.Client, ctx context.Context) { require.Len(t, zsetKeys, test.count) }) } + + t.Run("SCAN reject invalid input", func(t *testing.T) { + util.ErrorRegexp(t, rdb.Do(ctx, "SCAN", "0", "hello").Err(), ".*syntax error.*") + util.ErrorRegexp(t, rdb.Do(ctx, "SCAN", "0", "hello", "hi").Err(), ".*syntax error.*") + util.ErrorRegexp(t, rdb.Do(ctx, "SCAN", "0", "count", "1", "hello", "hi").Err(), ".*syntax error.*") + util.ErrorRegexp(t, rdb.Do(ctx, "SCAN", "0", "hello", "hi", "count", "1").Err(), ".*syntax error.*") + require.NoError(t, rdb.Do(ctx, "SCAN", "0", "count", "1", "match", "a*").Err()) + require.NoError(t, rdb.Do(ctx, "SCAN", "0", "match", "a*", "count", "1").Err()) + util.ErrorRegexp(t, rdb.Do(ctx, "SCAN", "0", "count", "1", "match", "a*", "hello").Err(), ".*syntax error.*") + util.ErrorRegexp(t, rdb.Do(ctx, "SCAN", "0", "count", "1", "match", "a*", "hello", "hi").Err(), ".*syntax error.*") + }) + + t.Run("SCAN with type args ", func(t *testing.T) { + //string type + require.NoError(t, rdb.Set(ctx, "stringtype1", "fee1", 0).Err()) + require.NoError(t, rdb.Set(ctx, "stringtype2", "fee1", 0).Err()) + require.NoError(t, rdb.Set(ctx, "stringtype3", "fee1", 0).Err()) + require.Equal(t, []string{"stringtype1", "stringtype2", "stringtype3"}, scanAll(t, rdb, "match", "stringtype*", "type", "string")) + require.Equal(t, []string{"stringtype1", "stringtype2", "stringtype3"}, scanAll(t, rdb, "match", "stringtype*", "count", "3", "type", "string")) + //hash type + require.NoError(t, rdb.HSet(ctx, "hashtype1", "key1", "val1", "key2", "val2").Err()) + require.NoError(t, rdb.HSet(ctx, "hashtype2", "key1", "val1", "key2", "val2").Err()) + require.NoError(t, rdb.HSet(ctx, "hashtype3", "key1", "val1", "key2", "val2").Err()) + require.Equal(t, []string{"hashtype1", "hashtype2", "hashtype3"}, scanAll(t, rdb, "match", "hashtype*", "type", "hash")) + require.Equal(t, []string{"hashtype1", "hashtype2", "hashtype3"}, scanAll(t, rdb, "match", "hashtype*", "count", "3", "type", "hash")) + //list type + require.NoError(t, rdb.RPush(ctx, "listtype1", "1").Err()) + require.NoError(t, rdb.RPush(ctx, "listtype2", "2").Err()) + require.NoError(t, rdb.RPush(ctx, "listtype3", "3").Err()) + require.Equal(t, []string{"listtype1", "listtype2", "listtype3"}, scanAll(t, rdb, "match", "listtype*", "type", "list")) + require.Equal(t, []string{"listtype1", "listtype2", "listtype3"}, scanAll(t, rdb, "match", "listtype*", "count", "3", "type", "list")) + //set type + require.NoError(t, rdb.SAdd(ctx, "settype1", "1").Err()) + require.NoError(t, rdb.SAdd(ctx, "settype2", "1").Err()) + require.NoError(t, rdb.SAdd(ctx, "settype3", "1").Err()) + require.Equal(t, []string{"settype1", "settype2", "settype3"}, scanAll(t, rdb, "match", "settype*", "type", "set")) + require.Equal(t, []string{"settype1", "settype2", "settype3"}, scanAll(t, rdb, "match", "settype*", "count", "3", "type", "set")) + //zet type + members := []redis.Z{ + {Score: 1, Member: "1"}, + {Score: 2, Member: "2"}, + {Score: 3, Member: "3"}, + {Score: 10, Member: "4"}, + } + require.NoError(t, rdb.ZAdd(ctx, "zsettype1", members...).Err()) + require.NoError(t, rdb.ZAdd(ctx, "zsettype2", members...).Err()) + require.NoError(t, rdb.ZAdd(ctx, "zsettype3", members...).Err()) + require.Equal(t, []string{"zsettype1", "zsettype2", "zsettype3"}, scanAll(t, rdb, "match", "zsettype*", "type", "zset")) + require.Equal(t, []string{"zsettype1", "zsettype2", "zsettype3"}, scanAll(t, rdb, "match", "zsettype*", "count", "3", "type", "zset")) + //bitmap type + require.NoError(t, rdb.SetBit(ctx, "bitmaptype1", 0, 0).Err()) + require.NoError(t, rdb.SetBit(ctx, "bitmaptype2", 0, 0).Err()) + require.NoError(t, rdb.SetBit(ctx, "bitmaptype3", 0, 0).Err()) + require.Equal(t, []string{"bitmaptype1", "bitmaptype2", "bitmaptype3"}, scanAll(t, rdb, "match", "bitmaptype*", "type", "bitmap")) + require.Equal(t, []string{"bitmaptype1", "bitmaptype2", "bitmaptype3"}, scanAll(t, rdb, "match", "bitmaptype*", "count", "3", "type", "bitmap")) + //stream type + require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{Stream: "streamtype1", Values: []string{"item", "1", "value", "a"}}).Err()) + require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{Stream: "streamtype2", Values: []string{"item", "1", "value", "a"}}).Err()) + require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{Stream: "streamtype3", Values: []string{"item", "1", "value", "a"}}).Err()) + require.Equal(t, []string{"streamtype1", "streamtype2", "streamtype3"}, scanAll(t, rdb, "match", "streamtype*", "type", "stream")) + require.Equal(t, []string{"streamtype1", "streamtype2", "streamtype3"}, scanAll(t, rdb, "match", "streamtype*", "count", "3", "type", "stream")) + //MBbloom type + require.NoError(t, rdb.Do(ctx, "bf.reserve", "MBbloomtype1", "0.02", "1000").Err()) + require.NoError(t, rdb.Do(ctx, "bf.reserve", "MBbloomtype2", "0.02", "1000").Err()) + require.NoError(t, rdb.Do(ctx, "bf.reserve", "MBbloomtype3", "0.02", "1000").Err()) + require.Equal(t, []string{"MBbloomtype1", "MBbloomtype2", "MBbloomtype3"}, scanAll(t, rdb, "match", "MBbloomtype*", "type", "MBbloom--")) + require.Equal(t, []string{"MBbloomtype1", "MBbloomtype2", "MBbloomtype3"}, scanAll(t, rdb, "match", "MBbloomtype*", "count", "3", "type", "MBbloom--")) + //ReJSON-RL type + require.NoError(t, rdb.Do(ctx, "JSON.SET", "ReJSONtype1", "$", ` {"x":1, "y":2} `).Err()) + require.NoError(t, rdb.Do(ctx, "JSON.SET", "ReJSONtype2", "$", ` {"x":1, "y":2} `).Err()) + require.NoError(t, rdb.Do(ctx, "JSON.SET", "ReJSONtype3", "$", ` {"x":1, "y":2} `).Err()) + require.Equal(t, []string{"ReJSONtype1", "ReJSONtype2", "ReJSONtype3"}, scanAll(t, rdb, "match", "ReJSONtype*", "type", "ReJSON-RL")) + require.Equal(t, []string{"ReJSONtype1", "ReJSONtype2", "ReJSONtype3"}, scanAll(t, rdb, "match", "ReJSONtype*", "count", "3", "type", "ReJSON-RL")) + //invalid type + util.ErrorRegexp(t, rdb.Do(ctx, "SCAN", "0", "count", "1", "match", "a*", "type", "hi").Err(), "Invalid type") + + }) + } // SCAN of Kvrocks returns _cursor instead of cursor. Thus, redis.Client Scan can fail with diff --git a/tests/gocase/unit/search/search_test.go b/tests/gocase/unit/search/search_test.go new file mode 100644 index 00000000000..651323c1710 --- /dev/null +++ b/tests/gocase/unit/search/search_test.go @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package search + +import ( + "context" + "testing" + + "github.com/apache/kvrocks/tests/gocase/util" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" +) + +func TestSearch(t *testing.T) { + t.Skip("search commands is disabled") + + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + t.Run("FT.CREATE", func(t *testing.T) { + require.NoError(t, rdb.Do(ctx, "FT.CREATE", "testidx1", "ON", "JSON", "PREFIX", "1", "test1:", "SCHEMA", "a", "TAG", "b", "NUMERIC").Err()) + + verify := func(t *testing.T) { + require.Equal(t, []interface{}{"testidx1"}, rdb.Do(ctx, "FT._LIST").Val()) + infoRes := rdb.Do(ctx, "FT.INFO", "testidx1") + require.NoError(t, infoRes.Err()) + idxInfo := infoRes.Val().([]interface{}) + require.Equal(t, "index_name", idxInfo[0]) + require.Equal(t, "testidx1", idxInfo[1]) + require.Equal(t, "on_data_type", idxInfo[2]) + require.Equal(t, "ReJSON-RL", idxInfo[3]) + require.Equal(t, "prefixes", idxInfo[4]) + require.Equal(t, []interface{}{"test1:"}, idxInfo[5]) + require.Equal(t, "fields", idxInfo[6]) + require.Equal(t, []interface{}{"a", "tag"}, idxInfo[7].([]interface{})[0]) + require.Equal(t, []interface{}{"b", "numeric"}, idxInfo[7].([]interface{})[1]) + } + verify(t) + + srv.Restart() + verify(t) + }) + + t.Run("FT.SEARCH", func(t *testing.T) { + require.NoError(t, rdb.Do(ctx, "JSON.SET", "test1:k1", "$", `{"a": "x,y", "b": 11}`).Err()) + require.NoError(t, rdb.Do(ctx, "JSON.SET", "test1:k2", "$", `{"a": "x,z", "b": 22}`).Err()) + require.NoError(t, rdb.Do(ctx, "JSON.SET", "test1:k3", "$", `{"a": "y,z", "b": 33}`).Err()) + require.NoError(t, rdb.Do(ctx, "JSON.SET", "test2:k4", "$", `{"a": "x,y,z", "b": 44}`).Err()) + + verify := func(t *testing.T, res *redis.Cmd) { + require.NoError(t, res.Err()) + require.Equal(t, 7, len(res.Val().([]interface{}))) + require.Equal(t, int64(3), res.Val().([]interface{})[0]) + require.Equal(t, "test1:k1", res.Val().([]interface{})[1]) + require.Equal(t, "test1:k2", res.Val().([]interface{})[3]) + require.Equal(t, "test1:k3", res.Val().([]interface{})[5]) + } + + res := rdb.Do(ctx, "FT.SEARCHSQL", "select * from testidx1") + verify(t, res) + res = rdb.Do(ctx, "FT.SEARCH", "testidx1", "*") + verify(t, res) + + verify = func(t *testing.T, res *redis.Cmd) { + require.NoError(t, res.Err()) + require.Equal(t, 3, len(res.Val().([]interface{}))) + require.Equal(t, int64(1), res.Val().([]interface{})[0]) + require.Equal(t, "test1:k2", res.Val().([]interface{})[1]) + fields := res.Val().([]interface{})[2].([]interface{}) + if fields[0] == "a" { + require.Equal(t, "x,z", fields[1]) + require.Equal(t, "b", fields[2]) + require.Equal(t, "22", fields[3]) + } else if fields[0] == "b" { + require.Equal(t, "22", fields[1]) + require.Equal(t, "a", fields[2]) + require.Equal(t, "x,z", fields[3]) + } else { + require.Fail(t, "not started with a or b") + } + } + + res = rdb.Do(ctx, "FT.SEARCHSQL", `select * from testidx1 where a hastag "z" and b < 30`) + verify(t, res) + res = rdb.Do(ctx, "FT.SEARCH", "testidx1", `@a:{z} @b:[-inf (30]`) + verify(t, res) + }) + + t.Run("FT.DROPINDEX", func(t *testing.T) { + require.NoError(t, rdb.Do(ctx, "FT.DROPINDEX", "testidx1").Err()) + + verify := func(t *testing.T) { + require.Equal(t, []interface{}{}, rdb.Do(ctx, "FT._LIST").Val()) + infoRes := rdb.Do(ctx, "FT.INFO", "testidx1") + require.Equal(t, "ERR index not found", infoRes.Err().Error()) + } + verify(t) + + srv.Restart() + verify(t) + }) +} diff --git a/tests/gocase/unit/sort/sort_test.go b/tests/gocase/unit/sort/sort_test.go new file mode 100644 index 00000000000..6715ed783a0 --- /dev/null +++ b/tests/gocase/unit/sort/sort_test.go @@ -0,0 +1,881 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package sort + +import ( + "context" + "fmt" + "testing" + + "github.com/redis/go-redis/v9" + + "github.com/apache/kvrocks/tests/gocase/util" + "github.com/stretchr/testify/require" +) + +func TestSortParser(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + t.Run("SORT Parser", func(t *testing.T) { + rdb.RPush(ctx, "bad-case-key", 5, 4, 3, 2, 1) + + _, err := rdb.Do(ctx, "Sort").Result() + require.EqualError(t, err, "ERR wrong number of arguments") + + _, err = rdb.Do(ctx, "Sort", "bad-case-key", "BadArg").Result() + require.EqualError(t, err, "ERR syntax error") + + _, err = rdb.Do(ctx, "Sort", "bad-case-key", "LIMIT").Result() + require.EqualError(t, err, "ERR no more item to parse") + + _, err = rdb.Do(ctx, "Sort", "bad-case-key", "LIMIT", 1).Result() + require.EqualError(t, err, "ERR no more item to parse") + + _, err = rdb.Do(ctx, "Sort", "bad-case-key", "LIMIT", 1, "not-number").Result() + require.EqualError(t, err, "ERR not started as an integer") + + _, err = rdb.Do(ctx, "Sort", "bad-case-key", "STORE").Result() + require.EqualError(t, err, "ERR no more item to parse") + + rdb.MSet(ctx, "rank_1", 1, "rank_2", "rank_3", 3, "rank_4", 4, "rank_5", 5) + _, err = rdb.Do(ctx, "Sort", "bad-case-key", "BY", "dontsort", "BY", "rank_*").Result() + require.EqualError(t, err, "ERR don't use multiple BY parameters") + + _, err = rdb.Do(ctx, "Sort_RO", "bad-case-key", "STORE", "store_ro_key").Result() + require.EqualError(t, err, "ERR SORT_RO is read-only and does not support the STORE parameter") + }) +} + +func TestSortLengthLimit(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + t.Run("SORT Length Limit", func(t *testing.T) { + for i := 0; i <= 512; i++ { + rdb.LPush(ctx, "many-list-elems-key", i) + } + _, err := rdb.Sort(ctx, "many-list-elems-key", &redis.Sort{}).Result() + require.EqualError(t, err, "The number of elements to be sorted exceeds SORT_LENGTH_LIMIT = 512") + + for i := 0; i <= 512; i++ { + rdb.SAdd(ctx, "many-set-elems-key", i) + } + _, err = rdb.Sort(ctx, "many-set-elems-key", &redis.Sort{}).Result() + require.EqualError(t, err, "The number of elements to be sorted exceeds SORT_LENGTH_LIMIT = 512") + + for i := 0; i <= 512; i++ { + rdb.ZAdd(ctx, "many-zset-elems-key", redis.Z{Score: float64(i), Member: fmt.Sprintf("%d", i)}) + } + _, err = rdb.Sort(ctx, "many-zset-elems-key", &redis.Sort{}).Result() + require.EqualError(t, err, "The number of elements to be sorted exceeds SORT_LENGTH_LIMIT = 512") + }) +} + +func TestListSort(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + t.Run("SORT Basic", func(t *testing.T) { + rdb.LPush(ctx, "today_cost", 30, 1.5, 10, 8) + + sortResult, err := rdb.Sort(ctx, "today_cost", &redis.Sort{}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1.5", "8", "10", "30"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "today_cost", &redis.Sort{Order: "ASC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1.5", "8", "10", "30"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "today_cost", &redis.Sort{Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"30", "10", "8", "1.5"}, sortResult) + + sortResult, err = rdb.SortRO(ctx, "today_cost", &redis.Sort{Order: "ASC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1.5", "8", "10", "30"}, sortResult) + + sortResult, err = rdb.SortRO(ctx, "today_cost", &redis.Sort{Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"30", "10", "8", "1.5"}, sortResult) + }) + + t.Run("SORT ALPHA", func(t *testing.T) { + rdb.LPush(ctx, "website", "www.reddit.com", "www.slashdot.com", "www.infoq.com") + + sortResult, err := rdb.Sort(ctx, "website", &redis.Sort{Alpha: true}).Result() + require.NoError(t, err) + require.Equal(t, []string{"www.infoq.com", "www.reddit.com", "www.slashdot.com"}, sortResult) + + _, err = rdb.Sort(ctx, "website", &redis.Sort{Alpha: false}).Result() + require.EqualError(t, err, "One or more scores can't be converted into double") + }) + + t.Run("SORT LIMIT", func(t *testing.T) { + rdb.RPush(ctx, "rank", 1, 3, 5, 7, 9, 2, 4, 6, 8, 10) + + sortResult, err := rdb.Sort(ctx, "rank", &redis.Sort{Offset: 0, Count: 5}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4", "5"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: 0, Count: 5, Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"10", "9", "8", "7", "6"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: -1, Count: 0}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: 10, Count: 0}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: 10, Count: 1}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: 11, Count: 1}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: -1, Count: 1}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: -2, Count: 2}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: -1, Count: 11}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: -2, Count: -1}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: -2, Count: -2}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}, sortResult) + }) + + t.Run("SORT BY + GET", func(t *testing.T) { + rdb.LPush(ctx, "uid", 1, 2, 3, 4) + rdb.MSet(ctx, "user_name_1", "admin", "user_name_2", "jack", "user_name_3", "peter", "user_name_4", "mary") + rdb.MSet(ctx, "user_level_1", 9999, "user_level_2", 10, "user_level_3", 25, "user_level_4", 70) + + sortResult, err := rdb.Sort(ctx, "uid", &redis.Sort{By: "user_level_*"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"2", "3", "4", "1"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{Get: []string{"user_name_*"}}).Result() + require.NoError(t, err) + require.Equal(t, []string{"admin", "jack", "peter", "mary"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "user_level_*", Get: []string{"user_name_*"}}).Result() + require.NoError(t, err) + require.Equal(t, []string{"jack", "peter", "mary", "admin"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{Get: []string{"user_level_*", "user_name_*"}}).Result() + require.NoError(t, err) + require.Equal(t, []string{"9999", "admin", "10", "jack", "25", "peter", "70", "mary"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{Get: []string{"#", "user_level_*", "user_name_*"}}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "9999", "admin", "2", "10", "jack", "3", "25", "peter", "4", "70", "mary"}, sortResult) + + // not sorted + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"4", "3", "2", "1"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 0, Count: 1}).Result() + require.NoError(t, err) + require.Equal(t, []string{"4"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 1, Count: 2}).Result() + require.NoError(t, err) + require.Equal(t, []string{"3", "2"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 4, Count: 1}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 4, Count: 0}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 1, Count: -1}).Result() + require.NoError(t, err) + require.Equal(t, []string{"3", "2", "1"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 0, Count: 1, Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 1, Count: 2, Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"2", "3"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 4, Count: 1, Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 4, Count: 0, Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 1, Count: -1, Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"2", "3", "4"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Get: []string{"#", "user_level_*", "user_name_*"}}).Result() + require.NoError(t, err) + require.Equal(t, []string{"4", "70", "mary", "3", "25", "peter", "2", "10", "jack", "1", "9999", "admin"}, sortResult) + + // pattern with hash tag + rdb.HMSet(ctx, "user_info_1", "name", "admin", "level", 9999) + rdb.HMSet(ctx, "user_info_2", "name", "jack", "level", 10) + rdb.HMSet(ctx, "user_info_3", "name", "peter", "level", 25) + rdb.HMSet(ctx, "user_info_4", "name", "mary", "level", 70) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "user_info_*->level"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"2", "3", "4", "1"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "user_info_*->level", Get: []string{"user_info_*->name"}}).Result() + require.NoError(t, err) + require.Equal(t, []string{"jack", "peter", "mary", "admin"}, sortResult) + + // get/by empty and nil + rdb.LPush(ctx, "uid_empty_nil", 4, 5, 6) + rdb.MSet(ctx, "user_name_5", "tom", "user_level_5", -1) + + getResult, err := rdb.Do(ctx, "Sort", "uid_empty_nil", "Get", "user_name_*").Slice() + require.NoError(t, err) + require.Equal(t, []interface{}{"mary", "tom", nil}, getResult) + byResult, err := rdb.Do(ctx, "Sort", "uid_empty_nil", "By", "user_level_*").Slice() + require.NoError(t, err) + require.Equal(t, []interface{}{"5", "6", "4"}, byResult) + + rdb.MSet(ctx, "user_name_6", "", "user_level_6", "") + + getResult, err = rdb.Do(ctx, "Sort", "uid_empty_nil", "Get", "user_name_*").Slice() + require.NoError(t, err) + require.Equal(t, []interface{}{"mary", "tom", ""}, getResult) + + byResult, err = rdb.Do(ctx, "Sort", "uid_empty_nil", "By", "user_level_*").Slice() + require.NoError(t, err) + require.Equal(t, []interface{}{"5", "6", "4"}, byResult) + }) + + t.Run("SORT STORE", func(t *testing.T) { + rdb.RPush(ctx, "numbers", 1, 3, 5, 7, 9, 2, 4, 6, 8, 10) + + storedLen, err := rdb.Do(ctx, "Sort", "numbers", "STORE", "sorted-numbers").Result() + require.NoError(t, err) + require.Equal(t, int64(10), storedLen) + + sortResult, err := rdb.LRange(ctx, "sorted-numbers", 0, -1).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}, sortResult) + + rdb.LPush(ctx, "no-force-alpha-sort-key", 123, 3, 21) + storedLen, err = rdb.Do(ctx, "Sort", "no-force-alpha-sort-key", "BY", "not-exists-key", "STORE", "no-alpha-sorted").Result() + require.NoError(t, err) + require.Equal(t, int64(3), storedLen) + + sortResult, err = rdb.LRange(ctx, "no-alpha-sorted", 0, -1).Result() + require.NoError(t, err) + require.Equal(t, []string{"21", "3", "123"}, sortResult) + + // get empty and nil + rdb.LPush(ctx, "uid_get_empty_nil", 4, 5, 6) + rdb.MSet(ctx, "user_name_4", "mary", "user_level_4", 70, "user_name_5", "tom", "user_level_5", -1) + + storedLen, err = rdb.Do(ctx, "Sort", "uid_get_empty_nil", "Get", "user_name_*", "Store", "get_empty_nil_store").Result() + require.NoError(t, err) + require.Equal(t, int64(3), storedLen) + + sortResult, err = rdb.LRange(ctx, "get_empty_nil_store", 0, -1).Result() + require.NoError(t, err) + require.Equal(t, []string{"mary", "tom", ""}, sortResult) + + rdb.MSet(ctx, "user_name_6", "", "user_level_6", "") + storedLen, err = rdb.Do(ctx, "Sort", "uid_get_empty_nil", "Get", "user_name_*", "Store", "get_empty_nil_store").Result() + require.NoError(t, err) + require.Equal(t, int64(3), storedLen) + + sortResult, err = rdb.LRange(ctx, "get_empty_nil_store", 0, -1).Result() + require.NoError(t, err) + require.Equal(t, []string{"mary", "tom", ""}, sortResult) + }) +} + +func TestSetSort(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + t.Run("SORT Basic", func(t *testing.T) { + rdb.SAdd(ctx, "today_cost", 30, 1.5, 10, 8) + + sortResult, err := rdb.Sort(ctx, "today_cost", &redis.Sort{}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1.5", "8", "10", "30"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "today_cost", &redis.Sort{Order: "ASC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1.5", "8", "10", "30"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "today_cost", &redis.Sort{Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"30", "10", "8", "1.5"}, sortResult) + + sortResult, err = rdb.SortRO(ctx, "today_cost", &redis.Sort{Order: "ASC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1.5", "8", "10", "30"}, sortResult) + + sortResult, err = rdb.SortRO(ctx, "today_cost", &redis.Sort{Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"30", "10", "8", "1.5"}, sortResult) + }) + + t.Run("SORT ALPHA", func(t *testing.T) { + rdb.SAdd(ctx, "website", "www.reddit.com", "www.slashdot.com", "www.infoq.com") + + sortResult, err := rdb.Sort(ctx, "website", &redis.Sort{Alpha: true}).Result() + require.NoError(t, err) + require.Equal(t, []string{"www.infoq.com", "www.reddit.com", "www.slashdot.com"}, sortResult) + + _, err = rdb.Sort(ctx, "website", &redis.Sort{Alpha: false}).Result() + require.EqualError(t, err, "One or more scores can't be converted into double") + }) + + t.Run("SORT LIMIT", func(t *testing.T) { + rdb.SAdd(ctx, "rank", 1, 3, 5, 7, 9, 2, 4, 6, 8, 10) + + sortResult, err := rdb.Sort(ctx, "rank", &redis.Sort{Offset: 0, Count: 5}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4", "5"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: 0, Count: 5, Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"10", "9", "8", "7", "6"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: -1, Count: 0}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: 10, Count: 0}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: 10, Count: 1}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: 11, Count: 1}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: -1, Count: 1}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: -2, Count: 2}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: -1, Count: 11}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: -2, Count: -1}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: -2, Count: -2}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}, sortResult) + }) + + t.Run("SORT BY + GET", func(t *testing.T) { + rdb.SAdd(ctx, "uid", 4, 3, 2, 1) + rdb.MSet(ctx, "user_name_1", "admin", "user_name_2", "jack", "user_name_3", "peter", "user_name_4", "mary") + rdb.MSet(ctx, "user_level_1", 9999, "user_level_2", 10, "user_level_3", 25, "user_level_4", 70) + + sortResult, err := rdb.Sort(ctx, "uid", &redis.Sort{By: "user_level_*"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"2", "3", "4", "1"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{Get: []string{"user_name_*"}}).Result() + require.NoError(t, err) + require.Equal(t, []string{"admin", "jack", "peter", "mary"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "user_level_*", Get: []string{"user_name_*"}}).Result() + require.NoError(t, err) + require.Equal(t, []string{"jack", "peter", "mary", "admin"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{Get: []string{"user_level_*", "user_name_*"}}).Result() + require.NoError(t, err) + require.Equal(t, []string{"9999", "admin", "10", "jack", "25", "peter", "70", "mary"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{Get: []string{"#", "user_level_*", "user_name_*"}}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "9999", "admin", "2", "10", "jack", "3", "25", "peter", "4", "70", "mary"}, sortResult) + + // not sorted + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 0, Count: 1}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 1, Count: 2}).Result() + require.NoError(t, err) + require.Equal(t, []string{"2", "3"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 4, Count: 1}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 4, Count: 0}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 1, Count: -1}).Result() + require.NoError(t, err) + require.Equal(t, []string{"2", "3", "4"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 0, Count: 1, Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 1, Count: 2, Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"2", "3"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 4, Count: 1, Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 4, Count: 0, Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 1, Count: -1, Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"2", "3", "4"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Get: []string{"#", "user_level_*", "user_name_*"}}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "9999", "admin", "2", "10", "jack", "3", "25", "peter", "4", "70", "mary"}, sortResult) + + // pattern with hash tag + rdb.HMSet(ctx, "user_info_1", "name", "admin", "level", 9999) + rdb.HMSet(ctx, "user_info_2", "name", "jack", "level", 10) + rdb.HMSet(ctx, "user_info_3", "name", "peter", "level", 25) + rdb.HMSet(ctx, "user_info_4", "name", "mary", "level", 70) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "user_info_*->level"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"2", "3", "4", "1"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "user_info_*->level", Get: []string{"user_info_*->name"}}).Result() + require.NoError(t, err) + require.Equal(t, []string{"jack", "peter", "mary", "admin"}, sortResult) + + // get/by empty and nil + rdb.SAdd(ctx, "uid_empty_nil", 4, 5, 6) + rdb.MSet(ctx, "user_name_5", "tom", "user_level_5", -1) + + getResult, err := rdb.Do(ctx, "Sort", "uid_empty_nil", "Get", "user_name_*").Slice() + require.NoError(t, err) + require.Equal(t, []interface{}{"mary", "tom", nil}, getResult) + byResult, err := rdb.Do(ctx, "Sort", "uid_empty_nil", "By", "user_level_*").Slice() + require.NoError(t, err) + require.Equal(t, []interface{}{"5", "6", "4"}, byResult) + + rdb.MSet(ctx, "user_name_6", "", "user_level_6", "") + + getResult, err = rdb.Do(ctx, "Sort", "uid_empty_nil", "Get", "user_name_*").Slice() + require.NoError(t, err) + require.Equal(t, []interface{}{"mary", "tom", ""}, getResult) + + byResult, err = rdb.Do(ctx, "Sort", "uid_empty_nil", "By", "user_level_*").Slice() + require.NoError(t, err) + require.Equal(t, []interface{}{"5", "6", "4"}, byResult) + + }) + + t.Run("SORT STORE", func(t *testing.T) { + rdb.SAdd(ctx, "numbers", 1, 3, 5, 7, 9, 2, 4, 6, 8, 10) + + storedLen, err := rdb.Do(ctx, "Sort", "numbers", "STORE", "sorted-numbers").Result() + require.NoError(t, err) + require.Equal(t, int64(10), storedLen) + + sortResult, err := rdb.LRange(ctx, "sorted-numbers", 0, -1).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}, sortResult) + + rdb.SAdd(ctx, "force-alpha-sort-key", 123, 3, 21) + storedLen, err = rdb.Do(ctx, "Sort", "force-alpha-sort-key", "BY", "not-exists-key", "STORE", "alpha-sorted").Result() + require.NoError(t, err) + require.Equal(t, int64(3), storedLen) + + sortResult, err = rdb.LRange(ctx, "alpha-sorted", 0, -1).Result() + require.NoError(t, err) + require.Equal(t, []string{"123", "21", "3"}, sortResult) + + // get empty and nil + rdb.SAdd(ctx, "uid_get_empty_nil", 4, 5, 6) + rdb.MSet(ctx, "user_name_4", "mary", "user_level_4", 70, "user_name_5", "tom", "user_level_5", -1) + + storedLen, err = rdb.Do(ctx, "Sort", "uid_get_empty_nil", "Get", "user_name_*", "Store", "get_empty_nil_store").Result() + require.NoError(t, err) + require.Equal(t, int64(3), storedLen) + + sortResult, err = rdb.LRange(ctx, "get_empty_nil_store", 0, -1).Result() + require.NoError(t, err) + require.Equal(t, []string{"mary", "tom", ""}, sortResult) + + rdb.MSet(ctx, "user_name_6", "", "user_level_6", "") + storedLen, err = rdb.Do(ctx, "Sort", "uid_get_empty_nil", "Get", "user_name_*", "Store", "get_empty_nil_store").Result() + require.NoError(t, err) + require.Equal(t, int64(3), storedLen) + + sortResult, err = rdb.LRange(ctx, "get_empty_nil_store", 0, -1).Result() + require.NoError(t, err) + require.Equal(t, []string{"mary", "tom", ""}, sortResult) + }) +} + +func TestZSetSort(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + t.Run("SORT Basic", func(t *testing.T) { + rdb.ZAdd(ctx, "today_cost", redis.Z{Score: 30, Member: "1"}, redis.Z{Score: 1.5, Member: "2"}, redis.Z{Score: 10, Member: "3"}, redis.Z{Score: 8, Member: "4"}) + + sortResult, err := rdb.Sort(ctx, "today_cost", &redis.Sort{}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "today_cost", &redis.Sort{Order: "ASC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "today_cost", &redis.Sort{Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"4", "3", "2", "1"}, sortResult) + + sortResult, err = rdb.SortRO(ctx, "today_cost", &redis.Sort{Order: "ASC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4"}, sortResult) + + sortResult, err = rdb.SortRO(ctx, "today_cost", &redis.Sort{Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"4", "3", "2", "1"}, sortResult) + }) + + t.Run("SORT ALPHA", func(t *testing.T) { + rdb.ZAdd(ctx, "website", redis.Z{Score: 1, Member: "www.reddit.com"}, redis.Z{Score: 2, Member: "www.slashdot.com"}, redis.Z{Score: 3, Member: "www.infoq.com"}) + + sortResult, err := rdb.Sort(ctx, "website", &redis.Sort{Alpha: true}).Result() + require.NoError(t, err) + require.Equal(t, []string{"www.infoq.com", "www.reddit.com", "www.slashdot.com"}, sortResult) + + _, err = rdb.Sort(ctx, "website", &redis.Sort{Alpha: false}).Result() + require.EqualError(t, err, "One or more scores can't be converted into double") + }) + + t.Run("SORT LIMIT", func(t *testing.T) { + rdb.ZAdd(ctx, "rank", + redis.Z{Score: 1, Member: "1"}, + redis.Z{Score: 2, Member: "3"}, + redis.Z{Score: 3, Member: "5"}, + redis.Z{Score: 4, Member: "7"}, + redis.Z{Score: 5, Member: "9"}, + redis.Z{Score: 6, Member: "2"}, + redis.Z{Score: 7, Member: "4"}, + redis.Z{Score: 8, Member: "6"}, + redis.Z{Score: 9, Member: "8"}, + redis.Z{Score: 10, Member: "10"}, + ) + + sortResult, err := rdb.Sort(ctx, "rank", &redis.Sort{Offset: 0, Count: 5}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4", "5"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: 0, Count: 5, Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"10", "9", "8", "7", "6"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: -1, Count: 0}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: 10, Count: 0}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: 10, Count: 1}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: 11, Count: 1}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: -1, Count: 1}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: -2, Count: 2}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: -1, Count: 11}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: -2, Count: -1}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "rank", &redis.Sort{Offset: -2, Count: -2}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}, sortResult) + }) + + t.Run("SORT BY + GET", func(t *testing.T) { + rdb.ZAdd(ctx, "uid", + redis.Z{Score: 1, Member: "4"}, + redis.Z{Score: 2, Member: "3"}, + redis.Z{Score: 3, Member: "2"}, + redis.Z{Score: 4, Member: "1"}) + + rdb.MSet(ctx, "user_name_1", "admin", "user_name_2", "jack", "user_name_3", "peter", "user_name_4", "mary") + rdb.MSet(ctx, "user_level_1", 9999, "user_level_2", 10, "user_level_3", 25, "user_level_4", 70) + + sortResult, err := rdb.Sort(ctx, "uid", &redis.Sort{By: "user_level_*"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"2", "3", "4", "1"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{Get: []string{"user_name_*"}}).Result() + require.NoError(t, err) + require.Equal(t, []string{"admin", "jack", "peter", "mary"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "user_level_*", Get: []string{"user_name_*"}}).Result() + require.NoError(t, err) + require.Equal(t, []string{"jack", "peter", "mary", "admin"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{Get: []string{"user_level_*", "user_name_*"}}).Result() + require.NoError(t, err) + require.Equal(t, []string{"9999", "admin", "10", "jack", "25", "peter", "70", "mary"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{Get: []string{"#", "user_level_*", "user_name_*"}}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "9999", "admin", "2", "10", "jack", "3", "25", "peter", "4", "70", "mary"}, sortResult) + + // not sorted + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"4", "3", "2", "1"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 0, Count: 1}).Result() + require.NoError(t, err) + require.Equal(t, []string{"4"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 1, Count: 2}).Result() + require.NoError(t, err) + require.Equal(t, []string{"3", "2"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 4, Count: 1}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 4, Count: 0}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 1, Count: -1}).Result() + require.NoError(t, err) + require.Equal(t, []string{"3", "2", "1"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 0, Count: 1, Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"1"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 1, Count: 2, Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"2", "3"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 4, Count: 1, Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 4, Count: 0, Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Offset: 1, Count: -1, Order: "DESC"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"2", "3", "4"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "not-exists-key", Get: []string{"#", "user_level_*", "user_name_*"}}).Result() + require.NoError(t, err) + require.Equal(t, []string{"4", "70", "mary", "3", "25", "peter", "2", "10", "jack", "1", "9999", "admin"}, sortResult) + + // pattern with hash tag + rdb.HMSet(ctx, "user_info_1", "name", "admin", "level", 9999) + rdb.HMSet(ctx, "user_info_2", "name", "jack", "level", 10) + rdb.HMSet(ctx, "user_info_3", "name", "peter", "level", 25) + rdb.HMSet(ctx, "user_info_4", "name", "mary", "level", 70) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "user_info_*->level"}).Result() + require.NoError(t, err) + require.Equal(t, []string{"2", "3", "4", "1"}, sortResult) + + sortResult, err = rdb.Sort(ctx, "uid", &redis.Sort{By: "user_info_*->level", Get: []string{"user_info_*->name"}}).Result() + require.NoError(t, err) + require.Equal(t, []string{"jack", "peter", "mary", "admin"}, sortResult) + + // get/by empty and nil + rdb.ZAdd(ctx, "uid_empty_nil", + redis.Z{Score: 4, Member: "6"}, + redis.Z{Score: 5, Member: "5"}, + redis.Z{Score: 6, Member: "4"}) + rdb.MSet(ctx, "user_name_5", "tom", "user_level_5", -1) + + getResult, err := rdb.Do(ctx, "Sort", "uid_empty_nil", "Get", "user_name_*").Slice() + require.NoError(t, err) + require.Equal(t, []interface{}{"mary", "tom", nil}, getResult) + byResult, err := rdb.Do(ctx, "Sort", "uid_empty_nil", "By", "user_level_*").Slice() + require.NoError(t, err) + require.Equal(t, []interface{}{"5", "6", "4"}, byResult) + + rdb.MSet(ctx, "user_name_6", "", "user_level_6", "") + + getResult, err = rdb.Do(ctx, "Sort", "uid_empty_nil", "Get", "user_name_*").Slice() + require.NoError(t, err) + require.Equal(t, []interface{}{"mary", "tom", ""}, getResult) + + byResult, err = rdb.Do(ctx, "Sort", "uid_empty_nil", "By", "user_level_*").Slice() + require.NoError(t, err) + require.Equal(t, []interface{}{"5", "6", "4"}, byResult) + }) + + t.Run("SORT STORE", func(t *testing.T) { + rdb.ZAdd(ctx, "numbers", + redis.Z{Score: 1, Member: "1"}, + redis.Z{Score: 2, Member: "3"}, + redis.Z{Score: 3, Member: "5"}, + redis.Z{Score: 4, Member: "7"}, + redis.Z{Score: 5, Member: "9"}, + redis.Z{Score: 6, Member: "2"}, + redis.Z{Score: 7, Member: "4"}, + redis.Z{Score: 8, Member: "6"}, + redis.Z{Score: 9, Member: "8"}, + redis.Z{Score: 10, Member: "10"}, + ) + + storedLen, err := rdb.Do(ctx, "Sort", "numbers", "STORE", "sorted-numbers").Result() + require.NoError(t, err) + require.Equal(t, int64(10), storedLen) + + sortResult, err := rdb.LRange(ctx, "sorted-numbers", 0, -1).Result() + require.NoError(t, err) + require.Equal(t, []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}, sortResult) + + rdb.ZAdd(ctx, "no-force-alpha-sort-key", + redis.Z{Score: 1, Member: "123"}, + redis.Z{Score: 2, Member: "3"}, + redis.Z{Score: 3, Member: "21"}, + ) + + storedLen, err = rdb.Do(ctx, "Sort", "no-force-alpha-sort-key", "BY", "not-exists-key", "STORE", "no-alpha-sorted").Result() + require.NoError(t, err) + require.Equal(t, int64(3), storedLen) + + sortResult, err = rdb.LRange(ctx, "no-alpha-sorted", 0, -1).Result() + require.NoError(t, err) + require.Equal(t, []string{"123", "3", "21"}, sortResult) + + // get empty and nil + rdb.ZAdd(ctx, "uid_get_empty_nil", + redis.Z{Score: 4, Member: "6"}, + redis.Z{Score: 5, Member: "5"}, + redis.Z{Score: 6, Member: "4"}) + rdb.MSet(ctx, "user_name_4", "mary", "user_level_4", 70, "user_name_5", "tom", "user_level_5", -1) + + storedLen, err = rdb.Do(ctx, "Sort", "uid_get_empty_nil", "Get", "user_name_*", "Store", "get_empty_nil_store").Result() + require.NoError(t, err) + require.Equal(t, int64(3), storedLen) + + sortResult, err = rdb.LRange(ctx, "get_empty_nil_store", 0, -1).Result() + require.NoError(t, err) + require.Equal(t, []string{"mary", "tom", ""}, sortResult) + + rdb.MSet(ctx, "user_name_6", "", "user_level_6", "") + storedLen, err = rdb.Do(ctx, "Sort", "uid_get_empty_nil", "Get", "user_name_*", "Store", "get_empty_nil_store").Result() + require.NoError(t, err) + require.Equal(t, int64(3), storedLen) + + sortResult, err = rdb.LRange(ctx, "get_empty_nil_store", 0, -1).Result() + require.NoError(t, err) + require.Equal(t, []string{"mary", "tom", ""}, sortResult) + }) +} diff --git a/tests/gocase/unit/type/bitmap/bitmap_test.go b/tests/gocase/unit/type/bitmap/bitmap_test.go index 508f52dd82b..55528e30991 100644 --- a/tests/gocase/unit/type/bitmap/bitmap_test.go +++ b/tests/gocase/unit/type/bitmap/bitmap_test.go @@ -378,4 +378,194 @@ func TestBitmap(t *testing.T) { require.EqualValues(t, 't', res.Val()[0]) require.ErrorContains(t, rdb.Do(ctx, "BITFIELD_RO", "str", "INCRBY", "u8", "32", 2).Err(), "BITFIELD_RO only supports the GET subcommand") }) + + t.Run("BITPOS BIT option check", func(t *testing.T) { + require.NoError(t, rdb.Set(ctx, "mykey", "\x00\xff\xf0", 0).Err()) + cmd := rdb.BitPosSpan(ctx, "mykey", 1, 7, 15, "bit") + require.NoError(t, cmd.Err()) + require.EqualValues(t, 8, cmd.Val()) + }) + + t.Run("BITPOS BIT not found check check", func(t *testing.T) { + require.NoError(t, rdb.Set(ctx, "mykey", "\x00\xff\xf0", 0).Err()) + cmd := rdb.BitPosSpan(ctx, "mykey", 0, 0, 5, "bit") + require.NoError(t, cmd.Err()) + require.EqualValues(t, 0, cmd.Val()) + }) + + t.Run("BITPOS BIT not found check check", func(t *testing.T) { + require.NoError(t, rdb.Set(ctx, "mykey", "\x00\xff\xf0", 0).Err()) + cmd := rdb.BitPosSpan(ctx, "mykey", 0, 2, 3, "bit") + require.NoError(t, cmd.Err()) + require.EqualValues(t, 2, cmd.Val()) + }) + + /* Test cases adapted from redis test cases : https://github.com/redis/redis/blob/unstable/tests/unit/bitops.tcl + */ + t.Run("BITPOS bit=0 with empty key returns 0", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "str").Err()) + cmd := rdb.BitPosSpan(ctx, "str", 0, 0, -1, "bit") + require.NoError(t, cmd.Err()) + require.EqualValues(t, 0, cmd.Val()) + }) + + t.Run("BITPOS bit=0 with string less than 1 word works", func(t *testing.T) { + require.NoError(t, rdb.Set(ctx, "str", "\xff\xf0\x00", 0).Err()) + cmd := rdb.BitPosSpan(ctx, "str", 0, 0, -1, "bit") + require.NoError(t, cmd.Err()) + require.EqualValues(t, 12, cmd.Val()) + }) + + t.Run("BITPOS bit=1 with string less than 1 word works", func(t *testing.T) { + require.NoError(t, rdb.Set(ctx, "str", "\x00\x0f\x00", 0).Err()) + cmd := rdb.BitPosSpan(ctx, "str", 1, 0, -1, "bit") + require.NoError(t, cmd.Err()) + require.EqualValues(t, 12, cmd.Val()) + }) + + t.Run("BITPOS bit=0 starting at unaligned address", func(t *testing.T) { + require.NoError(t, rdb.Set(ctx, "str", "\xff\xf0\x00", 0).Err()) + cmd := rdb.BitPosSpan(ctx, "str", 0, 1, -1, "bit") + require.NoError(t, cmd.Err()) + require.EqualValues(t, 12, cmd.Val()) + }) + + t.Run("BITPOS bit=1 starting at unaligned address", func(t *testing.T) { + require.NoError(t, rdb.Set(ctx, "str", "\x00\x0f\xff", 0).Err()) + cmd := rdb.BitPosSpan(ctx, "str", 1, 1, -1, "bit") + require.NoError(t, cmd.Err()) + require.EqualValues(t, 12, cmd.Val()) + }) + + t.Run("BITPOS bit=0 unaligned+full word+reminder", func(t *testing.T) { + require.NoError(t, rdb.Set(ctx, "str", "\xff\xff\xff", 0).Err()) + require.NoError(t, rdb.Append(ctx, "str", "\xff\xff\xff\xff\xff\xff\xff\xff").Err()) + require.NoError(t, rdb.Append(ctx, "str", "\xff\xff\xff\xff\xff\xff\xff\xff").Err()) + require.NoError(t, rdb.Append(ctx, "str", "\xff\xff\xff\xff\xff\xff\xff\xff").Err()) + require.NoError(t, rdb.Append(ctx, "str", "\x0f").Err()) + // Test values 1, 9, 17, 25, 33, 41, 49, 57, 65 + for i := 0; i < 9; i++ { + if i == 6 { + cmd := rdb.BitPosSpan(ctx, "str", 0, 41, -1, "bit") + require.NoError(t, cmd.Err()) + require.EqualValues(t, 216, cmd.Val()) + } else { + cmd := rdb.BitPosSpan(ctx, "str", 0, int64(i*8)+1, -1, "bit") + require.NoError(t, cmd.Err()) + require.EqualValues(t, 216, cmd.Val()) + } + } + }) + + t.Run("BITPOS bit=1 unaligned+full word+reminder", func(t *testing.T) { + require.NoError(t, rdb.Set(ctx, "str", "\x00\x00\x00", 0).Err()) + require.NoError(t, rdb.Append(ctx, "str", "\x00\x00\x00\x00\x00\x00\x00\x00").Err()) + require.NoError(t, rdb.Append(ctx, "str", "\x00\x00\x00\x00\x00\x00\x00\x00").Err()) + require.NoError(t, rdb.Append(ctx, "str", "\x00\x00\x00\x00\x00\x00\x00\x00").Err()) + require.NoError(t, rdb.Append(ctx, "str", "\xf0").Err()) + // Test values 1, 9, 17, 25, 33, 41, 49, 57, 65 + for i := 0; i < 9; i++ { + if i == 6 { + cmd := rdb.BitPosSpan(ctx, "str", 1, 41, -1, "bit") + require.NoError(t, cmd.Err()) + require.EqualValues(t, 216, cmd.Val()) + } else { + cmd := rdb.BitPosSpan(ctx, "str", 1, int64(i*8)+1, -1, "bit") + require.NoError(t, cmd.Err()) + require.EqualValues(t, 216, cmd.Val()) + } + } + }) + + t.Run("BITPOS bit=1 returns -1 if string is all 0 bits", func(t *testing.T) { + require.NoError(t, rdb.Set(ctx, "str", "", 0).Err()) + for i := 0; i < 20; i++ { + cmd := rdb.BitPosSpan(ctx, "str", 1, 0, -1, "bit") + require.NoError(t, cmd.Err()) + require.EqualValues(t, -1, cmd.Val()) + require.NoError(t, rdb.Append(ctx, "str", "\x00").Err()) + } + }) + + t.Run("BITPOS bit=0 works with intervals", func(t *testing.T) { + require.NoError(t, rdb.Set(ctx, "str", "\x00\xff\x00", 0).Err()) + cmd := rdb.BitPosSpan(ctx, "str", 0, 0, -1, "bit") + require.NoError(t, cmd.Err()) + require.EqualValues(t, 0, cmd.Val()) + cmd = rdb.BitPosSpan(ctx, "str", 0, 8, -1, "bit") + require.NoError(t, cmd.Err()) + require.EqualValues(t, 16, cmd.Val()) + cmd = rdb.BitPosSpan(ctx, "str", 0, 16, -1, "bit") + require.NoError(t, cmd.Err()) + require.EqualValues(t, 16, cmd.Val()) + cmd = rdb.BitPosSpan(ctx, "str", 0, 16, 200, "bit") + require.NoError(t, cmd.Err()) + require.EqualValues(t, 16, cmd.Val()) + cmd = rdb.BitPosSpan(ctx, "str", 0, 8, 8, "bit") + require.NoError(t, cmd.Err()) + require.EqualValues(t, -1, cmd.Val()) + }) + + t.Run("BITPOS bit=1 works with intervals", func(t *testing.T) { + require.NoError(t, rdb.Set(ctx, "str", "\x00\xff\x00", 0).Err()) + cmd := rdb.BitPosSpan(ctx, "str", 1, 0, -1, "bit") + require.NoError(t, cmd.Err()) + require.EqualValues(t, 8, cmd.Val()) + cmd = rdb.BitPosSpan(ctx, "str", 1, 8, -1, "bit") + require.NoError(t, cmd.Err()) + require.EqualValues(t, 8, cmd.Val()) + cmd = rdb.BitPosSpan(ctx, "str", 1, 16, -1, "bit") + require.NoError(t, cmd.Err()) + require.EqualValues(t, -1, cmd.Val()) + cmd = rdb.BitPosSpan(ctx, "str", 1, 16, 200, "bit") + require.NoError(t, cmd.Err()) + require.EqualValues(t, -1, cmd.Val()) + cmd = rdb.BitPosSpan(ctx, "str", 1, 8, 8, "bit") + require.NoError(t, cmd.Err()) + require.EqualValues(t, 8, cmd.Val()) + }) + + t.Run("BITPOS bit=0 changes behavior if end is given", func(t *testing.T) { + require.NoError(t, rdb.Set(ctx, "str", "\xff\xff\xff", 0).Err()) + cmd := rdb.BitPosSpan(ctx, "str", 0, 0, -1, "bit") + require.NoError(t, cmd.Err()) + require.EqualValues(t, -1, cmd.Val()) + }) + + t.Run("BITPOS bit=1 fuzzy testing using SETBIT", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "str").Err()) + var max int64 = 524288 + var firstOnePos int64 = -1 + for j := 0; j < 1000; j++ { + cmd := rdb.BitPosSpan(ctx, "str", 1, 0, -1, "bit") + require.NoError(t, cmd.Err()) + require.EqualValues(t, firstOnePos, cmd.Val()) + pos := util.RandomInt(max) + require.NoError(t, rdb.SetBit(ctx, "str", int64(pos), 1).Err()) + if firstOnePos == -1 || firstOnePos > pos { + firstOnePos = pos + } + } + }) + + t.Run("BITPOS bit=0 fuzzy testing using SETBIT", func(t *testing.T) { + var max int64 = 524288 + firstZeroPos := max + require.NoError(t, rdb.Set(ctx, "str", strings.Repeat("\xff", int(max/8)), 0).Err()) + for j := 0; j < 1000; j++ { + cmd := rdb.BitPosSpan(ctx, "str", 0, 0, -1, "bit") + require.NoError(t, cmd.Err()) + if firstZeroPos == max { + require.EqualValues(t, -1, cmd.Val()) + } else { + require.EqualValues(t, firstZeroPos, cmd.Val()) + } + pos := util.RandomInt(max) + require.NoError(t, rdb.SetBit(ctx, "str", int64(pos), 0).Err()) + if firstZeroPos > pos { + firstZeroPos = pos + } + } + }) + } diff --git a/tests/gocase/unit/type/hash/hash_test.go b/tests/gocase/unit/type/hash/hash_test.go index 38b0a576089..3c769649295 100644 --- a/tests/gocase/unit/type/hash/hash_test.go +++ b/tests/gocase/unit/type/hash/hash_test.go @@ -895,3 +895,44 @@ func TestHashWithAsyncIOEnabled(t *testing.T) { require.Len(t, rdb.HVals(ctx, testKey).Val(), 50) }) } + +func TestHashWithAsyncIODisabled(t *testing.T) { + srv := util.StartServer(t, map[string]string{ + "rocksdb.read_options.async_io": "no", + }) + defer srv.Close() + + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + ctx := context.Background() + + t.Run("Test bug with large value after compaction", func(t *testing.T) { + testKey := "test-hash-1" + require.NoError(t, rdb.Del(ctx, testKey).Err()) + + src := rand.NewSource(time.Now().UnixNano()) + dd := make([]byte, 5000) + for i := 1; i <= 50; i++ { + for j := range dd { + dd[j] = byte(src.Int63()) + } + key := util.RandString(10, 20, util.Alpha) + require.NoError(t, rdb.HSet(ctx, testKey, key, string(dd)).Err()) + } + + require.EqualValues(t, 50, rdb.HLen(ctx, testKey).Val()) + require.Len(t, rdb.HGetAll(ctx, testKey).Val(), 50) + require.Len(t, rdb.HKeys(ctx, testKey).Val(), 50) + require.Len(t, rdb.HVals(ctx, testKey).Val(), 50) + + require.NoError(t, rdb.Do(ctx, "COMPACT").Err()) + + time.Sleep(5 * time.Second) + + require.EqualValues(t, 50, rdb.HLen(ctx, testKey).Val()) + require.Len(t, rdb.HGetAll(ctx, testKey).Val(), 50) + require.Len(t, rdb.HKeys(ctx, testKey).Val(), 50) + require.Len(t, rdb.HVals(ctx, testKey).Val(), 50) + }) +} diff --git a/tests/gocase/unit/type/json/json_test.go b/tests/gocase/unit/type/json/json_test.go index a1489f7acd4..e720c89be4d 100644 --- a/tests/gocase/unit/type/json/json_test.go +++ b/tests/gocase/unit/type/json/json_test.go @@ -180,6 +180,9 @@ func TestJson(t *testing.T) { result2 = append(result2, int64(3), int64(5), interface{}(nil)) require.NoError(t, rdb.Do(ctx, "JSON.SET", "a", "$", `{"a":"foo", "nested": {"a": "hello"}, "nested2": {"a": 31}}`).Err()) require.Equal(t, rdb.Do(ctx, "JSON.STRLEN", "a", "$..a").Val(), result2) + require.Error(t, rdb.Do(ctx, "JSON.STRLEN", "not_exists", "$").Err()) + require.ErrorIs(t, rdb.Do(ctx, "JSON.STRLEN", "not_exists").Err(), redis.Nil) + }) t.Run("Merge basics", func(t *testing.T) { @@ -353,7 +356,7 @@ func TestJson(t *testing.T) { t.Run("JSON.ARRTRIM basics", func(t *testing.T) { require.NoError(t, rdb.Del(ctx, "a").Err()) // key no exists - require.EqualError(t, rdb.Do(ctx, "JSON.ARRTRIM", "not_exists", "$", 0, 0).Err(), redis.Nil.Error()) + require.ErrorContains(t, rdb.Do(ctx, "JSON.ARRTRIM", "not_exists", "$", 0, 0).Err(), "could not perform this operation on a key that doesn't exist") // key not json require.NoError(t, rdb.Do(ctx, "SET", "no_json", "1").Err()) require.Error(t, rdb.Do(ctx, "JSON.ARRTRIM", "no_json", "$", 0, 0).Err()) @@ -510,6 +513,9 @@ func TestJson(t *testing.T) { EqualJSON(t, `[3]`, rdb.Do(ctx, "JSON.NUMINCRBY", "a", "$.foo", 2).Val()) EqualJSON(t, `[3.5]`, rdb.Do(ctx, "JSON.NUMINCRBY", "a", "$.foo", 0.5).Val()) + require.Error(t, rdb.Do(ctx, "JSON.NUMINCRBY", "a", "$.foo", "9e99999").Err()) + require.Error(t, rdb.Do(ctx, "JSON.NUMINCRBY", "a", "$.foo", "999999999999999999999999999999").Err()) + // wrong type require.Equal(t, `[null]`, rdb.Do(ctx, "JSON.NUMINCRBY", "a", "$.bar", 1).Val()) @@ -578,7 +584,9 @@ func TestJson(t *testing.T) { require.NoError(t, err) require.EqualValues(t, 0, len(vals)) - err = rdb.Do(ctx, "JSON.OBJLEN", "no-such-json-key", "$").Err() + require.Error(t, rdb.Do(ctx, "JSON.OBJLEN", "no-such-json-key", "$").Err()) + err = rdb.Do(ctx, "JSON.OBJLEN", "no-such-json-key").Err() + require.EqualError(t, err, redis.Nil.Error()) }) @@ -612,6 +620,43 @@ func TestJson(t *testing.T) { require.EqualValues(t, "[]", vals[1]) }) + + t.Run("JSON.MSET basics", func(t *testing.T) { + require.NoError(t, rdb.Do(ctx, "JSON.DEL", "a0").Err()) + require.Error(t, rdb.Do(ctx, "JSON.MSET", "a0", "$.a", `{"a": 1, "b": 2, "nested": {"a": 3}, "c": null}`, "a1", "$", `{"a": 4, "b": 5, "nested": {"a": 6}, "c": null}`).Err()) + require.NoError(t, rdb.Do(ctx, "JSON.MSET", "a0", "$", `{"a": 1, "b": 2, "nested": {"a": 3}, "c": null}`, "a1", "$", `{"a": 4, "b": 5, "nested": {"a": 6}, "c": null}`).Err()) + + EqualJSON(t, `{"a": 1, "b": 2, "nested": {"a": 3}, "c": null}`, rdb.Do(ctx, "JSON.GET", "a0").Val()) + EqualJSON(t, `[{"a": 1, "b": 2, "nested": {"a": 3}, "c": null}]`, rdb.Do(ctx, "JSON.GET", "a0", "$").Val()) + EqualJSON(t, `[1]`, rdb.Do(ctx, "JSON.GET", "a0", "$.a").Val()) + + EqualJSON(t, `{"a": 4, "b": 5, "nested": {"a": 6}, "c": null}`, rdb.Do(ctx, "JSON.GET", "a1").Val()) + EqualJSON(t, `[{"a": 4, "b": 5, "nested": {"a": 6}, "c": null}]`, rdb.Do(ctx, "JSON.GET", "a1", "$").Val()) + EqualJSON(t, `[4]`, rdb.Do(ctx, "JSON.GET", "a1", "$.a").Val()) + }) + + t.Run("JSON.DEBUG MEMORY basics", func(t *testing.T) { + require.NoError(t, rdb.Do(ctx, "JSON.SET", "a", "$", `{"b":true,"x":1, "y":1.2, "z": {"x":[1,2,3], "y": null}, "v":{"x":"y"},"f":{"x":[]}}`).Err()) + //object + var result1 = make([]interface{}, 0) + result1 = append(result1, int64(43)) + require.Equal(t, result1, rdb.Do(ctx, "JSON.DEBUG", "MEMORY", "a", "$").Val()) + //integer string array empty_array + var result2 = make([]interface{}, 0) + result2 = append(result2, int64(1), int64(1), int64(2), int64(4)) + require.Equal(t, result2, rdb.Do(ctx, "JSON.DEBUG", "MEMORY", "a", "$..x").Val()) + //null object + var result3 = make([]interface{}, 0) + result3 = append(result3, int64(9), int64(1)) + require.Equal(t, result3, rdb.Do(ctx, "JSON.DEBUG", "MEMORY", "a", "$..y").Val()) + //no no_exists + require.Equal(t, []interface{}{}, rdb.Do(ctx, "JSON.DEBUG", "MEMORY", "a", "$..no_exists").Val()) + //no key no path + require.Equal(t, rdb.Do(ctx, "JSON.DEBUG", "MEMORY", "not_exists").Val(), int64(0)) + //no key have path + require.Equal(t, []interface{}{}, rdb.Do(ctx, "JSON.DEBUG", "MEMORY", "not_exists", "$").Val()) + + }) } func EqualJSON(t *testing.T, expected string, actual interface{}) { diff --git a/tests/gocase/unit/type/stream/stream_test.go b/tests/gocase/unit/type/stream/stream_test.go index 66ae0705dab..4ba461ff341 100644 --- a/tests/gocase/unit/type/stream/stream_test.go +++ b/tests/gocase/unit/type/stream/stream_test.go @@ -867,355 +867,1043 @@ func TestStreamOffset(t *testing.T) { require.EqualValues(t, providedSeqNum, seqNum) }) - t.Run("XGROUP CREATE with different kinds of commands and XGROUP DESTROY", func(t *testing.T) { - streamName := "test-stream-a" - groupName := "test-group-a" - require.NoError(t, rdb.Del(ctx, streamName).Err()) - // No such stream (No such key) - require.Error(t, rdb.Do(ctx, "XGROUP", "CREATE", streamName, groupName, "$").Err()) - require.Error(t, rdb.Do(ctx, "XGROUP", "CREATE", streamName, groupName, "$", "ENTRIESREAD", "10").Err()) - require.Error(t, rdb.Do(ctx, "XGROUP", "CREATE", streamName, groupName, "$", "ENTRIESREAD").Err()) - require.Error(t, rdb.Do(ctx, "XGROUP", "CREATE", streamName, groupName, "$", "MKSTREAM", "ENTRIESREAD").Err()) - require.NoError(t, rdb.Do(ctx, "XGROUP", "CREATE", streamName, groupName, "$", "MKSTREAM").Err()) - require.NoError(t, rdb.XInfoStream(ctx, streamName).Err()) - require.Error(t, rdb.Do(ctx, "XGROUP", "CREATE", streamName, groupName, "$").Err()) - // Invalid syntax - groupName = "test-group-b" - require.Error(t, rdb.Do(ctx, "XGROUP", "CREAT", streamName, groupName, "$").Err()) - require.Error(t, rdb.Do(ctx, "XGROUP", "CREATE", streamName, groupName, "$", "ENTRIEREAD", "10").Err()) - require.Error(t, rdb.Do(ctx, "XGROUP", "CREATE", streamName, groupName, "$", "ENTRIESREAD", "-10").Err()) - require.Error(t, rdb.Do(ctx, "XGROUP", "CREATE", streamName, "1test-group-c", "$").Err()) - - require.NoError(t, rdb.Del(ctx, "myStream").Err()) - require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{Stream: "myStream", Values: []string{"iTeM", "1", "vAluE", "a"}}).Err()) - require.NoError(t, rdb.XGroupCreate(ctx, "myStream", "myGroup", "$").Err()) - result, err := rdb.XGroupDestroy(ctx, "myStream", "myGroup").Result() - require.NoError(t, err) - require.Equal(t, int64(1), result) - result, err = rdb.XGroupDestroy(ctx, "myStream", "myGroup").Result() - require.NoError(t, err) - require.Equal(t, int64(0), result) - }) - - t.Run("XGROUP CREATECONSUMER with different kinds of commands", func(t *testing.T) { - streamName := "test-stream" - groupName := "test-group" - consumerName := "test-consumer" - require.NoError(t, rdb.Del(ctx, streamName).Err()) - //No such stream - require.Error(t, rdb.XGroupCreateConsumer(ctx, streamName, groupName, consumerName).Err()) - require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ - Stream: streamName, - ID: "1-0", - Values: []string{"data", "a"}, - }).Err()) - //no such group - require.Error(t, rdb.XGroupCreateConsumer(ctx, streamName, groupName, consumerName).Err()) - require.NoError(t, rdb.XGroupCreate(ctx, streamName, groupName, "$").Err()) - - r := rdb.XGroupCreateConsumer(ctx, streamName, groupName, consumerName).Val() - require.Equal(t, int64(1), r) - r = rdb.XGroupCreateConsumer(ctx, streamName, groupName, consumerName).Val() - require.Equal(t, int64(0), r) - }) - - t.Run("XGROUP DELCONSUMER with different kinds of commands", func(t *testing.T) { - streamName := "test-stream" - groupName := "test-group" - consumerName := "test-consumer" - require.NoError(t, rdb.Del(ctx, streamName).Err()) - //No such stream - require.Error(t, rdb.XGroupCreateConsumer(ctx, streamName, groupName, consumerName).Err()) - require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ - Stream: streamName, - ID: "1-0", - Values: []string{"data", "a"}, - }).Err()) - //no such group - require.Error(t, rdb.XGroupCreateConsumer(ctx, streamName, groupName, consumerName).Err()) - require.NoError(t, rdb.XGroupCreate(ctx, streamName, groupName, "$").Err()) - require.NoError(t, rdb.XGroupCreateConsumer(ctx, streamName, groupName, consumerName).Err()) - - require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ - Stream: streamName, - ID: "*", - Values: []string{"data1", "a1"}, - }).Err()) - require.NoError(t, rdb.XReadGroup(ctx, &redis.XReadGroupArgs{ - Group: groupName, - Consumer: consumerName, - Streams: []string{streamName, ">"}, - Count: 1, - NoAck: false, - }).Err()) - ri, erri := rdb.XInfoGroups(ctx, streamName).Result() - require.NoError(t, erri) - require.Equal(t, int64(1), ri[0].Consumers) - require.Equal(t, int64(1), ri[0].Pending) - - r, err := rdb.XGroupDelConsumer(ctx, streamName, groupName, consumerName).Result() - require.NoError(t, err) - require.Equal(t, int64(1), r) - ri, erri = rdb.XInfoGroups(ctx, streamName).Result() - require.NoError(t, erri) - require.Equal(t, int64(0), ri[0].Consumers) - require.Equal(t, int64(0), ri[0].Pending) - - require.NoError(t, rdb.XGroupCreateConsumer(ctx, streamName, groupName, consumerName).Err()) - require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ - Stream: streamName, - ID: "*", - Values: []string{"data2", "a2"}, - }).Err()) - require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ - Stream: streamName, - ID: "*", - Values: []string{"data3", "a3"}, - }).Err()) - require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ - Stream: streamName, - ID: "*", - Values: []string{"data4", "a4"}, - }).Err()) - require.NoError(t, rdb.XReadGroup(ctx, &redis.XReadGroupArgs{ - Group: groupName, - Consumer: consumerName, - Streams: []string{streamName, ">"}, - Count: 3, - NoAck: false, - }).Err()) - ri, erri = rdb.XInfoGroups(ctx, streamName).Result() - require.NoError(t, erri) - require.Equal(t, int64(1), ri[0].Consumers) - require.Equal(t, int64(3), ri[0].Pending) - r, err = rdb.XGroupDelConsumer(ctx, streamName, groupName, consumerName).Result() - require.NoError(t, err) - require.Equal(t, int64(3), r) - ri, erri = rdb.XInfoGroups(ctx, streamName).Result() - require.NoError(t, erri) - require.Equal(t, int64(0), ri[0].Consumers) - require.Equal(t, int64(0), ri[0].Pending) - }) - - t.Run("XGROUP SETID with different kinds of commands", func(t *testing.T) { - streamName := "test-stream" - groupName := "test-group" - require.NoError(t, rdb.Del(ctx, streamName).Err()) - //No such stream - require.Error(t, rdb.XGroupSetID(ctx, streamName, groupName, "$").Err()) - require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ - Stream: streamName, - ID: "1-0", - Values: []string{"data", "a"}, - }).Err()) - //No such group - require.Error(t, rdb.XGroupSetID(ctx, streamName, groupName, "$").Err()) - require.NoError(t, rdb.XGroupCreate(ctx, streamName, groupName, "$").Err()) - - require.NoError(t, rdb.XGroupSetID(ctx, streamName, groupName, "0-0").Err()) - require.Error(t, rdb.Do(ctx, "xgroup", "setid", streamName, groupName, "$", "entries", "100").Err()) - require.Error(t, rdb.Do(ctx, "xgroup", "setid", streamName, groupName, "$", "entriesread", "-100").Err()) - require.NoError(t, rdb.Do(ctx, "xgroup", "setid", streamName, groupName, "$", "entriesread", "100").Err()) - }) - - t.Run("XINFO GROUPS and XINFO CONSUMERS", func(t *testing.T) { - streamName := "test-stream" - group1 := "t1" - group2 := "t2" - consumer1 := "c1" - consumer2 := "c2" - consumer3 := "c3" - require.NoError(t, rdb.Del(ctx, streamName).Err()) - require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ - Stream: streamName, - ID: "1-0", - Values: []string{"data", "a"}, - }).Err()) - require.NoError(t, rdb.XGroupCreate(ctx, streamName, group1, "$").Err()) - r := rdb.XInfoGroups(ctx, streamName).Val() - require.Equal(t, group1, r[0].Name) - require.Equal(t, int64(0), r[0].Consumers) - require.Equal(t, int64(0), r[0].Pending) - require.Equal(t, "1-0", r[0].LastDeliveredID) - require.Equal(t, int64(0), r[0].EntriesRead) - require.Equal(t, int64(0), r[0].Lag) - - require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ - Stream: streamName, - ID: "2-0", - Values: []string{"data1", "b"}, - }).Err()) - require.NoError(t, rdb.XGroupCreate(ctx, streamName, group2, "$").Err()) - r = rdb.XInfoGroups(ctx, streamName).Val() - require.Equal(t, group2, r[1].Name) - require.Equal(t, "2-0", r[1].LastDeliveredID) - - require.NoError(t, rdb.XGroupCreateConsumer(ctx, streamName, group1, consumer1).Err()) - require.NoError(t, rdb.XGroupCreateConsumer(ctx, streamName, group1, consumer2).Err()) - require.NoError(t, rdb.XGroupCreateConsumer(ctx, streamName, group2, consumer3).Err()) - r = rdb.XInfoGroups(ctx, streamName).Val() - require.Equal(t, int64(2), r[0].Consumers) - require.Equal(t, int64(1), r[1].Consumers) - - r1 := rdb.XInfoConsumers(ctx, streamName, group1).Val() - require.Equal(t, consumer1, r1[0].Name) - require.Equal(t, consumer2, r1[1].Name) - r1 = rdb.XInfoConsumers(ctx, streamName, group2).Val() - require.Equal(t, consumer3, r1[0].Name) - }) - - t.Run("XREAD After XGroupCreate and XGroupCreateConsumer, for issue #2109", func(t *testing.T) { - streamName := "test-stream" - group := "group" - require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ - Stream: streamName, - ID: "*", - Values: []string{"data1", "b"}, - }).Err()) - require.NoError(t, rdb.XGroupCreate(ctx, streamName, group, "0").Err()) - require.NoError(t, rdb.XGroupCreateConsumer(ctx, streamName, group, "consumer").Err()) - require.NoError(t, rdb.XRead(ctx, &redis.XReadArgs{ - Streams: []string{streamName, "0"}, - }).Err()) - }) - - t.Run("XREADGROUP with different kinds of commands", func(t *testing.T) { - streamName := "mystream" - groupName := "mygroup" - require.NoError(t, rdb.Del(ctx, streamName).Err()) - require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ - Stream: streamName, - ID: "1-0", - Values: []string{"field1", "data1"}, - }).Err()) - require.NoError(t, rdb.XGroupCreate(ctx, streamName, groupName, "0").Err()) - consumerName := "myconsumer" - r, err := rdb.XReadGroup(ctx, &redis.XReadGroupArgs{ - Group: groupName, - Consumer: consumerName, - Streams: []string{streamName, ">"}, - Count: 1, - NoAck: false, - }).Result() - require.NoError(t, err) - require.Equal(t, []redis.XStream{{ - Stream: streamName, - Messages: []redis.XMessage{{ID: "1-0", Values: map[string]interface{}{"field1": "data1"}}}, - }}, r) - - require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ - Stream: streamName, - ID: "2-0", - Values: []string{"field2", "data2"}, - }).Err()) - r, err = rdb.XReadGroup(ctx, &redis.XReadGroupArgs{ - Group: groupName, - Consumer: consumerName, - Streams: []string{streamName, ">"}, - Count: 1, - NoAck: false, - }).Result() - require.NoError(t, err) - require.Equal(t, []redis.XStream{{ - Stream: streamName, - Messages: []redis.XMessage{{ID: "2-0", Values: map[string]interface{}{"field2": "data2"}}}, - }}, r) - - r, err = rdb.XReadGroup(ctx, &redis.XReadGroupArgs{ - Group: groupName, - Consumer: consumerName, - Streams: []string{streamName, "0"}, - Count: 2, - NoAck: false, - }).Result() - require.NoError(t, err) - require.Equal(t, []redis.XStream{{ - Stream: streamName, - Messages: []redis.XMessage{{ID: "1-0", Values: map[string]interface{}{"field1": "data1"}}, - {ID: "2-0", Values: map[string]interface{}{"field2": "data2"}}}, - }}, r) - - require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ - Stream: streamName, - ID: "3-0", - Values: []string{"field3", "data3"}, - }).Err()) - r, err = rdb.XReadGroup(ctx, &redis.XReadGroupArgs{ - Group: groupName, - Consumer: consumerName, - Streams: []string{streamName, ">"}, - Count: 1, - NoAck: true, - }).Result() - require.NoError(t, err) - require.Equal(t, []redis.XStream{{ - Stream: streamName, - Messages: []redis.XMessage{{ID: "3-0", Values: map[string]interface{}{"field3": "data3"}}}, - }}, r) - r, err = rdb.XReadGroup(ctx, &redis.XReadGroupArgs{ - Group: groupName, - Consumer: consumerName, - Streams: []string{streamName, "0"}, - Count: 2, - NoAck: false, - }).Result() - require.NoError(t, err) - require.Equal(t, []redis.XStream{{ - Stream: streamName, - Messages: []redis.XMessage{{ID: "1-0", Values: map[string]interface{}{"field1": "data1"}}, - {ID: "2-0", Values: map[string]interface{}{"field2": "data2"}}}, - }}, r) - - c := srv.NewClient() - defer func() { require.NoError(t, c.Close()) }() - ch := make(chan []redis.XStream) - go func() { - ch <- c.XReadGroup(ctx, &redis.XReadGroupArgs{ - Group: groupName, - Consumer: consumerName, - Streams: []string{streamName, ">"}, - Count: 2, - Block: 10 * time.Second, - NoAck: false, - }).Val() - }() - require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ - Stream: streamName, - ID: "4-0", - Values: []string{"field4", "data4"}, - }).Err()) - r = <-ch - require.Equal(t, []redis.XStream{{ - Stream: streamName, - Messages: []redis.XMessage{{ID: "4-0", Values: map[string]interface{}{"field4": "data4"}}}, - }}, r) - - require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ - Stream: streamName, - ID: "5-0", - Values: []string{"field5", "data5"}, - }).Err()) - require.NoError(t, rdb.XReadGroup(ctx, &redis.XReadGroupArgs{ - Group: groupName, - Consumer: consumerName, - Streams: []string{streamName, ">"}, - Count: 1, - NoAck: false, - }).Err()) - require.NoError(t, rdb.XDel(ctx, streamName, "5-0").Err()) - r, err = rdb.XReadGroup(ctx, &redis.XReadGroupArgs{ - Group: groupName, - Consumer: consumerName, - Streams: []string{streamName, "5"}, - Count: 1, - NoAck: false, - }).Result() - require.NoError(t, err) - require.Equal(t, []redis.XStream{{ - Stream: streamName, - Messages: []redis.XMessage{{ID: "5-0", Values: map[string]interface{}(nil)}}, - }}, r) - }) + // t.Run("XGROUP CREATE with different kinds of commands and XGROUP DESTROY", func(t *testing.T) { + // streamName := "test-stream-a" + // groupName := "test-group-a" + // require.NoError(t, rdb.Del(ctx, streamName).Err()) + // // No such stream (No such key) + // require.Error(t, rdb.Do(ctx, "XGROUP", "CREATE", streamName, groupName, "$").Err()) + // require.Error(t, rdb.Do(ctx, "XGROUP", "CREATE", streamName, groupName, "$", "ENTRIESREAD", "10").Err()) + // require.Error(t, rdb.Do(ctx, "XGROUP", "CREATE", streamName, groupName, "$", "ENTRIESREAD").Err()) + // require.Error(t, rdb.Do(ctx, "XGROUP", "CREATE", streamName, groupName, "$", "MKSTREAM", "ENTRIESREAD").Err()) + // require.NoError(t, rdb.Do(ctx, "XGROUP", "CREATE", streamName, groupName, "$", "MKSTREAM").Err()) + // require.NoError(t, rdb.XInfoStream(ctx, streamName).Err()) + // require.Error(t, rdb.Do(ctx, "XGROUP", "CREATE", streamName, groupName, "$").Err()) + // // Invalid syntax + // groupName = "test-group-b" + // require.Error(t, rdb.Do(ctx, "XGROUP", "CREAT", streamName, groupName, "$").Err()) + // require.Error(t, rdb.Do(ctx, "XGROUP", "CREATE", streamName, groupName, "$", "ENTRIEREAD", "10").Err()) + // require.Error(t, rdb.Do(ctx, "XGROUP", "CREATE", streamName, groupName, "$", "ENTRIESREAD", "-10").Err()) + // require.Error(t, rdb.Do(ctx, "XGROUP", "CREATE", streamName, "1test-group-c", "$").Err()) + + // require.NoError(t, rdb.Del(ctx, "myStream").Err()) + // require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{Stream: "myStream", Values: []string{"iTeM", "1", "vAluE", "a"}}).Err()) + // require.NoError(t, rdb.XGroupCreate(ctx, "myStream", "myGroup", "$").Err()) + // result, err := rdb.XGroupDestroy(ctx, "myStream", "myGroup").Result() + // require.NoError(t, err) + // require.Equal(t, int64(1), result) + // result, err = rdb.XGroupDestroy(ctx, "myStream", "myGroup").Result() + // require.NoError(t, err) + // require.Equal(t, int64(0), result) + // }) + + // t.Run("XGROUP CREATECONSUMER with different kinds of commands", func(t *testing.T) { + // streamName := "test-stream" + // groupName := "test-group" + // consumerName := "test-consumer" + // require.NoError(t, rdb.Del(ctx, streamName).Err()) + // //No such stream + // require.Error(t, rdb.XGroupCreateConsumer(ctx, streamName, groupName, consumerName).Err()) + // require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ + // Stream: streamName, + // ID: "1-0", + // Values: []string{"data", "a"}, + // }).Err()) + // //no such group + // require.Error(t, rdb.XGroupCreateConsumer(ctx, streamName, groupName, consumerName).Err()) + // require.NoError(t, rdb.XGroupCreate(ctx, streamName, groupName, "$").Err()) + + // r := rdb.XGroupCreateConsumer(ctx, streamName, groupName, consumerName).Val() + // require.Equal(t, int64(1), r) + // r = rdb.XGroupCreateConsumer(ctx, streamName, groupName, consumerName).Val() + // require.Equal(t, int64(0), r) + // }) + + // t.Run("XGROUP DELCONSUMER with different kinds of commands", func(t *testing.T) { + // streamName := "test-stream" + // groupName := "test-group" + // consumerName := "test-consumer" + // require.NoError(t, rdb.Del(ctx, streamName).Err()) + // //No such stream + // require.Error(t, rdb.XGroupCreateConsumer(ctx, streamName, groupName, consumerName).Err()) + // require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ + // Stream: streamName, + // ID: "1-0", + // Values: []string{"data", "a"}, + // }).Err()) + // //no such group + // require.Error(t, rdb.XGroupCreateConsumer(ctx, streamName, groupName, consumerName).Err()) + // require.NoError(t, rdb.XGroupCreate(ctx, streamName, groupName, "$").Err()) + // require.NoError(t, rdb.XGroupCreateConsumer(ctx, streamName, groupName, consumerName).Err()) + + // require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ + // Stream: streamName, + // ID: "*", + // Values: []string{"data1", "a1"}, + // }).Err()) + // require.NoError(t, rdb.XReadGroup(ctx, &redis.XReadGroupArgs{ + // Group: groupName, + // Consumer: consumerName, + // Streams: []string{streamName, ">"}, + // Count: 1, + // NoAck: false, + // }).Err()) + // ri, erri := rdb.XInfoGroups(ctx, streamName).Result() + // require.NoError(t, erri) + // require.Equal(t, int64(1), ri[0].Consumers) + // require.Equal(t, int64(1), ri[0].Pending) + + // r, err := rdb.XGroupDelConsumer(ctx, streamName, groupName, consumerName).Result() + // require.NoError(t, err) + // require.Equal(t, int64(1), r) + // ri, erri = rdb.XInfoGroups(ctx, streamName).Result() + // require.NoError(t, erri) + // require.Equal(t, int64(0), ri[0].Consumers) + // require.Equal(t, int64(0), ri[0].Pending) + + // require.NoError(t, rdb.XGroupCreateConsumer(ctx, streamName, groupName, consumerName).Err()) + // require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ + // Stream: streamName, + // ID: "*", + // Values: []string{"data2", "a2"}, + // }).Err()) + // require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ + // Stream: streamName, + // ID: "*", + // Values: []string{"data3", "a3"}, + // }).Err()) + // require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ + // Stream: streamName, + // ID: "*", + // Values: []string{"data4", "a4"}, + // }).Err()) + // require.NoError(t, rdb.XReadGroup(ctx, &redis.XReadGroupArgs{ + // Group: groupName, + // Consumer: consumerName, + // Streams: []string{streamName, ">"}, + // Count: 3, + // NoAck: false, + // }).Err()) + // ri, erri = rdb.XInfoGroups(ctx, streamName).Result() + // require.NoError(t, erri) + // require.Equal(t, int64(1), ri[0].Consumers) + // require.Equal(t, int64(3), ri[0].Pending) + // r, err = rdb.XGroupDelConsumer(ctx, streamName, groupName, consumerName).Result() + // require.NoError(t, err) + // require.Equal(t, int64(3), r) + // ri, erri = rdb.XInfoGroups(ctx, streamName).Result() + // require.NoError(t, erri) + // require.Equal(t, int64(0), ri[0].Consumers) + // require.Equal(t, int64(0), ri[0].Pending) + // }) + + // t.Run("XGROUP SETID with different kinds of commands", func(t *testing.T) { + // streamName := "test-stream" + // groupName := "test-group" + // require.NoError(t, rdb.Del(ctx, streamName).Err()) + // //No such stream + // require.Error(t, rdb.XGroupSetID(ctx, streamName, groupName, "$").Err()) + // require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ + // Stream: streamName, + // ID: "1-0", + // Values: []string{"data", "a"}, + // }).Err()) + // //No such group + // require.Error(t, rdb.XGroupSetID(ctx, streamName, groupName, "$").Err()) + // require.NoError(t, rdb.XGroupCreate(ctx, streamName, groupName, "$").Err()) + + // require.NoError(t, rdb.XGroupSetID(ctx, streamName, groupName, "0-0").Err()) + // require.Error(t, rdb.Do(ctx, "xgroup", "setid", streamName, groupName, "$", "entries", "100").Err()) + // require.Error(t, rdb.Do(ctx, "xgroup", "setid", streamName, groupName, "$", "entriesread", "-100").Err()) + // require.NoError(t, rdb.Do(ctx, "xgroup", "setid", streamName, groupName, "$", "entriesread", "100").Err()) + // }) + + // t.Run("XINFO GROUPS and XINFO CONSUMERS", func(t *testing.T) { + // streamName := "test-stream" + // group1 := "t1" + // group2 := "t2" + // consumer1 := "c1" + // consumer2 := "c2" + // consumer3 := "c3" + // require.NoError(t, rdb.Del(ctx, streamName).Err()) + // require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ + // Stream: streamName, + // ID: "1-0", + // Values: []string{"data", "a"}, + // }).Err()) + // require.NoError(t, rdb.XGroupCreate(ctx, streamName, group1, "$").Err()) + // r := rdb.XInfoGroups(ctx, streamName).Val() + // require.Equal(t, group1, r[0].Name) + // require.Equal(t, int64(0), r[0].Consumers) + // require.Equal(t, int64(0), r[0].Pending) + // require.Equal(t, "1-0", r[0].LastDeliveredID) + // require.Equal(t, int64(0), r[0].EntriesRead) + // require.Equal(t, int64(0), r[0].Lag) + + // require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ + // Stream: streamName, + // ID: "2-0", + // Values: []string{"data1", "b"}, + // }).Err()) + // require.NoError(t, rdb.XGroupCreate(ctx, streamName, group2, "$").Err()) + // r = rdb.XInfoGroups(ctx, streamName).Val() + // require.Equal(t, group2, r[1].Name) + // require.Equal(t, "2-0", r[1].LastDeliveredID) + + // require.NoError(t, rdb.XGroupCreateConsumer(ctx, streamName, group1, consumer1).Err()) + // require.NoError(t, rdb.XGroupCreateConsumer(ctx, streamName, group1, consumer2).Err()) + // require.NoError(t, rdb.XGroupCreateConsumer(ctx, streamName, group2, consumer3).Err()) + // r = rdb.XInfoGroups(ctx, streamName).Val() + // require.Equal(t, int64(2), r[0].Consumers) + // require.Equal(t, int64(1), r[1].Consumers) + + // r1 := rdb.XInfoConsumers(ctx, streamName, group1).Val() + // require.Equal(t, consumer1, r1[0].Name) + // require.Equal(t, consumer2, r1[1].Name) + // r1 = rdb.XInfoConsumers(ctx, streamName, group2).Val() + // require.Equal(t, consumer3, r1[0].Name) + // }) + + // t.Run("XINFO after delete pending message and related consumer, for issue #2350", func(t *testing.T) { + // streamName := "test-stream-2350" + // groupName := "test-group-2350" + // consumerName := "test-consumer-2350" + // require.NoError(t, rdb.XGroupCreateMkStream(ctx, streamName, groupName, "$").Err()) + // require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ + // Stream: streamName, + // ID: "*", + // Values: []string{"testing", "overflow"}, + // }).Err()) + // readRsp := rdb.XReadGroup(ctx, &redis.XReadGroupArgs{ + // Group: groupName, + // Consumer: consumerName, + // Streams: []string{streamName, ">"}, + // Count: 1, + // NoAck: false, + // }) + // require.NoError(t, readRsp.Err()) + // require.Len(t, readRsp.Val(), 1) + // streamRsp := readRsp.Val()[0] + // require.Len(t, streamRsp.Messages, 1) + // msgID := streamRsp.Messages[0] + // require.NoError(t, rdb.XAck(ctx, streamName, groupName, msgID.ID).Err()) + // require.NoError(t, rdb.XGroupDelConsumer(ctx, streamName, groupName, consumerName).Err()) + // infoRsp := rdb.XInfoGroups(ctx, streamName) + // require.NoError(t, infoRsp.Err()) + // infoGroups := infoRsp.Val() + // require.Len(t, infoGroups, 1) + // infoGroup := infoGroups[0] + // require.Equal(t, groupName, infoGroup.Name) + // require.Equal(t, int64(0), infoGroup.Consumers) + // require.Equal(t, int64(0), infoGroup.Pending) + // require.Equal(t, msgID.ID, infoGroup.LastDeliveredID) + // }) + + // t.Run("XREAD After XGroupCreate and XGroupCreateConsumer, for issue #2109", func(t *testing.T) { + // streamName := "test-stream" + // group := "group" + // require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ + // Stream: streamName, + // ID: "*", + // Values: []string{"data1", "b"}, + // }).Err()) + // require.NoError(t, rdb.XGroupCreate(ctx, streamName, group, "0").Err()) + // require.NoError(t, rdb.XGroupCreateConsumer(ctx, streamName, group, "consumer").Err()) + // require.NoError(t, rdb.XRead(ctx, &redis.XReadArgs{ + // Streams: []string{streamName, "0"}, + // }).Err()) + // }) + + // t.Run("XREADGROUP with different kinds of commands", func(t *testing.T) { + // streamName := "mystream" + // groupName := "mygroup" + // require.NoError(t, rdb.Del(ctx, streamName).Err()) + // require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ + // Stream: streamName, + // ID: "1-0", + // Values: []string{"field1", "data1"}, + // }).Err()) + // require.NoError(t, rdb.XGroupCreate(ctx, streamName, groupName, "0").Err()) + // consumerName := "myconsumer" + // r, err := rdb.XReadGroup(ctx, &redis.XReadGroupArgs{ + // Group: groupName, + // Consumer: consumerName, + // Streams: []string{streamName, ">"}, + // Count: 1, + // NoAck: false, + // }).Result() + // require.NoError(t, err) + // require.Equal(t, []redis.XStream{{ + // Stream: streamName, + // Messages: []redis.XMessage{{ID: "1-0", Values: map[string]interface{}{"field1": "data1"}}}, + // }}, r) + + // require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ + // Stream: streamName, + // ID: "2-0", + // Values: []string{"field2", "data2"}, + // }).Err()) + // r, err = rdb.XReadGroup(ctx, &redis.XReadGroupArgs{ + // Group: groupName, + // Consumer: consumerName, + // Streams: []string{streamName, ">"}, + // Count: 1, + // NoAck: false, + // }).Result() + // require.NoError(t, err) + // require.Equal(t, []redis.XStream{{ + // Stream: streamName, + // Messages: []redis.XMessage{{ID: "2-0", Values: map[string]interface{}{"field2": "data2"}}}, + // }}, r) + + // r, err = rdb.XReadGroup(ctx, &redis.XReadGroupArgs{ + // Group: groupName, + // Consumer: consumerName, + // Streams: []string{streamName, "0"}, + // Count: 2, + // NoAck: false, + // }).Result() + // require.NoError(t, err) + // require.Equal(t, []redis.XStream{{ + // Stream: streamName, + // Messages: []redis.XMessage{{ID: "1-0", Values: map[string]interface{}{"field1": "data1"}}, + // {ID: "2-0", Values: map[string]interface{}{"field2": "data2"}}}, + // }}, r) + + // require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ + // Stream: streamName, + // ID: "3-0", + // Values: []string{"field3", "data3"}, + // }).Err()) + // r, err = rdb.XReadGroup(ctx, &redis.XReadGroupArgs{ + // Group: groupName, + // Consumer: consumerName, + // Streams: []string{streamName, ">"}, + // Count: 1, + // NoAck: true, + // }).Result() + // require.NoError(t, err) + // require.Equal(t, []redis.XStream{{ + // Stream: streamName, + // Messages: []redis.XMessage{{ID: "3-0", Values: map[string]interface{}{"field3": "data3"}}}, + // }}, r) + // r, err = rdb.XReadGroup(ctx, &redis.XReadGroupArgs{ + // Group: groupName, + // Consumer: consumerName, + // Streams: []string{streamName, "0"}, + // Count: 2, + // NoAck: false, + // }).Result() + // require.NoError(t, err) + // require.Equal(t, []redis.XStream{{ + // Stream: streamName, + // Messages: []redis.XMessage{{ID: "1-0", Values: map[string]interface{}{"field1": "data1"}}, + // {ID: "2-0", Values: map[string]interface{}{"field2": "data2"}}}, + // }}, r) + + // c := srv.NewClient() + // defer func() { require.NoError(t, c.Close()) }() + // ch := make(chan []redis.XStream) + // go func() { + // ch <- c.XReadGroup(ctx, &redis.XReadGroupArgs{ + // Group: groupName, + // Consumer: consumerName, + // Streams: []string{streamName, ">"}, + // Count: 2, + // Block: 10 * time.Second, + // NoAck: false, + // }).Val() + // }() + // require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ + // Stream: streamName, + // ID: "4-0", + // Values: []string{"field4", "data4"}, + // }).Err()) + // r = <-ch + // require.Equal(t, []redis.XStream{{ + // Stream: streamName, + // Messages: []redis.XMessage{{ID: "4-0", Values: map[string]interface{}{"field4": "data4"}}}, + // }}, r) + + // require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ + // Stream: streamName, + // ID: "5-0", + // Values: []string{"field5", "data5"}, + // }).Err()) + // require.NoError(t, rdb.XReadGroup(ctx, &redis.XReadGroupArgs{ + // Group: groupName, + // Consumer: consumerName, + // Streams: []string{streamName, ">"}, + // Count: 1, + // NoAck: false, + // }).Err()) + // require.NoError(t, rdb.XDel(ctx, streamName, "5-0").Err()) + // r, err = rdb.XReadGroup(ctx, &redis.XReadGroupArgs{ + // Group: groupName, + // Consumer: consumerName, + // Streams: []string{streamName, "5"}, + // Count: 1, + // NoAck: false, + // }).Result() + // require.NoError(t, err) + // require.Equal(t, []redis.XStream{{ + // Stream: streamName, + // Messages: []redis.XMessage{{ID: "5-0", Values: map[string]interface{}(nil)}}, + // }}, r) + // }) + + // t.Run("Check xreadgroup fetches the newest data after create consumer in the command", func(t *testing.T) { + // streamName := "mystream" + // groupName := "mygroup" + // require.NoError(t, rdb.Del(ctx, streamName).Err()) + // require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ + // Stream: streamName, + // ID: "1-0", + // Values: []string{"field1", "data1"}, + // }).Err()) + // require.NoError(t, rdb.XGroupCreate(ctx, streamName, groupName, "0").Err()) + // consumerName := "myconsumer" + // err := rdb.XReadGroup(ctx, &redis.XReadGroupArgs{ + // Group: groupName, + // Consumer: consumerName, + // Streams: []string{streamName, ">"}, + // Count: 1, + // NoAck: false, + // }).Err() + // require.NoError(t, err) + // ri, erri := rdb.XInfoGroups(ctx, streamName).Result() + // require.NoError(t, erri) + // require.Equal(t, int64(1), ri[0].Consumers) + // }) + + // t.Run("XACK with different kinds of commands", func(t *testing.T) { + // streamName := "mystream" + // groupName := "mygroup" + // require.NoError(t, rdb.Del(ctx, streamName).Err()) + // r, err := rdb.XAck(ctx, streamName, groupName, "0-0").Result() + // require.NoError(t, err) + // require.Equal(t, int64(0), r) + // require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ + // Stream: streamName, + // ID: "1-0", + // Values: []string{"field1", "data1"}, + // }).Err()) + // require.NoError(t, rdb.XGroupCreate(ctx, streamName, groupName, "0").Err()) + // consumerName := "myconsumer" + // err = rdb.XReadGroup(ctx, &redis.XReadGroupArgs{ + // Group: groupName, + // Consumer: consumerName, + // Streams: []string{streamName, ">"}, + // Count: 1, + // NoAck: false, + // }).Err() + // require.NoError(t, err) + // r, err = rdb.XAck(ctx, streamName, groupName, "1-0").Result() + // require.NoError(t, err) + // require.Equal(t, int64(1), r) + + // require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ + // Stream: streamName, + // ID: "2-0", + // Values: []string{"field1", "data1"}, + // }).Err()) + // require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ + // Stream: streamName, + // ID: "3-0", + // Values: []string{"field1", "data1"}, + // }).Err()) + // require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ + // Stream: streamName, + // ID: "4-0", + // Values: []string{"field1", "data1"}, + // }).Err()) + // err = rdb.XReadGroup(ctx, &redis.XReadGroupArgs{ + // Group: groupName, + // Consumer: consumerName, + // Streams: []string{streamName, ">"}, + // Count: 3, + // NoAck: false, + // }).Err() + // require.NoError(t, err) + // r, err = rdb.XAck(ctx, streamName, groupName, "2-0", "3-0", "4-0").Result() + // require.NoError(t, err) + // require.Equal(t, int64(3), r) + // }) + + // t.Run("Simple XCLAIM command tests", func(t *testing.T) { + // streamName := "mystream" + // groupName := "mygroup" + // consumerName := "myconsumer" + // consumer1Name := "myconsumer1" + // require.NoError(t, rdb.Del(ctx, streamName).Err()) + + // require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ + // Stream: streamName, + // ID: "1-0", + // Values: []string{"field1", "data1"}, + // }).Err()) + // require.NoError(t, rdb.XGroupCreate(ctx, streamName, groupName, "0").Err()) + // r, err := rdb.XReadGroup(ctx, &redis.XReadGroupArgs{ + // Group: groupName, + // Consumer: consumerName, + // Streams: []string{streamName, ">"}, + // Count: 1, + // NoAck: false, + // }).Result() + // require.NoError(t, err) + // require.Equal(t, []redis.XStream{{ + // Stream: streamName, + // Messages: []redis.XMessage{{ID: "1-0", Values: map[string]interface{}{"field1": "data1"}}}, + // }}, r) + + // claimedMessages, err := rdb.XClaim(ctx, &redis.XClaimArgs{ + // Stream: streamName, + // Group: groupName, + // Consumer: consumer1Name, + // MinIdle: 0, + // Messages: []string{"1-0"}, + // }).Result() + // require.NoError(t, err) + // require.Len(t, claimedMessages, 1, "Expected to claim 1 message") + // require.Equal(t, "1-0", claimedMessages[0].ID, "Expected claimed message ID to match") + + // time.Sleep(2000 * time.Millisecond) + // minIdleTime := 1000 * time.Millisecond + // claimedMessages, err = rdb.XClaim(ctx, &redis.XClaimArgs{ + // Stream: streamName, + // Group: groupName, + // Consumer: consumerName, + // MinIdle: minIdleTime, + // Messages: []string{"1-0"}, + // }).Result() + // require.NoError(t, err) + // require.Len(t, claimedMessages, 1, "Expected to claim 1 message if idle time is large enough") + // require.Equal(t, "1-0", claimedMessages[0].ID, "Expected claimed message ID to match") + + // minIdleTime = 60000 * time.Millisecond + // claimedMessages, err = rdb.XClaim(ctx, &redis.XClaimArgs{ + // Stream: streamName, + // Group: groupName, + // Consumer: consumer1Name, + // MinIdle: minIdleTime, + // Messages: []string{"1-0"}, + // }).Result() + + // require.NoError(t, err) + // require.Empty(t, claimedMessages, "Expected no messages to be claimed due to insufficient idle time") + // }) + + // t.Run("XCLAIM with different timing situations and options", func(t *testing.T) { + // streamName := "mystream" + // groupName := "mygroup" + // consumerName := "myconsumer" + // consumer1Name := "myconsumer1" + // require.NoError(t, rdb.Del(ctx, streamName).Err()) + + // require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ + // Stream: streamName, + // ID: "1-0", + // Values: []string{"field1", "data1"}, + // }).Err()) + // require.NoError(t, rdb.XGroupCreate(ctx, streamName, groupName, "0").Err()) + // r, err := rdb.XReadGroup(ctx, &redis.XReadGroupArgs{ + // Group: groupName, + // Consumer: consumerName, + // Streams: []string{streamName, ">"}, + // Count: 1, + // NoAck: false, + // }).Result() + // require.NoError(t, err) + // require.Equal(t, []redis.XStream{{ + // Stream: streamName, + // Messages: []redis.XMessage{{ID: "1-0", Values: map[string]interface{}{"field1": "data1"}}}, + // }}, r) + + // rawClaimedMessages, err := rdb.Do(ctx, "XCLAIM", streamName, groupName, consumer1Name, "0", "1-0", "IDLE", "5000").Result() + // require.NoError(t, err) + // messages, ok := rawClaimedMessages.([]interface{}) + // require.True(t, ok, "Expected the result to be a slice of interface{}") + // firstMsg, ok := messages[0].([]interface{}) + // require.True(t, ok, "Expected message details to be a slice of interface{}") + // msgID, ok := firstMsg[0].(string) + // require.True(t, ok, "Expected message ID to be a string") + // require.Equal(t, "1-0", msgID, "Expected claimed message ID to match") + + // claimedMessages, err := rdb.XClaim(ctx, &redis.XClaimArgs{ + // Stream: streamName, + // Group: groupName, + // Consumer: consumerName, + // MinIdle: 2000 * time.Millisecond, + // Messages: []string{"1-0"}, + // }).Result() + // require.NoError(t, err) + // require.Len(t, claimedMessages, 1, "Expected to claim 1 message if idle time is large enough") + // require.Equal(t, "1-0", claimedMessages[0].ID, "Expected claimed message ID to match") + + // tenSecondsAgo := time.Now().Add(-10 * time.Second).UnixMilli() + // rawClaimedMessages, err = rdb.Do(ctx, "XCLAIM", streamName, groupName, consumer1Name, "0", "1-0", "TIME", tenSecondsAgo).Result() + // require.NoError(t, err) + // messages, ok = rawClaimedMessages.([]interface{}) + // require.True(t, ok, "Expected the result to be a slice of interface{}") + // firstMsg, ok = messages[0].([]interface{}) + // require.True(t, ok, "Expected message details to be a slice of interface{}") + // msgID, ok = firstMsg[0].(string) + // require.True(t, ok, "Expected message ID to be a string") + // require.Equal(t, "1-0", msgID, "Expected claimed message ID to match") + + // claimedMessages, err = rdb.XClaim(ctx, &redis.XClaimArgs{ + // Stream: streamName, + // Group: groupName, + // Consumer: consumerName, + // MinIdle: 5000 * time.Millisecond, + // Messages: []string{"1-0"}, + // }).Result() + // require.NoError(t, err) + // require.Len(t, claimedMessages, 1, "Expected to claim 1 message if idle time is large enough") + // require.Equal(t, "1-0", claimedMessages[0].ID, "Expected claimed message ID to match") + // }) + + // t.Run("XCLAIM command with different options", func(t *testing.T) { + // streamName := "mystream" + // groupName := "mygroup" + // consumerName := "myconsumer" + // consumer1Name := "myconsumer1" + + // require.NoError(t, rdb.Del(ctx, streamName).Err()) + // require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ + // Stream: streamName, + // ID: "1-0", + // Values: []string{"field1", "data1"}, + // }).Err()) + // require.NoError(t, rdb.XGroupCreate(ctx, streamName, groupName, "0").Err()) + + // rawClaimedMessages, err := rdb.Do(ctx, "XCLAIM", streamName, groupName, consumerName, "0", "1-0", "FORCE").Result() + // require.NoError(t, err) + // messages, ok := rawClaimedMessages.([]interface{}) + // require.True(t, ok, "Expected the result to be a slice of interface{}") + // firstMsg, ok := messages[0].([]interface{}) + // require.True(t, ok, "Expected message details to be a slice of interface{}") + // msgID, ok := firstMsg[0].(string) + // require.True(t, ok, "Expected message ID to be a string") + // require.Equal(t, "1-0", msgID, "Expected claimed message ID to match") + + // cmd := rdb.XClaimJustID(ctx, &redis.XClaimArgs{ + // Stream: streamName, + // Group: groupName, + // Consumer: consumer1Name, + // MinIdle: 0, + // Messages: []string{"1-0"}, + // }) + + // claimedIDs, err := cmd.Result() + // require.NoError(t, err) + // require.Len(t, claimedIDs, 1, "Expected to claim exactly one message ID") + // require.Equal(t, "1-0", claimedIDs[0], "Expected claimed message ID to match") + // }) + + // t.Run("XAUTOCLAIM can claim PEL items from another consume", func(t *testing.T) { + + // streamName := "mystream" + // groupName := "mygroup" + // var id1 string + // require.NoError(t, rdb.Del(ctx, streamName).Err()) + // { + // rsp := rdb.XAdd(ctx, &redis.XAddArgs{ + // Stream: streamName, + // ID: "*", + // Values: []string{"a", "1"}, + // }) + // require.NoError(t, rsp.Err()) + // id1 = rsp.Val() + // } + // var id2 string + // { + // rsp := rdb.XAdd(ctx, &redis.XAddArgs{ + // Stream: streamName, + // ID: "*", + // Values: []string{"b", "2"}, + // }) + // require.NoError(t, rsp.Err()) + // id2 = rsp.Val() + // } + // var id3 string + // { + // rsp := rdb.XAdd(ctx, &redis.XAddArgs{ + // Stream: streamName, + // ID: "*", + // Values: []string{"c", "3"}, + // }) + // require.NoError(t, rsp.Err()) + // id3 = rsp.Val() + // } + // var id4 string + // { + // rsp := rdb.XAdd(ctx, &redis.XAddArgs{ + // Stream: streamName, + // ID: "*", + // Values: []string{"d", "4"}, + // }) + // require.NoError(t, rsp.Err()) + // id4 = rsp.Val() + // } + + // require.NoError(t, rdb.XGroupCreate(ctx, streamName, groupName, "0").Err()) + + // consumer1 := "consumer1" + // consumer2 := "consumer2" + // { + // rsp := rdb.XReadGroup(ctx, &redis.XReadGroupArgs{ + // Group: groupName, + // Consumer: consumer1, + // Streams: []string{streamName, ">"}, + // Count: 1, + // }) + // require.NoError(t, rsp.Err()) + // require.Len(t, rsp.Val(), 1) + // require.Len(t, rsp.Val()[0].Messages, 1) + // require.Equal(t, id1, rsp.Val()[0].Messages[0].ID) + // require.Len(t, rsp.Val()[0].Messages[0].Values, 1) + // require.Equal(t, "1", rsp.Val()[0].Messages[0].Values["a"]) + // } + + // { + // time.Sleep(200 * time.Millisecond) + // rsp := rdb.XAutoClaim(ctx, &redis.XAutoClaimArgs{ + // Stream: streamName, + // Group: groupName, + // Consumer: consumer2, + // MinIdle: 10 * time.Millisecond, + // Count: 1, + // Start: "-", + // }) + // require.NoError(t, rsp.Err()) + // msgs, start := rsp.Val() + // require.Equal(t, "0-0", start) + // require.Len(t, msgs, 1) + // require.Len(t, msgs[0].Values, 1) + // require.Equal(t, "1", msgs[0].Values["a"]) + // } + + // { + // rsp := rdb.XReadGroup(ctx, &redis.XReadGroupArgs{ + // Group: groupName, + // Consumer: consumer1, + // Streams: []string{streamName, ">"}, + // Count: 3, + // }) + // require.NoError(t, rsp.Err()) + + // time.Sleep(time.Millisecond * 200) + // require.NoError(t, rdb.XDel(ctx, streamName, id2).Err()) + // } + + // { + // cmd := rdb.Do(ctx, "XAUTOCLAIM", streamName, groupName, consumer2, 10, "-", "COUNT", 3) + // require.NoError(t, cmd.Err()) + // require.Equal(t, []interface{}{ + // id4, + // []interface{}{ + // []interface{}{ + // id1, + // []interface{}{"a", "1"}, + // }, + // []interface{}{ + // id3, + // []interface{}{"c", "3"}, + // }, + // }, + // []interface{}{ + // id2, + // }, + // }, cmd.Val()) + // } + + // { + // time.Sleep(time.Millisecond * 200) + // require.NoError(t, rdb.XDel(ctx, streamName, id4).Err()) + // rsp := rdb.XAutoClaimJustID(ctx, &redis.XAutoClaimArgs{ + // Stream: streamName, + // Group: groupName, + // Consumer: consumer2, + // MinIdle: 10 * time.Millisecond, + // Start: "-", + // }) + // require.NoError(t, rsp.Err()) + // msgs, start := rsp.Val() + // require.Equal(t, "0-0", start) + // require.Len(t, msgs, 2) + // require.Equal(t, id1, msgs[0]) + // require.Equal(t, id3, msgs[1]) + // } + // }) + + // t.Run("XAUTOCLAIM as an iterator", func(t *testing.T) { + // streamName := "mystream" + // groupName := "mygroup" + // var id3, id5 string + // require.NoError(t, rdb.Del(ctx, streamName).Err()) + // { + // rsp := rdb.XAdd(ctx, &redis.XAddArgs{ + // Stream: streamName, + // ID: "*", + // Values: []string{"a", "1"}, + // }) + // require.NoError(t, rsp.Err()) + // } + // { + // rsp := rdb.XAdd(ctx, &redis.XAddArgs{ + // Stream: streamName, + // ID: "*", + // Values: []string{"b", "2"}, + // }) + // require.NoError(t, rsp.Err()) + // } + // { + // rsp := rdb.XAdd(ctx, &redis.XAddArgs{ + // Stream: streamName, + // ID: "*", + // Values: []string{"c", "3"}, + // }) + // require.NoError(t, rsp.Err()) + // id3 = rsp.Val() + // } + // { + // rsp := rdb.XAdd(ctx, &redis.XAddArgs{ + // Stream: streamName, + // ID: "*", + // Values: []string{"d", "4"}, + // }) + // require.NoError(t, rsp.Err()) + // } + // { + // rsp := rdb.XAdd(ctx, &redis.XAddArgs{ + // Stream: streamName, + // ID: "*", + // Values: []string{"e", "5"}, + // }) + // require.NoError(t, rsp.Err()) + // id5 = rsp.Val() + // } + // require.NoError(t, rdb.XGroupCreate(ctx, streamName, groupName, "0").Err()) + + // consumer1, consumer2 := "consumer1", "consumer2" + // { + // rsp := rdb.XReadGroup(ctx, &redis.XReadGroupArgs{ + // Group: groupName, + // Consumer: consumer1, + // Streams: []string{streamName, ">"}, + // Count: 90, + // }) + // require.NoError(t, rsp.Err()) + // time.Sleep(200 * time.Millisecond) + // } + // { + // rsp := rdb.XAutoClaim(ctx, &redis.XAutoClaimArgs{ + // Stream: streamName, + // Group: groupName, + // Consumer: consumer2, + // MinIdle: 10 * time.Millisecond, + // Count: 2, + // Start: "-", + // }) + // require.NoError(t, rsp.Err()) + // msgs, start := rsp.Val() + // require.Equal(t, id3, start) + // require.Len(t, msgs, 2) + // require.Len(t, msgs[0].Values, 1) + // require.Equal(t, "1", msgs[0].Values["a"]) + // } + + // { + // rsp := rdb.XAutoClaim(ctx, &redis.XAutoClaimArgs{ + // Stream: streamName, + // Group: groupName, + // Consumer: consumer2, + // MinIdle: 10 * time.Millisecond, + // Start: id3, + // Count: 2, + // }) + // require.NoError(t, rsp.Err()) + // msgs, start := rsp.Val() + // require.Equal(t, id5, start) + // require.Len(t, msgs, 2) + // require.Len(t, msgs[0].Values, 1) + // require.Equal(t, "3", msgs[0].Values["c"]) + // } + + // { + // rsp := rdb.XAutoClaim(ctx, &redis.XAutoClaimArgs{ + // Stream: streamName, + // Group: groupName, + // Consumer: consumer2, + // MinIdle: 10 * time.Millisecond, + // Start: id5, + // Count: 1, + // }) + // require.NoError(t, rsp.Err()) + // msgs, start := rsp.Val() + // require.Equal(t, "0-0", start) + // require.Len(t, msgs, 1) + // require.Len(t, msgs[0].Values, 1) + // require.Equal(t, "5", msgs[0].Values["e"]) + // } + // }) + + // t.Run("XAUTOCLAIM with XDEL", func(t *testing.T) { + // streamName := "x" + // groupName := "grp" + // require.NoError(t, rdb.Del(ctx, streamName).Err()) + // require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ + // Stream: streamName, + // ID: "1-0", + // Values: map[string]interface{}{"f": "v"}, + // }).Err()) + // require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ + // Stream: streamName, + // ID: "2-0", + // Values: map[string]interface{}{"f": "v"}, + // }).Err()) + // require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ + // Stream: streamName, + // ID: "3-0", + // Values: map[string]interface{}{"f": "v"}, + // }).Err()) + // require.NoError(t, rdb.XGroupCreate(ctx, streamName, groupName, "0").Err()) + // { + // rsp := rdb.XReadGroup(ctx, &redis.XReadGroupArgs{ + // Group: groupName, + // Consumer: "Alice", + // Streams: []string{streamName, ">"}, + // }) + // require.NoError(t, rsp.Err()) + // require.Len(t, rsp.Val(), 1) + // require.Len(t, rsp.Val()[0].Messages, 3) + // require.Equal(t, "1-0", rsp.Val()[0].Messages[0].ID) + // require.Equal(t, "v", rsp.Val()[0].Messages[0].Values["f"]) + // require.Equal(t, "2-0", rsp.Val()[0].Messages[1].ID) + // require.Equal(t, "v", rsp.Val()[0].Messages[1].Values["f"]) + // require.Equal(t, "3-0", rsp.Val()[0].Messages[2].ID) + // require.Equal(t, "v", rsp.Val()[0].Messages[2].Values["f"]) + // } + // { + // require.NoError(t, rdb.XDel(ctx, streamName, "2-0").Err()) + // cmd := rdb.Do(ctx, "XAUTOCLAIM", streamName, groupName, "Bob", 0, "0-0") + // require.NoError(t, cmd.Err()) + // require.Equal(t, []interface{}{ + // "0-0", + // []interface{}{ + // []interface{}{ + // "1-0", + // []interface{}{"f", "v"}, + // }, + // []interface{}{ + // "3-0", + // []interface{}{"f", "v"}, + // }, + // }, + // []interface{}{ + // "2-0", + // }, + // }, cmd.Val()) + // } + // }) + + // t.Run("XAUTOCLAIM with XDEL and count", func(t *testing.T) { + // streamName := "x" + // groupName := "grp" + // require.NoError(t, rdb.Del(ctx, streamName).Err()) + // require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ + // Stream: streamName, + // ID: "1-0", + // Values: map[string]interface{}{"f": "v"}, + // }).Err()) + // require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ + // Stream: streamName, + // ID: "2-0", + // Values: map[string]interface{}{"f": "v"}, + // }).Err()) + // require.NoError(t, rdb.XAdd(ctx, &redis.XAddArgs{ + // Stream: streamName, + // ID: "3-0", + // Values: map[string]interface{}{"f": "v"}, + // }).Err()) + // require.NoError(t, rdb.XGroupCreate(ctx, streamName, groupName, "0").Err()) + // { + // rsp := rdb.XReadGroup(ctx, &redis.XReadGroupArgs{ + // Group: groupName, + // Consumer: "Alice", + // Streams: []string{streamName, ">"}, + // }) + // require.NoError(t, rsp.Err()) + // require.Len(t, rsp.Val(), 1) + // require.Len(t, rsp.Val()[0].Messages, 3) + // require.Equal(t, "1-0", rsp.Val()[0].Messages[0].ID) + // require.Equal(t, "v", rsp.Val()[0].Messages[0].Values["f"]) + // require.Equal(t, "2-0", rsp.Val()[0].Messages[1].ID) + // require.Equal(t, "v", rsp.Val()[0].Messages[1].Values["f"]) + // require.Equal(t, "3-0", rsp.Val()[0].Messages[2].ID) + // require.Equal(t, "v", rsp.Val()[0].Messages[2].Values["f"]) + // } + // { + // require.NoError(t, rdb.XDel(ctx, streamName, "1-0").Err()) + // require.NoError(t, rdb.XDel(ctx, streamName, "2-0").Err()) + // cmd := rdb.Do(ctx, "XAUTOCLAIM", streamName, groupName, "Bob", 0, "0-0", "COUNT", 1) + // require.NoError(t, cmd.Err()) + // require.Equal(t, []interface{}{ + // "2-0", + // []interface{}{}, + // []interface{}{ + // "1-0", + // }, + // }, cmd.Val()) + // } + // { + // cmd := rdb.Do(ctx, "XAUTOCLAIM", streamName, groupName, "Bob", 0, "2-0", "COUNT", 1) + // require.NoError(t, cmd.Err()) + // require.Equal(t, []interface{}{ + // "3-0", + // []interface{}{}, + // []interface{}{ + // "2-0", + // }, + // }, cmd.Val()) + // } + // { + // cmd := rdb.Do(ctx, "XAUTOCLAIM", streamName, groupName, "Bob", 0, "3-0", "COUNT", 1) + // require.NoError(t, cmd.Err()) + // require.Equal(t, []interface{}{ + // "0-0", + // []interface{}{ + // []interface{}{ + // "3-0", + // []interface{}{"f", "v"}, + // }, + // }, + // []interface{}{}, + // }, cmd.Val()) + // } + // // assert_equal [XPENDING x grp - + 10 Alice] {} + // // add xpending to this test case when it is supported + // }) + + // t.Run("XAUTOCLAIM with out of range count", func(t *testing.T) { + // err := rdb.XAutoClaim(ctx, &redis.XAutoClaimArgs{ + // Stream: "x", + // Group: "grp", + // Consumer: "Bob", + // MinIdle: 0, + // Start: "3-0", + // Count: 8070450532247928833, + // }).Err() + // require.Error(t, err) + // require.True(t, strings.HasPrefix(err.Error(), "ERR COUNT")) + // }) + + // t.Run("XAUTOCLAIM COUNT must be > 0", func(t *testing.T) { + // cmd := rdb.Do(ctx, "XAUTOCLAIM", "key", "group", "consumer", 1, 1, "COUNT", 0) + // require.Error(t, cmd.Err()) + // require.Equal(t, "ERR COUNT must be > 0", cmd.Err().Error()) + // }) } func parseStreamEntryID(id string) (ts int64, seqNum int64) { diff --git a/tests/gocase/util/server.go b/tests/gocase/util/server.go index a3f4314e342..849b47ad102 100644 --- a/tests/gocase/util/server.go +++ b/tests/gocase/util/server.go @@ -136,7 +136,10 @@ func (s *KvrocksServer) close(keepDir bool) { func (s *KvrocksServer) Restart() { s.close(true) + s.Start() +} +func (s *KvrocksServer) Start() { b := *binPath require.NotEmpty(s.t, b, "please set the binary path by `-binPath`") cmd := exec.Command(b) diff --git a/utils/kvrocks2redis/parser.cc b/utils/kvrocks2redis/parser.cc index 9d4db5ec698..c131f14654f 100644 --- a/utils/kvrocks2redis/parser.cc +++ b/utils/kvrocks2redis/parser.cc @@ -33,7 +33,7 @@ Status Parser::ParseFullDB() { rocksdb::DB *db = storage_->GetDB(); - rocksdb::ColumnFamilyHandle *metadata_cf_handle = storage_->GetCFHandle(engine::kMetadataColumnFamilyName); + rocksdb::ColumnFamilyHandle *metadata_cf_handle = storage_->GetCFHandle(ColumnFamilyID::Metadata); // Due to RSI(Rocksdb Secondary Instance) not supporting "Snapshots based read", we don't need to set the snapshot // parameter. However, until we proactively invoke TryCatchUpWithPrimary, this replica is read-only, which can be // considered as a snapshot. diff --git a/utils/kvrocks2redis/tests/README.md b/utils/kvrocks2redis/tests/README.md index 6bc6048d9b9..7b3b136ea21 100644 --- a/utils/kvrocks2redis/tests/README.md +++ b/utils/kvrocks2redis/tests/README.md @@ -7,7 +7,7 @@ For testing the `kvrocks2redis` utility, manually check generate AOF. * Start `kvrocks` and `kvrocks2redis` * [ ] TODO automatic create docker env * Install dependency:: - * pip install git+https://github.com/andymccurdy/redis-py.git@2.10.3 + * pip install redis==4.3.6 * Usage: ```bash diff --git a/utils/kvrocks2redis/tests/check_consistency.py b/utils/kvrocks2redis/tests/check_consistency.py index 3a176cc87ec..72477978120 100644 --- a/utils/kvrocks2redis/tests/check_consistency.py +++ b/utils/kvrocks2redis/tests/check_consistency.py @@ -100,7 +100,7 @@ def _import_and_compare(self, num): self.src_cli.incr(incr_key) hash_key = f'hash_key_{i}' hash_value = {'field1': f'field1_value_{i}', 'field2': f'field2_value_{i}'} - self.src_cli.hmset(hash_key, hash_value) + self.src_cli.hset(hash_key, mapping=hash_value) set_key = f'set_key_{i}' set_value = [f'set_value_{i}_1', f'set_value_{i}_2', f'set_value_{i}_3'] self.src_cli.sadd(set_key, *set_value) @@ -129,4 +129,4 @@ def _import_and_compare(self, num): args = parser.parse_args() redis_comparator = RedisComparator(args.src_host, args.src_port, args.src_password, args.dst_host, args.dst_port, args.dst_password) - redis_comparator.compare_redis_data(args.key_file) \ No newline at end of file + redis_comparator.compare_redis_data(args.key_file) diff --git a/utils/kvrocks2redis/tests/populate-kvrocks.py b/utils/kvrocks2redis/tests/populate-kvrocks.py index 8883a0dfd36..fefae86a0fe 100644 --- a/utils/kvrocks2redis/tests/populate-kvrocks.py +++ b/utils/kvrocks2redis/tests/populate-kvrocks.py @@ -146,8 +146,10 @@ def run_test(client, cases : list): print('******* Some case test fail *******') for cmd in fails: print(cmd) + return False else: print('All case passed.') + return True if __name__ == '__main__': @@ -155,5 +157,10 @@ def run_test(client, cases : list): client = redis.Redis(host=args.host, port=args.port, decode_responses=True, password=args.password) if args.flushdb: client.flushdb() - run_test(client, PopulateCases) - run_test(client, AppendCases) + succ = True + if not run_test(client, PopulateCases): + succ = False + if not run_test(client, AppendCases): + succ = False + if not succ: + raise AssertionError("Test failed. See details above.") diff --git a/x.py b/x.py index 464d3987e42..710534a2632 100755 --- a/x.py +++ b/x.py @@ -19,9 +19,10 @@ from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter, REMAINDER from glob import glob -from os import makedirs +import os from pathlib import Path import re +import filecmp from subprocess import Popen, PIPE import sys from typing import List, Any, Optional, TextIO, Tuple @@ -92,6 +93,24 @@ def check_version(current: str, required: Tuple[int, int, int], prog_name: Optio return semver +def prepare() -> None: + basedir = Path(__file__).parent.absolute() + + # Install Git hooks + hooks = basedir / "dev" / "hooks" + git_hooks = basedir / ".git" / "hooks" + + git_hooks.mkdir(exist_ok=True) + for hook in hooks.iterdir(): + dst = git_hooks / hook.name + if dst.exists(): + if filecmp.cmp(hook, dst, shallow=False): + print(f"{hook.name} already installed.") + continue + raise RuntimeError(f"{dst} already exists; please remove it first") + else: + dst.symlink_to(hook) + print(f"{hook.name} installed at {dst}.") def build(dir: str, jobs: Optional[int], ghproxy: bool, ninja: bool, unittest: bool, compiler: str, cmake_path: str, D: List[str], skip_build: bool) -> None: @@ -106,7 +125,7 @@ def build(dir: str, jobs: Optional[int], ghproxy: bool, ninja: bool, unittest: b cmake_version = output.read().strip() check_version(cmake_version, CMAKE_REQUIRE_VERSION, "CMake") - makedirs(dir, exist_ok=True) + os.makedirs(dir, exist_ok=True) cmake_options = ["-DCMAKE_BUILD_TYPE=RelWithDebInfo"] if ghproxy: @@ -415,6 +434,13 @@ def test_go(dir: str, cli_path: str, rest: List[str]) -> None: parser_test_go.add_argument('rest', nargs=REMAINDER, help="the rest of arguments to forward to go test") parser_test_go.set_defaults(func=test_go) + parser_prepare = subparsers.add_parser( + 'prepare', + description="Prepare scripts such as git hooks", + help="Prepare scripts such as git hooks" + ) + parser_prepare.set_defaults(func=prepare) + args = parser.parse_args() arg_dict = dict(vars(args))