@@ -866,7 +866,7 @@ async def copy_to_table(self, table_name, *, source,
866866 delimiter = None , null = None , header = None ,
867867 quote = None , escape = None , force_quote = None ,
868868 force_not_null = None , force_null = None ,
869- encoding = None ):
869+ encoding = None , where = None ):
870870 """Copy data to the specified table.
871871
872872 :param str table_name:
@@ -885,6 +885,15 @@ async def copy_to_table(self, table_name, *, source,
885885 :param str schema_name:
886886 An optional schema name to qualify the table.
887887
888+ :param str where:
889+ An optional SQL expression used to filter rows when copying.
890+
891+ .. note::
892+
893+ Usage of this parameter requires support for the
894+ ``COPY FROM ... WHERE`` syntax, introduced in
895+ PostgreSQL version 12.
896+
888897 :param float timeout:
889898 Optional timeout value in seconds.
890899
@@ -912,6 +921,9 @@ async def copy_to_table(self, table_name, *, source,
912921 https://www.postgresql.org/docs/current/static/sql-copy.html
913922
914923 .. versionadded:: 0.11.0
924+
925+ .. versionadded:: 0.29.0
926+ Added the *where* parameter.
915927 """
916928 tabname = utils ._quote_ident (table_name )
917929 if schema_name :
@@ -923,21 +935,22 @@ async def copy_to_table(self, table_name, *, source,
923935 else :
924936 cols = ''
925937
938+ cond = self ._format_copy_where (where )
926939 opts = self ._format_copy_opts (
927940 format = format , oids = oids , freeze = freeze , delimiter = delimiter ,
928941 null = null , header = header , quote = quote , escape = escape ,
929942 force_not_null = force_not_null , force_null = force_null ,
930943 encoding = encoding
931944 )
932945
933- copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}' .format (
934- tab = tabname , cols = cols , opts = opts )
946+ copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts} {cond} ' .format (
947+ tab = tabname , cols = cols , opts = opts , cond = cond )
935948
936949 return await self ._copy_in (copy_stmt , source , timeout )
937950
938951 async def copy_records_to_table (self , table_name , * , records ,
939952 columns = None , schema_name = None ,
940- timeout = None ):
953+ timeout = None , where = None ):
941954 """Copy a list of records to the specified table using binary COPY.
942955
943956 :param str table_name:
@@ -954,6 +967,16 @@ async def copy_records_to_table(self, table_name, *, records,
954967 :param str schema_name:
955968 An optional schema name to qualify the table.
956969
970+ :param str where:
971+ An optional SQL expression used to filter rows when copying.
972+
973+ .. note::
974+
975+ Usage of this parameter requires support for the
976+ ``COPY FROM ... WHERE`` syntax, introduced in
977+ PostgreSQL version 12.
978+
979+
957980 :param float timeout:
958981 Optional timeout value in seconds.
959982
@@ -998,6 +1021,9 @@ async def copy_records_to_table(self, table_name, *, records,
9981021
9991022 .. versionchanged:: 0.24.0
10001023 The ``records`` argument may be an asynchronous iterable.
1024+
1025+ .. versionadded:: 0.29.0
1026+ Added the *where* parameter.
10011027 """
10021028 tabname = utils ._quote_ident (table_name )
10031029 if schema_name :
@@ -1015,14 +1041,27 @@ async def copy_records_to_table(self, table_name, *, records,
10151041
10161042 intro_ps = await self ._prepare (intro_query , use_cache = True )
10171043
1044+ cond = self ._format_copy_where (where )
10181045 opts = '(FORMAT binary)'
10191046
1020- copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}' .format (
1021- tab = tabname , cols = cols , opts = opts )
1047+ copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts} {cond} ' .format (
1048+ tab = tabname , cols = cols , opts = opts , cond = cond )
10221049
10231050 return await self ._protocol .copy_in (
10241051 copy_stmt , None , None , records , intro_ps ._state , timeout )
10251052
1053+ def _format_copy_where (self , where ):
1054+ if where and not self ._server_caps .sql_copy_from_where :
1055+ raise exceptions .UnsupportedServerFeatureError (
1056+ 'the `where` parameter requires PostgreSQL 12 or later' )
1057+
1058+ if where :
1059+ where_clause = 'WHERE ' + where
1060+ else :
1061+ where_clause = ''
1062+
1063+ return where_clause
1064+
10261065 def _format_copy_opts (self , * , format = None , oids = None , freeze = None ,
10271066 delimiter = None , null = None , header = None , quote = None ,
10281067 escape = None , force_quote = None , force_not_null = None ,
@@ -2404,7 +2443,7 @@ class _ConnectionProxy:
24042443ServerCapabilities = collections .namedtuple (
24052444 'ServerCapabilities' ,
24062445 ['advisory_locks' , 'notifications' , 'plpgsql' , 'sql_reset' ,
2407- 'sql_close_all' , 'jit' ])
2446+ 'sql_close_all' , 'sql_copy_from_where' , ' jit' ])
24082447ServerCapabilities .__doc__ = 'PostgreSQL server capabilities.'
24092448
24102449
@@ -2417,6 +2456,7 @@ def _detect_server_capabilities(server_version, connection_settings):
24172456 sql_reset = True
24182457 sql_close_all = False
24192458 jit = False
2459+ sql_copy_from_where = False
24202460 elif hasattr (connection_settings , 'crdb_version' ):
24212461 # CockroachDB detected.
24222462 advisory_locks = False
@@ -2425,6 +2465,7 @@ def _detect_server_capabilities(server_version, connection_settings):
24252465 sql_reset = False
24262466 sql_close_all = False
24272467 jit = False
2468+ sql_copy_from_where = False
24282469 elif hasattr (connection_settings , 'crate_version' ):
24292470 # CrateDB detected.
24302471 advisory_locks = False
@@ -2433,6 +2474,7 @@ def _detect_server_capabilities(server_version, connection_settings):
24332474 sql_reset = False
24342475 sql_close_all = False
24352476 jit = False
2477+ sql_copy_from_where = False
24362478 else :
24372479 # Standard PostgreSQL server assumed.
24382480 advisory_locks = True
@@ -2441,13 +2483,15 @@ def _detect_server_capabilities(server_version, connection_settings):
24412483 sql_reset = True
24422484 sql_close_all = True
24432485 jit = server_version >= (11 , 0 )
2486+ sql_copy_from_where = server_version .major >= 12
24442487
24452488 return ServerCapabilities (
24462489 advisory_locks = advisory_locks ,
24472490 notifications = notifications ,
24482491 plpgsql = plpgsql ,
24492492 sql_reset = sql_reset ,
24502493 sql_close_all = sql_close_all ,
2494+ sql_copy_from_where = sql_copy_from_where ,
24512495 jit = jit ,
24522496 )
24532497
0 commit comments