Skip to content

Commit

Permalink
Add dryRun parameter for query API (#593)
Browse files Browse the repository at this point in the history
* Add dryRun parameter for query API

And split ibis.py by data source

* Use http status code instead of json response

* Rename to Connector
  • Loading branch information
grieve54706 authored Jun 5, 2024
1 parent 4e24567 commit 2e08721
Show file tree
Hide file tree
Showing 13 changed files with 247 additions and 92 deletions.
24 changes: 24 additions & 0 deletions ibis-server/app/logger.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from functools import wraps

from app.config import get_config

Expand All @@ -7,3 +8,26 @@

def get_logger(name):
return logging.getLogger(name)


def log_dto(f):
logger = get_logger('app.routers.ibis')

@wraps(f)
def wrapper(*args, **kwargs):
logger.debug(f'DTO: {kwargs["dto"]}')
return f(*args, **kwargs)

return wrapper


def log_rewritten(f):
logger = get_logger('app.mdl.rewriter')

@wraps(f)
def wrapper(*args, **kwargs):
rs = f(*args, **kwargs)
logger.debug(f'Rewritten SQL: {rs}')
return rs

return wrapper
14 changes: 9 additions & 5 deletions ibis-server/app/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from fastapi import FastAPI
from fastapi.responses import JSONResponse, RedirectResponse
from fastapi.responses import RedirectResponse
from starlette.responses import PlainTextResponse

from app.config import get_config
from app.model.connector import QueryDryRunError
from app.routers import ibis

app = FastAPI()
Expand All @@ -23,9 +25,11 @@ def config():
return get_config()


@app.exception_handler(QueryDryRunError)
async def query_dry_run_error_handler(request, exc: QueryDryRunError):
return PlainTextResponse(str(exc), status_code=422)


@app.exception_handler(Exception)
async def exception_handler(request, exc: Exception):
return JSONResponse(
status_code=500,
content=str(exc),
)
return PlainTextResponse(str(exc), status_code=500)
2 changes: 1 addition & 1 deletion ibis-server/app/mdl/rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import orjson

from app.config import get_config
from app.util import log_rewritten
from app.logger import log_rewritten

wren_engine_endpoint = get_config().wren_engine_endpoint

Expand Down
33 changes: 33 additions & 0 deletions ibis-server/app/model/connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from json import loads

from app.mdl.rewriter import rewrite
from app.model.data_source import DataSource, ConnectionInfo


class Connector:
def __init__(self, data_source: DataSource, connection_info: ConnectionInfo, manifest_str: str):
self.data_source = data_source
self.connection = self.data_source.get_connection(connection_info)
self.manifest_str = manifest_str

def query(self, sql) -> dict:
rewritten_sql = rewrite(self.manifest_str, sql)
return self._to_json(self.connection.sql(rewritten_sql, dialect='trino').to_pandas())

def dry_run(self, sql) -> None:
try:
rewritten_sql = rewrite(self.manifest_str, sql)
self.connection.sql(rewritten_sql, dialect='trino')
except Exception as e:
raise QueryDryRunError(f'Exception: {type(e)}, message: {str(e)}')

@staticmethod
def _to_json(df):
json_obj = loads(df.to_json(orient='split'))
del json_obj['index']
json_obj['dtypes'] = df.dtypes.apply(lambda x: x.name).to_dict()
return json_obj


class QueryDryRunError(Exception):
pass
33 changes: 0 additions & 33 deletions ibis-server/app/routers/ibis.py

This file was deleted.

13 changes: 13 additions & 0 deletions ibis-server/app/routers/ibis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from fastapi import APIRouter

import app.routers.ibis.bigquery as bigquery
import app.routers.ibis.postgres as postgres
import app.routers.ibis.snowflake as snowflake

prefix = "/v2/ibis"

router = APIRouter(prefix=prefix)

router.include_router(bigquery.router)
router.include_router(postgres.router)
router.include_router(snowflake.router)
23 changes: 23 additions & 0 deletions ibis-server/app/routers/ibis/bigquery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Annotated

from fastapi import APIRouter, Query, Response
from fastapi.responses import JSONResponse

from app.logger import log_dto
from app.model.connector import Connector
from app.model.data_source import DataSource
from app.model.dto import BigQueryDTO

router = APIRouter(prefix='/bigquery', tags=['bigquery'])

data_source = DataSource.bigquery


@router.post("/query")
@log_dto
def query(dto: BigQueryDTO, dry_run: Annotated[bool, Query(alias="dryRun")] = False) -> Response:
connector = Connector(data_source, dto.connection_info, dto.manifest_str)
if dry_run:
connector.dry_run(dto.sql)
return Response(status_code=204)
return JSONResponse(connector.query(dto.sql))
23 changes: 23 additions & 0 deletions ibis-server/app/routers/ibis/postgres.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Annotated

from fastapi import APIRouter, Query, Response
from fastapi.responses import JSONResponse

from app.logger import log_dto
from app.model.connector import Connector
from app.model.data_source import DataSource
from app.model.dto import PostgresDTO

router = APIRouter(prefix='/postgres', tags=['postgres'])

data_source = DataSource.postgres


@router.post("/query")
@log_dto
def query(dto: PostgresDTO, dry_run: Annotated[bool, Query(alias="dryRun")] = False) -> Response:
connector = Connector(data_source, dto.connection_info, dto.manifest_str)
if dry_run:
connector.dry_run(dto.sql)
return Response(status_code=204)
return JSONResponse(connector.query(dto.sql))
23 changes: 23 additions & 0 deletions ibis-server/app/routers/ibis/snowflake.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Annotated

from fastapi import APIRouter, Query, Response
from fastapi.responses import JSONResponse

from app.logger import log_dto
from app.model.connector import Connector
from app.model.data_source import DataSource
from app.model.dto import SnowflakeDTO

router = APIRouter(prefix='/snowflake', tags=['snowflake'])

data_source = DataSource.snowflake


@router.post("/query")
@log_dto
def query(dto: SnowflakeDTO, dry_run: Annotated[bool, Query(alias="dryRun")] = False) -> Response:
connector = Connector(data_source, dto.connection_info, dto.manifest_str)
if dry_run:
connector.dry_run(dto.sql)
return Response(status_code=204)
return JSONResponse(connector.query(dto.sql))
36 changes: 0 additions & 36 deletions ibis-server/app/util.py

This file was deleted.

39 changes: 33 additions & 6 deletions ibis-server/tests/routers/ibis/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def get_connection_info():
"credentials": os.getenv("TEST_BIG_QUERY_CREDENTIALS_BASE64_JSON")
}

def test_bigquery(self, manifest_str: str):
def test_query(self, manifest_str: str):
connection_info = self.get_connection_info()
response = client.post(
url="/v2/ibis/bigquery/query",
Expand All @@ -67,13 +67,13 @@ def test_bigquery(self, manifest_str: str):
assert result['data'][0][0] is not None
assert result['dtypes'] is not None

def test_no_manifest(self):
def test_query_without_manifest(self):
connection_info = self.get_connection_info()
response = client.post(
url="/v2/ibis/bigquery/query",
json={
"connectionInfo": connection_info,
"sql": "SELECT * FROM Orders LIMIT 1"
"sql": 'SELECT * FROM "Orders" LIMIT 1'
}
)
assert response.status_code == 422
Expand All @@ -83,7 +83,7 @@ def test_no_manifest(self):
assert result['detail'][0]['loc'] == ['body', 'manifestStr']
assert result['detail'][0]['msg'] == 'Field required'

def test_no_sql(self, manifest_str: str):
def test_query_without_sql(self, manifest_str: str):
connection_info = self.get_connection_info()
response = client.post(
url="/v2/ibis/bigquery/query",
Expand All @@ -99,12 +99,12 @@ def test_no_sql(self, manifest_str: str):
assert result['detail'][0]['loc'] == ['body', 'sql']
assert result['detail'][0]['msg'] == 'Field required'

def test_no_connection_info(self, manifest_str: str):
def test_query_without_connection_info(self, manifest_str: str):
response = client.post(
url="/v2/ibis/bigquery/query",
json={
"manifestStr": manifest_str,
"sql": "SELECT * FROM Orders LIMIT 1"
"sql": 'SELECT * FROM "Orders" LIMIT 1'
}
)
assert response.status_code == 422
Expand All @@ -113,3 +113,30 @@ def test_no_connection_info(self, manifest_str: str):
assert result['detail'][0]['type'] == 'missing'
assert result['detail'][0]['loc'] == ['body', 'connectionInfo']
assert result['detail'][0]['msg'] == 'Field required'

def test_query_with_dry_run(self, manifest_str: str):
connection_info = self.get_connection_info()
response = client.post(
url="/v2/ibis/bigquery/query",
params={"dryRun": True},
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": 'SELECT * FROM "Orders" LIMIT 1'
}
)
assert response.status_code == 204

def test_query_with_dry_run_and_invalid_sql(self, manifest_str: str):
connection_info = self.get_connection_info()
response = client.post(
url="/v2/ibis/bigquery/query",
params={"dryRun": True},
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": 'SELECT * FROM X'
}
)
assert response.status_code == 422
assert response.text is not None
Loading

0 comments on commit 2e08721

Please sign in to comment.