|
7 | 7 |
|
8 | 8 | import asyncio
|
9 | 9 | import collections
|
| 10 | +import collections.abc |
10 | 11 | import struct
|
11 | 12 | import time
|
12 | 13 |
|
@@ -451,6 +452,115 @@ async def copy_from_query(self, query, *args, output,
|
451 | 452 |
|
452 | 453 | return await self._copy_out(copy_stmt, output, timeout)
|
453 | 454 |
|
| 455 | + async def copy_to_table(self, table_name, *, source, |
| 456 | + columns=None, schema_name=None, timeout=None, |
| 457 | + format=None, oids=None, freeze=None, |
| 458 | + delimiter=None, null=None, header=None, |
| 459 | + quote=None, escape=None, force_quote=None, |
| 460 | + force_not_null=None, force_null=None, |
| 461 | + encoding=None): |
| 462 | + """Copy data to the specified table. |
| 463 | +
|
| 464 | + :param str table_name: |
| 465 | + The name of the table to copy data to. |
| 466 | +
|
| 467 | + :param source: |
| 468 | + A :term:`path-like object <python:path-like object>`, |
| 469 | + or a :term:`file-like object <python:file-like object>`, or |
| 470 | + an :term:`asynchronous iterable <python:asynchronous iterable>` |
| 471 | + that returns ``bytes``, or an object supporting the |
| 472 | + :term:`buffer protocol <python:buffer protocol>`. |
| 473 | +
|
| 474 | + :param list columns: |
| 475 | + An optional list of column names to copy. |
| 476 | +
|
| 477 | + :param str schema_name: |
| 478 | + An optional schema name to qualify the table. |
| 479 | +
|
| 480 | + :param float timeout: |
| 481 | + Optional timeout value in seconds. |
| 482 | +
|
| 483 | + The remaining kewyword arguments are ``COPY`` statement options, |
| 484 | + see `COPY statement documentation`_ for details. |
| 485 | +
|
| 486 | + :return: The status string of the COPY command. |
| 487 | +
|
| 488 | + .. versionadded:: 0.11.0 |
| 489 | +
|
| 490 | + .. _`COPY statement documentation`: https://www.postgresql.org/docs/\ |
| 491 | + current/static/sql-copy.html |
| 492 | +
|
| 493 | + """ |
| 494 | + tabname = utils._quote_ident(table_name) |
| 495 | + if schema_name: |
| 496 | + tabname = utils._quote_ident(schema_name) + '.' + tabname |
| 497 | + |
| 498 | + if columns: |
| 499 | + cols = '({})'.format( |
| 500 | + ', '.join(utils._quote_ident(c) for c in columns)) |
| 501 | + else: |
| 502 | + cols = '' |
| 503 | + |
| 504 | + opts = self._format_copy_opts( |
| 505 | + format=format, oids=oids, freeze=freeze, delimiter=delimiter, |
| 506 | + null=null, header=header, quote=quote, escape=escape, |
| 507 | + force_not_null=force_not_null, force_null=force_null, |
| 508 | + encoding=encoding |
| 509 | + ) |
| 510 | + |
| 511 | + copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}'.format( |
| 512 | + tab=tabname, cols=cols, opts=opts) |
| 513 | + |
| 514 | + return await self._copy_in(copy_stmt, source, timeout) |
| 515 | + |
| 516 | + async def copy_records_to_table(self, table_name, *, records, |
| 517 | + columns=None, schema_name=None, |
| 518 | + timeout=None): |
| 519 | + """Copy a list of records to the specified table using binary COPY. |
| 520 | +
|
| 521 | + :param str table_name: |
| 522 | + The name of the table to copy data to. |
| 523 | +
|
| 524 | + :param records: |
| 525 | + An iterable returning row tuples to copy into the table. |
| 526 | +
|
| 527 | + :param list columns: |
| 528 | + An optional list of column names to copy. |
| 529 | +
|
| 530 | + :param str schema_name: |
| 531 | + An optional schema name to qualify the table. |
| 532 | +
|
| 533 | + :param float timeout: |
| 534 | + Optional timeout value in seconds. |
| 535 | +
|
| 536 | + :return: The status string of the COPY command. |
| 537 | +
|
| 538 | + .. versionadded:: 0.11.0 |
| 539 | + """ |
| 540 | + tabname = utils._quote_ident(table_name) |
| 541 | + if schema_name: |
| 542 | + tabname = utils._quote_ident(schema_name) + '.' + tabname |
| 543 | + |
| 544 | + if columns: |
| 545 | + col_list = ', '.join(utils._quote_ident(c) for c in columns) |
| 546 | + cols = '({})'.format(col_list) |
| 547 | + else: |
| 548 | + col_list = '*' |
| 549 | + cols = '' |
| 550 | + |
| 551 | + intro_query = 'SELECT {cols} FROM {tab} LIMIT 1'.format( |
| 552 | + tab=tabname, cols=col_list) |
| 553 | + |
| 554 | + intro_ps = await self.prepare(intro_query) |
| 555 | + |
| 556 | + opts = '(FORMAT binary)' |
| 557 | + |
| 558 | + copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}'.format( |
| 559 | + tab=tabname, cols=cols, opts=opts) |
| 560 | + |
| 561 | + return await self._copy_in_records( |
| 562 | + copy_stmt, records, intro_ps._state, timeout) |
| 563 | + |
454 | 564 | def _format_copy_opts(self, *, format=None, oids=None, freeze=None,
|
455 | 565 | delimiter=None, null=None, header=None, quote=None,
|
456 | 566 | escape=None, force_quote=None, force_not_null=None,
|
@@ -519,6 +629,60 @@ async def _writer(data):
|
519 | 629 | if opened_by_us:
|
520 | 630 | f.close()
|
521 | 631 |
|
| 632 | + async def _copy_in(self, copy_stmt, source, timeout): |
| 633 | + try: |
| 634 | + path = compat.fspath(source) |
| 635 | + except TypeError: |
| 636 | + # source is not a path-like object |
| 637 | + path = None |
| 638 | + |
| 639 | + f = None |
| 640 | + reader = None |
| 641 | + data = None |
| 642 | + opened_by_us = False |
| 643 | + run_in_executor = self._loop.run_in_executor |
| 644 | + |
| 645 | + if path is not None: |
| 646 | + # a path |
| 647 | + f = await run_in_executor(None, open, path, 'wb') |
| 648 | + opened_by_us = True |
| 649 | + elif hasattr(source, 'read'): |
| 650 | + # file-like |
| 651 | + f = source |
| 652 | + elif isinstance(source, collections.abc.AsyncIterable): |
| 653 | + # assuming calling output returns an awaitable. |
| 654 | + reader = source |
| 655 | + else: |
| 656 | + # assuming source is an instance supporting the buffer protocol. |
| 657 | + data = source |
| 658 | + |
| 659 | + if f is not None: |
| 660 | + # Copying from a file-like object. |
| 661 | + class _Reader: |
| 662 | + @compat.aiter_compat |
| 663 | + def __aiter__(self): |
| 664 | + return self |
| 665 | + |
| 666 | + async def __anext__(self): |
| 667 | + data = await run_in_executor(None, f.read, 524288) |
| 668 | + if len(data) == 0: |
| 669 | + raise StopAsyncIteration |
| 670 | + else: |
| 671 | + return data |
| 672 | + |
| 673 | + reader = _Reader() |
| 674 | + |
| 675 | + try: |
| 676 | + return await self._protocol.copy_in( |
| 677 | + copy_stmt, reader, data, None, None, timeout) |
| 678 | + finally: |
| 679 | + if opened_by_us: |
| 680 | + await run_in_executor(None, f.close) |
| 681 | + |
| 682 | + async def _copy_in_records(self, copy_stmt, records, intro_stmt, timeout): |
| 683 | + return await self._protocol.copy_in( |
| 684 | + copy_stmt, None, None, records, intro_stmt, timeout) |
| 685 | + |
522 | 686 | async def set_type_codec(self, typename, *,
|
523 | 687 | schema='public', encoder, decoder, binary=False):
|
524 | 688 | """Set an encoder/decoder pair for the specified data type.
|
|
0 commit comments