From 7bee01b10596280da6ffda331c5a30ff3310d56d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Saugat=20Pachhai=20=28=E0=A4=B8=E0=A5=8C=E0=A4=97=E0=A4=BE?= =?UTF-8?q?=E0=A4=A4=29?= Date: Thu, 12 Dec 2024 08:40:06 +0545 Subject: [PATCH] cli: improve startup time This PR halves the total startuptime from 2.8s to 1.4s, with minimal effort. There are two more imports that can be reduced: requests, and numpy, that will reduce total time by .3 sec. The total import time spent is now 1.251s. It's difficult to fix remaining imports because we expose them on top-level imports, and fixing them will require too much duplication/effort or incur runtime costs. --- src/datachain/client/__init__.py | 3 +-- src/datachain/lib/dc.py | 6 +++++- src/datachain/lib/file.py | 3 ++- src/datachain/lib/meta_formats.py | 3 ++- src/datachain/query/dataset.py | 5 ++++- 5 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/datachain/client/__init__.py b/src/datachain/client/__init__.py index e922c949a..08f71799c 100644 --- a/src/datachain/client/__init__.py +++ b/src/datachain/client/__init__.py @@ -1,4 +1,3 @@ from .fsspec import Client -from .s3 import ClientS3 -__all__ = ["Client", "ClientS3"] +__all__ = ["Client"] diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index 57a215eda..85f2618ca 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -19,7 +19,6 @@ ) import orjson -import pandas as pd import sqlalchemy from pydantic import BaseModel from sqlalchemy.sql.functions import GenericFunction @@ -57,6 +56,7 @@ from datachain.utils import batched_it, inside_notebook, row_to_nested_dict if TYPE_CHECKING: + import pandas as pd from pyarrow import DataType as ArrowDataType from typing_extensions import Concatenate, ParamSpec, Self @@ -1701,6 +1701,8 @@ def to_pandas(self, flatten=False) -> "pd.DataFrame": Parameters: flatten : Whether to use a multiindex or flatten column names. """ + import pandas as pd + headers, max_length = self._effective_signals_schema.get_headers_with_length() if flatten or max_length < 2: columns = [".".join(filter(None, header)) for header in headers] @@ -1724,6 +1726,8 @@ def show( transpose : Whether to transpose rows and columns. truncate : Whether or not to truncate the contents of columns. """ + import pandas as pd + dc = self.limit(limit) if limit > 0 else self # type: ignore[misc] df = dc.to_pandas(flatten) diff --git a/src/datachain/lib/file.py b/src/datachain/lib/file.py index 10536eb04..35803e130 100644 --- a/src/datachain/lib/file.py +++ b/src/datachain/lib/file.py @@ -17,7 +17,6 @@ from fsspec.callbacks import DEFAULT_CALLBACK, Callback from PIL import Image -from pyarrow.dataset import dataset from pydantic import Field, field_validator from datachain.client.fileslice import FileSlice @@ -452,6 +451,8 @@ class ArrowRow(DataModel): @contextmanager def open(self): """Stream row contents from indexed file.""" + from pyarrow.dataset import dataset + if self.file._caching_enabled: self.file.ensure_cached() path = self.file.get_local_path() diff --git a/src/datachain/lib/meta_formats.py b/src/datachain/lib/meta_formats.py index 70473557a..d2fc40613 100644 --- a/src/datachain/lib/meta_formats.py +++ b/src/datachain/lib/meta_formats.py @@ -6,7 +6,6 @@ from pathlib import Path from typing import Callable -import datamodel_code_generator import jmespath as jsp from pydantic import BaseModel, ConfigDict, Field, ValidationError # noqa: F401 @@ -67,6 +66,8 @@ def read_schema(source_file, data_type="csv", expr=None, model_name=None): data_type = "json" # treat json line as plain JSON in auto-schema data_string = json.dumps(json_object) + import datamodel_code_generator + input_file_types = {i.value: i for i in datamodel_code_generator.InputFileType} input_file_type = input_file_types[data_type] with tempfile.TemporaryDirectory() as tmpdir: diff --git a/src/datachain/query/dataset.py b/src/datachain/query/dataset.py index 46058ba83..567156cb8 100644 --- a/src/datachain/query/dataset.py +++ b/src/datachain/query/dataset.py @@ -35,7 +35,6 @@ from sqlalchemy.sql.selectable import Select from datachain.asyn import ASYNC_WORKERS, AsyncMapper, OrderedMapper -from datachain.catalog import QUERY_SCRIPT_CANCELED_EXIT_CODE, get_catalog from datachain.data_storage.schema import ( PARTITION_COLUMN_ID, partition_col_names, @@ -394,6 +393,8 @@ def create_result_query( """ def populate_udf_table(self, udf_table: "Table", query: Select) -> None: + from datachain.catalog import QUERY_SCRIPT_CANCELED_EXIT_CODE + use_partitioning = self.partition_by is not None batching = self.udf.get_batching(use_partitioning) workers = self.workers @@ -1087,6 +1088,8 @@ def get_table() -> "TableClause": def delete( name: str, version: Optional[int] = None, catalog: Optional["Catalog"] = None ) -> None: + from datachain.catalog import get_catalog + catalog = catalog or get_catalog() version = version or catalog.get_dataset(name).latest_version catalog.remove_dataset(name, version)