Skip to content

Commit

Permalink
Fixes settings (#397)
Browse files Browse the repository at this point in the history
* fixes settings

* drop settings from listings

* is_pydantic check for class

---------

Co-authored-by: Matt Seddon <37993418+mattseddon@users.noreply.github.com>
  • Loading branch information
Dave Berenbaum and mattseddon authored Sep 9, 2024
1 parent 7d74a46 commit 823f9f8
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 5 deletions.
21 changes: 17 additions & 4 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def from_storage(
.save(list_dataset_name, listing=True)
)

dc = cls.from_dataset(list_dataset_name, session=session)
dc = cls.from_dataset(list_dataset_name, session=session, settings=settings)
dc.signals_schema = dc.signals_schema.mutate({f"{object_name}": file_type})

return ls(dc, list_path, recursive=recursive, object_name=object_name)
Expand All @@ -426,6 +426,7 @@ def from_dataset(
name: str,
version: Optional[int] = None,
session: Optional[Session] = None,
settings: Optional[dict] = None,
) -> "DataChain":
"""Get data from a saved Dataset. It returns the chain itself.
Expand All @@ -438,7 +439,7 @@ def from_dataset(
chain = DataChain.from_dataset("my_cats")
```
"""
return DataChain(name=name, version=version, session=session)
return DataChain(name=name, version=version, session=session, settings=settings)

@classmethod
def from_json(
Expand Down Expand Up @@ -1622,6 +1623,8 @@ def from_csv(
model_name: str = "",
source: bool = True,
nrows=None,
session: Optional[Session] = None,
settings: Optional[dict] = None,
**kwargs,
) -> "DataChain":
"""Generate chain from csv files.
Expand All @@ -1638,6 +1641,8 @@ def from_csv(
model_name : Generated model name.
source : Whether to include info about the source file.
nrows : Optional row limit.
session : Session to use for the chain.
settings : Settings to use for the chain.
Example:
Reading a csv file:
Expand All @@ -1654,7 +1659,9 @@ def from_csv(
from pyarrow.csv import ConvertOptions, ParseOptions, ReadOptions
from pyarrow.dataset import CsvFileFormat

chain = DataChain.from_storage(path, **kwargs)
chain = DataChain.from_storage(
path, session=session, settings=settings, **kwargs
)

column_names = None
if not header:
Expand Down Expand Up @@ -1701,6 +1708,8 @@ def from_parquet(
object_name: str = "",
model_name: str = "",
source: bool = True,
session: Optional[Session] = None,
settings: Optional[dict] = None,
**kwargs,
) -> "DataChain":
"""Generate chain from parquet files.
Expand All @@ -1713,6 +1722,8 @@ def from_parquet(
object_name : Created object column name.
model_name : Generated model name.
source : Whether to include info about the source file.
session : Session to use for the chain.
settings : Settings to use for the chain.
Example:
Reading a single file:
Expand All @@ -1725,7 +1736,9 @@ def from_parquet(
dc = DataChain.from_parquet("s3://mybucket/dir")
```
"""
chain = DataChain.from_storage(path, **kwargs)
chain = DataChain.from_storage(
path, session=session, settings=settings, **kwargs
)
return chain.parse_tabular(
output=output,
object_name=object_name,
Expand Down
7 changes: 6 additions & 1 deletion src/datachain/lib/model_store.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import logging
from typing import ClassVar, Optional

Expand Down Expand Up @@ -69,7 +70,11 @@ def remove(cls, fr: type) -> None:

@staticmethod
def is_pydantic(val):
return not hasattr(val, "__origin__") and issubclass(val, BaseModel)
return (
not hasattr(val, "__origin__")
and inspect.isclass(val)
and issubclass(val, BaseModel)
)

@staticmethod
def to_pydantic(val) -> Optional[type[BaseModel]]:
Expand Down

0 comments on commit 823f9f8

Please sign in to comment.