-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdatamodels.py
97 lines (71 loc) · 1.88 KB
/
datamodels.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import base64
from dataclasses import field
from llmx import TextGenerationConfig
from pydantic.dataclasses import dataclass
from typing import Optional, List, Any, Union, Dict
@dataclass
class Goal:
question: str
visualization: str
rationale : str
index: Optional[int] = 0
def _repr_markdown_(self):
return f"""
### Goal {self.index}
---
**Question:** {self.question}
**Visualization:** `{self.visualization}`
**Rationale:** {self.rationale}
"""
@dataclass
class Persona:
persona: str
rationale: str
def _repr_markdown_(self):
return f"""
### Persona
---
**Persona:** {self.persona}
**Rationale:** {self.rationale}
"""
@dataclass
class Summary:
name: str
file_name: str
dataset_description: str
field_names: List[Any]
fields: Optional[List[Any]] = None
def _repr_markdown_(self):
field_lines = "\n".join([f"- **{name}:** {field}" for name,
field in zip(self.field_names, self.fields)])
return f"""
## Dataset Summary
---
**Name:** {self.name}
**File Name:** {self.file_name}
**Dataset Description:**
{self.dataset_description}
**Fields:**
{field_lines}
"""
@dataclass
class ChartExecutorResponse:
spec: Optional[Union[str, Dict]]
status: bool
raster: Optional[str]
code: str
library: str
error: Optional[Dict] = None
def _repr_mimebundle(self, include=None, exclude=None):
bundle = {"text/plain": self.code}
if self.raster is not None:
bundle["image/png"] = self.raster
if self.spec is not None:
bundle["application/vnd.vegalite.v5+json"] = self.spec
return bundle
def savefig(self, path):
if self.raster:
with open(path, 'wb') as f:
f.write(base64.b64decode(self.raster))
else:
raise FileNotFoundError("No ratser image to save")