-
Notifications
You must be signed in to change notification settings - Fork 59
/
Copy pathcolumn.py
109 lines (94 loc) · 3.81 KB
/
column.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import re
from dataclasses import dataclass
from typing import ClassVar, Dict
from dbt.adapters.base.column import Column
from dbt_common.exceptions import DbtRuntimeError
# Taken from the MAX_LENGTH variable in
# https://github.com/trinodb/trino/blob/master/core/trino-spi/src/main/java/io/trino/spi/type/VarcharType.java
TRINO_VARCHAR_MAX_LENGTH = 2147483646
@dataclass
class TrinoColumn(Column):
TYPE_LABELS: ClassVar[Dict[str, str]] = {
"STRING": "VARCHAR",
"FLOAT": "DOUBLE",
}
@property
def data_type(self):
# when varchar has no defined size, default to unbound varchar
# the super().data_type defaults to varchar(256)
if self.dtype.lower() == "varchar" and self.char_size is None:
return self.dtype
return super().data_type
def is_string(self) -> bool:
return self.dtype.lower() in ["varchar", "char", "varbinary", "json"]
def is_float(self) -> bool:
return self.dtype.lower() in [
"real",
"double precision",
"double",
]
def is_integer(self) -> bool:
return self.dtype.lower() in [
"tinyint",
"smallint",
"integer",
"int",
"bigint",
]
def is_numeric(self) -> bool:
return self.dtype.lower() == "decimal"
@classmethod
def string_type(cls, size: int) -> str:
return "varchar({})".format(size)
def string_size(self) -> int:
# override the string_size function to handle the unbound varchar case
if self.dtype.lower() == "varchar" and self.char_size is None:
return TRINO_VARCHAR_MAX_LENGTH
return super().string_size()
@classmethod
def from_description(cls, name: str, raw_data_type: str) -> "Column":
# Most of the Trino data types specify a type and not a precision/scale/charsize
if not raw_data_type.lower().startswith(("varchar", "char", "decimal")):
return cls(name, raw_data_type)
# Trino data types that do specify a precision/scale/charsize:
match = re.match(
r"(?P<type>[^(]+)(?P<size>\([^)]+\))?(?P<type_suffix>[\w ]+)?", raw_data_type
)
if match is None:
raise DbtRuntimeError(f'Could not interpret data type "{raw_data_type}"')
data_type = match.group("type")
size_info = match.group("size")
data_type_suffix = match.group("type_suffix")
if data_type_suffix:
data_type += data_type_suffix
char_size = None
numeric_precision = None
numeric_scale = None
if size_info is not None:
# strip out the parentheses
size_info = size_info[1:-1]
parts = size_info.split(",")
if len(parts) == 1:
try:
char_size = int(parts[0])
except ValueError:
raise DbtRuntimeError(
f'Could not interpret data_type "{raw_data_type}": '
f'could not convert "{parts[0]}" to an integer'
)
elif len(parts) == 2:
try:
numeric_precision = int(parts[0])
except ValueError:
raise DbtRuntimeError(
f'Could not interpret data_type "{raw_data_type}": '
f'could not convert "{parts[0]}" to an integer'
)
try:
numeric_scale = int(parts[1])
except ValueError:
raise DbtRuntimeError(
f'Could not interpret data_type "{raw_data_type}": '
f'could not convert "{parts[1]}" to an integer'
)
return cls(name, data_type, char_size, numeric_precision, numeric_scale)