Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Emable SSL on Redis connection based on env config (to enable AWS lambda connectivity) #209

Merged
merged 4 commits into from
Dec 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion development/stream_interface/inference_pipeline_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def command_thread(pipeline: InferencePipeline, watchdog: PipelineWatchDog) -> N
help=f"Flag to decide if output to be streamed or displayed on screen",
required=False,
type=str,
default="screen",
default="display",
)
args = parser.parse_args()
main(
Expand Down
2 changes: 2 additions & 0 deletions docker/dockerfiles/Dockerfile.onnx.lambda
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ ENV ALLOW_NUMPY_INPUT=False
ENV INFERENCE_SERVER_ID=HostedInferenceLambda
ENV DISABLE_VERSION_CHECK=true
ENV DOCTR_MULTIPROCESSING_DISABLE=TRUE
ENV REDIS_SSL=true

WORKDIR ${LAMBDA_TASK_ROOT}

CMD [ "lambda.handler" ]
2 changes: 2 additions & 0 deletions docker/dockerfiles/Dockerfile.onnx.lambda.slim
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ ENV LAMBDA=True
ENV ALLOW_NUMPY_INPUT=False
ENV INFERENCE_SERVER_ID=HostedInferenceLambda
ENV DISABLE_VERSION_CHECK=true
ENV REDIS_SSL=true

WORKDIR ${LAMBDA_TASK_ROOT}

CMD [ "lambda.handler" ]
15 changes: 13 additions & 2 deletions inference/core/cache/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
from redis.exceptions import ConnectionError, TimeoutError

from inference.core import logger
from inference.core.cache.memory import MemoryCache
from inference.core.cache.redis import RedisCache
from inference.core.env import REDIS_HOST, REDIS_PORT
from inference.core.env import REDIS_HOST, REDIS_PORT, REDIS_SSL, REDIS_TIMEOUT

if REDIS_HOST is not None:
cache = RedisCache(host=REDIS_HOST, port=REDIS_PORT)
try:
cache = RedisCache(
host=REDIS_HOST, port=REDIS_PORT, ssl=REDIS_SSL, timeout=REDIS_TIMEOUT
)
except (TimeoutError, ConnectionError):
logger.error(
f"Could not connect to Redis under {REDIS_HOST}:{REDIS_PORT}. MemoryCache to be used."
)
cache = MemoryCache()
else:
cache = MemoryCache()
20 changes: 18 additions & 2 deletions inference/core/cache/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,27 @@ class RedisCache(BaseCache):
_expire_thread (threading.Thread): A thread that runs the _expire method.
"""

def __init__(self, host: str = "localhost", port: int = 6379, db: int = 0) -> None:
def __init__(
self,
host: str = "localhost",
port: int = 6379,
db: int = 0,
ssl: bool = False,
timeout: float = 2.0,
) -> None:
"""
Initializes a new instance of the MemoryCache class.
"""
self.client = redis.Redis(host=host, port=port, db=db, decode_responses=True)
self.client = redis.Redis(
host=host,
port=port,
db=db,
decode_responses=True,
ssl=ssl,
socket_timeout=timeout,
socket_connect_timeout=timeout,
)
self.client.ping()

self.zexpires = dict()

Expand Down
2 changes: 2 additions & 0 deletions inference/core/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,8 @@

# Redis port, default is 6379
REDIS_PORT = int(os.getenv("REDIS_PORT", 6379))
REDIS_SSL = str2bool(os.getenv("REDIS_SSL", False))
REDIS_TIMEOUT = float(os.getenv("REDIS_TIMEOUT", 2.0))

# Required ONNX providers, default is None
REQUIRED_ONNX_PROVIDERS = safe_split_value(os.getenv("REQUIRED_ONNX_PROVIDERS", None))
Expand Down
Loading