Skip to content

Commit

Permalink
feat: optionally check if all env vars match
Browse files Browse the repository at this point in the history
From EDA credentials we might pass in env vars for source
plugins, the generic source plugin can optionally have
a dictionary of env vars that it can check
e.g.
```
   check_env_vars:
       ENV_V1: value1
       ENV_V2: value2
```
  • Loading branch information
mkanoor committed Nov 14, 2024
1 parent 2d5edfa commit 03d2eec
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 1 deletion.
35 changes: 34 additions & 1 deletion extensions/eda/plugins/event_source/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@
final payload which can be used to trigger a shutdown of
the rulebook, especially when we are using rulebooks to
forward messages to other running rulebooks.
check_env_vars dict Optionally check if all the defined env vars are set
before generating the events. If any of the env_var is missing
or the value doesn't match the source plugin will end
with an exception
"""
Expand All @@ -53,16 +57,35 @@
from __future__ import annotations

import asyncio
import os
import random
import time
from dataclasses import dataclass, fields
from datetime import datetime
from pathlib import Path
from typing import Any
from typing import Any, Dict, Optional

import yaml


class MissingEnvVarError(Exception):
"""Exception class for missing env var."""

def __init__(self: "MissingEnvVarError", env_var: str) -> None:
"""Class constructor with the missing env_var."""
super().__init__(f"Env Var {env_var} is required")


class EnvVarMismatchError(Exception):
"""Exception class for mismatch in the env var value."""

def __init__(
self: "EnvVarMismatchError", env_var: str, value: str, expected: str
) -> None:
"""Class constructor with mismatch in env_var value."""
super().__init__(f"Env Var {env_var} expected: {expected} passed in: {value}")


@dataclass
class Args:
"""Class to store all the passed in args."""
Expand All @@ -84,6 +107,7 @@ class ControlArgs:
loop_count: int = 1
repeat_count: int = 1
timestamp: bool = False
check_env_vars: Optional[Dict[str, str]] = None


@dataclass
Expand Down Expand Up @@ -135,6 +159,7 @@ async def __call__(self: Generic) -> None:
msg = "time_format must be one of local, iso8601, epoch"
raise ValueError(msg)

await self._check_env_vars()
await self._load_payload_from_file()

if not isinstance(self.my_args.payload, list):
Expand Down Expand Up @@ -174,6 +199,14 @@ async def _post_event(self: Generic, event: dict[str, Any], index: int) -> None:
print(data) # noqa: T201
await self.queue.put(data)

async def _check_env_vars(self: Generic) -> None:
if self.control_args.check_env_vars:
for key, value in self.control_args.check_env_vars.items():
if key not in os.environ:
raise MissingEnvVarError(key)
if os.environ[key] != value:
raise EnvVarMismatchError(key, os.environ[key], value)

async def _load_payload_from_file(self: Generic) -> None:
if not self.my_args.payload_file:
return
Expand Down
58 changes: 58 additions & 0 deletions tests/unit/event_source/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
import pytest
import yaml

from extensions.eda.plugins.event_source.generic import (
EnvVarMismatchError,
MissingEnvVarError,
)
from extensions.eda.plugins.event_source.generic import main as generic_main


Expand Down Expand Up @@ -243,3 +247,57 @@ def test_generic_parsing_payload_file() -> None:
},
)
)


def test_env_vars_missing() -> None:
"""Test missing env vars"""
myqueue = _MockQueue()
event = {"name": "fred"}

with pytest.raises(MissingEnvVarError):
asyncio.run(
generic_main(
myqueue,
{
"payload": event,
"check_env_vars": {"NAME_MISSING": "Fred"},
},
)
)


def test_env_vars_mismatch() -> None:
"""Test env vars with incorrect values"""
myqueue = _MockQueue()
event = {"name": "fred"}

os.environ["TEST_ENV1"] = "Kaboom"
with pytest.raises(EnvVarMismatchError):
asyncio.run(
generic_main(
myqueue,
{
"payload": event,
"check_env_vars": {"TEST_ENV1": "Fred"},
},
)
)


def test_env_vars() -> None:
"""Test env vars with correct values"""
myqueue = _MockQueue()
event = {"name": "fred"}

os.environ["TEST_ENV1"] = "Fred"
asyncio.run(
generic_main(
myqueue,
{
"payload": event,
"check_env_vars": {"TEST_ENV1": "Fred"},
},
)
)
assert len(myqueue.queue) == 1
assert myqueue.queue[0] == {"name": "fred"}

0 comments on commit 03d2eec

Please sign in to comment.