Skip to content

Commit c94d6fc

Browse files
committed
Ignore custom data codec for internal introspection
Fixes: MagicStack#617
1 parent 68b40cb commit c94d6fc

9 files changed

+82
-32
lines changed

asyncpg/connection.py

+33-12
Original file line numberDiff line numberDiff line change
@@ -342,13 +342,16 @@ async def _get_statement(
342342
*,
343343
named: bool=False,
344344
use_cache: bool=True,
345+
ignore_custom_codec=False,
345346
record_class=None
346347
):
347348
if record_class is None:
348349
record_class = self._protocol.get_record_class()
349350

350351
if use_cache:
351-
statement = self._stmt_cache.get((query, record_class))
352+
statement = self._stmt_cache.get(
353+
(query, record_class, ignore_custom_codec)
354+
)
352355
if statement is not None:
353356
return statement
354357

@@ -371,6 +374,7 @@ async def _get_statement(
371374
query,
372375
timeout,
373376
record_class=record_class,
377+
ignore_custom_codec=ignore_custom_codec,
374378
)
375379
need_reprepare = False
376380
types_with_missing_codecs = statement._init_types()
@@ -415,7 +419,8 @@ async def _get_statement(
415419
)
416420

417421
if use_cache:
418-
self._stmt_cache.put((query, record_class), statement)
422+
self._stmt_cache.put(
423+
(query, record_class, ignore_custom_codec), statement)
419424

420425
# If we've just created a new statement object, check if there
421426
# are any statements for GC.
@@ -426,7 +431,12 @@ async def _get_statement(
426431

427432
async def _introspect_types(self, typeoids, timeout):
428433
return await self.__execute(
429-
self._intro_query, (list(typeoids),), 0, timeout)
434+
self._intro_query,
435+
(list(typeoids),),
436+
0,
437+
timeout,
438+
ignore_custom_codec=True,
439+
)
430440

431441
async def _introspect_type(self, typename, schema):
432442
if (
@@ -439,20 +449,22 @@ async def _introspect_type(self, typename, schema):
439449
[typeoid],
440450
limit=0,
441451
timeout=None,
452+
ignore_custom_codec=True,
442453
)
443-
if rows:
444-
typeinfo = rows[0]
445-
else:
446-
typeinfo = None
447454
else:
448-
typeinfo = await self.fetchrow(
449-
introspection.TYPE_BY_NAME, typename, schema)
455+
rows = await self._execute(
456+
introspection.TYPE_BY_NAME,
457+
[typename, schema],
458+
limit=1,
459+
timeout=None,
460+
ignore_custom_codec=True,
461+
)
450462

451-
if not typeinfo:
463+
if not rows:
452464
raise ValueError(
453465
'unknown type: {}.{}'.format(schema, typename))
454466

455-
return typeinfo
467+
return rows[0]
456468

457469
def cursor(
458470
self,
@@ -1325,7 +1337,9 @@ def _mark_stmts_as_closed(self):
13251337
def _maybe_gc_stmt(self, stmt):
13261338
if (
13271339
stmt.refs == 0
1328-
and not self._stmt_cache.has((stmt.query, stmt.record_class))
1340+
and not self._stmt_cache.has(
1341+
(stmt.query, stmt.record_class, stmt.ignore_custom_codec)
1342+
)
13291343
):
13301344
# If low-level `stmt` isn't referenced from any high-level
13311345
# `PreparedStatement` object and is not in the `_stmt_cache`:
@@ -1589,6 +1603,7 @@ async def _execute(
15891603
timeout,
15901604
*,
15911605
return_status=False,
1606+
ignore_custom_codec=False,
15921607
record_class=None
15931608
):
15941609
with self._stmt_exclusive_section:
@@ -1599,6 +1614,7 @@ async def _execute(
15991614
timeout,
16001615
return_status=return_status,
16011616
record_class=record_class,
1617+
ignore_custom_codec=ignore_custom_codec,
16021618
)
16031619
return result
16041620

@@ -1610,6 +1626,7 @@ async def __execute(
16101626
timeout,
16111627
*,
16121628
return_status=False,
1629+
ignore_custom_codec=False,
16131630
record_class=None
16141631
):
16151632
executor = lambda stmt, timeout: self._protocol.bind_execute(
@@ -1620,6 +1637,7 @@ async def __execute(
16201637
executor,
16211638
timeout,
16221639
record_class=record_class,
1640+
ignore_custom_codec=ignore_custom_codec,
16231641
)
16241642

16251643
async def _executemany(self, query, args, timeout):
@@ -1637,20 +1655,23 @@ async def _do_execute(
16371655
timeout,
16381656
retry=True,
16391657
*,
1658+
ignore_custom_codec=False,
16401659
record_class=None
16411660
):
16421661
if timeout is None:
16431662
stmt = await self._get_statement(
16441663
query,
16451664
None,
16461665
record_class=record_class,
1666+
ignore_custom_codec=ignore_custom_codec,
16471667
)
16481668
else:
16491669
before = time.monotonic()
16501670
stmt = await self._get_statement(
16511671
query,
16521672
timeout,
16531673
record_class=record_class,
1674+
ignore_custom_codec=ignore_custom_codec,
16541675
)
16551676
after = time.monotonic()
16561677
timeout -= after - before

asyncpg/protocol/codecs/base.pxd

+2-1
Original file line numberDiff line numberDiff line change
@@ -166,5 +166,6 @@ cdef class DataCodecConfig:
166166
dict _derived_type_codecs
167167
dict _custom_type_codecs
168168

169-
cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format)
169+
cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format,
170+
bint ignore_custom_codec=*)
170171
cdef inline Codec get_any_local_codec(self, uint32_t oid)

asyncpg/protocol/codecs/base.pyx

+12-10
Original file line numberDiff line numberDiff line change
@@ -692,18 +692,20 @@ cdef class DataCodecConfig:
692692

693693
return codec
694694

695-
cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format):
695+
cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format,
696+
bint ignore_custom_codec=False):
696697
cdef Codec codec
697698

698-
codec = self.get_any_local_codec(oid)
699-
if codec is not None:
700-
if codec.format != format:
701-
# The codec for this OID has been overridden by
702-
# set_{builtin}_type_codec with a different format.
703-
# We must respect that and not return a core codec.
704-
return None
705-
else:
706-
return codec
699+
if not ignore_custom_codec:
700+
codec = self.get_any_local_codec(oid)
701+
if codec is not None:
702+
if codec.format != format:
703+
# The codec for this OID has been overridden by
704+
# set_{builtin}_type_codec with a different format.
705+
# We must respect that and not return a core codec.
706+
return None
707+
else:
708+
return codec
707709

708710
codec = get_core_codec(oid, format)
709711
if codec is not None:

asyncpg/protocol/prepared_stmt.pxd

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ cdef class PreparedStatementState:
1212
readonly bint closed
1313
readonly int refs
1414
readonly type record_class
15+
readonly bint ignore_custom_codec
1516

1617

1718
list row_desc

asyncpg/protocol/prepared_stmt.pyx

+7-3
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ cdef class PreparedStatementState:
1616
str name,
1717
str query,
1818
BaseProtocol protocol,
19-
type record_class
19+
type record_class,
20+
bint ignore_custom_codec
2021
):
2122
self.name = name
2223
self.query = query
@@ -28,6 +29,7 @@ cdef class PreparedStatementState:
2829
self.closed = False
2930
self.refs = 0
3031
self.record_class = record_class
32+
self.ignore_custom_codec = ignore_custom_codec
3133

3234
def _get_parameters(self):
3335
cdef Codec codec
@@ -205,7 +207,8 @@ cdef class PreparedStatementState:
205207
cols_mapping[col_name] = i
206208
cols_names.append(col_name)
207209
oid = row[3]
208-
codec = self.settings.get_data_codec(oid)
210+
codec = self.settings.get_data_codec(
211+
oid, ignore_custom_codec=self.ignore_custom_codec)
209212
if codec is None or not codec.has_decoder():
210213
raise exceptions.InternalClientError(
211214
'no decoder for OID {}'.format(oid))
@@ -230,7 +233,8 @@ cdef class PreparedStatementState:
230233

231234
for i from 0 <= i < self.args_num:
232235
p_oid = self.parameters_desc[i]
233-
codec = self.settings.get_data_codec(p_oid)
236+
codec = self.settings.get_data_codec(
237+
p_oid, ignore_custom_codec=self.ignore_custom_codec)
234238
if codec is None or not codec.has_encoder():
235239
raise exceptions.InternalClientError(
236240
'no encoder for OID {}'.format(p_oid))

asyncpg/protocol/protocol.pyx

+2-1
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ cdef class BaseProtocol(CoreProtocol):
145145
async def prepare(self, stmt_name, query, timeout,
146146
*,
147147
PreparedStatementState state=None,
148+
ignore_custom_codec=False,
148149
record_class):
149150
if self.cancel_waiter is not None:
150151
await self.cancel_waiter
@@ -161,7 +162,7 @@ cdef class BaseProtocol(CoreProtocol):
161162
self.last_query = query
162163
if state is None:
163164
state = PreparedStatementState(
164-
stmt_name, query, self, record_class)
165+
stmt_name, query, self, record_class, ignore_custom_codec)
165166
self.statement = state
166167
except Exception as ex:
167168
waiter.set_exception(ex)

asyncpg/protocol/settings.pxd

+2-1
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,5 @@ cdef class ConnectionSettings(pgproto.CodecContext):
2626
cpdef inline set_builtin_type_codec(
2727
self, typeoid, typename, typeschema, typekind, alias_to, format)
2828
cpdef inline Codec get_data_codec(
29-
self, uint32_t oid, ServerDataFormat format=*)
29+
self, uint32_t oid, ServerDataFormat format=*,
30+
bint ignore_custom_codec=*)

asyncpg/protocol/settings.pyx

+8-4
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,18 @@ cdef class ConnectionSettings(pgproto.CodecContext):
8787
typekind, alias_to, _format)
8888

8989
cpdef inline Codec get_data_codec(self, uint32_t oid,
90-
ServerDataFormat format=PG_FORMAT_ANY):
90+
ServerDataFormat format=PG_FORMAT_ANY,
91+
bint ignore_custom_codec=False):
9192
if format == PG_FORMAT_ANY:
92-
codec = self._data_codecs.get_codec(oid, PG_FORMAT_BINARY)
93+
codec = self._data_codecs.get_codec(
94+
oid, PG_FORMAT_BINARY, ignore_custom_codec)
9395
if codec is None:
94-
codec = self._data_codecs.get_codec(oid, PG_FORMAT_TEXT)
96+
codec = self._data_codecs.get_codec(
97+
oid, PG_FORMAT_TEXT, ignore_custom_codec)
9598
return codec
9699
else:
97-
return self._data_codecs.get_codec(oid, format)
100+
return self._data_codecs.get_codec(
101+
oid, format, ignore_custom_codec)
98102

99103
def __getattr__(self, name):
100104
if not name.startswith('_'):

tests/test_introspection.py

+15
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,20 @@ def tearDownClass(cls):
4343

4444
super().tearDownClass()
4545

46+
def setUp(self):
47+
super().setUp()
48+
self.loop.run_until_complete(self._add_custom_codec(self.con))
49+
50+
async def _add_custom_codec(self, conn):
51+
# mess up with the codec - builtin introspection shouldn't be affected
52+
await conn.set_type_codec(
53+
"oid",
54+
schema="pg_catalog",
55+
encoder=lambda value: None,
56+
decoder=lambda value: None,
57+
format="text",
58+
)
59+
4660
@tb.with_connection_options(database='asyncpg_intro_test')
4761
async def test_introspection_on_large_db(self):
4862
await self.con.execute(
@@ -142,6 +156,7 @@ async def test_introspection_retries_after_cache_bust(self):
142156
# query would cause introspection to retry.
143157
slow_intro_conn = await self.connect(
144158
connection_class=SlowIntrospectionConnection)
159+
await self._add_custom_codec(slow_intro_conn)
145160
try:
146161
await self.con.execute('''
147162
CREATE DOMAIN intro_1_t AS int;

0 commit comments

Comments
 (0)