Skip to content

Commit

Permalink
more changes
Browse files Browse the repository at this point in the history
  • Loading branch information
drf7 committed Dec 20, 2024
1 parent dd48d67 commit 2cc60e0
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 23 deletions.
30 changes: 17 additions & 13 deletions nuclia/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

from nuclia.exceptions import KBNotAvailable
from nuclia.lib.kb import AsyncNucliaDBClient, Environment, NucliaDBClient
from nuclia.config import read_config
from nuclia.sdk.auth import NucliaAuth
from nuclia.sdk.auth import AsyncNucliaAuth

if TYPE_CHECKING:
from nuclia.config import Config
Expand All @@ -26,33 +29,34 @@ def set_config(config: Config):


def get_config(config_path: Optional[str] = None) -> Config:
if DATA.config is None:
from nuclia.config import read_config

if DATA.config is None or DATA.config.filepath != config_path:
DATA.config = read_config(config_path=config_path)
return DATA.config


def get_auth() -> NucliaAuth:
get_config()
if DATA.auth is None:
from nuclia.sdk.auth import NucliaAuth
def get_auth(config_path: Optional[str] = None) -> NucliaAuth:
get_config(config_path=config_path)
if config_path is not None:
DATA.auth = NucliaAuth(config_path=config_path)
return DATA.auth

if DATA.auth is None:
DATA.auth = NucliaAuth()
return DATA.auth


def get_async_auth() -> AsyncNucliaAuth:
get_config()
if DATA.async_auth is None:
from nuclia.sdk.auth import AsyncNucliaAuth
def get_async_auth(config_path: Optional[str] = None) -> AsyncNucliaAuth:
get_config(config_path=config_path)
if config_path is not None:
return AsyncNucliaAuth(config_path=config_path)

if DATA.async_auth is None:
DATA.async_auth = AsyncNucliaAuth()
return DATA.async_auth


def get_client(kbid: str) -> NucliaDBClient:
auth = get_auth()
def get_client(kbid: str, config_path: Optional[str] = None) -> NucliaDBClient:
auth = get_auth(config_path=config_path)
kb_obj = auth._config.get_kb(kbid)

if kb_obj is None:
Expand Down
18 changes: 9 additions & 9 deletions nuclia/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ async def async_wrapper_checkout_kb(*args, **kwargs):
return await func(*args, **kwargs)
url = kwargs.get("url")
api_key = kwargs.get("api_key")
auth = get_async_auth()
auth = get_async_auth(config_path=kwargs.get("config_path"))
if url is None:
# Get default KB
kbid = auth._config.get_default_kb()
Expand All @@ -68,13 +68,13 @@ def wrapper_checkout_kb(*args, **kwargs):
return func(*args, **kwargs)
url = kwargs.get("url")
api_key = kwargs.get("api_key")
auth = get_auth()
auth = get_auth(config_path=kwargs.get("config_path"))
if url is None:
# Get default KB
kbid = auth._config.get_default_kb()
if kbid is None:
raise NotDefinedDefault()
ndb = get_client(kbid)
ndb = get_client(kbid, config_path=kwargs.get("config_path"))
elif url.find(BASE_DOMAIN) >= 0:
region = url.split(".")[0].split("/")[-1]
ndb = NucliaDBClient(
Expand All @@ -101,7 +101,7 @@ def wrapper_checkout_nucliadb(*args, **kwargs):
if "ndb" in kwargs:
return func(*args, **kwargs)
url = kwargs.get("url")
auth = get_auth()
auth = get_auth(config_path=kwargs.get("config_path"))
if url is None:
# Get default KB
nucliadb = auth._config.get_default_nucliadb()
Expand All @@ -122,7 +122,7 @@ def wrapper_checkout_nucliadb(*args, **kwargs):
def nua(func):
@wraps(func)
async def async_wrapper_checkout_nua(*args, **kwargs):
auth = get_auth()
auth = get_auth(config_path=kwargs.get("config_path"))
nua_id = auth._config.get_default_nua()
if nua_id is None:
raise NotDefinedDefault()
Expand All @@ -137,7 +137,7 @@ async def async_wrapper_checkout_nua(*args, **kwargs):

@wraps(func)
async def async_generative_wrapper_checkout_nua(*args, **kwargs):
auth = get_auth()
auth = get_auth(config_path=kwargs.get("config_path"))
nua_id = auth._config.get_default_nua()
if nua_id is None:
raise NotDefinedDefault()
Expand All @@ -153,7 +153,7 @@ async def async_generative_wrapper_checkout_nua(*args, **kwargs):

@wraps(func)
def wrapper_checkout_nua(*args, **kwargs):
auth = get_auth()
auth = get_auth(config_path=kwargs.get("config_path"))
nua_id = auth._config.get_default_nua()
if nua_id is None:
raise NotDefinedDefault()
Expand All @@ -179,7 +179,7 @@ def account(func):
def wrapper(*args, **kwargs):
account_slug = kwargs.get("account")
account_id = kwargs.get("account_id")
auth = get_auth()
auth = get_auth(config_path=kwargs.get("config_path"))
if account_id is None and account_slug is None:
account_slug = auth._config.get_default_account()
if account_slug is None:
Expand Down Expand Up @@ -223,7 +223,7 @@ def zone(func):
def wrapper_checkout_zone(*args, **kwargs):
zone = kwargs.get("zone")
if not zone:
auth = get_auth()
auth = get_auth(config_path=kwargs.get("config_path"))
kwargs["zone"] = auth._config.get_default_zone()
return func(*args, **kwargs)

Expand Down
3 changes: 2 additions & 1 deletion nuclia/sdk/kb.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def copy(
slug: Optional[str] = None,
destination: str,
override: Optional[bool] = False,
config_path: Optional[str] = None,
**kwargs,
):
ndb = kwargs["ndb"]
Expand Down Expand Up @@ -248,7 +249,7 @@ def copy(
remote_files[file_id] = file.value
else:
files_to_upload.append({"id": file_id, "data": file.value})
destination_kb = get_client(destination)
destination_kb = get_client(destination, config_path=config_path)
if override:
try:
self.resource.delete(
Expand Down

0 comments on commit 2cc60e0

Please sign in to comment.