diff --git a/.coveragerc b/.coveragerc index 33733e8..d1ec1b1 100644 --- a/.coveragerc +++ b/.coveragerc @@ -4,3 +4,11 @@ omit = src/coherence/messages_pb2_grpc.py src/coherence/services_pb2.py src/coherence/services_pb2_grpc.py + src/coherence/cache_service_messages_v1_pb2.py + src/coherence/cache_service_messages_v1_pb2_grpc.py + src/coherence/common_messages_v1_pb2.py + src/coherence/common_messages_v1_pb2_grpc.py + src/coherence/proxy_service_messages_v1_pb2.py + src/coherence/proxy_service_messages_v1_pb2_grpc.py + src/coherence/proxy_service_v1_pb2.py + src/coherence/proxy_service_v1_pb2_grpc.py diff --git a/.github/workflows/validate.yml b/.github/workflows/validate.yml index 0584b97..49f9a11 100644 --- a/.github/workflows/validate.yml +++ b/.github/workflows/validate.yml @@ -17,11 +17,11 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8.x", "3.9.x", "3.10.x", "3.11.x"] - poetry-version: ["1.5.0"] + python-version: ["3.9.x", "3.10.x", "3.11.x", "3.12.x", "3.13.x"] + poetry-version: ["1.8.4"] os: [ubuntu-latest] coherenceVersion: - - 24.03 + - 24.09 - 22.06.10 base-image: - gcr.io/distroless/java17-debian11 @@ -29,7 +29,7 @@ jobs: - ",-jakarta,javax" - ",jakarta,-javax" exclude: - - coherenceVersion: 24.03 + - coherenceVersion: 24.09 profile: ",-jakarta,javax" - coherenceVersion: 22.06.10 profile: ",jakarta,-javax" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ad3017f..3e76243 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,7 +8,7 @@ exclude: \w*(_pb2)\w* repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: cef0300fd0fc4d2a87a85fa2093c6b283ea36f4b # frozen: v5.0.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer @@ -16,21 +16,21 @@ repos: - id: check-added-large-files - repo: https://github.com/PyCQA/flake8 - rev: 7.1.1 + rev: e43806be3607110919eff72939fda031776e885a # frozen: 7.1.1 hooks: - id: flake8 - repo: https://github.com/psf/black - rev: 24.8.0 + rev: 1b2427a2b785cc4aac97c19bb4b9a0de063f9547 # frozen: 24.10.0 hooks: - id: black - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.11.2 + rev: f56614daa94d5cd733d3b7004c5df9caad267b4a # frozen: v1.13.0 hooks: - id: mypy - repo: https://github.com/PyCQA/isort - rev: 5.13.2 + rev: c235f5e450b4b84e58d114ed4c589cbf454175a3 # frozen: 5.13.2 hooks: - id: isort diff --git a/Makefile b/Makefile index 687d8c2..a8f50c7 100644 --- a/Makefile +++ b/Makefile @@ -38,8 +38,8 @@ COHERENCE_WKA1 ?= server1 COHERENCE_WKA2 ?= server1 CLUSTER_PORT ?= 7574 # Profiles to include for building -PROFILES ?= ",-jakarta,javax" -COHERENCE_BASE_IMAGE ?= gcr.io/distroless/java17-debian11 +PROFILES ?= ",jakarta,-javax" +COHERENCE_BASE_IMAGE ?= gcr.io/distroless/java17-debian12 # ---------------------------------------------------------------------------------------------------------------------- # Set the location of various build tools @@ -66,6 +66,27 @@ CURRDIR := $(shell pwd) COMPOSE:=$(shell type -p docker-compose || echo docker compose) $(info COMPOSE = $(COMPOSE)) +# ---------------------------------------------------------------------------------------------------------------------- +# List of unit tests +# ---------------------------------------------------------------------------------------------------------------------- +UNIT_TESTS := tests/unit/test_cache_options.py \ + tests/unit/test_local_cache.py \ + tests/unit/test_environment.py \ + tests/unit/test_serialization.py \ + tests/unit/test_extractors.py + +# ---------------------------------------------------------------------------------------------------------------------- +# List of E2E tests +# ---------------------------------------------------------------------------------------------------------------------- +E2E_TESTS := tests/e2e/test_session.py \ + tests/e2e/test_client.py \ + tests/e2e/test_events.py \ + tests/e2e/test_filters.py \ + tests/e2e/test_processors.py \ + tests/e2e/test_aggregators.py \ + tests/e2e/test_near_caching.py \ +# tests/e2e/test_ai.py \ + # ---------------------------------------------------------------------------------------------------------------------- # Clean-up all of the build artifacts # ---------------------------------------------------------------------------------------------------------------------- @@ -116,21 +137,53 @@ generate-proto: ## Generate Proto Files sed -e 's/import messages_pb2 as messages__pb2/import coherence.messages_pb2 as messages__pb2/' \ < $(CURRDIR)/src/coherence/services_pb2_grpc.py > $(CURRDIR)/src/coherence/services_pb2_grpc.py.out mv $(CURRDIR)/src/coherence/services_pb2_grpc.py.out $(CURRDIR)/src/coherence/services_pb2_grpc.py + curl -o $(PROTO_DIR)/proxy_service_v1.proto \ + https://raw.githubusercontent.com/oracle/coherence/$(COHERENCE_VERSION)/prj/coherence-grpc/src/main/proto/proxy_service_v1.proto + curl -o $(PROTO_DIR)/proxy_service_messages_v1.proto \ + https://raw.githubusercontent.com/oracle/coherence/$(COHERENCE_VERSION)/prj/coherence-grpc/src/main/proto/proxy_service_messages_v1.proto + curl -o $(PROTO_DIR)/common_messages_v1.proto \ + https://raw.githubusercontent.com/oracle/coherence/$(COHERENCE_VERSION)/prj/coherence-grpc/src/main/proto/common_messages_v1.proto + curl -o $(PROTO_DIR)/cache_service_messages_v1.proto \ + https://raw.githubusercontent.com/oracle/coherence/$(COHERENCE_VERSION)/prj/coherence-grpc/src/main/proto/cache_service_messages_v1.proto + python -m grpc_tools.protoc --proto_path=$(CURRDIR)/etc/proto --pyi_out=$(CURRDIR)/src/coherence --python_out=$(CURRDIR)/src/coherence \ + --grpc_python_out=$(CURRDIR)/src/coherence \ + $(CURRDIR)/etc/proto/proxy_service_v1.proto \ + $(CURRDIR)/etc/proto/proxy_service_messages_v1.proto \ + $(CURRDIR)/etc/proto/common_messages_v1.proto \ + $(CURRDIR)/etc/proto/cache_service_messages_v1.proto + sed -e 's/import proxy_service_messages_v1_pb2 as proxy__service__messages__v1__pb2/import coherence.proxy_service_messages_v1_pb2 as proxy__service__messages__v1__pb2/' \ + < $(CURRDIR)/src/coherence/proxy_service_v1_pb2.py > $(CURRDIR)/src/coherence/proxy_service_v1_pb2.py.out + mv $(CURRDIR)/src/coherence/proxy_service_v1_pb2.py.out $(CURRDIR)/src/coherence/proxy_service_v1_pb2.py + sed -e 's/import common_messages_v1_pb2 as common__messages__v1__pb2/import coherence.common_messages_v1_pb2 as common__messages__v1__pb2/' \ + < $(CURRDIR)/src/coherence/proxy_service_messages_v1_pb2.py > $(CURRDIR)/src/coherence/proxy_service_messages_v1_pb2.py.out + mv $(CURRDIR)/src/coherence/proxy_service_messages_v1_pb2.py.out $(CURRDIR)/src/coherence/proxy_service_messages_v1_pb2.py + sed -e 's/import proxy_service_messages_v1_pb2 as proxy__service__messages__v1__pb2/import coherence.proxy_service_messages_v1_pb2 as proxy__service__messages__v1__pb2/' \ + < $(CURRDIR)/src/coherence/proxy_service_v1_pb2_grpc.py > $(CURRDIR)/src/coherence/proxy_service_v1_pb2_grpc.py.out + mv $(CURRDIR)/src/coherence/proxy_service_v1_pb2_grpc.py.out $(CURRDIR)/src/coherence/proxy_service_v1_pb2_grpc.py + sed -e 's/import common_messages_v1_pb2 as common__messages__v1__pb2/import coherence.common_messages_v1_pb2 as common__messages__v1__pb2/' \ + < $(CURRDIR)/src/coherence/cache_service_messages_v1_pb2.py > $(CURRDIR)/src/coherence/cache_service_messages_v1_pb2.py.out + mv $(CURRDIR)/src/coherence/cache_service_messages_v1_pb2.py.out $(CURRDIR)/src/coherence/cache_service_messages_v1_pb2.py # ---------------------------------------------------------------------------------------------------------------------- # Run tests with code coverage # ---------------------------------------------------------------------------------------------------------------------- .PHONY: test test: ## - pytest -W error --cov src/coherence --cov-report=term --cov-report=html \ - tests/test_serialization.py \ - tests/test_extractors.py \ - tests/test_session.py \ - tests/test_client.py \ - tests/test_events.py \ - tests/test_filters.py \ - tests/test_processors.py \ - tests/test_aggregators.py \ + pytest -W error --cov src/coherence --cov-report=term --cov-report=html $(UNIT_TESTS) $(E2E_TESTS) + +# ---------------------------------------------------------------------------------------------------------------------- +# Run unit tests with code coverage +# ---------------------------------------------------------------------------------------------------------------------- +.PHONY: test-unit +test-unit: ## + pytest -W error --cov src/coherence --cov-report=term --cov-report=html $(UNIT_TESTS) + +# ---------------------------------------------------------------------------------------------------------------------- +# Run e2e tests with code coverage +# ---------------------------------------------------------------------------------------------------------------------- +.PHONY: test-e2e +test-e2e: ## + pytest -W error --cov src/coherence --cov-report=term --cov-report=html $(E2E_TESTS) # ---------------------------------------------------------------------------------------------------------------------- # Run standards validation across project diff --git a/docs/api_reference.rst b/docs/api_reference.rst index 5a38842..aa44091 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -11,6 +11,18 @@ Aggregators .. autoclass:: coherence.Aggregators :members: +CacheOptions +------------ +.. autoclass:: coherence.CacheOptions + :members: + + .. automethod:: __init__ + +CacheStats +------------ +.. autoclass:: coherence.CacheStats + :members: + Comparator ---------- .. autoclass:: coherence.Comparator @@ -58,6 +70,13 @@ NamedMap :show-inheritance: :members: +NearCacheOptions +---------------- +.. autoclass:: coherence.NearCacheOptions + :members: + + .. automethod:: __init__ + Options ------- .. autoclass:: coherence.Options diff --git a/etc/proto/cache_service_messages_v1.proto b/etc/proto/cache_service_messages_v1.proto new file mode 100644 index 0000000..cef054c --- /dev/null +++ b/etc/proto/cache_service_messages_v1.proto @@ -0,0 +1,404 @@ +/* + * Copyright (c) 2020, 2024, Oracle and/or its affiliates. + * + * Licensed under the Universal Permissive License v 1.0 as shown at + * https://oss.oracle.com/licenses/upl. + */ + +// ----------------------------------------------------------------- +// Messages used by the Coherence gRPC NamedCache Service. +// +// NOTE: If you add a new request message to this message the current +// protocol version in com.oracle.coherence.grpc.NamedCacheProtocol must +// be increased. This only needs to be done once for any given Coherence +// release. +// ----------------------------------------------------------------- + +syntax = "proto3"; + +package coherence.cache.v1; + +import "common_messages_v1.proto"; +import "google/protobuf/any.proto"; +import "google/protobuf/empty.proto"; +import "google/protobuf/timestamp.proto"; +import "google/protobuf/wrappers.proto"; + +option java_multiple_files = true; +option java_package = "com.oracle.coherence.grpc.messages.cache.v1"; + +// An enum representing the types of request for a Named Cache Service proxy +// +// NOTE: The index numbers for the enum elements MUST NOT BE CHANGED as +// that would break backwards compatibility. Only new index numbers can +// be added. +// +enum NamedCacheRequestType { + // An unknown message. + // This request type is not used, it is here as enums must have a zero value, + // but we need to know the difference between a zero value and the field being + // incorrectly set. + Unknown = 0; + // Called to ensure a cache. + // Must be the first message called prior to any other cache requests. + // The message field must be an EnsureCacheRequest. + // The response will contain the Cache Id and an empty response field. + EnsureCache = 1; + // Execute an aggregator on the cache and return the result. + // The message field must contain an ExecuteRequest where the agent field + // is the serialized aggregator. + // The response will be a BytesValue containing the aggregator result. + Aggregate = 2; + // Clear the specified cache. + // The message field should not be set. + // The response will just be a Complete message corresponding to the request id. + Clear = 3; + // Determine whether the specified cache is contains a specified key mapped + // to a specified value. + // The message field must contain a BinaryKeyAndValue that contains the + // serialized key and value. + // The response will contain the Cache Id and a BoolValue in the response field. + ContainsEntry = 4; + // Determine whether the specified cache is contains a specified key. + // The message field must contain a BytesValue that contains the + // serialized key. + // The response will contain the Cache Id and a BoolValue in the response field. + ContainsKey = 5; + // Determine whether the specified cache is contains a specified value. + // The message field must contain a BytesValue that contains the + // serialized value. + // The response will contain the Cache Id and a BoolValue in the response field. + ContainsValue = 6; + // Destroy the specified cache. + // The message field should not be set. + // The response will just be a Complete message corresponding to the request id. + Destroy = 7; + // Determine whether the specified cache is empty. + // The message field should not be set. + // The response will contain the Cache Id and a BoolValue in the response field. + IsEmpty = 8; + // Determine whether the specified cache is ready. + // The message field should not be set. + // The response will contain the Cache Id and a BoolValue in the response field. + IsReady = 9; + // Determine whether the specified cache is contains a specified key. + // The message field must contain a BytesValue that contains the + // serialized key of the entry to get. + // The response will contain the Cache Id and an OptionalValue that will be empty + // if the cache did not contain an entry for the key or will contain the value from + // the cache. + Get = 10; + // Return the values from the specified cache that are mapped to a collection of keys. + // The message field must contain a CollectionOfBytesValues that contains the + // serialized keys of the entries to get. + // There will be multiple responses containing the Cache Id and an BinaryKeyAndValue for + // each requested key. + GetAll = 11; + // Add or remove an index. + // The message field must contain an IndexRequest. + // The response will just be a Complete message corresponding to the request id. + Index = 12; + // Execute an entry processor against a single entry in the cache and return the result. + // The message field must contain a ExecuteRequest where the agent field + // is the serialized entry processor. + // The response will be a stream of BinaryKeyAndValue values followed by a Complete + // message to signal the end of the response stream. + Invoke = 13; + // Add or remove a MapListener. + // The message field must contain a MapListenerRequest. + // The response will just be a Complete message corresponding to the request id. + MapListener = 14; + // Retrieve a page of entries from the cache + // The message field must contain a BytesValue that is the opaque cookie returned + // by a previous page request, or an empty (or not set) BytesValue to retrieve the + // first page. + // The response will be a stream of results. The first response will be a BytesValue + // which is the new cookie, followed by a stream of BinaryKeyAndValue messages for + // each cache entry in the page, finally followed by a Complete message to signal + // the end of the response stream. + PageOfEntries = 15; + // Retrieve a page of keys from the cache + // The message field must contain a BytesValue that is the opaque cookie returned + // by a previous page request, or an empty (or not set) BytesValue to retrieve the + // first page. + // The response will be a stream of results. The first response will be a BytesValue + // which is the new cookie, followed by a stream of BytesValue messages for + // each cache key in the page, finally followed by a Complete message to signal + // the end of the response stream. + PageOfKeys = 16; + // Add a key and value to the cache, with an optional TTL. + // The message field must contain a PutRequest that contains the + // serialized key, serialized value and optional TTL. + // The response will contain the Cache Id and an BytesValue that will be empty + // if the cache did not contain an entry for the key or will contain the previous + // value from the cache that was mapped to the key. + Put = 17; + // Add a set of keys and values to the cache, with an optional TTL. + // The message field must contain a PutAllRequest that contains the + // serialized keys and values and optional TTL. + // The response will just be a Complete message corresponding to the request id. + PutAll = 18; + // Add a key and value to the cache if a value is not already mapped to the key. + // The message field should contain a PutRequest that contains the + // serialized key, serialized value. + // The response will contain the Cache Id and an BytesValue that will contain the + // serialized previous value mapped to the key. + PutIfAbsent = 19; + // Execute a query for cache entries + // The message field must contain a QueryRequest + // The response will be a stream of BinaryKeyAndValue representing each cache entry + // in the results of the query, finally followed by a Complete message to signal + // the end of the response stream. + QueryEntries = 20; + // Execute a query for cache keys + // The message field must contain a QueryRequest + // The response will be a stream of BytesValue representing each cache key in + // the results of the query, finally followed by a Complete message to signal + // the end of the response stream. + QueryKeys = 21; + // Execute a query for cache values + // The message field must contain a QueryRequest + // The response will be a stream of BytesValue representing each cache value in + // the results of the query, finally followed by a Complete message to signal + // the end of the response stream. + QueryValues = 22; + // Remove an entry from the cache. + // The message field must contain a BytesValue that contains the + // serialized key of the entry to remove. + // The response will contain the Cache Id and an BytesValue that will be empty + // if the cache did not contain an entry for the key or will contain the value from + // the cache. + Remove = 23; + // Remove an entry from the cache if the specified key maps to the specified value. + // The message field must contain a BinaryKeyAndValue that contains the + // serialized key and expected value of the entry to remove. + // The response will contain the Cache Id and an BoolValue that will true if the + // entry was removed. + RemoveMapping = 24; + // Replace an entry in the cache only if the key is currently mapped to a value. + // The message field must contain a BinaryKeyAndValue that contains the + // serialized key of the entry to replace and the serialized value to map to the + // key. + // The response will contain the Cache Id and an BytesValue that will contain the + // serialized previous value mapped to the key. + Replace = 25; + // Replace an entry in the cache only if the key is currently mapped to a + // specified value. + // The message field must contain a ReplaceMappingRequest that contains the + // serialized key of the entry to replace, the serialized expected value and the + // serialized new value to map to the key. + // The response will contain the Cache Id and an BoolValue that will be true if + // the cache mapping was updated. + ReplaceMapping = 26; + // Obtain the size of the specified cache. + // The message field should not be set. + // The response will contain the Cache Id and an Int32Value in the response field. + Size = 27; + // Truncate the specified cache. + // The message field should not be set. + // The response will just be a Complete message corresponding to the request id. + Truncate = 28; +} + +// A request to perform an operation on a remote NamedCache. +message NamedCacheRequest { + // The type of the request + NamedCacheRequestType type = 1; + // The cache identifier for the request. + // The identifier must be the same value returned by the initial ensure cache request. + // This is optional only for EnsureCache as this cannot have a cache identifier + optional int32 cacheId = 2; + // The actual request message, this is optional because some messages do not require + // a message body, for example cache.size() + // The actual request message should be packed inside an Any message and set in this field. + // The proxy will know which message type to expect here based on the "type" field's value. + optional google.protobuf.Any message = 3; +} + +// An enum representing different types of response. +// +// NOTE: The index numbers for the enum elements MUST NOT BE CHANGED as +// that would break backwards compatibility. Only new index numbers can +// be added. +enum ResponseType { + // The response is a message. + Message = 0; + // The response is a map event. + MapEvent = 1; + // The response is destroy event + Destroyed = 2; + // The response is truncated event + Truncated = 3; +} + +// A response message from a Named Cache Service proxy. +// +// NOTE: If you add a new request message to this message the protocol +// version in com.oracle.coherence.grpc.NamedCacheProtocol must be +// increased. This only needs to be done once for any given Coherence +// release. +message NamedCacheResponse { + // The cache identifier for the request + int32 cacheId = 1; + // An enum representing different response types. + // The type of the request. + ResponseType type = 2; + // The response can contain one of a number of response types + // The sender of the corresponding request should know which + // response type it expects + optional google.protobuf.Any message = 3; +} + +// A request to ensure a specific cache. +message EnsureCacheRequest { + // The name of the cache. + string cache = 1; +} + +// A request to associate the specified value with the +// specified key in a cache with an optional TTL. +message PutRequest { + // The cache entry key. + bytes key = 1; + // The value of the entry. + bytes value = 2; + // The time to live in millis. + optional int64 ttl = 3; +} + +// A request to associate the specified value with the +// specified key in a cache with an optional TTL. +message PutAllRequest { + // The cache entries to put. + repeated coherence.common.v1.BinaryKeyAndValue entries = 1; + // The time to live in millis. + optional int64 ttl = 2; +} + +// A request to replace the mapping for the specified key +// with the specified newValue in a cache only if the specified +// key is associated with the specified previousValue in +// that cache. +message ReplaceMappingRequest { + // The key of the entry to be replaced. + bytes key = 1; + // The previous value that should exist in the cache. + bytes previousValue = 2; + // The new value to put. + bytes newValue = 3; +} + +// A request to add or remove an index to a cache +message IndexRequest { + // True to add an index, false to remove an index + bool add = 1; + // The serialized ValueExtractor to use to create or remove the index. + bytes extractor = 2; + // A flag indicating whether to sort the index. + // This is not required for index removal. + optional bool sorted = 3; + // The optional comparator to use to sort the index. + // This is not required for index removal. + optional bytes comparator = 4; +} + +// A message containing either a single serialized key, or a +// collection of serialized keys, or a serialized Filter. +message KeysOrFilter { + oneof keyOrFilter { + // A single serialized key + bytes key = 1; + // The collection of serialized keys + coherence.common.v1.CollectionOfBytesValues keys = 2; + // The serialized filter + bytes filter = 3; + } +} + +// A message containing either a single serialized key, +// or a serialized Filter. +message KeyOrFilter { + oneof keyOrFilter { + // A single serialized key + bytes key = 1; + // The serialized filter + bytes filter = 2; + } +} + +// A request to aggregate entries in a cache. +message ExecuteRequest { + // The serialized executable agent (for example an entry processor or aggregator). + bytes agent = 1; + // The optional collection of keys or filter to use to execute the agent. + optional KeysOrFilter keys = 3; +} + +// A request cache query request. +message QueryRequest { + // The serialized Filter to identify the data to return. + optional bytes filter = 1; + // The optional comparator to use to sort the returned data. + optional bytes comparator = 2; +} + +// A message to subscribe to or unsubscribe from MapEvents for a cache. +message MapListenerRequest { + // A flag indicating whether to subscribe to (true) or unsubscribe from (false) events. + bool subscribe = 1; + // The optional serialized key, or serialized Filter, to identify the entry + // (or entries) to subscribe to. + // If neither key nor filter are set then an Always filter will be used. + optional KeyOrFilter keyOrFilter = 2; + // A unique filter identifier used if the keyOrFilter contains a Filter. + int64 filterId = 3; + // A flag set to true to indicate that the MapEvent objects do + // not have to include the OldValue and NewValue property values + // in order to allow optimizations + bool lite = 4; + // Whether the listener is synchronous + bool synchronous = 5; + // A flag set to true to indicate that the listener is a priming listener. + // A priming listener can only be used when the keyOrFilter field contains + // a single key, or an InKeySetFilter. + bool priming = 6; + // An optional serialized MapTrigger. + bytes trigger = 7; +} + +// A response containing a MapEvent for a MapListener +message MapEventMessage { + // The type of the event + int32 id = 1; + // The key of the entry + bytes key = 2; + // The new value of the entry + bytes newValue = 3; + // The old value of the entry + bytes oldValue = 4; + // An enum of TransformationState values to describes how a CacheEvent has been or should be transformed. + enum TransformationState { + // Value used to indicate that an event is non-transformable and should + // not be passed to any transformer-based listeners. + NON_TRANSFORMABLE = 0; + // Value used to indicate that an event is transformable and could be + // passed to transformer-based listeners. + TRANSFORMABLE = 1; + // Value used to indicate that an event has been transformed, and should + // only be passed to transformer-based listeners. + TRANSFORMED = 2; + } + // TransformationState describes how a CacheEvent has been or should be transformed. + TransformationState transformationState = 5; + // The Filter identifiers applicable to the event. + repeated int64 filterIds = 6; + // A flag indicating whether the event is a synthetic event. + bool synthetic = 7; + // A flag indicating whether the event is a priming event. + bool priming = 8; + // A flag indicating whether this is an expiry event. + bool expired = 9; + // true iff this event is caused by a synthetic version update sent + // by the server to notify clients of the current version. + bool versionUpdate = 10; +} diff --git a/etc/proto/common_messages_v1.proto b/etc/proto/common_messages_v1.proto new file mode 100644 index 0000000..37bdfa1 --- /dev/null +++ b/etc/proto/common_messages_v1.proto @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2020, 2024, Oracle and/or its affiliates. + * + * Licensed under the Universal Permissive License v 1.0 as shown at + * https://oss.oracle.com/licenses/upl. + */ + +// ----------------------------------------------------------------- +// Common messages used by various Coherence services. +// ----------------------------------------------------------------- + +syntax = "proto3"; + +package coherence.common.v1; + +import "google/protobuf/any.proto"; + +option java_multiple_files = true; +option java_package = "com.oracle.coherence.grpc.messages.common.v1"; + +// An error message +message ErrorMessage { + // The text of the error message + string message = 1; + // An optional Exception serialized using the client's serializer + optional bytes error = 2; +} + +// A message to indicate completion of a request response. +message Complete { +} + +// A heart beat message. +message HeartbeatMessage { + // The UUID of the client + optional bytes uuid = 1; + // True to send a heartbeat response + bool ack = 2; +} + +// An optional value. +message OptionalValue { + // A flag indicating whether the value is present. + bool present = 1; + // The serialized value. + bytes value = 2; +} + +// A message that contains a collection of serialized binary values. +message CollectionOfBytesValues { + // The serialized values + repeated bytes values = 1; +} + +// A message containing a serialized key and value. +message BinaryKeyAndValue { + // The serialized binary key. + bytes key = 1; + // The serialized binary value. + bytes value = 2; +} diff --git a/etc/proto/messages.proto b/etc/proto/messages.proto new file mode 100644 index 0000000..7f0881f --- /dev/null +++ b/etc/proto/messages.proto @@ -0,0 +1,506 @@ +/* + * Copyright (c) 2020, 2023 Oracle and/or its affiliates. + * + * Licensed under the Universal Permissive License v 1.0 as shown at + * https://oss.oracle.com/licenses/upl. + */ + +// Authors: +// Mahesh Kannan +// Jonathan Knight + +// NamedCacheService message types +// + +syntax = "proto3"; + +package coherence; + +option java_multiple_files = true; +option java_package = "com.oracle.coherence.grpc"; + +// A request to clear all the entries in the cache. +message ClearRequest { + // The scope name to use to obtain the cache. + string scope = 1; + // The name of the cache. + string cache = 2; +} + +// A request to determine whether an entry exists in a cache +// with a specific key and value. +message ContainsEntryRequest { + // The scope name to use to obtain the cache. + string scope = 1; + // The name of the cache. + string cache = 2; + // The serialization format. + string format = 3; + // The serialization format. + bytes key = 4; + // The value of the entry to verify. + bytes value = 5; +} + +// A request to determine whether an entry exists in a cache +// for the specified key. +message ContainsKeyRequest { + // The scope name to use to obtain the cache. + string scope = 1; + // The name of the cache. + string cache = 2; + // The serialization format. + string format = 3; + // The key of the entry to verify. + bytes key = 4; +} + +// A request to determine whether an entry exists in a cache +// with the specified value. +message ContainsValueRequest { + // The scope name to use to obtain the cache. + string scope = 1; + // The name of the cache. + string cache = 2; + // The serialization format. + string format = 3; + // The value of the entry to verify. + bytes value = 4; +} + +// A request to destroy a cache. +message DestroyRequest { + // The scope name to use to obtain the cache. + string scope = 1; + // The name of the cache. + string cache = 2; +} + +// A request to determine whether a cache is empty or not. +message IsEmptyRequest { + // The scope name to use to obtain the cache. + string scope = 1; + // The name of the cache. + string cache = 2; +} + +// A request to determine the number of entries in a cache. +message SizeRequest { + // The scope name to use to obtain the cache. + string scope = 1; + // The name of the cache. + string cache = 2; +} + +// A request to obtain the value to which a cache maps the +// specified key. +message GetRequest { + // The scope name to use to obtain the cache. + string scope = 1; + // The name of the cache. + string cache = 2; + // The serialization format. + string format = 3; + // The key of the entry to retrieve. + bytes key = 4; +} + +// A request to obtain the values that map to the specified keys +message GetAllRequest { + // The scope name to use to obtain the cache. + string scope = 1; + // The name of the cache. + string cache = 2; + // The serialization format. + string format = 3; + // The key of the entry to retrieve. + repeated bytes key = 4; +} + +// A request to associate the specified value with the +// specified key in a cache. +message PutRequest { + // The scope name to use to obtain the cache. + string scope = 1; + // The name of the cache. + string cache = 2; + // The serialization format. + string format = 3; + // The cache entry key. + bytes key = 4; + // The value of the entry. + bytes value = 5; + // The time to live in millis. + int64 ttl = 6; +} + +// A request to associate the specified value with the +// specified key in a cache. +message PutAllRequest { + // The scope name to use to obtain the cache. + string scope = 1; + // The name of the cache. + string cache = 2; + // The serialization format. + string format = 3; + // The cache entries to put. + repeated Entry entry = 4; +} + +// A request to associate the specified value with the +// specified key in a cache only if the specified key +// is not associated with any value (including null). +message PutIfAbsentRequest { + // The scope name to use to obtain the cache. + string scope = 1; + // The name of the cache. + string cache = 2; + // The serialization format. + string format = 3; + // The cache entry key. + bytes key = 4; + // The value to be put. + bytes value = 5; + // The time to live in millis. + int64 ttl = 6; +} + +// A request to remove the mapping for a key from a cache +// if it is present. +message RemoveRequest { + // The scope name to use to obtain the cache. + string scope = 1; + // The name of the cache. + string cache = 2; + // The serialization format. + string format = 3; + // The key of the entry to be removed. + bytes key = 4; +} + +// A request to remove the mapping for a key from a cache +// only if the specified key is associated with the specified +// value in that cache. +message RemoveMappingRequest { + // The scope name to use to obtain the cache. + string scope = 1; + // The name of the cache. + string cache = 2; + // The serialization format. + string format = 3; + // The key of the entry to be removed. + bytes key = 4; + // The value of the entry to verify. + bytes value = 5; +} + +// A request to replace the mapping for the specified key +// with the specified value in a cache only if the specified +// key is associated with some value in that cache. +message ReplaceRequest { + // The scope name to use to obtain the cache. + string scope = 1; + // The name of the cache. + string cache = 2; + // The serialization format. + string format = 3; + // The key of the entry to be replaced. + bytes key = 4; + // The value of the entry to be replaced. + bytes value = 5; +} + +// A request to replace the mapping for the specified key +// with the specified newValue in a cache only if the specified +// key is associated with the specified previousValue in +// that cache. +message ReplaceMappingRequest { + // The scope name to use to obtain the cache. + string scope = 1; + // The name of the cache. + string cache = 2; + // The serialization format. + string format = 3; + // The key of the entry to be replaced. + bytes key = 4; + // The previous value that should exist in the cache. + bytes previousValue = 5; + // The new value to put. + bytes newValue = 6; +} + +// A request for a page of data from a cache. +// This request is used for implementing methods such as NamedCache.keySet(), +// NamedCache.entrySet() and NamedCache.values() where it would be impractical +// to return the whole data set in one response. +message PageRequest { + // The scope name to use to obtain the cache. + string scope = 1; + // The name of the cache. + string cache = 2; + // The serialization format + string format = 3; + // An opaque cookie to track the requested page. + bytes cookie = 4; +} + +// A cache entry key/value pair. +message EntryResult { + // The cache entry key. + bytes key = 1; + // The cache entry value. + bytes value = 2; + // An opaque cookie to track the requested page. + bytes cookie = 3; +} + +// A key value pair. +message Entry { + // The cache entry key. + bytes key = 1; + // The value of the entry. + bytes value = 2; +} + +// A request to truncate a cache. +message TruncateRequest { + // The scope name to use to obtain the cache. + string scope = 1; + // The name of the cache. + string cache = 2; +} + +// A request to add an index to a cache +message AddIndexRequest { + // The scope name to use to obtain the cache. + string scope = 1; + // The name of the cache. + string cache = 2; + // The serialization format. + string format = 3; + // The serialized ValueExtractor to use to create the index. + bytes extractor = 4; + // A flag indicating whether to sort the index. + bool sorted = 5; + // The optional comparator to use to sort the index. + bytes comparator = 6; +} + +// A request to remove an index from a cache +message RemoveIndexRequest { + // The scope name to use to obtain the cache. + string scope = 1; + // The name of the cache. + string cache = 2; + // The serialization format. + string format = 3; + // The serialized ValueExtractor to use to create the index. + bytes extractor = 4; +} + +// A request to aggreagte entries in a cache. +message AggregateRequest { + // The scope name to use to obtain the cache. + string scope = 1; + // The name of the cache. + string cache = 2; + // The serialization format. + string format = 3; + // The serialized EntryAggregator to aggregate. + bytes aggregator = 4; + // The optional set of serialized keys of the entries to aggregate. + repeated bytes keys = 5; + // The optional serialized Filter to identify the entries to aggregate. + bytes filter = 6; +} + +// A request to invoke an EntryProcessor against a single entry. +message InvokeRequest { + // The scope name to use to obtain the cache. + string scope = 1; + // The name of the cache. + string cache = 2; + // The serialization format. + string format = 3; + // The serialized EntryProcessor to invoke. + bytes processor = 4; + // The serialized key of the entry to process. + bytes key = 5; +} + +// A request to invoke an entry processor against a number of entries. +message InvokeAllRequest { + // The scope name to use to obtain the cache. + string scope = 1; + // The name of the cache. + string cache = 2; + // The serialization format. + string format = 3; + // The serialized EntryProcessor to invoke. + bytes processor = 4; + // The optional set of serialized keys of the entries to process. + repeated bytes keys = 5; + // The optional serialized Filter to identify the entries to process. + bytes filter = 6; +} + +// A request to get a set of entries from a cache. +message EntrySetRequest { + // The scope name to use to obtain the cache. + string scope = 1; + // The name of the cache. + string cache = 2; + // The serialization format. + string format = 3; + // The serialized Filter to identify the entries to return. + bytes filter = 4; + // The optional comparator to use to sort the returned entries. + bytes comparator = 5; +} + +// A request to get a set of keys from a cache. +message KeySetRequest { + // The scope name to use to obtain the cache. + string scope = 1; + // The name of the cache. + string cache = 2; + // The serialization format. + string format = 3; + // The serialized Filter to identify the keys to return. + bytes filter = 4; +} + +// A request to get a collection of values from a cache. +message ValuesRequest { + // The scope name to use to obtain the cache. + string scope = 1; + // The name of the cache. + string cache = 2; + // The serialization format. + string format = 3; + // The serialized Filter to identify the values to return. + bytes filter = 4; + // The optional comparator to use to sort the returned values. + bytes comparator = 5; +} + +// An optional value. +message OptionalValue { + // A flag indicating whether the value is present. + bool present = 1; + // The serialized value. + bytes value = 2; +} + +// A message to subscribe to or unsubscribe from MapEvents for a cache. +message MapListenerRequest { + // The scope name to use to obtain the cache. + string scope = 1; + // The name of the cache. + string cache = 2; + // The serialization format. + string format = 3; + // A unique identifier for the request so that the client + // can match a request to a response + string uid = 4; + // An enum representing the request type + enum RequestType { + // The request to initialise the channel. + INIT = 0; + // The request is for a key listener. + KEY = 1; + // The request is for a Filter listener. + FILTER = 2; + } + // The type of the request. + RequestType type = 5; + // The serialized Filter to identify the entries to subscribe to. + bytes filter = 6; + // The serialized key to identify the entry to subscribe to. + bytes key = 7; + // A flag set to true to indicate that the MapEvent objects do + // not have to include the OldValue and NewValue property values + // in order to allow optimizations + bool lite = 8; + // A flag indicating whether to subscribe to (true) or unsubscribe from (false) events. + bool subscribe = 9; + // A flag set to true to indicate that the listener is a priming listener. + bool priming = 10; + // An optional serialized MapTrigger. + bytes trigger = 11; + // A unique filter identifier. + int64 filterId = 12; +} + +// A response to indicate that a MapListener was subscribed to a cache. +message MapListenerResponse { + // A response can be one of either a subscribed response or an event response. + oneof response_type { + MapListenerSubscribedResponse subscribed = 1; + MapListenerUnsubscribedResponse unsubscribed = 2; + MapEventResponse event = 3; + MapListenerErrorResponse error = 4; + CacheDestroyedResponse destroyed = 5; + CacheTruncatedResponse truncated = 6; + } +} + +// A response to indicate that a MapListener was subscribed to a cache. +message MapListenerSubscribedResponse { + string uid = 1; +} + +// A response to indicate that a MapListener was unsubscribed from a cache. +message MapListenerUnsubscribedResponse { + string uid = 1; +} + +// A response to indicate that a cache was destroyed. +message CacheDestroyedResponse { + string cache = 1; +} + +// A response to indicate that a cache was truncated. +message CacheTruncatedResponse { + string cache = 1; +} + +// A response to indicate that an error occurred processing a MapListener request. +message MapListenerErrorResponse { + string uid = 1; + string message = 2; + int32 code = 3; + repeated string stack = 4; +} + +// A response containing a MapEvent for a MapListener +message MapEventResponse { + // The type of the event + int32 id = 1; + // The key of the entry + bytes key = 2; + // The new value of the entry + bytes newValue = 3; + // The old value of the entry + bytes oldValue = 4; + // An enum of TransformationState values to describes how a CacheEvent has been or should be transformed. + enum TransformationState { + // Value used to indicate that an event is non-transformable and should + // not be passed to any transformer-based listeners. + NON_TRANSFORMABLE = 0; + // Value used to indicate that an event is transformable and could be + // passed to transformer-based listeners. + TRANSFORMABLE = 1; + // Value used to indicate that an event has been transformed, and should + // only be passed to transformer-based listeners. + TRANSFORMED = 2; + } + // TransformationState describes how a CacheEvent has been or should be transformed. + TransformationState transformationState = 5; + // The Filter identifiers applicable to the event. + repeated int64 filterIds = 6; + // A flag indicating whether the event is a synthetic event. + bool synthetic = 7; + // A flag indicating whether the event is a priming event. + bool priming = 8; +} diff --git a/etc/proto/proxy_service_messages_v1.proto b/etc/proto/proxy_service_messages_v1.proto new file mode 100644 index 0000000..c8cde72 --- /dev/null +++ b/etc/proto/proxy_service_messages_v1.proto @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2020, 2024, Oracle and/or its affiliates. + * + * Licensed under the Universal Permissive License v 1.0 as shown at + * https://oss.oracle.com/licenses/upl. + */ + +syntax = "proto3"; + +package coherence.proxy.v1; + +import "common_messages_v1.proto"; +import "google/protobuf/any.proto"; + +option java_multiple_files = true; +option java_package = "com.oracle.coherence.grpc.messages.proxy.v1"; + +// ----------------------------------------------------------------- +// Messages used by the Coherence gRPC Proxy Service. +// ----------------------------------------------------------------- + +// A request to the Coherence gRPC proxy. +// Except for a Heartbeat, every request must have a unique id field. +message ProxyRequest { + int64 id = 1; + oneof request { + // The initialization request, which must be the first request sent. + InitRequest init = 3; + // A message that is specific to a Coherence gRPC service. + // Each service on the proxy will know what type to expect here. + google.protobuf.Any message = 4; + // A periodic heartbeat message sent by the client + coherence.common.v1.HeartbeatMessage heartbeat = 5; + } +} + +// A response from a Coherence gRPC proxy. +// Except for a Heartbeat, every response will contain an id field +// that corresponds to the id of the request that the response if for. +message ProxyResponse { + // The identifier of the request messages this response is for, or zero if + // this message is non-request related, for example it is an event. + int64 id = 1; + // The actual response message. + oneof response { + // The response to the initial InitRequest. + InitResponse init = 4; + // A response of a type specific to a Coherence gRPC service. + // The client that sent the corresponding request will know what + // type of message it expects in this field. + google.protobuf.Any message = 5; + // An error response to a specific request id + coherence.common.v1.ErrorMessage error = 6; + // A complete message is sent to indicate that a stream of messages for + // the same request id have been completed. + coherence.common.v1.Complete complete = 7; + // A periodic heart beat sent by the server + coherence.common.v1.HeartbeatMessage heartbeat = 8; + } +} + +// Initialize a connection. +message InitRequest { + // The scope name to use to obtain the server resources. + string scope = 2; + // The serialization format to use. + string format = 3; + // The protocol to use for the channel + string protocol = 4; + // The protocol version requested by the client + int32 protocolVersion = 5; + // The minimum protocol version supported by the client + int32 supportedProtocolVersion = 6; + // The requested frequency that heartbeat messages should be sent by the server (in millis) + optional int64 heartbeat = 7; + // The optional client UUID (usually from Coherence clients that have a local Member UUID). + optional bytes clientUuid = 8; +} + +// The response to an InitRequest +message InitResponse { + // This client connection's UUID. + bytes uuid = 1; + // The Coherence version of the proxy + string version = 2; + // The encoded version of the proxy + int32 encodedVersion = 3; + // The protocol version the client should use + int32 protocolVersion = 4; + // The proxy member Id + int32 proxyMemberId = 5; + // The proxy member UUID + bytes proxyMemberUuid = 6; +} diff --git a/etc/proto/proxy_service_v1.proto b/etc/proto/proxy_service_v1.proto new file mode 100644 index 0000000..f45353a --- /dev/null +++ b/etc/proto/proxy_service_v1.proto @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2020, 2024, Oracle and/or its affiliates. + * + * Licensed under the Universal Permissive License v 1.0 as shown at + * https://oss.oracle.com/licenses/upl. + */ + +// NamedCacheService V2 service definition. + +syntax = "proto3"; + +package coherence.proxy.v1; + +import "proxy_service_messages_v1.proto"; +import "google/protobuf/empty.proto"; +import "google/protobuf/wrappers.proto"; + +option java_multiple_files = true; +option java_package = "com.oracle.coherence.grpc.services.proxy.v1"; + +// ----------------------------------------------------------------- +// The Coherence gRPC Proxy Service definition. +// ----------------------------------------------------------------- + +service ProxyService { + // Sets up a bidirectional channel for cache requests and responses. + rpc subChannel (stream ProxyRequest) returns (stream ProxyResponse) { + } +} diff --git a/etc/proto/services.proto b/etc/proto/services.proto new file mode 100644 index 0000000..6fd7280 --- /dev/null +++ b/etc/proto/services.proto @@ -0,0 +1,155 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. + * + * Licensed under the Universal Permissive License v 1.0 as shown at + * http://oss.oracle.com/licenses/upl. + */ + +// Authors: +// Mahesh Kannan +// Jonathan Knight + +// NamedCacheService service definition. + +syntax = "proto3"; + +package coherence; + +import "messages.proto"; +import "google/protobuf/empty.proto"; +import "google/protobuf/wrappers.proto"; + +option java_multiple_files = true; +option java_package = "com.oracle.coherence.grpc"; + +// A gRPC NamedCache service. +// +service NamedCacheService { + + // Add an index to a cache. + rpc addIndex (AddIndexRequest) returns (google.protobuf.Empty) { + } + + // Obtain the results of running an entry aggregator against the cache. + // The aggregator may run against entries specified by key or entries + // matching a given filter. + rpc aggregate (AggregateRequest) returns (google.protobuf.BytesValue) { + } + + // Clear a cache. + rpc clear (ClearRequest) returns (google.protobuf.Empty) { + } + + // Check if this map contains a mapping for the specified key to the specified value. + rpc containsEntry (ContainsEntryRequest) returns (google.protobuf.BoolValue) { + } + + // Check if this map contains a mapping for the specified key. + rpc containsKey (ContainsKeyRequest) returns (google.protobuf.BoolValue) { + } + + // Check if this map contains a mapping for the specified value. + rpc containsValue (ContainsValueRequest) returns (google.protobuf.BoolValue) { + } + + // Destroy a cache. + rpc destroy (DestroyRequest) returns (google.protobuf.Empty) { + } + + // Obtain all of the entries in the cache where the cache entries + // match a given filter. + rpc entrySet (EntrySetRequest) returns (stream Entry) { + } + + // Sets up a bidirectional channel for cache events. + rpc events (stream MapListenerRequest) returns (stream MapListenerResponse) { + } + + // Get a value for a given key from a cache. + rpc get (GetRequest) returns (OptionalValue) { + } + + // Get all of the values from a cache for a given collection of keys. + rpc getAll (GetAllRequest) returns (stream Entry) { + } + + // Invoke an entry processor against an entry in a cache. + rpc invoke (InvokeRequest) returns (google.protobuf.BytesValue) { + } + + // Invoke an entry processor against a number of entries in a cache. + rpc invokeAll (InvokeAllRequest) returns (stream Entry) { + } + + // Determine whether a cache is empty. + rpc isEmpty (IsEmptyRequest) returns (google.protobuf.BoolValue) { + } + + // Obtain all of the keys in the cache where the cache entries + // match a given filter. + rpc keySet (KeySetRequest) returns (stream google.protobuf.BytesValue) { + } + + // Get the next page of a paged entry set request. + rpc nextEntrySetPage (PageRequest) returns (stream EntryResult) { + } + + // Get the next page of a paged key set request. + rpc nextKeySetPage (PageRequest) returns (stream google.protobuf.BytesValue) { + } + + // Associate the specified value with the specified key in this cache. + // If the cache previously contained a mapping for the key, the old value + // is replaced by the specified value. + // An optional expiry (TTL) value may be set for the entry to expire the + // entry from the cache after that time. + rpc put (PutRequest) returns (google.protobuf.BytesValue) { + } + + // Copies all of the mappings from the request into the cache. + rpc putAll (PutAllRequest) returns (google.protobuf.Empty) { + } + + // If the specified key is not already associated with a value (or is mapped + // to null associate it with the given value and returns null, else return + // the current value. + rpc putIfAbsent (PutIfAbsentRequest) returns (google.protobuf.BytesValue) { + } + + // Remove the mapping that is associated with the specified key. + rpc remove (RemoveRequest) returns (google.protobuf.BytesValue) { + } + + // Remove an index from the cache. + rpc removeIndex (RemoveIndexRequest) returns (google.protobuf.Empty) { + } + + // Remove the mapping that is associated with the specified key only + // if the mapping exists in the cache. + rpc removeMapping (RemoveMappingRequest) returns (google.protobuf.BoolValue) { + } + + // Replace the entry for the specified key only if it is currently + // mapped to some value. + rpc replace (ReplaceRequest) returns (google.protobuf.BytesValue) { + } + + // Replace the mapping for the specified key only if currently mapped + // to the specified value. + rpc replaceMapping (ReplaceMappingRequest) returns (google.protobuf.BoolValue) { + } + + // Determine the number of entries in a cache. + rpc size (SizeRequest) returns (google.protobuf.Int32Value) { + } + + // Truncate a cache. This is the same as clearing a cache but no + // cache entry events will be generated. + rpc truncate (TruncateRequest) returns (google.protobuf.Empty) { + } + + // Obtain all of the values in the cache where the cache entries + // match a given filter. + rpc values (ValuesRequest) returns (stream google.protobuf.BytesValue) { + } +} diff --git a/examples/events.py b/examples/events.py index fcf2424..eeb7695 100644 --- a/examples/events.py +++ b/examples/events.py @@ -31,7 +31,7 @@ async def do_run() -> None: await asyncio.sleep(1) print("Releasing the NamedMap; this should generate an event ...") - named_map.release() + await named_map.release() await asyncio.sleep(1) print("Destroying the NamedMap; this should generate an event ...") diff --git a/pyproject.toml b/pyproject.toml index b84e336..7924f11 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,17 +21,19 @@ classifiers = [ ] [tool.poetry.dependencies] -python = "^3.8" +python = "^3.9" protobuf = "5.29.1" grpcio = "1.68.1" grpcio-tools = "1.68.1" jsonpickle = ">=3.0,<4.1" pymitter = ">=0.4,<0.6" typing-extensions = ">=4.11,<4.13" +types-protobuf = "5.27.0.20240626" +pympler = "1.1" [tool.poetry.dev-dependencies] pytest = "~8.3" -pytest-asyncio = "~0.24" +pytest-asyncio = "~0.25" pytest-cov = "~5.0" pytest-unordered = "~0.6" pre-commit = "~3.5" @@ -40,8 +42,10 @@ sphinx-rtd-theme = "~2.0" sphinxcontrib-napoleon = "~0.7" m2r = "~0.3" third-party-license-file-generator = "~2024.8" +pyinstrument="5.0.0" [tool.pytest.ini_options] +asyncio_default_fixture_loop_scope = "function" pythonpath = ["src"] [tool.isort] diff --git a/src/coherence/__init__.py b/src/coherence/__init__.py index 3fe1e6c..cc70564 100644 --- a/src/coherence/__init__.py +++ b/src/coherence/__init__.py @@ -6,25 +6,33 @@ __version__ = "1.1.1" +import contextvars import logging +from typing import Final # expose these symbols in top-level namespace from .aggregator import Aggregators as Aggregators -from .client import MapEntry as MapEntry +from .client import CacheOptions as CacheOptions from .client import NamedCache as NamedCache from .client import NamedMap as NamedMap from .client import Options as Options from .client import Session as Session from .client import TlsOptions as TlsOptions +from .client import request_timeout as request_timeout from .comparator import Comparator as Comparator +from .entry import MapEntry as MapEntry from .extractor import Extractors as Extractors from .filter import Filters as Filters +from .local_cache import CacheStats as CacheStats +from .local_cache import NearCacheOptions as NearCacheOptions from .processor import Processors as Processors # default logging configuration for coherence handler: logging.StreamHandler = logging.StreamHandler() # type: ignore handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")) +_TIMEOUT_CONTEXT_VAR: Final[contextvars.ContextVar[float]] = contextvars.ContextVar("coherence-request-timeout") + COH_LOG = logging.getLogger("coherence") COH_LOG.setLevel(logging.INFO) COH_LOG.addHandler(handler) diff --git a/src/coherence/aggregator.py b/src/coherence/aggregator.py index 4c22591..27e187a 100644 --- a/src/coherence/aggregator.py +++ b/src/coherence/aggregator.py @@ -346,8 +346,9 @@ def __init__( that is based on an array of corresponding :class:`coherence.extractor.UniversalExtractor` objects; may not be `NONE` - :param aggregator: an EntryAggregator object; may not be null - :param filter: an optional Filter object used to filter out results of individual group aggregation results + :param aggregator: an EntryAggregator object; may not be null + :param filter: an optional Filter object used to filter out results + of individual group aggregation results """ super().__init__(extractor_or_property) if aggregator is not None: @@ -654,8 +655,8 @@ def group_by( :param extractor_or_property: the extractor or method/property name to provide values for aggregation :param aggregator: the underlying :class:`coherence.aggregator.EntryAggregator` - :param filter: an optional :class:`coherence.filter.Filter` object used to filter out results of individual - group aggregation results + :param filter: an optional :class:`coherence.filter.Filter` object used to filter out results + of individual group aggregation results :return: a :class:`coherence.aggregator.GroupAggregator` based on a specified property or method name(s) and an :class:`coherence.aggregator.EntryAggregator`. """ diff --git a/src/coherence/ai.py b/src/coherence/ai.py new file mode 100644 index 0000000..939403f --- /dev/null +++ b/src/coherence/ai.py @@ -0,0 +1,334 @@ +# Copyright (c) 2022, 2024, Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at +# https://oss.oracle.com/licenses/upl. + +from __future__ import annotations + +import base64 +import math +from abc import ABC +from collections import OrderedDict +from typing import Any, Dict, List, Optional, TypeVar, Union, cast + +import jsonpickle + +from coherence.aggregator import EntryAggregator +from coherence.extractor import ValueExtractor +from coherence.filter import Filter +from coherence.serialization import _META_CLASS, JavaProxyUnpickler, proxy + +E = TypeVar("E") +T = TypeVar("T") +K = TypeVar("K") +V = TypeVar("V") + + +class Vector(ABC): + """ + Base class that represents a Vector + """ + + def __init__(self) -> None: + """ + Constructs a new `Vector`. + """ + super().__init__() + + +@proxy("ai.BitVector") +class BitVector(Vector): + """ + Class that represents a Vector of Bits + """ + + def __init__( + self, + hex_string: str, + byte_array: Optional[bytes] = None, + int_array: Optional[List[int]] = None, + ): + """ + Creates an instance of BitVector + + :param hex_string: hexadecimal string used to create the BitVector + :param byte_array: optional byte array used to create the BitVector + :param int_array: optional int array used to create the BitVector + """ + super().__init__() + if hex_string is not None: + if hex_string.startswith("0x"): + self.bits = hex_string[2:] + else: + self.bits = hex_string + elif byte_array is not None: + self.bits = byte_array.hex() + else: + self.bits = "" + for i in int_array: + self.bits += hex(i)[2:] + self.bits = "0x" + self.bits + + +@proxy("ai.Int8Vector") +class ByteVector(Vector): + """ + Class that represents Vector of bytes + """ + + def __init__(self, byte_array: bytes): + """ + Creates an instance of ByteVector + + :param byte_array: byte array used to create a ByteVector + """ + super().__init__() + self.array = base64.b64encode(byte_array).decode("UTF-8") + + +@proxy("ai.Float32Vector") +class FloatVector(Vector): + """ + Class that represents Vector of floats + """ + + def __init__(self, float_array: List[float]): + """ + Creates an instance of FloatVector + + :param float_array: array of floats used to create a FloatVector + """ + super().__init__() + self.array = float_array + + +class AbstractEvolvable(ABC): + def __init__(self, data_version: int = 0, bin_future: Optional[Any] = None): + self.dataVersion = data_version + self.binFuture = bin_future + + +@proxy("ai.DocumentChunk") +class DocumentChunk(AbstractEvolvable): + """ + Class that represents a chunk of text extracted from a document. + """ + + def __init__( + self, + text: str, + metadata: Optional[Dict[str, Any] | OrderedDict[str, Any]] = None, + vector: Optional[Vector] = None, + ): + """ + Creates an instance of DocumentChunk class + + :param text: the chunk of text extracted from a document + :param metadata: optional document metadata + :param vector: the vector associated with the document chunk + """ + super().__init__() + self.text = text + if metadata is None: + self.metadata: Dict[str, Any] = OrderedDict() + else: + self.metadata = metadata + self.vector = vector + + +@jsonpickle.handlers.register(DocumentChunk) +class DocumentChunkHandler(jsonpickle.handlers.BaseHandler): + def flatten(self, obj: object, data: Dict[str, Any]) -> Dict[str, Any]: + dc: DocumentChunk = cast(DocumentChunk, obj) + result_dict: Dict[Any, Any] = dict() + result_dict[_META_CLASS] = "ai.DocumentChunk" + result_dict["dataVersion"] = dc.dataVersion + if hasattr(dc, "binFuture") and dc.binFuture is not None: + result_dict["binFuture"] = dc.binFuture + if hasattr(dc, "metadata") and dc.metadata is not None: + result_dict["metadata"] = dict() + if isinstance(dc.metadata, OrderedDict): + result_dict["metadata"]["@ordered"] = True + entries = list() + for k, v in dc.metadata.items(): + entries.append({"key": k, "value": v}) + result_dict["metadata"]["entries"] = entries + if hasattr(dc, "vector"): + v = dc.vector + if v is not None: + if isinstance(v, BitVector): + result_dict["vector"] = dict() + result_dict["vector"][_META_CLASS] = "ai.BitVector" + # noinspection PyUnresolvedReferences + result_dict["vector"]["bits"] = v.bits + elif isinstance(v, ByteVector): + result_dict["vector"] = dict() + result_dict["vector"][_META_CLASS] = "ai.Int8Vector" + # noinspection PyUnresolvedReferences + result_dict["vector"]["array"] = v.array + elif isinstance(v, FloatVector): + result_dict["vector"] = dict() + result_dict["vector"][_META_CLASS] = "ai.Float32Vector" + # noinspection PyUnresolvedReferences + result_dict["vector"]["array"] = v.array + result_dict["text"] = dc.text + return result_dict + + def restore(self, obj: Dict[str, Any]) -> DocumentChunk: + jpu = JavaProxyUnpickler() + d = DocumentChunk("") + o = jpu._restore_from_dict(obj, d) + return o + + +class DistanceAlgorithm(ABC): + """ + Base class that represents algorithm that can calculate distance to a given vector + """ + + def __init__(self) -> None: + super().__init__() + + +@proxy("ai.distance.CosineSimilarity") +class CosineDistance(DistanceAlgorithm): + """ + Represents a DistanceAlgorithm that performs a cosine similarity calculation + between two vectors. Cosine similarity measures the similarity between two + vectors of an inner product space. It is measured by the cosine of the angle + between two vectors and determines whether two vectors are pointing in + roughly the same direction. It is often used to measure document similarity + in text analysis. + """ + + def __init__(self) -> None: + super().__init__() + + +@proxy("ai.distance.InnerProductSimilarity") +class InnerProductDistance(DistanceAlgorithm): + """ + Represents a DistanceAlgorithm that performs inner product distance + calculation between two vectors. + """ + + def __init__(self) -> None: + super().__init__() + + +@proxy("ai.distance.L2SquaredDistance") +class L2SquaredDistance(DistanceAlgorithm): + """ + Represents a DistanceAlgorithm that performs an L2 squared distance + calculation between two vectors. + """ + + def __init__(self) -> None: + super().__init__() + + +@proxy("ai.search.SimilarityAggregator") +class SimilaritySearch(EntryAggregator): + """ + This class represents an aggregator to execute a similarity query. + """ + + def __init__( + self, + extractor_or_property: Union[ValueExtractor[T, E], str], + vector: Vector, + max_results: int, + algorithm: Optional[DistanceAlgorithm] = CosineDistance(), + filter: Optional[Filter] = None, + brute_force: Optional[bool] = False, + ) -> None: + """ + Create a SimilaritySearch aggregator that will use cosine distance to + calculate and return up to `max_results` results that are closest to the + specified `vector`. + + :param extractor_or_property: the ValueExtractor to extract the vector + from the cache value + :param vector: the vector to calculate similarity with + :param max_results: the maximum number of results to return + :param algorithm: the distance algorithm to use + :param filter: filter to use to limit the set of entries to search. + :param brute_force: Force brute force search, ignoring any available indexes. + """ + super().__init__(extractor_or_property) + self.algorithm = algorithm + self.bruteForce = brute_force + self.filter = filter + self.maxResults = max_results + self.vector = vector + + +class BaseQueryResult(ABC): + """ + A base class for QueryResult implementation + """ + + def __init__(self, result: float, key: K, value: V) -> None: + self.distance = result + self.key = key + self.value = value + + +@proxy("ai.results.QueryResult") +class QueryResult(BaseQueryResult): + """ + QueryResult class + """ + + def __init__(self, result: float, key: K, value: V) -> None: + """ + Creates an instance of the QueryResult class + + :param result: the query result + :param key: the key of the vector the result applies to + :param value: the optional result value + """ + super().__init__(result, key, value) + + def __str__(self) -> str: + return "QueryResult{ " + "result=" + str(self.distance) + ", key=" + str(self.key) + "}" + + +@proxy("ai.index.BinaryQuantIndex") +class BinaryQuantIndex(AbstractEvolvable): + """ + This class represents a custom index using binary quantization of vectors + """ + + def __init__(self, extractor: Union[ValueExtractor[T, E], str], over_sampling_factor: int = 3) -> None: + """ + Creates an instance of BinaryQuantIndex class + + :param extractor: the ValueExtractor to use to extract the Vector + :param over_sampling_factor: the oversampling factor + """ + super().__init__() + self.extractor = extractor + self.oversamplingFactor = over_sampling_factor + + +class Vectors: + + EPSILON = 1e-30 # Python automatically handles float precision + + @staticmethod + def normalize(array: List[float]) -> List[float]: + norm = 0.0 + c_dim = len(array) + + # Calculate the norm (sum of squares) + for v in array: + norm += v * v + + # Compute the normalization factor (inverse of the square root of the sum of squares) + norm = 1.0 / (math.sqrt(norm) + Vectors.EPSILON) + + # Apply the normalization factor to each element in the array + for i in range(c_dim): + array[i] = array[i] * norm + + return array diff --git a/src/coherence/cache_service_messages_v1_pb2.py b/src/coherence/cache_service_messages_v1_pb2.py new file mode 100644 index 0000000..d3d092a --- /dev/null +++ b/src/coherence/cache_service_messages_v1_pb2.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: cache_service_messages_v1.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +import coherence.common_messages_v1_pb2 as common__messages__v1__pb2 +from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 +from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 +from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2 +from google.protobuf import wrappers_pb2 as google_dot_protobuf_dot_wrappers__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1f\x63\x61\x63he_service_messages_v1.proto\x12\x12\x63oherence.cache.v1\x1a\x18\x63ommon_messages_v1.proto\x1a\x19google/protobuf/any.proto\x1a\x1bgoogle/protobuf/empty.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1egoogle/protobuf/wrappers.proto\"\xa6\x01\n\x11NamedCacheRequest\x12\x37\n\x04type\x18\x01 \x01(\x0e\x32).coherence.cache.v1.NamedCacheRequestType\x12\x14\n\x07\x63\x61\x63heId\x18\x02 \x01(\x05H\x00\x88\x01\x01\x12*\n\x07message\x18\x03 \x01(\x0b\x32\x14.google.protobuf.AnyH\x01\x88\x01\x01\x42\n\n\x08_cacheIdB\n\n\x08_message\"\x8d\x01\n\x12NamedCacheResponse\x12\x0f\n\x07\x63\x61\x63heId\x18\x01 \x01(\x05\x12.\n\x04type\x18\x02 \x01(\x0e\x32 .coherence.cache.v1.ResponseType\x12*\n\x07message\x18\x03 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00\x88\x01\x01\x42\n\n\x08_message\"#\n\x12\x45nsureCacheRequest\x12\r\n\x05\x63\x61\x63he\x18\x01 \x01(\t\"B\n\nPutRequest\x12\x0b\n\x03key\x18\x01 \x01(\x0c\x12\r\n\x05value\x18\x02 \x01(\x0c\x12\x10\n\x03ttl\x18\x03 \x01(\x03H\x00\x88\x01\x01\x42\x06\n\x04_ttl\"b\n\rPutAllRequest\x12\x37\n\x07\x65ntries\x18\x01 \x03(\x0b\x32&.coherence.common.v1.BinaryKeyAndValue\x12\x10\n\x03ttl\x18\x02 \x01(\x03H\x00\x88\x01\x01\x42\x06\n\x04_ttl\"M\n\x15ReplaceMappingRequest\x12\x0b\n\x03key\x18\x01 \x01(\x0c\x12\x15\n\rpreviousValue\x18\x02 \x01(\x0c\x12\x10\n\x08newValue\x18\x03 \x01(\x0c\"v\n\x0cIndexRequest\x12\x0b\n\x03\x61\x64\x64\x18\x01 \x01(\x08\x12\x11\n\textractor\x18\x02 \x01(\x0c\x12\x13\n\x06sorted\x18\x03 \x01(\x08H\x00\x88\x01\x01\x12\x17\n\ncomparator\x18\x04 \x01(\x0cH\x01\x88\x01\x01\x42\t\n\x07_sortedB\r\n\x0b_comparator\"|\n\x0cKeysOrFilter\x12\r\n\x03key\x18\x01 \x01(\x0cH\x00\x12<\n\x04keys\x18\x02 \x01(\x0b\x32,.coherence.common.v1.CollectionOfBytesValuesH\x00\x12\x10\n\x06\x66ilter\x18\x03 \x01(\x0cH\x00\x42\r\n\x0bkeyOrFilter\"=\n\x0bKeyOrFilter\x12\r\n\x03key\x18\x01 \x01(\x0cH\x00\x12\x10\n\x06\x66ilter\x18\x02 \x01(\x0cH\x00\x42\r\n\x0bkeyOrFilter\"]\n\x0e\x45xecuteRequest\x12\r\n\x05\x61gent\x18\x01 \x01(\x0c\x12\x33\n\x04keys\x18\x03 \x01(\x0b\x32 .coherence.cache.v1.KeysOrFilterH\x00\x88\x01\x01\x42\x07\n\x05_keys\"V\n\x0cQueryRequest\x12\x13\n\x06\x66ilter\x18\x01 \x01(\x0cH\x00\x88\x01\x01\x12\x17\n\ncomparator\x18\x02 \x01(\x0cH\x01\x88\x01\x01\x42\t\n\x07_filterB\r\n\x0b_comparator\"\xc9\x01\n\x12MapListenerRequest\x12\x11\n\tsubscribe\x18\x01 \x01(\x08\x12\x39\n\x0bkeyOrFilter\x18\x02 \x01(\x0b\x32\x1f.coherence.cache.v1.KeyOrFilterH\x00\x88\x01\x01\x12\x10\n\x08\x66ilterId\x18\x03 \x01(\x03\x12\x0c\n\x04lite\x18\x04 \x01(\x08\x12\x13\n\x0bsynchronous\x18\x05 \x01(\x08\x12\x0f\n\x07priming\x18\x06 \x01(\x08\x12\x0f\n\x07trigger\x18\x07 \x01(\x0c\x42\x0e\n\x0c_keyOrFilter\"\xd5\x02\n\x0fMapEventMessage\x12\n\n\x02id\x18\x01 \x01(\x05\x12\x0b\n\x03key\x18\x02 \x01(\x0c\x12\x10\n\x08newValue\x18\x03 \x01(\x0c\x12\x10\n\x08oldValue\x18\x04 \x01(\x0c\x12T\n\x13transformationState\x18\x05 \x01(\x0e\x32\x37.coherence.cache.v1.MapEventMessage.TransformationState\x12\x11\n\tfilterIds\x18\x06 \x03(\x03\x12\x11\n\tsynthetic\x18\x07 \x01(\x08\x12\x0f\n\x07priming\x18\x08 \x01(\x08\x12\x0f\n\x07\x65xpired\x18\t \x01(\x08\x12\x15\n\rversionUpdate\x18\n \x01(\x08\"P\n\x13TransformationState\x12\x15\n\x11NON_TRANSFORMABLE\x10\x00\x12\x11\n\rTRANSFORMABLE\x10\x01\x12\x0f\n\x0bTRANSFORMED\x10\x02*\xbd\x03\n\x15NamedCacheRequestType\x12\x0b\n\x07Unknown\x10\x00\x12\x0f\n\x0b\x45nsureCache\x10\x01\x12\r\n\tAggregate\x10\x02\x12\t\n\x05\x43lear\x10\x03\x12\x11\n\rContainsEntry\x10\x04\x12\x0f\n\x0b\x43ontainsKey\x10\x05\x12\x11\n\rContainsValue\x10\x06\x12\x0b\n\x07\x44\x65stroy\x10\x07\x12\x0b\n\x07IsEmpty\x10\x08\x12\x0b\n\x07IsReady\x10\t\x12\x07\n\x03Get\x10\n\x12\n\n\x06GetAll\x10\x0b\x12\t\n\x05Index\x10\x0c\x12\n\n\x06Invoke\x10\r\x12\x0f\n\x0bMapListener\x10\x0e\x12\x11\n\rPageOfEntries\x10\x0f\x12\x0e\n\nPageOfKeys\x10\x10\x12\x07\n\x03Put\x10\x11\x12\n\n\x06PutAll\x10\x12\x12\x0f\n\x0bPutIfAbsent\x10\x13\x12\x10\n\x0cQueryEntries\x10\x14\x12\r\n\tQueryKeys\x10\x15\x12\x0f\n\x0bQueryValues\x10\x16\x12\n\n\x06Remove\x10\x17\x12\x11\n\rRemoveMapping\x10\x18\x12\x0b\n\x07Replace\x10\x19\x12\x12\n\x0eReplaceMapping\x10\x1a\x12\x08\n\x04Size\x10\x1b\x12\x0c\n\x08Truncate\x10\x1c*G\n\x0cResponseType\x12\x0b\n\x07Message\x10\x00\x12\x0c\n\x08MapEvent\x10\x01\x12\r\n\tDestroyed\x10\x02\x12\r\n\tTruncated\x10\x03\x42/\n+com.oracle.coherence.grpc.messages.cache.v1P\x01\x62\x06proto3') + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'cache_service_messages_v1_pb2', globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'\n+com.oracle.coherence.grpc.messages.cache.v1P\001' + _NAMEDCACHEREQUESTTYPE._serialized_start=1840 + _NAMEDCACHEREQUESTTYPE._serialized_end=2285 + _RESPONSETYPE._serialized_start=2287 + _RESPONSETYPE._serialized_end=2358 + _NAMEDCACHEREQUEST._serialized_start=203 + _NAMEDCACHEREQUEST._serialized_end=369 + _NAMEDCACHERESPONSE._serialized_start=372 + _NAMEDCACHERESPONSE._serialized_end=513 + _ENSURECACHEREQUEST._serialized_start=515 + _ENSURECACHEREQUEST._serialized_end=550 + _PUTREQUEST._serialized_start=552 + _PUTREQUEST._serialized_end=618 + _PUTALLREQUEST._serialized_start=620 + _PUTALLREQUEST._serialized_end=718 + _REPLACEMAPPINGREQUEST._serialized_start=720 + _REPLACEMAPPINGREQUEST._serialized_end=797 + _INDEXREQUEST._serialized_start=799 + _INDEXREQUEST._serialized_end=917 + _KEYSORFILTER._serialized_start=919 + _KEYSORFILTER._serialized_end=1043 + _KEYORFILTER._serialized_start=1045 + _KEYORFILTER._serialized_end=1106 + _EXECUTEREQUEST._serialized_start=1108 + _EXECUTEREQUEST._serialized_end=1201 + _QUERYREQUEST._serialized_start=1203 + _QUERYREQUEST._serialized_end=1289 + _MAPLISTENERREQUEST._serialized_start=1292 + _MAPLISTENERREQUEST._serialized_end=1493 + _MAPEVENTMESSAGE._serialized_start=1496 + _MAPEVENTMESSAGE._serialized_end=1837 + _MAPEVENTMESSAGE_TRANSFORMATIONSTATE._serialized_start=1757 + _MAPEVENTMESSAGE_TRANSFORMATIONSTATE._serialized_end=1837 +# @@protoc_insertion_point(module_scope) diff --git a/src/coherence/cache_service_messages_v1_pb2.pyi b/src/coherence/cache_service_messages_v1_pb2.pyi new file mode 100644 index 0000000..c33c0f0 --- /dev/null +++ b/src/coherence/cache_service_messages_v1_pb2.pyi @@ -0,0 +1,199 @@ +# mypy: ignore-errors +import common_messages_v1_pb2 as _common_messages_v1_pb2 +from google.protobuf import any_pb2 as _any_pb2 +from google.protobuf import empty_pb2 as _empty_pb2 +from google.protobuf import timestamp_pb2 as _timestamp_pb2 +from google.protobuf import wrappers_pb2 as _wrappers_pb2 +from google.protobuf.internal import containers as _containers +from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union + +Aggregate: NamedCacheRequestType +Clear: NamedCacheRequestType +ContainsEntry: NamedCacheRequestType +ContainsKey: NamedCacheRequestType +ContainsValue: NamedCacheRequestType +DESCRIPTOR: _descriptor.FileDescriptor +Destroy: NamedCacheRequestType +Destroyed: ResponseType +EnsureCache: NamedCacheRequestType +Get: NamedCacheRequestType +GetAll: NamedCacheRequestType +Index: NamedCacheRequestType +Invoke: NamedCacheRequestType +IsEmpty: NamedCacheRequestType +IsReady: NamedCacheRequestType +MapEvent: ResponseType +MapListener: NamedCacheRequestType +Message: ResponseType +PageOfEntries: NamedCacheRequestType +PageOfKeys: NamedCacheRequestType +Put: NamedCacheRequestType +PutAll: NamedCacheRequestType +PutIfAbsent: NamedCacheRequestType +QueryEntries: NamedCacheRequestType +QueryKeys: NamedCacheRequestType +QueryValues: NamedCacheRequestType +Remove: NamedCacheRequestType +RemoveMapping: NamedCacheRequestType +Replace: NamedCacheRequestType +ReplaceMapping: NamedCacheRequestType +Size: NamedCacheRequestType +Truncate: NamedCacheRequestType +Truncated: ResponseType +Unknown: NamedCacheRequestType + +class EnsureCacheRequest(_message.Message): + __slots__ = ["cache"] + CACHE_FIELD_NUMBER: _ClassVar[int] + cache: str + def __init__(self, cache: _Optional[str] = ...) -> None: ... + +class ExecuteRequest(_message.Message): + __slots__ = ["agent", "keys"] + AGENT_FIELD_NUMBER: _ClassVar[int] + KEYS_FIELD_NUMBER: _ClassVar[int] + agent: bytes + keys: KeysOrFilter + def __init__(self, agent: _Optional[bytes] = ..., keys: _Optional[_Union[KeysOrFilter, _Mapping]] = ...) -> None: ... + +class IndexRequest(_message.Message): + __slots__ = ["add", "comparator", "extractor", "sorted"] + ADD_FIELD_NUMBER: _ClassVar[int] + COMPARATOR_FIELD_NUMBER: _ClassVar[int] + EXTRACTOR_FIELD_NUMBER: _ClassVar[int] + SORTED_FIELD_NUMBER: _ClassVar[int] + add: bool + comparator: bytes + extractor: bytes + sorted: bool + def __init__(self, add: bool = ..., extractor: _Optional[bytes] = ..., sorted: bool = ..., comparator: _Optional[bytes] = ...) -> None: ... + +class KeyOrFilter(_message.Message): + __slots__ = ["filter", "key"] + FILTER_FIELD_NUMBER: _ClassVar[int] + KEY_FIELD_NUMBER: _ClassVar[int] + filter: bytes + key: bytes + def __init__(self, key: _Optional[bytes] = ..., filter: _Optional[bytes] = ...) -> None: ... + +class KeysOrFilter(_message.Message): + __slots__ = ["filter", "key", "keys"] + FILTER_FIELD_NUMBER: _ClassVar[int] + KEYS_FIELD_NUMBER: _ClassVar[int] + KEY_FIELD_NUMBER: _ClassVar[int] + filter: bytes + key: bytes + keys: _common_messages_v1_pb2.CollectionOfBytesValues + def __init__(self, key: _Optional[bytes] = ..., keys: _Optional[_Union[_common_messages_v1_pb2.CollectionOfBytesValues, _Mapping]] = ..., filter: _Optional[bytes] = ...) -> None: ... + +class MapEventMessage(_message.Message): + __slots__ = ["expired", "filterIds", "id", "key", "newValue", "oldValue", "priming", "synthetic", "transformationState", "versionUpdate"] + class TransformationState(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = [] + EXPIRED_FIELD_NUMBER: _ClassVar[int] + FILTERIDS_FIELD_NUMBER: _ClassVar[int] + ID_FIELD_NUMBER: _ClassVar[int] + KEY_FIELD_NUMBER: _ClassVar[int] + NEWVALUE_FIELD_NUMBER: _ClassVar[int] + NON_TRANSFORMABLE: MapEventMessage.TransformationState + OLDVALUE_FIELD_NUMBER: _ClassVar[int] + PRIMING_FIELD_NUMBER: _ClassVar[int] + SYNTHETIC_FIELD_NUMBER: _ClassVar[int] + TRANSFORMABLE: MapEventMessage.TransformationState + TRANSFORMATIONSTATE_FIELD_NUMBER: _ClassVar[int] + TRANSFORMED: MapEventMessage.TransformationState + VERSIONUPDATE_FIELD_NUMBER: _ClassVar[int] + expired: bool + filterIds: _containers.RepeatedScalarFieldContainer[int] + id: int + key: bytes + newValue: bytes + oldValue: bytes + priming: bool + synthetic: bool + transformationState: MapEventMessage.TransformationState + versionUpdate: bool + def __init__(self, id: _Optional[int] = ..., key: _Optional[bytes] = ..., newValue: _Optional[bytes] = ..., oldValue: _Optional[bytes] = ..., transformationState: _Optional[_Union[MapEventMessage.TransformationState, str]] = ..., filterIds: _Optional[_Iterable[int]] = ..., synthetic: bool = ..., priming: bool = ..., expired: bool = ..., versionUpdate: bool = ...) -> None: ... + +class MapListenerRequest(_message.Message): + __slots__ = ["filterId", "keyOrFilter", "lite", "priming", "subscribe", "synchronous", "trigger"] + FILTERID_FIELD_NUMBER: _ClassVar[int] + KEYORFILTER_FIELD_NUMBER: _ClassVar[int] + LITE_FIELD_NUMBER: _ClassVar[int] + PRIMING_FIELD_NUMBER: _ClassVar[int] + SUBSCRIBE_FIELD_NUMBER: _ClassVar[int] + SYNCHRONOUS_FIELD_NUMBER: _ClassVar[int] + TRIGGER_FIELD_NUMBER: _ClassVar[int] + filterId: int + keyOrFilter: KeyOrFilter + lite: bool + priming: bool + subscribe: bool + synchronous: bool + trigger: bytes + def __init__(self, subscribe: bool = ..., keyOrFilter: _Optional[_Union[KeyOrFilter, _Mapping]] = ..., filterId: _Optional[int] = ..., lite: bool = ..., synchronous: bool = ..., priming: bool = ..., trigger: _Optional[bytes] = ...) -> None: ... + +class NamedCacheRequest(_message.Message): + __slots__ = ["cacheId", "message", "type"] + CACHEID_FIELD_NUMBER: _ClassVar[int] + MESSAGE_FIELD_NUMBER: _ClassVar[int] + TYPE_FIELD_NUMBER: _ClassVar[int] + cacheId: int + message: _any_pb2.Any + type: NamedCacheRequestType + def __init__(self, type: _Optional[_Union[NamedCacheRequestType, str]] = ..., cacheId: _Optional[int] = ..., message: _Optional[_Union[_any_pb2.Any, _Mapping]] = ...) -> None: ... + +class NamedCacheResponse(_message.Message): + __slots__ = ["cacheId", "message", "type"] + CACHEID_FIELD_NUMBER: _ClassVar[int] + MESSAGE_FIELD_NUMBER: _ClassVar[int] + TYPE_FIELD_NUMBER: _ClassVar[int] + cacheId: int + message: _any_pb2.Any + type: ResponseType + def __init__(self, cacheId: _Optional[int] = ..., type: _Optional[_Union[ResponseType, str]] = ..., message: _Optional[_Union[_any_pb2.Any, _Mapping]] = ...) -> None: ... + +class PutAllRequest(_message.Message): + __slots__ = ["entries", "ttl"] + ENTRIES_FIELD_NUMBER: _ClassVar[int] + TTL_FIELD_NUMBER: _ClassVar[int] + entries: _containers.RepeatedCompositeFieldContainer[_common_messages_v1_pb2.BinaryKeyAndValue] + ttl: int + def __init__(self, entries: _Optional[_Iterable[_Union[_common_messages_v1_pb2.BinaryKeyAndValue, _Mapping]]] = ..., ttl: _Optional[int] = ...) -> None: ... + +class PutRequest(_message.Message): + __slots__ = ["key", "ttl", "value"] + KEY_FIELD_NUMBER: _ClassVar[int] + TTL_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: bytes + ttl: int + value: bytes + def __init__(self, key: _Optional[bytes] = ..., value: _Optional[bytes] = ..., ttl: _Optional[int] = ...) -> None: ... + +class QueryRequest(_message.Message): + __slots__ = ["comparator", "filter"] + COMPARATOR_FIELD_NUMBER: _ClassVar[int] + FILTER_FIELD_NUMBER: _ClassVar[int] + comparator: bytes + filter: bytes + def __init__(self, filter: _Optional[bytes] = ..., comparator: _Optional[bytes] = ...) -> None: ... + +class ReplaceMappingRequest(_message.Message): + __slots__ = ["key", "newValue", "previousValue"] + KEY_FIELD_NUMBER: _ClassVar[int] + NEWVALUE_FIELD_NUMBER: _ClassVar[int] + PREVIOUSVALUE_FIELD_NUMBER: _ClassVar[int] + key: bytes + newValue: bytes + previousValue: bytes + def __init__(self, key: _Optional[bytes] = ..., previousValue: _Optional[bytes] = ..., newValue: _Optional[bytes] = ...) -> None: ... + +class NamedCacheRequestType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = [] + +class ResponseType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = [] diff --git a/src/coherence/cache_service_messages_v1_pb2_grpc.py b/src/coherence/cache_service_messages_v1_pb2_grpc.py new file mode 100644 index 0000000..2daafff --- /dev/null +++ b/src/coherence/cache_service_messages_v1_pb2_grpc.py @@ -0,0 +1,4 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + diff --git a/src/coherence/client.py b/src/coherence/client.py index 14a061d..5517ec9 100644 --- a/src/coherence/client.py +++ b/src/coherence/client.py @@ -8,10 +8,13 @@ import asyncio import logging import os +import sys +import textwrap import time +import traceback import uuid -from asyncio import Condition, Task -from threading import Lock +from asyncio import Condition, Event, Task +from contextlib import asynccontextmanager from typing import ( Any, AsyncIterator, @@ -31,18 +34,46 @@ # noinspection PyPackageRequirements import grpc +from google.protobuf.json_format import MessageToJson +from grpc.aio import Channel, StreamStreamMultiCallable from pymitter import EventEmitter from .aggregator import AverageAggregator, EntryAggregator, PriorityAggregator, SumAggregator +from .cache_service_messages_v1_pb2 import MapEventMessage, NamedCacheResponse, ResponseType from .comparator import Comparator -from .event import MapLifecycleEvent, MapListener, SessionLifecycleEvent +from .entry import MapEntry +from .event import ( + MapEvent, + MapEventType, + MapLifecycleEvent, + MapListener, + SessionLifecycleEvent, + _ListenerGroup, + _MapEventsManagerV0, + _MapEventsManagerV1, +) from .extractor import ValueExtractor from .filter import Filter +from .local_cache import CacheStats, LocalCache, NearCacheOptions from .messages_pb2 import PageRequest from .processor import EntryProcessor +from .proxy_service_messages_v1_pb2 import ProxyRequest, ProxyResponse +from .proxy_service_v1_pb2_grpc import ProxyServiceStub from .serialization import Serializer, SerializerRegistry from .services_pb2_grpc import NamedCacheServiceStub -from .util import RequestFactory +from .util import ( + Dispatcher, + PagingDispatcher, + RequestFactory, + RequestFactoryV1, + ResponseObserver, + StreamingDispatcher, + UnaryDispatcher, + cur_time_millis, +) + +_SPECIFY_EXTRACTOR: Final[str] = "A ValueExtractor must be specified" +_SPECIFY_MAP_LISTENER: Final[str] = "A MapListener must be specified" E = TypeVar("E") K = TypeVar("K") @@ -53,6 +84,68 @@ COH_LOG = logging.getLogger("coherence") +@asynccontextmanager +async def request_timeout(seconds: float): # type: ignore + from . import _TIMEOUT_CONTEXT_VAR + + request_timeout = _TIMEOUT_CONTEXT_VAR.set(seconds) + try: + yield + finally: + _TIMEOUT_CONTEXT_VAR.reset(request_timeout) + + +# noinspection PyUnresolvedReferences,PyProtectedMember +class _Handshake: + def __init__(self, session: Session): + self._protocol_version: int = 0 + self._proxy_version: str = "unknown" + self._proxy_member_id: int = 0 + self._session = session + self._channel: Channel = session.channel + self._stream: Optional[StreamStreamMultiCallable] = None + + async def handshake(self) -> None: + stub: ProxyServiceStub = ProxyServiceStub(self._channel) + stream: StreamStreamMultiCallable = stub.subChannel() + try: + await stream.write(RequestFactoryV1.init_sub_channel()) + response = await asyncio.wait_for(stream.read(), self._session.options.request_timeout_seconds) + stream.cancel() # cancel the stream; no longer needed + self._proxy_version = response.init.version + self._protocol_version = response.init.protocolVersion + self._proxy_member_id = response.init.proxyMemberId + except grpc.aio._call.AioRpcError as e: + error_code: int = e.code().value[0] + if ( + # Check for StatusCode INTERNAL as work around for + # grpc issue https://github.com/grpc/grpc/issues/36066 + error_code == grpc.StatusCode.UNIMPLEMENTED.value[0] + or error_code == grpc.StatusCode.INTERNAL.value[0] + ): + return + else: + raise RuntimeError( + f"Unexpected error, {e}, when attempting to handshake with proxy: {e.details()}" + ) from e + except asyncio.TimeoutError as e: + raise RuntimeError("Handshake with proxy timed out") from e + finally: + stream.cancel() + + @property + def protocol_version(self) -> int: + return self._protocol_version + + @property + def proxy_version(self) -> str: + return self._proxy_version + + @property + def proxy_member_id(self) -> int: + return self._proxy_member_id + + @no_type_check def _pre_call_cache(func): def inner(self, *args, **kwargs): @@ -96,17 +189,67 @@ async def inner_async(self, *args, **kwargs): return inner -class MapEntry(Generic[K, V]): - """ - A map entry (key-value pair). - """ +class CacheOptions: + def __init__(self, default_expiry: int = 0, near_cache_options: Optional[NearCacheOptions] = None): + """ + Constructs a new CacheOptions which may be used in configuring the behavior of a NamedMap or NamedCache. + + :param default_expiry: the default expiration time, in millis, that will be applied to entries + inserted into a NamedMap or NamedCache + :param near_cache_options: the near caching configuration this NamedMap or NamedCache should use + """ + super().__init__() + self._default_expiry: int = default_expiry if default_expiry >= 0 else -1 + self._near_cache_options = near_cache_options + + def __str__(self) -> str: + """ + :return: the string representation of this CacheOptions instance. + """ + result: str = f"CacheOptions(default-expiry={self._default_expiry}" + result += ")" if self.near_cache_options is None else f", near-cache-options={self.near_cache_options})" + return result - def __init__(self, key: K, value: V): - self.key = key - self.value = value + def __eq__(self, other: Any) -> bool: + """ + Compare two CacheOptions for equality. + + :param other: the CacheOptions to compare against + :return: True if equal otherwise False + """ + if self is other: + return True + + if isinstance(other, CacheOptions): + return ( + self.default_expiry == other.default_expiry and True + if self.near_cache_options is None + else self.near_cache_options == other.near_cache_options + ) + + return False + + @property + def default_expiry(self) -> int: + """ + The configured default entry time-to-live. + + :return: the default entry ttl + """ + return self._default_expiry + + @property + def near_cache_options(self) -> Optional[NearCacheOptions]: + """ + The configured NearCacheOptions. + + :return: the configured NearCacheOptions, if any + """ + return self._near_cache_options class NamedMap(abc.ABC, Generic[K, V]): + # noinspection PyUnresolvedReferences """ A Map-based data-structure that manages entries across one or more processes. Entries are typically managed in memory, and are often comprised of data that is also stored persistently, on disk. @@ -118,7 +261,24 @@ class NamedMap(abc.ABC, Generic[K, V]): @property @abc.abstractmethod def name(self) -> str: - """documentation""" + """Returns the logical name of this NamedMap""" + + @property + @abc.abstractmethod + def session(self) -> Session: + """Returns the Session associated with this NamedMap""" + + @property + @abc.abstractmethod + def options(self) -> Optional[CacheOptions]: + """Returns the CacheOptions associated with this NamedMap""" + + @property + def near_cache_stats(self) -> Optional[CacheStats]: + """ + Returns the CacheStats of the near cache, if one has been configured. + """ + return None @abc.abstractmethod def on(self, event: MapLifecycleEvent, callback: Callable[[str], None]) -> None: @@ -206,7 +366,7 @@ async def get_or_default(self, key: K, default_value: Optional[V] = None) -> Opt """ @abc.abstractmethod - def get_all(self, keys: set[K]) -> AsyncIterator[MapEntry[K, V]]: + async def get_all(self, keys: set[K]) -> AsyncIterator[MapEntry[K, V]]: """ Get all the specified keys if they are in the map. For each key that is in the map, that key and its corresponding value will be placed in the map that is returned by @@ -219,7 +379,7 @@ def get_all(self, keys: set[K]) -> AsyncIterator[MapEntry[K, V]]: """ @abc.abstractmethod - async def put(self, key: K, value: V) -> V: + async def put(self, key: K, value: V) -> Optional[V]: """ Associates the specified value with the specified key in this map. If the map previously contained a mapping for this key, the old value is replaced. @@ -233,7 +393,7 @@ async def put(self, key: K, value: V) -> V: """ @abc.abstractmethod - async def put_if_absent(self, key: K, value: V) -> V: + async def put_if_absent(self, key: K, value: V) -> Optional[V]: """ If the specified key is not already associated with a value (or is mapped to `None`) associates it with the given value and returns `None`, else returns the current value. @@ -247,11 +407,12 @@ async def put_if_absent(self, key: K, value: V) -> V: """ @abc.abstractmethod - async def put_all(self, map: dict[K, V]) -> None: + async def put_all(self, map: dict[K, V], ttl: Optional[int] = 0) -> None: """ Copies all mappings from the specified map to this map :param map: the map to copy from + :param ttl: the time to live for the map entries """ @abc.abstractmethod @@ -273,7 +434,7 @@ async def destroy(self) -> None: """ @abc.abstractmethod - def release(self) -> None: + async def release(self) -> None: """ Release local resources associated with instance. @@ -288,7 +449,7 @@ async def truncate(self) -> None: """ @abc.abstractmethod - async def remove(self, key: K) -> V: + async def remove(self, key: K) -> Optional[V]: """ Removes the mapping for a key from this map if it is present. @@ -307,7 +468,7 @@ async def remove_mapping(self, key: K, value: V) -> bool: """ @abc.abstractmethod - async def replace(self, key: K, value: V) -> V: + async def replace(self, key: K, value: V) -> Optional[V]: """ Replaces the entry for the specified key only if currently mapped to the specified value. @@ -364,7 +525,7 @@ async def size(self) -> int: """ @abc.abstractmethod - async def invoke(self, key: K, processor: EntryProcessor[R]) -> R: + async def invoke(self, key: K, processor: EntryProcessor[R]) -> Optional[R]: """ Invoke the passed EntryProcessor against the Entry specified by the passed key, returning the result of the invocation. @@ -375,7 +536,7 @@ async def invoke(self, key: K, processor: EntryProcessor[R]) -> R: """ @abc.abstractmethod - def invoke_all( + async def invoke_all( self, processor: EntryProcessor[R], keys: Optional[set[K]] = None, filter: Optional[Filter] = None ) -> AsyncIterator[MapEntry[K, R]]: """ @@ -401,7 +562,7 @@ def invoke_all( @abc.abstractmethod async def aggregate( self, aggregator: EntryAggregator[R], keys: Optional[set[K]] = None, filter: Optional[Filter] = None - ) -> R: + ) -> Optional[R]: """ Perform an aggregating operation against the entries specified by the passed keys. @@ -412,13 +573,13 @@ async def aggregate( """ @abc.abstractmethod - def values( + async def values( self, filter: Optional[Filter] = None, comparator: Optional[Comparator] = None, by_page: bool = False ) -> AsyncIterator[V]: """ Return a Set of the values contained in this map that satisfy the criteria expressed by the filter. If no filter or comparator is specified, it returns a Set view of the values contained in this map.The - collection is backed by the map, so changes to the map are reflected in the collection, and vice-versa. If + collection is backed by the map, so changes to the map are reflected in the collection, and vice versa. If the map is modified while an iteration over the collection is in progress (except through the iterator's own `remove` operation), the results of the iteration are undefined. @@ -431,7 +592,7 @@ def values( """ @abc.abstractmethod - def keys(self, filter: Optional[Filter] = None, by_page: bool = False) -> AsyncIterator[K]: + async def keys(self, filter: Optional[Filter] = None, by_page: bool = False) -> AsyncIterator[K]: """ Return a set view of the keys contained in this map for entries that satisfy the criteria expressed by the filter. @@ -443,7 +604,7 @@ def keys(self, filter: Optional[Filter] = None, by_page: bool = False) -> AsyncI """ @abc.abstractmethod - def entries( + async def entries( self, filter: Optional[Filter] = None, comparator: Optional[Comparator] = None, by_page: bool = False ) -> AsyncIterator[MapEntry[K, V]]: """ @@ -488,6 +649,7 @@ def remove_index(self, extractor: ValueExtractor[T, E]) -> None: class NamedCache(NamedMap[K, V]): + # noinspection PyUnresolvedReferences """ A Map-based data-structure that manages entries across one or more processes. Entries are typically managed in memory, and are often comprised of data that is also stored in an external system, for example, a database, @@ -499,14 +661,15 @@ class NamedCache(NamedMap[K, V]): """ @abc.abstractmethod - async def put(self, key: K, value: V, ttl: int = -1) -> V: + async def put(self, key: K, value: V, ttl: Optional[int] = None) -> Optional[V]: """ Associates the specified value with the specified key in this map. If the map previously contained a mapping for this key, the old value is replaced. :param key: the key with which the specified value is to be associated :param value: the value to be associated with the specified key - :param ttl: the expiry time in millis (optional) + :param ttl: the expiry time in millis (optional). If not specific, it will default to the default + ttl defined in the cache options provided when the cache was obtained :return: resolving to the previous value associated with specified key, or `None` if there was no mapping for key. A `None` return can also indicate that the map previously associated `None` with the specified key if the implementation supports `None` values @@ -514,14 +677,15 @@ async def put(self, key: K, value: V, ttl: int = -1) -> V: """ @abc.abstractmethod - async def put_if_absent(self, key: K, value: V, ttl: int = -1) -> V: + async def put_if_absent(self, key: K, value: V, ttl: Optional[int] = None) -> Optional[V]: """ If the specified key is not already associated with a value (or is mapped to null) associates it with the given value and returns `None`, else returns the current value. :param key: the key with which the specified value is to be associated :param value: the value to be associated with the specified key - :param ttl: the expiry time in millis (optional) + :param ttl: the expiry time in millis (optional). If not specific, it will default to the default + ttl defined in the cache options provided when the cache was obtained. :return: resolving to the previous value associated with specified key, or `None` if there was no mapping for key. A `None` return can also indicate that the map previously associated `None` with the specified key if the implementation supports `None` values @@ -530,7 +694,9 @@ async def put_if_absent(self, key: K, value: V, ttl: int = -1) -> V: class NamedCacheClient(NamedCache[K, V]): - def __init__(self, cache_name: str, session: Session, serializer: Serializer): + def __init__( + self, cache_name: str, session: Session, serializer: Serializer, cache_options: Optional[CacheOptions] = None + ) -> None: self._cache_name: str = cache_name self._serializer: Serializer = serializer self._client_stub: NamedCacheServiceStub = NamedCacheServiceStub(session.channel) @@ -540,11 +706,12 @@ def __init__(self, cache_name: str, session: Session, serializer: Serializer): self._destroyed: bool = False self._released: bool = False self._session: Session = session - from .event import _MapEventsManager + self._cache_options: Optional[CacheOptions] = cache_options + self._default_expiry: int = cache_options.default_expiry if cache_options is not None else 0 self._setup_event_handlers() - self._events_manager: _MapEventsManager[K, V] = _MapEventsManager( + self._events_manager: _MapEventsManagerV0[K, V] = _MapEventsManagerV0( self, session, self._client_stub, serializer, self._internal_emitter ) @@ -552,6 +719,10 @@ def __init__(self, cache_name: str, session: Session, serializer: Serializer): def name(self) -> str: return self._cache_name + @property + def session(self) -> Session: + return self._session + @property def destroyed(self) -> bool: return self._destroyed @@ -560,6 +731,10 @@ def destroyed(self) -> bool: def released(self) -> bool: return self._released + @property + def options(self) -> Optional[CacheOptions]: + return self._cache_options + @_pre_call_cache def on(self, event: MapLifecycleEvent, callback: Callable[[str], None]) -> None: self._emitter.on(str(event.value), callback) @@ -569,7 +744,7 @@ async def get(self, key: K) -> Optional[V]: g = self._request_factory.get_request(key) v = await self._client_stub.get(g) if v.present: - return self._request_factory.get_serializer().deserialize(v.value) + return self._request_factory.serializer.deserialize(v.value) else: return None @@ -582,23 +757,23 @@ async def get_or_default(self, key: K, default_value: Optional[V] = None) -> Opt return default_value @_pre_call_cache - def get_all(self, keys: set[K]) -> AsyncIterator[MapEntry[K, V]]: + async def get_all(self, keys: set[K]) -> AsyncIterator[MapEntry[K, V]]: r = self._request_factory.get_all_request(keys) stream = self._client_stub.getAll(r) - return _Stream(self._request_factory.get_serializer(), stream, _entry_producer) + return _Stream(self._request_factory.serializer, stream, _entry_producer) @_pre_call_cache - async def put(self, key: K, value: V, ttl: int = -1) -> V: - p = self._request_factory.put_request(key, value, ttl) + async def put(self, key: K, value: V, ttl: Optional[int] = None) -> Optional[V]: + p = self._request_factory.put_request(key, value, ttl if ttl is not None else self._default_expiry) v = await self._client_stub.put(p) - return self._request_factory.get_serializer().deserialize(v.value) + return self._request_factory.serializer.deserialize(v.value) @_pre_call_cache - async def put_if_absent(self, key: K, value: V, ttl: int = -1) -> V: - p = self._request_factory.put_if_absent_request(key, value, ttl) + async def put_if_absent(self, key: K, value: V, ttl: Optional[int] = None) -> Optional[V]: + p = self._request_factory.put_if_absent_request(key, value, ttl if ttl is not None else self._default_expiry) v = await self._client_stub.putIfAbsent(p) - return self._request_factory.get_serializer().deserialize(v.value) + return self._request_factory.serializer.deserialize(v.value) @_pre_call_cache async def put_all(self, map: dict[K, V]) -> None: @@ -611,16 +786,16 @@ async def clear(self) -> None: await self._client_stub.clear(r) async def destroy(self) -> None: - self.release() + await self.release() self._internal_emitter.once(MapLifecycleEvent.DESTROYED.value) self._internal_emitter.emit(MapLifecycleEvent.DESTROYED.value, self.name) r = self._request_factory.destroy_request() await self._client_stub.destroy(r) - @_pre_call_cache - def release(self) -> None: - self._internal_emitter.once(MapLifecycleEvent.RELEASED.value) - self._internal_emitter.emit(MapLifecycleEvent.RELEASED.value, self.name) + async def release(self) -> None: + if self.active: + self._internal_emitter.once(MapLifecycleEvent.RELEASED.value) + self._internal_emitter.emit(MapLifecycleEvent.RELEASED.value, self.name) @_pre_call_cache async def truncate(self) -> None: @@ -629,81 +804,82 @@ async def truncate(self) -> None: await self._client_stub.truncate(r) @_pre_call_cache - async def remove(self, key: K) -> V: + async def remove(self, key: K) -> Optional[V]: r = self._request_factory.remove_request(key) v = await self._client_stub.remove(r) - return self._request_factory.get_serializer().deserialize(v.value) + return self._request_factory.serializer.deserialize(v.value) @_pre_call_cache async def remove_mapping(self, key: K, value: V) -> bool: r = self._request_factory.remove_mapping_request(key, value) v = await self._client_stub.removeMapping(r) - return self._request_factory.get_serializer().deserialize(v.value) + return self._request_factory.serializer.deserialize(v.value) @_pre_call_cache - async def replace(self, key: K, value: V) -> V: + async def replace(self, key: K, value: V) -> Optional[V]: r = self._request_factory.replace_request(key, value) v = await self._client_stub.replace(r) - return self._request_factory.get_serializer().deserialize(v.value) + return self._request_factory.serializer.deserialize(v.value) @_pre_call_cache async def replace_mapping(self, key: K, old_value: V, new_value: V) -> bool: r = self._request_factory.replace_mapping_request(key, old_value, new_value) v = await self._client_stub.replaceMapping(r) - return self._request_factory.get_serializer().deserialize(v.value) + return self._request_factory.serializer.deserialize(v.value) @_pre_call_cache async def contains_key(self, key: K) -> bool: r = self._request_factory.contains_key_request(key) v = await self._client_stub.containsKey(r) - return self._request_factory.get_serializer().deserialize(v.value) + return self._request_factory.serializer.deserialize(v.value) @_pre_call_cache async def contains_value(self, value: V) -> bool: r = self._request_factory.contains_value_request(value) v = await self._client_stub.containsValue(r) - return self._request_factory.get_serializer().deserialize(v.value) + return self._request_factory.serializer.deserialize(v.value) @_pre_call_cache async def is_empty(self) -> bool: r = self._request_factory.is_empty_request() v = await self._client_stub.isEmpty(r) - return self._request_factory.get_serializer().deserialize(v.value) + return self._request_factory.serializer.deserialize(v.value) @_pre_call_cache async def size(self) -> int: r = self._request_factory.size_request() v = await self._client_stub.size(r) - return self._request_factory.get_serializer().deserialize(v.value) + return self._request_factory.serializer.deserialize(v.value) @_pre_call_cache - async def invoke(self, key: K, processor: EntryProcessor[R]) -> R: + async def invoke(self, key: K, processor: EntryProcessor[R]) -> Optional[R]: r = self._request_factory.invoke_request(key, processor) v = await self._client_stub.invoke(r) - return self._request_factory.get_serializer().deserialize(v.value) + return self._request_factory.serializer.deserialize(v.value) @_pre_call_cache - def invoke_all( + async def invoke_all( self, processor: EntryProcessor[R], keys: Optional[set[K]] = None, filter: Optional[Filter] = None ) -> AsyncIterator[MapEntry[K, R]]: r = self._request_factory.invoke_all_request(processor, keys, filter) stream = self._client_stub.invokeAll(r) - return _Stream(self._request_factory.get_serializer(), stream, _entry_producer) + return _Stream(self._request_factory.serializer, stream, _entry_producer) @_pre_call_cache async def aggregate( self, aggregator: EntryAggregator[R], keys: Optional[set[K]] = None, filter: Optional[Filter] = None - ) -> R: + ) -> Optional[R]: r = self._request_factory.aggregate_request(aggregator, keys, filter) results = await self._client_stub.aggregate(r) - value: Any = self._request_factory.get_serializer().deserialize(results.value) + value: Any = self._request_factory.serializer.deserialize(results.value) # for compatibility with 22.06 if isinstance(aggregator, SumAggregator) and isinstance(value, str): return cast(R, float(value)) elif isinstance(aggregator, AverageAggregator) and isinstance(value, str): return cast(R, float(value)) elif isinstance(aggregator, PriorityAggregator): + # noinspection PyTypeChecker,PyUnresolvedReferences pri_agg: PriorityAggregator[R] = aggregator if ( isinstance(pri_agg.aggregator, AverageAggregator) or isinstance(pri_agg.aggregator, SumAggregator) @@ -714,7 +890,7 @@ async def aggregate( return cast(R, value) @_pre_call_cache - def values( + async def values( self, filter: Optional[Filter] = None, comparator: Optional[Comparator] = None, by_page: bool = False ) -> AsyncIterator[V]: if by_page and comparator is None and filter is None: @@ -723,20 +899,20 @@ def values( r = self._request_factory.values_request(filter) stream = self._client_stub.values(r) - return _Stream(self._request_factory.get_serializer(), stream, _scalar_producer) + return _Stream(self._request_factory.serializer, stream, _scalar_producer) @_pre_call_cache - def keys(self, filter: Optional[Filter] = None, by_page: bool = False) -> AsyncIterator[K]: + async def keys(self, filter: Optional[Filter] = None, by_page: bool = False) -> AsyncIterator[K]: if by_page and filter is None: return _PagedStream(self, _scalar_deserializer, True) else: r = self._request_factory.keys_request(filter) stream = self._client_stub.keySet(r) - return _Stream(self._request_factory.get_serializer(), stream, _scalar_producer) + return _Stream(self._request_factory.serializer, stream, _scalar_producer) @_pre_call_cache - def entries( + async def entries( self, filter: Optional[Filter] = None, comparator: Optional[Comparator] = None, by_page: bool = False ) -> AsyncIterator[MapEntry[K, V]]: if by_page and comparator is None and filter is None: @@ -745,7 +921,7 @@ def entries( r = self._request_factory.entries_request(filter, comparator) stream = self._client_stub.entrySet(r) - return _Stream(self._request_factory.get_serializer(), stream, _entry_producer) + return _Stream(self._request_factory.serializer, stream, _entry_producer) from .event import MapListener @@ -755,7 +931,7 @@ async def add_map_listener( self, listener: MapListener[K, V], listener_for: Optional[K | Filter] = None, lite: bool = False ) -> None: if listener is None: - raise ValueError("A MapListener must be specified") + raise ValueError(_SPECIFY_MAP_LISTENER) if listener_for is None or isinstance(listener_for, Filter): await self._events_manager._register_filter_listener(listener, listener_for, lite) @@ -766,7 +942,7 @@ async def add_map_listener( @_pre_call_cache async def remove_map_listener(self, listener: MapListener[K, V], listener_for: Optional[K | Filter] = None) -> None: if listener is None: - raise ValueError("A MapListener must be specified") + raise ValueError(_SPECIFY_MAP_LISTENER) if listener_for is None or isinstance(listener_for, Filter): await self._events_manager._remove_filter_listener(listener, listener_for) @@ -778,14 +954,14 @@ async def add_index( self, extractor: ValueExtractor[T, E], ordered: bool = False, comparator: Optional[Comparator] = None ) -> None: if extractor is None: - raise ValueError("A ValueExtractor must be specified") + raise ValueError(_SPECIFY_EXTRACTOR) r = self._request_factory.add_index_request(extractor, ordered, comparator) await self._client_stub.addIndex(r) @_pre_call_cache async def remove_index(self, extractor: ValueExtractor[T, E]) -> None: if extractor is None: - raise ValueError("A ValueExtractor must be specified") + raise ValueError(_SPECIFY_EXTRACTOR) r = self._request_factory.remove_index_request(extractor) await self._client_stub.removeIndex(r) @@ -827,6 +1003,451 @@ def __str__(self) -> str: ) +class NamedCacheClientV1(NamedCache[K, V]): + + def __init__( + self, cache_name: str, session: Session, serializer: Serializer, cache_options: Optional[CacheOptions] = None + ): + self._cache_name: str = cache_name + self._cache_id: int = 0 + self._serializer: Serializer = serializer + self._request_factory: RequestFactoryV1 = RequestFactoryV1( + cache_name, self._cache_id, session.scope, serializer, lambda: session.options.request_timeout_seconds + ) + self._emitter: EventEmitter = EventEmitter() + self._internal_emitter: EventEmitter = EventEmitter() + self._destroyed: bool = False + self._released: bool = False + self._session: Session = session + self._cache_options: Optional[CacheOptions] = cache_options + self._default_expiry: int = cache_options.default_expiry if cache_options is not None else 0 + self._near_cache: Optional[LocalCache[K, V]] = None + self._near_cache_listener: Optional[MapListener[K, V]] = None + self._near_cache_lock: asyncio.Lock = asyncio.Lock() + + self._events_manager: _MapEventsManagerV1[K, V] = _MapEventsManagerV1( + self, session, serializer, self._internal_emitter, self._request_factory + ) + + self._stream_handler: StreamHandler = StreamHandler(session, self._request_factory, self._events_manager) + self._setup_event_handlers() + + near_options: Optional[NearCacheOptions] = None if cache_options is None else cache_options.near_cache_options + if near_options is not None: + self._near_cache = LocalCache(cache_name, near_options) + + async def _post_create(self) -> None: + near: Optional[LocalCache[K, V]] = self._near_cache + if near is not None: + # setup event listener + async def callback(event: MapEvent[K, V]) -> None: + if event.type == MapEventType.ENTRY_INSERTED or event.type == MapEventType.ENTRY_UPDATED: + if await near.contains_key(event.key): + val: Optional[V] = event.new + if val is not None: + await near.put(event.key, val) + elif event.type == MapEventType.ENTRY_DELETED: + # processing a remove + await near.remove(event.key) + + self._near_cache_listener = MapListener(synchronous=True).on_any(callback) # type: ignore + await self.add_map_listener(self._near_cache_listener) + + def _setup_event_handlers(self) -> None: + """ + Setup handlers to notify cache-level handlers of events. + """ + emitter: EventEmitter = self._emitter + internal_emitter: EventEmitter = self._internal_emitter + this: NamedCacheClientV1[K, V] = self + cache_name = self._cache_name + + # noinspection PyProtectedMember + def on_destroyed(name: str) -> None: + if name == cache_name and not this.destroyed: + this._events_manager._close() + this._destroyed = True + this._released = True + emitter.emit(MapLifecycleEvent.DESTROYED.value, name) + + # noinspection PyProtectedMember + def on_released(name: str) -> None: + if name == cache_name and not this.released: + this._events_manager._close() + this._released = True + emitter.emit(MapLifecycleEvent.RELEASED.value, name) + + def on_truncated(name: str) -> None: + if name == cache_name: + emitter.emit(MapLifecycleEvent.TRUNCATED.value, name) + + internal_emitter.on(MapLifecycleEvent.DESTROYED.value, on_destroyed) + internal_emitter.on(MapLifecycleEvent.RELEASED.value, on_released) + internal_emitter.on(MapLifecycleEvent.TRUNCATED.value, on_truncated) + + near: Optional[LocalCache[K, V]] = this._near_cache + if near is not None: + # setup lifecycle callbacks to clear the near cache + async def do_clear() -> None: + await near.clear() + + self.on(MapLifecycleEvent.TRUNCATED, do_clear) # type: ignore + self.on(MapLifecycleEvent.DESTROYED, do_clear) # type: ignore + + @property + def cache_id(self) -> int: + return self._cache_id + + @cache_id.setter + def cache_id(self, cache_id: int) -> None: + self._cache_id = cache_id + + @property + def name(self) -> str: + return self._cache_name + + @property + def session(self) -> Session: + return self._session + + @property + def options(self) -> Optional[CacheOptions]: + return self._cache_options + + @property + def near_cache_stats(self) -> Optional[CacheStats]: + near_cache: Optional[LocalCache[K, V]] = self._near_cache + return None if near_cache is None else near_cache.stats + + def on(self, event: MapLifecycleEvent, callback: Callable[[str], None]) -> None: + self._emitter.on(str(event.value), callback) + + @property + def destroyed(self) -> bool: + return self._destroyed + + @property + def released(self) -> bool: + return self._released + + async def _ensure_cache(self) -> None: + dispatcher: UnaryDispatcher[int] = self._request_factory.ensure_request(self._cache_name) + await dispatcher.dispatch(self._stream_handler) + + self.cache_id = dispatcher.result() + self._request_factory.cache_id = self.cache_id + + @_pre_call_cache + async def get(self, key: K) -> Optional[V]: + near_cache: Optional[LocalCache[K, V]] = self._near_cache + + # check the near cache first + if near_cache is not None: + async with self._near_cache_lock: + result: Optional[V] = await near_cache.get(key) + if result is not None: + return result + + start: int = cur_time_millis() + dispatcher: UnaryDispatcher[Optional[V]] = self._request_factory.get_request(key) + await dispatcher.dispatch(self._stream_handler) + result = dispatcher.result() + + if result is not None: + await near_cache.put(key, result) + # noinspection PyProtectedMember + near_cache.stats._register_misses_millis(cur_time_millis() - start) + else: + dispatcher = self._request_factory.get_request(key) + await dispatcher.dispatch(self._stream_handler) + result = dispatcher.result() + + return result + + @_pre_call_cache + async def put(self, key: K, value: V, ttl: Optional[int] = None) -> Optional[V]: + dispatcher: UnaryDispatcher[Optional[V]] = self._request_factory.put_request( + key, value, ttl if ttl is not None else self._default_expiry + ) + await dispatcher.dispatch(self._stream_handler) + return dispatcher.result() + + @_pre_call_cache + async def put_if_absent(self, key: K, value: V, ttl: Optional[int] = None) -> Optional[V]: + dispatcher: UnaryDispatcher[Optional[V]] = self._request_factory.put_if_absent_request( + key, value, ttl if ttl is not None else self._default_expiry + ) + await dispatcher.dispatch(self._stream_handler) + return dispatcher.result() + + # noinspection PyProtectedMember + @_pre_call_cache + async def add_map_listener( + self, listener: MapListener[K, V], listener_for: Optional[K | Filter] = None, lite: bool = False + ) -> None: + if listener is None: + raise ValueError(_SPECIFY_MAP_LISTENER) + + if listener_for is None or isinstance(listener_for, Filter): + await self._events_manager._register_filter_listener(listener, listener_for, lite) + else: + await self._events_manager._register_key_listener(listener, listener_for, lite) + + # noinspection PyProtectedMember + @_pre_call_cache + async def remove_map_listener(self, listener: MapListener[K, V], listener_for: Optional[K | Filter] = None) -> None: + if listener is None: + raise ValueError(_SPECIFY_MAP_LISTENER) + + if listener_for is None or isinstance(listener_for, Filter): + await self._events_manager._remove_filter_listener(listener, listener_for) + else: + await self._events_manager._remove_key_listener(listener, listener_for) + + @_pre_call_cache + async def get_or_default(self, key: K, default_value: Optional[V] = None) -> Optional[V]: + v: Optional[V] = await self.get(key) + if v is not None: + return v + else: + return default_value + + # noinspection PyProtectedMember + @_pre_call_cache + async def get_all(self, keys: set[K]) -> AsyncIterator[MapEntry[K, V]]: + near_cache: Optional[LocalCache[K, V]] = self._near_cache + result: dict[K, V] + + # check the near cache first + if near_cache is not None: + async with self._near_cache_lock: + result = await near_cache.get_all(keys) + if result is not None: + if len(result) == len(keys): + # all keys were found, return an AsyncIterator + # over those results + async def async_iter() -> AsyncIterator[MapEntry[K, V]]: + for key, value in result.items(): + yield MapEntry(key, value) + + return async_iter() + else: + # some keys are present within the near cache; make + # a remote call to obtain the keys that are missing. + stats: CacheStats = near_cache.stats + remote_keys: set[K] = keys.difference(result) + start: int = cur_time_millis() + dispatcher: StreamingDispatcher[MapEntry[K, V]] = self._request_factory.get_all_request( + remote_keys + ) + await dispatcher.dispatch(self._stream_handler) + stats._register_misses_millis(cur_time_millis() - start) + + # we could return a composite AsyncIterator that would + # yield the local keys followed by the results from the + # remote call, but doing could result in additional + # remote calls if there are concurrent get_all() calls + # that start the same time. Instead, populate the + # near cache while locked and then return results + # This is not the most memory efficient, but it makes + # the stats more likely to make sense to the user. + remote_entries: list[MapEntry[K, V]] = [] + async for entry in dispatcher: + await near_cache.put(entry.key, entry.value) + remote_entries.append(entry) + + stats._register_misses_millis(cur_time_millis() - start) + + # noinspection PyProtectedMember + async def async_iter() -> AsyncIterator[MapEntry[K, V]]: + for key, value in result.items(): + yield MapEntry(key, value) + for remote_entry in remote_entries: + yield remote_entry + + return async_iter() + else: + dispatcher = self._request_factory.get_all_request(keys) + await dispatcher.dispatch(self._stream_handler) + return dispatcher + else: + dispatcher = self._request_factory.get_all_request(keys) + await dispatcher.dispatch(self._stream_handler) + return dispatcher + + @_pre_call_cache + async def put_all(self, kv_map: dict[K, V], ttl: Optional[int] = 0) -> None: + dispatcher: Dispatcher = self._request_factory.put_all_request(kv_map, ttl) + await dispatcher.dispatch(self._stream_handler) + + @_pre_call_cache + async def clear(self) -> None: + dispatcher: Dispatcher = self._request_factory.clear_request() + await dispatcher.dispatch(self._stream_handler) + if self._near_cache is not None: + await self._near_cache.clear() + + async def destroy(self) -> None: + self._internal_emitter.once(MapLifecycleEvent.DESTROYED.value) + self._internal_emitter.emit(MapLifecycleEvent.DESTROYED.value, self.name) + dispatcher: Dispatcher = self._request_factory.destroy_request() + await dispatcher.dispatch(self._stream_handler) + + async def release(self) -> None: + if self.active: + await self._stream_handler.close() + self._internal_emitter.once(MapLifecycleEvent.RELEASED.value) + self._internal_emitter.emit(MapLifecycleEvent.RELEASED.value, self.name) + + if self._near_cache is not None: + await self._near_cache.release() + + @_pre_call_cache + async def truncate(self) -> None: + dispatcher: Dispatcher = self._request_factory.truncate_request() + await dispatcher.dispatch(self._stream_handler) + + # clear the near cache as the lifecycle listeners are not synchronous + if self._near_cache is not None: + await self._near_cache.clear() + + @_pre_call_cache + async def remove(self, key: K) -> Optional[V]: + dispatcher: UnaryDispatcher[Optional[V]] = self._request_factory.remove_request(key) + await dispatcher.dispatch(self._stream_handler) + return dispatcher.result() + + @_pre_call_cache + async def remove_mapping(self, key: K, value: V) -> bool: + dispatcher: UnaryDispatcher[bool] = self._request_factory.remove_mapping_request(key, value) + await dispatcher.dispatch(self._stream_handler) + return dispatcher.result() + + @_pre_call_cache + async def replace(self, key: K, value: V) -> Optional[V]: + dispatcher: UnaryDispatcher[Optional[V]] = self._request_factory.replace_request(key, value) + await dispatcher.dispatch(self._stream_handler) + return dispatcher.result() + + @_pre_call_cache + async def replace_mapping(self, key: K, old_value: V, new_value: V) -> bool: + dispatcher: UnaryDispatcher[bool] = self._request_factory.replace_mapping_request(key, old_value, new_value) + await dispatcher.dispatch(self._stream_handler) + return dispatcher.result() + + @_pre_call_cache + async def contains_key(self, key: K) -> bool: + near_cache: Optional[LocalCache[K, V]] = self._near_cache + + # check the near cache first + if near_cache is not None: + result: Optional[V] = await near_cache.get(key) + if result is not None: + return True + + dispatcher: UnaryDispatcher[bool] = self._request_factory.contains_key_request(key) + await dispatcher.dispatch(self._stream_handler) + return dispatcher.result() + + @_pre_call_cache + async def contains_value(self, value: V) -> bool: + dispatcher: UnaryDispatcher[bool] = self._request_factory.contains_value_request(value) + await dispatcher.dispatch(self._stream_handler) + return dispatcher.result() + + @_pre_call_cache + async def is_empty(self) -> bool: + dispatcher: UnaryDispatcher[bool] = self._request_factory.is_empty_request() + await dispatcher.dispatch(self._stream_handler) + return dispatcher.result() + + @_pre_call_cache + async def size(self) -> int: + dispatcher: UnaryDispatcher[int] = self._request_factory.size_request() + await dispatcher.dispatch(self._stream_handler) + return dispatcher.result() + + @_pre_call_cache + async def invoke(self, key: K, processor: EntryProcessor[R]) -> Optional[R]: + dispatcher: UnaryDispatcher[Optional[R]] = self._request_factory.invoke_request(key, processor) + await dispatcher.dispatch(self._stream_handler) + return dispatcher.result() + + @_pre_call_cache + async def invoke_all( + self, processor: EntryProcessor[R], keys: Optional[set[K]] = None, filter: Optional[Filter] = None + ) -> AsyncIterator[MapEntry[K, R]]: + dispatcher: StreamingDispatcher[MapEntry[K, R]] = self._request_factory.invoke_all_request( + processor, keys, filter + ) + await dispatcher.dispatch(self._stream_handler) + return dispatcher + + @_pre_call_cache + async def aggregate( + self, aggregator: EntryAggregator[R], keys: Optional[set[K]] = None, filter: Optional[Filter] = None + ) -> Optional[R]: + dispatcher: UnaryDispatcher[Optional[R]] = self._request_factory.aggregate_request(aggregator, keys, filter) + await dispatcher.dispatch(self._stream_handler) + return dispatcher.result() + + @_pre_call_cache + async def values( + self, filter: Optional[Filter] = None, comparator: Optional[Comparator] = None, by_page: bool = False + ) -> AsyncIterator[V]: + if by_page and comparator is None and filter is None: + page_dispatcher: PagingDispatcher[V] = self._request_factory.page_request(values_only=True) + await page_dispatcher.dispatch(self._stream_handler) + return page_dispatcher + else: + dispatcher: StreamingDispatcher[V] = self._request_factory.values_request(filter, comparator) + await dispatcher.dispatch(self._stream_handler) + return dispatcher + + # gTODO + @_pre_call_cache + async def keys(self, filter: Optional[Filter] = None, by_page: bool = False) -> AsyncIterator[K]: + if by_page and filter is None: + page_dispatcher: PagingDispatcher[K] = self._request_factory.page_request(keys_only=True) + await page_dispatcher.dispatch(self._stream_handler) + return page_dispatcher + else: + dispatcher: StreamingDispatcher[K] = self._request_factory.keys_request(filter) + await dispatcher.dispatch(self._stream_handler) + return dispatcher + + @_pre_call_cache + async def entries( + self, filter: Optional[Filter] = None, comparator: Optional[Comparator] = None, by_page: bool = False + ) -> AsyncIterator[MapEntry[K, V]]: + if by_page and comparator is None and filter is None: + page_dispatcher: PagingDispatcher[MapEntry[K, V]] = self._request_factory.page_request() + await page_dispatcher.dispatch(self._stream_handler) + return page_dispatcher + else: + dispatcher: StreamingDispatcher[MapEntry[K, V]] = self._request_factory.entries_request(filter, comparator) + await dispatcher.dispatch(self._stream_handler) + return dispatcher + + @_pre_call_cache + async def add_index( + self, extractor: ValueExtractor[T, E], ordered: bool = False, comparator: Optional[Comparator] = None + ) -> None: + if extractor is None: + raise ValueError(_SPECIFY_EXTRACTOR) + + dispatcher: Dispatcher = self._request_factory.add_index_request(extractor, ordered, comparator) + await dispatcher.dispatch(self._stream_handler) + + @_pre_call_cache + async def remove_index(self, extractor: ValueExtractor[T, E]) -> None: + if extractor is None: + raise ValueError(_SPECIFY_EXTRACTOR) + + dispatcher: Dispatcher = self._request_factory.remove_index_request(extractor) + await dispatcher.dispatch(self._stream_handler) + + class TlsOptions: """ Options specific to the configuration of TLS. @@ -868,11 +1489,9 @@ def __init__( self._locked = locked self._enabled = enabled - self._ca_cert_path = ca_cert_path if ca_cert_path is not None else os.getenv(TlsOptions.ENV_CA_CERT) - self._client_cert_path = ( - client_cert_path if client_cert_path is not None else os.getenv(TlsOptions.ENV_CLIENT_CERT) - ) - self._client_key_path = client_key_path if client_key_path is not None else os.getenv(TlsOptions.ENV_CLIENT_KEY) + self._ca_cert_path = os.getenv(TlsOptions.ENV_CA_CERT, ca_cert_path) + self._client_cert_path = os.getenv(TlsOptions.ENV_CLIENT_CERT, client_cert_path) + self._client_key_path = os.getenv(TlsOptions.ENV_CLIENT_KEY, client_key_path) @property def enabled(self) -> bool: @@ -1029,18 +1648,14 @@ def __init__( https://github.com/grpc/grpc/blob/master/include/grpc/impl/grpc_types.h :param tls_options: Optional TLS configuration. """ - addr = os.getenv(Options.ENV_SERVER_ADDRESS) - if addr is not None: - self._address = addr - else: - self._address = address + self._address = os.getenv(Options.ENV_SERVER_ADDRESS, address) self._request_timeout_seconds = Options._get_float_from_env( Options.ENV_REQUEST_TIMEOUT, request_timeout_seconds ) self._ready_timeout_seconds = Options._get_float_from_env(Options.ENV_READY_TIMEOUT, ready_timeout_seconds) self._session_disconnect_timeout_seconds = Options._get_float_from_env( - Options.ENV_READY_TIMEOUT, session_disconnect_seconds + Options.ENV_SESSION_DISCONNECT_TIMEOUT, session_disconnect_seconds ) self._scope = scope @@ -1222,9 +1837,6 @@ class Session: """ - DEFAULT_FORMAT: Final[str] = "json" - """The default serialization format""" - def __init__(self, session_options: Optional[Options] = None): """ Construct a new `Session` based on the provided :func:`coherence.client.Options` @@ -1233,10 +1845,11 @@ def __init__(self, session_options: Optional[Options] = None): """ self._closed: bool = False self._session_id: str = str(uuid.uuid4()) - self._ready = False + self._ready: bool = False + self._initialized: bool = False self._ready_condition: Condition = Condition() self._caches: dict[str, NamedCache[Any, Any]] = dict() - self._lock: Lock = Lock() + self._lock: asyncio.Lock = asyncio.Lock() if session_options is not None: self._session_options = session_options else: @@ -1264,6 +1877,9 @@ def __init__(self, session_options: Optional[Options] = None): ("grpc.lb_policy_name", "round_robin"), ] + self._is_server_grpc_v1 = False + self._v1_init_response_details: dict[str, Any] = dict() + if self._session_options.tls_options is None: self._channel: grpc.aio.Channel = grpc.aio.insecure_channel( self._session_options.address, @@ -1283,6 +1899,8 @@ def __init__(self, session_options: Optional[Options] = None): interceptors=interceptors, ) + self._handshake = _Handshake(self) + watch_task: Task[None] = asyncio.create_task(watch_channel_state(self)) self._tasks.add(watch_task) self._emitter: EventEmitter = EventEmitter() @@ -1292,6 +1910,7 @@ def __init__(self, session_options: Optional[Options] = None): async def create(session_options: Optional[Options] = None) -> Session: session: Session = Session(session_options) await session._set_ready(False) + await session._handshake.handshake() return session # noinspection PyTypeHints @@ -1315,6 +1934,10 @@ def on( :param event: the event to listener for :param callback: the callback to invoke when the event is raised """ + if event == SessionLifecycleEvent.CONNECTED and self.is_ready(): + callback() # type: ignore + return + self._emitter.on(str(event.value), callback) @property @@ -1374,54 +1997,64 @@ def session_id(self) -> str: return self._session_id def __str__(self) -> str: - return ( - f"Session(id={self.session_id}, closed={self.closed}, state={self._channel.get_state(False)}," - f" caches/maps={len(self._caches)}, options={self.options})" - ) + if self._protocol_version > 0: + return ( + f"Session(id={self.session_id}, closed={self.closed}, state={self._channel.get_state(False)}," + f" caches/maps={len(self._caches)}, protocol-version={self._protocol_version} options={self.options}" + f" proxy-version={self._proxy_version}, proxy-member-id={self._proxy_member_id})" + ) + else: + return ( + f"Session(id={self.session_id}, closed={self.closed}, state={self._channel.get_state(False)}," + f" caches/maps={len(self._caches)}, protocol-version={self._protocol_version} options={self.options})" + ) # noinspection PyProtectedMember @_pre_call_session - async def get_cache(self, name: str, ser_format: str = DEFAULT_FORMAT) -> "NamedCache[K, V]": + async def get_cache(self, name: str, cache_options: Optional[CacheOptions] = None) -> NamedCache[K, V]: """ Returns a :func:`coherence.client.NamedCache` for the specified cache name. :param name: the cache name - :param ser_format: the serialization format for keys and values stored within the cache + :param cache_options: a :class:`coherence.client.CacheOptions` :return: Returns a :func:`coherence.client.NamedCache` for the specified cache name. """ - serializer = SerializerRegistry.serializer(ser_format) - with self._lock: + serializer = SerializerRegistry.serializer(self._session_options.format) + + async with self._lock: c = self._caches.get(name) if c is None: - c = NamedCacheClient(name, self, serializer) - # initialize the event stream now to ensure lifecycle listeners will work as expected - await c._events_manager._ensure_stream() + if self._protocol_version == 0: + c = NamedCacheClient(name, self, serializer, cache_options) + # initialize the event stream now to ensure lifecycle listeners will work as expected + await c._events_manager._ensure_stream() + else: + c = NamedCacheClientV1(name, self, serializer, cache_options) + await c._ensure_cache() + await c._post_create() + self._setup_event_handlers(c) self._caches.update({name: c}) + else: + if c.options != cache_options: + raise ValueError( + "A NamedMap or NamedCache with the same name already exists with different CacheOptions" + ) return c # noinspection PyProtectedMember @_pre_call_session - async def get_map(self, name: str, ser_format: str = DEFAULT_FORMAT) -> "NamedMap[K, V]": + async def get_map(self, name: str, cache_options: Optional[CacheOptions] = None) -> NamedMap[K, V]: """ Returns a :func:`coherence.client.NameMap` for the specified cache name. :param name: the map name - :param ser_format: the serialization format for keys and values stored within the cache + :param cache_options: a :class:`coherence.client.CacheOptions` :return: Returns a :func:`coherence.client.NamedMap` for the specified cache name. """ - serializer = SerializerRegistry.serializer(ser_format) - with self._lock: - c = self._caches.get(name) - if c is None: - c = NamedCacheClient(name, self, serializer) - # initialize the event stream now to ensure lifecycle listeners will work as expected - await c._events_manager._ensure_stream() - self._setup_event_handlers(c) - self._caches.update({name: c}) - return c + return cast(NamedMap[K, V], await self.get_cache(name, cache_options)) def is_ready(self) -> bool: """ @@ -1433,6 +2066,18 @@ def is_ready(self) -> bool: return True if not self._ready_enabled else self._ready + @property + def _proxy_version(self) -> str: + return self._handshake.proxy_version + + @property + def _protocol_version(self) -> int: + return self._handshake.protocol_version + + @property + def _proxy_member_id(self) -> int: + return self._handshake.proxy_member_id + async def _set_ready(self, ready: bool) -> None: self._ready = ready if self._ready: @@ -1465,18 +2110,19 @@ async def close(self) -> None: self._emitter.emit(SessionLifecycleEvent.CLOSED.value) for task in self._tasks: task.cancel() + await task self._tasks.clear() caches_copy: dict[str, NamedCache[Any, Any]] = self._caches.copy() for cache in caches_copy.values(): - cache.release() + await cache.release() self._caches.clear() await self._channel.close() # TODO: consider grace period? self._channel = None - def _setup_event_handlers(self, client: NamedCacheClient[K, V]) -> None: + def _setup_event_handlers(self, client: NamedCacheClient[K, V] | NamedCacheClientV1[K, V]) -> None: this: Session = self def on_destroyed(name: str) -> None: @@ -1512,9 +2158,11 @@ async def _do_intercept(self, continuation, client_call_details, request): :param request: the gRPC request (if any) :return: the result of the call """ + from . import _TIMEOUT_CONTEXT_VAR + new_details = grpc.aio.ClientCallDetails( client_call_details.method, - self._session.options.request_timeout_seconds, + _TIMEOUT_CONTEXT_VAR.get(self._session.options.request_timeout_seconds), client_call_details.metadata, client_call_details.credentials, True if self._session._ready_enabled else None, @@ -1687,7 +2335,7 @@ def __init__( self._result_handler: Callable[[Serializer, Any], Any] = result_handler # the serializer to be used when deserializing streamed results - self._serializer: Serializer = client._request_factory.get_serializer() + self._serializer: Serializer = client._request_factory.serializer # cookie that tracks page streaming; used for each new page request self._cookie: bytes = bytes() @@ -1796,3 +2444,230 @@ async def _entry_producer(serializer: Serializer, stream: grpc.Channel.unary_str async for item in stream: return _entry_deserializer(serializer, item) raise StopAsyncIteration + + +async def _entry_producer_from_list(serializer: Serializer, the_list: list[Any]) -> MapEntry[K, V]: # type: ignore + if len(the_list) == 0: + raise StopAsyncIteration + for item in the_list: + the_list.pop(0) + return _entry_deserializer(serializer, item) + + +class _ListAsyncIterator(abc.ABC, AsyncIterator[T]): + def __init__( + self, + serializer: Serializer, + the_list: list[T], + next_producer: Callable[[Serializer, list[T]], Awaitable[T]], + ) -> None: + super().__init__() + # A function that may be called to produce a series of results + self._next_producer = next_producer + + # the Serializer that should be used to deserialize results + self._serializer = serializer + + # the gRPC stream providing results + self._the_list = the_list + + def __aiter__(self) -> AsyncIterator[T]: + return self + + def __anext__(self) -> Awaitable[T]: + return self._next_producer(self._serializer, self._the_list) + + +# noinspection PyProtectedMember +class StreamHandler: + # noinspection PyTypeChecker + def __init__( + self, + session: Session, + request_factory: RequestFactoryV1, + events_manager: _MapEventsManagerV1[K, V], + ): + self._debug: str = os.environ.get("COHERENCE_MESSAGING_DEBUG", "off") + self._session: Session = session + self._channel = session.channel + self._reconnect_timeout: float = session.options.session_disconnect_timeout_seconds + self._proxy_stub = ProxyServiceStub(session.channel) + + self._request_factory: RequestFactoryV1 = request_factory + self._events_manager: _MapEventsManagerV1[K, V] = events_manager + self._stream: Optional[StreamStreamMultiCallable] = None + self._observers: dict[int, ResponseObserver] = dict() + self.result_available = Event() + self.result_available.clear() + self._background_tasks: Set[Task[Any]] = set() + self._closed: bool = False + self._connected = Event() + self._connected.clear() + self._ensure_lock = asyncio.Lock() + self._write_lock = asyncio.Lock() + + task = asyncio.create_task(self.handle_response()) + task.add_done_callback(self._background_tasks.discard) + self._background_tasks.add(task) + + def on_disconnect() -> None: + self._connected.clear() + self._stream = None + + async def on_reconnect() -> None: + self._connected.set() + # noinspection PyUnresolvedReferences + await self._events_manager._named_map._ensure_cache() + await self._events_manager._reconnect() + + session.on(SessionLifecycleEvent.DISCONNECTED, on_disconnect) + session.on(SessionLifecycleEvent.RECONNECTED, on_reconnect) + + def _log_message(self, message: Any, send: bool = True) -> None: + debug: str = self._debug + + if debug != sys.intern("off"): + session_id = self._session.session_id + prefix: str = f"c.m.d SND [{session_id}] -> " if send else f"c.m.d RCV [{session_id}] <- " + if debug == sys.intern("on"): + COH_LOG.debug(prefix + textwrap.shorten(MessageToJson(message=message, indent=None), 256)) + elif debug == sys.intern("full"): + COH_LOG.debug(prefix + MessageToJson(message=message, indent=None)) + + @property + async def stream(self) -> StreamStreamMultiCallable: + await self._ensure_stream() + + return self._stream + + # noinspection PyUnresolvedReferences + async def close(self) -> None: + tasks: Set[Task[Any]] = set(self._background_tasks) + for task in tasks: + task.cancel() + await task + + if self._stream is not None: + self._stream.cancel() + self._stream = None + + self._closed = True + + async def _ensure_stream(self) -> StreamStreamMultiCallable: + if self._stream is None: + async with self._ensure_lock: + if self._stream is None: + stream = self._proxy_stub.subChannel() + + try: + await stream.write(self._request_factory.init_sub_channel()) + except grpc.aio._call.AioRpcError as e: + print(e) + + await stream.read() + self._stream = stream # we don't care about the result, only that it completes + self._connected.set() + + return self._stream + + # noinspection PyUnresolvedReferences + async def send_proxy_request(self, proxy_request: ProxyRequest) -> None: + stream: StreamStreamMultiCallable = await self.stream + + self._log_message(proxy_request) + + async with self._write_lock: + await stream.write(proxy_request) + + def register_observer(self, observer: ResponseObserver) -> None: + assert observer.id not in self._observers + + self._observers[observer.id] = observer + + def deregister_observer(self, observer: ResponseObserver) -> None: + self._observers.pop(observer.id, None) + + async def handle_response(self) -> None: + while not self._closed: + try: + stream: StreamStreamMultiCallable = await self.stream + # noinspection PyUnresolvedReferences + response = await stream.read() + response_id = response.id + + self._log_message(response, False) + + if response_id == 0: + await self.handle_zero_id_response(response) + else: + if response.HasField("message"): + observer = self._observers.get(response_id, None) + if observer is not None: + named_cache_response = NamedCacheResponse() + response.message.Unpack(named_cache_response) + observer._next(named_cache_response) + continue + elif response.HasField("init"): + self.result_available.set() + elif response.HasField("error"): + observer = self._observers.get(response_id, None) + if observer is not None: + self._observers.pop(response_id, None) + observer._err(Exception(response.error.message)) + continue + elif response.HasField("complete"): + observer = self._observers.get(response_id, None) + if observer is not None: + self._observers.pop(response_id, None) + observer._done() + except asyncio.CancelledError: + return + except grpc.aio._call.AioRpcError as e: + if e.code().name == "CANCELLED": + continue + COH_LOG.error("Received unexpected error from proxy: " + str(e)) + + # noinspection PyUnresolvedReferences + async def handle_zero_id_response(self, response: ProxyResponse) -> None: + if response.HasField("message"): + named_cache_response = NamedCacheResponse() + response.message.Unpack(named_cache_response) + response_type = named_cache_response.type + if response_type == ResponseType.Message: + return + elif response_type == ResponseType.MapEvent: + # Handle MapEvent Response + event_response = MapEventMessage() + named_cache_response.message.Unpack(event_response) + + if event_response.id == 0: + # v0 map event received - drop on the floor + return + + try: + event: MapEvent[Any, Any] = MapEvent( + self._events_manager._named_map, event_response, self._events_manager._serializer + ) + for _id in event_response.filterIds: + filter_group: Optional[_ListenerGroup[Any, Any, Any]] = ( + self._events_manager._filter_id_listener_group_map.get(_id, None) + ) + if filter_group is not None: + await filter_group._notify_listeners(event) + + key_group = self._events_manager._key_map.get(event.key, None) + if key_group is not None: + await key_group._notify_listeners(event) + except Exception as e: + traceback.print_exc() + COH_LOG.warning("Unhandled Event Message: " + str(e)) + elif response_type == ResponseType.Destroyed: + if self._events_manager._named_map.cache_id == named_cache_response.cacheId: + self._events_manager._emitter.emit( + MapLifecycleEvent.DESTROYED.value, self._events_manager._named_map.name + ) + elif response_type == ResponseType.Truncated: + if self._events_manager._named_map.cache_id == named_cache_response.cacheId: + self._events_manager._emitter.emit( + MapLifecycleEvent.TRUNCATED.value, self._events_manager._named_map.name + ) diff --git a/src/coherence/common_messages_v1_pb2.py b/src/coherence/common_messages_v1_pb2.py new file mode 100644 index 0000000..8ceff5d --- /dev/null +++ b/src/coherence/common_messages_v1_pb2.py @@ -0,0 +1,37 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: common_messages_v1.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x63ommon_messages_v1.proto\x12\x13\x63oherence.common.v1\x1a\x19google/protobuf/any.proto\"=\n\x0c\x45rrorMessage\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x12\n\x05\x65rror\x18\x02 \x01(\x0cH\x00\x88\x01\x01\x42\x08\n\x06_error\"\n\n\x08\x43omplete\";\n\x10HeartbeatMessage\x12\x11\n\x04uuid\x18\x01 \x01(\x0cH\x00\x88\x01\x01\x12\x0b\n\x03\x61\x63k\x18\x02 \x01(\x08\x42\x07\n\x05_uuid\"/\n\rOptionalValue\x12\x0f\n\x07present\x18\x01 \x01(\x08\x12\r\n\x05value\x18\x02 \x01(\x0c\")\n\x17\x43ollectionOfBytesValues\x12\x0e\n\x06values\x18\x01 \x03(\x0c\"/\n\x11\x42inaryKeyAndValue\x12\x0b\n\x03key\x18\x01 \x01(\x0c\x12\r\n\x05value\x18\x02 \x01(\x0c\x42\x30\n,com.oracle.coherence.grpc.messages.common.v1P\x01\x62\x06proto3') + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'common_messages_v1_pb2', globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'\n,com.oracle.coherence.grpc.messages.common.v1P\001' + _ERRORMESSAGE._serialized_start=76 + _ERRORMESSAGE._serialized_end=137 + _COMPLETE._serialized_start=139 + _COMPLETE._serialized_end=149 + _HEARTBEATMESSAGE._serialized_start=151 + _HEARTBEATMESSAGE._serialized_end=210 + _OPTIONALVALUE._serialized_start=212 + _OPTIONALVALUE._serialized_end=259 + _COLLECTIONOFBYTESVALUES._serialized_start=261 + _COLLECTIONOFBYTESVALUES._serialized_end=302 + _BINARYKEYANDVALUE._serialized_start=304 + _BINARYKEYANDVALUE._serialized_end=351 +# @@protoc_insertion_point(module_scope) diff --git a/src/coherence/common_messages_v1_pb2.pyi b/src/coherence/common_messages_v1_pb2.pyi new file mode 100644 index 0000000..f5747d9 --- /dev/null +++ b/src/coherence/common_messages_v1_pb2.pyi @@ -0,0 +1,50 @@ +# mypy: ignore-errors +from google.protobuf import any_pb2 as _any_pb2 +from google.protobuf.internal import containers as _containers +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import ClassVar as _ClassVar, Iterable as _Iterable, Optional as _Optional + +DESCRIPTOR: _descriptor.FileDescriptor + +class BinaryKeyAndValue(_message.Message): + __slots__ = ["key", "value"] + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: bytes + value: bytes + def __init__(self, key: _Optional[bytes] = ..., value: _Optional[bytes] = ...) -> None: ... + +class CollectionOfBytesValues(_message.Message): + __slots__ = ["values"] + VALUES_FIELD_NUMBER: _ClassVar[int] + values: _containers.RepeatedScalarFieldContainer[bytes] + def __init__(self, values: _Optional[_Iterable[bytes]] = ...) -> None: ... + +class Complete(_message.Message): + __slots__ = [] + def __init__(self) -> None: ... + +class ErrorMessage(_message.Message): + __slots__ = ["error", "message"] + ERROR_FIELD_NUMBER: _ClassVar[int] + MESSAGE_FIELD_NUMBER: _ClassVar[int] + error: bytes + message: str + def __init__(self, message: _Optional[str] = ..., error: _Optional[bytes] = ...) -> None: ... + +class HeartbeatMessage(_message.Message): + __slots__ = ["ack", "uuid"] + ACK_FIELD_NUMBER: _ClassVar[int] + UUID_FIELD_NUMBER: _ClassVar[int] + ack: bool + uuid: bytes + def __init__(self, uuid: _Optional[bytes] = ..., ack: bool = ...) -> None: ... + +class OptionalValue(_message.Message): + __slots__ = ["present", "value"] + PRESENT_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + present: bool + value: bytes + def __init__(self, present: bool = ..., value: _Optional[bytes] = ...) -> None: ... diff --git a/src/coherence/common_messages_v1_pb2_grpc.py b/src/coherence/common_messages_v1_pb2_grpc.py new file mode 100644 index 0000000..2daafff --- /dev/null +++ b/src/coherence/common_messages_v1_pb2_grpc.py @@ -0,0 +1,4 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + diff --git a/src/coherence/entry.py b/src/coherence/entry.py new file mode 100644 index 0000000..8c2f7cb --- /dev/null +++ b/src/coherence/entry.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from typing import Generic, TypeVar + +K = TypeVar("K") +V = TypeVar("V") + + +class MapEntry(Generic[K, V]): + """ + A map entry (key-value pair). + """ + + def __init__(self, key: K, value: V): + self._key = key + self._value = value + + @property + def key(self) -> K: + return self._key + + @property + def value(self) -> V: + return self._value + + def __str__(self) -> str: + return f"MapEntry(key={self.key}, value={self.value})" diff --git a/src/coherence/event.py b/src/coherence/event.py index bd4ad0f..ca32558 100644 --- a/src/coherence/event.py +++ b/src/coherence/event.py @@ -5,10 +5,10 @@ from __future__ import annotations import asyncio -from abc import ABCMeta, abstractmethod +from abc import ABC, ABCMeta, abstractmethod from asyncio import Event, Task from enum import Enum, unique -from typing import Callable, Generic, Optional, Set, TypeVar, cast +from typing import Any, Callable, Generic, Optional, Set, TypeVar, cast # noinspection PyPackageRequirements import grpc @@ -16,11 +16,12 @@ import coherence.client +from .cache_service_messages_v1_pb2 import MapEventMessage, NamedCacheRequest from .filter import Filter, Filters, MapEventFilter from .messages_pb2 import MapEventResponse, MapListenerRequest, MapListenerResponse from .serialization import Serializer from .services_pb2_grpc import NamedCacheServiceStub -from .util import RequestFactory +from .util import RequestFactory, RequestFactoryV1 K = TypeVar("K") """the type of the map entry keys.""" @@ -28,6 +29,9 @@ V = TypeVar("V") """the type of the map entry values.""" +RT = TypeVar("RT") +"""the gRPC request type.""" + @unique class MapEventType(Enum): @@ -89,7 +93,10 @@ class MapEvent(Generic[K, V]): """ def __init__( - self, source: coherence.client.NamedMap[K, V], response: MapEventResponse, serializer: Serializer + self, + source: coherence.client.NamedMap[K, V], + response: MapEventResponse | MapEventMessage, + serializer: Serializer, ) -> None: """ Constructs a new MapEvent. @@ -225,8 +232,15 @@ class MapListener(Generic[K, V]): _emitter: EventEmitter """The internal emitter used to emit events.""" - def __init__(self) -> None: - """Constructs a new MapListener.""" + def __init__(self, synchronous: bool = False, priming: bool = False) -> None: + """ + Constructs a new MapListener. + + :param synchronous: if True, the listener will be registered as a synchronous listener + :param priming: if True, the listener will be registered as a synchronous listener + """ + self._synchronous: bool = synchronous + self._priming: bool = priming self._emitter = EventEmitter() def _on(self, event: MapEventType, callback: MapListenerCallback[K, V]) -> MapListener[K, V]: @@ -271,12 +285,22 @@ def on_any(self, callback: MapListenerCallback[K, V]) -> MapListener[K, V]: """ return self.on_deleted(callback).on_updated(callback).on_inserted(callback) + @property + def synchronous(self) -> bool: + """ + :return: True if this listener is synchronous, otherwise False + """ + return self._synchronous + + @property + def priming(self) -> bool: + """ + :return: True if this listener is priming, otherwise False + """ + return self._priming -class _ListenerGroup(Generic[K, V], metaclass=ABCMeta): - """Manages a collection of MapEventListeners that will be notified when an event is raised. - This also manages the on-wire activities for registering/deregistering a listener with the - gRPC proxy.""" +class _ListenerGroup(Generic[K, V, RT], metaclass=ABCMeta): _key_or_filter: K | Filter """The key or Filter for which this group of listeners will receive events.""" @@ -289,31 +313,13 @@ class _ListenerGroup(Generic[K, V], metaclass=ABCMeta): _lite_false_count: int """The number of callbacks that aren't lite.""" - _manager: _MapEventsManager[K, V] - """The associated MapEventsManager for this group.""" - - _request: MapListenerRequest - """The subscription request. A reference is maintained for unsubscribe purposes.""" - _subscription_waiter: Event """Used by a caller to be notified when the listener subscription as been completed.""" _unsubscription_waiter: Event """Used by a caller to be notified when the listener unsubscribe as been completed.""" - def _init_(self, manager: _MapEventsManager[K, V], key_or_filter: K | Filter) -> None: - """ - Constructs a new _ListenerGroup. - :param manager: the _MapEventManager - :param key_or_filter: the key or filter for this group of listeners - :raises ValueError: if either `manager` or `key_or_filter` is `None` - """ - if manager is None: - raise ValueError("Argument `manager` must not be None") - if key_or_filter is None: - raise ValueError("Argument `key_or_filter` must not be None") - - self._manager = manager + def __init__(self, key_or_filter: K | Filter) -> None: self._key_or_filter = key_or_filter self._listeners = {} self._lite_false_count = 0 @@ -321,6 +327,34 @@ def _init_(self, manager: _MapEventsManager[K, V], key_or_filter: K | Filter) -> self._subscribed_waiter = Event() self._unsubscribed_waiter = Event() + @abstractmethod + async def _subscribe(self, lite: bool, sync: bool = False, priming: bool = False) -> None: + pass + + @abstractmethod + async def _unsubscribe(self) -> None: + pass + + @abstractmethod + def _post_subscribe(self, request: RT) -> None: + """ + Custom actions that implementations may need to make after a subscription has been completed. + :param request: the request that was used to subscribe + """ + pass + + @abstractmethod + def _post_unsubscribe(self, request: RT) -> None: + """ + Custom actions that implementations may need to make after an unsubscription has been completed. + :param request: the request that was used to unsubscribe + """ + pass + + @abstractmethod + def _subscribe_complete(self) -> None: + pass + async def add_listener(self, listener: MapListener[K, V], lite: bool) -> None: """ Add a callback to this group. This causes a subscription message to be sent through the stream @@ -348,7 +382,7 @@ async def add_listener(self, listener: MapListener[K, V], lite: bool) -> None: if size > 1: await self._unsubscribe() - await self._subscribe(lite) + await self._subscribe(lite, listener.synchronous, listener.priming) async def remove_listener(self, listener: MapListener[K, V]) -> None: """ @@ -373,6 +407,63 @@ async def remove_listener(self, listener: MapListener[K, V]) -> None: await self._unsubscribe() await self._subscribe(True) + # noinspection PyProtectedMember + async def _notify_listeners(self, event: MapEvent[K, V]) -> None: + """ + Notify all listeners within this group of the provided event. + :param event: + """ + event_label: str = self._get_emitter_label(event) + listener: MapListener[K, V] + for listener in self._listeners.keys(): + await listener._emitter.emit_async(event_label, event) + await asyncio.sleep(0) + + # noinspection PyProtectedMember + @staticmethod + def _get_emitter_label(event: MapEvent[K, V]) -> str: + """ + The string label required by the internal event emitter. + :param event: the MapEvent whose label will be generated + :return: the emitter-friendly event label + """ + if event.type == MapEventType.ENTRY_DELETED: + return MapEventType.ENTRY_DELETED.value + elif event.type == MapEventType.ENTRY_INSERTED: + return MapEventType.ENTRY_INSERTED.value + elif event.type == MapEventType.ENTRY_UPDATED: + return MapEventType.ENTRY_UPDATED.value + else: + raise AssertionError(f"Unknown EventType [{event}]") + + +class _ListenerGroupV0(_ListenerGroup[K, V, MapListenerRequest], metaclass=ABCMeta): + """Manages a collection of MapEventListeners that will be notified when an event is raised. + This also manages the on-wire activities for registering/de-registering a listener with the + gRPC proxy.""" + + _manager: _MapEventsManagerV0[K, V] + """The associated MapEventsManager for this group.""" + + _request: MapListenerRequest + """The subscription request. A reference is maintained for unsubscribe purposes.""" + + def __init__(self, manager: _MapEventsManagerV0[K, V], key_or_filter: K | Filter) -> None: + """ + Constructs a new _ListenerGroup. + :param manager: the _MapEventManager + :param key_or_filter: the key or filter for this group of listeners + :raises ValueError: if either `manager` or `key_or_filter` is `None` + """ + if manager is None: + raise ValueError("Argument `manager` must not be None") + if key_or_filter is None: + raise ValueError("Argument `key_or_filter` must not be None") + + super().__init__(key_or_filter=key_or_filter) + + self._manager = manager + # noinspection PyProtectedMember async def _write(self, request: MapListenerRequest) -> None: """Write the request to the event stream.""" @@ -380,7 +471,7 @@ async def _write(self, request: MapListenerRequest) -> None: await event_stream.write(request) # noinspection PyProtectedMember - async def _subscribe(self, lite: bool) -> None: + async def _subscribe(self, lite: bool, sync: bool = False, priming: bool = False) -> None: """ Send a gRPC MapListener subscription request for a key or filter. :param lite: `True` if the event should only include the key, or `False` @@ -402,17 +493,6 @@ async def _subscribe(self, lite: bool) -> None: await self._subscribed_waiter.wait() self._subscribed_waiter.clear() - # noinspection PyProtectedMember - def _subscribe_complete(self) -> None: - """Called when the response to the subscription request has been received.""" - - # no longer pending - del self._manager._pending_registrations[self._request.uid] - self._post_subscribe(self._request) - - # notify caller that subscription is active - self._subscribed_waiter.set() - # noinspection PyProtectedMember async def _unsubscribe(self) -> None: """ @@ -430,60 +510,86 @@ async def _unsubscribe(self) -> None: self._post_unsubscribe(request) # noinspection PyProtectedMember - def _notify_listeners(self, event: MapEvent[K, V]) -> None: - """ - Notify all listeners within this group of the provided event. - :param event: - """ - event_label: str = self._get_emitter_label(event) - listener: MapListener[K, V] - for listener in self._listeners.keys(): - listener._emitter.emit(event_label, event) + def _subscribe_complete(self) -> None: + del self._manager._pending_registrations[self._request.uid] + self._post_subscribe(self._request) - # noinspection PyProtectedMember - @staticmethod - def _get_emitter_label(event: MapEvent[K, V]) -> str: + # notify caller that subscription is active + self._subscribed_waiter.set() + + +# noinspection PyProtectedMember +class _ListenerGroupV1(_ListenerGroup[K, V, NamedCacheRequest], ABC): + + def __init__(self, manager: _MapEventsManagerV1[K, V], key_or_filter: K | Filter): + if manager is None: + raise ValueError("Argument `manager` must not be None") + if key_or_filter is None: + raise ValueError("Argument `key_or_filter` must not be None") + + super().__init__(key_or_filter=key_or_filter) + + self._manager = manager + + async def _subscribe(self, lite: bool, sync: bool = False, priming: bool = False) -> None: """ - The string label required by the internal event emitter. - :param event: the MapEvent whose label will be generated - :return: the emitter-friendly event label + Send a gRPC MapListener subscription request for a key or filter. + :param lite: `True` if the event should only include the key, or `False` + if the event should include old and new values as well as the key """ - if event.type == MapEventType.ENTRY_DELETED: - return MapEventType.ENTRY_DELETED.value - elif event.type == MapEventType.ENTRY_INSERTED: - return MapEventType.ENTRY_INSERTED.value - elif event.type == MapEventType.ENTRY_UPDATED: - return MapEventType.ENTRY_UPDATED.value + request: NamedCacheRequest + filter_id: int + if isinstance(self._key_or_filter, Filter): + (dispatcher, request, filter_id) = self._manager.request_factory.map_listener_request( + True, lite, filter=self._key_or_filter, sync=sync, priming=priming + ) else: - raise AssertionError(f"Unknown EventType [{event}]") + (dispatcher, request, filter_id) = self._manager.request_factory.map_listener_request( + True, lite, key=self._key_or_filter, sync=sync, priming=priming + ) - @abstractmethod - def _post_subscribe(self, request: MapListenerRequest) -> None: - """ - Custom actions that implementations may need to make after a subscription has been completed. - :param request: the request that was used to subscribe - """ - pass + self._request = request + self._filter_id = filter_id - @abstractmethod - def _post_unsubscribe(self, request: MapListenerRequest) -> None: - """ - Custom actions that implementations may need to make after an unsubscription has been completed. - :param request: the request that was used to unsubscribe - """ - pass + # set this registration as pending + self._manager._pending_registrations[filter_id] = self + + # noinspection PyUnresolvedReferences + await dispatcher.dispatch(self._manager._named_map._stream_handler) + + self._subscribe_complete() + + async def _unsubscribe(self) -> None: + request: NamedCacheRequest + if isinstance(self._key_or_filter, MapEventFilter): + # noinspection PyTypeChecker + (dispatcher, request, filter_id) = self._manager.request_factory.map_listener_request( + subscribe=False, filter=self._key_or_filter, filter_id=self._filter_id + ) + else: + (dispatcher, request, filter_id) = self._manager.request_factory.map_listener_request( + subscribe=False, key=self._key_or_filter + ) + + # noinspection PyUnresolvedReferences + await dispatcher.dispatch(self._manager._named_map._stream_handler) + self._post_unsubscribe(request) + + def _subscribe_complete(self) -> None: + del self._manager._pending_registrations[self._filter_id] + self._post_subscribe(self._request) -class _KeyListenerGroup(_ListenerGroup[K, V]): +class _KeyListenerGroupV0(_ListenerGroupV0[K, V]): """A ListenerGroup for key-based MapListeners""" - def __init__(self, manager: _MapEventsManager[K, V], key: K) -> None: + def __init__(self, manager: _MapEventsManagerV0[K, V], key: K) -> None: """ Creates a new _KeyListenerGroup :param manager: the _MapEventManager :param key: the group key """ - super()._init_(manager, key) + super().__init__(manager, key) # noinspection PyProtectedMember def _post_subscribe(self, request: MapListenerRequest) -> None: @@ -498,16 +604,46 @@ def _post_unsubscribe(self, request: MapListenerRequest) -> None: manager._key_group_unsubscribed(key) -class _FilterListenerGroup(_ListenerGroup[K, V]): +class _KeyListenerGroupV1(_ListenerGroupV1[K, V]): + _manager: _MapEventsManagerV1[K, V] + """The associated MapEventsManager for this group.""" + + _request: MapListenerRequest + """The subscription request. A reference is maintained for unsubscribe purposes.""" + + def __init__(self, manager: _MapEventsManagerV1[K, V], key_or_filter: K | Filter) -> None: + """ + Constructs a new _ListenerGroup. + :param manager: the _MapEventManager + :param key_or_filter: the key or filter for this group of listeners + :raises ValueError: if either `manager` or `key_or_filter` is `None` + """ + if manager is None: + raise ValueError("Argument `manager` must not be None") + if key_or_filter is None: + raise ValueError("Argument `key_or_filter` must not be None") + + super().__init__(manager, key_or_filter) + + # noinspection PyProtectedMember + def _post_subscribe(self, request: MapListenerRequest) -> None: + self._manager._key_group_subscribed(cast(K, self._key_or_filter), self) + + # noinspection PyProtectedMember + def _post_unsubscribe(self, request: MapListenerRequest) -> None: + self._manager._key_group_unsubscribed(cast(K, self._key_or_filter)) + + +class _FilterListenerGroupV0(_ListenerGroupV0[K, V]): """A ListenerGroup for Filter-based MapListeners""" - def __init__(self, manager: _MapEventsManager[K, V], filter: Filter) -> None: + def __init__(self, manager: _MapEventsManagerV0[K, V], filter: Filter) -> None: """ - Creates a new _KeyListenerGroup + Creates a new _FilterListenerGroupV0 :param manager: the _MapEventManager :param filter: the group Filter """ - super()._init_(manager, filter) + super().__init__(manager, filter) # noinspection PyProtectedMember def _post_subscribe(self, request: MapListenerRequest) -> None: @@ -518,7 +654,27 @@ def _post_unsubscribe(self, request: MapListenerRequest) -> None: self._manager._filter_group_unsubscribed(request.filterId, cast(Filter, self._key_or_filter)) -class _MapEventsManager(Generic[K, V]): +class _FilterListenerGroupV1(_ListenerGroupV1[K, V]): + """A ListenerGroup for Filter-based MapListeners""" + + def __init__(self, manager: _MapEventsManagerV1[K, V], filter: Filter) -> None: + """ + Creates a new _KeyListenerGroup + :param manager: the _MapEventManager + :param filter: the group Filter + """ + super().__init__(manager, filter) + + # noinspection PyProtectedMember + def _post_subscribe(self, request: MapListenerRequest) -> None: + self._manager._filter_group_subscribed(self._filter_id, cast(Filter, self._key_or_filter), self) + + # noinspection PyProtectedMember + def _post_unsubscribe(self, request: MapListenerRequest) -> None: + self._manager._filter_group_unsubscribed(self._filter_id, cast(Filter, self._key_or_filter)) + + +class _MapEventsManager(Generic[K, V], ABC): """MapEventsManager handles registration, de-registration of callbacks, and notification of {@link MapEvent}s to callbacks. Since multiple callbacks can be registered for a single key / filter, this class relies on another internal @@ -554,18 +710,15 @@ class called ListenerGroup which maintains the collection of callbacks. _map_name: str """The logical name of the provided NamedMap.""" - _key_map: dict[K, _ListenerGroup[K, V]] + _key_map: dict[K, _ListenerGroup[K, V, Any]] """Contains mappings between a key and its group of MapListeners.""" - _filter_map: dict[Filter, _ListenerGroup[K, V]] + _filter_map: dict[Filter, _ListenerGroup[K, V, Any]] """Contains mappings between a Filter and its group of MapListeners.""" - _filter_id_listener_group_map: dict[int, _ListenerGroup[K, V]] + _filter_id_listener_group_map: dict[int, _ListenerGroup[K, V, Any]] """Contains mappings between a logical filter ID and its ListenerGroup.""" - _request_factory: RequestFactory - """The RequestFactory used to obtain the necessary gRPC requests.""" - _event_stream: Optional[grpc.aio.StreamStreamCall] """gRPC bidirectional stream for subscribing/unsubscribing MapListeners and receiving MapEvents from the proxy.""" @@ -574,17 +727,15 @@ class called ListenerGroup which maintains the collection of callbacks. """"Flag indicating the event stream is open and ready for listener registrations and incoming events.""" - _pending_registrations: dict[str, _ListenerGroup[K, V]] + _pending_registrations: dict[str | int, _ListenerGroup[K, V, Any]] """The mapping of pending listener registrations keyed by request uid.""" _background_tasks: Set[Task[None]] - # noinspection PyProtectedMember def __init__( self, named_map: coherence.client.NamedMap[K, V], session: coherence.Session, - client: NamedCacheServiceStub, serializer: Serializer, emitter: EventEmitter, ) -> None: @@ -592,12 +743,10 @@ def __init__( Constructs a new _MapEventManager. :param named_map: the 'source' of the events :param session: the Session associated with this NamedMap - :param client: the gRPC client :param serializer: the Serializer that will be used for ser/deser operations :param emitter: the internal event emitter used to notify registered MapListeners """ self._named_map = named_map - self._client = client self._serializer = serializer self._emitter = emitter self._map_name = named_map.name @@ -608,8 +757,6 @@ def __init__( self._filter_id_listener_group_map = {} self._pending_registrations = {} - self._request_factory = RequestFactory(self._map_name, session.scope, serializer) - self._event_stream = None self._open = False self._background_tasks = set() @@ -617,57 +764,26 @@ def __init__( session.on(SessionLifecycleEvent.DISCONNECTED, self._close) - # intentionally ignoring the typing here to avoid complicating the - # callback API exposed on the session - # noinspection PyTypeChecker - session.on(SessionLifecycleEvent.RECONNECTED, self._reconnect) - + @abstractmethod def _close(self) -> None: - """Close the gRPC event stream and any background tasks.""" - - event_stream: grpc.aio.StreamStreamCall = self._event_stream - if event_stream is not None: - event_stream.cancel() - self._event_stream = None - - self._open = False - for task in self._background_tasks: - task.cancel() - self._background_tasks.clear() + pass # noinspection PyProtectedMember async def _reconnect(self) -> None: - group: _ListenerGroup[K, V] + group: _ListenerGroup[K, V, Any] for group in self._key_map.values(): await group._subscribe(group._registered_lite) for group in self._filter_map.values(): await group._subscribe(group._registered_lite) - async def _ensure_stream(self) -> grpc.aio.StreamStreamCall: - """ - Initialize the event stream for MapListener events. - """ - if self._event_stream is None: - event_stream: grpc.aio.StreamStreamCall = self._client.events() - await event_stream.write(self._request_factory.map_event_subscribe()) - self._event_stream = event_stream - read_task: Task[None] = asyncio.create_task(self._handle_response()) - self._background_tasks.add(read_task) - # we use asyncio.timeout here instead of using the gRPC timeout - # as any deadline set on the stream will result in a loss of events - try: - await asyncio.wait_for(self._stream_waiter.wait(), self._session.options.request_timeout_seconds) - except TimeoutError: - s = ( - "Deadline [{0} seconds] exceeded waiting for event stream" - " to become ready. Server address - {1})".format( - str(self._session.options.request_timeout_seconds), self._session.options.address - ) - ) - raise TimeoutError(s) + @abstractmethod + def _new_key_group(self, key: K) -> _ListenerGroup[K, V, Any]: + pass - return self._event_stream + @abstractmethod + def _new_filter_group(self, filter: Filter) -> _ListenerGroup[K, V, Any]: + pass async def _register_key_listener(self, listener: MapListener[K, V], key: K, lite: bool = False) -> None: """ @@ -677,10 +793,10 @@ async def _register_key_listener(self, listener: MapListener[K, V], key: K, lite :param lite: `True` if the event should only include the key, or `False` if the event should include old and new values as well as the key """ - group: Optional[_ListenerGroup[K, V]] = self._key_map.get(key, None) + group: Optional[_ListenerGroup[K, V, Any]] = self._key_map.get(key, None) if group is None: - group = _KeyListenerGroup(self, key) + group = self._new_key_group(key) self._key_map[key] = group await group.add_listener(listener, lite) @@ -691,7 +807,7 @@ async def _remove_key_listener(self, listener: MapListener[K, V], key: K) -> Non :param listener: the MapListener to remove :param key: they key the listener was associated with """ - group: Optional[_ListenerGroup[K, V]] = self._key_map.get(key, None) + group: Optional[_ListenerGroup[K, V, Any]] = self._key_map.get(key, None) if group is not None: await group.remove_listener(listener) @@ -704,13 +820,13 @@ async def _register_filter_listener( :param listener: the MapListener to register :param filter: the Filter associated with the listener :param lite: `True` if the event should only include the key, or `False` - if the event should include old and new values as well as the key + if the event should include old and new values as well as the key """ filter_local: Filter = filter if filter is not None else self._DEFAULT_FILTER - group: Optional[_ListenerGroup[K, V]] = self._filter_map.get(filter_local, None) + group: Optional[_ListenerGroup[K, V, Any]] = self._filter_map.get(filter_local, None) if group is None: - group = _FilterListenerGroup(self, filter_local) + group = self._new_filter_group(filter_local) self._filter_map[filter_local] = group await group.add_listener(listener, lite) @@ -721,12 +837,12 @@ async def _remove_filter_listener(self, listener: MapListener[K, V], filter: Opt :param filter: the Filter that was used with the listener registration """ filter_local: Filter = filter if filter is not None else self._DEFAULT_FILTER - group: Optional[_ListenerGroup[K, V]] = self._filter_map.get(filter_local, None) + group: Optional[_ListenerGroup[K, V, Any]] = self._filter_map.get(filter_local, None) if group is not None: await group.remove_listener(listener) - def _key_group_subscribed(self, key: K, group: _ListenerGroup[K, V]) -> None: + def _key_group_subscribed(self, key: K, group: _ListenerGroup[K, V, Any]) -> None: """ Called internally by _KeyListenerGroup when a key listener is subscribed. :param key: the registration key @@ -741,7 +857,7 @@ def _key_group_unsubscribed(self, key: K) -> None: """ del self._key_map[key] - def _filter_group_subscribed(self, filter_id: int, filter: Filter, group: _ListenerGroup[K, V]) -> None: + def _filter_group_subscribed(self, filter_id: int, filter: Filter, group: _ListenerGroup[K, V, Any]) -> None: """ Called internally by _FilterListenerGroup when a filter listener is subscribed. :param filter_id: the ID of the filter @@ -760,6 +876,78 @@ def _filter_group_unsubscribed(self, filter_id: int, filter: Filter) -> None: del self._filter_id_listener_group_map[filter_id] del self._filter_map[filter] + +class _MapEventsManagerV0(_MapEventsManager[K, V]): + """MapEventsManager implementation for V0 of the gRPC proxy.""" + + # noinspection PyProtectedMember + def __init__( + self, + named_map: coherence.client.NamedMap[K, V], + session: coherence.Session, + client: NamedCacheServiceStub, + serializer: Serializer, + emitter: EventEmitter, + ) -> None: + """ + Constructs a new _MapEventManager. + :param named_map: the 'source' of the events + :param session: the Session associated with this NamedMap + :param client: the gRPC client + :param serializer: the Serializer that will be used for ser/deser operations + :param emitter: the internal event emitter used to notify registered MapListeners + """ + super().__init__(named_map, session, serializer, emitter) + self._client = client + self._request_factory = RequestFactory(self._map_name, session.scope, serializer) + + # noinspection PyTypeChecker + session.on(SessionLifecycleEvent.RECONNECTED, self._reconnect) + + def _new_key_group(self, key: K) -> _ListenerGroup[K, V, Any]: + return _KeyListenerGroupV0(self, key) + + def _new_filter_group(self, filter: Filter) -> _ListenerGroup[K, V, Any]: + return _FilterListenerGroupV0(self, filter) + + def _close(self) -> None: + """Close the gRPC event stream and any background tasks.""" + + event_stream: grpc.aio.StreamStreamCall = self._event_stream + if event_stream is not None: + event_stream.cancel() + self._event_stream = None + + self._open = False + for task in self._background_tasks: + task.cancel() + self._background_tasks.clear() + + async def _ensure_stream(self) -> grpc.aio.StreamStreamCall: + """ + Initialize the event stream for MapListener events. + """ + if self._event_stream is None: + event_stream: grpc.aio.StreamStreamCall = self._client.events() + await event_stream.write(self._request_factory.map_event_subscribe()) + self._event_stream = event_stream + read_task: Task[None] = asyncio.create_task(self._handle_response()) + self._background_tasks.add(read_task) + # we use asyncio.timeout here instead of using the gRPC timeout + # as any deadline set on the stream will result in a loss of events + try: + await asyncio.wait_for(self._stream_waiter.wait(), self._session.options.request_timeout_seconds) + except TimeoutError: + s = ( + "Deadline [{0} seconds] exceeded waiting for event stream" + " to become ready. Server address - {1})".format( + str(self._session.options.request_timeout_seconds), self._session.options.address + ) + ) + raise TimeoutError(s) + + return self._event_stream + # noinspection PyProtectedMember async def _handle_response(self) -> None: """ @@ -778,7 +966,9 @@ async def _handle_response(self) -> None: response: MapListenerResponse = await event_stream.read() if response.HasField("subscribed"): subscribed = response.subscribed - group: Optional[_ListenerGroup[K, V]] = self._pending_registrations.get(subscribed.uid, None) + group: Optional[_ListenerGroup[K, V, Any]] = self._pending_registrations.get( + subscribed.uid, None + ) if group is not None: group._subscribe_complete() elif response.HasField("destroyed"): @@ -793,14 +983,47 @@ async def _handle_response(self) -> None: response_event = response.event event: MapEvent[K, V] = MapEvent(self._named_map, response_event, self._serializer) for _id in response_event.filterIds: - filter_group: Optional[_ListenerGroup[K, V]] = self._filter_id_listener_group_map.get( + filter_group: Optional[_ListenerGroup[K, V, Any]] = self._filter_id_listener_group_map.get( _id, None ) if filter_group is not None: - filter_group._notify_listeners(event) + await filter_group._notify_listeners(event) key_group = self._key_map.get(event.key, None) if key_group is not None: - key_group._notify_listeners(event) + await key_group._notify_listeners(event) except asyncio.CancelledError: return + + +class _MapEventsManagerV1(_MapEventsManager[K, V]): + def __init__( + self, + named_map: coherence.client.NamedMap[K, V], + session: coherence.Session, + serializer: Serializer, + emitter: EventEmitter, + request_factory: RequestFactoryV1, + ) -> None: + super().__init__(named_map, session, serializer, emitter) + self.request_factory = request_factory + + @property + def request_factory(self) -> RequestFactoryV1: + return self._request_factory + + @request_factory.setter + def request_factory(self, value: RequestFactoryV1) -> None: + self._request_factory = value + + async def _ensure_stream(self) -> grpc.aio.StreamStreamCall: + pass # in v1, this is a no-op + + def _new_key_group(self, key: K) -> _ListenerGroup[K, V, Any]: + return _KeyListenerGroupV1(self, key) + + def _new_filter_group(self, filter: Filter) -> _ListenerGroup[K, V, Any]: + return _FilterListenerGroupV1(self, filter) + + def _close(self) -> None: + return diff --git a/src/coherence/local_cache.py b/src/coherence/local_cache.py new file mode 100644 index 0000000..5e47b07 --- /dev/null +++ b/src/coherence/local_cache.py @@ -0,0 +1,786 @@ +# Copyright (c) 2024, Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at +# https://oss.oracle.com/licenses/upl. +from __future__ import annotations + +import asyncio +from collections import OrderedDict +from typing import Any, Generic, Optional, Tuple, TypeVar + +from pympler import asizeof + +from .entry import MapEntry +from .util import cur_time_millis, millis_format_date + +K = TypeVar("K") +V = TypeVar("V") + + +class NearCacheOptions: + def __init__( + self, ttl: Optional[int] = None, high_units: int = 0, high_units_memory: int = 0, prune_factor: float = 0.80 + ) -> None: + """ + Constructs a new NearCacheOptions. These options, when present, will configure + a NamedMap or NamedCache with a \"near cache\". This near cache will locally + cache entries to reduce the need to go to the network for entries that are + frequently accessed. Changes made to, or removal of entries from the remote + cache will be reflected in the near cache. + + :param ttl: the time-to-live, in millis, for entries held in the near cache. + Expiration resolution is to the 1/4 second, thus the minimum positive + ttl value is 250. If the ttl is zero, then no expiry will be applied + :param high_units: the maximum number of entries to be held by + the near cache. If this value is exceeded, the cache will be pruned + down by the percentage defined by the prune_factor parameter + :param high_units_memory: the maximum number of entries, in bytes, + that may be held by the near cache. If this value is exceeded, the + cache will be pruned down by the percentage defined by the + prune_factor parameter + :param prune_factor: the prune factor defines the target cache + size after exceeding the high_units or high_units_memory + high-water mark + """ + super().__init__() + if high_units < 0 or high_units_memory < 0: + raise ValueError("values for high_units and high_units_memory must be positive") + if ttl is None and high_units == 0 and high_units_memory == 0: + raise ValueError("at least one option must be specified") + if ttl is not None and ttl < 0: + raise ValueError("ttl cannot be less than zero") + if ttl is not None and 0 < ttl < 250: + raise ValueError("ttl has 1/4 second resolution; minimum TTL is 250") + if high_units != 0 and high_units_memory != 0: + raise ValueError("high_units and high_units_memory cannot be used together; specify one or the other") + if prune_factor < 0.1 or prune_factor > 1: + raise ValueError("prune_factor must be between .1 and 1") + + self._ttl = ttl if ttl is not None and ttl >= 0 else 0 + self._high_units = high_units + self._high_units_memory = high_units_memory + self._prune_factor = prune_factor + + def __str__(self) -> str: + """ + Returns a string representation of this NearCacheOptions instance. + + :return: string representation of this NearCacheOptions instance + """ + return ( + f"NearCacheOptions(ttl={self.ttl}ms, high-units={self.high_units}" + f", high-units-memory={self.high_unit_memory}" + f", prune-factor={self.prune_factor:.2f})" + ) + + def __eq__(self, other: Any) -> bool: + """ + Compare two NearCacheOptions for equality. + + :param other: the NearCacheOptions to compare against + :return: True if equal otherwise False + """ + if self is other: + return True + + if isinstance(other, NearCacheOptions): + return ( + self.ttl == other.ttl + and self.high_units == other.high_units + and self.high_unit_memory == other.high_unit_memory + and self.prune_factor == other.prune_factor + ) + + return False + + @property + def ttl(self) -> int: + """ + The time-to-live to be applied to entries inserted into the + near cache. + + :return: the ttl to be applied to entries inserted into the + near cache + """ + return self._ttl + + @property + def high_units(self) -> int: + """ + The maximum number of entries that may be held by the near cache. + If this value is exceeded, the cache will be pruned down by the + percentage defined by the prune_factor. + + :return: the maximum number of entries that may be held by the + near cache + """ + return self._high_units + + @property + def high_unit_memory(self) -> int: + """ + The maximum number of entries, in bytes, that may be held in the near cache. + If this value is exceeded, the cache will be pruned down by the + percentage defined by the prune_factor. + + :return: the maximum number of entries, in bytes, that may be held + by the near cache + """ + return self._high_units_memory + + @property + def prune_factor(self) -> float: + """ + This is percentage of units that will remain after a cache has + been pruned. When high_units is configured, this will be the number + of entries. When high_units_memory is configured, this will be the + size, in bytes, of the cache will be pruned down to. + + :return: the target cache size after pruning occurs + """ + return self._prune_factor + + +class LocalEntry(MapEntry[K, V]): + """ + A MapEntry implementation that includes metadata + to allow for expiry and pruning within a local + cache. + """ + + def __init__(self, key: K, value: V, ttl: int): + """ + Constructs a new LocalEntry. + + :param key: the entry key + :param value: the entry value + :param ttl: the time-to-live for this entry + """ + super().__init__(key, value) + self._ttl: int = ttl + now: int = cur_time_millis() + # store when this entry expires (1/4 second resolution) + self._expires: int = ((now + ttl) & ~0xFF) if ttl > 0 else 0 + self._last_access: int = now + self._size = asizeof.asizeof(self) + self._size += asizeof.asizeof(self._size) + + @property + def bytes(self) -> int: + """ + :return: the size in bytes of this entry + """ + return self._size + + @property + def ttl(self) -> int: + """ + :return: the time-to-live , in millis, of this entry + """ + return self._ttl + + @property + def last_access(self) -> int: + """ + :return: the last time, in millis, this entry was accessed + """ + return self._last_access + + @property + def expires_at(self) -> int: + """ + :return: the time when this entry will expire + """ + return self._expires + + def touch(self) -> None: + """ + Updates the last accessed time of this entry. + """ + self._last_access = cur_time_millis() + + def __str__(self) -> str: + return ( + f"LocalEntry(key={self.key}, value={self.value}," + f" ttl={self.ttl}ms," + f" last-access={millis_format_date(self.last_access)})" + ) + + +class CacheStats: + """ + Tracking statistics for LocalCaches. + """ + + def __init__(self, local_cache: "LocalCache[K, V]"): + """ + Constructs a new CacheStats. + + :param local_cache: the associated LocalCache + """ + self._local_cache: "LocalCache[K, V]" = local_cache + self._hits: int = 0 + self._misses: int = 0 + self._puts: int = 0 + self._memory: int = 0 + self._prunes: int = 0 + self._pruned_count: int = 0 + self._expires: int = 0 + self._expired_count: int = 0 + self._expires_millis: int = 0 + self._prunes_millis: int = 0 + self._misses_millis: int = 0 + + @property + def hits(self) -> int: + """ + The number of times an entry was found in the near cache. + + :return: the number of cache hits + """ + return self._hits + + @property + def misses(self) -> int: + """ + The number of times an entry was not found in the near cache. + + :return: the number of cache misses + """ + return self._misses + + @property + def misses_duration(self) -> int: + """ + The accumulated time, in millis, spent for a cache miss. + + :return: the accumulated total of millis spent when a cache + miss occurs and a remote get is made + """ + return self._misses_millis + + @property + def hit_rate(self) -> float: + """ + The ration of hits to misses. + + :return: the ratio of hits to misses + """ + hits: int = self.hits + misses: int = self.misses + total = hits + misses + + if misses == 0 and hits > 0: + return 1.0 if hits > 0 else 0.0 + + return 0.0 if total == 0 else round((float(hits) / (float(total))), 3) + + @property + def puts(self) -> int: + """ + The total number of puts that have been made against the near + cache. + + :return: the total number of puts + """ + return self._puts + + @property + def gets(self) -> int: + """ + The total number of gets that have been made against the near + cache. + + :return: the total number of gets + """ + return self.hits + self.misses + + @property + def prunes(self) -> int: + """ + :return: the number of times the cache was pruned due to exceeding + the configured high-water mark + """ + return self._prunes + + @property + def expires(self) -> int: + """ + The number of times expiry of entries has been processed. + + :return: the number of times expiry was processed + """ + return self._expires + + @property + def num_pruned(self) -> int: + """ + The total number of entries that have been removed due to + exceeding the configured high-water mark. + + :return: the number of entries pruned + """ + return self._pruned_count + + @property + def num_expired(self) -> int: + """ + The total number of entries that have been removed due to + expiration. + + :return: the number of entries that was expired + """ + return self._expired_count + + @property + def prunes_duration(self) -> int: + """ + The accumulated total time, in millis, spent pruning the + near cache + + :return: the accumulated total of millis spent pruning the near + cache + """ + return self._prunes_millis + + @property + def expires_duration(self) -> int: + """ + The accumulated total time, in millis, spent processing expiration + of entries in the near cache + + :return: the accumulated total of millis spent expiring entries + in the near cache + """ + return self._expires_millis + + @property + def size(self) -> int: + """ + The total number of entries held by the near cache. + + :return: the number of local cache entries + """ + return len(self._local_cache.storage) + + @property + def bytes(self) -> int: + """ + The total number of bytes the entries of the near cache is + consuming. + + :return: The total number of bytes the entries of the near cache is + consuming + """ + return self._memory + + def reset(self) -> None: + """ + Resets all statistics aside from memory consumption and size. + + :return: None + """ + self._prunes = 0 + self._prunes_millis = 0 + self._pruned_count = 0 + self._misses = 0 + self._misses_millis = 0 + self._hits = 0 + self._puts = 0 + self._expires = 0 + self._expires_millis = 0 + self._expired_count = 0 + + def _register_hit(self) -> None: + """ + Register a hit. + + :return: None + """ + self._hits += 1 + + def _register_miss(self) -> None: + """ + Register a miss. + + :return: None + """ + self._misses += 1 + + def _register_put(self) -> None: + """ + Register a put. + + :return: None + """ + self._puts += 1 + + def _update_memory(self, size: int) -> None: + """ + Update the current memory total. + + :param size: the memory amount to increase/decrease + :return: None + """ + self._memory += size + + def _register_prunes(self, count: int, millis: int) -> None: + """ + Register prune statistics. + + :param count: the number of entries pruned + :param millis: the number of millis spent on a prune operation + :return: None + """ + self._prunes += 1 + self._pruned_count += count + self._prunes_millis += millis if millis > 0 else 1 + + def _register_misses_millis(self, millis: int) -> None: + """ + Register miss millis. + + :param millis: the millis spent when a cache miss occurs + :return: None + """ + self._misses_millis += millis + + def _register_expires(self, count: int, millis: int) -> None: + """ + Register the number of entries expired and the millis spent processing + the expiry logic. + + :param count: the number of entries expired + :param millis: the time spent processing + :return: None + """ + self._expires += 1 + self._expired_count += count + self._expires_millis += millis if millis > 0 else 1 + + def __str__(self) -> str: + """ + :return: the string representation of this CacheStats instance. + """ + return ( + f"CacheStats(puts={self.puts}, gets={self.gets}, hits={self.hits}" + f", misses={self.misses}, misses-duration={self.misses_duration}ms" + f", hit-rate={self.hit_rate}, prunes={self.prunes}, num-pruned={self.num_pruned}" + f", prunes-duration={self.prunes_duration}ms, size={self.size}" + f", expires={self.num_expired}, num-expired={self.expires}" + f", expires-duration={self.expires_duration}ms" + f", memory-bytes={self.bytes})" + ) + + +# noinspection PyProtectedMember +class LocalCache(Generic[K, V]): + """ + A local cache of entries. This cache will expire entries as they ripen + and will prune the cache down to any configured watermarks defined + in the NearCacheOptions + """ + + def __init__(self, name: str, options: NearCacheOptions): + """ + Constructs a new LocalCache. + + :param name: the name of the local cache + :param options: the NearCacheOptions configuring this LocalCache + """ + self._name: str = name + self._options: NearCacheOptions = options + self._stats: CacheStats = CacheStats(self) + self._storage: dict[K, Optional[LocalEntry[K, V]]] = dict() + self._expiries: dict[int, set[K]] = OrderedDict() + self._lock: asyncio.Lock = asyncio.Lock() + self._next_expiry: int = 0 + + async def put(self, key: K, value: V, ttl: Optional[int] = None) -> Optional[V]: + """ + Associates the specified value with the specified key in this cache. If the + cache previously contained a mapping for this key, the old value is replaced. + + :param key: the key with which the specified value is to be associated + :param value: the value to be associated with the specified key + :param ttl: the time-to-live (in millis) of this entry + :return: the previous value associated with the specified key, or `None` + if there was no mapping for key. A `None` return can also indicate + that the map previously associated `None` with the specified key + if the implementation supports `None` values + """ + async with self._lock: + stats: CacheStats = self.stats + storage: dict[K, Optional[LocalEntry[K, V]]] = self.storage + stats._register_put() + self._prune() + + old_entry: Optional[LocalEntry[K, V]] = storage.get(key, None) + if old_entry is not None: + stats._update_memory(-old_entry.bytes) + + entry: LocalEntry[K, V] = LocalEntry(key, value, ttl if ttl is not None else self.options.ttl) + self._register_expiry(entry) + stats._update_memory(entry.bytes) + + storage[key] = entry + + return None if old_entry is None else old_entry.value + + async def get(self, key: K) -> Optional[V]: + """ + Returns the value to which this cache maps the specified key. + + :param key: the key whose associated value is to be returned + """ + async with self._lock: + storage: dict[K, Optional[LocalEntry[K, V]]] = self.storage + stats: CacheStats = self.stats + self._expire() + + entry: Optional[LocalEntry[K, V]] = storage.get(key, None) + if entry is None: + stats._register_miss() + return None + + stats._register_hit() + entry.touch() + + return entry.value + + async def get_all(self, keys: set[K]) -> dict[K, V]: + """ + Get all the specified keys if they are in the cache. For each key that is in the cache, + that key and its corresponding value will be placed in the cache that is returned by + this method. The absence of a key in the returned map indicates that it was not in the cache, + which may imply (for caches that can load behind the scenes) that the requested data + could not be loaded. + + :param keys: an Iterable of keys that may be in this cache + :return: a dict containing the keys/values that were found in the + local cache + """ + async with self._lock: + self._expire() + + stats: CacheStats = self.stats + results: dict[K, V] = dict() + + for key in keys: + entry: Optional[LocalEntry[K, V]] = self.storage.get(key, None) + if entry is None: + stats._register_miss() + continue + + stats._register_hit() + entry.touch() + + results[key] = entry.value + + return results + + async def remove(self, key: K) -> Optional[V]: + """ + Removes the mapping for a key from this cache if it is present. + + :param key: key whose mapping is to be removed from the cache + :return: the previous value associated with key, or `None` if there was no mapping for key + """ + async with self._lock: + self._expire() + entry: Optional[LocalEntry[K, V]] = self.storage.pop(key, None) + + if entry is None: + return None + + self._remove_expiry(entry) + self.stats._update_memory(-entry.bytes) + return entry.value + + async def contains_key(self, key: K) -> bool: + """ + Returns `true` if the specified key is mapped a value within the cache. + + :param key: the key whose presence in this cache is to be tested + :return: resolving to `true` if the key is mapped to a value, or `false` if it does not + """ + return key in self._storage + + async def size(self) -> int: + """ + Signifies the number of key-value mappings in this cache. + + :return: the number of key-value mappings in this cache + """ + async with self._lock: + self._expire() + + return len(self.storage) + + async def clear(self) -> None: + """ + Clears all the mappings in the cache. + """ + async with self._lock: + self._storage = dict() + self._expiries = OrderedDict() + + self.stats._memory = 0 + + async def release(self) -> None: + """ + Release local resources associated with instance. + """ + await self.clear() + + @property + def stats(self) -> CacheStats: + """ + :return: the statistics for this cache + """ + return self._stats + + @property + def name(self) -> str: + """ + :return: the name of this cache + """ + return self._name + + @property + def storage(self) -> dict[K, Optional[LocalEntry[K, V]]]: + """ + :return: the local storage for this cache + """ + return self._storage + + @property + def options(self) -> NearCacheOptions: + """ + :return: the NearCacheOptions for this cache + """ + return self._options + + def _prune(self) -> None: + """ + Prunes this cache based on NearCacheOptions configuration. + + :return: None + """ + self._expire() + + storage: dict[K, Optional[LocalEntry[K, V]]] = self.storage + options: NearCacheOptions = self.options + prune_factor: float = options.prune_factor + high_units: int = options.high_units + high_units_used: bool = high_units > 0 + high_units_mem: int = options.high_unit_memory + mem_units_used: bool = high_units_mem > 0 + cur_size = len(storage) + + if (high_units_used and high_units < cur_size + 1) or (mem_units_used and high_units_mem < self.stats.bytes): + start = cur_time_millis() + stats: CacheStats = self.stats + prune_count: int = 0 + + to_sort: list[Tuple[int, K]] = [] + for key, value in storage.items(): + if value is not None: + to_sort.append((value.last_access, key)) + + to_sort = sorted(to_sort, key=lambda x: x[0]) + + target_size: int = int(round(float((cur_size if high_units_used else stats.bytes) * prune_factor))) + + for item in to_sort: + entry: Optional[LocalEntry[K, V]] = storage.pop(item[1]) + if entry is not None: + self._remove_expiry(entry) + stats._update_memory(-entry.bytes) + prune_count += 1 + + if (len(storage) if high_units_used else stats._memory) <= target_size: + break + + end = cur_time_millis() + stats._register_prunes(prune_count, end - start) + + def _expire(self) -> None: + """ + Process and remove any expired entries from the cache. + + :return: None + """ + expires: dict[int, set[K]] = self._expiries + if len(expires) == 0: + return + + now: int = cur_time_millis() + if self._next_expiry > 0 and now < self._next_expiry: + return + + storage: dict[K, Optional[LocalEntry[K, V]]] = self.storage + stats: CacheStats = self.stats + + expired_count: int = 0 + exp_buckets_to_remove: list[int] = [] + + for expire_time, keys in expires.items(): + if expire_time < now: + exp_buckets_to_remove.append(expire_time) + for key in keys: + entry: Optional[LocalEntry[K, V]] = storage.pop(key, None) + if entry is not None: + expired_count += 1 + stats._update_memory(-entry.bytes) + continue + break + + if len(exp_buckets_to_remove) > 0: + for bucket in exp_buckets_to_remove: + expires.pop(bucket, None) + + end = cur_time_millis() + if expired_count > 0: + stats._register_expires(expired_count, end - now) + + # expiries have 1/4 second resolution, so only check + # expiry in the same interval + self._next_expiry = end + 250 + + def _register_expiry(self, entry: LocalEntry[K, V]) -> None: + """ + Register the expiry, if any, of provided entry. + + :param entry: the entry to register + :return: None + """ + if entry.ttl > 0: + expires_at = entry.expires_at + expires_map: dict[int, set[K]] = self._expiries + if expires_at in expires_map: + keys: set[K] = expires_map[expires_at] + keys.add(entry.key) + else: + expires_map[expires_at] = {entry.key} + + def _remove_expiry(self, entry: LocalEntry[K, V]) -> None: + """ + Removes the provided entry from expiry tracking. + + :param entry: the entry to register + :return: None + """ + if entry.ttl > 0: + expires_at = entry.expires_at + expires_map: dict[int, set[K]] = self._expiries + if expires_at in expires_map: + keys: set[K] = expires_map[expires_at] + expire_key: K = entry.key + if expire_key in keys: + keys.remove(expire_key) + + def __str__(self) -> str: + """ + :return: the string representation of this LocalCache. + """ + return f"LocalCache(name={self.name}, options={self.options}, stats={self.stats})" diff --git a/src/coherence/processor.py b/src/coherence/processor.py index 3da0eb9..86b88f3 100644 --- a/src/coherence/processor.py +++ b/src/coherence/processor.py @@ -62,8 +62,8 @@ def when(self, filter: Filter) -> EntryProcessor[R]: applied to the entry evaluates to `true`; otherwise the result of the invocation will return `None`. - :param filter: the filter :return: Returns a :class:`coherence.processor.ConditionalProcessor` comprised of - this processor and the provided filter. + :param filter: the filter :return: Returns a :class:`coherence.processor.ConditionalProcessor` comprised + of this processor and the provided filter. """ return ConditionalProcessor(filter, self) diff --git a/src/coherence/proxy_service_messages_v1_pb2.py b/src/coherence/proxy_service_messages_v1_pb2.py new file mode 100644 index 0000000..115d0f8 --- /dev/null +++ b/src/coherence/proxy_service_messages_v1_pb2.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: proxy_service_messages_v1.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +import coherence.common_messages_v1_pb2 as common__messages__v1__pb2 +from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1fproxy_service_messages_v1.proto\x12\x12\x63oherence.proxy.v1\x1a\x18\x63ommon_messages_v1.proto\x1a\x19google/protobuf/any.proto\"\xbb\x01\n\x0cProxyRequest\x12\n\n\x02id\x18\x01 \x01(\x03\x12/\n\x04init\x18\x03 \x01(\x0b\x32\x1f.coherence.proxy.v1.InitRequestH\x00\x12\'\n\x07message\x18\x04 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00\x12:\n\theartbeat\x18\x05 \x01(\x0b\x32%.coherence.common.v1.HeartbeatMessageH\x00\x42\t\n\x07request\"\xa5\x02\n\rProxyResponse\x12\n\n\x02id\x18\x01 \x01(\x03\x12\x30\n\x04init\x18\x04 \x01(\x0b\x32 .coherence.proxy.v1.InitResponseH\x00\x12\'\n\x07message\x18\x05 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00\x12\x32\n\x05\x65rror\x18\x06 \x01(\x0b\x32!.coherence.common.v1.ErrorMessageH\x00\x12\x31\n\x08\x63omplete\x18\x07 \x01(\x0b\x32\x1d.coherence.common.v1.CompleteH\x00\x12:\n\theartbeat\x18\x08 \x01(\x0b\x32%.coherence.common.v1.HeartbeatMessageH\x00\x42\n\n\x08response\"\xc7\x01\n\x0bInitRequest\x12\r\n\x05scope\x18\x02 \x01(\t\x12\x0e\n\x06\x66ormat\x18\x03 \x01(\t\x12\x10\n\x08protocol\x18\x04 \x01(\t\x12\x17\n\x0fprotocolVersion\x18\x05 \x01(\x05\x12 \n\x18supportedProtocolVersion\x18\x06 \x01(\x05\x12\x16\n\theartbeat\x18\x07 \x01(\x03H\x00\x88\x01\x01\x12\x17\n\nclientUuid\x18\x08 \x01(\x0cH\x01\x88\x01\x01\x42\x0c\n\n_heartbeatB\r\n\x0b_clientUuid\"\x8e\x01\n\x0cInitResponse\x12\x0c\n\x04uuid\x18\x01 \x01(\x0c\x12\x0f\n\x07version\x18\x02 \x01(\t\x12\x16\n\x0e\x65ncodedVersion\x18\x03 \x01(\x05\x12\x17\n\x0fprotocolVersion\x18\x04 \x01(\x05\x12\x15\n\rproxyMemberId\x18\x05 \x01(\x05\x12\x17\n\x0fproxyMemberUuid\x18\x06 \x01(\x0c\x42/\n+com.oracle.coherence.grpc.messages.proxy.v1P\x01\x62\x06proto3') + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'proxy_service_messages_v1_pb2', globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'\n+com.oracle.coherence.grpc.messages.proxy.v1P\001' + _PROXYREQUEST._serialized_start=109 + _PROXYREQUEST._serialized_end=296 + _PROXYRESPONSE._serialized_start=299 + _PROXYRESPONSE._serialized_end=592 + _INITREQUEST._serialized_start=595 + _INITREQUEST._serialized_end=794 + _INITRESPONSE._serialized_start=797 + _INITRESPONSE._serialized_end=939 +# @@protoc_insertion_point(module_scope) diff --git a/src/coherence/proxy_service_messages_v1_pb2.pyi b/src/coherence/proxy_service_messages_v1_pb2.pyi new file mode 100644 index 0000000..ed1b94d --- /dev/null +++ b/src/coherence/proxy_service_messages_v1_pb2.pyi @@ -0,0 +1,70 @@ +# mypy: ignore-errors +import common_messages_v1_pb2 as _common_messages_v1_pb2 +from google.protobuf import any_pb2 as _any_pb2 +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import ClassVar as _ClassVar, Mapping as _Mapping, Optional as _Optional, Union as _Union + +DESCRIPTOR: _descriptor.FileDescriptor + +class InitRequest(_message.Message): + __slots__ = ["clientUuid", "format", "heartbeat", "protocol", "protocolVersion", "scope", "supportedProtocolVersion"] + CLIENTUUID_FIELD_NUMBER: _ClassVar[int] + FORMAT_FIELD_NUMBER: _ClassVar[int] + HEARTBEAT_FIELD_NUMBER: _ClassVar[int] + PROTOCOLVERSION_FIELD_NUMBER: _ClassVar[int] + PROTOCOL_FIELD_NUMBER: _ClassVar[int] + SCOPE_FIELD_NUMBER: _ClassVar[int] + SUPPORTEDPROTOCOLVERSION_FIELD_NUMBER: _ClassVar[int] + clientUuid: bytes + format: str + heartbeat: int + protocol: str + protocolVersion: int + scope: str + supportedProtocolVersion: int + def __init__(self, scope: _Optional[str] = ..., format: _Optional[str] = ..., protocol: _Optional[str] = ..., protocolVersion: _Optional[int] = ..., supportedProtocolVersion: _Optional[int] = ..., heartbeat: _Optional[int] = ..., clientUuid: _Optional[bytes] = ...) -> None: ... + +class InitResponse(_message.Message): + __slots__ = ["encodedVersion", "protocolVersion", "proxyMemberId", "proxyMemberUuid", "uuid", "version"] + ENCODEDVERSION_FIELD_NUMBER: _ClassVar[int] + PROTOCOLVERSION_FIELD_NUMBER: _ClassVar[int] + PROXYMEMBERID_FIELD_NUMBER: _ClassVar[int] + PROXYMEMBERUUID_FIELD_NUMBER: _ClassVar[int] + UUID_FIELD_NUMBER: _ClassVar[int] + VERSION_FIELD_NUMBER: _ClassVar[int] + encodedVersion: int + protocolVersion: int + proxyMemberId: int + proxyMemberUuid: bytes + uuid: bytes + version: str + def __init__(self, uuid: _Optional[bytes] = ..., version: _Optional[str] = ..., encodedVersion: _Optional[int] = ..., protocolVersion: _Optional[int] = ..., proxyMemberId: _Optional[int] = ..., proxyMemberUuid: _Optional[bytes] = ...) -> None: ... + +class ProxyRequest(_message.Message): + __slots__ = ["heartbeat", "id", "init", "message"] + HEARTBEAT_FIELD_NUMBER: _ClassVar[int] + ID_FIELD_NUMBER: _ClassVar[int] + INIT_FIELD_NUMBER: _ClassVar[int] + MESSAGE_FIELD_NUMBER: _ClassVar[int] + heartbeat: _common_messages_v1_pb2.HeartbeatMessage + id: int + init: InitRequest + message: _any_pb2.Any + def __init__(self, id: _Optional[int] = ..., init: _Optional[_Union[InitRequest, _Mapping]] = ..., message: _Optional[_Union[_any_pb2.Any, _Mapping]] = ..., heartbeat: _Optional[_Union[_common_messages_v1_pb2.HeartbeatMessage, _Mapping]] = ...) -> None: ... + +class ProxyResponse(_message.Message): + __slots__ = ["complete", "error", "heartbeat", "id", "init", "message"] + COMPLETE_FIELD_NUMBER: _ClassVar[int] + ERROR_FIELD_NUMBER: _ClassVar[int] + HEARTBEAT_FIELD_NUMBER: _ClassVar[int] + ID_FIELD_NUMBER: _ClassVar[int] + INIT_FIELD_NUMBER: _ClassVar[int] + MESSAGE_FIELD_NUMBER: _ClassVar[int] + complete: _common_messages_v1_pb2.Complete + error: _common_messages_v1_pb2.ErrorMessage + heartbeat: _common_messages_v1_pb2.HeartbeatMessage + id: int + init: InitResponse + message: _any_pb2.Any + def __init__(self, id: _Optional[int] = ..., init: _Optional[_Union[InitResponse, _Mapping]] = ..., message: _Optional[_Union[_any_pb2.Any, _Mapping]] = ..., error: _Optional[_Union[_common_messages_v1_pb2.ErrorMessage, _Mapping]] = ..., complete: _Optional[_Union[_common_messages_v1_pb2.Complete, _Mapping]] = ..., heartbeat: _Optional[_Union[_common_messages_v1_pb2.HeartbeatMessage, _Mapping]] = ...) -> None: ... diff --git a/src/coherence/proxy_service_messages_v1_pb2_grpc.py b/src/coherence/proxy_service_messages_v1_pb2_grpc.py new file mode 100644 index 0000000..2daafff --- /dev/null +++ b/src/coherence/proxy_service_messages_v1_pb2_grpc.py @@ -0,0 +1,4 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + diff --git a/src/coherence/proxy_service_v1_pb2.py b/src/coherence/proxy_service_v1_pb2.py new file mode 100644 index 0000000..9817094 --- /dev/null +++ b/src/coherence/proxy_service_v1_pb2.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: proxy_service_v1.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +import coherence.proxy_service_messages_v1_pb2 as proxy__service__messages__v1__pb2 +from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 +from google.protobuf import wrappers_pb2 as google_dot_protobuf_dot_wrappers__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16proxy_service_v1.proto\x12\x12\x63oherence.proxy.v1\x1a\x1fproxy_service_messages_v1.proto\x1a\x1bgoogle/protobuf/empty.proto\x1a\x1egoogle/protobuf/wrappers.proto2g\n\x0cProxyService\x12W\n\nsubChannel\x12 .coherence.proxy.v1.ProxyRequest\x1a!.coherence.proxy.v1.ProxyResponse\"\x00(\x01\x30\x01\x42/\n+com.oracle.coherence.grpc.services.proxy.v1P\x01\x62\x06proto3') + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'proxy_service_v1_pb2', globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'\n+com.oracle.coherence.grpc.services.proxy.v1P\001' + _PROXYSERVICE._serialized_start=140 + _PROXYSERVICE._serialized_end=243 +# @@protoc_insertion_point(module_scope) diff --git a/src/coherence/proxy_service_v1_pb2.pyi b/src/coherence/proxy_service_v1_pb2.pyi new file mode 100644 index 0000000..405eb1e --- /dev/null +++ b/src/coherence/proxy_service_v1_pb2.pyi @@ -0,0 +1,8 @@ +# mypy: ignore-errors +import proxy_service_messages_v1_pb2 as _proxy_service_messages_v1_pb2 +from google.protobuf import empty_pb2 as _empty_pb2 +from google.protobuf import wrappers_pb2 as _wrappers_pb2 +from google.protobuf import descriptor as _descriptor +from typing import ClassVar as _ClassVar + +DESCRIPTOR: _descriptor.FileDescriptor diff --git a/src/coherence/proxy_service_v1_pb2_grpc.py b/src/coherence/proxy_service_v1_pb2_grpc.py new file mode 100644 index 0000000..8f6e56b --- /dev/null +++ b/src/coherence/proxy_service_v1_pb2_grpc.py @@ -0,0 +1,79 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +import coherence.proxy_service_messages_v1_pb2 as proxy__service__messages__v1__pb2 + + +class ProxyServiceStub(object): + """----------------------------------------------------------------- + The Coherence gRPC Proxy Service definition. + ----------------------------------------------------------------- + + """ + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.subChannel = channel.stream_stream( + '/coherence.proxy.v1.ProxyService/subChannel', + request_serializer=proxy__service__messages__v1__pb2.ProxyRequest.SerializeToString, + response_deserializer=proxy__service__messages__v1__pb2.ProxyResponse.FromString, + ) + + +class ProxyServiceServicer(object): + """----------------------------------------------------------------- + The Coherence gRPC Proxy Service definition. + ----------------------------------------------------------------- + + """ + + def subChannel(self, request_iterator, context): + """Sets up a bidirectional channel for cache requests and responses. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_ProxyServiceServicer_to_server(servicer, server): + rpc_method_handlers = { + 'subChannel': grpc.stream_stream_rpc_method_handler( + servicer.subChannel, + request_deserializer=proxy__service__messages__v1__pb2.ProxyRequest.FromString, + response_serializer=proxy__service__messages__v1__pb2.ProxyResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'coherence.proxy.v1.ProxyService', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class ProxyService(object): + """----------------------------------------------------------------- + The Coherence gRPC Proxy Service definition. + ----------------------------------------------------------------- + + """ + + @staticmethod + def subChannel(request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.stream_stream(request_iterator, target, '/coherence.proxy.v1.ProxyService/subChannel', + proxy__service__messages__v1__pb2.ProxyRequest.SerializeToString, + proxy__service__messages__v1__pb2.ProxyResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/src/coherence/serialization.py b/src/coherence/serialization.py index 0ab0ecd..8b71e3f 100644 --- a/src/coherence/serialization.py +++ b/src/coherence/serialization.py @@ -4,6 +4,7 @@ from __future__ import annotations +import collections from abc import ABC, abstractmethod from decimal import Decimal from typing import Any, Callable, Dict, Final, Optional, Type, TypeVar, cast @@ -18,6 +19,7 @@ _META_CLASS: Final[str] = "@class" _META_VERSION: Final[str] = "@version" _META_ENUM: Final[str] = "enum" +_META_ORDERED: Final[str] = "@ordered" _JSON_KEY = "key" _JSON_VALUE = "value" @@ -80,7 +82,8 @@ def deserialize(self, value: bytes) -> T: # type: ignore return cast(T, None) else: if ord(s[0]) == ord(MAGIC_BYTE): - return jsonpickle.decode(s[1:], context=self._unpickler) + r = jsonpickle.decode(s[1:], context=self._unpickler) + return r else: raise ValueError("Invalid JSON serialization format") else: @@ -205,10 +208,10 @@ class JavaProxyUnpickler(jsonpickle.Unpickler): # noinspection PyUnresolvedReferences def _restore(self, obj: Any) -> Any: if isinstance(obj, dict): - metadata: str = obj.get(_META_CLASS, None) + metadata: Any = obj.get(_META_CLASS, None) if metadata is not None: type_: Optional[Type[Any]] = _type_for(metadata) - actual: dict[str, Any] = dict() + actual: dict[Any, Any] = dict() if type_ is None: if "map" in metadata.lower(): for entry in obj[_JSON_ENTRIES]: @@ -228,6 +231,25 @@ def _restore(self, obj: Any) -> Any: return super().restore(actual, reset=False) + # When "@Ordered" set to true which converts to OrderedDict() + metadata = obj.get(_META_ORDERED, False) + if metadata is True: + o = collections.OrderedDict() + entries = obj.get(_JSON_ENTRIES, None) + if entries is not None: + for entry in obj[_JSON_ENTRIES]: + o[entry[_JSON_KEY]] = entry[_JSON_VALUE] + return o + + # When there is no "@Ordered" set. Only "entries" list exists + if len(obj) == 1: + entries = obj.get(_JSON_ENTRIES, None) + if entries is not None: + actual = dict() + for entry in obj[_JSON_ENTRIES]: + actual[entry[_JSON_KEY]] = entry[_JSON_VALUE] + return super().restore(actual, reset=False) + return super()._restore(obj) diff --git a/src/coherence/services_pb2.pyi b/src/coherence/services_pb2.pyi index 6197156..016abbc 100644 --- a/src/coherence/services_pb2.pyi +++ b/src/coherence/services_pb2.pyi @@ -5,4 +5,4 @@ from google.protobuf import wrappers_pb2 as _wrappers_pb2 from google.protobuf import descriptor as _descriptor from typing import ClassVar as _ClassVar -DESCRIPTOR: _descriptor.FileDescriptor +DESCRIPTOR: _descriptor.FileDescriptor \ No newline at end of file diff --git a/src/coherence/util.py b/src/coherence/util.py index 06c073d..edc2316 100644 --- a/src/coherence/util.py +++ b/src/coherence/util.py @@ -4,11 +4,29 @@ from __future__ import annotations +import asyncio +import sys +import threading import time -from typing import Optional, TypeVar +from abc import ABC, abstractmethod +from asyncio import Event +from datetime import datetime, timezone +from typing import Any, AsyncIterator, Callable, Final, Generic, Optional, Tuple, TypeVar + +from google.protobuf.any_pb2 import Any as GrpcAny +from google.protobuf.wrappers_pb2 import BoolValue, BytesValue, Int32Value from .aggregator import EntryAggregator +from .cache_service_messages_v1_pb2 import EnsureCacheRequest, ExecuteRequest, IndexRequest, KeyOrFilter, KeysOrFilter +from .cache_service_messages_v1_pb2 import MapListenerRequest as V1MapListenerRequest +from .cache_service_messages_v1_pb2 import NamedCacheRequest, NamedCacheRequestType, NamedCacheResponse +from .cache_service_messages_v1_pb2 import PutAllRequest as V1PutAllRequest +from .cache_service_messages_v1_pb2 import PutRequest as V1PutRequest +from .cache_service_messages_v1_pb2 import QueryRequest +from .cache_service_messages_v1_pb2 import ReplaceMappingRequest as V1ReplaceMappingRequest +from .common_messages_v1_pb2 import BinaryKeyAndValue, CollectionOfBytesValues, OptionalValue from .comparator import Comparator +from .entry import MapEntry from .extractor import ValueExtractor from .filter import Filter, Filters, MapEventFilter from .messages_pb2 import ( @@ -41,8 +59,12 @@ ValuesRequest, ) from .processor import EntryProcessor +from .proxy_service_messages_v1_pb2 import InitRequest, ProxyRequest from .serialization import Serializer +_FILTER_REQUIRED: Final[str] = "Filter cannot be None" +_KEYS_FILTERS_EXCLUSIVE: Final[str] = "keys and filter are mutually exclusive" + E = TypeVar("E") K = TypeVar("K") R = TypeVar("R") @@ -50,6 +72,348 @@ V = TypeVar("V") +def cur_time_millis() -> int: + """ + :return: the current time, in millis, since epoch + """ + return time.time_ns() // 1_000_000 + + +def millis_format_date(millis: int) -> str: + """ + Format the given time in millis to a readable format. + + :param millis: the millis time to format + :return: the formatted date + """ + dt = datetime.fromtimestamp(millis / 1000, timezone.utc) + return dt.strftime("%Y-%m-%dT%H:%M:%S.%f") + + +class Dispatcher(ABC): + def __init__(self, timeout: float): + super().__init__() + self._timeout: float = timeout + + @abstractmethod + async def dispatch(self, stream_handler: Any) -> None: + pass + + +class ResponseTransformer(ABC, Generic[T]): + def __init__(self, serializer: Serializer): + self._serializer = serializer + + @abstractmethod + def transform(self, response: NamedCacheResponse) -> T: + pass + + @property + def serializer(self) -> Serializer: + return self._serializer + + +class ScalarResultProducer(ABC, Generic[T]): + @abstractmethod + def result(self) -> T: + pass + + +class KeyValueTransformer(ResponseTransformer[MapEntry[K, V]]): + def __init__(self, serializer: Serializer): + super().__init__(serializer) + + def transform(self, response: NamedCacheResponse) -> MapEntry[K, V]: + from coherence import MapEntry + + binary_key_value = BinaryKeyAndValue() + response.message.Unpack(binary_key_value) + return MapEntry( + self.serializer.deserialize(binary_key_value.key), self._serializer.deserialize(binary_key_value.value) + ) + + +class ValueTransformer(ResponseTransformer[V]): + def __init__(self, serializer: Serializer): + super().__init__(serializer) + + def transform(self, response: NamedCacheResponse) -> V: + binary_key_value = BinaryKeyAndValue() + response.message.Unpack(binary_key_value) + return self.serializer.deserialize(binary_key_value.value) + + +class OptionalValueTransformer(ResponseTransformer[Optional[T]]): + def __init__(self, serializer: Serializer): + super().__init__(serializer) + + def transform(self, response: NamedCacheResponse) -> Optional[T]: + optional_value = OptionalValue() + response.message.Unpack(optional_value) + if optional_value.present: + return self.serializer.deserialize(optional_value.value) + else: + return None + + +class IntValueTransformer(ResponseTransformer[int]): + def __init__(self, serializer: Serializer): + super().__init__(serializer) + + def transform(self, response: NamedCacheResponse) -> int: + value = Int32Value() + response.message.Unpack(value) + return value.value + + +class BoolValueTransformer(ResponseTransformer[bool]): + def __init__(self, serializer: Serializer): + super().__init__(serializer) + + def transform(self, response: NamedCacheResponse) -> bool: + bool_value = BoolValue() + response.message.Unpack(bool_value) + return bool_value.value + + +class BytesValueTransformer(ResponseTransformer[Optional[T]]): + def __init__(self, serializer: Serializer): + super().__init__(serializer) + + def transform(self, response: NamedCacheResponse) -> Optional[T]: + bytes_value = BytesValue() + response.message.Unpack(bytes_value) + result: T = self.serializer.deserialize(bytes_value.value) + return result + + +class CookieTransformer(ResponseTransformer[bytes]): + def transform(self, response: NamedCacheResponse) -> bytes: + bytes_value = BytesValue() + response.message.Unpack(bytes_value) + return bytes_value.value + + +class CacheIdTransformer(ResponseTransformer[int]): + + def transform(self, response: NamedCacheResponse) -> int: + return response.cacheId + + +class ResponseObserver(ABC): + def __init__(self, request: ProxyRequest): + if request is None: + raise ValueError("Request cannot be None") + + self._request: ProxyRequest = request + self._waiter: Event = Event() + self._complete: bool = False + self._error: Optional[Exception] = None + + @abstractmethod + def _next(self, response: NamedCacheResponse) -> None: + pass + + def _err(self, error: Exception) -> None: + self._error = error + self._done() + + def _done(self) -> None: + self._complete = True + self._waiter.set() + + @property + def id(self) -> int: + return self._request.id + + +class UnaryDispatcher(ResponseObserver, Dispatcher, ScalarResultProducer[T]): + def __init__(self, timeout: float, request: ProxyRequest, transformer: Optional[ResponseTransformer[T]] = None): + ResponseObserver.__init__(self, request) + Dispatcher.__init__(self, timeout) + self._waiter = Event() + self._transformer = transformer + self._result: T + self._complete: bool = False + + def _next(self, response: NamedCacheResponse) -> None: + if self._complete is True: + return + + if self._transformer is not None: + self._result = self._transformer.transform(response) + + async def dispatch(self, stream_handler: Any) -> None: + from . import _TIMEOUT_CONTEXT_VAR + + assert self._complete is False + + stream_handler.register_observer(self) + + async def _dispatch_and_wait() -> None: + await stream_handler.send_proxy_request(self._request) + + await self._waiter.wait() + + if self._error is not None: + raise self._error + + try: + await asyncio.wait_for(_dispatch_and_wait(), _TIMEOUT_CONTEXT_VAR.get(self._timeout)) + except Exception as e: + stream_handler.deregister_observer(self) + raise e + + def result(self) -> T: + return self._result + + +class StreamingDispatcher(ResponseObserver, Dispatcher, AsyncIterator[T]): + def __init__(self, timeout: float, request: ProxyRequest, transformer: ResponseTransformer[T]): + ResponseObserver.__init__(self, request) + Dispatcher.__init__(self, timeout) + self._transformer = transformer + self._stream_handler: Any + self._deadline = Event() + + def _next(self, response: NamedCacheResponse) -> None: + if self._complete is True: + return + + self._result: T = self._transformer.transform(response) + self._waiter.set() + + async def dispatch(self, stream_handler: Any) -> None: + # noinspection PyAttributeOutsideInit + self._stream_handler = stream_handler + stream_handler.register_observer(self) + + # setup deadline handling for this call + async def deadline() -> None: + from . import _TIMEOUT_CONTEXT_VAR + + try: + await stream_handler.send_proxy_request(self._request) + + if self._error is not None: + raise self._error + + await asyncio.wait_for(self._deadline.wait(), _TIMEOUT_CONTEXT_VAR.get(self._timeout)) + except Exception as e: + stream_handler.deregister_observer(self) + self._error = e + self._waiter.set() # raise error to the caller + + asyncio.get_running_loop().create_task(deadline()) + + def __aiter__(self) -> AsyncIterator[T]: + return self + + async def __anext__(self) -> T: + await self._waiter.wait() + if self._error is not None: + raise self._error + elif self._complete is True: + self._deadline.set() + raise StopAsyncIteration + else: + try: + return self._result + finally: + self._waiter.clear() + + +class PagingDispatcher(ResponseObserver, Dispatcher, AsyncIterator[T]): + def __init__( + self, + timeout: float, + request: ProxyRequest, + request_creator: Callable[[bytes], ProxyRequest], + transformer: ResponseTransformer[T], + ): + ResponseObserver.__init__(self, request) + Dispatcher.__init__(self, timeout) + self._cookie_transformer: ResponseTransformer[Any] = CookieTransformer(transformer.serializer) + self._transformer: ResponseTransformer[T] = transformer + self._request_creator: Callable[[bytes], ProxyRequest] = request_creator + self._first: bool = True + self._cookie: bytes = bytes() + self._exhausted: bool = False + self._stream_handler: Any + self._timeout: float = timeout + self._in_progress: bool = False + self._deadline = Event() + + def _next(self, response: NamedCacheResponse) -> None: + if self._complete is True: + return + + if self._first: + # first response will have the cookie + self._first = False + self._cookie = self._cookie_transformer.transform(response) + self._exhausted = self._cookie == b"" + else: + self._result: T = self._transformer.transform(response) + self._waiter.set() + + def _done(self) -> None: + if self._exhausted: + self._complete = True + self._waiter.set() + else: + self._first = True + self._request = self._request_creator(self._cookie) + asyncio.create_task(self.dispatch(self._stream_handler)) + + async def dispatch(self, stream_handler: Any) -> None: + # noinspection PyAttributeOutsideInit + self._stream_handler = stream_handler + stream_handler.register_observer(self) + + if self._in_progress is False: + self._in_progress = True + + # setup deadline handling for this call + async def deadline() -> None: + from . import _TIMEOUT_CONTEXT_VAR + + try: + await stream_handler.send_proxy_request(self._request) + + if self._error is not None: + raise self._error + + await asyncio.wait_for(self._deadline.wait(), _TIMEOUT_CONTEXT_VAR.get(self._timeout)) + except Exception as e: + stream_handler.deregister_observer(self) + self._error = e + self._waiter.set() # raise error to the caller + + asyncio.get_running_loop().create_task(deadline()) + else: + await stream_handler.send_proxy_request(self._request) + + if self._error is not None: + stream_handler.deregister_observer(self) + raise self._error + + def __aiter__(self) -> AsyncIterator[T]: + return self + + async def __anext__(self) -> T: + await self._waiter.wait() + if self._error is not None: + raise self._error + elif self._complete is True: + raise StopAsyncIteration + else: + try: + return self._result + finally: + self._waiter.clear() + + class RequestFactory: def __init__(self, cache_name: str, scope: str, serializer: Serializer) -> None: self._cache_name: str = cache_name @@ -59,7 +423,8 @@ def __init__(self, cache_name: str, scope: str, serializer: Serializer) -> None: self.__next_request_id: int = 0 self.__next_filter_id: int = 0 - def get_serializer(self) -> Serializer: + @property + def serializer(self) -> Serializer: return self._serializer def put_request(self, key: K, value: V, ttl: int = -1) -> PutRequest: @@ -82,16 +447,18 @@ def get_request(self, key: K) -> GetRequest: ) return g - def get_all_request(self, keys: set[K]) -> GetRequest: + def get_all_request(self, keys: set[K]) -> GetAllRequest: if keys is None: raise ValueError("Must specify a set of keys") - g: GetAllRequest = GetAllRequest(scope=self._scope, cache=self._cache_name, format=self._serializer.format) + get_all: GetAllRequest = GetAllRequest( + scope=self._scope, cache=self._cache_name, format=self._serializer.format + ) for key in keys: - g.key.append(self._serializer.serialize(key)) + get_all.key.append(self._serializer.serialize(key)) - return g + return get_all def put_if_absent_request(self, key: K, value: V, ttl: int = -1) -> PutIfAbsentRequest: p = PutIfAbsentRequest( @@ -206,7 +573,7 @@ def invoke_all_request( self, processor: EntryProcessor[R], keys: Optional[set[K]] = None, filter: Optional[Filter] = None ) -> InvokeAllRequest: if keys is not None and filter is not None: - raise ValueError("keys and filter are mutually exclusive") + raise ValueError(_KEYS_FILTERS_EXCLUSIVE) r = InvokeAllRequest( scope=self._scope, @@ -227,7 +594,7 @@ def aggregate_request( self, aggregator: EntryAggregator[R], keys: Optional[set[K]] = None, filter: Optional[Filter] = None ) -> AggregateRequest: if keys is not None and filter is not None: - raise ValueError("keys and filter are mutually exclusive") + raise ValueError(_KEYS_FILTERS_EXCLUSIVE) r: AggregateRequest = AggregateRequest( scope=self._scope, @@ -246,7 +613,7 @@ def aggregate_request( def values_request(self, filter: Optional[Filter] = None, comparator: Optional[Comparator] = None) -> ValuesRequest: if filter is None and comparator is not None: - raise ValueError("Filter cannot be None") + raise ValueError(_FILTER_REQUIRED) r: ValuesRequest = ValuesRequest( scope=self._scope, @@ -278,7 +645,7 @@ def entries_request( self, filter: Optional[Filter] = None, comparator: Optional[Comparator] = None ) -> EntrySetRequest: if filter is None and comparator is not None: - raise ValueError("Filter cannot be None") + raise ValueError(_FILTER_REQUIRED) r: EntrySetRequest = EntrySetRequest( scope=self._scope, @@ -327,14 +694,17 @@ def map_listener_request( request.priming = False if key is not None: # registering a key listener + # noinspection PyUnresolvedReferences request.type = MapListenerRequest.RequestType.KEY request.key = self._serializer.serialize(key) else: # registering a Filter listener + # noinspection PyUnresolvedReferences request.type = MapListenerRequest.RequestType.FILTER self.__next_filter_id += 1 request.filterId = self.__next_filter_id filter_local: Filter = filter if filter is not None else Filters.always() if not isinstance(filter_local, MapEventFilter): + # noinspection PyUnresolvedReferences filter_local = MapEventFilter.from_filter(filter_local) request.filter = self._serializer.serialize(filter_local) @@ -347,6 +717,7 @@ def map_event_subscribe(self) -> MapListenerRequest: ) request.uid = self.__generate_next_request_id("init") request.subscribe = True + # noinspection PyUnresolvedReferences request.type = MapListenerRequest.RequestType.INIT return request @@ -381,3 +752,559 @@ def remove_index_request(self, extractor: ValueExtractor[T, E]) -> RemoveIndexRe ) return r + + +class RequestIdGenerator: + _generator = None + + def __init__(self) -> None: + self._lock = threading.Lock() + self._counter = 0 + + @classmethod + def generator(cls) -> RequestIdGenerator: + if RequestIdGenerator._generator is None: + RequestIdGenerator._generator = RequestIdGenerator() + return RequestIdGenerator._generator + + @classmethod + def next(cls) -> int: + generator = cls.generator() + with generator._lock: + if generator._counter == sys.maxsize: + generator._counter = 0 + else: + generator._counter += 1 + return generator._counter + + +class RequestFactoryV1: + + def __init__( + self, cache_name: str, cache_id: int, scope: str, serializer: Serializer, timeout: Callable[[], float] + ) -> None: + self._cache_name: str = cache_name + self._cache_id: int = cache_id + self._scope: str = scope + self._timeout: Callable[[], float] = timeout + self._serializer: Serializer = serializer + + @property + def cache_id(self) -> int: + return self._cache_id + + @cache_id.setter + def cache_id(self, value: int) -> None: + self._cache_id = value + + @property + def request_timeout(self) -> float: + return self._timeout() + + def _create_named_cache_request(self, request: Any, request_type: NamedCacheRequestType) -> NamedCacheRequest: + any_cache_request = GrpcAny() + any_cache_request.Pack(request) + + return NamedCacheRequest( + type=request_type, + cacheId=self.cache_id, + message=any_cache_request, + ) + + def create_proxy_request(self, named_cache_request: NamedCacheRequest) -> ProxyRequest: + any_named_cache_request = GrpcAny() + any_named_cache_request.Pack(named_cache_request) + req_id = RequestIdGenerator.next() + proxy_request = ProxyRequest( + id=req_id, + message=any_named_cache_request, + ) + return proxy_request + + @staticmethod + def init_sub_channel( + scope: str = "", + serialization_format: str = "json", + protocol: str = "CacheService", + protocol_version: int = 1, + supported_protocol_version: int = 1, + heartbeat: int = 0, + ) -> ProxyRequest: + init_request = InitRequest( + scope=scope, + format=serialization_format, + protocol=protocol, + protocolVersion=protocol_version, + supportedProtocolVersion=supported_protocol_version, + heartbeat=heartbeat, + ) + + return ProxyRequest(id=2, init=init_request) + + def ensure_request(self, cache_name: str) -> UnaryDispatcher[int]: + cache_request = EnsureCacheRequest(cache=cache_name) + + any_cache_request = GrpcAny() + any_cache_request.Pack(cache_request) + + named_cache_request = NamedCacheRequest( + type=NamedCacheRequestType.EnsureCache, + message=any_cache_request, + ) + return UnaryDispatcher( + self.request_timeout, self.create_proxy_request(named_cache_request), CacheIdTransformer(self._serializer) + ) + + def put_request(self, key: K, value: V, ttl: int = 0) -> UnaryDispatcher[Optional[V]]: + request: NamedCacheRequest = self._create_named_cache_request( + V1PutRequest( + key=self._serializer.serialize(key), # Serialized key + value=self._serializer.serialize(value), # Serialized value + ttl=ttl, + ), + NamedCacheRequestType.Put, + ) + + return UnaryDispatcher( + self.request_timeout, self.create_proxy_request(request), OptionalValueTransformer(self._serializer) + ) + + def get_request(self, key: K) -> UnaryDispatcher[Optional[V]]: + request: NamedCacheRequest = self._create_named_cache_request( + BytesValue(value=self._serializer.serialize(key)), NamedCacheRequestType.Get + ) + + return UnaryDispatcher( + self.request_timeout, self.create_proxy_request(request), OptionalValueTransformer(self._serializer) + ) + + def get_all_request(self, keys: set[K]) -> StreamingDispatcher[MapEntry[K, V]]: + if keys is None: + raise ValueError("Must specify a set of keys") + + named_cache_request: NamedCacheRequest = self._create_named_cache_request( + CollectionOfBytesValues( + values=list(self._serializer.serialize(k) for k in keys), + ), + NamedCacheRequestType.GetAll, + ) + + return StreamingDispatcher( + self.request_timeout, self.create_proxy_request(named_cache_request), KeyValueTransformer(self._serializer) + ) + + def put_if_absent_request(self, key: K, value: V, ttl: int = 0) -> UnaryDispatcher[Optional[V]]: + request: NamedCacheRequest = self._create_named_cache_request( + V1PutRequest( + key=self._serializer.serialize(key), # Serialized key + value=self._serializer.serialize(value), # Serialized value + ttl=ttl, + ), + NamedCacheRequestType.PutIfAbsent, + ) + + return UnaryDispatcher( + self.request_timeout, self.create_proxy_request(request), BytesValueTransformer(self._serializer) + ) + + def put_all_request(self, kv_map: dict[K, V], ttl: Optional[int] = 0) -> Dispatcher: + request: NamedCacheRequest = self._create_named_cache_request( + V1PutAllRequest( + entries=list( + BinaryKeyAndValue(key=self._serializer.serialize(k), value=self._serializer.serialize(v)) + for k, v in kv_map.items() + ), + ttl=ttl, + ), + NamedCacheRequestType.PutAll, + ) + + return UnaryDispatcher(self.request_timeout, self.create_proxy_request(request)) + + def clear_request(self) -> Dispatcher: + named_cache_request = NamedCacheRequest( + type=NamedCacheRequestType.Clear, + cacheId=self.cache_id, + ) + return UnaryDispatcher(self.request_timeout, self.create_proxy_request(named_cache_request)) + + def destroy_request(self) -> Dispatcher: + named_cache_request: NamedCacheRequest = NamedCacheRequest( + type=NamedCacheRequestType.Destroy, + cacheId=self.cache_id, + ) + + return UnaryDispatcher(self.request_timeout, self.create_proxy_request(named_cache_request)) + + def truncate_request(self) -> Dispatcher: + named_cache_request: NamedCacheRequest = NamedCacheRequest( + type=NamedCacheRequestType.Truncate, + cacheId=self.cache_id, + ) + return UnaryDispatcher(self.request_timeout, self.create_proxy_request(named_cache_request)) + + def remove_request(self, key: K) -> UnaryDispatcher[Optional[V]]: + named_cache_request: NamedCacheRequest = self._create_named_cache_request( + BytesValue(value=self._serializer.serialize(key)), NamedCacheRequestType.Remove + ) + + return UnaryDispatcher( + self.request_timeout, + self.create_proxy_request(named_cache_request), + BytesValueTransformer(self._serializer), + ) + + def remove_mapping_request(self, key: K, value: V) -> UnaryDispatcher[bool]: + named_cache_request: NamedCacheRequest = self._create_named_cache_request( + BinaryKeyAndValue(key=self._serializer.serialize(key), value=self._serializer.serialize(value)), + NamedCacheRequestType.RemoveMapping, + ) + + return UnaryDispatcher( + self.request_timeout, self.create_proxy_request(named_cache_request), BoolValueTransformer(self._serializer) + ) + + def replace_request(self, key: K, value: V) -> UnaryDispatcher[Optional[V]]: + named_cache_request: NamedCacheRequest = self._create_named_cache_request( + BinaryKeyAndValue(key=self._serializer.serialize(key), value=self._serializer.serialize(value)), + NamedCacheRequestType.Replace, + ) + + return UnaryDispatcher( + self.request_timeout, + self.create_proxy_request(named_cache_request), + BytesValueTransformer(self._serializer), + ) + + def replace_mapping_request(self, key: K, old_value: V, new_value: V) -> UnaryDispatcher[bool]: + named_cache_request: NamedCacheRequest = self._create_named_cache_request( + V1ReplaceMappingRequest( + key=self._serializer.serialize(key), + previousValue=self._serializer.serialize(old_value), + newValue=self._serializer.serialize(new_value), + ), + NamedCacheRequestType.ReplaceMapping, + ) + + return UnaryDispatcher( + self.request_timeout, self.create_proxy_request(named_cache_request), BoolValueTransformer(self._serializer) + ) + + def contains_key_request(self, key: K) -> UnaryDispatcher[bool]: + named_cache_request = self._create_named_cache_request( + BytesValue(value=self._serializer.serialize(key)), NamedCacheRequestType.ContainsKey + ) + + return UnaryDispatcher( + self.request_timeout, self.create_proxy_request(named_cache_request), BoolValueTransformer(self._serializer) + ) + + def contains_value_request(self, value: V) -> UnaryDispatcher[bool]: + named_cache_request = self._create_named_cache_request( + BytesValue(value=self._serializer.serialize(value)), NamedCacheRequestType.ContainsValue + ) + + return UnaryDispatcher( + self.request_timeout, self.create_proxy_request(named_cache_request), BoolValueTransformer(self._serializer) + ) + + def is_empty_request(self) -> UnaryDispatcher[bool]: + named_cache_request = NamedCacheRequest( + type=NamedCacheRequestType.IsEmpty, + cacheId=self.cache_id, + ) + + return UnaryDispatcher( + self.request_timeout, self.create_proxy_request(named_cache_request), BoolValueTransformer(self._serializer) + ) + + def size_request(self) -> UnaryDispatcher[int]: + named_cache_request = NamedCacheRequest( + type=NamedCacheRequestType.Size, + cacheId=self.cache_id, + ) + + return UnaryDispatcher( + self.request_timeout, self.create_proxy_request(named_cache_request), IntValueTransformer(self._serializer) + ) + + def invoke_request(self, key: K, processor: EntryProcessor[R]) -> UnaryDispatcher[Optional[R]]: + named_cache_request: NamedCacheRequest = self._create_named_cache_request( + ExecuteRequest( + agent=self._serializer.serialize(processor), + keys=KeysOrFilter( + key=self._serializer.serialize(key), + ), + ), + NamedCacheRequestType.Invoke, + ) + + return UnaryDispatcher( + self.request_timeout, self.create_proxy_request(named_cache_request), ValueTransformer(self._serializer) + ) + + def invoke_all_request( + self, processor: EntryProcessor[R], keys: Optional[set[K]] = None, filter: Optional[Filter] = None + ) -> StreamingDispatcher[MapEntry[K, R]]: + if keys is not None and filter is not None: + raise ValueError(_KEYS_FILTERS_EXCLUSIVE) + + if keys is not None: + cache_request = ExecuteRequest( + agent=self._serializer.serialize(processor), + keys=KeysOrFilter( + keys=CollectionOfBytesValues( + values=list(self._serializer.serialize(key) for key in keys), + ), + ), + ) + elif filter is not None: + cache_request = ExecuteRequest( + agent=self._serializer.serialize(processor), + keys=KeysOrFilter( + filter=self._serializer.serialize(filter), + ), + ) + else: + cache_request = ExecuteRequest( + agent=self._serializer.serialize(processor), + ) + + named_cache_request: NamedCacheRequest = self._create_named_cache_request( + cache_request, NamedCacheRequestType.Invoke + ) + + return StreamingDispatcher( + self.request_timeout, self.create_proxy_request(named_cache_request), KeyValueTransformer(self._serializer) + ) + + def aggregate_request( + self, aggregator: EntryAggregator[R], keys: Optional[set[K]] = None, filter: Optional[Filter] = None + ) -> UnaryDispatcher[Optional[R]]: + if keys is not None and filter is not None: + raise ValueError(_KEYS_FILTERS_EXCLUSIVE) + + if keys is not None: + cache_request = ExecuteRequest( + agent=self._serializer.serialize(aggregator), + keys=KeysOrFilter( + keys=CollectionOfBytesValues( + values=list(self._serializer.serialize(key) for key in keys), + ), + ), + ) + elif filter is not None: + cache_request = ExecuteRequest( + agent=self._serializer.serialize(aggregator), + keys=KeysOrFilter( + filter=self._serializer.serialize(filter), + ), + ) + else: + cache_request = ExecuteRequest( + agent=self._serializer.serialize(aggregator), + ) + + named_cache_request: NamedCacheRequest = self._create_named_cache_request( + cache_request, NamedCacheRequestType.Aggregate + ) + return UnaryDispatcher( + self.request_timeout, + self.create_proxy_request(named_cache_request), + BytesValueTransformer(self._serializer), + ) + + def values_request( + self, filter: Optional[Filter] = None, comparator: Optional[Comparator] = None + ) -> StreamingDispatcher[V]: + if filter is None and comparator is not None: + raise ValueError(_FILTER_REQUIRED) + + if filter is not None: + query_request = QueryRequest(filter=self._serializer.serialize(filter)) + elif comparator is not None: + query_request = QueryRequest(comparator=self._serializer.serialize(comparator)) + else: + query_request = QueryRequest() + + named_cache_request: NamedCacheRequest = self._create_named_cache_request( + query_request, NamedCacheRequestType.QueryValues + ) + + return StreamingDispatcher( + self.request_timeout, + self.create_proxy_request(named_cache_request), + BytesValueTransformer(self._serializer), # type: ignore + ) + + def keys_request(self, filter: Optional[Filter] = None) -> StreamingDispatcher[K]: + + if filter is not None: + query_request = QueryRequest(filter=self._serializer.serialize(filter)) + else: + query_request = QueryRequest() + + named_cache_request: NamedCacheRequest = self._create_named_cache_request( + query_request, NamedCacheRequestType.QueryKeys + ) + + return StreamingDispatcher( + self.request_timeout, + self.create_proxy_request(named_cache_request), + BytesValueTransformer(self._serializer), # type: ignore + ) + + def entries_request( + self, filter: Optional[Filter] = None, comparator: Optional[Comparator] = None + ) -> StreamingDispatcher[MapEntry[K, V]]: + if filter is None and comparator is not None: + raise ValueError(_FILTER_REQUIRED) + + if filter is not None: + query_request = QueryRequest(filter=self._serializer.serialize(filter)) + elif comparator is not None: + query_request = QueryRequest(comparator=self._serializer.serialize(comparator)) + else: + query_request = QueryRequest() + + named_cache_request: NamedCacheRequest = self._create_named_cache_request( + query_request, NamedCacheRequestType.QueryEntries + ) + + return StreamingDispatcher( + self.request_timeout, self.create_proxy_request(named_cache_request), KeyValueTransformer(self._serializer) + ) + + def page_request(self, keys_only: bool = False, values_only: bool = False) -> PagingDispatcher[T]: + """ + Creates a gRPC PageRequest. + + :param keys_only: flag indicating interest in only keys + :param values_only: flag indicating interest in only values + :return: a new PageRequest + """ + if keys_only and values_only: + raise ValueError("keys_only and values_only cannot be True at the same time") + + if keys_only: + return PagingDispatcher( + self.request_timeout, + self._page_of_keys_creator(None), + self._page_of_keys_creator, + BytesValueTransformer(self._serializer), # type: ignore + ) + elif values_only: + return PagingDispatcher( + self.request_timeout, + self._page_of_entries_creator(None), + self._page_of_entries_creator, + ValueTransformer(self._serializer), + ) + else: + return PagingDispatcher( + self.request_timeout, + self._page_of_entries_creator(None), + self._page_of_entries_creator, + KeyValueTransformer(self._serializer), # type: ignore + ) + + def _page_of_keys_creator(self, cookie: Optional[bytes]) -> ProxyRequest: + if cookie is None: + cookie_bytes = BytesValue() + else: + cookie_bytes = BytesValue(value=cookie) + + return self.create_proxy_request( + self._create_named_cache_request(cookie_bytes, NamedCacheRequestType.PageOfKeys) + ) + + def _page_of_entries_creator(self, cookie: Optional[bytes]) -> ProxyRequest: + if cookie is None: + cookie_bytes = BytesValue() + else: + cookie_bytes = BytesValue(value=cookie) + + return self.create_proxy_request( + self._create_named_cache_request(cookie_bytes, NamedCacheRequestType.PageOfEntries) + ) + + def map_listener_request( + self, + subscribe: bool, + lite: bool = False, + sync: bool = False, + priming: bool = False, + *, + key: Optional[K] = None, + filter: Optional[Filter] = None, + filter_id: int = -1, + ) -> Tuple[UnaryDispatcher[Any], ProxyRequest, int]: + """Creates a gRPC generated MapListenerRequest""" + + if key is None and filter is None: + raise AssertionError("Must specify a key or a filter") + + if key is None: # registering a Filter listener + filter_local: Filter = filter # type: ignore + if not isinstance(filter_local, MapEventFilter): + # noinspection PyUnresolvedReferences + filter_local = MapEventFilter.from_filter(filter_local) + listener_request: V1MapListenerRequest = V1MapListenerRequest( + subscribe=subscribe, + lite=lite, + synchronous=sync, + priming=priming, + filterId=RequestIdGenerator.next() if filter_id == -1 else filter_id, + keyOrFilter=KeyOrFilter(filter=self._serializer.serialize(filter_local)), + ) + else: # registering a key listener + listener_request = V1MapListenerRequest( + subscribe=subscribe, + lite=lite, + synchronous=sync, + priming=priming, + keyOrFilter=KeyOrFilter(key=self._serializer.serialize(key)), + ) + + named_cache_request: NamedCacheRequest = self._create_named_cache_request( + listener_request, NamedCacheRequestType.MapListener + ) + proxy_request: ProxyRequest = self.create_proxy_request(named_cache_request) + + return ( + UnaryDispatcher(self.request_timeout, proxy_request), + proxy_request, + listener_request.filterId, + ) + + def add_index_request( + self, extractor: ValueExtractor[T, E], ordered: bool = False, comparator: Optional[Comparator] = None + ) -> Dispatcher: + if comparator is None: + named_cache_request: NamedCacheRequest = self._create_named_cache_request( + IndexRequest(add=True, extractor=self._serializer.serialize(extractor), sorted=ordered), + NamedCacheRequestType.Index, + ) + else: + named_cache_request = self._create_named_cache_request( + IndexRequest( + add=True, + extractor=self._serializer.serialize(extractor), + sorted=ordered, + comparator=self._serializer.serialize(extractor), + ), + NamedCacheRequestType.Index, + ) + + return UnaryDispatcher(self.request_timeout, self.create_proxy_request(named_cache_request)) + + def remove_index_request(self, extractor: ValueExtractor[T, E]) -> Dispatcher: + named_cache_request: NamedCacheRequest = self._create_named_cache_request( + IndexRequest( + add=False, + extractor=self._serializer.serialize(extractor), + ), + NamedCacheRequestType.Index, + ) + + return UnaryDispatcher(self.request_timeout, self.create_proxy_request(named_cache_request)) diff --git a/tests/Task.py b/tests/Task.py deleted file mode 100644 index 20a6b73..0000000 --- a/tests/Task.py +++ /dev/null @@ -1,35 +0,0 @@ -from time import time -from uuid import uuid4 - -from coherence import serialization - - -@serialization.proxy("Task") -@serialization.mappings({"created_at": "createdAt"}) -class Task: - def __init__(self, description: str) -> None: - super().__init__() - self.id: str = str(uuid4())[0:6] - self.description: str = description - self.completed: bool = False - self.created_at: int = int(time() * 1000) - - def __hash__(self) -> int: - return hash((self.id, self.description, self.completed, self.created_at)) - - def __eq__(self, o: object) -> bool: - if isinstance(o, Task): - # noinspection PyTypeChecker - t: Task = o - return ( - self.id == t.id - and self.description == t.description - and self.completed == t.completed - and self.created_at == t.created_at - ) - return False - - def __str__(self) -> str: - return 'Task(id="{}", description="{}", completed={}, created_at={})'.format( - self.id, self.description, self.completed, self.created_at - ) diff --git a/tests/__init__.py b/tests/__init__.py index 7f50b60..912f3de 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,16 +1,18 @@ -# Copyright (c) 2022 Oracle and/or its affiliates. +# Copyright (c) 2022, 2024, Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at # https://oss.oracle.com/licenses/upl. import asyncio import logging.config import os from asyncio import Event -from typing import Final, List, TypeVar +from typing import Any, Final, List, TypeVar import pytest from coherence import Options, Session, TlsOptions from coherence.event import MapEvent, MapListener +from coherence.processor import EntryProcessor +from coherence.serialization import proxy K = TypeVar("K") """Generic type for cache keys""" @@ -19,7 +21,7 @@ """Generic type for cache values""" # logging configuration for tests -logging_config: str = "tests/logging.conf" # executing from project root +logging_config: str = os.path.dirname(__file__) + "/logging.conf" # executing from project root if not os.path.exists(logging_config): logging_config = "logging.conf" # executing from tests directory (most likely IntelliJ) @@ -184,6 +186,9 @@ async def get_session(wait_for_ready: float = 0) -> Session: run_secure: Final[str] = "RUN_SECURE" session: Session + options: Options = Options( + default_address, default_scope, default_request_timeout, wait_for_ready, ser_format=default_format + ) if run_secure in os.environ: # Default TlsOptions constructor will pick up the SSL Certs and @@ -195,17 +200,19 @@ async def get_session(wait_for_ready: float = 0) -> Session: tls_options.enabled = True tls_options.locked() - options: Options = Options( - default_address, default_scope, default_request_timeout, wait_for_ready, ser_format=default_format - ) options.tls_options = tls_options options.channel_options = (("grpc.ssl_target_name_override", "Star-Lord"),) session = await Session.create(options) else: - session = await Session.create(Options(ready_timeout_seconds=wait_for_ready)) + session = await Session.create(options) return session async def wait_for(event: Event, timeout: float) -> None: await asyncio.wait_for(event.wait(), timeout) + + +@proxy("test.longrunning") +class LongRunningProcessor(EntryProcessor[Any]): + pass diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..5289909 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,33 @@ +# Copyright (c) 2024, Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at +# https://oss.oracle.com/licenses/upl. +import asyncio +import time +from typing import Any, AsyncGenerator + +import pytest_asyncio + +import tests +from coherence import NamedCache, Session +from tests.person import Person + + +@pytest_asyncio.fixture +async def test_session() -> AsyncGenerator[Session, None]: + session: Session = await tests.get_session() + yield session + await session.close() + await asyncio.sleep(0) # helps avoid loop already closed errors + + +@pytest_asyncio.fixture +async def cache(test_session: Session) -> AsyncGenerator[NamedCache[Any, Any], None]: + cache: NamedCache[Any, Any] = await test_session.get_cache("test-" + str(time.time_ns())) + yield cache + await cache.truncate() + + +@pytest_asyncio.fixture +async def person_cache(cache: NamedCache[Any, Any]) -> AsyncGenerator[NamedCache[str, Person], None]: + await Person.populate_named_map(cache) + yield cache diff --git a/tests/e2e/__init__.py b/tests/e2e/__init__.py new file mode 100644 index 0000000..0945b75 --- /dev/null +++ b/tests/e2e/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) 2024 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at +# https://oss.oracle.com/licenses/upl. diff --git a/tests/test_aggregators.py b/tests/e2e/test_aggregators.py similarity index 100% rename from tests/test_aggregators.py rename to tests/e2e/test_aggregators.py diff --git a/tests/e2e/test_ai.py b/tests/e2e/test_ai.py new file mode 100644 index 0000000..59c64ea --- /dev/null +++ b/tests/e2e/test_ai.py @@ -0,0 +1,155 @@ +# Copyright (c) 2022, 2024, Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at +# https://oss.oracle.com/licenses/upl. +import random +import time +from typing import List, Optional, cast + +import pytest + +from coherence import COH_LOG, Extractors, NamedCache, Session +from coherence.ai import BinaryQuantIndex, DocumentChunk, FloatVector, SimilaritySearch, Vectors + + +class ValueWithVector: + def __init__(self, vector: FloatVector, text: str, number: int) -> None: + self.vector = vector + self.text = text + self.number = number + + def get_vector(self) -> FloatVector: + return self.vector + + def get_text(self) -> str: + return self.text + + def get_number(self) -> int: + return self.number + + def __repr__(self) -> str: + return f"ValueWithVector(vector={self.vector}, text='{self.text}', number={self.number})" + + +def random_floats(n: int) -> List[float]: + floats: List[float] = [0.0] * n + for i in range(n): + floats[i] = random.uniform(-50.0, 50.0) + return floats + + +DIMENSIONS: int = 384 + + +async def populate_vectors(vectors: NamedCache[int, ValueWithVector]) -> ValueWithVector: + matches: List[List[float]] = [[]] * 5 + matches[0] = random_floats(DIMENSIONS) + + # Creating copies of matches[0] for matches[1] to matches[4] + for i in range(1, 5): + matches[i] = matches[0].copy() + matches[i][0] += 1.0 # Modify the first element + + count = 10000 + values: List[Optional[ValueWithVector]] = [None] * count + + # Assign normalized vectors to the first 5 entries + for i in range(5): + values[i] = ValueWithVector(FloatVector(Vectors.normalize(matches[i])), str(i), i) + await vectors.put(i, values[i]) + + # Fill the remaining values with random vectors + for i in range(5, count): + values[i] = ValueWithVector(FloatVector(Vectors.normalize(random_floats(DIMENSIONS))), str(i), i) + await vectors.put(i, values[i]) + + return cast(ValueWithVector, values[0]) + + +async def populate_document_chunk_vectors(vectors: NamedCache[int, DocumentChunk]) -> DocumentChunk: + matches: List[List[float]] = [[]] * 5 + matches[0] = random_floats(DIMENSIONS) + + # Creating copies of matches[0] for matches[1] to matches[4] + for i in range(1, 5): + matches[i] = matches[0].copy() + matches[i][0] += 1.0 # Modify the first element + + count = 10000 + values: List[Optional[DocumentChunk]] = [None] * count + + # Assign normalized vectors to the first 5 entries + for i in range(5): + values[i] = DocumentChunk(str(i), metadata=None, vector=FloatVector(Vectors.normalize(matches[i]))) + await vectors.put(i, values[i]) + + # Fill the remaining values with random vectors + for i in range(5, count): + values[i] = DocumentChunk( + str(i), metadata=None, vector=FloatVector(Vectors.normalize(random_floats(DIMENSIONS))) + ) + await vectors.put(i, values[i]) + + return cast(DocumentChunk, values[0]) + + +@pytest.mark.asyncio +async def test_similarity_search_with_index(test_session: Session) -> None: + cache: NamedCache[int, ValueWithVector] = await test_session.get_cache("vector_cache") + cache.add_index(BinaryQuantIndex(Extractors.extract("vector"))) + value_with_vector = await populate_vectors(cache) + + # Create a SimilaritySearch aggregator + value_extractor = Extractors.extract("vector") + k = 10 + ss = SimilaritySearch(value_extractor, value_with_vector.vector, k) + + ss.bruteForce = True # Set bruteForce to True + start_time_bf = time.perf_counter() + hnsw_result = await cache.aggregate(ss) + end_time_bf = time.perf_counter() + elapsed_time = end_time_bf - start_time_bf + COH_LOG.info("Results below for test_SimilaritySearch with BruteForce true:") + for e in hnsw_result: + COH_LOG.info(e) + COH_LOG.info(f"Elapsed time for brute force: {elapsed_time} seconds") + + assert hnsw_result is not None + assert len(hnsw_result) == k + + ss.bruteForce = False + start_time = time.perf_counter() + hnsw_result = await cache.aggregate(ss) + end_time = time.perf_counter() + elapsed_time = end_time - start_time + COH_LOG.info("Results below for test_SimilaritySearch with Index:") + for e in hnsw_result: + COH_LOG.info(e) + COH_LOG.info(f"Elapsed time: {elapsed_time} seconds") + + assert hnsw_result is not None + assert len(hnsw_result) == k + + await cache.truncate() + await cache.destroy() + + +@pytest.mark.asyncio +async def test_similarity_search_with_document_chunk(test_session: Session) -> None: + cache: NamedCache[int, DocumentChunk] = await test_session.get_cache("vector_cache") + dc: DocumentChunk = await populate_document_chunk_vectors(cache) + + # Create a SimilaritySearch aggregator + value_extractor = Extractors.extract("vector") + k = 10 + ss = SimilaritySearch(value_extractor, dc.vector, k) + + hnsw_result = await cache.aggregate(ss) + + assert hnsw_result is not None + assert len(hnsw_result) == k + COH_LOG.info("Results below for test_SimilaritySearch_with_DocumentChunk:") + for e in hnsw_result: + COH_LOG.info(e) + + await cache.truncate() + await cache.destroy() diff --git a/tests/test_client.py b/tests/e2e/test_client.py similarity index 61% rename from tests/test_client.py rename to tests/e2e/test_client.py index 09de73b..0495b9e 100644 --- a/tests/test_client.py +++ b/tests/e2e/test_client.py @@ -2,15 +2,18 @@ # Licensed under the Universal Permissive License v 1.0 as shown at # https://oss.oracle.com/licenses/upl. +import asyncio from asyncio import Event from time import sleep, time -from typing import Any, AsyncGenerator, Dict, Final, List, Optional, Set, TypeVar, Union +from typing import Dict, Final, List, Optional, Set, TypeVar, Union import pytest -import pytest_asyncio +from grpc import StatusCode +from grpc.aio import AioRpcError import tests -from coherence import Aggregators, Filters, MapEntry, NamedCache, Session +from coherence import Aggregators, Filters, MapEntry, NamedCache, Session, request_timeout +from coherence.client import CacheOptions from coherence.event import MapLifecycleEvent from coherence.extractor import ChainedExtractor, Extractors, UniversalExtractor from coherence.processor import ExtractorProcessor @@ -40,39 +43,12 @@ async def _insert_large_number_of_entries(cache: NamedCache[str, str]) -> int: return num_entries -@pytest_asyncio.fixture -async def setup_and_teardown() -> AsyncGenerator[NamedCache[Any, Any], None]: - session: Session = await tests.get_session() - - cache: NamedCache[Any, Any] = await session.get_cache("test") - - yield cache # this is what is returned to the test functions - - await cache.truncate() - await session.close() - - -@pytest_asyncio.fixture -async def setup_and_teardown_person_cache() -> AsyncGenerator[NamedCache[str, Person], None]: - session: Session = await tests.get_session() - cache: NamedCache[str, Person] = await session.get_cache("test") - - await Person.populate_named_map(cache) - - yield cache - - await cache.truncate() - await session.close() - - # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_get_and_put(setup_and_teardown: NamedCache[str, Union[str, int, Person]]) -> None: - cache: NamedCache[str, Union[str, int, Person]] = setup_and_teardown - +async def test_get_and_put(cache: NamedCache[str, Union[str, int, Person]]) -> None: k: str = "one" v: str = "only-one" - # c.put(k, v, 60000) + await cache.put(k, v) r = await cache.get(k) assert r == v @@ -94,9 +70,7 @@ async def test_get_and_put(setup_and_teardown: NamedCache[str, Union[str, int, P # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_put_with_ttl(setup_and_teardown: NamedCache[str, Union[str, int]]) -> None: - cache: NamedCache[str, Union[str, int, Person]] = setup_and_teardown - +async def test_put_with_ttl(cache: NamedCache[str, Union[str, int]]) -> None: k: str = "one" v: str = "only-one" await cache.put(k, v, 5000) # TTL of 5 seconds @@ -110,9 +84,7 @@ async def test_put_with_ttl(setup_and_teardown: NamedCache[str, Union[str, int]] # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_put_if_absent(setup_and_teardown: NamedCache[str, str]) -> None: - cache: NamedCache[str, str] = setup_and_teardown - +async def test_put_if_absent(cache: NamedCache[str, str]) -> None: k: str = "one" v: str = "only-one" await cache.put(k, v) @@ -127,9 +99,7 @@ async def test_put_if_absent(setup_and_teardown: NamedCache[str, str]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_keys_filtered(setup_and_teardown: NamedCache[str, str]) -> None: - cache: NamedCache[str, str] = setup_and_teardown - +async def test_keys_filtered(cache: NamedCache[str, str]) -> None: k: str = "one" v: str = "only-one" await cache.put(k, v) @@ -141,7 +111,7 @@ async def test_keys_filtered(setup_and_teardown: NamedCache[str, str]) -> None: await cache.put(k2, v2) local_set: Set[str] = set() - async for e in cache.keys(Filters.equals("length()", 8)): + async for e in await cache.keys(Filters.equals("length()", 8)): local_set.add(e) assert len(local_set) == 2 @@ -151,16 +121,14 @@ async def test_keys_filtered(setup_and_teardown: NamedCache[str, str]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_keys_paged(setup_and_teardown: NamedCache[str, str]) -> None: - cache: NamedCache[str, str] = setup_and_teardown - +async def test_keys_paged(cache: NamedCache[str, str]) -> None: # insert enough data into the cache to ensure results will be paged # by the proxy. num_entries: int = await _insert_large_number_of_entries(cache) # Stream the keys and locally cache the results local_set: Set[str] = set() - async for e in cache.keys(by_page=True): + async for e in await cache.keys(by_page=True): local_set.add(e) assert len(local_set) == num_entries @@ -168,9 +136,7 @@ async def test_keys_paged(setup_and_teardown: NamedCache[str, str]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_entries_filtered(setup_and_teardown: NamedCache[str, str]) -> None: - cache: NamedCache[str, str] = setup_and_teardown - +async def test_entries_filtered(cache: NamedCache[str, str]) -> None: k: str = "one" v: str = "only-one" await cache.put(k, v) @@ -182,7 +148,7 @@ async def test_entries_filtered(setup_and_teardown: NamedCache[str, str]) -> Non await cache.put(k2, v2) local_dict: Dict[str, str] = {} - async for e in cache.entries(Filters.equals("length()", 8)): + async for e in await cache.entries(Filters.equals("length()", 8)): local_dict[e.key] = e.value assert len(local_dict) == 2 @@ -192,9 +158,7 @@ async def test_entries_filtered(setup_and_teardown: NamedCache[str, str]) -> Non # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_entries_paged(setup_and_teardown: NamedCache[str, str]) -> None: - cache: NamedCache[str, str] = setup_and_teardown - +async def test_entries_paged(cache: NamedCache[str, str]) -> None: # insert enough data into the cache to ensure results will be paged # by the proxy. num_entries = await _insert_large_number_of_entries(cache) @@ -203,16 +167,14 @@ async def test_entries_paged(setup_and_teardown: NamedCache[str, str]) -> None: # Stream the keys and locally cache the results local_dict: Dict[str, str] = {} - async for e in cache.entries(by_page=True): + async for e in await cache.entries(by_page=True): local_dict[e.key] = e.value assert len(local_dict) == num_entries @pytest.mark.asyncio -async def test_values_filtered(setup_and_teardown: NamedCache[str, str]) -> None: - cache: NamedCache[str, str] = setup_and_teardown - +async def test_values_filtered(cache: NamedCache[str, str]) -> None: k: str = "one" v: str = "only-one" await cache.put(k, v) @@ -224,7 +186,7 @@ async def test_values_filtered(setup_and_teardown: NamedCache[str, str]) -> None await cache.put(k2, v2) local_list: List[str] = [] - async for e in cache.values(Filters.equals("length()", 8)): + async for e in await cache.values(Filters.equals("length()", 8)): local_list.append(e) assert len(local_list) == 2 @@ -234,16 +196,14 @@ async def test_values_filtered(setup_and_teardown: NamedCache[str, str]) -> None # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_values_paged(setup_and_teardown: NamedCache[str, str]) -> None: - cache: NamedCache[str, str] = setup_and_teardown - +async def test_values_paged(cache: NamedCache[str, str]) -> None: # insert enough data into the cache to ensure results will be paged # by the proxy. num_entries: int = await _insert_large_number_of_entries(cache) # Stream the keys and locally cache the results local_list: List[str] = [] - async for e in cache.values(by_page=True): + async for e in await cache.values(by_page=True): local_list.append(e) assert len(local_list) == num_entries @@ -251,9 +211,7 @@ async def test_values_paged(setup_and_teardown: NamedCache[str, str]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_put_all(setup_and_teardown: NamedCache[str, str]) -> None: - cache: NamedCache[str, str] = setup_and_teardown - +async def test_put_all(cache: NamedCache[str, str]) -> None: k1: str = "three" v1: str = "only-three" k2: str = "four" @@ -268,9 +226,7 @@ async def test_put_all(setup_and_teardown: NamedCache[str, str]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_get_or_default(setup_and_teardown: NamedCache[str, str]) -> None: - cache: NamedCache[str, str] = setup_and_teardown - +async def test_get_or_default(cache: NamedCache[str, str]) -> None: k1: str = "one" v1: str = "only-one" await cache.put(k1, v1) @@ -284,9 +240,7 @@ async def test_get_or_default(setup_and_teardown: NamedCache[str, str]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_get_all(setup_and_teardown: NamedCache[str, str]) -> None: - cache: NamedCache[str, str] = setup_and_teardown - +async def test_get_all(cache: NamedCache[str, str]) -> None: k1: str = "one" v1: str = "only-one" await cache.put(k1, v1) @@ -300,7 +254,7 @@ async def test_get_all(setup_and_teardown: NamedCache[str, str]) -> None: await cache.put(k3, v3) r: Dict[str, str] = {} - async for e in cache.get_all({k1, k3}): + async for e in await cache.get_all({k1, k3}): r[e.key] = e.value assert r == {k1: v1, k3: v3} @@ -308,9 +262,7 @@ async def test_get_all(setup_and_teardown: NamedCache[str, str]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_get_all_no_keys_raises_error(setup_and_teardown: NamedCache[str, str]) -> None: - cache: NamedCache[str, str] = setup_and_teardown - +async def test_get_all_no_keys_raises_error(cache: NamedCache[str, str]) -> None: with pytest.raises(ValueError): # noinspection PyTypeChecker await cache.get_all(None) @@ -318,9 +270,7 @@ async def test_get_all_no_keys_raises_error(setup_and_teardown: NamedCache[str, # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_remove(setup_and_teardown: NamedCache[str, str]) -> None: - cache: NamedCache[str, str] = setup_and_teardown - +async def test_remove(cache: NamedCache[str, str]) -> None: k1: str = "one" v1: str = "only-one" await cache.put(k1, v1) @@ -334,9 +284,7 @@ async def test_remove(setup_and_teardown: NamedCache[str, str]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_remove_mapping(setup_and_teardown: NamedCache[str, str]) -> None: - cache: NamedCache[str, str] = setup_and_teardown - +async def test_remove_mapping(cache: NamedCache[str, str]) -> None: k1: str = "one" v1: str = "only-one" await cache.put(k1, v1) @@ -350,9 +298,7 @@ async def test_remove_mapping(setup_and_teardown: NamedCache[str, str]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_replace(setup_and_teardown: NamedCache[str, str]) -> None: - cache: NamedCache[str, str] = setup_and_teardown - +async def test_replace(cache: NamedCache[str, str]) -> None: k1: str = "one" v1: str = "only-one" await cache.put(k1, v1) @@ -364,9 +310,7 @@ async def test_replace(setup_and_teardown: NamedCache[str, str]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_replace_mapping(setup_and_teardown: NamedCache[str, str]) -> None: - cache: NamedCache[str, str] = setup_and_teardown - +async def test_replace_mapping(cache: NamedCache[str, str]) -> None: k1: str = "one" v1: str = "only-one" await cache.put(k1, v1) @@ -378,9 +322,7 @@ async def test_replace_mapping(setup_and_teardown: NamedCache[str, str]) -> None # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_contains_key(setup_and_teardown: NamedCache[str, str]) -> None: - cache: NamedCache[str, str] = setup_and_teardown - +async def test_contains_key(cache: NamedCache[str, str]) -> None: k1: str = "one" v1: str = "only-one" await cache.put(k1, v1) @@ -394,9 +336,7 @@ async def test_contains_key(setup_and_teardown: NamedCache[str, str]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_contains_value(setup_and_teardown: NamedCache[str, str]) -> None: - cache: NamedCache[str, str] = setup_and_teardown - +async def test_contains_value(cache: NamedCache[str, str]) -> None: k1: str = "one" v1: str = "only-one" await cache.put(k1, v1) @@ -410,9 +350,7 @@ async def test_contains_value(setup_and_teardown: NamedCache[str, str]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_is_empty(setup_and_teardown: NamedCache[str, str]) -> None: - cache: NamedCache[str, str] = setup_and_teardown - +async def test_is_empty(cache: NamedCache[str, str]) -> None: k1: str = "one" v1: str = "only-one" await cache.put(k1, v1) @@ -427,9 +365,7 @@ async def test_is_empty(setup_and_teardown: NamedCache[str, str]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_size(setup_and_teardown: NamedCache[str, str]) -> None: - cache: NamedCache[str, str] = setup_and_teardown - +async def test_size(cache: NamedCache[str, str]) -> None: k1: str = "one" v1: str = "only-one" await cache.put(k1, v1) @@ -449,9 +385,7 @@ async def test_size(setup_and_teardown: NamedCache[str, str]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_invoke(setup_and_teardown: NamedCache[str, Union[str, Person]]) -> None: - cache: NamedCache[str, Union[str, Person]] = setup_and_teardown - +async def test_invoke(cache: NamedCache[str, Union[str, Person]]) -> None: k1: str = "one" v1: str = "only-one" await cache.put(k1, v1) @@ -482,9 +416,7 @@ async def test_invoke(setup_and_teardown: NamedCache[str, Union[str, Person]]) - # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_invoke_all_keys(setup_and_teardown: NamedCache[str, str]) -> None: - cache: NamedCache[str, str] = setup_and_teardown - +async def test_invoke_all_keys(cache: NamedCache[str, str]) -> None: k1: str = "one" v1: str = "only-one" await cache.put(k1, v1) @@ -499,7 +431,7 @@ async def test_invoke_all_keys(setup_and_teardown: NamedCache[str, str]) -> None r: Dict[str, int] = {} e: MapEntry[str, int] - async for e in cache.invoke_all(ExtractorProcessor(UniversalExtractor("length()")), keys={k1, k3}): + async for e in await cache.invoke_all(ExtractorProcessor(UniversalExtractor("length()")), keys={k1, k3}): r[e.key] = e.value assert r == {k1: 8, k3: 10} @@ -509,9 +441,8 @@ async def test_invoke_all_keys(setup_and_teardown: NamedCache[str, str]) -> None # noinspection PyShadowingNames -@pytest.mark.asyncio -async def test_cache_truncate_event(setup_and_teardown: NamedCache[str, str]) -> None: - cache: NamedCache[str, str] = setup_and_teardown +@pytest.mark.asyncio(loop_scope="function") +async def test_cache_truncate_event(cache: NamedCache[str, str]) -> None: name: str = "UNSET" event: Event = Event() @@ -534,7 +465,7 @@ def callback(n: str) -> None: # noinspection PyShadowingNames,DuplicatedCode -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="function") async def test_cache_release_event() -> None: session: Session = await tests.get_session() cache: NamedCache[str, str] = await session.get_cache("test-" + str(int(time() * 1000))) @@ -553,7 +484,7 @@ def callback(n: str) -> None: await cache.put("C", "D") assert await cache.size() == 2 - cache.release() + await cache.release() await tests.wait_for(event, EVENT_TIMEOUT) assert name == cache.name @@ -566,40 +497,118 @@ def callback(n: str) -> None: # noinspection PyShadowingNames,DuplicatedCode,PyUnresolvedReferences @pytest.mark.asyncio -async def test_add_remove_index(setup_and_teardown_person_cache: NamedCache[str, Person]) -> None: - cache: NamedCache[str, Person] = setup_and_teardown_person_cache - - await cache.add_index(Extractors.extract("age")) - result = await cache.aggregate(Aggregators.record(), None, Filters.greater("age", 25)) - # print(result) - # {'@class': 'util.SimpleQueryRecord', 'results': [{'@class': 'util.SimpleQueryRecord.PartialResult', - # 'partitionSet': {'@class': 'net.partition.PartitionSet', 'bits': [2147483647], 'markedCount': -1, - # 'partitionCount': 31, 'tailMask': 2147483647}, 'steps': [{'@class': 'util.SimpleQueryRecord.PartialResult.Step', - # 'efficiency': 5, 'filter': 'GreaterFilter(.age, 25)', - # 'indexLookupRecords': [{'@class': 'util.SimpleQueryRecord.PartialResult.IndexLookupRecord', - # 'bytes': 6839, 'distinctValues': 5, 'extractor': '.age', 'index': 'Partitioned: Footprint=6.67KB, Size=5', - # 'indexDesc': 'Partitioned: ', 'ordered': False}], 'keySetSizePost': 0, 'keySetSizePre': 7, 'millis': 0, - # 'subSteps': []}]}], 'type': {'@class': 'aggregator.QueryRecorder.RecordType', 'enum': 'EXPLAIN'}} +async def test_add_remove_index(person_cache: NamedCache[str, Person]) -> None: + await person_cache.add_index(Extractors.extract("age")) + result = await person_cache.aggregate(Aggregators.record(), None, Filters.greater("age", 25)) idx_rec = result["results"][0].get("steps")[0].get("indexLookupRecords")[0] - # print(idx_rec) - # {'@class': 'util.SimpleQueryRecord.PartialResult.IndexLookupRecord', 'bytes': 6839, 'distinctValues': 5, - # 'extractor': '.age', 'index': 'Partitioned: Footprint=6.67KB, Size=5', 'indexDesc': 'Partitioned: ', - # 'ordered': False} assert "index" in idx_rec - await cache.remove_index(Extractors.extract("age")) - result2 = await cache.aggregate(Aggregators.record(), None, Filters.greater("age", 25)) - print(result2) - # {'@class': 'util.SimpleQueryRecord', 'results': [{'@class': 'util.SimpleQueryRecord.PartialResult', - # 'partitionSet': {'@class': 'net.partition.PartitionSet', 'bits': [2147483647], 'markedCount': -1, - # 'partitionCount': 31, 'tailMask': 2147483647}, 'steps': [{'@class': 'util.SimpleQueryRecord.PartialResult.Step', - # 'efficiency': 7000, 'filter': 'GreaterFilter(.age, 25)', - # 'indexLookupRecords': [{'@class': 'util.SimpleQueryRecord.PartialResult.IndexLookupRecord', 'bytes': -1, - # 'distinctValues': -1, 'extractor': '.age', 'ordered': False}], 'keySetSizePost': 0, 'keySetSizePre': 7, - # 'millis': 0, 'subSteps': []}]}], 'type': {'@class': 'aggregator.QueryRecorder.RecordType', 'enum': 'EXPLAIN'}} + await person_cache.remove_index(Extractors.extract("age")) + result2 = await person_cache.aggregate(Aggregators.record(), None, Filters.greater("age", 25)) idx_rec = result2["results"][0].get("steps")[0].get("indexLookupRecords")[0] - # print(idx_rec) - # {'@class': 'util.SimpleQueryRecord.PartialResult.IndexLookupRecord', 'bytes': -1, 'distinctValues': -1, - # 'extractor': '.age', 'ordered': False} assert "index" not in idx_rec + + +# noinspection PyExceptClausesOrder +@pytest.mark.asyncio +async def test_stream_request_timeout(cache: NamedCache[str, str]) -> None: + # insert enough data into the cache to ensure results will be paged + # by the proxy. + await _insert_large_number_of_entries(cache) + + start = time() + try: + async with request_timeout(seconds=1.0): + async for e in await cache.values(): + continue + assert False + except TimeoutError: # v1 + end = time() + assert pytest.approx((end - start), 0.5) == 1.0 + except asyncio.exceptions.TimeoutError: # v1 + end = time() + assert pytest.approx((end - start), 0.5) == 1.0 + except AioRpcError as e: # noqa: F841 + end = time() + assert e.code() == StatusCode.DEADLINE_EXCEEDED + assert pytest.approx((end - start), 0.5) == 1.0 + + +# noinspection PyExceptClausesOrder +@pytest.mark.asyncio +async def test_paged_stream_request_timeout(cache: NamedCache[str, str]) -> None: + # insert enough data into the cache to ensure results will be paged + # by the proxy. + await _insert_large_number_of_entries(cache) + + start = time() + try: + async with request_timeout(seconds=1.0): + async for e in await cache.values(by_page=True): + continue + assert False + except TimeoutError: + end = time() + assert pytest.approx((end - start), 0.5) == 1.0 + except asyncio.exceptions.TimeoutError: # v1 + end = time() + assert pytest.approx((end - start), 0.5) == 1.0 + except AioRpcError as e: # noqa: F841 + end = time() + assert e.code() == StatusCode.DEADLINE_EXCEEDED + assert pytest.approx((end - start), 0.5) == 1.0 + + +# noinspection PyUnresolvedReferences +@pytest.mark.asyncio +async def test_ttl_configuration(test_session: Session) -> None: + cache: NamedCache[str, str] = await test_session.get_cache("none") + assert cache._default_expiry == 0 + await cache.destroy() + + options: CacheOptions = CacheOptions() + cache = await test_session.get_cache("default", options) + assert cache._default_expiry == options.default_expiry + await cache.destroy() + + options = CacheOptions(default_expiry=2000) + cache = await test_session.get_cache("defined", options) + assert cache._default_expiry == options.default_expiry + + await cache.put("a", "b") + assert await cache.size() == 1 + + sleep(2.5) + assert await cache.size() == 0 + await cache.destroy() + + options = CacheOptions(default_expiry=2000) + cache = await test_session.get_cache("override", options) + + await cache.put("a", "b", 5000) + + assert await cache.size() == 1 + + sleep(2.5) + assert await cache.size() == 1 + + sleep(1) + assert await cache.size() == 1 + + sleep(3) + assert await cache.size() == 0 + await cache.destroy() + + +@pytest.mark.asyncio +async def test_unary_error(test_session: Session) -> None: + cache: NamedCache[str, str] = await test_session.get_cache("unary_error") + + d = dict() + d["@class"] = "com.foo.Bar" + + with pytest.raises(Exception) as ex: + await cache.put("a", d) + + assert "Could not deserialize" in str(ex.value) diff --git a/tests/test_events.py b/tests/e2e/test_events.py similarity index 89% rename from tests/test_events.py rename to tests/e2e/test_events.py index 4875db9..8725f1c 100644 --- a/tests/test_events.py +++ b/tests/e2e/test_events.py @@ -3,14 +3,11 @@ # https://oss.oracle.com/licenses/upl. import asyncio -import time -from typing import Any, AsyncGenerator, Generic, List, TypeVar, Union, cast +from typing import Generic, List, TypeVar, Union, cast import pytest -import pytest_asyncio -import tests -from coherence import Filters, NamedCache, Session +from coherence import Filters, NamedCache from coherence.event import MapEvent, MapEventType from coherence.filter import Filter, LessFilter, MapEventFilter from tests import CountingMapListener @@ -263,11 +260,14 @@ async def _run_basic_test( :param filter_mask: the event mask, if any """ listener: CountingMapListener[str, str] = CountingMapListener("basic") + query_filter: MapEventFilter[str, str] = ( + None if filter_mask is None else MapEventFilter(filter_mask, Filters.always()) + ) - if filter_mask is None: + if query_filter is None: await cache.add_map_listener(listener) else: - await cache.add_map_listener(listener, MapEventFilter(filter_mask, Filters.always())) + await cache.add_map_listener(listener, query_filter) await cache.put("A", "B") await cache.put("A", "C") @@ -279,7 +279,10 @@ async def _run_basic_test( # remove the listener and trigger some events. Ensure no events captured. listener.reset() - await cache.remove_map_listener(listener) + if query_filter is None: + await cache.remove_map_listener(listener) + else: + await cache.remove_map_listener(listener, query_filter) await cache.put("A", "B") await cache.put("A", "C") @@ -291,45 +294,26 @@ async def _run_basic_test( expected2.validate(listener) -@pytest_asyncio.fixture -async def setup_and_teardown() -> AsyncGenerator[NamedCache[Any, Any], None]: - """ - Fixture for test setup/teardown. - """ - session: Session = await tests.get_session() - cache: NamedCache[Any, Any] = await session.get_cache("test-" + str(time.time_ns())) - - yield cache - - await cache.clear() - await session.close() - - # ----- test functions ------------------------------------------------------ @pytest.mark.asyncio -async def test_add_no_listener(setup_and_teardown: NamedCache[str, str]) -> None: - cache: NamedCache[str, str] = setup_and_teardown - +async def test_add_no_listener(cache: NamedCache[str, str]) -> None: with pytest.raises(ValueError): await cache.add_map_listener(None) @pytest.mark.asyncio -async def test_remove_no_listener(setup_and_teardown: NamedCache[str, str]) -> None: - cache: NamedCache[str, str] = setup_and_teardown - +async def test_remove_no_listener(cache: NamedCache[str, str]) -> None: with pytest.raises(ValueError): await cache.remove_map_listener(None) # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_all(setup_and_teardown: NamedCache[str, str]) -> None: +async def test_all(cache: NamedCache[str, str]) -> None: """Ensure the registered MapListener is able to receive insert, update, and delete events.""" - cache: NamedCache[str, str] = setup_and_teardown name: str = cache.name expected: ExpectedEvents[str, str] = ExpectedEvents( @@ -343,10 +327,9 @@ async def test_all(setup_and_teardown: NamedCache[str, str]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_inserts_only(setup_and_teardown: NamedCache[str, str]) -> None: +async def test_inserts_only(cache: NamedCache[str, str]) -> None: """Ensure the registered MapListener is able to receive insert events only.""" - cache: NamedCache[str, str] = setup_and_teardown name: str = cache.name expected: ExpectedEvents[str, str] = ExpectedEvents( @@ -358,10 +341,9 @@ async def test_inserts_only(setup_and_teardown: NamedCache[str, str]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_updates_only(setup_and_teardown: NamedCache[str, str]) -> None: +async def test_updates_only(cache: NamedCache[str, str]) -> None: """Ensure the registered MapListener is able to receive update events only.""" - cache: NamedCache[str, str] = setup_and_teardown name: str = cache.name expected: ExpectedEvents[str, str] = ExpectedEvents( @@ -373,10 +355,9 @@ async def test_updates_only(setup_and_teardown: NamedCache[str, str]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_deletes_only(setup_and_teardown: NamedCache[str, str]) -> None: +async def test_deletes_only(cache: NamedCache[str, str]) -> None: """Ensure the registered MapListener is able to receive delete events only.""" - cache: NamedCache[str, str] = setup_and_teardown name: str = cache.name expected: ExpectedEvents[str, str] = ExpectedEvents( @@ -388,10 +369,9 @@ async def test_deletes_only(setup_and_teardown: NamedCache[str, str]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_multiple_listeners(setup_and_teardown: NamedCache[str, str]) -> None: +async def test_multiple_listeners(cache: NamedCache[str, str]) -> None: """Ensure the multiple registered MapListeners are able to receive insert, update, and delete events.""" - cache: NamedCache[str, str] = setup_and_teardown name: str = cache.name expected: ExpectedEvents[str, str] = ExpectedEvents( @@ -456,10 +436,9 @@ async def test_multiple_listeners(setup_and_teardown: NamedCache[str, str]) -> N # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_custom_filter_listener(setup_and_teardown: NamedCache[str, Person]) -> None: +async def test_custom_filter_listener(cache: NamedCache[str, Person]) -> None: """Ensure a custom filter is applied when filtering values for events.""" - cache: NamedCache[str, Person] = setup_and_teardown name: str = cache.name fred: Person = Person.fred() @@ -493,10 +472,9 @@ async def test_custom_filter_listener(setup_and_teardown: NamedCache[str, Person # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_key_listener(setup_and_teardown: NamedCache[str, Person]) -> None: +async def test_key_listener(cache: NamedCache[str, Person]) -> None: """Ensure a listener can be associated with a key.""" - cache: NamedCache[str, Person] = setup_and_teardown name: str = cache.name fred: Person = Person.fred() @@ -529,11 +507,10 @@ async def test_key_listener(setup_and_teardown: NamedCache[str, Person]) -> None # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_lite_listeners(setup_and_teardown: NamedCache[str, Person]) -> None: +async def test_lite_listeners(cache: NamedCache[str, Person]) -> None: """Ensure lite event handling works as expected alone or when similar listeners are registered that are non-lite. See test comments for details.""" - cache: NamedCache[str, Person] = setup_and_teardown name: str = cache.name always: Filter = Filters.always() diff --git a/tests/test_filters.py b/tests/e2e/test_filters.py similarity index 85% rename from tests/test_filters.py rename to tests/e2e/test_filters.py index 301e891..b4bb86b 100644 --- a/tests/test_filters.py +++ b/tests/e2e/test_filters.py @@ -1,33 +1,18 @@ # Copyright (c) 2022, 2023, Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at # https://oss.oracle.com/licenses/upl. - -from typing import Any, AsyncGenerator +from typing import Any import pytest -import pytest_asyncio -import tests -from coherence import NamedCache, Session +from coherence import NamedCache from coherence.filter import Filters from coherence.processor import ConditionalRemove, EntryProcessor -@pytest_asyncio.fixture -async def setup_and_teardown() -> AsyncGenerator[NamedCache[Any, Any], None]: - session: Session = await tests.get_session() - cache: NamedCache[Any, Any] = await session.get_cache("test") - - yield cache - - await cache.clear() - await session.close() - - # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_and(setup_and_teardown: NamedCache[str, Any]) -> None: - cache = setup_and_teardown +async def test_and(cache: NamedCache[str, Any]) -> None: k = "k1" v = {"id": 123, "my_str": "123", "ival": 123, "fval": 12.3, "iarr": [1, 2, 3], "group:": 1} await cache.put(k, v) @@ -45,8 +30,7 @@ async def test_and(setup_and_teardown: NamedCache[str, Any]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_or(setup_and_teardown: NamedCache[Any, Any]) -> None: - cache = setup_and_teardown +async def test_or(cache: NamedCache[Any, Any]) -> None: k = "k1" v = {"id": 123, "my_str": "123", "ival": 123, "fval": 12.3, "iarr": [1, 2, 3], "group:": 1} await cache.put(k, v) @@ -64,8 +48,7 @@ async def test_or(setup_and_teardown: NamedCache[Any, Any]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_xor(setup_and_teardown: NamedCache[Any, Any]) -> None: - cache = setup_and_teardown +async def test_xor(cache: NamedCache[Any, Any]) -> None: k = "k1" v = {"id": 123, "my_str": "123", "ival": 123, "fval": 12.3, "iarr": [1, 2, 3], "group:": 1} await cache.put(k, v) @@ -89,8 +72,7 @@ async def test_xor(setup_and_teardown: NamedCache[Any, Any]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_all(setup_and_teardown: NamedCache[Any, Any]) -> None: - cache = setup_and_teardown +async def test_all(cache: NamedCache[Any, Any]) -> None: k = "k1" v = {"id": 123, "my_str": "123", "ival": 123, "fval": 12.3, "iarr": [1, 2, 3], "group:": 1} await cache.put(k, v) @@ -114,8 +96,7 @@ async def test_all(setup_and_teardown: NamedCache[Any, Any]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_any(setup_and_teardown: NamedCache[Any, Any]) -> None: - cache = setup_and_teardown +async def test_any(cache: NamedCache[Any, Any]) -> None: k = "k1" v = {"id": 123, "my_str": "123", "ival": 123, "fval": 12.3, "iarr": [1, 2, 3], "group:": 1} await cache.put(k, v) @@ -139,8 +120,7 @@ async def test_any(setup_and_teardown: NamedCache[Any, Any]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_greater(setup_and_teardown: NamedCache[Any, Any]) -> None: - cache = setup_and_teardown +async def test_greater(cache: NamedCache[Any, Any]) -> None: k = "k1" v = {"id": 123, "my_str": "123", "ival": 123, "fval": 12.3, "iarr": [1, 2, 3], "group:": 1} await cache.put(k, v) @@ -158,8 +138,7 @@ async def test_greater(setup_and_teardown: NamedCache[Any, Any]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_greater_equals(setup_and_teardown: NamedCache[Any, Any]) -> None: - cache = setup_and_teardown +async def test_greater_equals(cache: NamedCache[Any, Any]) -> None: k = "k1" v = {"id": 123, "my_str": "123", "ival": 123, "fval": 12.3, "iarr": [1, 2, 3], "group:": 1} await cache.put(k, v) @@ -177,8 +156,7 @@ async def test_greater_equals(setup_and_teardown: NamedCache[Any, Any]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_less(setup_and_teardown: NamedCache[Any, Any]) -> None: - cache = setup_and_teardown +async def test_less(cache: NamedCache[Any, Any]) -> None: k = "k1" v = {"id": 123, "my_str": "123", "ival": 123, "fval": 12.3, "iarr": [1, 2, 3], "group:": 1} await cache.put(k, v) @@ -196,8 +174,7 @@ async def test_less(setup_and_teardown: NamedCache[Any, Any]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_less_equals(setup_and_teardown: NamedCache[Any, Any]) -> None: - cache = setup_and_teardown +async def test_less_equals(cache: NamedCache[Any, Any]) -> None: k = "k1" v = {"id": 123, "my_str": "123", "ival": 123, "fval": 12.3, "iarr": [1, 2, 3], "group:": 1} await cache.put(k, v) @@ -215,8 +192,7 @@ async def test_less_equals(setup_and_teardown: NamedCache[Any, Any]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_between(setup_and_teardown: NamedCache[Any, Any]) -> None: - cache = setup_and_teardown +async def test_between(cache: NamedCache[Any, Any]) -> None: k = "k1" v = {"id": 123, "my_str": "123", "ival": 123, "fval": 12.3, "iarr": [1, 2, 3], "group:": 1} await cache.put(k, v) @@ -240,8 +216,7 @@ async def test_between(setup_and_teardown: NamedCache[Any, Any]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_not_equals(setup_and_teardown: NamedCache[Any, Any]) -> None: - cache = setup_and_teardown +async def test_not_equals(cache: NamedCache[Any, Any]) -> None: k = "k1" v = {"id": 123, "my_str": "123", "ival": 123, "fval": 12.3, "iarr": [1, 2, 3], "group:": 1} await cache.put(k, v) @@ -259,8 +234,7 @@ async def test_not_equals(setup_and_teardown: NamedCache[Any, Any]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_not(setup_and_teardown: NamedCache[Any, Any]) -> None: - cache = setup_and_teardown +async def test_not(cache: NamedCache[Any, Any]) -> None: k = "k1" v = {"id": 123, "my_str": "123", "ival": 123, "fval": 12.3, "iarr": [1, 2, 3], "group:": 1} await cache.put(k, v) @@ -278,8 +252,7 @@ async def test_not(setup_and_teardown: NamedCache[Any, Any]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_is_none(setup_and_teardown: NamedCache[Any, Any]) -> None: - cache = setup_and_teardown +async def test_is_none(cache: NamedCache[Any, Any]) -> None: k = "k1" v = {"id": 123, "my_str": "123", "ival": 123, "fval": 12.3, "iarr": [1, 2, 3], "group:": 1} await cache.put(k, v) @@ -297,8 +270,7 @@ async def test_is_none(setup_and_teardown: NamedCache[Any, Any]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_is_not_none(setup_and_teardown: NamedCache[Any, Any]) -> None: - cache = setup_and_teardown +async def test_is_not_none(cache: NamedCache[Any, Any]) -> None: k = "k1" v = {"id": 123, "my_str": "123", "ival": 123, "fval": 12.3, "iarr": [1, 2, 3], "group:": 1} await cache.put(k, v) @@ -316,8 +288,7 @@ async def test_is_not_none(setup_and_teardown: NamedCache[Any, Any]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_contains_any(setup_and_teardown: NamedCache[Any, Any]) -> None: - cache = setup_and_teardown +async def test_contains_any(cache: NamedCache[Any, Any]) -> None: k = "k1" v = {"id": 123, "my_str": "123", "ival": 123, "fval": 12.3, "iarr": [1, 2, 3], "group:": 1} await cache.put(k, v) @@ -335,8 +306,7 @@ async def test_contains_any(setup_and_teardown: NamedCache[Any, Any]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_contains_all(setup_and_teardown: NamedCache[Any, Any]) -> None: - cache = setup_and_teardown +async def test_contains_all(cache: NamedCache[Any, Any]) -> None: k = "k1" v = {"id": 123, "my_str": "123", "ival": 123, "fval": 12.3, "iarr": [1, 2, 3], "group:": 1} await cache.put(k, v) @@ -354,8 +324,7 @@ async def test_contains_all(setup_and_teardown: NamedCache[Any, Any]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_contains(setup_and_teardown: NamedCache[Any, Any]) -> None: - cache = setup_and_teardown +async def test_contains(cache: NamedCache[Any, Any]) -> None: k = "k1" v = {"id": 123, "my_str": "123", "ival": 123, "fval": 12.3, "iarr": [1, 2, 3], "group:": 1} await cache.put(k, v) @@ -373,8 +342,7 @@ async def test_contains(setup_and_teardown: NamedCache[Any, Any]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_in(setup_and_teardown: NamedCache[Any, Any]) -> None: - cache = setup_and_teardown +async def test_in(cache: NamedCache[Any, Any]) -> None: k = "k1" v = {"id": 123, "my_str": "123", "ival": 123, "fval": 12.3, "iarr": [1, 2, 3], "group:": 1} await cache.put(k, v) @@ -392,8 +360,7 @@ async def test_in(setup_and_teardown: NamedCache[Any, Any]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_like(setup_and_teardown: NamedCache[Any, Any]) -> None: - cache = setup_and_teardown +async def test_like(cache: NamedCache[Any, Any]) -> None: k = "k1" v = {"id": 123, "my_str": "123-my-test-string", "ival": 123, "fval": 12.3, "iarr": [1, 2, 3], "group:": 1} await cache.put(k, v) @@ -417,8 +384,7 @@ async def test_like(setup_and_teardown: NamedCache[Any, Any]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_present(setup_and_teardown: NamedCache[Any, Any]) -> None: - cache = setup_and_teardown +async def test_present(cache: NamedCache[Any, Any]) -> None: k = "k1" v = {"id": 123, "my_str": "123-my-test-string", "ival": 123, "fval": 12.3, "iarr": [1, 2, 3], "group:": 1} await cache.put(k, v) @@ -436,8 +402,7 @@ async def test_present(setup_and_teardown: NamedCache[Any, Any]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_regex(setup_and_teardown: NamedCache[Any, Any]) -> None: - cache = setup_and_teardown +async def test_regex(cache: NamedCache[Any, Any]) -> None: k = "k1" v = {"id": 123, "my_str": "test", "ival": 123, "fval": 12.3, "iarr": [1, 2, 3], "group:": 1} await cache.put(k, v) diff --git a/tests/e2e/test_near_caching.py b/tests/e2e/test_near_caching.py new file mode 100644 index 0000000..36a92c1 --- /dev/null +++ b/tests/e2e/test_near_caching.py @@ -0,0 +1,427 @@ +# Copyright (c) 2024, Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at +# https://oss.oracle.com/licenses/upl. + +import asyncio +import time +from functools import reduce +from typing import Optional + +import pytest + +from coherence import CacheOptions, CacheStats, Filters, NamedCache, NearCacheOptions, Processors, Session + + +@pytest.mark.asyncio +async def test_basic_put_get_remove(test_session: Session) -> None: + if test_session._protocol_version < 1: + return + + cache: NamedCache[str, str] = await test_session.get_cache( + "basic", CacheOptions(near_cache_options=NearCacheOptions(ttl=2000)) + ) + await cache.clear() + + stats: Optional[CacheStats] = cache.near_cache_stats + assert stats is not None + + result: Optional[str] = await cache.put("a", "b") + assert result is None + assert await cache.size() == 1 + assert stats.size == 0 + assert stats.puts == 0 + + result = await cache.get("a") + assert result == "b" + assert stats.size == 1 + assert stats.puts == 1 + assert stats.gets == 1 + assert stats.hits == 0 + assert stats.misses == 1 + + result = await cache.get("a") + assert result == "b" + assert stats.size == 1 + assert stats.puts == 1 + assert stats.gets == 2 + assert stats.hits == 1 + assert stats.misses == 1 + + # allow entry to expire + await asyncio.sleep(2.1) + + result = await cache.get("a") + assert result == "b" + assert stats.size == 1 + assert stats.puts == 2 + assert stats.gets == 3 + assert stats.hits == 1 + assert stats.misses == 2 + assert stats.expires == 1 + + result = await cache.remove("a") + assert result == "b" + assert await cache.size() == 0 + assert stats.size == 0 + assert stats.puts == 2 + assert stats.gets == 3 + assert stats.hits == 1 + assert stats.misses == 2 + assert stats.expires == 1 + + # re-populate the near cache + await cache.put("a", "b") + await cache.get("a") + + assert stats.size == 1 + assert stats.puts == 3 + assert stats.gets == 4 + assert stats.hits == 1 + assert stats.misses == 3 + assert stats.expires == 1 + + # remove the entry via processor and ensure the near cache + # is in the expected state + await cache.invoke("a", Processors.conditional_remove(Filters.always())) + await asyncio.sleep(1) + + assert stats.size == 0 + assert stats.puts == 3 + assert stats.gets == 4 + assert stats.hits == 1 + assert stats.misses == 3 + assert stats.expires == 1 + + # re-populate the near cache + # noinspection PyTypeChecker + await cache.put("a", {"b": "d"}) + assert await cache.get("a") == {"b": "d"} + + assert stats.size == 1 + assert stats.puts == 4 + assert stats.gets == 5 + assert stats.hits == 1 + assert stats.misses == 4 + assert stats.expires == 1 + + # update an entry via processor and ensure the near cache + # is in the expected state + await cache.invoke("a", Processors.update("b", "c")) + assert await cache.get("a") == {"b": "c"} + + assert stats.size == 1 + assert stats.puts == 5 + assert stats.gets == 6 + assert stats.hits == 2 + assert stats.misses == 4 + assert stats.expires == 1 + + +@pytest.mark.asyncio +async def test_get_all(test_session: Session) -> None: + if test_session._protocol_version < 1: + return + + cache: NamedCache[str, str] = await test_session.get_cache( + "basic", CacheOptions(near_cache_options=NearCacheOptions(ttl=2000)) + ) + await cache.clear() + + stats: Optional[CacheStats] = cache.near_cache_stats + assert stats is not None + + await cache.put_all({str(x): str(x) for x in range(10)}) + assert stats.size == 0 + assert stats.puts == 0 + assert stats.gets == 0 + assert stats.hits == 0 + assert stats.misses == 0 + + result: dict[str, str] = {} + async for entry in await cache.get_all({"0", "9"}): + result[entry.key] = entry.value + assert stats.size == 2 + assert stats.puts == 2 + assert stats.gets == 2 + assert stats.hits == 0 + assert stats.misses == 2 + + assert result == {"0": "0", "9": "9"} + + # issue a get_all that has a mix of keys that are and are + # not in the near cache + result = {} + async for entry in await cache.get_all({"0", "9", "1", "8"}): + result[entry.key] = entry.value + + assert stats.size == 4 + assert stats.puts == 4 + assert stats.gets == 6 + assert stats.hits == 2 + assert stats.misses == 4 + + assert result == {"0": "0", "9": "9", "1": "1", "8": "8"} + + # issue a get_all for only keys present in the near cache + result = {} + async for entry in await cache.get_all({"0", "9", "1", "8"}): + result[entry.key] = entry.value + + assert stats.size == 4 + assert stats.puts == 4 + assert stats.gets == 10 + assert stats.hits == 6 + assert stats.misses == 4 + + assert result == {"0": "0", "9": "9", "1": "1", "8": "8"} + + +@pytest.mark.asyncio +async def test_remove(test_session: Session) -> None: + if test_session._protocol_version < 1: + return + + cache: NamedCache[str, str] = await test_session.get_cache( + "basic", CacheOptions(near_cache_options=NearCacheOptions(ttl=2000)) + ) + await cache.clear() + + stats: Optional[CacheStats] = cache.near_cache_stats + assert stats is not None + + # populate the near cache + await cache.put("a", "b") + await cache.get("a") + + assert stats.size == 1 + assert stats.puts == 1 + assert stats.gets == 1 + assert stats.hits == 0 + assert stats.misses == 1 + assert stats.expires == 0 + + # invalid mapping should have no impact on near cache + await cache.remove_mapping("a", "c") + + assert stats.size == 1 + assert stats.puts == 1 + assert stats.gets == 1 + assert stats.hits == 0 + assert stats.misses == 1 + assert stats.expires == 0 + + # assert near cache entry is removed + await cache.remove_mapping("a", "b") + + assert stats.size == 0 + assert stats.puts == 1 + assert stats.gets == 1 + assert stats.hits == 0 + assert stats.misses == 1 + assert stats.expires == 0 + + +@pytest.mark.asyncio +async def test_replace(test_session: Session) -> None: + if test_session._protocol_version < 1: + return + + cache: NamedCache[str, str] = await test_session.get_cache( + "basic", CacheOptions(near_cache_options=NearCacheOptions(ttl=2000)) + ) + await cache.clear() + + stats: Optional[CacheStats] = cache.near_cache_stats + assert stats is not None + + # populate the near cache + await cache.put("a", "b") + await cache.get("a") + + assert stats.size == 1 + assert stats.puts == 1 + assert stats.gets == 1 + assert stats.hits == 0 + assert stats.misses == 1 + assert stats.expires == 0 + + # blind replace + await cache.replace("a", "c") + + assert stats.size == 1 + assert stats.puts == 2 + assert stats.gets == 1 + assert stats.hits == 0 + assert stats.misses == 1 + assert stats.expires == 0 + + # invalid mapping should have no impact on near cache + await cache.replace_mapping("a", "b", "c") + + assert stats.size == 1 + assert stats.puts == 2 + assert stats.gets == 1 + assert stats.hits == 0 + assert stats.misses == 1 + assert stats.expires == 0 + + # assert near cache entry is removed + await cache.replace_mapping("a", "c", "b") + + assert stats.size == 1 + assert stats.puts == 3 + assert stats.gets == 1 + assert stats.hits == 0 + assert stats.misses == 1 + assert stats.expires == 0 + + +@pytest.mark.asyncio +async def test_clear(test_session: Session) -> None: + if test_session._protocol_version < 1: + return + + cache: NamedCache[str, str] = await test_session.get_cache( + "basic", CacheOptions(near_cache_options=NearCacheOptions(ttl=2000)) + ) + + stats: Optional[CacheStats] = cache.near_cache_stats + assert stats is not None + + await cache.put_all({str(x): str(x) for x in range(10)}) + + async for _ in await cache.get_all({str(x) for x in range(10)}): + continue + + assert stats.size == 10 + assert stats.puts == 10 + assert stats.gets == 10 + assert stats.hits == 0 + assert stats.misses == 10 + + await cache.clear() + + assert stats.size == 0 + assert stats.puts == 10 + assert stats.gets == 10 + assert stats.hits == 0 + assert stats.misses == 10 + + +@pytest.mark.asyncio +async def test_incompatible_near_cache_options(test_session: Session) -> None: + cache: NamedCache[str, str] = await test_session.get_cache( + "basic", CacheOptions(near_cache_options=NearCacheOptions(ttl=2000)) + ) + await cache.clear() + + with pytest.raises(ValueError) as err: + await test_session.get_cache("basic", CacheOptions(near_cache_options=NearCacheOptions(ttl=1900))) + + assert str(err.value) == "A NamedMap or NamedCache with the same name already exists with different CacheOptions" + + cache2: NamedCache[str, str] = await test_session.get_cache( + "basic", CacheOptions(near_cache_options=NearCacheOptions(ttl=2000)) + ) + + assert cache == cache2 + + +@pytest.mark.asyncio +async def test_concurrency(test_session: Session) -> None: + if test_session._protocol_version < 1: + return + + cache: NamedCache[str, str] = await test_session.get_cache( + "basic", CacheOptions(near_cache_options=NearCacheOptions(ttl=0)) + ) + await cache.clear() + stats: CacheStats = cache.near_cache_stats + + # these knobs control: + # - how many current tasks to run + # - how many entries will be inserted and queried + # - how many times the calls will be invoked + task_count: int = 100 + num_entries: int = 1_000 + iterations: int = 4 + + cache_seed: dict[str, str] = {str(x): str(x) for x in range(num_entries)} + cache_seed_keys: set[str] = {key for key in cache_seed.keys()} + print() + + async def get_all_task() -> int: + begin = time.time_ns() + + for _ in range(iterations): + async for _ in await cache.get_all(cache_seed_keys): + continue + + return (time.time_ns() - begin) // 1_000_000 + + async def get_task() -> int: + begin = time.time_ns() + + for _ in range(iterations): + for key in cache_seed_keys: + await cache.get(key) + + return (time.time_ns() - begin) // 1_000_000 + + await cache.put_all(cache_seed) + + begin_outer: int = time.time_ns() + results: list[int] = await asyncio.gather(*[get_all_task() for _ in range(task_count)]) + end_outer: int = time.time_ns() + + print_and_validate( + "get_all", + num_entries, + iterations, + task_count, + (end_outer - begin_outer), + reduce(lambda first, second: first + second, results), + stats, + ) + + stats.reset() + await cache.clear() + await cache.put_all(cache_seed) + + begin_outer = time.time_ns() + results2: list[int] = await asyncio.gather(*[get_task() for _ in range(task_count)]) + end_outer = time.time_ns() + + print_and_validate( + "individual_gets", + num_entries, + iterations, + task_count, + (end_outer - begin_outer), + reduce(lambda first, second: first + second, results2), + stats, + ) + + +def print_and_validate( + task_name: str, + num_entries: int, + iterations: int, + task_count: int, + total_time: int, + task_time: int, + stats: CacheStats, +) -> None: + print() + print(f"[{task_name}] {task_count} Tasks Completed!") + print(f"[{task_name}] Stats at end -> {stats} -> {total_time // 1_000_000}ms") + print(f"[{task_name}] Tasks completion average: {task_time / task_count}") + + assert stats.puts == num_entries + assert stats.gets == iterations * task_count * num_entries + assert stats.hits == (iterations * task_count * num_entries) - num_entries + assert stats.misses == num_entries + assert stats.size == num_entries + assert stats.hit_rate == pytest.approx(0.99, rel=0.2) + assert stats.expires == 0 + assert stats.prunes == 0 diff --git a/tests/test_processors.py b/tests/e2e/test_processors.py similarity index 81% rename from tests/test_processors.py rename to tests/e2e/test_processors.py index 976e66d..46648fb 100644 --- a/tests/test_processors.py +++ b/tests/e2e/test_processors.py @@ -2,13 +2,11 @@ # Licensed under the Universal Permissive License v 1.0 as shown at # https://oss.oracle.com/licenses/upl. -from typing import Any, AsyncGenerator +from typing import Any import pytest -import pytest_asyncio -import tests -from coherence import NamedCache, Session +from coherence import NamedCache from coherence.filter import Filter, Filters from coherence.processor import EntryProcessor, Numeric, PreloadRequest, Processors, ScriptProcessor, TouchProcessor from coherence.serialization import _META_VERSION, JSONSerializer @@ -16,22 +14,9 @@ from tests.person import Person -@pytest_asyncio.fixture -async def setup_and_teardown() -> AsyncGenerator[NamedCache[Any, Any], None]: - session: Session = await tests.get_session() - cache: NamedCache[Any, Any] = await session.get_cache("test") - - yield cache - - await cache.clear() - await session.close() - - # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_extractor(setup_and_teardown: NamedCache[Any, Any]) -> None: - cache: NamedCache[Any, Any] = setup_and_teardown - +async def test_extractor(cache: NamedCache[Any, Any]) -> None: k1 = "one" v1 = "only-one" await cache.put(k1, v1) @@ -62,9 +47,7 @@ async def test_extractor(setup_and_teardown: NamedCache[Any, Any]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_composite(setup_and_teardown: NamedCache[str, Any]) -> None: - cache: NamedCache[str, Any] = setup_and_teardown - +async def test_composite(cache: NamedCache[str, Any]) -> None: k = "k1" v = {"id": 123, "my_str": "123", "ival": 123, "fval": 12.3, "iarr": [1, 2, 3], "group:": 1} await cache.put(k, v) @@ -82,9 +65,7 @@ async def test_composite(setup_and_teardown: NamedCache[str, Any]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_conditional(setup_and_teardown: NamedCache[str, Any]) -> None: - cache: NamedCache[str, Any] = setup_and_teardown - +async def test_conditional(cache: NamedCache[str, Any]) -> None: k = "k1" v = {"id": 123, "my_str": "123", "ival": 123, "fval": 12.3, "iarr": [1, 2, 3], "group:": 1} await cache.put(k, v) @@ -101,9 +82,7 @@ async def test_conditional(setup_and_teardown: NamedCache[str, Any]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_null(setup_and_teardown: NamedCache[str, Any]) -> None: - cache: NamedCache[str, Any] = setup_and_teardown - +async def test_null(cache: NamedCache[str, Any]) -> None: k = "k1" v = {"id": 123, "my_str": "123", "ival": 123, "fval": 12.3, "iarr": [1, 2, 3], "group:": 1} await cache.put(k, v) @@ -114,9 +93,7 @@ async def test_null(setup_and_teardown: NamedCache[str, Any]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_multiplier(setup_and_teardown: NamedCache[Any, Any]) -> None: - cache: NamedCache[Any, Any] = setup_and_teardown - +async def test_multiplier(cache: NamedCache[Any, Any]) -> None: k = "k1" v = {"id": 123, "my_str": "123", "ival": 123, "fval": 12.3, "iarr": [1, 2, 3], "group:": 1} await cache.put(k, v) @@ -127,9 +104,7 @@ async def test_multiplier(setup_and_teardown: NamedCache[Any, Any]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_incrementor(setup_and_teardown: NamedCache[Any, Any]) -> None: - cache: NamedCache[Any, Any] = setup_and_teardown - +async def test_incrementor(cache: NamedCache[Any, Any]) -> None: k = "k1" v = {"id": 123, "my_str": "123", "ival": 123, "fval": 12.3, "iarr": [1, 2, 3], "group:": 1} await cache.put(k, v) @@ -140,9 +115,7 @@ async def test_incrementor(setup_and_teardown: NamedCache[Any, Any]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_conditional_put(setup_and_teardown: NamedCache[Any, Any]) -> None: - cache: NamedCache[Any, Any] = setup_and_teardown - +async def test_conditional_put(cache: NamedCache[Any, Any]) -> None: k1 = "one" v1 = "only-one" await cache.put(k1, v1) @@ -163,9 +136,7 @@ async def test_conditional_put(setup_and_teardown: NamedCache[Any, Any]) -> None # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_conditional_put_all(setup_and_teardown: NamedCache[Any, Any]) -> None: - cache: NamedCache[Any, Any] = setup_and_teardown - +async def test_conditional_put_all(cache: NamedCache[Any, Any]) -> None: k1 = "one" v1 = "only-one" await cache.put(k1, v1) @@ -176,7 +147,7 @@ async def test_conditional_put_all(setup_and_teardown: NamedCache[Any, Any]) -> f = Filters.always() # This will always return True cp = Processors.conditional_put_all(f, dict([(k1, "only-one-one"), (k2, "only-two-two")])) - async for _ in cache.invoke_all(cp): + async for _ in await cache.invoke_all(cp): break # ignore the results assert await cache.get(k1) == "only-one-one" @@ -184,7 +155,7 @@ async def test_conditional_put_all(setup_and_teardown: NamedCache[Any, Any]) -> pf = Filters.present() cp = Processors.conditional_put_all(Filters.negate(pf), dict([("three", "only-three")])) - async for _ in cache.invoke_all(cp, {"one", "three"}): + async for _ in await cache.invoke_all(cp, {"one", "three"}): break # ignore the results assert await cache.get(k1) == "only-one-one" @@ -194,9 +165,7 @@ async def test_conditional_put_all(setup_and_teardown: NamedCache[Any, Any]) -> # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_conditional_remove(setup_and_teardown: NamedCache[str, str]) -> None: - cache: NamedCache[str, str] = setup_and_teardown - +async def test_conditional_remove(cache: NamedCache[str, str]) -> None: k1 = "one" v1 = "only-one" await cache.put(k1, v1) @@ -220,9 +189,7 @@ async def test_conditional_remove(setup_and_teardown: NamedCache[str, str]) -> N # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_method_invocation(setup_and_teardown: NamedCache[str, Any]) -> None: - cache: NamedCache[str, Any] = setup_and_teardown - +async def test_method_invocation(cache: NamedCache[str, Any]) -> None: k = "k1" v = {"id": 123, "my_str": "123", "ival": 123, "fval": 12.3, "iarr": [1, 2, 3], "group:": 1} await cache.put(k, v) @@ -281,9 +248,7 @@ async def test_preload() -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_updater(setup_and_teardown: NamedCache[Any, Any]) -> None: - cache: NamedCache[Any, Any] = setup_and_teardown - +async def test_updater(cache: NamedCache[Any, Any]) -> None: k = "k1" v = {"id": 123, "my_str": "123", "ival": 123, "fval": 12.3, "iarr": [1, 2, 3], "group:": 1} await cache.put(k, v) @@ -294,9 +259,7 @@ async def test_updater(setup_and_teardown: NamedCache[Any, Any]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_versioned_put(setup_and_teardown: NamedCache[Any, Any]) -> None: - cache: NamedCache[Any, Any] = setup_and_teardown - +async def test_versioned_put(cache: NamedCache[Any, Any]) -> None: k = "123" versioned123 = { _META_VERSION: 1, @@ -337,8 +300,7 @@ async def test_versioned_put(setup_and_teardown: NamedCache[Any, Any]) -> None: # noinspection PyShadowingNames @pytest.mark.asyncio -async def test_versioned_put_all(setup_and_teardown: NamedCache[Any, Any]) -> None: - cache: NamedCache[Any, Any] = setup_and_teardown +async def test_versioned_put_all(cache: NamedCache[Any, Any]) -> None: k1 = "123" versioned123 = { _META_VERSION: 1, @@ -406,7 +368,7 @@ async def test_versioned_put_all(setup_and_teardown: NamedCache[Any, Any]) -> No vpa = Processors.versioned_put_all(dict([(k1, versioned123_update), (k2, versioned234_update)])) - async for _ in cache.invoke_all(vpa): + async for _ in await cache.invoke_all(vpa): break assert await cache.get(k1) == expected_versioned123_update diff --git a/tests/test_session.py b/tests/e2e/test_session.py similarity index 92% rename from tests/test_session.py rename to tests/e2e/test_session.py index 182b42e..ec78170 100644 --- a/tests/test_session.py +++ b/tests/e2e/test_session.py @@ -14,7 +14,7 @@ import pytest import tests -from coherence import NamedCache, NamedMap, Options, Session, TlsOptions +from coherence import NamedCache, NamedMap, Options, Session from coherence.event import MapLifecycleEvent, SessionLifecycleEvent from tests import CountingMapListener @@ -26,7 +26,6 @@ async def test_basics() -> None: """Test initial session state and post-close invocations raise error""" - run_secure: str = "RUN_SECURE" session: Session = await tests.get_session() def check_basics() -> None: @@ -35,16 +34,6 @@ def check_basics() -> None: assert session.session_id is not None assert session.options is not None - if run_secure in os.environ: - assert session.options.tls_options is not None - assert session.options.tls_options.enabled - assert session.options.tls_options.client_key_path == os.environ.get(TlsOptions.ENV_CLIENT_KEY) - assert session.options.tls_options.ca_cert_path == os.environ.get(TlsOptions.ENV_CA_CERT) - assert session.options.tls_options.client_cert_path == os.environ.get(TlsOptions.ENV_CLIENT_CERT) - else: - assert session.options.tls_options is None - assert session.options.channel_options is None - assert session.options.session_disconnect_timeout_seconds == Options.DEFAULT_SESSION_DISCONNECT_TIMEOUT if Options.ENV_REQUEST_TIMEOUT in os.environ: @@ -137,7 +126,7 @@ def close_callback() -> None: session.on(SessionLifecycleEvent.RECONNECTED, reconn_callback) session.on(SessionLifecycleEvent.CLOSED, close_callback) - await tests.wait_for(conn_event, EVENT_TIMEOUT) + # await tests.wait_for(conn_event, EVENT_TIMEOUT) assert session.is_ready() await _shutdown_proxy() @@ -161,13 +150,17 @@ def close_callback() -> None: @pytest.mark.asyncio async def test_wait_for_ready() -> None: session: Session = await tests.get_session(10.0) - print(f"Session -> {session}") + + print(f"Session (pre-cache) -> {session}") logging.debug("Getting cache ...") try: count: int = 50 cache: NamedCache[str, str] = await session.get_cache("test-" + str(int(time() * 1000))) + + print(f"Session (post-cache) -> {session}") + listener: CountingMapListener[str, str] = CountingMapListener("Test") await _run_pre_shutdown_logic(cache, listener, count) @@ -300,7 +293,7 @@ def callback(n: str) -> None: assert await cache.size() == 2 if lifecycle_event == MapLifecycleEvent.RELEASED: - cache.release() + await cache.release() else: await cache.destroy() diff --git a/tests/java/coherence-python-test/src/main/java/com/oracle/coherence/python/testing/LongRunningProcessor.java b/tests/java/coherence-python-test/src/main/java/com/oracle/coherence/python/testing/LongRunningProcessor.java new file mode 100644 index 0000000..9320679 --- /dev/null +++ b/tests/java/coherence-python-test/src/main/java/com/oracle/coherence/python/testing/LongRunningProcessor.java @@ -0,0 +1,24 @@ +package com.oracle.coherence.python.testing; + + +import com.tangosol.util.Base; +import com.tangosol.util.InvocableMap; +import java.util.Map; +import java.util.Set; + + +public class LongRunningProcessor + implements InvocableMap.EntryProcessor + { + public Void process(InvocableMap.Entry entry) + { + Base.sleep(5000); + return null; + } + + public Map processAll(Set> setEntries) + { + Base.sleep(5000); + return InvocableMap.EntryProcessor.super.processAll(setEntries); + } + } diff --git a/tests/java/coherence-python-test/src/main/resources/META-INF/type-aliases.properties b/tests/java/coherence-python-test/src/main/resources/META-INF/type-aliases.properties index 4c62004..c4022d8 100644 --- a/tests/java/coherence-python-test/src/main/resources/META-INF/type-aliases.properties +++ b/tests/java/coherence-python-test/src/main/resources/META-INF/type-aliases.properties @@ -1,9 +1,7 @@ -# # Copyright (c) 2022 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at # https://oss.oracle.com/licenses/upl. -# - test.customer=com.oracle.coherence.python.testing.Customer test.address=com.oracle.coherence.python.testing.Address +test.longrunning=com.oracle.coherence.python.testing.LongRunningProcessor diff --git a/tests/java/pom.xml b/tests/java/pom.xml index e5bba50..e7f688b 100644 --- a/tests/java/pom.xml +++ b/tests/java/pom.xml @@ -40,7 +40,7 @@ com.oracle.coherence.ce ${coherence.version} - gcr.io/distroless/java:11 + gcr.io/distroless/java17-debian12:latest 2.2.1 diff --git a/tests/test_serialization.py b/tests/test_serialization.py deleted file mode 100644 index b8d2942..0000000 --- a/tests/test_serialization.py +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright (c) 2022, 2023 Oracle and/or its affiliates. -# Licensed under the Universal Permissive License v 1.0 as shown at -# https://oss.oracle.com/licenses/upl. - -from decimal import Decimal -from typing import Any - -from coherence.serialization import JSONSerializer, proxy -from tests.Task import Task - - -def test_python_decimal() -> None: - _verify_round_trip(Decimal("12.1345797249237923423872493"), True) - - -def test_python_large_integer() -> None: - _verify_round_trip(9223372036854775810, True) # larger than Java Long (2^63 - 1) - - -def test_python_large_negative_integer() -> None: - _verify_round_trip(-9223372036854775810, True) # less than Java Long -2^63 - - -def test_python_java_long_upper_bound() -> None: - _verify_round_trip(9223372036854775807, False) # Java Long (2^32 - 1) - - -def test_python_java_long_lower_bound() -> None: - _verify_round_trip(-9223372036854775809, False) # Java Long (2^32 - 1) - - -def test_custom_object() -> None: - _verify_round_trip(Task("Milk, eggs, and bread"), True) - - -def test_python_numerics_in_object() -> None: - _verify_round_trip(Simple(), False) - - -@proxy("test.Simple") -class Simple: - def __init__(self) -> None: - super().__init__() - self.n1 = (2**63) - 1 - self.n2 = self.n1 + 5 - self.n3 = Decimal("12.1345797249237923423872493") - - def __eq__(self, o: object) -> bool: - if isinstance(o, Simple): - return self.n1 == getattr(o, "n1") and self.n2 == getattr(o, "n2") and self.n3 == getattr(o, "n3") - - return False - - -def _verify_round_trip(obj: Any, should_have_class: bool) -> None: - serializer: JSONSerializer = JSONSerializer() - ser_result: bytes = serializer.serialize(obj) - print(f"Serialized [{type(obj)}] result: {ser_result.decode()}") - - if should_have_class: - assert "@class" in ser_result.decode() - - deser_result: Any = serializer.deserialize(ser_result) - print(f"Deserialized [{type(deser_result)}] result: {deser_result}") - assert deser_result == obj diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..0945b75 --- /dev/null +++ b/tests/unit/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) 2024 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at +# https://oss.oracle.com/licenses/upl. diff --git a/tests/unit/test_cache_options.py b/tests/unit/test_cache_options.py new file mode 100644 index 0000000..9995a5d --- /dev/null +++ b/tests/unit/test_cache_options.py @@ -0,0 +1,156 @@ +import pytest + +from coherence.client import CacheOptions +from coherence.local_cache import NearCacheOptions + + +def test_cache_options_default_expiry() -> None: + options: CacheOptions = CacheOptions(-10) + assert options.default_expiry == -1 + + options = CacheOptions(10000) + assert options.default_expiry == 10000 + + +def test_near_cache_options_no_explicit_params() -> None: + with pytest.raises(ValueError) as err: + NearCacheOptions() + + assert str(err.value) == "at least one option must be specified" + + +def test_near_cache_options_negative_units() -> None: + message: str = "values for high_units and high_units_memory must be positive" + + with pytest.raises(ValueError) as err: + NearCacheOptions(high_units=-1) + + assert str(err.value) == message + + with pytest.raises(ValueError) as err: + NearCacheOptions(high_units_memory=-1) + + assert str(err.value) == message + + +def test_near_cache_options_both_units() -> None: + message: str = "high_units and high_units_memory cannot be used together; specify one or the other" + + with pytest.raises(ValueError) as err: + NearCacheOptions(high_units=1000, high_units_memory=10000) + + assert str(err.value) == message + + with pytest.raises(ValueError) as err: + NearCacheOptions(ttl=10000, high_units=1000, high_units_memory=10000) + + assert str(err.value) == message + + +def test_near_cache_options_prune_factor() -> None: + message: str = "prune_factor must be between .1 and 1" + + with pytest.raises(ValueError) as err: + NearCacheOptions(high_units=100, prune_factor=-1) + + assert str(err.value) == message + + with pytest.raises(ValueError) as err: + NearCacheOptions(high_units=100, prune_factor=0) + + assert str(err.value) == message + + with pytest.raises(ValueError) as err: + NearCacheOptions(high_units=100, prune_factor=0.05) + + assert str(err.value) == message + + with pytest.raises(ValueError) as err: + NearCacheOptions(high_units=100, prune_factor=1.001) + + assert str(err.value) == message + + +def test_near_cache_options_str() -> None: + options: NearCacheOptions = NearCacheOptions(high_units=100) + assert str(options) == "NearCacheOptions(ttl=0ms, high-units=100, high-units-memory=0, prune-factor=0.80)" + + options = NearCacheOptions(high_units=100, ttl=1000) + assert str(options) == "NearCacheOptions(ttl=1000ms, high-units=100, high-units-memory=0, prune-factor=0.80)" + + options = NearCacheOptions(high_units_memory=100 * 1024) + assert str(options) == "NearCacheOptions(ttl=0ms, high-units=0, high-units-memory=102400, prune-factor=0.80)" + + options = NearCacheOptions(high_units_memory=100 * 1024, prune_factor=0.25) + assert str(options) == "NearCacheOptions(ttl=0ms, high-units=0, high-units-memory=102400, prune-factor=0.25)" + + +def test_near_cache_eq() -> None: + options: NearCacheOptions = NearCacheOptions(high_units=100) + options2: NearCacheOptions = NearCacheOptions(high_units=100, ttl=1000) + options3: NearCacheOptions = NearCacheOptions(high_units=100) + + assert options == options + assert options != options2 + assert options == options3 + assert options != "some string" + + +def test_near_cache_options_ttl() -> None: + options = NearCacheOptions(ttl=1000) + assert options.ttl == 1000 + + # ensure minimum can be set + options = NearCacheOptions(ttl=250) + assert options.ttl == 250 + + +def test_near_cache_ttl_negative() -> None: + with pytest.raises(ValueError) as err: + NearCacheOptions(ttl=-1) + + assert str(err.value) == "ttl cannot be less than zero" + + with pytest.raises(ValueError) as err: + NearCacheOptions(ttl=100) + + assert str(err.value) == "ttl has 1/4 second resolution; minimum TTL is 250" + + +def test_near_cache_options_high_units() -> None: + options: NearCacheOptions = NearCacheOptions(high_units=10000) + assert options.high_units == 10000 + + +def test_near_cache_options_high_units_memory() -> None: + options: NearCacheOptions = NearCacheOptions(high_units_memory=10000) + assert options._high_units_memory == 10000 + + +def test_cache_options_str() -> None: + options: CacheOptions = CacheOptions(10000) + assert str(options) == "CacheOptions(default-expiry=10000)" + + options = CacheOptions(5000, NearCacheOptions(high_units=10000)) + assert ( + str(options) == "CacheOptions(default-expiry=5000, near-cache-options=NearCacheOptions(ttl=0ms," + " high-units=10000, high-units-memory=0, prune-factor=0.80))" + ) + + +def test_cache_options_eq() -> None: + options: CacheOptions = CacheOptions(10000) + options2: CacheOptions = CacheOptions(10000) + options3: CacheOptions = CacheOptions(1000) + + assert options == options + assert options == options2 + assert options != options3 + + options = CacheOptions(10000, NearCacheOptions(high_units=10000)) + options2 = CacheOptions(10000, NearCacheOptions(high_units=10000)) + options3 = CacheOptions(10000, NearCacheOptions(high_units=1000)) + + assert options == options + assert options == options2 + assert options != options3 diff --git a/tests/unit/test_environment.py b/tests/unit/test_environment.py new file mode 100644 index 0000000..34d306f --- /dev/null +++ b/tests/unit/test_environment.py @@ -0,0 +1,172 @@ +# Copyright (c) 2024, Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at +# https://oss.oracle.com/licenses/upl. + +import pytest + +from coherence import Options, TlsOptions + + +def test_options_address(monkeypatch: pytest.MonkeyPatch) -> None: + try: + monkeypatch.delenv(name=Options.ENV_SERVER_ADDRESS, raising=False) + options: Options = Options() + assert options.address == "localhost:1408" + + custom_address: str = "acme.com:1409" + options = Options(address=custom_address) + assert options.address == custom_address + + custom_address = "acme.com:1409" + monkeypatch.setenv(name=Options.ENV_SERVER_ADDRESS, value=custom_address) + options = Options() + assert options.address == custom_address + + options = Options(address="127.0.0.1:9000") + assert options.address == custom_address + finally: + monkeypatch.undo() + + +def test_options_req_timeout(monkeypatch: pytest.MonkeyPatch) -> None: + try: + monkeypatch.delenv(name=Options.ENV_REQUEST_TIMEOUT, raising=False) + options: Options = Options() + assert options.request_timeout_seconds == 30 + + options = Options(request_timeout_seconds=15) + assert options.request_timeout_seconds == 15 + + monkeypatch.setenv(Options.ENV_REQUEST_TIMEOUT, "35") + options = Options() + assert options.request_timeout_seconds == 35 + + options = Options(request_timeout_seconds=15) + assert options.request_timeout_seconds == 35 + finally: + monkeypatch.undo() + + +def test_options_ready_timeout(monkeypatch: pytest.MonkeyPatch) -> None: + try: + monkeypatch.delenv(name=Options.ENV_READY_TIMEOUT, raising=False) + options: Options = Options() + assert options.ready_timeout_seconds == 0 + + options = Options(ready_timeout_seconds=15) + assert options.ready_timeout_seconds == 15 + + monkeypatch.setenv(Options.ENV_READY_TIMEOUT, "35") + options = Options() + assert options.ready_timeout_seconds == 35 + + options = Options(ready_timeout_seconds=15) + assert options.ready_timeout_seconds == 35 + finally: + monkeypatch.undo() + + +def test_disconnect_timeout(monkeypatch: pytest.MonkeyPatch) -> None: + try: + monkeypatch.delenv(name=Options.ENV_SESSION_DISCONNECT_TIMEOUT, raising=False) + options: Options = Options() + assert options.session_disconnect_timeout_seconds == 30 + + options = Options(session_disconnect_seconds=15) + assert options.session_disconnect_timeout_seconds == 15 + + monkeypatch.setenv(Options.ENV_SESSION_DISCONNECT_TIMEOUT, "35") + options = Options() + assert options.session_disconnect_timeout_seconds == 35 + + options = Options(ready_timeout_seconds=15) + assert options.session_disconnect_timeout_seconds == 35 + finally: + monkeypatch.undo() + + +def test_tls_options(monkeypatch: pytest.MonkeyPatch) -> None: + try: + monkeypatch.delenv(name=TlsOptions.ENV_CA_CERT, raising=False) + monkeypatch.delenv(name=TlsOptions.ENV_CLIENT_CERT, raising=False) + monkeypatch.delenv(name=TlsOptions.ENV_CLIENT_KEY, raising=False) + + tls_options: TlsOptions = TlsOptions() + assert tls_options.enabled is False + assert tls_options.ca_cert_path is None + assert tls_options.client_cert_path is None + assert tls_options.client_key_path is None + + ca_cert_path: str = "/tmp/ca.pem" + client_cert_path: str = "/tmp/client.pem" + client_key_path: str = "/tmp/client.key" + ca_cert_path2: str = "/tmp/ca2.pem" + client_cert_path2: str = "/tmp/client2.pem" + client_key_path2: str = "/tmp/client2.key" + + tls_options = TlsOptions( + enabled=True, ca_cert_path=ca_cert_path, client_cert_path=client_cert_path, client_key_path=client_key_path + ) + assert tls_options.enabled + assert tls_options.ca_cert_path == ca_cert_path + assert tls_options.client_cert_path == client_cert_path + assert tls_options.client_key_path == client_key_path + + monkeypatch.setenv(name=TlsOptions.ENV_CA_CERT, value=ca_cert_path) + monkeypatch.setenv(name=TlsOptions.ENV_CLIENT_CERT, value=client_cert_path) + monkeypatch.setenv(name=TlsOptions.ENV_CLIENT_KEY, value=client_key_path) + + tls_options = TlsOptions() + assert tls_options.enabled is False + assert tls_options.ca_cert_path == ca_cert_path + assert tls_options.client_cert_path == client_cert_path + assert tls_options.client_key_path == client_key_path + + tls_options = TlsOptions( + enabled=True, + ca_cert_path=ca_cert_path2, + client_cert_path=client_cert_path2, + client_key_path=client_key_path2, + ) + assert tls_options.enabled + assert tls_options.ca_cert_path == ca_cert_path + assert tls_options.client_cert_path == client_cert_path + assert tls_options.client_key_path == client_key_path + + monkeypatch.delenv(name=TlsOptions.ENV_CA_CERT) + tls_options = TlsOptions( + enabled=True, + ca_cert_path=ca_cert_path2, + client_cert_path=client_cert_path2, + client_key_path=client_key_path2, + ) + assert tls_options.enabled + assert tls_options.ca_cert_path == ca_cert_path2 + assert tls_options.client_cert_path == client_cert_path + assert tls_options.client_key_path == client_key_path + + monkeypatch.delenv(name=TlsOptions.ENV_CLIENT_CERT) + tls_options = TlsOptions( + enabled=True, + ca_cert_path=ca_cert_path2, + client_cert_path=client_cert_path2, + client_key_path=client_key_path2, + ) + assert tls_options.enabled + assert tls_options.ca_cert_path == ca_cert_path2 + assert tls_options.client_cert_path == client_cert_path2 + assert tls_options.client_key_path == client_key_path + + monkeypatch.delenv(name=TlsOptions.ENV_CLIENT_KEY) + tls_options = TlsOptions( + enabled=True, + ca_cert_path=ca_cert_path2, + client_cert_path=client_cert_path2, + client_key_path=client_key_path2, + ) + assert tls_options.enabled + assert tls_options.ca_cert_path == ca_cert_path2 + assert tls_options.client_cert_path == client_cert_path2 + assert tls_options.client_key_path == client_key_path2 + finally: + monkeypatch.undo() diff --git a/tests/test_extractors.py b/tests/unit/test_extractors.py similarity index 100% rename from tests/test_extractors.py rename to tests/unit/test_extractors.py diff --git a/tests/unit/test_local_cache.py b/tests/unit/test_local_cache.py new file mode 100644 index 0000000..5ea2988 --- /dev/null +++ b/tests/unit/test_local_cache.py @@ -0,0 +1,351 @@ +# Copyright (c) 2024, Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at +# https://oss.oracle.com/licenses/upl. + +import asyncio +import time +from typing import Any, Callable, Coroutine, Optional + +import pytest + +from coherence.local_cache import CacheStats, LocalCache, LocalEntry, NearCacheOptions +from coherence.util import cur_time_millis, millis_format_date + + +def test_local_entry() -> None: + start: int = cur_time_millis() + entry: LocalEntry[str, str] = LocalEntry("a", "b", 750) + + # check initial state after creation + assert entry.key == "a" + assert entry.value == "b" + assert entry.ttl == 750 + assert entry.last_access >= start + assert entry.bytes > 0 + + # touch the entry and ensure the last_access has increased + # over the insert time + last: int = entry.last_access + time.sleep(0.3) + entry.touch() + assert entry.last_access > last + + +def test_local_entry_str() -> None: + entry: LocalEntry[str, str] = LocalEntry("a", "b", 500) + + result: str = str(entry) + assert result == (f"LocalEntry(key=a, value=b, ttl=500ms, last-access={millis_format_date(entry.last_access)})") + + time.sleep(0.6) + + result = str(entry) + assert result == (f"LocalEntry(key=a, value=b, ttl=500ms, last-access={millis_format_date(entry.last_access)})") + + +@pytest.mark.asyncio +async def test_basic_put_get_remove() -> None: + options: NearCacheOptions = NearCacheOptions(high_units=10) + cache: LocalCache[str, str] = LocalCache("test", options) + result: Optional[str] = await cache.get("a") + + assert result is None + + stats: CacheStats = cache.stats + + # validate stats with a get against an empty cache + assert stats.misses == 1 + assert stats.gets == 1 + assert stats.bytes == 0 + + # check stats after a single put + result = await cache.put("a", "b") + assert result is None + assert stats.puts == 1 + assert stats.gets == 1 + + # check stats after a get for the value previously inserted + result = await cache.get("a") + assert result == "b" + assert stats.misses == 1 + assert stats.gets == 2 + assert stats.bytes > 0 + + # update the value + result = await cache.put("a", "c") + + # snapshot the current size for validation later + stats_bytes: int = stats.bytes + + # ensure previous value returned after update and stats + # are accurate + assert result == "b" + assert stats.puts == 2 + assert stats.misses == 1 + assert stats.gets == 2 + + # insert new value and validate stats + result = await cache.put("b", "d") + assert result is None + assert stats.puts == 3 + assert stats.misses == 1 + assert stats.gets == 2 + assert stats.bytes > stats_bytes + + # issue a series of gets for a non-existent key + for _ in range(10): + await cache.get("c") + + # validate the stats including the hit-rate + assert stats.gets == 12 + assert stats.hits == 1 + assert stats.misses == 11 + assert stats.hit_rate == 0.083 + assert stats.size == 2 + + # remove a value from the cache + # ensure the returned value is what was associated + # with the key. Ensure the bytes has decreased + # back to the snapshot taken earlier + result = await cache.remove("b") + assert result == "d" + assert stats.bytes == stats_bytes + + +@pytest.mark.asyncio +async def test_get_all() -> None: + cache: LocalCache[str, str] = LocalCache("test", NearCacheOptions(high_units=100)) + stats: CacheStats = cache.stats + + for i in range(10): + key_value: str = str(i) + await cache.put(key_value, key_value) + + assert stats.puts == 10 + + results: dict[str, str] = await cache.get_all({"1", "2", "3", "4", "5"}) + assert len(results) == 5 + for i in range(1, 5): + key_value = str(i) + assert results[key_value] == key_value + assert stats.gets == 5 + assert stats.hits == 5 + assert stats.misses == 0 + + results = await cache.get_all({"8", "9", "10", "11"}) + assert len(results) == 2 + for i in range(8, 10): + key_value = str(i) + assert results[key_value] == key_value + assert ("10" in results) is False + assert ("11" in results) is False + assert stats.gets == 9 + assert stats.hits == 7 + assert stats.misses == 2 + + +@pytest.mark.asyncio +async def test_expiry() -> None: + options: NearCacheOptions = NearCacheOptions(high_units=1000) + cache: LocalCache[str, str] = LocalCache("test", options) + stats: CacheStats = cache.stats + + for i in range(5): + await cache.put(str(i), str(i), 1000) + + for i in range(5, 10): + await cache.put(str(i), str(i), 2000) + + for i in range(10): + await cache.get(str(i)) + + assert await cache.size() == 10 + + await asyncio.sleep(1.3) + + for i in range(5): + assert await cache.get(str(i)) is None + + for i in range(5, 10): + assert await cache.get(str(i)) == str(i) + + duration: int = stats.expires_duration + assert stats.expires == 1 + assert stats.num_expired == 5 + assert duration > 0 + + await asyncio.sleep(1.05) + + for i in range(10): + assert await cache.get(str(i)) is None + + # assert correct expires count and + # the duration has increased + assert stats.expires == 2 + assert stats.num_expired == 10 + assert stats.expires_duration > duration + + +@pytest.mark.asyncio +async def test_pruning_units() -> None: + options: NearCacheOptions = NearCacheOptions(high_units=100) + cache: LocalCache[str, str] = LocalCache("test", options) + stats: CacheStats = cache.stats + + for i in range(210): + key_value: str = str(i) + await cache.put(key_value, key_value, 0) + + cur_size: int = await cache.size() + assert cur_size < 100 + assert stats.prunes == 6 + assert stats.prunes_duration > 0 + + # assert that the oldest entries were pruned first + for i in range(210 - cur_size, 210): + key_value = str(i) + assert await cache.get(key_value) == key_value + + +@pytest.mark.asyncio +async def test_pruning_memory() -> None: + upper_bound_mem: int = 110 * 1024 # 110KB + options: NearCacheOptions = NearCacheOptions(high_units_memory=upper_bound_mem) + cache: LocalCache[str, str] = LocalCache("test", options) + stats: CacheStats = cache.stats + + for i in range(210): + key_value: str = str(i) + await cache.put(key_value, key_value, 0) + + assert stats.prunes > 0 + assert stats.prunes_duration > 0 + assert stats.bytes < upper_bound_mem + + # assert that the oldest entries were pruned first + for i in range(210 - await cache.size(), 210): + key_value = str(i) + assert await cache.get(key_value) == key_value + + +@pytest.mark.asyncio +async def test_stats_reset() -> None: + upper_bound_mem: int = 110 * 1024 # 110KB + options: NearCacheOptions = NearCacheOptions(ttl=500, high_units_memory=upper_bound_mem) + cache: LocalCache[str, str] = LocalCache("test", options) + stats: CacheStats = cache.stats + + for i in range(210): + key_value: str = str(i) + await cache.put(key_value, key_value) + await cache.get(key_value) + + await cache.get("none") + await cache.put("A", "B", 0) + + await asyncio.sleep(0.75) + + assert await cache.size() == 1 + + memory: int = stats.bytes + assert stats.puts == 211 + assert stats.gets == 211 + assert stats.prunes > 0 + assert stats.prunes_duration > 0 + assert stats.expires > 0 + assert stats.expires_duration > 0 + assert stats.hits > 0 + assert stats.hit_rate > 0 + assert stats.misses > 0 + assert memory > 0 + + stats.reset() + + assert stats.puts == 0 + assert stats.gets == 0 + assert stats.prunes == 0 + assert stats.prunes_duration == 0 + assert stats.expires == 0 + assert stats.expires_duration == 0 + assert stats.hits == 0 + assert stats.hit_rate == 0.0 + assert stats.misses == 0 + assert stats.bytes == memory + + +@pytest.mark.asyncio +async def test_clear() -> None: + async def do_clear(cache: LocalCache) -> None: + await cache.clear() + + await _validate_clear_reset(do_clear) + + +@pytest.mark.asyncio +async def test_release() -> None: + async def do_release(cache: LocalCache) -> None: + await cache.release() + + await _validate_clear_reset(do_release) + + +@pytest.mark.asyncio +async def test_local_cache_str() -> None: + options: NearCacheOptions = NearCacheOptions(high_units=300) + cache: LocalCache[str, str] = LocalCache("test", options) + + for i in range(210): + key_value: str = str(i) + await cache.put(key_value, key_value, 500) + await cache.get(key_value) + + await cache.get("none") + await cache.put("A", "B", 0) + + result: str = str(cache) + stats: str = ( + f"CacheStats(puts=211, gets=211, hits=210, misses=1," + f" misses-duration=0ms, hit-rate=0.995, prunes=0, num-pruned=0, prunes-duration=0ms," + f" size=211, expires=0, num-expired=0, expires-duration=0ms, memory-bytes={cache.stats.bytes})" + ) + + assert result == f"LocalCache(name=test, options={str(options)}, stats={stats})" + + +async def _validate_clear_reset(reset: Callable[[LocalCache], Coroutine[Any, Any, None]]) -> None: + options: NearCacheOptions = NearCacheOptions(high_units=300) + cache: LocalCache[str, str] = LocalCache("test", options) + stats: CacheStats = cache.stats + + for i in range(210): + key_value: str = str(i) + await cache.put(key_value, key_value) + await cache.get(key_value) + + assert stats.size == 210 + assert stats.bytes > 1000 + + # store current states to clear impacts the appropriate stats + puts: int = stats.puts + gets: int = stats.gets + misses: int = stats.misses + misses_duration: int = stats.misses_duration + prunes: int = stats.prunes + prunes_duration: int = stats.prunes_duration + expires: int = stats.expires + expires_duration: int = stats.expires_duration + hit_rate: float = stats.hit_rate + + await reset(cache) + + assert stats.puts == puts + assert stats.gets == gets + assert stats.misses == misses + assert stats.misses_duration == misses_duration + assert stats.prunes == prunes + assert stats.prunes_duration == prunes_duration + assert stats.expires == expires + assert stats.expires_duration == expires_duration + assert stats.hit_rate == hit_rate + assert stats.size == 0 + assert stats.bytes == 0 diff --git a/tests/unit/test_serialization.py b/tests/unit/test_serialization.py new file mode 100644 index 0000000..23f4348 --- /dev/null +++ b/tests/unit/test_serialization.py @@ -0,0 +1,255 @@ +# Copyright (c) 2022, 2023 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at +# https://oss.oracle.com/licenses/upl. + +from decimal import Decimal +from time import time +from typing import Any +from uuid import uuid4 + +from coherence import Extractors, Filters +from coherence.ai import ( + BinaryQuantIndex, + BitVector, + ByteVector, + CosineDistance, + DocumentChunk, + FloatVector, + QueryResult, + SimilaritySearch, +) +from coherence.extractor import ValueExtractor +from coherence.filter import Filter +from coherence.serialization import JSONSerializer, SerializerRegistry, mappings, proxy + +s = SerializerRegistry.serializer(JSONSerializer.SER_FORMAT) + + +def test_python_decimal() -> None: + _verify_round_trip(Decimal("12.1345797249237923423872493"), True) + + +def test_python_large_integer() -> None: + _verify_round_trip(9223372036854775810, True) # larger than Java Long (2^63 - 1) + + +def test_python_large_negative_integer() -> None: + _verify_round_trip(-9223372036854775810, True) # less than Java Long -2^63 + + +def test_python_java_long_upper_bound() -> None: + _verify_round_trip(9223372036854775807, False) # Java Long (2^32 - 1) + + +def test_python_java_long_lower_bound() -> None: + _verify_round_trip(-9223372036854775809, False) # Java Long (2^32 - 1) + + +def test_custom_object() -> None: + _verify_round_trip(Task("Milk, eggs, and bread"), True) + + +def test_python_numerics_in_object() -> None: + _verify_round_trip(Simple(), False) + + +@proxy("Task") +@mappings({"created_at": "createdAt"}) +class Task: + def __init__(self, description: str) -> None: + super().__init__() + self.id: str = str(uuid4())[0:6] + self.description: str = description + self.completed: bool = False + self.created_at: int = int(time() * 1000) + + def __hash__(self) -> int: + return hash((self.id, self.description, self.completed, self.created_at)) + + def __eq__(self, o: object) -> bool: + if isinstance(o, Task): + # noinspection PyTypeChecker + t: Task = o + return ( + self.id == t.id + and self.description == t.description + and self.completed == t.completed + and self.created_at == t.created_at + ) + return False + + def __str__(self) -> str: + return 'Task(id="{}", description="{}", completed={}, created_at={})'.format( + self.id, self.description, self.completed, self.created_at + ) + + +@proxy("test.Simple") +class Simple: + def __init__(self) -> None: + super().__init__() + self.n1 = (2**63) - 1 + self.n2 = self.n1 + 5 + self.n3 = Decimal("12.1345797249237923423872493") + + def __eq__(self, o: object) -> bool: + if isinstance(o, Simple): + return self.n1 == getattr(o, "n1") and self.n2 == getattr(o, "n2") and self.n3 == getattr(o, "n3") + + return False + + +def _verify_round_trip(obj: Any, should_have_class: bool) -> None: + serializer: JSONSerializer = JSONSerializer() + ser_result: bytes = serializer.serialize(obj) + print(f"Serialized [{type(obj)}] result: {ser_result.decode()}") + + if should_have_class: + assert "@class" in ser_result.decode() + + deser_result: Any = serializer.deserialize(ser_result) + print(f"Deserialized [{type(deser_result)}] result: {deser_result}") + assert deser_result == obj + + +def test_bit_vector_serialization() -> None: + coh_bv = BitVector(hex_string="AABBCC") + ser = s.serialize(coh_bv) + assert ser == b'\x15{"@class": "ai.BitVector", "bits": "0xAABBCC"}' + o = s.deserialize(ser) + assert isinstance(o, BitVector) + + coh_bv = BitVector(hex_string="0xAABBCC") + ser = s.serialize(coh_bv) + assert ser == b'\x15{"@class": "ai.BitVector", "bits": "0xAABBCC"}' + o = s.deserialize(ser) + assert isinstance(o, BitVector) + + coh_bv = BitVector(hex_string=None, byte_array=bytes([1, 2, 10])) + ser = s.serialize(coh_bv) + assert ser == b'\x15{"@class": "ai.BitVector", "bits": "0x01020a"}' + o = s.deserialize(ser) + assert isinstance(o, BitVector) + + coh_bv = BitVector(hex_string=None, int_array=[1234, 1235]) + ser = s.serialize(coh_bv) + assert ser == b'\x15{"@class": "ai.BitVector", "bits": "0x4d24d3"}' + o = s.deserialize(ser) + assert isinstance(o, BitVector) + + +def test_byte_vector_serialization() -> None: + coh_int8v = ByteVector(bytes([1, 2, 3, 4])) + ser = s.serialize(coh_int8v) + assert ser == b'\x15{"@class": "ai.Int8Vector", "array": "AQIDBA=="}' + o = s.deserialize(ser) + assert isinstance(o, ByteVector) + + +def test_float_vector_serialization() -> None: + coh_fv = FloatVector([1.0, 2.0, 3.0]) + ser = s.serialize(coh_fv) + assert ser == b'\x15{"@class": "ai.Float32Vector", "array": [1.0, 2.0, 3.0]}' + o = s.deserialize(ser) + assert isinstance(o, FloatVector) + + +def test_document_chunk_serialization() -> None: + dc = DocumentChunk("test") + ser = s.serialize(dc) + assert ser == ( + b'\x15{"@class": "ai.DocumentChunk", "dataVersion": 0, ' + b'"metadata": {"@ordered": true, "entries": []}, "text": "test"}' + ) + o = s.deserialize(ser) + assert isinstance(o, DocumentChunk) + + d = {"one": "one-value", "two": "two-value"} + dc = DocumentChunk("test", d) + ser = s.serialize(dc) + assert ser == ( + b'\x15{"@class": "ai.DocumentChunk", "dataVersion": 0, "metadata": {"entries": [' + b'{"key": "one", "value": "one-value"}, {"key": "two", "value": "two-value"}]}, ' + b'"text": "test"}' + ) + o = s.deserialize(ser) + assert isinstance(o, DocumentChunk) + + coh_fv = FloatVector([1.0, 2.0, 3.0]) + d = {"one": "one-value", "two": "two-value"} + dc = DocumentChunk("test", d, coh_fv) + ser = s.serialize(dc) + assert ser == ( + b'\x15{"@class": "ai.DocumentChunk", "dataVersion": 0, "metadata": {"entries": [' + b'{"key": "one", "value": "one-value"}, {"key": "two", "value": "two-value"}]}, ' + b'"vector": {"@class": "ai.Float32Vector", "array": [1.0, 2.0, 3.0]}, "text": "test"}' + ) + o = s.deserialize(ser) + assert isinstance(o, DocumentChunk) + + +# noinspection PyUnresolvedReferences +def test_similarity_search_serialization() -> None: + coh_fv = FloatVector([1.0, 2.0, 3.0]) + ve = Extractors.extract("foo") + f = Filters.equals("foo", "bar") + ss = SimilaritySearch(ve, coh_fv, 19, filter=f) + ser = s.serialize(ss) + assert ser == ( + b'\x15{"@class": "ai.search.SimilarityAggregator", ' + b'"extractor": {"@class": "extractor.UniversalExtractor", "name": "foo", "params": null}, ' + b'"algorithm": {"@class": "ai.distance.CosineSimilarity"}, ' + b'"bruteForce": false, ' + b'"filter": {"@class": "filter.EqualsFilter", ' + b'"extractor": {"@class": "extractor.UniversalExtractor", ' + b'"name": "foo", "params": null}, "value": "bar"}, "maxResults": 19, ' + b'"vector": {"@class": "ai.Float32Vector", "array": [1.0, 2.0, 3.0]}}' + ) + + o = s.deserialize(ser) + assert isinstance(o, SimilaritySearch) + assert isinstance(o.extractor, ValueExtractor) + assert isinstance(o.algorithm, CosineDistance) + assert isinstance(o.filter, Filter) + assert o.maxResults == 19 + assert isinstance(o.vector, FloatVector) + + ss.bruteForce = True + ser = s.serialize(ss) + assert ser == ( + b'\x15{"@class": "ai.search.SimilarityAggregator", ' + b'"extractor": {"@class": "extractor.UniversalExtractor", "name": "foo", "params": null}, ' + b'"algorithm": {"@class": "ai.distance.CosineSimilarity"}, ' + b'"bruteForce": true, ' + b'"filter": {"@class": "filter.EqualsFilter", ' + b'"extractor": {"@class": "extractor.UniversalExtractor", ' + b'"name": "foo", "params": null}, "value": "bar"}, "maxResults": 19, ' + b'"vector": {"@class": "ai.Float32Vector", "array": [1.0, 2.0, 3.0]}}' + ) + + +# noinspection PyUnresolvedReferences +def test_query_result_serialization() -> None: + bqr = QueryResult(3.0, 1, "abc") + ser = s.serialize(bqr) + assert ser == b'\x15{"@class": "ai.results.QueryResult", "distance": 3.0, "key": 1, "value": "abc"}' + + o = s.deserialize(ser) + assert isinstance(o, QueryResult) + assert o.distance == 3.0 + assert o.key == 1 + assert o.value == "abc" + + +# noinspection PyUnresolvedReferences +def test_binary_quant_index_serialization() -> None: + bqi = BinaryQuantIndex(Extractors.extract("foo")) + ser = s.serialize(bqi) + assert ser == ( + b'\x15{"@class": "ai.index.BinaryQuantIndex", "dataVersion": 0, ' + b'"binFuture": null, "extractor": {"@class": "extractor.UniversalExtractor", ' + b'"name": "foo", "params": null}, "oversamplingFactor": 3}' + ) + + o = s.deserialize(ser) + assert isinstance(o, BinaryQuantIndex)