|
1 | 1 | from typing import List, Dict
|
2 |
| -from ..abcs.database_types import Float, TemporalType, FractionalType, DbPath, TimestampTZ |
| 2 | +from ..abcs.database_types import ( |
| 3 | + Float, |
| 4 | + TemporalType, |
| 5 | + FractionalType, |
| 6 | + DbPath, |
| 7 | + TimestampTZ, |
| 8 | + RedShiftSuper |
| 9 | +) |
3 | 10 | from ..abcs.mixins import AbstractMixin_MD5
|
4 | 11 | from .postgresql import (
|
5 | 12 | PostgreSQL,
|
@@ -40,13 +47,18 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
|
40 | 47 | def normalize_number(self, value: str, coltype: FractionalType) -> str:
|
41 | 48 | return self.to_string(f"{value}::decimal(38,{coltype.precision})")
|
42 | 49 |
|
| 50 | + def normalize_json(self, value: str, _coltype: RedShiftSuper) -> str: |
| 51 | + return f'nvl2({value}, json_serialize({value}), NULL)' |
| 52 | + |
43 | 53 |
|
44 | 54 | class Dialect(PostgresqlDialect):
|
45 | 55 | name = "Redshift"
|
46 | 56 | TYPE_CLASSES = {
|
47 | 57 | **PostgresqlDialect.TYPE_CLASSES,
|
48 | 58 | "double": Float,
|
49 | 59 | "real": Float,
|
| 60 | + # JSON |
| 61 | + "super": RedShiftSuper |
50 | 62 | }
|
51 | 63 | SUPPORTS_INDEXES = False
|
52 | 64 |
|
@@ -109,11 +121,48 @@ def query_external_table_schema(self, path: DbPath) -> Dict[str, tuple]:
|
109 | 121 | assert len(d) == len(rows)
|
110 | 122 | return d
|
111 | 123 |
|
| 124 | + def select_view_columns(self, path: DbPath) -> str: |
| 125 | + _, schema, table = self._normalize_table_path(path) |
| 126 | + |
| 127 | + return ( |
| 128 | + """select * from pg_get_cols('{}.{}') |
| 129 | + cols(view_schema name, view_name name, col_name name, col_type varchar, col_num int) |
| 130 | + """.format(schema, table) |
| 131 | + ) |
| 132 | + |
| 133 | + def query_pg_get_cols(self, path: DbPath) -> Dict[str, tuple]: |
| 134 | + rows = self.query(self.select_view_columns(path), list) |
| 135 | + |
| 136 | + if not rows: |
| 137 | + raise RuntimeError(f"{self.name}: View '{'.'.join(path)}' does not exist, or has no columns") |
| 138 | + |
| 139 | + output = {} |
| 140 | + for r in rows: |
| 141 | + col_name = r[2] |
| 142 | + type_info = r[3].split('(') |
| 143 | + base_type = type_info[0] |
| 144 | + precision = None |
| 145 | + scale = None |
| 146 | + |
| 147 | + if len(type_info) > 1: |
| 148 | + if base_type == 'numeric': |
| 149 | + precision, scale = type_info[1][:-1].split(',') |
| 150 | + precision = int(precision) |
| 151 | + scale = int(scale) |
| 152 | + |
| 153 | + out = [col_name, base_type, None, precision, scale] |
| 154 | + output[col_name] = tuple(out) |
| 155 | + |
| 156 | + return output |
| 157 | + |
112 | 158 | def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
|
113 | 159 | try:
|
114 | 160 | return super().query_table_schema(path)
|
115 | 161 | except RuntimeError:
|
116 |
| - return self.query_external_table_schema(path) |
| 162 | + try: |
| 163 | + return self.query_external_table_schema(path) |
| 164 | + except RuntimeError: |
| 165 | + return self.query_pg_get_cols() |
117 | 166 |
|
118 | 167 | def _normalize_table_path(self, path: DbPath) -> DbPath:
|
119 | 168 | if len(path) == 1:
|
|
0 commit comments