From 01830c328cfd48aeccc645bba99de801ed54c4db Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Mon, 8 May 2017 11:51:55 -0400 Subject: [PATCH] Add support for COPY IN This commit adds two new Connection methods: copy_to_table() and copy_records_to_table() that allow copying data to the specified table either in text or, in the latter case, record form. Closes #123. Closes #21. --- asyncpg/cluster.py | 8 +- asyncpg/connection.py | 164 ++++++++++++++++++++++ asyncpg/protocol/consts.pxi | 2 + asyncpg/protocol/coreproto.pxd | 7 + asyncpg/protocol/coreproto.pyx | 84 ++++++++++++ asyncpg/protocol/protocol.pxd | 2 + asyncpg/protocol/protocol.pyx | 138 +++++++++++++++++-- asyncpg/protocol/python.pxd | 1 + tests/test_copy.py | 244 ++++++++++++++++++++++++++++++++- 9 files changed, 640 insertions(+), 10 deletions(-) diff --git a/asyncpg/cluster.py b/asyncpg/cluster.py index 9e9030c0..afb8e6c7 100644 --- a/asyncpg/cluster.py +++ b/asyncpg/cluster.py @@ -15,6 +15,7 @@ import shutil import socket import subprocess +import sys import tempfile import textwrap import time @@ -213,10 +214,15 @@ def start(self, wait=60, *, server_settings={}, **opts): 'pg_ctl start exited with status {:d}: {}'.format( process.returncode, stderr.decode())) else: + if os.getenv('ASYNCPG_DEBUG_SERVER'): + stdout = sys.stdout + else: + stdout = subprocess.DEVNULL + self._daemon_process = \ subprocess.Popen( [self._postgres, '-D', self._data_dir, *extra_args], - stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, + stdout=stdout, stderr=subprocess.STDOUT, preexec_fn=ensure_dead_with_parent) self._daemon_pid = self._daemon_process.pid diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 02492824..59945df4 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -7,6 +7,7 @@ import asyncio import collections +import collections.abc import struct import time @@ -451,6 +452,115 @@ async def copy_from_query(self, query, *args, output, return await self._copy_out(copy_stmt, output, timeout) + async def copy_to_table(self, table_name, *, source, + columns=None, schema_name=None, timeout=None, + format=None, oids=None, freeze=None, + delimiter=None, null=None, header=None, + quote=None, escape=None, force_quote=None, + force_not_null=None, force_null=None, + encoding=None): + """Copy data to the specified table. + + :param str table_name: + The name of the table to copy data to. + + :param source: + A :term:`path-like object `, + or a :term:`file-like object `, or + an :term:`asynchronous iterable ` + that returns ``bytes``, or an object supporting the + :term:`buffer protocol `. + + :param list columns: + An optional list of column names to copy. + + :param str schema_name: + An optional schema name to qualify the table. + + :param float timeout: + Optional timeout value in seconds. + + The remaining kewyword arguments are ``COPY`` statement options, + see `COPY statement documentation`_ for details. + + :return: The status string of the COPY command. + + .. versionadded:: 0.11.0 + + .. _`COPY statement documentation`: https://www.postgresql.org/docs/\ + current/static/sql-copy.html + + """ + tabname = utils._quote_ident(table_name) + if schema_name: + tabname = utils._quote_ident(schema_name) + '.' + tabname + + if columns: + cols = '({})'.format( + ', '.join(utils._quote_ident(c) for c in columns)) + else: + cols = '' + + opts = self._format_copy_opts( + format=format, oids=oids, freeze=freeze, delimiter=delimiter, + null=null, header=header, quote=quote, escape=escape, + force_not_null=force_not_null, force_null=force_null, + encoding=encoding + ) + + copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}'.format( + tab=tabname, cols=cols, opts=opts) + + return await self._copy_in(copy_stmt, source, timeout) + + async def copy_records_to_table(self, table_name, *, records, + columns=None, schema_name=None, + timeout=None): + """Copy a list of records to the specified table using binary COPY. + + :param str table_name: + The name of the table to copy data to. + + :param records: + An iterable returning row tuples to copy into the table. + + :param list columns: + An optional list of column names to copy. + + :param str schema_name: + An optional schema name to qualify the table. + + :param float timeout: + Optional timeout value in seconds. + + :return: The status string of the COPY command. + + .. versionadded:: 0.11.0 + """ + tabname = utils._quote_ident(table_name) + if schema_name: + tabname = utils._quote_ident(schema_name) + '.' + tabname + + if columns: + col_list = ', '.join(utils._quote_ident(c) for c in columns) + cols = '({})'.format(col_list) + else: + col_list = '*' + cols = '' + + intro_query = 'SELECT {cols} FROM {tab} LIMIT 1'.format( + tab=tabname, cols=col_list) + + intro_ps = await self.prepare(intro_query) + + opts = '(FORMAT binary)' + + copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}'.format( + tab=tabname, cols=cols, opts=opts) + + return await self._copy_in_records( + copy_stmt, records, intro_ps._state, timeout) + def _format_copy_opts(self, *, format=None, oids=None, freeze=None, delimiter=None, null=None, header=None, quote=None, escape=None, force_quote=None, force_not_null=None, @@ -519,6 +629,60 @@ async def _writer(data): if opened_by_us: f.close() + async def _copy_in(self, copy_stmt, source, timeout): + try: + path = compat.fspath(source) + except TypeError: + # source is not a path-like object + path = None + + f = None + reader = None + data = None + opened_by_us = False + run_in_executor = self._loop.run_in_executor + + if path is not None: + # a path + f = await run_in_executor(None, open, path, 'wb') + opened_by_us = True + elif hasattr(source, 'read'): + # file-like + f = source + elif isinstance(source, collections.abc.AsyncIterable): + # assuming calling output returns an awaitable. + reader = source + else: + # assuming source is an instance supporting the buffer protocol. + data = source + + if f is not None: + # Copying from a file-like object. + class _Reader: + @compat.aiter_compat + def __aiter__(self): + return self + + async def __anext__(self): + data = await run_in_executor(None, f.read, 524288) + if len(data) == 0: + raise StopAsyncIteration + else: + return data + + reader = _Reader() + + try: + return await self._protocol.copy_in( + copy_stmt, reader, data, None, None, timeout) + finally: + if opened_by_us: + await run_in_executor(None, f.close) + + async def _copy_in_records(self, copy_stmt, records, intro_stmt, timeout): + return await self._protocol.copy_in( + copy_stmt, None, None, records, intro_stmt, timeout) + async def set_type_codec(self, typename, *, schema='public', encoder, decoder, binary=False): """Set an encoder/decoder pair for the specified data type. diff --git a/asyncpg/protocol/consts.pxi b/asyncpg/protocol/consts.pxi index 6edd93d7..f880940a 100644 --- a/asyncpg/protocol/consts.pxi +++ b/asyncpg/protocol/consts.pxi @@ -11,3 +11,5 @@ DEF _BUFFER_FREELIST_SIZE = 256 DEF _RECORD_FREELIST_SIZE = 1024 DEF _MEMORY_FREELIST_SIZE = 1024 DEF _MAXINT32 = 2**31 - 1 +DEF _COPY_BUFFER_SIZE = 524288 +DEF _COPY_SIGNATURE = b"PGCOPY\n\377\r\n\0" diff --git a/asyncpg/protocol/coreproto.pxd b/asyncpg/protocol/coreproto.pxd index 3048e3f9..c3b18f3d 100644 --- a/asyncpg/protocol/coreproto.pxd +++ b/asyncpg/protocol/coreproto.pxd @@ -113,6 +113,8 @@ cdef class CoreProtocol: cdef _process__bind(self, char mtype) cdef _process__copy_out(self, char mtype) cdef _process__copy_out_data(self, char mtype) + cdef _process__copy_in(self, char mtype) + cdef _process__copy_in_data(self, char mtype) cdef _parse_msg_authentication(self) cdef _parse_msg_parameter_status(self) @@ -124,6 +126,10 @@ cdef class CoreProtocol: cdef _parse_msg_error_response(self, is_error) cdef _parse_msg_command_complete(self) + cdef _write_copy_data_msg(self, object data) + cdef _write_copy_done_msg(self) + cdef _write_copy_fail_msg(self, str cause) + cdef _auth_password_message_cleartext(self) cdef _auth_password_message_md5(self, bytes salt) @@ -157,6 +163,7 @@ cdef class CoreProtocol: cdef _close(self, str name, bint is_portal) cdef _simple_query(self, str query) cdef _copy_out(self, str copy_stmt) + cdef _copy_in(self, str copy_stmt) cdef _terminate(self) cdef _decode_row(self, const char* buf, ssize_t buf_len) diff --git a/asyncpg/protocol/coreproto.pyx b/asyncpg/protocol/coreproto.pyx index b069e783..e8ae79a0 100644 --- a/asyncpg/protocol/coreproto.pyx +++ b/asyncpg/protocol/coreproto.pyx @@ -88,6 +88,12 @@ cdef class CoreProtocol: state == PROTOCOL_COPY_OUT_DONE): self._process__copy_out_data(mtype) + elif state == PROTOCOL_COPY_IN: + self._process__copy_in(mtype) + + elif state == PROTOCOL_COPY_IN_DATA: + self._process__copy_in_data(mtype) + elif state == PROTOCOL_CANCELLED: # discard all messages until the sync message if mtype == b'E': @@ -356,6 +362,33 @@ cdef class CoreProtocol: self._parse_msg_ready_for_query() self._push_result() + cdef _process__copy_in(self, char mtype): + if mtype == b'E': + self._parse_msg_error_response(True) + + elif mtype == b'G': + # CopyInResponse + self._set_state(PROTOCOL_COPY_IN_DATA) + self.buffer.consume_message() + + elif mtype == b'Z': + # ReadyForQuery + self._parse_msg_ready_for_query() + self._push_result() + + cdef _process__copy_in_data(self, char mtype): + if mtype == b'E': + self._parse_msg_error_response(True) + + elif mtype == b'C': + # CommandComplete + self._parse_msg_command_complete() + + elif mtype == b'Z': + # ReadyForQuery + self._parse_msg_ready_for_query() + self._push_result() + cdef _parse_msg_command_complete(self): cdef: char* cbuf @@ -387,6 +420,42 @@ cdef class CoreProtocol: self._on_result() self.result = None + cdef _write_copy_data_msg(self, object data): + cdef: + WriteBuffer buf + object mview + Py_buffer *pybuf + + mview = PyMemoryView_GetContiguous(data, cpython.PyBUF_SIMPLE, b'C') + + try: + pybuf = PyMemoryView_GET_BUFFER(mview) + + buf = WriteBuffer.new_message(b'd') + buf.write_cstr(pybuf.buf, pybuf.len) + buf.end_message() + finally: + mview.release() + + self._write(buf) + + cdef _write_copy_done_msg(self): + cdef: + WriteBuffer buf + + buf = WriteBuffer.new_message(b'c') + buf.end_message() + self._write(buf) + + cdef _write_copy_fail_msg(self, str cause): + cdef: + WriteBuffer buf + + buf = WriteBuffer.new_message(b'f') + buf.write_str(cause or '', self.encoding) + buf.end_message() + self._write(buf) + cdef _parse_data_msgs(self): cdef: ReadBuffer buf = self.buffer @@ -592,6 +661,10 @@ cdef class CoreProtocol: new_state == PROTOCOL_COPY_OUT_DONE): self.state = new_state + elif (self.state == PROTOCOL_COPY_IN and + new_state == PROTOCOL_COPY_IN_DATA): + self.state = new_state + elif self.state == PROTOCOL_FAILED: raise RuntimeError( 'cannot switch to state {}; ' @@ -810,6 +883,17 @@ cdef class CoreProtocol: buf.end_message() self._write(buf) + cdef _copy_in(self, str copy_stmt): + cdef WriteBuffer buf + + self._ensure_connected() + self._set_state(PROTOCOL_COPY_IN) + + buf = WriteBuffer.new_message(b'Q') + buf.write_str(copy_stmt, self.encoding) + buf.end_message() + self._write(buf) + cdef _terminate(self): cdef WriteBuffer buf self._ensure_connected() diff --git a/asyncpg/protocol/protocol.pxd b/asyncpg/protocol/protocol.pxd index 8e441537..210d4c83 100644 --- a/asyncpg/protocol/protocol.pxd +++ b/asyncpg/protocol/protocol.pxd @@ -41,6 +41,7 @@ cdef class BaseProtocol(CoreProtocol): str last_query + bint writing_paused bint closing readonly uint64_t queries_count @@ -58,6 +59,7 @@ cdef class BaseProtocol(CoreProtocol): cdef _on_result__simple_query(self, object waiter) cdef _on_result__bind(self, object waiter) cdef _on_result__copy_out(self, object waiter) + cdef _on_result__copy_in(self, object waiter) cdef _handle_waiter_on_connection_lost(self, cause) diff --git a/asyncpg/protocol/protocol.pyx b/asyncpg/protocol/protocol.pyx index 04d57c77..49a0b72c 100644 --- a/asyncpg/protocol/protocol.pyx +++ b/asyncpg/protocol/protocol.pyx @@ -11,6 +11,7 @@ cimport cython cimport cpython import asyncio +import builtins import codecs import collections import socket @@ -24,7 +25,7 @@ from asyncpg.protocol cimport record from asyncpg.protocol.python cimport ( PyMem_Malloc, PyMem_Realloc, PyMem_Calloc, PyMem_Free, PyMemoryView_GET_BUFFER, PyMemoryView_Check, - PyMemoryView_FromMemory, + PyMemoryView_FromMemory, PyMemoryView_GetContiguous, PyUnicode_AsUTF8AndSize, PyByteArray_AsString, PyByteArray_Check, PyUnicode_AsUCS4Copy, PyByteArray_Size, PyByteArray_Resize, @@ -34,6 +35,7 @@ from asyncpg.protocol.python cimport ( from cpython cimport PyBuffer_FillInfo, PyBytes_AsString from asyncpg.exceptions import _base as apg_exc_base +from asyncpg import compat from asyncpg import types as apg_types from asyncpg import exceptions as apg_exc @@ -106,6 +108,8 @@ cdef class BaseProtocol(CoreProtocol): self.closing = False self.is_reading = True + self.writing_allowed = asyncio.Event(loop=self.loop) + self.writing_allowed.set() self.timeout_handle = None self.timeout_callback = self._on_timeout @@ -332,6 +336,116 @@ cdef class BaseProtocol(CoreProtocol): return status_msg + async def copy_in(self, copy_stmt, reader, data, + records, PreparedStatementState record_stmt, timeout): + cdef: + WriteBuffer wbuf + ssize_t num_cols + Codec codec + + if self.cancel_waiter is not None: + await self.cancel_waiter + if self.cancel_sent_waiter is not None: + await self.cancel_sent_waiter + self.cancel_sent_waiter = None + + self._check_state() + + timeout = self._get_timeout_impl(timeout) + timer = Timer(timeout) + + waiter = self._new_waiter(timer.get_remaining_budget()) + + # Initiate COPY IN. + self._copy_in(copy_stmt) + + try: + if record_stmt is not None: + # copy_in_records in binary mode + wbuf = WriteBuffer.new() + # Signature + wbuf.write_bytes(_COPY_SIGNATURE) + # Flags field + wbuf.write_int32(0) + # No header extension + wbuf.write_int32(0) + + record_stmt._ensure_rows_decoder() + codecs = record_stmt.rows_codecs + num_cols = len(codecs) + settings = self.settings + + for codec in codecs: + if not codec.has_encoder(): + raise RuntimeError( + 'no encoder for OID {}'.format(codec.oid)) + + for row in records: + # Tuple header + wbuf.write_int16(num_cols) + # Tuple data + for i in range(num_cols): + codec = cpython.PyTuple_GET_ITEM(codecs, i) + codec.encode(settings, wbuf, row[i]) + + if wbuf.len() >= _COPY_BUFFER_SIZE: + with timer: + await self.writing_allowed.wait() + self._write_copy_data_msg(wbuf) + wbuf = WriteBuffer.new() + + # End of binary copy. + wbuf.write_int16(-1) + self._write_copy_data_msg(wbuf) + + elif reader is not None: + try: + aiter = reader.__aiter__ + except AttributeError: + raise TypeError('reader is not an asynchronous iterable') + else: + iterator = aiter() + + try: + while True: + # We rely on protocol flow control to moderate the + # rate of data messages. + with timer: + await self.writing_allowed.wait() + with timer: + chunk = await asyncio.wait_for( + iterator.__anext__(), + timeout=timer.get_remaining_budget(), + loop=self.loop) + self._write_copy_data_msg(chunk) + except builtins.StopAsyncIteration: + pass + else: + # Buffer passed in directly. + await self.writing_allowed.wait() + self._write_copy_data_msg(data) + + except asyncio.TimeoutError: + self._write_copy_fail_msg('TimeoutError') + self._on_timeout(self.waiter) + try: + await waiter + except TimeoutError: + raise + else: + raise RuntimeError('TimoutError was not raised') + + except Exception as e: + self._write_copy_fail_msg(str(e)) + self._request_cancel() + raise + + self._write_copy_done_msg() + + status_msg = await waiter + + return status_msg + async def close_statement(self, PreparedStatementState state, timeout): if self.cancel_waiter is not None: await self.cancel_waiter @@ -516,6 +630,10 @@ cdef class BaseProtocol(CoreProtocol): waiter.set_result((self.result, copy_done, status_msg)) + cdef _on_result__copy_in(self, object waiter): + status_msg = self.result_status_msg.decode(self.encoding) + waiter.set_result(status_msg) + cdef _decode_row(self, const char* buf, ssize_t buf_len): if ASYNCPG_DEBUG: if self.statement is None: @@ -576,6 +694,9 @@ cdef class BaseProtocol(CoreProtocol): self.state == PROTOCOL_COPY_OUT_DONE): self._on_result__copy_out(waiter) + elif self.state == PROTOCOL_COPY_IN_DATA: + self._on_result__copy_in(waiter) + else: raise RuntimeError( 'got result for unknown protocol state {}'. @@ -591,13 +712,8 @@ cdef class BaseProtocol(CoreProtocol): if self.cancel_waiter is not None: # We have received the result of a cancelled operation. - # Check that the result waiter (if it exists) has a result. - if self.waiter is not None and not self.waiter.done(): - self.cancel_waiter.set_exception( - RuntimeError( - 'invalid result waiter state on cancellation')) - else: - self.cancel_waiter.set_result(None) + # Simply ignore the result. + self.cancel_waiter.set_result(None) self.cancel_waiter = None self.waiter = None return @@ -629,6 +745,12 @@ cdef class BaseProtocol(CoreProtocol): self.closing = True self._handle_waiter_on_connection_lost(exc) + def pause_writing(self): + self.writing_allowed.clear() + + def resume_writing(self): + self.writing_allowed.set() + class Timer: def __init__(self, budget): diff --git a/asyncpg/protocol/python.pxd b/asyncpg/protocol/python.pxd index 4ac7720e..adce0073 100644 --- a/asyncpg/protocol/python.pxd +++ b/asyncpg/protocol/python.pxd @@ -21,6 +21,7 @@ cdef extern from "Python.h": int PyMemoryView_Check(object) Py_buffer *PyMemoryView_GET_BUFFER(object) object PyMemoryView_FromMemory(char *mem, ssize_t size, int flags) + object PyMemoryView_GetContiguous(object, int buffertype, char order) char* PyUnicode_AsUTF8AndSize(object unicode, ssize_t *size) except NULL char* PyByteArray_AsString(object) diff --git a/tests/test_copy.py b/tests/test_copy.py index 34f0107c..1ec503d4 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -6,13 +6,15 @@ import asyncio +import datetime import io import tempfile from asyncpg import _testbase as tb +from asyncpg import compat -class TestCopy(tb.ConnectedTestCase): +class TestCopyFrom(tb.ConnectedTestCase): async def test_copy_from_table_basics(self): await self.con.execute(''' @@ -353,3 +355,243 @@ async def writer(data): await task self.assertEqual(await self.con.fetchval('SELECT 1'), 1) + + +class TestCopyTo(tb.ConnectedTestCase): + + async def test_copy_to_table_basics(self): + await self.con.execute(''' + CREATE TABLE copytab(a text, "b~" text, i int); + ''') + + try: + f = io.BytesIO() + f.write( + '\n'.join([ + 'a1\tb1\t1', + 'a2\tb2\t2', + 'a3\tb3\t3', + 'a4\tb4\t4', + 'a5\tb5\t5', + '*\t\\N\t\\N', + '' + ]).encode('utf-8') + ) + f.seek(0) + + res = await self.con.copy_to_table('copytab', source=f) + self.assertEqual(res, 'COPY 6') + + output = await self.con.fetch(""" + SELECT * FROM copytab ORDER BY a + """) + self.assertEqual( + output, + [ + ('*', None, None), + ('a1', 'b1', 1), + ('a2', 'b2', 2), + ('a3', 'b3', 3), + ('a4', 'b4', 4), + ('a5', 'b5', 5), + ] + ) + + # Test parameters. + await self.con.execute('TRUNCATE copytab') + await self.con.execute('SET search_path=none') + + f.seek(0) + f.truncate() + + f.write( + '\n'.join([ + 'a|b~', + '*a1*|b1', + '*a2*|b2', + '*a3*|b3', + '*a4*|b4', + '*a5*|b5', + '*!**|*n-u-l-l*', + 'n-u-l-l|bb' + ]).encode('utf-8') + ) + f.seek(0) + + res = await self.con.copy_to_table( + 'copytab', source=f, columns=('a', 'b~'), + schema_name='public', format='csv', + delimiter='|', null='n-u-l-l', header=True, + quote='*', escape='!', force_not_null=('a',), + force_null=('b~',)) + + self.assertEqual(res, 'COPY 7') + + await self.con.execute('SET search_path=public') + + output = await self.con.fetch(""" + SELECT * FROM copytab ORDER BY a + """) + self.assertEqual( + output, + [ + ('*', None, None), + ('a1', 'b1', None), + ('a2', 'b2', None), + ('a3', 'b3', None), + ('a4', 'b4', None), + ('a5', 'b5', None), + ('n-u-l-l', 'bb', None), + ] + ) + + finally: + await self.con.execute('DROP TABLE public.copytab') + + async def test_copy_to_table_large_rows(self): + await self.con.execute(''' + CREATE TABLE copytab(a text, b text); + ''') + + try: + class _Source: + def __init__(self): + self.rowcount = 0 + + @compat.aiter_compat + def __aiter__(self): + return self + + async def __anext__(self): + if self.rowcount >= 100: + raise StopAsyncIteration + else: + self.rowcount += 1 + return b'a1' * 500000 + b'\t' + b'b1' * 500000 + b'\n' + + res = await self.con.copy_to_table('copytab', source=_Source()) + + self.assertEqual(res, 'COPY 100') + + finally: + await self.con.execute('DROP TABLE copytab') + + async def test_copy_to_table_from_bytes_like(self): + await self.con.execute(''' + CREATE TABLE copytab(a text, b text); + ''') + + try: + data = memoryview((b'a1' * 500 + b'\t' + b'b1' * 500 + b'\n') * 2) + res = await self.con.copy_to_table('copytab', source=data) + self.assertEqual(res, 'COPY 2') + finally: + await self.con.execute('DROP TABLE copytab') + + async def test_copy_to_table_fail_in_source_1(self): + await self.con.execute(''' + CREATE TABLE copytab(a text, b text); + ''') + + try: + class _Source: + def __init__(self): + self.rowcount = 0 + + @compat.aiter_compat + def __aiter__(self): + return self + + async def __anext__(self): + raise RuntimeError('failure in source') + + with self.assertRaisesRegexp(RuntimeError, 'failure in source'): + await self.con.copy_to_table('copytab', source=_Source()) + + # Check that the protocol has recovered. + self.assertEqual(await self.con.fetchval('SELECT 1'), 1) + + finally: + await self.con.execute('DROP TABLE copytab') + + async def test_copy_to_table_fail_in_source_2(self): + await self.con.execute(''' + CREATE TABLE copytab(a text, b text); + ''') + + try: + class _Source: + def __init__(self): + self.rowcount = 0 + + @compat.aiter_compat + def __aiter__(self): + return self + + async def __anext__(self): + if self.rowcount == 0: + self.rowcount += 1 + return b'a\tb\n' + else: + raise RuntimeError('failure in source') + + with self.assertRaisesRegexp(RuntimeError, 'failure in source'): + await self.con.copy_to_table('copytab', source=_Source()) + + # Check that the protocol has recovered. + self.assertEqual(await self.con.fetchval('SELECT 1'), 1) + + finally: + await self.con.execute('DROP TABLE copytab') + + async def test_copy_to_table_timeout(self): + await self.con.execute(''' + CREATE TABLE copytab(a text, b text); + ''') + + try: + class _Source: + def __init__(self, loop): + self.rowcount = 0 + self.loop = loop + + @compat.aiter_compat + def __aiter__(self): + return self + + async def __anext__(self): + self.rowcount += 1 + await asyncio.sleep(60, loop=self.loop) + return b'a1' * 50 + b'\t' + b'b1' * 50 + b'\n' + + with self.assertRaises(asyncio.TimeoutError): + await self.con.copy_to_table( + 'copytab', source=_Source(self.loop), timeout=0.10) + + # Check that the protocol has recovered. + self.assertEqual(await self.con.fetchval('SELECT 1'), 1) + + finally: + await self.con.execute('DROP TABLE copytab') + + async def test_copy_records_to_table(self): + await self.con.execute(''' + CREATE TABLE copytab(a text, b int, c timestamptz); + ''') + + try: + date = datetime.datetime.now(tz=datetime.timezone.utc) + delta = datetime.timedelta(days=1) + + records = [ + ('a-{}'.format(i), i, date + delta) + for i in range(100) + ] + + res = await self.con.copy_records_to_table( + 'copytab', records=records) + + self.assertEqual(res, 'COPY 100') + + finally: + await self.con.execute('DROP TABLE copytab')