Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Binh Vu committed Oct 8, 2024
1 parent b987453 commit 742c4e8
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 9 deletions.
15 changes: 11 additions & 4 deletions sand/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import Literal
from typing import Literal, TypedDict

import serde.yaml
from sm.misc.funcs import import_attr
Expand Down Expand Up @@ -65,7 +65,7 @@ def from_yaml(infile: Path | str) -> AppConfig:
),
assistant=FnConfig(
default=obj["assistant"].pop("default"),
funcs=obj["assistant"],
funcs=AppConfig._parse_args(obj["assistant"], cwd),
),
export=FnConfig(default=obj["export"].pop("default"), funcs=obj["export"]),
)
Expand All @@ -86,6 +86,8 @@ def _parse_args(obj: dict, cwd: Path):
for k, v in obj.items():
if isinstance(v, str) and v.startswith(RELPATH_CONST):
out[k] = str(cwd / v[len(RELPATH_CONST) :])
elif isinstance(v, dict):
out[k] = AppConfig._parse_args(v, cwd)
else:
out[k] = v
return out
Expand All @@ -97,12 +99,17 @@ class SearchConfig:
ontology: str


class FnConstructor(TypedDict):
constructor: str
args: dict


@dataclass
class FnConfig:
default: str
funcs: dict[str, str]
funcs: dict[str, str | FnConstructor]

def get_func(self, name: Literal["default"] | str) -> str:
def get_func(self, name: Literal["default"] | str) -> str | FnConstructor:
if name == "default":
return self.funcs[self.default]
return self.funcs[name]
Expand Down
69 changes: 67 additions & 2 deletions sand/controllers/project.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import tempfile
from dataclasses import asdict
from pathlib import Path
from typing import List

import orjson
from flask import jsonify, request
from flask import jsonify, make_response, request
from gena import generate_api
from gena.deserializer import get_dataclass_deserializer
from sm.dataset import Dataset, Example, FullTable
from sm.prelude import I, M, O
from werkzeug.exceptions import BadRequest

from sand.controllers.helpers.upload import (
ALLOWED_EXTENSIONS,
CSVParserOpts,
Expand All @@ -14,8 +20,10 @@
parse_upload,
save_upload,
)
from sand.controllers.table import get_friendly_fs_name
from sand.models import Project
from werkzeug.exceptions import BadRequest
from sand.models.semantic_model import SemanticModel
from sand.models.table import Table, TableRow

project_bp = generate_api(Project)

Expand Down Expand Up @@ -127,3 +135,60 @@ def upload(id: int):
],
}
)


@project_bp.route(f"/{project_bp.name}/<id>/export", methods=["GET"])
def export(id: int):
"""Export tables from the project"""
try:
project = Project.get_by_id(id)
except:
raise BadRequest("Project not found")

examples = []
for tbl in Table.select().where(Table.project == project):
tblrows = list(TableRow.select().where(TableRow.table == tbl))
basetbl = I.ColumnBasedTable.from_rows(
records=[row.row for row in tblrows],
table_id=tbl.name,
headers=tbl.columns,
strict=True,
)
table = FullTable(
table=basetbl,
context=(
I.Context(
page_title=tbl.context_page.title,
page_url=tbl.context_page.url,
entities=(
[I.EntityId(tbl.context_page.entity, "")]
if tbl.context_page.entity is not None
else []
),
)
if tbl.context_page is not None
else I.Context()
),
links=M.Matrix.default(basetbl.shape(), list),
)
table.links = table.links.map_index(
lambda ri, ci: tblrows[ri].links.get(ci, [])
)

ex = Example(id=table.table.table_id, sms=[], table=table)

for sm in SemanticModel.select().where(SemanticModel.table == tbl):
ex.sms.append(sm.data)

examples.append(ex)

with tempfile.NamedTemporaryFile(suffix=".zip") as file:
Dataset(Path(file.name)).save(examples, table_fmt_indent=2)
dataset = Path(file.name).read_bytes()

resp = make_response(dataset)
resp.headers["Content-Type"] = "application/zip; charset=utf-8"
resp.headers["Content-Disposition"] = (
f"attachment; filename={get_friendly_fs_name(str(project.name))}.zip"
)
return resp
12 changes: 9 additions & 3 deletions sand/helpers/service_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

from typing import Generic, Literal, TypeVar

from sm.misc.funcs import import_func

from sand.config import FnConfig
from sm.misc.funcs import import_func

T = TypeVar("T")

Expand All @@ -19,7 +18,14 @@ def get_default(self) -> T:

def get(self, name: Literal["default"] | str) -> T:
if name not in self.services:
self.services[name] = import_func(self.cfg.get_func(name))()
fnconstructor = self.cfg.get_func(name)
if isinstance(fnconstructor, str):
fn = import_func(fnconstructor)()
else:
fn = import_func(fnconstructor["constructor"])(
**fnconstructor.get("args", {})
)
self.services[name] = fn
return self.services[name]

def get_available_providers(self) -> list[str]:
Expand Down
8 changes: 8 additions & 0 deletions sand/models/table.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from __future__ import annotations

class Table:
id: int
name: str
description: str
columns: list[str]
project_id: int

0 comments on commit 742c4e8

Please sign in to comment.