Skip to content

Add support for COPY IN #135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 11, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion asyncpg/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import shutil
import socket
import subprocess
import sys
import tempfile
import textwrap
import time
Expand Down Expand Up @@ -213,10 +214,15 @@ def start(self, wait=60, *, server_settings={}, **opts):
'pg_ctl start exited with status {:d}: {}'.format(
process.returncode, stderr.decode()))
else:
if os.getenv('ASYNCPG_DEBUG_SERVER'):
stdout = sys.stdout
else:
stdout = subprocess.DEVNULL

self._daemon_process = \
subprocess.Popen(
[self._postgres, '-D', self._data_dir, *extra_args],
stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL,
stdout=stdout, stderr=subprocess.STDOUT,
preexec_fn=ensure_dead_with_parent)

self._daemon_pid = self._daemon_process.pid
Expand Down
164 changes: 164 additions & 0 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import asyncio
import collections
import collections.abc
import struct
import time

Expand Down Expand Up @@ -451,6 +452,115 @@ async def copy_from_query(self, query, *args, output,

return await self._copy_out(copy_stmt, output, timeout)

async def copy_to_table(self, table_name, *, source,
columns=None, schema_name=None, timeout=None,
format=None, oids=None, freeze=None,
delimiter=None, null=None, header=None,
quote=None, escape=None, force_quote=None,
force_not_null=None, force_null=None,
encoding=None):
"""Copy data to the specified table.

:param str table_name:
The name of the table to copy data to.

:param source:
A :term:`path-like object <python:path-like object>`,
or a :term:`file-like object <python:file-like object>`, or
an :term:`asynchronous iterable <python:asynchronous iterable>`
that returns ``bytes``, or an object supporting the
:term:`buffer protocol <python:buffer protocol>`.

:param list columns:
An optional list of column names to copy.

:param str schema_name:
An optional schema name to qualify the table.

:param float timeout:
Optional timeout value in seconds.

The remaining kewyword arguments are ``COPY`` statement options,
see `COPY statement documentation`_ for details.

:return: The status string of the COPY command.

.. versionadded:: 0.11.0

.. _`COPY statement documentation`: https://www.postgresql.org/docs/\
current/static/sql-copy.html

"""
tabname = utils._quote_ident(table_name)
if schema_name:
tabname = utils._quote_ident(schema_name) + '.' + tabname

if columns:
cols = '({})'.format(
', '.join(utils._quote_ident(c) for c in columns))
else:
cols = ''

opts = self._format_copy_opts(
format=format, oids=oids, freeze=freeze, delimiter=delimiter,
null=null, header=header, quote=quote, escape=escape,
force_not_null=force_not_null, force_null=force_null,
encoding=encoding
)

copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}'.format(
tab=tabname, cols=cols, opts=opts)

return await self._copy_in(copy_stmt, source, timeout)

async def copy_records_to_table(self, table_name, *, records,
columns=None, schema_name=None,
timeout=None):
"""Copy a list of records to the specified table using binary COPY.

:param str table_name:
The name of the table to copy data to.

:param records:
An iterable returning row tuples to copy into the table.

:param list columns:
An optional list of column names to copy.

:param str schema_name:
An optional schema name to qualify the table.

:param float timeout:
Optional timeout value in seconds.

:return: The status string of the COPY command.

.. versionadded:: 0.11.0
"""
tabname = utils._quote_ident(table_name)
if schema_name:
tabname = utils._quote_ident(schema_name) + '.' + tabname

if columns:
col_list = ', '.join(utils._quote_ident(c) for c in columns)
cols = '({})'.format(col_list)
else:
col_list = '*'
cols = ''

intro_query = 'SELECT {cols} FROM {tab} LIMIT 1'.format(
tab=tabname, cols=col_list)

intro_ps = await self.prepare(intro_query)

opts = '(FORMAT binary)'

copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}'.format(
tab=tabname, cols=cols, opts=opts)

return await self._copy_in_records(
copy_stmt, records, intro_ps._state, timeout)

def _format_copy_opts(self, *, format=None, oids=None, freeze=None,
delimiter=None, null=None, header=None, quote=None,
escape=None, force_quote=None, force_not_null=None,
Expand Down Expand Up @@ -519,6 +629,60 @@ async def _writer(data):
if opened_by_us:
f.close()

async def _copy_in(self, copy_stmt, source, timeout):
try:
path = compat.fspath(source)
except TypeError:
# source is not a path-like object
path = None

f = None
reader = None
data = None
opened_by_us = False
run_in_executor = self._loop.run_in_executor

if path is not None:
# a path
f = await run_in_executor(None, open, path, 'wb')
opened_by_us = True
elif hasattr(source, 'read'):
# file-like
f = source
elif isinstance(source, collections.abc.AsyncIterable):
# assuming calling output returns an awaitable.
reader = source
else:
# assuming source is an instance supporting the buffer protocol.
data = source

if f is not None:
# Copying from a file-like object.
class _Reader:
@compat.aiter_compat
def __aiter__(self):
return self

async def __anext__(self):
data = await run_in_executor(None, f.read, 524288)
if len(data) == 0:
raise StopAsyncIteration
else:
return data

reader = _Reader()

try:
return await self._protocol.copy_in(
copy_stmt, reader, data, None, None, timeout)
finally:
if opened_by_us:
await run_in_executor(None, f.close)

async def _copy_in_records(self, copy_stmt, records, intro_stmt, timeout):
return await self._protocol.copy_in(
copy_stmt, None, None, records, intro_stmt, timeout)

async def set_type_codec(self, typename, *,
schema='public', encoder, decoder, binary=False):
"""Set an encoder/decoder pair for the specified data type.
Expand Down
2 changes: 2 additions & 0 deletions asyncpg/protocol/consts.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@ DEF _BUFFER_FREELIST_SIZE = 256
DEF _RECORD_FREELIST_SIZE = 1024
DEF _MEMORY_FREELIST_SIZE = 1024
DEF _MAXINT32 = 2**31 - 1
DEF _COPY_BUFFER_SIZE = 524288
DEF _COPY_SIGNATURE = b"PGCOPY\n\377\r\n\0"
7 changes: 7 additions & 0 deletions asyncpg/protocol/coreproto.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ cdef class CoreProtocol:
cdef _process__bind(self, char mtype)
cdef _process__copy_out(self, char mtype)
cdef _process__copy_out_data(self, char mtype)
cdef _process__copy_in(self, char mtype)
cdef _process__copy_in_data(self, char mtype)

cdef _parse_msg_authentication(self)
cdef _parse_msg_parameter_status(self)
Expand All @@ -124,6 +126,10 @@ cdef class CoreProtocol:
cdef _parse_msg_error_response(self, is_error)
cdef _parse_msg_command_complete(self)

cdef _write_copy_data_msg(self, object data)
cdef _write_copy_done_msg(self)
cdef _write_copy_fail_msg(self, str cause)

cdef _auth_password_message_cleartext(self)
cdef _auth_password_message_md5(self, bytes salt)

Expand Down Expand Up @@ -157,6 +163,7 @@ cdef class CoreProtocol:
cdef _close(self, str name, bint is_portal)
cdef _simple_query(self, str query)
cdef _copy_out(self, str copy_stmt)
cdef _copy_in(self, str copy_stmt)
cdef _terminate(self)

cdef _decode_row(self, const char* buf, ssize_t buf_len)
Expand Down
84 changes: 84 additions & 0 deletions asyncpg/protocol/coreproto.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ cdef class CoreProtocol:
state == PROTOCOL_COPY_OUT_DONE):
self._process__copy_out_data(mtype)

elif state == PROTOCOL_COPY_IN:
self._process__copy_in(mtype)

elif state == PROTOCOL_COPY_IN_DATA:
self._process__copy_in_data(mtype)

elif state == PROTOCOL_CANCELLED:
# discard all messages until the sync message
if mtype == b'E':
Expand Down Expand Up @@ -356,6 +362,33 @@ cdef class CoreProtocol:
self._parse_msg_ready_for_query()
self._push_result()

cdef _process__copy_in(self, char mtype):
if mtype == b'E':
self._parse_msg_error_response(True)

elif mtype == b'G':
# CopyInResponse
self._set_state(PROTOCOL_COPY_IN_DATA)
self.buffer.consume_message()

elif mtype == b'Z':
# ReadyForQuery
self._parse_msg_ready_for_query()
self._push_result()

cdef _process__copy_in_data(self, char mtype):
if mtype == b'E':
self._parse_msg_error_response(True)

elif mtype == b'C':
# CommandComplete
self._parse_msg_command_complete()

elif mtype == b'Z':
# ReadyForQuery
self._parse_msg_ready_for_query()
self._push_result()

cdef _parse_msg_command_complete(self):
cdef:
char* cbuf
Expand Down Expand Up @@ -387,6 +420,42 @@ cdef class CoreProtocol:
self._on_result()
self.result = None

cdef _write_copy_data_msg(self, object data):
cdef:
WriteBuffer buf
object mview
Py_buffer *pybuf

mview = PyMemoryView_GetContiguous(data, cpython.PyBUF_SIMPLE, b'C')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't you need to close mview after you use it?


try:
pybuf = PyMemoryView_GET_BUFFER(mview)

buf = WriteBuffer.new_message(b'd')
buf.write_cstr(<const char *>pybuf.buf, pybuf.len)
buf.end_message()
finally:
mview.release()

self._write(buf)

cdef _write_copy_done_msg(self):
cdef:
WriteBuffer buf

buf = WriteBuffer.new_message(b'c')
buf.end_message()
self._write(buf)

cdef _write_copy_fail_msg(self, str cause):
cdef:
WriteBuffer buf

buf = WriteBuffer.new_message(b'f')
buf.write_str(cause or '', self.encoding)
buf.end_message()
self._write(buf)

cdef _parse_data_msgs(self):
cdef:
ReadBuffer buf = self.buffer
Expand Down Expand Up @@ -592,6 +661,10 @@ cdef class CoreProtocol:
new_state == PROTOCOL_COPY_OUT_DONE):
self.state = new_state

elif (self.state == PROTOCOL_COPY_IN and
new_state == PROTOCOL_COPY_IN_DATA):
self.state = new_state

elif self.state == PROTOCOL_FAILED:
raise RuntimeError(
'cannot switch to state {}; '
Expand Down Expand Up @@ -810,6 +883,17 @@ cdef class CoreProtocol:
buf.end_message()
self._write(buf)

cdef _copy_in(self, str copy_stmt):
cdef WriteBuffer buf

self._ensure_connected()
self._set_state(PROTOCOL_COPY_IN)

buf = WriteBuffer.new_message(b'Q')
buf.write_str(copy_stmt, self.encoding)
buf.end_message()
self._write(buf)

cdef _terminate(self):
cdef WriteBuffer buf
self._ensure_connected()
Expand Down
2 changes: 2 additions & 0 deletions asyncpg/protocol/protocol.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ cdef class BaseProtocol(CoreProtocol):

str last_query

bint writing_paused
bint closing

readonly uint64_t queries_count
Expand All @@ -58,6 +59,7 @@ cdef class BaseProtocol(CoreProtocol):
cdef _on_result__simple_query(self, object waiter)
cdef _on_result__bind(self, object waiter)
cdef _on_result__copy_out(self, object waiter)
cdef _on_result__copy_in(self, object waiter)

cdef _handle_waiter_on_connection_lost(self, cause)

Expand Down
Loading