99# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1010# See the License for the specific language governing permissions and
1111# limitations under the License.
12+ from typing import Any , Optional
13+
1214from sqlalchemy .sql import compiler
1315from sqlalchemy .sql .base import DialectKWArgs
1416
9294
9395
9496class TrinoSQLCompiler (compiler .SQLCompiler ):
95- def limit_clause (self , select , ** kw ) :
97+ def limit_clause (self , select : Any , ** kw : Any ) -> str :
9698 """
9799 Trino support only OFFSET...LIMIT but not LIMIT...OFFSET syntax.
98100 """
@@ -103,15 +105,15 @@ def limit_clause(self, select, **kw):
103105 text += "\n LIMIT " + self .process (select ._limit_clause , ** kw )
104106 return text
105107
106- def visit_table (self , table , asfrom = False , iscrud = False , ashint = False ,
107- fromhints = None , use_schema = True , ** kwargs ) :
108+ def visit_table (self , table : Any , asfrom : bool = False , iscrud : bool = False , ashint : bool = False ,
109+ fromhints : Optional [ Any ] = None , use_schema : bool = True , ** kwargs : Any ) -> str :
108110 sql = super (TrinoSQLCompiler , self ).visit_table (
109111 table , asfrom , iscrud , ashint , fromhints , use_schema , ** kwargs
110112 )
111113 return self .add_catalog (sql , table )
112114
113115 @staticmethod
114- def add_catalog (sql , table ) :
116+ def add_catalog (sql : str , table : Any ) -> str :
115117 if table is None or not isinstance (table , DialectKWArgs ):
116118 return sql
117119
@@ -131,7 +133,7 @@ class TrinoDDLCompiler(compiler.DDLCompiler):
131133
132134
133135class TrinoTypeCompiler (compiler .GenericTypeCompiler ):
134- def visit_FLOAT (self , type_ , ** kw ) :
136+ def visit_FLOAT (self , type_ : Any , ** kw : Any ) -> str :
135137 precision = type_ .precision or 32
136138 if 0 <= precision <= 32 :
137139 return self .visit_REAL (type_ , ** kw )
@@ -140,37 +142,37 @@ def visit_FLOAT(self, type_, **kw):
140142 else :
141143 raise ValueError (f"type.precision must be in range [0, 64], got { type_ .precision } " )
142144
143- def visit_DOUBLE (self , type_ , ** kw ) :
145+ def visit_DOUBLE (self , type_ : Any , ** kw : Any ) -> str :
144146 return "DOUBLE"
145147
146- def visit_NUMERIC (self , type_ , ** kw ) :
148+ def visit_NUMERIC (self , type_ : Any , ** kw : Any ) -> str :
147149 return self .visit_DECIMAL (type_ , ** kw )
148150
149- def visit_NCHAR (self , type_ , ** kw ) :
151+ def visit_NCHAR (self , type_ : Any , ** kw : Any ) -> str :
150152 return self .visit_CHAR (type_ , ** kw )
151153
152- def visit_NVARCHAR (self , type_ , ** kw ) :
154+ def visit_NVARCHAR (self , type_ : Any , ** kw : Any ) -> str :
153155 return self .visit_VARCHAR (type_ , ** kw )
154156
155- def visit_TEXT (self , type_ , ** kw ) :
157+ def visit_TEXT (self , type_ : Any , ** kw : Any ) -> str :
156158 return self .visit_VARCHAR (type_ , ** kw )
157159
158- def visit_BINARY (self , type_ , ** kw ) :
160+ def visit_BINARY (self , type_ : Any , ** kw : Any ) -> str :
159161 return self .visit_VARBINARY (type_ , ** kw )
160162
161- def visit_CLOB (self , type_ , ** kw ) :
163+ def visit_CLOB (self , type_ : Any , ** kw : Any ) -> str :
162164 return self .visit_VARCHAR (type_ , ** kw )
163165
164- def visit_NCLOB (self , type_ , ** kw ) :
166+ def visit_NCLOB (self , type_ : Any , ** kw : Any ) -> str :
165167 return self .visit_VARCHAR (type_ , ** kw )
166168
167- def visit_BLOB (self , type_ , ** kw ) :
169+ def visit_BLOB (self , type_ : Any , ** kw : Any ) -> str :
168170 return self .visit_VARBINARY (type_ , ** kw )
169171
170- def visit_DATETIME (self , type_ , ** kw ) :
172+ def visit_DATETIME (self , type_ : Any , ** kw : Any ) -> str :
171173 return self .visit_TIMESTAMP (type_ , ** kw )
172174
173- def visit_TIMESTAMP (self , type_ , ** kw ) :
175+ def visit_TIMESTAMP (self , type_ : Any , ** kw : Any ) -> str :
174176 datatype = "TIMESTAMP"
175177 precision = getattr (type_ , "precision" , None )
176178 if precision not in range (0 , 13 ) and precision is not None :
@@ -182,7 +184,7 @@ def visit_TIMESTAMP(self, type_, **kw):
182184
183185 return datatype
184186
185- def visit_TIME (self , type_ , ** kw ) :
187+ def visit_TIME (self , type_ : Any , ** kw : Any ) -> str :
186188 datatype = "TIME"
187189 precision = getattr (type_ , "precision" , None )
188190 if precision not in range (0 , 13 ) and precision is not None :
@@ -193,13 +195,13 @@ def visit_TIME(self, type_, **kw):
193195 datatype += " WITH TIME ZONE"
194196 return datatype
195197
196- def visit_JSON (self , type_ , ** kw ) :
198+ def visit_JSON (self , type_ : Any , ** kw : Any ) -> str :
197199 return 'JSON'
198200
199201
200202class TrinoIdentifierPreparer (compiler .IdentifierPreparer ):
201203 reserved_words = RESERVED_WORDS
202204
203- def format_table (self , table , use_schema = True , name = None ):
205+ def format_table (self , table : Any , use_schema : bool = True , name : Optional [ str ] = None ) -> str :
204206 result = super (TrinoIdentifierPreparer , self ).format_table (table , use_schema , name )
205207 return TrinoSQLCompiler .add_catalog (result , table )
0 commit comments