Skip to content

Commit f19a893

Browse files
committed
Add support for COPY IN
This commit adds two new Connection methods: copy_to_table() and copy_records_to_table() that allow copying data to the specified table either in text or, in the latter case, record form. Related-To #123. Closes #21.
1 parent 5662d9f commit f19a893

File tree

9 files changed

+621
-9
lines changed

9 files changed

+621
-9
lines changed

asyncpg/cluster.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import shutil
1616
import socket
1717
import subprocess
18+
import sys
1819
import tempfile
1920
import textwrap
2021
import time
@@ -213,10 +214,15 @@ def start(self, wait=60, *, server_settings={}, **opts):
213214
'pg_ctl start exited with status {:d}: {}'.format(
214215
process.returncode, stderr.decode()))
215216
else:
217+
if os.getenv('ASYNCPG_DEBUG_SERVER'):
218+
stdout = sys.stdout
219+
else:
220+
stdout = subprocess.DEVNULL
221+
216222
self._daemon_process = \
217223
subprocess.Popen(
218224
[self._postgres, '-D', self._data_dir, *extra_args],
219-
stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL,
225+
stdout=stdout, stderr=subprocess.STDOUT,
220226
preexec_fn=ensure_dead_with_parent)
221227

222228
self._daemon_pid = self._daemon_process.pid

asyncpg/connection.py

+164
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import asyncio
99
import collections
10+
import collections.abc
1011
import struct
1112
import time
1213

@@ -451,6 +452,115 @@ async def copy_from_query(self, query, *args, output,
451452

452453
return await self._copy_out(copy_stmt, output, timeout)
453454

455+
async def copy_to_table(self, table_name, *, source,
456+
columns=None, schema_name=None, timeout=None,
457+
format=None, oids=None, freeze=None,
458+
delimiter=None, null=None, header=None,
459+
quote=None, escape=None, force_quote=None,
460+
force_not_null=None, force_null=None,
461+
encoding=None):
462+
"""Copy data to the specified table.
463+
464+
:param str table_name:
465+
The name of the table to copy data to.
466+
467+
:param source:
468+
A :term:`path-like object <python:path-like object>`,
469+
or a :term:`file-like object <python:file-like object>`, or
470+
an :term:`asynchronous iterable <python:asynchronous iterable>`
471+
that returns ``bytes``, or an object supporting the
472+
:term:`buffer protocol <python:buffer protocol>`.
473+
474+
:param list columns:
475+
An optional list of column names to copy.
476+
477+
:param str schema_name:
478+
An optional schema name to qualify the table.
479+
480+
:param float timeout:
481+
Optional timeout value in seconds.
482+
483+
The remaining kewyword arguments are ``COPY`` statement options,
484+
see `COPY statement documentation`_ for details.
485+
486+
:return: The status string of the COPY command.
487+
488+
.. versionadded:: 0.11.0
489+
490+
.. _`COPY statement documentation`: https://www.postgresql.org/docs/\
491+
current/static/sql-copy.html
492+
493+
"""
494+
tabname = utils._quote_ident(table_name)
495+
if schema_name:
496+
tabname = utils._quote_ident(schema_name) + '.' + tabname
497+
498+
if columns:
499+
cols = '({})'.format(
500+
', '.join(utils._quote_ident(c) for c in columns))
501+
else:
502+
cols = ''
503+
504+
opts = self._format_copy_opts(
505+
format=format, oids=oids, freeze=freeze, delimiter=delimiter,
506+
null=null, header=header, quote=quote, escape=escape,
507+
force_not_null=force_not_null, force_null=force_null,
508+
encoding=encoding
509+
)
510+
511+
copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}'.format(
512+
tab=tabname, cols=cols, opts=opts)
513+
514+
return await self._copy_in(copy_stmt, source, timeout)
515+
516+
async def copy_records_to_table(self, table_name, *, records,
517+
columns=None, schema_name=None,
518+
timeout=None):
519+
"""Copy a list of records to the specified table using binary COPY.
520+
521+
:param str table_name:
522+
The name of the table to copy data to.
523+
524+
:param records:
525+
An iterable returning row tuples to copy into the table.
526+
527+
:param list columns:
528+
An optional list of column names to copy.
529+
530+
:param str schema_name:
531+
An optional schema name to qualify the table.
532+
533+
:param float timeout:
534+
Optional timeout value in seconds.
535+
536+
:return: The status string of the COPY command.
537+
538+
.. versionadded:: 0.11.0
539+
"""
540+
tabname = utils._quote_ident(table_name)
541+
if schema_name:
542+
tabname = utils._quote_ident(schema_name) + '.' + tabname
543+
544+
if columns:
545+
col_list = ', '.join(utils._quote_ident(c) for c in columns)
546+
cols = '({})'.format(col_list)
547+
else:
548+
col_list = '*'
549+
cols = ''
550+
551+
intro_query = 'SELECT {cols} FROM {tab} LIMIT 1'.format(
552+
tab=tabname, cols=col_list)
553+
554+
intro_ps = await self.prepare(intro_query)
555+
556+
opts = '(FORMAT binary)'
557+
558+
copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}'.format(
559+
tab=tabname, cols=cols, opts=opts)
560+
561+
return await self._copy_in_records(
562+
copy_stmt, records, intro_ps._state, timeout)
563+
454564
def _format_copy_opts(self, *, format=None, oids=None, freeze=None,
455565
delimiter=None, null=None, header=None, quote=None,
456566
escape=None, force_quote=None, force_not_null=None,
@@ -519,6 +629,60 @@ async def _writer(data):
519629
if opened_by_us:
520630
f.close()
521631

632+
async def _copy_in(self, copy_stmt, source, timeout):
633+
try:
634+
path = compat.fspath(source)
635+
except TypeError:
636+
# source is not a path-like object
637+
path = None
638+
639+
f = None
640+
reader = None
641+
data = None
642+
opened_by_us = False
643+
run_in_executor = self._loop.run_in_executor
644+
645+
if path is not None:
646+
# a path
647+
f = await run_in_executor(None, open, path, 'wb')
648+
opened_by_us = True
649+
elif hasattr(source, 'read'):
650+
# file-like
651+
f = source
652+
elif isinstance(source, collections.abc.AsyncIterable):
653+
# assuming calling output returns an awaitable.
654+
reader = source
655+
else:
656+
# assuming source is an instance supporting the buffer protocol.
657+
data = source
658+
659+
if f is not None:
660+
# Copying from a file-like object.
661+
class _Reader:
662+
@compat.aiter_compat
663+
def __aiter__(self):
664+
return self
665+
666+
async def __anext__(self):
667+
data = await run_in_executor(None, f.read, 524288)
668+
if len(data) == 0:
669+
raise StopAsyncIteration
670+
else:
671+
return data
672+
673+
reader = _Reader()
674+
675+
try:
676+
return await self._protocol.copy_in(
677+
copy_stmt, reader, data, None, None, timeout)
678+
finally:
679+
if opened_by_us:
680+
await run_in_executor(None, f.close)
681+
682+
async def _copy_in_records(self, copy_stmt, records, intro_stmt, timeout):
683+
return await self._protocol.copy_in(
684+
copy_stmt, None, None, records, intro_stmt, timeout)
685+
522686
async def set_type_codec(self, typename, *,
523687
schema='public', encoder, decoder, binary=False):
524688
"""Set an encoder/decoder pair for the specified data type.

asyncpg/protocol/consts.pxi

+2
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,5 @@ DEF _BUFFER_FREELIST_SIZE = 256
1111
DEF _RECORD_FREELIST_SIZE = 1024
1212
DEF _MEMORY_FREELIST_SIZE = 1024
1313
DEF _MAXINT32 = 2**31 - 1
14+
DEF _COPY_BUFFER_SIZE = 524288
15+
DEF _COPY_SIGNATURE = b"PGCOPY\n\377\r\n\0"

asyncpg/protocol/coreproto.pxd

+7
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ cdef class CoreProtocol:
113113
cdef _process__bind(self, char mtype)
114114
cdef _process__copy_out(self, char mtype)
115115
cdef _process__copy_out_data(self, char mtype)
116+
cdef _process__copy_in(self, char mtype)
117+
cdef _process__copy_in_data(self, char mtype)
116118

117119
cdef _parse_msg_authentication(self)
118120
cdef _parse_msg_parameter_status(self)
@@ -124,6 +126,10 @@ cdef class CoreProtocol:
124126
cdef _parse_msg_error_response(self, is_error)
125127
cdef _parse_msg_command_complete(self)
126128

129+
cdef _write_copy_data_msg(self, object data)
130+
cdef _write_copy_done_msg(self)
131+
cdef _write_copy_fail_msg(self, str cause)
132+
127133
cdef _auth_password_message_cleartext(self)
128134
cdef _auth_password_message_md5(self, bytes salt)
129135

@@ -157,6 +163,7 @@ cdef class CoreProtocol:
157163
cdef _close(self, str name, bint is_portal)
158164
cdef _simple_query(self, str query)
159165
cdef _copy_out(self, str copy_stmt)
166+
cdef _copy_in(self, str copy_stmt)
160167
cdef _terminate(self)
161168

162169
cdef _decode_row(self, const char* buf, ssize_t buf_len)

asyncpg/protocol/coreproto.pyx

+79
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,12 @@ cdef class CoreProtocol:
8888
state == PROTOCOL_COPY_OUT_DONE):
8989
self._process__copy_out_data(mtype)
9090

91+
elif state == PROTOCOL_COPY_IN:
92+
self._process__copy_in(mtype)
93+
94+
elif state == PROTOCOL_COPY_IN_DATA:
95+
self._process__copy_in_data(mtype)
96+
9197
elif state == PROTOCOL_CANCELLED:
9298
# discard all messages until the sync message
9399
if mtype == b'E':
@@ -356,6 +362,33 @@ cdef class CoreProtocol:
356362
self._parse_msg_ready_for_query()
357363
self._push_result()
358364

365+
cdef _process__copy_in(self, char mtype):
366+
if mtype == b'E':
367+
self._parse_msg_error_response(True)
368+
369+
elif mtype == b'G':
370+
# CopyInResponse
371+
self._set_state(PROTOCOL_COPY_IN_DATA)
372+
self.buffer.consume_message()
373+
374+
elif mtype == b'Z':
375+
# ReadyForQuery
376+
self._parse_msg_ready_for_query()
377+
self._push_result()
378+
379+
cdef _process__copy_in_data(self, char mtype):
380+
if mtype == b'E':
381+
self._parse_msg_error_response(True)
382+
383+
elif mtype == b'C':
384+
# CommandComplete
385+
self._parse_msg_command_complete()
386+
387+
elif mtype == b'Z':
388+
# ReadyForQuery
389+
self._parse_msg_ready_for_query()
390+
self._push_result()
391+
359392
cdef _parse_msg_command_complete(self):
360393
cdef:
361394
char* cbuf
@@ -387,6 +420,37 @@ cdef class CoreProtocol:
387420
self._on_result()
388421
self.result = None
389422

423+
cdef _write_copy_data_msg(self, object data):
424+
cdef:
425+
WriteBuffer buf
426+
object mview
427+
Py_buffer *pybuf
428+
429+
mview = PyMemoryView_GetContiguous(data, cpython.PyBUF_SIMPLE, b'C')
430+
pybuf = PyMemoryView_GET_BUFFER(mview)
431+
432+
buf = WriteBuffer.new_message(b'd')
433+
buf.write_cstr(<const char *>pybuf.buf, pybuf.len)
434+
buf.end_message()
435+
self._write(buf)
436+
437+
cdef _write_copy_done_msg(self):
438+
cdef:
439+
WriteBuffer buf
440+
441+
buf = WriteBuffer.new_message(b'c')
442+
buf.end_message()
443+
self._write(buf)
444+
445+
cdef _write_copy_fail_msg(self, str cause):
446+
cdef:
447+
WriteBuffer buf
448+
449+
buf = WriteBuffer.new_message(b'f')
450+
buf.write_str(cause or '', self.encoding)
451+
buf.end_message()
452+
self._write(buf)
453+
390454
cdef _parse_data_msgs(self):
391455
cdef:
392456
ReadBuffer buf = self.buffer
@@ -592,6 +656,10 @@ cdef class CoreProtocol:
592656
new_state == PROTOCOL_COPY_OUT_DONE):
593657
self.state = new_state
594658

659+
elif (self.state == PROTOCOL_COPY_IN and
660+
new_state == PROTOCOL_COPY_IN_DATA):
661+
self.state = new_state
662+
595663
elif self.state == PROTOCOL_FAILED:
596664
raise RuntimeError(
597665
'cannot switch to state {}; '
@@ -810,6 +878,17 @@ cdef class CoreProtocol:
810878
buf.end_message()
811879
self._write(buf)
812880

881+
cdef _copy_in(self, str copy_stmt):
882+
cdef WriteBuffer buf
883+
884+
self._ensure_connected()
885+
self._set_state(PROTOCOL_COPY_IN)
886+
887+
buf = WriteBuffer.new_message(b'Q')
888+
buf.write_str(copy_stmt, self.encoding)
889+
buf.end_message()
890+
self._write(buf)
891+
813892
cdef _terminate(self):
814893
cdef WriteBuffer buf
815894
self._ensure_connected()

asyncpg/protocol/protocol.pxd

+2
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ cdef class BaseProtocol(CoreProtocol):
4141

4242
str last_query
4343

44+
bint writing_paused
4445
bint closing
4546

4647
readonly uint64_t queries_count
@@ -58,6 +59,7 @@ cdef class BaseProtocol(CoreProtocol):
5859
cdef _on_result__simple_query(self, object waiter)
5960
cdef _on_result__bind(self, object waiter)
6061
cdef _on_result__copy_out(self, object waiter)
62+
cdef _on_result__copy_in(self, object waiter)
6163

6264
cdef _handle_waiter_on_connection_lost(self, cause)
6365

0 commit comments

Comments
 (0)