-
Notifications
You must be signed in to change notification settings - Fork 16
/
mappings.py
69 lines (48 loc) · 2.03 KB
/
mappings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import json
import pathlib
from typing import Any, Dict, Type
from jinja2 import Template
from auto_labeling_pipeline.labels import ClassificationLabels, Labels, Seq2seqLabels, SequenceLabels
TEMPLATE_DIR = pathlib.Path(__file__).parent / 'templates'
class MappingTemplate:
label_collection: Type[Labels]
template_file: str = ''
def __init__(self, label_collection: Type[Labels] = Labels, template: str = ''):
if self.template_file:
template = self.load()
self.template = template
if label_collection is not Labels:
self.label_collection = label_collection
def render(self, response: Dict) -> Labels:
template = Template(self.template)
rendered_json = template.render(input=response)
labels = json.loads(rendered_json)
labels = self.label_collection(labels)
return labels
def load(self) -> str:
filepath = TEMPLATE_DIR / self.template_file
with open(filepath) as f:
return f.read()
def dict(self) -> Dict[str, Any]:
return {
'template': self.template,
'collection': self.label_collection
}
class AmazonComprehendSentimentTemplate(MappingTemplate):
label_collection = ClassificationLabels
template_file = 'amazon_comprehend_sentiment.jinja2'
class GCPImageLabelDetectionTemplate(MappingTemplate):
label_collection = ClassificationLabels
template_file = 'gcp_image_label_detection.jinja2'
class AmazonComprehendEntityTemplate(MappingTemplate):
label_collection = SequenceLabels
template_file = 'amazon_comprehend_entity.jinja2'
class GCPEntitiesTemplate(MappingTemplate):
label_collection = SequenceLabels
template_file = 'gcp_entities.jinja2'
class AmazonRekognitionLabelDetectionTemplate(MappingTemplate):
label_collection = ClassificationLabels
template_file = 'amazon_rekognition_label_detection.jinja2'
class GCPSpeechToTextTemplate(MappingTemplate):
label_collection = Seq2seqLabels
template_file = 'gcp_speech_to_text.jinja2'