Skip to content

Commit

Permalink
Fix mutate() (#78)
Browse files Browse the repository at this point in the history
  • Loading branch information
dmpetrov authored Jul 18, 2024
1 parent d244943 commit d6a3aef
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 17 deletions.
8 changes: 8 additions & 0 deletions examples/llm-claude-simple-query.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,12 @@ class Rating(BaseModel):
)
)

chain = chain.settings(parallel=13).mutate(
x=Column("file.name"),
y=Column("rating.status"),
price=Column("claude.usage.output_tokens") * 0.0072,
)

# chain.print_schema()

chain.show()
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
String,
)

TYPE_TO_DATACHAIN = {
PYTHON_TO_SQL = {
int: Int64,
str: String,
Literal: String,
Expand All @@ -34,14 +34,14 @@
}


def convert_to_db_type(typ): # noqa: PLR0911
def python_to_sql(typ): # noqa: PLR0911
if inspect.isclass(typ):
if issubclass(typ, SQLType):
return typ
if issubclass(typ, Enum):
return str

res = TYPE_TO_DATACHAIN.get(typ)
res = PYTHON_TO_SQL.get(typ)
if res:
return res

Expand All @@ -59,19 +59,19 @@ def convert_to_db_type(typ): # noqa: PLR0911
if ModelStore.is_pydantic(args0):
return Array(JSON())

next_type = convert_to_db_type(args0)
next_type = python_to_sql(args0)
return Array(next_type)

if orig is Annotated:
# Ignoring annotations
return convert_to_db_type(args[0])
return python_to_sql(args[0])

if inspect.isclass(orig) and issubclass(dict, orig):
return JSON

if orig == Union:
if len(args) == 2 and (type(None) in args):
return convert_to_db_type(args[0])
return python_to_sql(args[0])

if _is_json_inside_union(orig, args):
return JSON
Expand Down
23 changes: 23 additions & 0 deletions src/datachain/lib/convert/sql_to_python.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from datetime import datetime
from typing import Any

from sqlalchemy import ARRAY, JSON, Boolean, DateTime, Float, Integer, String

from datachain.data_storage.sqlite import Column

SQL_TO_PYTHON = {
String: str,
Integer: int,
Float: float,
Boolean: bool,
DateTime: datetime,
ARRAY: list,
JSON: dict,
}


def sql_to_python(args_map: dict[str, Column]) -> dict[str, Any]:
return {
k: SQL_TO_PYTHON.get(type(v.type), str) # type: ignore[union-attr]
for k, v in args_map.items()
}
29 changes: 29 additions & 0 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,35 @@ def select_except(self, *args: str) -> "Self":
chain.signals_schema = new_schema
return chain

@detach
def mutate(self, **kwargs) -> "Self":
"""Create new signals based on existing signals.
This method is vectorized and more efficient compared to map(), and it does not
extract or download any data from the internal database. However, it can only
utilize predefined built-in functions and their combinations.
The supported functions:
Numerical: +, -, *, /, rand(), avg(), count(), func(),
greatest(), least(), max(), min(), sum()
String: length(), split()
Filename: name(), parent(), file_stem(), file_ext()
Array: length(), sip_hash_64(), euclidean_distance(),
cosine_distance()
Example:
```py
dc.mutate(
area=Column("image.height") * Column("image.width"),
extension=file_ext(Column("file.name")),
dist=cosine_distance(embedding_text, embedding_image)
)
```
"""
chain = super().mutate(**kwargs)
chain.signals_schema = self.signals_schema.mutate(kwargs)
return chain

def iterate_flatten(self) -> Iterator[tuple[Any]]: # noqa: D102
db_signals = self.signals_schema.db_signals()
with super().select(*db_signals).as_iterable() as rows:
Expand Down
17 changes: 10 additions & 7 deletions src/datachain/lib/signal_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from typing_extensions import Literal as LiteralEx

from datachain.lib.convert.flatten import DATACHAIN_TO_TYPE
from datachain.lib.convert.type_converter import convert_to_db_type
from datachain.lib.convert.python_to_sql import python_to_sql
from datachain.lib.convert.sql_to_python import sql_to_python
from datachain.lib.convert.unflatten import unflatten_to_json_pos
from datachain.lib.data_model import DataModel, DataType
from datachain.lib.file import File
Expand Down Expand Up @@ -102,14 +103,13 @@ def _init_setup_values(self):
@staticmethod
def from_column_types(col_types: dict[str, Any]) -> "SignalSchema":
signals: dict[str, DataType] = {}
for field, type_ in col_types.items():
type_ = DATACHAIN_TO_TYPE.get(type_, None)
if type_ is None:
for field, col_type in col_types.items():
if (py_type := DATACHAIN_TO_TYPE.get(col_type, None)) is None:
raise SignalSchemaError(
f"signal schema cannot be obtained for column '{field}':"
f" unsupported type '{type_}'"
f" unsupported type '{py_type}'"
)
signals[field] = type_
signals[field] = py_type
return SignalSchema(signals)

def serialize(self) -> dict[str, str]:
Expand Down Expand Up @@ -161,7 +161,7 @@ def to_udf_spec(self) -> dict[str, type]:
continue
if not has_subtree:
db_name = DEFAULT_DELIMITER.join(path)
res[db_name] = convert_to_db_type(type_)
res[db_name] = python_to_sql(type_)
return res

def row_to_objs(self, row: Sequence[Any]) -> list[DataType]:
Expand Down Expand Up @@ -278,6 +278,9 @@ def clone_without_file_signals(self) -> "SignalSchema":
del schema[signal]
return SignalSchema(schema)

def mutate(self, args_map: dict) -> "SignalSchema":
return SignalSchema(self.values | sql_to_python(args_map))

def merge(
self,
right_schema: "SignalSchema",
Expand Down
13 changes: 13 additions & 0 deletions tests/unit/lib/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest
from pydantic import BaseModel

from datachain import Column
from datachain.lib.dc import C, DataChain, Sys
from datachain.lib.file import File
from datachain.lib.signal_schema import (
Expand Down Expand Up @@ -899,3 +900,15 @@ def test_to_pandas_multi_level():
assert "nnn" in df["t1"].columns
assert "count" in df["t1"].columns
assert df["t1"]["count"].tolist() == [3, 5, 1]


def test_mutate():
chain = DataChain.from_values(t1=features).mutate(
circle=2 * 3.14 * Column("t1.count"), place="pref_" + Column("t1.nnn")
)

assert chain.signals_schema.values["circle"] is float
assert chain.signals_schema.values["place"] is str

expected = [fr.count * 2 * 3.14 for fr in features]
np.testing.assert_allclose(chain.collect_one("circle"), expected)
8 changes: 4 additions & 4 deletions tests/unit/lib/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
from pydantic import BaseModel

from datachain.lib.convert.type_converter import convert_to_db_type
from datachain.lib.convert.python_to_sql import python_to_sql
from datachain.sql.types import JSON, Array, String


Expand All @@ -29,7 +29,7 @@ class MyFeature(BaseModel):
),
)
def test_convert_type_to_datachain(typ, expected):
assert convert_to_db_type(typ) == expected
assert python_to_sql(typ) == expected


@pytest.mark.parametrize(
Expand All @@ -41,7 +41,7 @@ def test_convert_type_to_datachain(typ, expected):
),
)
def test_convert_type_to_datachain_array(typ, expected):
assert convert_to_db_type(typ).to_dict() == expected.to_dict()
assert python_to_sql(typ).to_dict() == expected.to_dict()


@pytest.mark.parametrize(
Expand All @@ -55,4 +55,4 @@ def test_convert_type_to_datachain_array(typ, expected):
)
def test_convert_type_to_datachain_error(typ):
with pytest.raises(TypeError):
convert_to_db_type(typ)
python_to_sql(typ)

0 comments on commit d6a3aef

Please sign in to comment.