Skip to content

Commit ef0e80b

Browse files
authored
[Data] Add Dataset.write_sql (#38544)
Writing data back to databases is common for many applications like LLMs. For example, you might want to write vector indices back to a database like https://github.com/pgvector/pgvector. To support this use case, this PR adds an API to write Datasets to SQL databases. Signed-off-by: Balaji Veeramani <balaji@anyscale.com>
1 parent e9c7bc7 commit ef0e80b

File tree

4 files changed

+142
-18
lines changed

4 files changed

+142
-18
lines changed

python/ray/data/dataset.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@
115115
)
116116
from ray.data.datasource import (
117117
BlockWritePathProvider,
118+
Connection,
118119
CSVDatasource,
119120
Datasource,
120121
DefaultBlockWritePathProvider,
@@ -123,6 +124,7 @@
123124
NumpyDatasource,
124125
ParquetDatasource,
125126
ReadTask,
127+
SQLDatasource,
126128
TFRecordDatasource,
127129
WriteResult,
128130
)
@@ -3215,6 +3217,68 @@ def write_numpy(
32153217
block_path_provider=block_path_provider,
32163218
)
32173219

3220+
@ConsumptionAPI
3221+
def write_sql(
3222+
self,
3223+
sql: str,
3224+
connection_factory: Callable[[], Connection],
3225+
ray_remote_args: Optional[Dict[str, Any]] = None,
3226+
) -> None:
3227+
"""Write to a database that provides a
3228+
`Python DB API2-compliant <https://peps.python.org/pep-0249/>`_ connector.
3229+
3230+
.. note::
3231+
3232+
This method writes data in parallel using the DB API2 ``executemany``
3233+
method. To learn more about this method, see
3234+
`PEP 249 <https://peps.python.org/pep-0249/#executemany>`_.
3235+
3236+
Examples:
3237+
3238+
.. testcode::
3239+
3240+
import sqlite3
3241+
import ray
3242+
3243+
connection = sqlite3.connect("example.db")
3244+
connection.cursor().execute("CREATE TABLE movie(title, year, score)")
3245+
dataset = ray.data.from_items([
3246+
{"title": "Monty Python and the Holy Grail", "year": 1975, "score": 8.2},
3247+
{"title": "And Now for Something Completely Different", "year": 1971, "score": 7.5}
3248+
])
3249+
3250+
dataset.write_sql(
3251+
"INSERT INTO movie VALUES(?, ?, ?)", lambda: sqlite3.connect("example.db")
3252+
)
3253+
3254+
result = connection.cursor().execute("SELECT * FROM movie ORDER BY year")
3255+
print(result.fetchall())
3256+
3257+
.. testoutput::
3258+
3259+
[('And Now for Something Completely Different', 1971, 7.5), ('Monty Python and the Holy Grail', 1975, 8.2)]
3260+
3261+
.. testcode::
3262+
:hide:
3263+
3264+
import os
3265+
os.remove("example.db")
3266+
3267+
Arguments:
3268+
sql: An ``INSERT INTO`` statement that specifies the table to write to. The
3269+
number of parameters must match the number of columns in the table.
3270+
connection_factory: A function that takes no arguments and returns a
3271+
Python DB API2
3272+
`Connection object <https://peps.python.org/pep-0249/#connection-objects>`_.
3273+
ray_remote_args: Keyword arguments passed to :meth:`~ray.remote` in the
3274+
write tasks.
3275+
""" # noqa: E501
3276+
self.write_datasource(
3277+
SQLDatasource(connection_factory),
3278+
ray_remote_args=ray_remote_args,
3279+
sql=sql,
3280+
)
3281+
32183282
@ConsumptionAPI
32193283
def write_mongo(
32203284
self,

python/ray/data/datasource/sql_datasource.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
from contextlib import contextmanager
33
from typing import Any, Callable, Iterable, Iterator, List, Optional
44

5+
from ray.data._internal.execution.interfaces import TaskContext
56
from ray.data.block import Block, BlockAccessor, BlockMetadata
6-
from ray.data.datasource.datasource import Datasource, Reader, ReadTask
7+
from ray.data.datasource.datasource import Datasource, Reader, ReadTask, WriteResult
78
from ray.util.annotations import PublicAPI
89

910
Connection = Any # A Python DB API2-compliant `Connection` object.
@@ -23,12 +24,38 @@ def _cursor_to_block(cursor) -> Block:
2324

2425
@PublicAPI(stability="alpha")
2526
class SQLDatasource(Datasource):
27+
28+
_MAX_ROWS_PER_WRITE = 128
29+
2630
def __init__(self, connection_factory: Callable[[], Connection]):
2731
self.connection_factory = connection_factory
2832

2933
def create_reader(self, sql: str) -> "Reader":
3034
return _SQLReader(sql, self.connection_factory)
3135

36+
def write(
37+
self,
38+
blocks: Iterable[Block],
39+
ctx: TaskContext,
40+
sql: str,
41+
) -> WriteResult:
42+
with _connect(self.connection_factory) as cursor:
43+
for block in blocks:
44+
block_accessor = BlockAccessor.for_block(block)
45+
46+
values = []
47+
for row in block_accessor.iter_rows(public_row_format=False):
48+
values.append(tuple(row.values()))
49+
assert len(values) <= self._MAX_ROWS_PER_WRITE, len(values)
50+
if len(values) == self._MAX_ROWS_PER_WRITE:
51+
cursor.executemany(sql, values)
52+
values = []
53+
54+
if values:
55+
cursor.executemany(sql, values)
56+
57+
return "ok"
58+
3259

3360
def _check_connection_is_dbapi2_compliant(connection) -> None:
3461
for attr in "close", "commit", "cursor":
@@ -44,7 +71,7 @@ def _check_connection_is_dbapi2_compliant(connection) -> None:
4471
def _check_cursor_is_dbapi2_compliant(cursor) -> None:
4572
# These aren't all the methods required by the specification, but it's all the ones
4673
# we care about.
47-
for attr in "execute", "fetchone", "fetchall", "description":
74+
for attr in "execute", "executemany", "fetchone", "fetchall", "description":
4875
if not hasattr(cursor, attr):
4976
raise ValueError(
5077
"Your database connector created a `Cursor` object without a "
@@ -63,26 +90,25 @@ def _connect(connection_factory: Callable[[], Connection]) -> Iterator[Cursor]:
6390
cursor = connection.cursor()
6491
_check_cursor_is_dbapi2_compliant(cursor)
6592
yield cursor
66-
67-
finally:
93+
connection.commit()
94+
except Exception:
6895
# `rollback` is optional since not all databases provide transaction support.
6996
try:
7097
connection.rollback()
7198
except Exception as e:
7299
# Each connector implements its own `NotSupportError` class, so we check
73100
# the exception's name instead of using `isinstance`.
74-
if not (
101+
if (
75102
isinstance(e, AttributeError)
76103
or e.__class__.__name__ == "NotSupportedError"
77104
):
78-
raise e from None
79-
80-
connection.commit()
105+
pass
106+
raise
107+
finally:
81108
connection.close()
82109

83110

84111
class _SQLReader(Reader):
85-
86112
NUM_SAMPLE_ROWS = 100
87113
MIN_ROWS_PER_READ_TASK = 50
88114

python/ray/data/read_api.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1680,15 +1680,6 @@ def read_sql(
16801680
For examples of reading from larger databases like MySQL and PostgreSQL, see
16811681
:ref:`Reading from SQL Databases <reading_sql>`.
16821682
1683-
.. testcode::
1684-
:hide:
1685-
1686-
import os
1687-
try:
1688-
os.remove("example.db")
1689-
except OSError:
1690-
pass
1691-
16921683
.. testcode::
16931684
16941685
import sqlite3
@@ -1724,6 +1715,12 @@ def create_connection():
17241715
"SELECT year, COUNT(*) FROM movie GROUP BY year", create_connection
17251716
)
17261717
1718+
.. testcode::
1719+
:hide:
1720+
1721+
import os
1722+
os.remove("example.db")
1723+
17271724
Args:
17281725
sql: The SQL query to execute.
17291726
connection_factory: A function that takes no arguments and returns a

python/ray/data/tests/test_sql.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,40 @@ def test_read_sql(temp_database: str, parallelism: int):
3333
actual_values = [tuple(record.values()) for record in dataset.take_all()]
3434

3535
assert sorted(actual_values) == sorted(expected_values)
36+
37+
38+
def test_write_sql(temp_database: str):
39+
connection = sqlite3.connect(temp_database)
40+
connection.cursor().execute("CREATE TABLE test(string, number)")
41+
dataset = ray.data.from_items(
42+
[{"string": "spam", "number": 0}, {"string": "ham", "number": 1}]
43+
)
44+
45+
dataset.write_sql(
46+
"INSERT INTO test VALUES(?, ?)", lambda: sqlite3.connect(temp_database)
47+
)
48+
49+
result = connection.cursor().execute("SELECT * FROM test ORDER BY number")
50+
assert result.fetchall() == [("spam", 0), ("ham", 1)]
51+
52+
53+
@pytest.mark.parametrize("num_blocks", (1, 20))
54+
def test_write_sql_many_rows(num_blocks: int, temp_database: str):
55+
connection = sqlite3.connect(temp_database)
56+
connection.cursor().execute("CREATE TABLE test(id)")
57+
dataset = ray.data.range(1000).repartition(num_blocks)
58+
59+
dataset.write_sql(
60+
"INSERT INTO test VALUES(?)", lambda: sqlite3.connect(temp_database)
61+
)
62+
63+
result = connection.cursor().execute("SELECT * FROM test ORDER BY id")
64+
assert result.fetchall() == [(i,) for i in range(1000)]
65+
66+
67+
def test_write_sql_nonexistant_table(temp_database: str):
68+
dataset = ray.data.range(1)
69+
with pytest.raises(sqlite3.OperationalError):
70+
dataset.write_sql(
71+
"INSERT INTO test VALUES(?)", lambda: sqlite3.connect(temp_database)
72+
)

0 commit comments

Comments
 (0)