diff --git a/ibis-server/app/mdl/substitute.py b/ibis-server/app/mdl/substitute.py index c91759424..4c20c6c36 100644 --- a/ibis-server/app/mdl/substitute.py +++ b/ibis-server/app/mdl/substitute.py @@ -1,17 +1,15 @@ -import base64 - -from orjson import orjson from sqlglot import exp, parse_one from sqlglot.optimizer.scope import build_scope from app.model import UnprocessableEntityError from app.model.data_source import DataSource +from app.util import base64_to_dict class ModelSubstitute: def __init__(self, data_source: DataSource, manifest_str: str): self.data_source = data_source - self.manifest = orjson.loads(base64.b64decode(manifest_str).decode("utf-8")) + self.manifest = base64_to_dict(manifest_str) self.model_dict = self._build_model_dict(self.manifest["models"]) def substitute(self, sql: str, write: str | None = None) -> str: diff --git a/ibis-server/app/model/validator.py b/ibis-server/app/model/validator.py index b4665196c..5a3c4e931 100644 --- a/ibis-server/app/model/validator.py +++ b/ibis-server/app/model/validator.py @@ -1,11 +1,9 @@ from __future__ import annotations -import base64 -import json - from app.mdl.rewriter import Rewriter from app.model import NotFoundError, UnprocessableEntityError from app.model.connector import Connector +from app.util import base64_to_dict rules = ["column_is_valid", "relationship_is_valid"] @@ -48,8 +46,8 @@ async def _validate_relationship_is_valid( relationship_name = parameters.get("relationshipName") if relationship_name is None: raise MissingRequiredParameterError("relationship") - decoded_manifest = base64.b64decode(manifest_str).decode("utf-8") - manifest = json.loads(decoded_manifest) + + manifest = base64_to_dict(manifest_str) relationship = list( filter(lambda r: r["name"] == relationship_name, manifest["relationships"]) diff --git a/ibis-server/app/util.py b/ibis-server/app/util.py index a23240559..b80616857 100644 --- a/ibis-server/app/util.py +++ b/ibis-server/app/util.py @@ -1,3 +1,4 @@ +import base64 import datetime import decimal @@ -6,6 +7,10 @@ from pandas.core.dtypes.common import is_datetime64_any_dtype +def base64_to_dict(base64_str: str) -> dict: + return orjson.loads(base64.b64decode(base64_str).decode("utf-8")) + + def to_json(df: pd.DataFrame) -> dict: for column in df.columns: if is_datetime64_any_dtype(df[column].dtype):