From 8a0d881986486d053cd36d41b5c0e862120e941b Mon Sep 17 00:00:00 2001 From: Adam Glustein Date: Tue, 7 May 2024 17:29:26 -0400 Subject: [PATCH] Remove all caching code from CSP (#213) Signed-off-by: Adam Glustein --- csp/__init__.py | 3 +- csp/baselib.py | 5 +- csp/cache_support.py | 16 - csp/impl/config.py | 51 - csp/impl/managed_dataset/__init__.py | 0 .../aggregation_period_utils.py | 87 - .../cache_partition_argument_serializer.py | 101 - .../cache_user_custom_object_serializer.py | 11 - csp/impl/managed_dataset/dataset_metadata.py | 85 - .../managed_dataset/dateset_name_constants.py | 4 - csp/impl/managed_dataset/datetime_utils.py | 8 - csp/impl/managed_dataset/managed_dataset.py | 487 ---- .../managed_dataset_lock_file_util.py | 111 - .../managed_dataset_merge_utils.py | 431 --- .../managed_dataset_path_resolver.py | 470 ---- .../managed_dataset/managed_parquet_writer.py | 340 --- csp/impl/mem_cache.py | 14 +- csp/impl/types/instantiation_type_resolver.py | 29 +- csp/impl/wiring/base_parser.py | 17 +- csp/impl/wiring/cache_support/__init__.py | 0 .../cache_support/cache_config_resolver.py | 22 - .../wiring/cache_support/cache_type_mapper.py | 55 - .../dataset_partition_cached_data.py | 662 ----- .../wiring/cache_support/graph_building.py | 745 ----- .../partition_files_container.py | 99 - .../cache_support/runtime_cache_manager.py | 74 - csp/impl/wiring/context.py | 19 +- csp/impl/wiring/graph.py | 320 +-- csp/impl/wiring/graph_parser.py | 72 +- csp/impl/wiring/node.py | 35 +- csp/impl/wiring/node_parser.py | 14 +- csp/impl/wiring/outputs.py | 15 - csp/impl/wiring/runtime.py | 37 +- csp/impl/wiring/signature.py | 1 - csp/impl/wiring/special_output_names.py | 3 - csp/impl/wiring/threaded_runtime.py | 4 - csp/tests/impl/test_struct.py | 305 ++- csp/tests/test_caching.py | 2438 ----------------- csp/tests/test_engine.py | 27 +- csp/tests/test_parsing.py | 33 +- 40 files changed, 261 insertions(+), 6989 deletions(-) delete mode 100644 csp/cache_support.py delete mode 100644 csp/impl/config.py delete mode 100644 csp/impl/managed_dataset/__init__.py delete mode 100644 csp/impl/managed_dataset/aggregation_period_utils.py delete mode 100644 csp/impl/managed_dataset/cache_partition_argument_serializer.py delete mode 100644 csp/impl/managed_dataset/cache_user_custom_object_serializer.py delete mode 100644 csp/impl/managed_dataset/dataset_metadata.py delete mode 100644 csp/impl/managed_dataset/dateset_name_constants.py delete mode 100644 csp/impl/managed_dataset/datetime_utils.py delete mode 100644 csp/impl/managed_dataset/managed_dataset.py delete mode 100644 csp/impl/managed_dataset/managed_dataset_lock_file_util.py delete mode 100644 csp/impl/managed_dataset/managed_dataset_merge_utils.py delete mode 100644 csp/impl/managed_dataset/managed_dataset_path_resolver.py delete mode 100644 csp/impl/managed_dataset/managed_parquet_writer.py delete mode 100644 csp/impl/wiring/cache_support/__init__.py delete mode 100644 csp/impl/wiring/cache_support/cache_config_resolver.py delete mode 100644 csp/impl/wiring/cache_support/cache_type_mapper.py delete mode 100644 csp/impl/wiring/cache_support/dataset_partition_cached_data.py delete mode 100644 csp/impl/wiring/cache_support/graph_building.py delete mode 100644 csp/impl/wiring/cache_support/partition_files_container.py delete mode 100644 csp/impl/wiring/cache_support/runtime_cache_manager.py delete mode 100644 csp/tests/test_caching.py diff --git a/csp/__init__.py b/csp/__init__.py index b7cc55c1..05a05a68 100644 --- a/csp/__init__.py +++ b/csp/__init__.py @@ -4,7 +4,6 @@ from csp.curve import curve from csp.dataframe import DataFrame from csp.impl.builtin_functions import * -from csp.impl.config import Config from csp.impl.constants import UNSET from csp.impl.enum import DynamicEnum, Enum from csp.impl.error_handling import set_print_full_exception_stack @@ -30,7 +29,7 @@ from csp.math import * from csp.showgraph import show_graph -from . import cache_support, stats +from . import stats __version__ = "0.0.3" diff --git a/csp/baselib.py b/csp/baselib.py index 09052044..fb2593c8 100644 --- a/csp/baselib.py +++ b/csp/baselib.py @@ -283,10 +283,7 @@ def get_basket_field(dict_basket: {"K": ts["V"]}, field_name: str) -> OutputBask :param field_name: :return: """ - if isinstance(dict_basket, csp.impl.wiring.cache_support.graph_building.WrappedCachedStructBasket): - return dict_basket.get_basket_field(field_name) - else: - return {k: getattr(v, field_name) for k, v in dict_basket.items()} + return {k: getattr(v, field_name) for k, v in dict_basket.items()} @node(cppimpl=_cspbaselibimpl.sample) diff --git a/csp/cache_support.py b/csp/cache_support.py deleted file mode 100644 index 1ff9ab5a..00000000 --- a/csp/cache_support.py +++ /dev/null @@ -1,16 +0,0 @@ -from csp.impl.config import BaseCacheConfig, CacheCategoryConfig, CacheConfig -from csp.impl.managed_dataset.cache_user_custom_object_serializer import CacheObjectSerializer -from csp.impl.managed_dataset.dataset_metadata import TimeAggregation -from csp.impl.wiring import GraphCacheOptions, NoCachedDataException -from csp.impl.wiring.cache_support.cache_config_resolver import CacheConfigResolver - -__all__ = [ - "BaseCacheConfig", - "CacheCategoryConfig", - "CacheConfig", - "CacheConfigResolver", - "CacheObjectSerializer", - "GraphCacheOptions", - "NoCachedDataException", - "TimeAggregation", -] diff --git a/csp/impl/config.py b/csp/impl/config.py deleted file mode 100644 index a44145cc..00000000 --- a/csp/impl/config.py +++ /dev/null @@ -1,51 +0,0 @@ -from typing import Dict, List - -from csp.impl.managed_dataset.cache_user_custom_object_serializer import CacheObjectSerializer -from csp.impl.struct import Struct -from csp.utils.file_permissions import FilePermissions, RWXPermissions - - -class BaseCacheConfig(Struct): - data_folder: str - read_folders: List[str] # Additional read folders from which the data should be read if available - lock_file_permissions: FilePermissions = FilePermissions( - user_permissions=RWXPermissions.READ | RWXPermissions.WRITE, - group_permissions=RWXPermissions.READ | RWXPermissions.WRITE, - others_permissions=RWXPermissions.READ | RWXPermissions.WRITE, - ) - data_file_permissions: FilePermissions = FilePermissions( - user_permissions=RWXPermissions.READ | RWXPermissions.WRITE, - group_permissions=RWXPermissions.READ, - others_permissions=RWXPermissions.READ, - ) - merge_existing_files: bool = True - - -class CacheCategoryConfig(BaseCacheConfig): - category: List[str] - - -class CacheConfig(BaseCacheConfig): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - if not hasattr(self, "cache_serializers"): - self.cache_serializers = {} - - allow_overwrite: bool - # An optional override of output folders by category - # For example: - # category_overrides = [ - # CacheCategoryConfig(category=['forecasts'], data_folder='possibly_group_cached_forecasts_path'), - # CacheCategoryConfig(category=['forecasts', 'active_research'], data_folder='possibly_user_specific_forecasts_paths'), - # ] - # All forecasts except for forecasts that are under active_research will be read from/written to possibly_group_cached_forecasts_path. - # It would commonly be a path that is shared by the research team. On the other hand all forecasts under active_research will be written - # to possibly_user_specific_forecasts_paths which can be a private path of the current user that currently researching the forecast and - # needs to redump it often - it's not ready to share with the team yet. - category_overrides: List[CacheCategoryConfig] - graph_overrides: Dict[object, BaseCacheConfig] - cache_serializers: Dict[type, CacheObjectSerializer] - - -class Config(Struct): - cache_config: CacheConfig diff --git a/csp/impl/managed_dataset/__init__.py b/csp/impl/managed_dataset/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/csp/impl/managed_dataset/aggregation_period_utils.py b/csp/impl/managed_dataset/aggregation_period_utils.py deleted file mode 100644 index bf85d805..00000000 --- a/csp/impl/managed_dataset/aggregation_period_utils.py +++ /dev/null @@ -1,87 +0,0 @@ -import datetime -import glob -import os - -from csp.impl.managed_dataset.dataset_metadata import TimeAggregation - - -class AggregationPeriodUtils: - _AGG_LEVELS_GLOB_EXPRESSIONS = { - TimeAggregation.DAY: ["[0-9]" * 4, "[0-9]" * 2, "[0-9]" * 2], - TimeAggregation.MONTH: ["[0-9]" * 4, "[0-9]" * 2], - TimeAggregation.QUARTER: ["[0-9]" * 4, "Q[0-9]"], - TimeAggregation.YEAR: ["[0-9]" * 4], - } - - def __init__(self, aggregation_period: TimeAggregation): - self._aggregation_period = aggregation_period - - def resolve_period_start(self, cur_time: datetime.datetime): - if self._aggregation_period == TimeAggregation.DAY: - return datetime.datetime(cur_time.year, cur_time.month, cur_time.day) - elif self._aggregation_period == TimeAggregation.MONTH: - return datetime.datetime(cur_time.year, cur_time.month, 1) - elif self._aggregation_period == TimeAggregation.QUARTER: - return datetime.datetime(cur_time.year, ((cur_time.month - 1) // 3) * 3 + 1, 1) - elif self._aggregation_period == TimeAggregation.YEAR: - return datetime.datetime(cur_time.year, 1, 1) - else: - raise RuntimeError(f"Unsupported aggregation period {self._aggregation_period}") - - def resolve_period_end(self, cur_time: datetime.datetime, exclusive_end=True): - if self._aggregation_period == TimeAggregation.DAY: - res = datetime.datetime(cur_time.year, cur_time.month, cur_time.day) + datetime.timedelta(days=1) - elif self._aggregation_period == TimeAggregation.MONTH: - next_month_date = cur_time + datetime.timedelta(days=32 - cur_time.day) - res = datetime.datetime(next_month_date.year, next_month_date.month, 1) - elif self._aggregation_period == TimeAggregation.QUARTER: - extra_months = (3 - cur_time.month) % 3 - next_quarter_date = cur_time + datetime.timedelta(days=31 * extra_months + 32 - cur_time.day) - res = datetime.datetime(next_quarter_date.year, next_quarter_date.month, 1) - elif self._aggregation_period == TimeAggregation.YEAR: - res = datetime.datetime(cur_time.year + 1, 1, 1) - else: - raise RuntimeError(f"Unsupported aggregation period {self._aggregation_period}") - if not exclusive_end: - res -= datetime.timedelta(microseconds=1) - return res - - def resolve_period_start_end(self, cur_time: datetime.datetime, exclusive_end=True): - return self.resolve_period_start(cur_time), self.resolve_period_end(cur_time, exclusive_end=exclusive_end) - - def get_sub_folder_name(self, cur_time: datetime.datetime): - if self._aggregation_period == TimeAggregation.DAY: - return cur_time.strftime("%Y/%m/%d") - elif self._aggregation_period == TimeAggregation.MONTH: - return cur_time.strftime("%Y/%m") - elif self._aggregation_period == TimeAggregation.QUARTER: - quarter_index = (cur_time.month - 1) // 3 + 1 - return cur_time.strftime(f"%Y/Q{quarter_index}") - elif self._aggregation_period == TimeAggregation.YEAR: - return cur_time.strftime("%Y") - else: - raise RuntimeError(f"Unsupported aggregation period {self._aggregation_period}") - - def iterate_periods_in_date_range(self, start_time: datetime.datetime, end_time: datetime.datetime): - assert start_time <= end_time - period_start, period_end = self.resolve_period_start_end(start_time) - while period_start <= end_time: - yield period_start, period_end - period_start = period_end - period_end = self.resolve_period_end(period_start) - - def get_agg_bound_folder(self, root_folder: str, is_starttime: bool): - """Return the first/last partition folder for the dataset - :param root_folder: - :param is_starttime: - :return: - """ - glob_expressions = self._AGG_LEVELS_GLOB_EXPRESSIONS[self._aggregation_period] - cur = root_folder - ind = 0 if is_starttime else -1 - for glob_exp in glob_expressions: - cur_list = sorted(glob.glob(os.path.join(glob.escape(cur), glob_exp))) - if not cur_list: - return None - cur = os.path.join(cur, cur_list[ind]) - return cur diff --git a/csp/impl/managed_dataset/cache_partition_argument_serializer.py b/csp/impl/managed_dataset/cache_partition_argument_serializer.py deleted file mode 100644 index 2935b1f6..00000000 --- a/csp/impl/managed_dataset/cache_partition_argument_serializer.py +++ /dev/null @@ -1,101 +0,0 @@ -import hashlib -import io -import ruamel.yaml -from abc import ABCMeta, abstractmethod - -from csp.impl.struct import Struct - - -class SerializedArgument: - def __init__(self, arg, serializer): - self._arg = arg - self._serializer = serializer - self._arg_as_string = None - self._arg_as_dict = None - self._arg_as_yaml_string = None - - def __str__(self): - return self.arg_as_string - - @property - def arg(self): - return self._arg - - @property - def arg_as_string(self): - if self._arg_as_string is None: - self._arg_as_string = self._serializer.to_string(self) - return self._arg_as_string - - @property - def arg_as_yaml_string(self): - if self._arg_as_yaml_string is None: - yaml = ruamel.yaml.YAML() - string_io = io.StringIO() - yaml.dump(self.arg_as_dict, string_io) - self._arg_as_yaml_string = string_io.getvalue() - return self._arg_as_yaml_string - - @property - def arg_as_dict(self): - if self._arg_as_dict is None: - self._arg_as_dict = self._serializer.to_json_dict(self) - return self._arg_as_dict - - -class CachePartitionArgumentSerializer(metaclass=ABCMeta): - @abstractmethod - def to_json_dict(self, value: SerializedArgument): - """ - :param value: The value to serialize - :returns: Should return a dict that will be written to yaml file - """ - raise NotImplementedError() - - @abstractmethod - def from_json_dict(self, value): - """ - :param value: The dict that is read from yaml file - :returns: Should return the deserialized object - """ - raise NotImplementedError() - - @abstractmethod - def to_string(self, value: SerializedArgument): - """Serialize the given object to a string (this string will be the partition folder name) - - :param value: The value to serialize - :returns: Should return a string that will be the folder name - """ - raise NotImplementedError() - - def __call__(self, value): - return SerializedArgument(value, self) - - -class StructPartitionArgumentSerializer(CachePartitionArgumentSerializer): - def __init__(self, typ): - self._typ = typ - - def to_json_dict(self, value: SerializedArgument): - """ - :param value: The value to serialize - :returns: Should return a dict that will be written to yaml file - """ - assert isinstance(value.arg, self._typ) - return value.arg.to_dict() - - def from_json_dict(self, value) -> Struct: - """ - :param value: The dict that is read from yaml file - :returns: Should return the deserialized object - """ - return self._typ.from_dict(value) - - def to_string(self, value: SerializedArgument): - """Serialize the given object to a string (this string will be the partition folder name) - - :param value: The value to serialize - :returns: Should return a string that will be the folder name - """ - return f"struct_{hashlib.md5(value.arg_as_yaml_string.encode()).hexdigest()}" diff --git a/csp/impl/managed_dataset/cache_user_custom_object_serializer.py b/csp/impl/managed_dataset/cache_user_custom_object_serializer.py deleted file mode 100644 index ff637796..00000000 --- a/csp/impl/managed_dataset/cache_user_custom_object_serializer.py +++ /dev/null @@ -1,11 +0,0 @@ -from abc import ABCMeta, abstractmethod - - -class CacheObjectSerializer(metaclass=ABCMeta): - @abstractmethod - def serialize_to_bytes(self, value): - raise NotImplementedError - - @abstractmethod - def deserialize_from_bytes(self, value): - raise NotImplementedError diff --git a/csp/impl/managed_dataset/dataset_metadata.py b/csp/impl/managed_dataset/dataset_metadata.py deleted file mode 100644 index 136b3e64..00000000 --- a/csp/impl/managed_dataset/dataset_metadata.py +++ /dev/null @@ -1,85 +0,0 @@ -from enum import Enum, auto -from typing import Dict - -from csp.impl.struct import Struct -from csp.impl.wiring.cache_support.cache_type_mapper import CacheTypeMapper - - -class OutputType(Enum): - PARQUET = auto() - - -class TimeAggregation(Enum): - DAY = auto() - MONTH = auto() - QUARTER = auto() - YEAR = auto() - - -class DictBasketInfo(Struct): - key_type: object - value_type: object - - @classmethod - def _postprocess_dict_to_python(cls, d): - d["key_type"] = CacheTypeMapper.type_to_json(d["key_type"]) - d["value_type"] = CacheTypeMapper.type_to_json(d["value_type"]) - return d - - @classmethod - def _preprocess_dict_from_python(cls, d): - d["key_type"] = CacheTypeMapper.json_to_type(d["key_type"]) - d["value_type"] = CacheTypeMapper.json_to_type(d["value_type"]) - return d - - -class DatasetMetadata(Struct): - version: str = "1.0.0" - name: str - output_type: OutputType = OutputType.PARQUET - time_aggregation: TimeAggregation = TimeAggregation.DAY - columns: Dict[str, object] - dict_basket_columns: Dict[str, DictBasketInfo] - partition_columns: Dict[str, type] - timestamp_column_name: str - split_columns_to_files: bool = False - - @classmethod - def _postprocess_dict_to_python(cls, d): - output_type = d.get("output_type") - if output_type is not None: - d["output_type"] = output_type.name - time_aggregation = d.get("time_aggregation") - if time_aggregation is not None: - d["time_aggregation"] = time_aggregation.name - columns = d["columns"] - if columns: - d["columns"] = {k: CacheTypeMapper.type_to_json(v) for k, v in columns.items()} - - partition_columns = d.get("partition_columns") - if partition_columns: - d["partition_columns"] = {k: CacheTypeMapper.type_to_json(v) for k, v in partition_columns.items()} - - return d - - @classmethod - def _preprocess_dict_from_python(cls, d): - output_type = d.get("output_type") - if output_type is not None: - d["output_type"] = OutputType[output_type] - time_aggregation = d.get("time_aggregation") - if time_aggregation is not None: - d["time_aggregation"] = TimeAggregation[time_aggregation] - columns = d["columns"] - if columns: - d["columns"] = {k: CacheTypeMapper.json_to_type(v) for k, v in columns.items()} - partition_columns = d.get("partition_columns") - if partition_columns: - d["partition_columns"] = {k: CacheTypeMapper.json_to_type(v) for k, v in partition_columns.items()} - - return d - - @classmethod - def load_metadata(cls, file_path: str): - with open(file_path, "r") as f: - return DatasetMetadata.from_yaml(f.read()) diff --git a/csp/impl/managed_dataset/dateset_name_constants.py b/csp/impl/managed_dataset/dateset_name_constants.py deleted file mode 100644 index 3e67f812..00000000 --- a/csp/impl/managed_dataset/dateset_name_constants.py +++ /dev/null @@ -1,4 +0,0 @@ -class DatasetNameConstants: - UNNAMED_OUTPUT_NAME = "csp_unnamed_output" - CSP_TIMESTAMP = "csp_timestamp" - PARTITION_ARGUMENT_FILE_NAME = ".csp_argument_value" diff --git a/csp/impl/managed_dataset/datetime_utils.py b/csp/impl/managed_dataset/datetime_utils.py deleted file mode 100644 index e2c350be..00000000 --- a/csp/impl/managed_dataset/datetime_utils.py +++ /dev/null @@ -1,8 +0,0 @@ -from datetime import date, timedelta - -ONE_DAY_DELTA = timedelta(days=1) - - -def get_dates_in_range(start: date, end: date, inclusive_end=True): - n_days = (end - start).days + int(inclusive_end) - return [start + ONE_DAY_DELTA * i for i in range(n_days)] diff --git a/csp/impl/managed_dataset/managed_dataset.py b/csp/impl/managed_dataset/managed_dataset.py deleted file mode 100644 index dbf6d5f5..00000000 --- a/csp/impl/managed_dataset/managed_dataset.py +++ /dev/null @@ -1,487 +0,0 @@ -import logging -import os -import tempfile -from datetime import date, datetime, timedelta -from typing import Dict, List, Optional, Tuple, Union - -import csp -from csp.impl.config import BaseCacheConfig -from csp.impl.enum import Enum -from csp.impl.managed_dataset.cache_partition_argument_serializer import ( - SerializedArgument, - StructPartitionArgumentSerializer, -) -from csp.impl.managed_dataset.dataset_metadata import DatasetMetadata, DictBasketInfo, TimeAggregation -from csp.impl.managed_dataset.dateset_name_constants import DatasetNameConstants -from csp.impl.managed_dataset.managed_dataset_lock_file_util import LockContext, ManagedDatasetLockUtil -from csp.impl.managed_dataset.managed_dataset_path_resolver import DatasetPaths -from csp.impl.struct import Struct -from csp.utils.file_permissions import FilePermissions, apply_file_permissions, create_folder_with_permissions -from csp.utils.rm_utils import rm_file_or_folder - - -class _MetadataRWUtil: - def __init__(self, dataset, metadata_file_path, metadata, lock_file_permissions, data_file_permissions): - self._dataset = dataset - self._metadata = metadata - self._metadata_file_path = metadata_file_path - self._lock_file_util = ManagedDatasetLockUtil(lock_file_permissions) - self._data_file_permissions = data_file_permissions - - def _write_metadata(self): - locked_folder = os.path.dirname(self._metadata_file_path) - with self._lock_file_util.write_lock(locked_folder): - if os.path.exists(self._metadata_file_path): - return - - file_base_name = os.path.basename(self._metadata_file_path) - create_folder_with_permissions(locked_folder, self._dataset.cache_config.data_file_permissions) - with tempfile.NamedTemporaryFile(mode="w+", prefix=file_base_name, dir=locked_folder, delete=False) as f: - try: - yaml = self._metadata.to_yaml() - f.file.write(yaml) - f.file.flush() - apply_file_permissions(f.name, self._data_file_permissions) - os.rename(f.name, self._metadata_file_path) - except: - rm_file_or_folder(f.name) - raise - - def load_existing_or_store_metadata(self): - """Loads existing metadata if no metadata exists, also will store the current metadata to file - - :return: A tuple of (loaded_metadata, file_lock) where file lock locks the metadata folder for reading (the file lock is in acquired state) - """ - if not os.path.exists(self._metadata_file_path): - self._write_metadata() - - locked_folder = os.path.dirname(self._metadata_file_path) - read_lock = self._lock_file_util.read_lock(locked_folder) - read_lock.lock() - try: - with open(self._metadata_file_path, "r") as f: - existing_metadata = self._metadata.from_yaml(f.read()) - return existing_metadata, read_lock - except: - read_lock.unlock() - raise - - -class ManagedDatasetPartition: - """A single partition of a dataset, this basically represents the lowest level of the chain dataset->partition. - - Single partition corresponds to a single instance of partition values. For example if there is a dataset that is partitioned on columns - a:int, b:float, c:str then a single partition would for example correspond to (1, 1.0, 'str1') while (2,2.0, 'str2') would be a different - partition object. - Single partition corresponds to a single instance of graph. - """ - - # In the future we are going to support containers as well, for now just primitives - PARTITION_TYPE_STR_CONVERTORS = { - bool: str, - int: str, - float: str, - str: str, - datetime: lambda v: v.strftime("%Y%m%d_%H%M%S_%f"), - date: lambda v: v.strftime("%Y%m%d_%H%M%S_%f"), - timedelta: lambda v: f"td_{int(v.total_seconds() * 1e6)}us", - } - - def __init__(self, dataset, partition_values: Optional[Dict[str, object]] = None): - """ - :param dataset: An instance of ManagedDataset to which the partition belongs - :param partition_values: A dictionary of partition column name to value of the column for the given partition - """ - self._dataset = dataset - self._values_tuple, self._values_dict = self._normalize_partition_values(partition_values) - self._data_paths = None - - def get_data_for_period(self, starttime: datetime, endtime: datetime, missing_range_handler): - return self.data_paths.get_data_files_in_range( - starttime, - endtime, - missing_range_handler=missing_range_handler, - split_columns_to_files=self.dataset.metadata.split_columns_to_files, - ) - - @property - def data_paths(self): - if self._data_paths is None: - dataset_data_paths = self._dataset.data_paths - if dataset_data_paths: - self._data_paths = dataset_data_paths.get_partition_paths(self._values_dict) - return self._data_paths - - @property - def dataset(self): - return self._dataset - - @property - def value_tuple(self): - return self._values_tuple - - @property - def value_dict(self): - return self._values_dict - - def _create_folder_with_permissions(self, cur_root_path, folder_permissions): - if not os.path.exists(cur_root_path): - try: - os.mkdir(cur_root_path) - apply_file_permissions(cur_root_path, folder_permissions) - return True - except FileExistsError: - pass - return False - - def create_root_folder(self, cache_config): - if os.path.exists(self.data_paths.root_folder): - return - cur_root_path = self.dataset.data_paths.root_folder - rel_path = os.path.relpath(self.data_paths.root_folder, cur_root_path) - path_parts = list(filter(None, os.path.normpath(rel_path).split(os.sep))) - assert path_parts[0] == "data" - cur_root_path = os.path.join(cur_root_path, path_parts[0]) - - file_permissions = cache_config.data_file_permissions - folder_permissions = cache_config.data_file_permissions.get_folder_permissions() - - self._create_folder_with_permissions(cur_root_path, folder_permissions) - values_dict = self._values_dict - assert len(values_dict) + 1 == len(path_parts) - lock_util = ManagedDatasetLockUtil(cache_config.lock_file_permissions) - - with self.dataset.use_lock_context(): - for sub_folder, argument_value in zip(path_parts[1:], values_dict.values()): - cur_root_path = os.path.join(cur_root_path, sub_folder) - self._create_folder_with_permissions(cur_root_path, folder_permissions) - if isinstance(argument_value, SerializedArgument): - value_file_path = os.path.join(cur_root_path, DatasetNameConstants.PARTITION_ARGUMENT_FILE_NAME) - if not os.path.exists(value_file_path): - with lock_util.write_lock(value_file_path, is_lock_in_root_folder=True) as lock_file: - if not os.path.exists(value_file_path): - with open(value_file_path, "w") as value_file: - value_file.write(argument_value.arg_as_yaml_string) - apply_file_permissions(value_file_path, file_permissions) - rm_file_or_folder(lock_file.file_path, is_file=True) - - def publish_file(self, file_name, start_time, end_time, file_permissions=None, lock_file_permissions=None): - output_file_name = self.data_paths.get_output_file_name( - start_time, end_time, split_columns_to_files=self.dataset.metadata.split_columns_to_files - ) - # We might try to publish some files that are already there. Example - # We ran 20210101-20210102. We now run 20210102-20210103, since the data is not fully in cache we will run the graph, the data for 20210102 will be generated again. - if os.path.exists(output_file_name): - rm_file_or_folder(file_name) - return - - if file_permissions is not None: - if os.path.isdir(file_name): - folder_permissions = file_permissions.get_folder_permissions() - apply_file_permissions(file_name, folder_permissions) - for f in os.listdir(file_name): - apply_file_permissions(os.path.join(file_name, f), file_permissions) - else: - apply_file_permissions(file_name, file_permissions) - - with self.dataset.use_lock_context(): - lock_util = ManagedDatasetLockUtil(lock_file_permissions) - with lock_util.write_lock(output_file_name, is_lock_in_root_folder=True) as lock: - if os.path.exists(output_file_name): - logging.warning(f"Not publishing {output_file_name} since it already exists") - rm_file_or_folder(file_name) - else: - os.rename(file_name, output_file_name) - lock.delete_file() - - def merge_files(self, start_time: datetime, end_time: datetime, cache_config, parquet_output_config): - from csp.impl.managed_dataset.managed_dataset_merge_utils import SinglePartitionFileMerger - - with self.dataset.use_lock_context(): - file_merger = SinglePartitionFileMerger( - dataset_partition=self, - start_time=start_time, - end_time=end_time, - cache_config=cache_config, - parquet_output_config=parquet_output_config, - ) - file_merger.merge_files() - - def cleanup_unneeded_files(self, start_time: datetime, end_time: datetime, cache_config): - unused_files = self.data_paths.get_unused_files( - starttime=start_time, endtime=end_time, split_columns_to_files=self.dataset.metadata.split_columns_to_files - ) - if unused_files: - with self.dataset.use_lock_context(): - lock_util = ManagedDatasetLockUtil(cache_config.lock_file_permissions) - for f in unused_files: - try: - with lock_util.write_lock( - f, is_lock_in_root_folder=True, timeout_seconds=0, retry_period_seconds=0 - ) as lock: - rm_file_or_folder(f) - lock.delete_file() - except BlockingIOError: - logging.warning(f"Not removing {f} since it's currently locked") - - def partition_merge_lock(self, start_time: datetime, end_time: datetime): - raise NotImplementedError() - - def _get_type_convertor(self, typ): - if issubclass(typ, Enum): - return str - elif issubclass(typ, Struct): - return StructPartitionArgumentSerializer(typ) - else: - return self.PARTITION_TYPE_STR_CONVERTORS[typ] - - def _normalize_partition_values(self, partition_values): - metadata = self._dataset.metadata - if partition_values: - assert len(partition_values) == len(metadata.partition_columns) - assert partition_values.keys() == metadata.partition_columns.keys() - ordered_partition_values = ((k, partition_values[k]) for k in metadata.partition_columns) - partition_values = {k: self._get_type_convertor(type(v))(v) for k, v in ordered_partition_values} - values_tuple = tuple(partition_values.values()) - else: - assert not hasattr(metadata, "partition_columns") - values_tuple = tuple() - return values_tuple, partition_values - - -class ManagedDataset: - """A single dataset, this basically represents the highest level of the chain dataset->partition. - - Single dataset corresponds to a set of dataset_partitions all having identical schema but having different partition keys. - Example consider having cached trades for each ticker and date. Single dataset represents all the "trades" and has the trade - schema attached to it. Each partition will be part of the dataset but correspond to a different ticker. - Single dataset corresponds to a single "graph" function (defines paths and schemas for all instances of this graph). - """ - - SUPPORTED_PARTITION_TYPES = set(ManagedDatasetPartition.PARTITION_TYPE_STR_CONVERTORS.keys()) - - def __init__( - self, - name, - category: List[str] = None, - timestamp_column_name: str = None, - columns_types: Dict[str, object] = None, - partition_columns: Dict[str, type] = None, - *, - cache_config: BaseCacheConfig, - split_columns_to_files: Optional[bool], - time_aggregation: TimeAggregation, - dict_basket_column_types: Dict[str, Union[Tuple[type, type], DictBasketInfo]] = None, - ): - """ - :param name: The name of the dataset: - :param category: The category classification of the dataset, for example ['stats', 'daily'], or ['forecasts'], - this is being used as part of the path of the dataset on disk - :param timestamp_column_name: The name of the timestamp column in the parquet files. - :param columns_types: A dictionary of name->type of dataset column types. - :param partition_columns: A dictionary of partitioning columns of the dataset. This columns are not written into parquet files but instead - are used as part of the dataset partition path. - :param cache_config: The cache configuration for the data set - :param split_columns_to_files: A boolean that specifies whether the data of the dataset is split across files. - :param time_aggregation: The data aggregation period for the dataset - :param dict_basket_column_types: The dictionary basket columns of the dataset - """ - self._name = name - self._category = category if category else [] - self._cache_config = cache_config - self._lock_context = None - self._metadata = DatasetMetadata( - name=name, - split_columns_to_files=True if split_columns_to_files else False, - time_aggregation=time_aggregation, - ) - dict_basket_columns = self._normalize_dict_basket_types(dict_basket_column_types) - if dict_basket_columns: - self._metadata.dict_basket_columns = dict_basket_columns - if timestamp_column_name: - self._metadata.timestamp_column_name = timestamp_column_name - self._metadata.columns = columns_types if columns_types else {} - if partition_columns: - self._metadata.partition_columns = partition_columns - - self._data_paths: Optional[DatasetPaths] = None - self._set_folders(cache_config.data_folder, getattr(cache_config, "read_folders", None)) - - @classmethod - def _normalize_dict_basket_types(cls, dict_basket_column_types): - if not dict_basket_column_types: - return None - dict_types = {} - for name, type_entry in dict_basket_column_types.items(): - if isinstance(type_entry, DictBasketInfo): - dict_types[name] = type_entry - else: - key_type, value_type = type_entry - dict_types[name] = DictBasketInfo(key_type=key_type, value_type=value_type) - return dict_types - - @classmethod - def load_from_disk(cls, cache_config, name, data_category: Optional[List[str]] = None): - data_paths = DatasetPaths( - parent_folder=cache_config.data_folder, - read_folders=getattr(cache_config, "read_folders", None), - name=name, - data_category=data_category, - ) - metadata_file_path = data_paths.get_metadata_file_path(existing=True) - if metadata_file_path: - with open(metadata_file_path, "r") as f: - metadata = DatasetMetadata.from_yaml(f.read()) - res = ManagedDataset( - name=metadata.name, - category=data_category, - timestamp_column_name=metadata.timestamp_column_name, - columns_types=metadata.columns, - cache_config=cache_config, - split_columns_to_files=metadata.split_columns_to_files, - time_aggregation=metadata.time_aggregation, - dict_basket_column_types=getattr(metadata, "dict_basket_columns", None), - ) - if hasattr(metadata, "partition_columns"): - res.metadata.partition_columns = metadata.partition_columns - return res - else: - return None - - @classmethod - def is_supported_partition_type(cls, typ): - if typ in ManagedDataset.SUPPORTED_PARTITION_TYPES or ( - isinstance(typ, type) and (issubclass(typ, Enum) or issubclass(typ, Struct)) - ): - return True - else: - return False - - @property - def cache_config(self): - assert self._cache_config is not None - return self._cache_config - - def use_lock_context(self): - if self._lock_context is None: - self._lock_context = LockContext(self) - return ManagedDatasetLockUtil.set_dataset_context(self._lock_context) - - def validate_and_lock_metadata( - self, - lock_file_permissions: Optional[FilePermissions] = None, - data_file_permissions: Optional[FilePermissions] = None, - read: bool = False, - write: bool = False, - ): - """Validate that code metadata correspond to existing metadata on disk. If necessary writes metadata file to disk. - - :param lock_file_permissions: The permissions of the lock files that are created for safely accessing metadata. - :param data_file_permissions: The permissions of the written metadata files. - :param read: A bool that specifies whether the dataset will be read - :param write: A bool that specifies whether the dataset will be written. - - Note: validation for read vs written datasets will be different. For read datasets we allow slightly different more relaxed schemas. - - :return: An obtained "shared" lock that locks the dataset schema. Caller is responsible for releasing the lock - """ - assert self.data_paths is not None - with self.use_lock_context(): - metadata = self.metadata - metadata_file_path = self.data_paths.get_metadata_file_path(existing=not write) - metadata_rw_util = _MetadataRWUtil( - dataset=self, - metadata_file_path=metadata_file_path, - metadata=self.metadata, - lock_file_permissions=lock_file_permissions, - data_file_permissions=data_file_permissions, - ) - existing_meta, read_lock = metadata_rw_util.load_existing_or_store_metadata() - is_metadata_different = False - - if write: - is_metadata_different = existing_meta != self.metadata - else: - for field in DatasetMetadata.metadata(): - if field not in ("columns", "dict_basket_columns"): - if getattr(existing_meta, field, None) != getattr(self.metadata, field, None): - is_metadata_different = True - # The read metadata must be a subset of the existing metadata - existing_meta_columns = existing_meta.columns - existing_meta_dict_columns = getattr(existing_meta, "dict_basket_columns", None) - for col_name, col_type in metadata.columns.items(): - existing_type = existing_meta_columns.get(col_name, None) - if existing_type is None or existing_type != col_type: - is_metadata_different = True - break - cur_dict_basket_columns = getattr(metadata, "dict_basket_columns", None) - if cur_dict_basket_columns or existing_meta_dict_columns: - if cur_dict_basket_columns is None or existing_meta_dict_columns is None: - is_metadata_different = True - else: - for col_name, col_info in metadata.dict_basket_columns.items(): - if not existing_meta_dict_columns: - is_metadata_different = True - break - existing_meta_column_info = existing_meta_dict_columns.get(col_name) - if existing_meta_column_info is None: - is_metadata_different = True - break - existing_type = existing_meta_column_info.value_type - cur_type = col_info.value_type - - if issubclass(existing_type, csp.Struct): - if not issubclass(cur_type, csp.Struct): - is_metadata_different = True - break - existing_meta = existing_type.metadata() - for field, field_type in cur_type.metadata().items(): - if existing_meta.get(field) != field_type: - is_metadata_different = True - else: - is_metadata_different = existing_type is not cur_type - - if is_metadata_different: - read_lock.unlock() - raise RuntimeError( - f"Metadata mismatch at {metadata_file_path}\nCurrent:\n{metadata}\nExisting:{existing_meta}\n" - ) - return read_lock - - def get_partition(self, partition_values: Dict[str, object]): - """Get a partition object that corresponds to the given instance of partition key->value mapping. - :param partition_values: - """ - return ManagedDatasetPartition(self, partition_values) - - @property - def category(self): - return self._category - - @property - def parent_folder(self): - if self._data_paths is None: - return None - return self._data_paths.parent_folder - - def _set_folders(self, parent_folder, read_folders): - assert self._data_paths is None - if parent_folder: - self._data_paths = DatasetPaths( - parent_folder=parent_folder, - read_folders=read_folders, - name=self._name, - data_category=self._category, - time_aggregation=self.metadata.time_aggregation, - ) - else: - assert not read_folders, "Provided read folders without parent folder" - self._lock_context = None - - @property - def data_paths(self) -> DatasetPaths: - return self._data_paths - - @property - def metadata(self): - return self._metadata diff --git a/csp/impl/managed_dataset/managed_dataset_lock_file_util.py b/csp/impl/managed_dataset/managed_dataset_lock_file_util.py deleted file mode 100644 index d638cec6..00000000 --- a/csp/impl/managed_dataset/managed_dataset_lock_file_util.py +++ /dev/null @@ -1,111 +0,0 @@ -import os -import threading -import typing -from contextlib import contextmanager - -from csp.utils.file_permissions import create_folder_with_permissions -from csp.utils.lock_file import LockFile - -if typing.TYPE_CHECKING: - from .managed_dataset import ManagedDataset - - -class LockContext: - def __init__(self, data_set: "ManagedDataset"): - self._data_set = data_set - - def resolve_lock_file_path_and_create_folders(self, file_path: str, use_read_folders: bool): - parent_folder, lock_file_path = self._data_set.data_paths.resolve_lock_file_path( - file_path, use_read_folders=use_read_folders - ) - # We need to make sure that the root folder is created with the right permissions - create_folder_with_permissions(parent_folder, self._data_set.cache_config.data_file_permissions) - create_folder_with_permissions( - os.path.dirname(lock_file_path), self._data_set.cache_config.lock_file_permissions - ) - return lock_file_path - - -class ManagedDatasetLockUtil: - _READ_WRITE_LOCK_FILE_NAME = ".csp_read_write_lock" - _MERGE_LOCK_FILE_NAME = ".csp_merge_lock" - _TLS = threading.local() - - def __init__(self, lock_file_permissions): - self._lock_file_permissions = lock_file_permissions - - @classmethod - @contextmanager - def set_dataset_context(cls, lock_context: LockContext): - prev = getattr(cls._TLS, "instance", None) - try: - cls._TLS.instance = lock_context - yield lock_context - finally: - if prev is not None: - cls._TLS.instance = prev - else: - delattr(cls._TLS, "instance") - - @classmethod - def get_cur_context(cls): - res = getattr(cls._TLS, "instance", None) - if res is None: - raise RuntimeError("Trying to get lock context without any context set") - return res - - def _create_lock(self, file_path, lock_name, shared, is_lock_in_root_folder, timeout_seconds, retry_period_seconds): - cur_context = self.get_cur_context() - if os.path.isfile(file_path) or is_lock_in_root_folder: - base_path = os.path.splitext(os.path.basename(file_path))[0] - dir_name = os.path.dirname(file_path) - lock_file_name = f"{lock_name}.{base_path}" - return LockFile( - file_path=cur_context.resolve_lock_file_path_and_create_folders( - os.path.join(dir_name, lock_file_name), use_read_folders=shared - ), - shared=shared, - file_permissions=self._lock_file_permissions, - timeout_seconds=timeout_seconds, - retry_period_seconds=retry_period_seconds, - ) - else: - return LockFile( - file_path=cur_context.resolve_lock_file_path_and_create_folders( - os.path.join(file_path, lock_name), use_read_folders=shared - ), - shared=shared, - file_permissions=self._lock_file_permissions, - timeout_seconds=timeout_seconds, - retry_period_seconds=retry_period_seconds, - ) - - def write_lock(self, file_path, is_lock_in_root_folder=None, timeout_seconds=None, retry_period_seconds=None): - return self._create_lock( - file_path, - lock_name=self._READ_WRITE_LOCK_FILE_NAME, - shared=False, - is_lock_in_root_folder=is_lock_in_root_folder, - timeout_seconds=timeout_seconds, - retry_period_seconds=retry_period_seconds, - ) - - def read_lock(self, file_path, is_lock_in_root_folder=None, timeout_seconds=None, retry_period_seconds=None): - return self._create_lock( - file_path, - lock_name=self._READ_WRITE_LOCK_FILE_NAME, - shared=True, - is_lock_in_root_folder=is_lock_in_root_folder, - timeout_seconds=timeout_seconds, - retry_period_seconds=retry_period_seconds, - ) - - def merge_lock(self, file_path, is_lock_in_root_folder=None, timeout_seconds=None, retry_period_seconds=None): - return self._create_lock( - file_path, - lock_name=self._MERGE_LOCK_FILE_NAME, - shared=False, - is_lock_in_root_folder=is_lock_in_root_folder, - timeout_seconds=timeout_seconds, - retry_period_seconds=retry_period_seconds, - ) diff --git a/csp/impl/managed_dataset/managed_dataset_merge_utils.py b/csp/impl/managed_dataset/managed_dataset_merge_utils.py deleted file mode 100644 index ef2ffeb7..00000000 --- a/csp/impl/managed_dataset/managed_dataset_merge_utils.py +++ /dev/null @@ -1,431 +0,0 @@ -import datetime -import itertools -import os -import pytz -import tempfile -from typing import Optional - -import csp -from csp.adapters.output_adapters.parquet import ParquetOutputConfig, ParquetWriter -from csp.cache_support import CacheConfig -from csp.impl.managed_dataset.aggregation_period_utils import AggregationPeriodUtils -from csp.impl.managed_dataset.managed_dataset import ManagedDatasetPartition -from csp.impl.managed_dataset.managed_dataset_lock_file_util import ManagedDatasetLockUtil -from csp.utils.file_permissions import apply_file_permissions -from csp.utils.lock_file import MultipleFilesLock - - -def _pa(): - """ - Lazy import pyarrow - """ - import pyarrow - - return pyarrow - - -def _create_wip_file(output_folder, start_time, is_folder: Optional[bool] = False): - prefix = start_time.strftime("%Y%m%d_H%M%S_%f") if start_time else "merge_" - - if is_folder: - return tempfile.mkdtemp(dir=output_folder, suffix="_WIP", prefix=prefix) - else: - fd, cur_file_path = tempfile.mkstemp(dir=output_folder, suffix="_WIP", prefix=prefix) - os.close(fd) - return cur_file_path - - -class _SingleBasketMergeData: - def __init__(self, basket_name, basket_types, input_files, basket_data_input_files): - self.basket_data_input_files = basket_data_input_files - self.basket_name = basket_name - self.basket_types = basket_types - self.count_column_name = f"{basket_name}__csp_value_count" - if issubclass(self.basket_types.value_type, csp.Struct): - self.data_column_names = [f"{basket_name}.{c}" for c in self.basket_types.value_type.metadata()] - else: - self.data_column_names = [basket_name] - self.symbol_column_name = f"{basket_name}__csp_symbol" - self._cur_basket_data_row_group = None - self._cur_row_group_data_table = None - self._cur_row_group_symbol_table = None - self._cur_row_group_last_returned_index = int(-1) - - def _load_row_group(self, next_row_group_index=None): - if next_row_group_index is None: - next_row_group_index = self._cur_basket_data_row_group + 1 - self._cur_basket_data_row_group = next_row_group_index - do_iter = True - while do_iter: - if self._cur_basket_data_row_group < self.basket_data_input_files[self.data_column_names[0]].num_row_groups: - self._cur_row_group_data_tables = [ - self.basket_data_input_files[c].read_row_group(self._cur_basket_data_row_group) - for c in self.data_column_names - ] - self._cur_row_group_symbol_table = self.basket_data_input_files[self.symbol_column_name].read_row_group( - self._cur_basket_data_row_group - ) - if self._cur_row_group_data_tables[0].shape[0] > 0: - do_iter = False - else: - self._cur_basket_data_row_group += 1 - else: - self._cur_row_group_data_tables = None - self._cur_row_group_symbol_table = None - do_iter = False - self._cur_row_group_last_returned_index = int(-1) - return self._cur_row_group_data_tables is not None - - @property - def _num_remaining_rows_cur_chunk(self): - if not self._cur_row_group_data_tables: - return 0 - remaining_items_cur_group = ( - self._cur_row_group_data_tables[0].shape[0] - 1 - self._cur_row_group_last_returned_index - ) - return remaining_items_cur_group - - def _skip_rows(self, num_rows_to_skip): - remaining_items_cur_chunk = self._num_remaining_rows_cur_chunk - while num_rows_to_skip > 0: - if num_rows_to_skip >= remaining_items_cur_chunk: - num_rows_to_skip -= remaining_items_cur_chunk - assert self._load_row_group() or num_rows_to_skip == 0 - else: - self._cur_row_group_last_returned_index += int(num_rows_to_skip) - num_rows_to_skip = 0 - - def _iter_chunks(self, row_indices, full_column_tables): - count_table = full_column_tables[self.count_column_name].columns[0] - count_table_cum_sum = count_table.to_pandas().cumsum() - if self._cur_basket_data_row_group is None: - self._load_row_group(0) - - if row_indices is None: - if count_table_cum_sum.empty: - return - num_rows_to_return = int(count_table_cum_sum.iloc[-1]) - else: - if row_indices.size == 0: - if not count_table_cum_sum.empty: - self._skip_rows(count_table_cum_sum.iloc[-1]) - return - - num_rows_to_return = int(count_table_cum_sum[row_indices[-1]]) - if row_indices[0] != 0: - skipped_rows = int(count_table_cum_sum[row_indices[0] - 1]) - self._skip_rows(skipped_rows) - num_rows_to_return -= skipped_rows - - while num_rows_to_return > 0: - s_i = self._cur_row_group_last_returned_index + 1 - if num_rows_to_return < self._num_remaining_rows_cur_chunk: - e_i = s_i + num_rows_to_return - self._skip_rows(num_rows_to_return) - num_rows_to_return = 0 - yield (self._cur_row_group_symbol_table[s_i:e_i],) + tuple( - t[s_i:e_i] for t in self._cur_row_group_data_tables - ) - else: - num_read_rows = self._num_remaining_rows_cur_chunk - e_i = s_i + num_read_rows - num_rows_to_return -= num_read_rows - yield (self._cur_row_group_symbol_table[s_i:e_i],) + tuple( - t[s_i:e_i] for t in self._cur_row_group_data_tables - ) - assert self._load_row_group() or num_rows_to_return == 0 - - -class _MergeFileInfo(csp.Struct): - file_path: str - start_time: datetime.datetime - end_time: datetime.datetime - - -class SinglePartitionFileMerger: - def __init__( - self, - dataset_partition: ManagedDatasetPartition, - start_time, - end_time, - cache_config: CacheConfig, - parquet_output_config: ParquetOutputConfig, - ): - self._dataset_partition = dataset_partition - self._start_time = start_time - self._end_time = end_time - self._cache_config = cache_config - self._parquet_output_config = parquet_output_config.copy().resolve_compression() - # TODO: cleanup all reference to existing files and backup files - self._split_columns_to_files = getattr(dataset_partition.dataset.metadata, "split_columns_to_files", False) - self._aggregation_period_utils = AggregationPeriodUtils( - self._dataset_partition.dataset.metadata.time_aggregation - ) - - def _is_overwrite_allowed(self): - allow_overwrite = getattr(self._cache_config, "allow_overwrite", None) - if allow_overwrite is not None: - return allow_overwrite - allow_overwrite = getattr(self._parquet_output_config, "allow_overwrite", None) - return bool(allow_overwrite) - - def _resolve_merged_output_file_name(self, merge_candidates): - output_file_name = self._dataset_partition.data_paths.get_output_file_name( - start_time=merge_candidates[0].start_time, - end_time=merge_candidates[-1].end_time, - split_columns_to_files=self._split_columns_to_files, - ) - - return output_file_name - - def _iterate_file_chunks(self, file_name, start_cutoff=None): - dataset = self._dataset_partition.dataset - parquet_file = _pa().parquet.ParquetFile(file_name) - if start_cutoff: - for i in range(parquet_file.metadata.num_row_groups): - time_stamps = parquet_file.read_row_group(i, [dataset.metadata.timestamp_column_name])[ - dataset.metadata.timestamp_column_name - ].to_pandas() - row_indices = time_stamps.index.values[(time_stamps > pytz.utc.localize(start_cutoff))] - - if row_indices.size == 0: - continue - - full_table = parquet_file.read_row_group(i)[row_indices[0] : row_indices[-1] + 1] - yield full_table - else: - for i in range(parquet_file.metadata.num_row_groups): - yield parquet_file.read_row_group(i) - - def _iter_column_names(self, include_regular_columns=True, include_basket_data_columns=True): - dataset = self._dataset_partition.dataset - if include_regular_columns: - yield dataset.metadata.timestamp_column_name - for c in dataset.metadata.columns.keys(): - yield c - if hasattr(dataset.metadata, "dict_basket_columns"): - for c, t in dataset.metadata.dict_basket_columns.items(): - if include_regular_columns: - yield f"{c}__csp_value_count" - if include_basket_data_columns: - if issubclass(t.value_type, csp.Struct): - for field_name in t.value_type.metadata(): - yield f"{c}.{field_name}" - else: - yield c - yield f"{c}__csp_symbol" - - def _iter_column_files(self, folder, include_regular_columns=True, include_basket_data_columns=True): - for c in self._iter_column_names( - include_regular_columns=include_regular_columns, include_basket_data_columns=include_basket_data_columns - ): - yield c, os.path.join(folder, f"{c}.parquet") - - def _iterate_folder_chunks(self, file_name, start_cutoff=None): - dataset = self._dataset_partition.dataset - input_files = {} - for c, f in self._iter_column_files(file_name, include_basket_data_columns=False): - input_files[c] = _pa().parquet.ParquetFile(f) - - basket_data_input_files = {} - for c, f in self._iter_column_files(file_name, include_regular_columns=False): - basket_data_input_files[c] = _pa().parquet.ParquetFile(f) - - timestamp_column_reader = input_files[dataset.metadata.timestamp_column_name] - - basked_data = ( - { - k: _SingleBasketMergeData(k, v, input_files, basket_data_input_files) - for k, v in dataset.metadata.dict_basket_columns.items() - } - if getattr(dataset.metadata, "dict_basket_columns", None) - else {} - ) - - if start_cutoff: - for i in range(timestamp_column_reader.metadata.num_row_groups): - time_stamps = timestamp_column_reader.read_row_group(i, [dataset.metadata.timestamp_column_name])[ - dataset.metadata.timestamp_column_name - ].to_pandas() - row_indices = time_stamps.index.values[(time_stamps > pytz.utc.localize(start_cutoff))] - - full_column_tables = {} - truncated_column_tables = {} - for c in self._iter_column_names(include_basket_data_columns=False): - full_table = input_files[c].read_row_group(i) - full_column_tables[c] = full_table - if row_indices.size > 0: - truncated_column_tables[c] = full_table[row_indices[0] : row_indices[-1] + 1] - - if row_indices.size > 0: - yield ( - truncated_column_tables, - ( - v._iter_chunks(row_indices=row_indices, full_column_tables=full_column_tables) - for v in basked_data.values() - ), - ) - else: - for v in basked_data.values(): - assert ( - len(list(v._iter_chunks(row_indices=row_indices, full_column_tables=full_column_tables))) - == 0 - ) - else: - for i in range(timestamp_column_reader.metadata.num_row_groups): - truncated_column_tables = {} - for c in self._iter_column_names(include_basket_data_columns=False): - truncated_column_tables[c] = input_files[c].read_row_group(i) - yield ( - truncated_column_tables, - ( - v._iter_chunks(row_indices=None, full_column_tables=truncated_column_tables) - for v in basked_data.values() - ), - ) - - def _iterate_chunks(self, file_name, start_cutoff=None): - if self._dataset_partition.dataset.metadata.split_columns_to_files: - return self._iterate_folder_chunks(file_name, start_cutoff) - else: - return self._iterate_file_chunks(file_name, start_cutoff) - - def _iterate_merged_batches(self, merge_candidates): - iters = [] - # Here we need both start time and end time to be exclusive - start_cutoff = merge_candidates[0].start_time - datetime.timedelta(microseconds=1) - end_cutoff = merge_candidates[-1].end_time + datetime.timedelta(microseconds=1) - - for merge_candidate in merge_candidates: - merged_file_cutoff_start = None - if merge_candidate.start_time <= start_cutoff: - merged_file_cutoff_start = start_cutoff - assert end_cutoff > merge_candidate.end_time - iters.append(self._iterate_chunks(merge_candidate.file_path, start_cutoff=merged_file_cutoff_start)) - start_cutoff = merge_candidate.end_time - return itertools.chain(*iters) - - def _merged_data_folders(self, aggregation_folder, merge_candidates): - output_file_name = self._resolve_merged_output_file_name(merge_candidates) - - file_permissions = self._cache_config.data_file_permissions - folder_permission = file_permissions.get_folder_permissions() - - wip_file = _create_wip_file(aggregation_folder, start_time=None, is_folder=True) - apply_file_permissions(wip_file, folder_permission) - writers = {} - try: - for (column1, src_file_name), (column2, file_name) in zip( - self._iter_column_files(merge_candidates[0].file_path), self._iter_column_files(wip_file) - ): - assert column1 == column2 - schema = _pa().parquet.read_schema(src_file_name) - writers[column1] = _pa().parquet.ParquetWriter( - file_name, - schema=schema, - compression=self._parquet_output_config.compression, - version=ParquetWriter.PARQUET_VERSION, - ) - for batch, basket_batches in self._iterate_merged_batches(merge_candidates): - for column_name, values in batch.items(): - writers[column_name].write_table(values) - - for single_basket_column_batches in basket_batches: - for batch_columns in single_basket_column_batches: - for single_column_table in batch_columns: - writer = writers[single_column_table.column_names[0]] - writer.write_table(single_column_table) - finally: - for writer in writers.values(): - writer.close() - - for _, f in self._iter_column_files(wip_file): - apply_file_permissions(f, file_permissions) - - os.rename(wip_file, output_file_name) - - def _merge_data_files(self, aggregation_folder, merge_candidates): - output_file_name = self._resolve_merged_output_file_name(merge_candidates) - - file_permissions = self._cache_config.data_file_permissions - - wip_file = _create_wip_file(aggregation_folder, start_time=None, is_folder=False) - schema = _pa().parquet.read_schema(merge_candidates[0].file_path) - with _pa().parquet.ParquetWriter( - wip_file, - schema=schema, - compression=self._parquet_output_config.compression, - version=ParquetWriter.PARQUET_VERSION, - ) as parquet_writer: - for batch in self._iterate_merged_batches(merge_candidates): - parquet_writer.write_table(batch) - - apply_file_permissions(wip_file, file_permissions) - os.rename(wip_file, output_file_name) - - def _resolve_merge_candidates(self, existing_files): - if not existing_files or len(existing_files) <= 1: - return None - - merge_candidates = [] - - for (file_period_start, file_period_end), file_path in existing_files.items(): - if not merge_candidates: - merge_candidates.append( - _MergeFileInfo(file_path=file_path, start_time=file_period_start, end_time=file_period_end) - ) - continue - assert file_period_start >= merge_candidates[-1].start_time - if merge_candidates[-1].end_time + datetime.timedelta(microseconds=1) >= file_period_start: - merge_candidates.append( - _MergeFileInfo(file_path=file_path, start_time=file_period_start, end_time=file_period_end) - ) - elif len(merge_candidates) <= 1: - merge_candidates.clear() - merge_candidates.append( - _MergeFileInfo(file_path=file_path, start_time=file_period_start, end_time=file_period_end) - ) - else: - break - if len(merge_candidates) > 1: - return merge_candidates - return None - - def _merge_single_period(self, aggregation_folder, aggregation_period_start, aggregation_period_end): - lock_file_utils = ManagedDatasetLockUtil(self._cache_config.lock_file_permissions) - continue_merge = True - while continue_merge: - with lock_file_utils.merge_lock(aggregation_folder): - existing_files, _ = self._dataset_partition.data_paths.get_data_files_in_range( - aggregation_period_start, - aggregation_period_end, - missing_range_handler=lambda *args, **kwargs: True, - split_columns_to_files=self._split_columns_to_files, - truncate_data_periods=False, - include_read_folders=False, - ) - merge_candidates = self._resolve_merge_candidates(existing_files) - if not merge_candidates: - break - lock_file_paths = [r.file_path for r in merge_candidates] - locks = [lock_file_utils.write_lock(f, is_lock_in_root_folder=True) for f in lock_file_paths] - all_files_lock = MultipleFilesLock(locks) - if not all_files_lock.lock(): - break - - if self._dataset_partition.dataset.metadata.split_columns_to_files: - self._merged_data_folders(aggregation_folder, merge_candidates) - else: - self._merge_data_files(aggregation_folder, merge_candidates) - all_files_lock.unlock() - - def merge_files(self): - for ( - aggregation_period_start, - aggregation_period_end, - ) in self._aggregation_period_utils.iterate_periods_in_date_range( - start_time=self._start_time, end_time=self._end_time - ): - aggregation_period_end -= datetime.timedelta(microseconds=1) - aggregation_folder = self._dataset_partition.data_paths.get_output_folder_name(aggregation_period_start) - self._merge_single_period(aggregation_folder, aggregation_period_start, aggregation_period_end) diff --git a/csp/impl/managed_dataset/managed_dataset_path_resolver.py b/csp/impl/managed_dataset/managed_dataset_path_resolver.py deleted file mode 100644 index 7ded3713..00000000 --- a/csp/impl/managed_dataset/managed_dataset_path_resolver.py +++ /dev/null @@ -1,470 +0,0 @@ -import datetime -import glob -import os -from typing import Callable, Dict, List, Optional, Union - -import csp -from csp.impl.constants import UNSET -from csp.impl.managed_dataset.aggregation_period_utils import AggregationPeriodUtils -from csp.impl.managed_dataset.dataset_metadata import OutputType, TimeAggregation -from csp.impl.managed_dataset.dateset_name_constants import DatasetNameConstants - - -class DatasetPartitionPaths: - _FILE_EXTENSION_BY_TYPE = {OutputType.PARQUET: ".parquet"} - _FOLDER_DATA_GLOB_EXPRESSION = ( - "[0-9]" * 8 + "_" + "[0-9]" * 6 + "_" + "[0-9]" * 6 + "-" + "[0-9]" * 8 + "_" + "[0-9]" * 6 + "_" + "[0-9]" * 6 - ) - - DATA_FOLDER = "data" - - def __init__( - self, - dataset_root_folder: str, - dataset_read_folders, - partitioning_values: Dict[str, str] = None, - time_aggregation: TimeAggregation = TimeAggregation.DAY, - ): - self._partition_values = tuple(partitioning_values.values()) - self._time_aggregation = time_aggregation - if self._partition_values: - sub_folder_parts = list(map(str, self._partition_values)) - else: - sub_folder_parts = [] - - self._root_folder = os.path.join(dataset_root_folder, self.DATA_FOLDER, *sub_folder_parts) - self._read_folders = [os.path.join(v, self.DATA_FOLDER, *sub_folder_parts) for v in dataset_read_folders] - self._aggregation_period_utils = AggregationPeriodUtils(time_aggregation) - - @property - def root_folder(self): - return self._root_folder - - @classmethod - def _parse_file_name_times(cls, file_name): - base_name = os.path.basename(file_name) - start = datetime.datetime.strptime(base_name[:22], "%Y%m%d_%H%M%S_%f") - end = datetime.datetime.strptime(base_name[23:45], "%Y%m%d_%H%M%S_%f") - return (start, end) - - def get_period_start_time(self, start_time: datetime.datetime) -> datetime.datetime: - """Compute the start of the period for the given timestamp - :param start_time: - :return: - """ - return AggregationPeriodUtils(self._time_aggregation).resolve_period_start(start_time) - - def get_file_cutoff_time(self, start_time: datetime.datetime) -> datetime.datetime: - """Compute the latest time that should be written to the file for which the data start at a given time - :param start_time: - :return: - """ - return AggregationPeriodUtils(self._time_aggregation).resolve_period_end(start_time) - - def _get_existing_data_bound_for_root_folder(self, is_starttime, root_folder, split_columns_to_files): - agg_bound_folder = self._aggregation_period_utils.get_agg_bound_folder( - root_folder=root_folder, is_starttime=is_starttime - ) - if agg_bound_folder is None: - return None - if split_columns_to_files: - all_files = sorted(glob.glob(f"{glob.escape(agg_bound_folder)}/{self._FOLDER_DATA_GLOB_EXPRESSION}")) - else: - all_files = sorted(glob.glob(f"{glob.escape(agg_bound_folder)}/*.parquet")) - if not all_files: - return None - index = 0 if is_starttime else -1 - return self._parse_file_name_times(all_files[index])[index] - - def _iterate_root_and_read_folders(self, include_root_folder=True, include_read_folders=True): - if include_root_folder: - yield self._root_folder - if include_read_folders: - for f in self._read_folders: - yield f - - def _get_existing_data_bound_time( - self, is_starttime, *, split_columns_to_files: bool, include_root_folder=True, include_read_folders=True - ): - res = None - - for root_folder in self._iterate_root_and_read_folders( - include_root_folder=include_root_folder, include_read_folders=include_read_folders - ): - cur_res = self._get_existing_data_bound_for_root_folder(is_starttime, root_folder, split_columns_to_files) - if res is None or (cur_res is not None and ((cur_res < res) == is_starttime)): - res = cur_res - return res - - def _normalize_start_end_time( - self, - starttime: datetime.datetime, - endtime: Union[datetime.datetime, datetime.timedelta], - split_columns_to_files: bool, - ): - if starttime is None: - starttime = self._get_existing_data_bound_time(True, split_columns_to_files=split_columns_to_files) - if starttime is None: - return None, None - - if endtime is None: - endtime = self._get_existing_data_bound_time(False, split_columns_to_files=split_columns_to_files) - if endtime is None: - return None, None - elif isinstance(endtime, datetime.timedelta): - endtime = starttime + endtime - return starttime, endtime - - def _list_files_on_disk( - self, - starttime: datetime.datetime, - endtime: Union[datetime.datetime, datetime.timedelta], - split_columns_to_files=False, - return_unused=False, - include_read_folders=True, - ): - if starttime is None or endtime is None: - return [] - - files_with_times = [] - unused_files = [] - for period_start, _ in self._aggregation_period_utils.iterate_periods_in_date_range(starttime, endtime): - file_by_base_name = {} - for root_folder in self._iterate_root_and_read_folders(include_read_folders=include_read_folders): - date_output_folder = self.get_output_folder_name(period_start, root_folder) - if split_columns_to_files: - files = glob.glob(f"{glob.escape(date_output_folder)}/" + self._FOLDER_DATA_GLOB_EXPRESSION) - else: - files = glob.glob(f"{glob.escape(date_output_folder)}/*.parquet") - for f in files: - base_name = os.path.basename(f) - if base_name not in file_by_base_name: - file_by_base_name[base_name] = f - sorted_base_names = sorted(file_by_base_name) - files = [file_by_base_name[f] for f in sorted_base_names] - - for file in files: - file_start, file_end = self._parse_file_name_times(file) - # Files are sorted ascending by start_time, end_time. For a given start time, we want to keep the highest end_time - new_record = (file_start, file_end, file) - if files_with_times and files_with_times[-1][0] == file_start: - unused_files.append(files_with_times[-1][-1]) - files_with_times[-1] = new_record - elif files_with_times and files_with_times[-1][1] >= file_end: - # The file is fully included in the previous file range - unused_files.append(file) - else: - files_with_times.append(new_record) - return unused_files if return_unused else files_with_times - - def get_unused_files( - self, - starttime: datetime.datetime, - endtime: Union[datetime.datetime, datetime.timedelta], - split_columns_to_files=False, - ): - starttime, endtime = self._normalize_start_end_time(starttime, endtime, split_columns_to_files) - return self._list_files_on_disk( - starttime=starttime, - endtime=endtime, - split_columns_to_files=split_columns_to_files, - return_unused=True, - include_read_folders=False, - ) - - def get_data_files_in_range( - self, - starttime: datetime.datetime, - endtime: Union[datetime.datetime, datetime.timedelta], - missing_range_handler: Callable[[datetime.datetime, datetime.datetime], bool] = None, - split_columns_to_files=False, - truncate_data_periods=True, - include_read_folders=True, - ): - """Retrieve a list of all files in the given time range (inclusive) - :param starttime: The start time of the period - :param endtime: The end time of the period - :param missing_range_handler: A function that handles missing data. Will be called with (missing_period_starttime, missing_period_endtime), - should return True, if the missing data is not an error, should return False otherwise (in which case an exception will be raised). - By default if no missing_range_handler is specified, the function will raise exception on any missing data. - :param split_columns_to_files: A boolean that specifies whether the columns are split into separate files - :param truncate_data_periods: A boolean that specifies whether the time period of each file should be truncated to the period that is consumed for a given - time range. For example consider a file that exists for period (20210101-20210201) and we pass in the starttime=20210115 and endtime=20120116 then - for the file above the period (key of the returned dict) will be truncated to (20210115,20120116) if the flag is set to false then - (20210101,20210201) will be returned as a key instead. - :param include_read_folders: A boolean that specifies whether the files in "read_folders" should be included - :returns A tuple (files, full_coverage) where data is a dictionary of period->file_path and full_coverage is a boolean that is True only - if the whole requested period is covered by the files, False otherwise - """ - starttime, endtime = self._normalize_start_end_time(starttime, endtime, split_columns_to_files) - # It's a boolean but since we need to modify it from within internal function, we need to make it a list of boolean - full_coverage = [True] - - def handle_missing_period_error_reporting(start, end): - if not missing_range_handler or not missing_range_handler(start, end): - raise RuntimeError(f"Missing cache data for range {start} to {end}") - full_coverage[0] = False - - res = {} - - files_with_times = self._list_files_on_disk( - starttime=starttime, - endtime=endtime, - split_columns_to_files=split_columns_to_files, - include_read_folders=include_read_folders, - ) - - if starttime: - for period_start, _ in self._aggregation_period_utils.iterate_periods_in_date_range(starttime, endtime): - prev_end = None - for file_start, file_end, file in files_with_times: - file_new_data_start = file_start - - if prev_end is not None and prev_end >= file_start: - if file_end <= prev_end: - # The period of this file is fully covered in the previous one - continue - if truncate_data_periods: - file_new_data_start = prev_end + datetime.timedelta(microseconds=1) - - if ( - (starttime <= file_new_data_start <= endtime) - or (starttime <= file_end <= endtime) - or (file_new_data_start <= starttime <= endtime <= file_end) - ): - if truncate_data_periods and starttime > file_new_data_start: - file_new_data_start = starttime - if file_end > endtime and truncate_data_periods: - file_end = endtime - res[(file_new_data_start, file_end)] = file - prev_end = file_end - - if not res: - if starttime is not None or endtime is not None: - handle_missing_period_error_reporting(starttime, endtime) - return {}, False - else: - ONE_MICRO = datetime.timedelta(microseconds=1) - - dict_iter = iter(res.keys()) - period_start, period_end = next(dict_iter) - if period_start > starttime: - handle_missing_period_error_reporting(starttime, period_start - ONE_MICRO) - - for cur_start, cur_end in dict_iter: - if cur_start > period_end + ONE_MICRO: - handle_missing_period_error_reporting(period_end + ONE_MICRO, cur_start - ONE_MICRO) - period_end = cur_end - if period_end < endtime: - handle_missing_period_error_reporting(period_end + ONE_MICRO, endtime) - - return res, full_coverage[0] - - def get_output_folder_name(self, start_time: Union[datetime.datetime, datetime.date], root_folder=None): - root_folder = root_folder or self._root_folder - return os.path.join(root_folder, self._aggregation_period_utils.get_sub_folder_name(start_time)) - - def get_output_file_name( - self, - start_time: datetime.datetime, - end_time: datetime.datetime, - output_type: OutputType = OutputType.PARQUET, - split_columns_to_files: bool = False, - ): - assert end_time >= start_time - if output_type not in (OutputType.PARQUET,): - raise NotImplementedError(f"Unsupported output type: {output_type}") - - output_folder = self.get_output_folder_name(start_time=start_time) - assert end_time <= self._aggregation_period_utils.resolve_period_end(start_time, exclusive_end=False) - if split_columns_to_files: - file_extension = "" - else: - file_extension = self._FILE_EXTENSION_BY_TYPE[output_type] - return os.path.join( - output_folder, - f"{start_time.strftime('%Y%m%d_%H%M%S_%f')}-{end_time.strftime('%Y%m%d_%H%M%S_%f')}{file_extension}", - ) - - -class DatasetPartitionKey: - def __init__(self, value_dict): - self._value_dict = value_dict - self._key = None - - @property - def kwargs(self): - return self._value_dict - - def _get_key(self): - if self._key is None: - self._key = tuple(self._value_dict.items()) - return self._key - - def __str__(self): - return f"DatasetPartitionKey({self._value_dict})" - - def __repr__(self): - return str(self) - - def __eq__(self, other): - if not isinstance(other, DatasetPartitionKey): - return False - return self._get_key() == other._get_key() - - def __hash__(self): - return hash(self._get_key()) - - -class DatasetPaths(object): - DATASET_METADATA_FILE_NAME = "dataset_meta.yml" - - def __init__( - self, - parent_folder: str, - read_folders: str, - name: str, - time_aggregation=TimeAggregation.DAY, - data_category: Optional[List[str]] = None, - ): - self._name = name - self._time_aggregation = time_aggregation - self._data_category = data_category - - # Note we must call the list on data_category since we want a copy that we're going to modify - dataset_sub_folder_parts = list(data_category) if data_category else [] - dataset_sub_folder_parts.append(name) - self._dataset_sub_folder_parts_str = os.path.join(*dataset_sub_folder_parts) - self._parent_folder = parent_folder - self._dataset_root_folder = os.path.abspath(os.path.join(parent_folder, self._dataset_sub_folder_parts_str)) - self._dataset_read_root_folders = ( - [os.path.abspath(os.path.join(v, *dataset_sub_folder_parts)) for v in read_folders] if read_folders else [] - ) - - def get_partition_paths(self, partitioning_values: Dict[str, str] = None): - return DatasetPartitionPaths( - self.root_folder, - self._dataset_read_root_folders, - partitioning_values, - time_aggregation=self._time_aggregation, - ) - - @property - def parent_folder(self): - return self._parent_folder - - @property - def root_folder(self): - return self._dataset_root_folder - - @classmethod - def _get_metadata_file_path(cls, root_folder): - return os.path.join(root_folder, cls.DATASET_METADATA_FILE_NAME) - - def get_metadata_file_path(self, existing: bool): - """ - Get the metadata file path if "existing" is True then any metadata from either root folder or read folders will be returned (whichever exists) or None if not - metadata file exists. If "existing" is False then the metadata for the "root_folder" will be returned, no matter if it exists or not. - :param existing: - :return: - """ - if not existing: - return os.path.join(self.root_folder, self.DATASET_METADATA_FILE_NAME) - - for folder in self._iter_root_folders(True): - file_path = os.path.join(folder, self.DATASET_METADATA_FILE_NAME) - if os.path.exists(file_path): - return file_path - return None - - def _iter_root_folders(self, use_read_folders): - yield self._dataset_root_folder - if use_read_folders: - for f in self._dataset_read_root_folders: - yield f - - def _resolve_partitions_recursively(self, metadata, cur_path, columns, column_index=0): - if column_index >= len(columns): - yield {} - return - - col_name = columns[column_index] - col_type = metadata.partition_columns[col_name] - - for sub_folder in os.listdir(cur_path): - cur_value, sub_folder_full = self._load_value_from_path(cur_path, sub_folder, col_type) - if cur_value is not UNSET: - for res in self._resolve_partitions_recursively(metadata, sub_folder_full, columns, column_index + 1): - d = {col_name: cur_value} - d.update(**res) - yield d - - def _load_value_from_path(self, cur_path, sub_folder, col_type): - cur_value = UNSET - sub_folder_full = os.path.join(cur_path, sub_folder) - if issubclass(col_type, csp.Struct): - if os.path.isdir(sub_folder_full) and sub_folder.startswith("struct_"): - value_file = os.path.join(sub_folder_full, DatasetNameConstants.PARTITION_ARGUMENT_FILE_NAME) - if os.path.exists(os.path.exists(value_file)): - with open(value_file, "r") as f: - cur_value = col_type.from_yaml(f.read()) - elif col_type in (int, float, str): - try: - cur_value = col_type(sub_folder) - except ValueError: - pass - elif col_type is datetime.date: - try: - cur_value = datetime.datetime.strptime(sub_folder, "%Y%m%d_000000_000000").date() - except ValueError: - pass - elif col_type is datetime.datetime: - try: - cur_value = datetime.datetime.strptime(sub_folder, "%Y%m%d_%H%M%S_%f") - except ValueError: - pass - elif col_type is datetime.timedelta: - try: - if sub_folder.startswith("td_") and sub_folder.endswith("us"): - cur_value = datetime.timedelta(microseconds=int(sub_folder[3:-2])) - except ValueError: - pass - elif col_type is bool: - if sub_folder == "True": - cur_value = True - elif sub_folder == "False": - cur_value = False - else: - raise RuntimeError(f"Unsupported partition value type {col_type}: {sub_folder}") - return cur_value, sub_folder_full - - def get_partition_keys(self, metadata): - if not hasattr(metadata, "partition_columns") or not metadata.partition_columns: - return [DatasetPartitionKey({})] - - results_set = set() - results = [] - - columns = list(metadata.partition_columns) - for root_folder in self._iter_root_folders(True): - data_folder = os.path.join(root_folder, DatasetPartitionPaths.DATA_FOLDER) - for res in self._resolve_partitions_recursively(metadata, data_folder, columns=columns): - key = DatasetPartitionKey(res) - if key not in results_set: - results_set.add(key) - results.append(key) - return results - - def resolve_lock_file_path(self, desired_path, use_read_folders): - """ - :param desired_path: The desired path of the lock as if it was in the data folder (this path is modified to a separate path) - :param use_read_folders: A boolean flags whether the read folders should be tried as the prefix for the current desired path - :return: A tuple of (parent_folder, file_path) where parent_folder is the LAST non lock specific folder in the path (anything after this is lock specific and should - be created with different permissions) - """ - for f in self._iter_root_folders(use_read_folders=use_read_folders): - if os.path.commonprefix((desired_path, f)) == f: - parent_folder = f[: -len(self._dataset_sub_folder_parts_str)] - rel_path = os.path.relpath(desired_path, parent_folder) - return parent_folder, os.path.join(parent_folder, ".locks", rel_path) - raise RuntimeError(f"Unable to resolve lock file path for file {desired_path}") diff --git a/csp/impl/managed_dataset/managed_parquet_writer.py b/csp/impl/managed_dataset/managed_parquet_writer.py deleted file mode 100644 index c1b8ba41..00000000 --- a/csp/impl/managed_dataset/managed_parquet_writer.py +++ /dev/null @@ -1,340 +0,0 @@ -import datetime -import os -from typing import Dict, Optional, TypeVar, Union - -import csp -from csp.adapters.parquet import ParquetOutputConfig, ParquetWriter -from csp.impl.managed_dataset.cache_user_custom_object_serializer import CacheObjectSerializer -from csp.impl.managed_dataset.dateset_name_constants import DatasetNameConstants -from csp.impl.managed_dataset.managed_dataset import ManagedDatasetPartition -from csp.impl.managed_dataset.managed_dataset_merge_utils import _create_wip_file -from csp.impl.wiring import Context -from csp.impl.wiring.cache_support.partition_files_container import PartitionFileContainer -from csp.impl.wiring.outputs import OutputsContainer -from csp.impl.wiring.special_output_names import ALL_SPECIAL_OUTPUT_NAMES, CSP_CACHE_ENABLED_OUTPUT - -T = TypeVar("T") - - -def _pa(): - """ - Lazy import pyarrow - """ - import pyarrow - - return pyarrow - - -def _create_output_file_or_folder(data_paths, cur_file_start_time, split_columns_to_files): - output_folder = data_paths.get_output_folder_name(start_time=cur_file_start_time) - if not os.path.exists(output_folder): - os.makedirs(output_folder, exist_ok=True) - s_cur_file_path = _create_wip_file(output_folder, cur_file_start_time, is_folder=split_columns_to_files) - return s_cur_file_path - - -def _generate_empty_parquet_files(dataset_partition, existing_file, files_to_generate, parquet_output_config): - if not files_to_generate: - return - - if os.path.isdir(existing_file): - file_schemas = { - f: _pa().parquet.ParquetFile(os.path.join(existing_file, f)).schema.to_arrow_schema() - for f in os.listdir(existing_file) - if f.endswith(".parquet") - } - for (s, e), dir_name in files_to_generate.items(): - for f_name, schema in file_schemas.items(): - with _pa().parquet.ParquetWriter( - os.path.join(dir_name, f_name), - schema=schema, - compression=parquet_output_config.compression, - version=ParquetWriter.PARQUET_VERSION, - ): - pass - PartitionFileContainer.get_instance().add_generated_file( - dataset_partition, s, e, dir_name, parquet_output_config - ) - else: - file_info = _pa().parquet.ParquetFile(existing_file) - schema = file_info.schema.to_arrow_schema() - - for (s, e), f_name in files_to_generate.items(): - with _pa().parquet.ParquetWriter( - f_name, - schema=schema, - compression=parquet_output_config.compression, - version=ParquetWriter.PARQUET_VERSION, - ): - PartitionFileContainer.get_instance().add_generated_file( - dataset_partition, s, e, f_name, parquet_output_config - ) - - -@csp.node -def _cache_filename_provider_custom_time( - dataset_partition: ManagedDatasetPartition, - config: Optional[ParquetOutputConfig], - split_columns_to_files: Optional[bool], - timestamp_ts: csp.ts[datetime.datetime], -) -> csp.ts[str]: - with csp.state(): - s_data_paths = dataset_partition.data_paths - s_last_start_time = None - s_cur_file_path = None - s_cur_file_cutoff_time = None - s_empty_files_to_generate = {} - s_last_closed_file = None - - with csp.start(): - config = config.copy() if config is not None else ParquetOutputConfig() - config.resolve_compression() - - with csp.stop(): - # We need to chack that s_cur_file_path since if the engine had startup error, s_cur_file_path is undefined - if "s_cur_file_path" in locals() and s_cur_file_path: - PartitionFileContainer.get_instance().add_generated_file( - dataset_partition, s_last_start_time, timestamp_ts, s_cur_file_path, config - ) - _generate_empty_parquet_files(dataset_partition, s_last_closed_file, s_empty_files_to_generate, config) - - if csp.ticked(timestamp_ts): - if s_cur_file_cutoff_time is None: - s_last_start_time = timestamp_ts - s_cur_file_cutoff_time = s_data_paths.get_file_cutoff_time(s_last_start_time) - s_cur_file_path = _create_output_file_or_folder(s_data_paths, s_last_start_time, split_columns_to_files) - return s_cur_file_path - elif timestamp_ts >= s_cur_file_cutoff_time: - PartitionFileContainer.get_instance().add_generated_file( - dataset_partition, - s_last_start_time, - s_cur_file_cutoff_time - datetime.timedelta(microseconds=1), - s_cur_file_path, - config, - ) - s_last_closed_file = s_cur_file_path - s_last_start_time = s_cur_file_cutoff_time - s_cur_file_cutoff_time = s_data_paths.get_file_cutoff_time(s_last_start_time) - # There might be some empty files in the middle, we need to take care of this by creating a bunch of empty files on the way - while s_cur_file_cutoff_time <= timestamp_ts: - s_cur_file_path = _create_output_file_or_folder(s_data_paths, s_last_start_time, split_columns_to_files) - s_empty_files_to_generate[ - (s_last_start_time, s_cur_file_cutoff_time - datetime.timedelta(microseconds=1)) - ] = s_cur_file_path - s_last_start_time = s_cur_file_cutoff_time - s_cur_file_cutoff_time = s_data_paths.get_file_cutoff_time(s_last_start_time) - - s_cur_file_path = _create_output_file_or_folder(s_data_paths, s_last_start_time, split_columns_to_files) - return s_cur_file_path - - -def _finalize_current_output_file( - data_paths, config, dataset_partition, now, cur_file_path, split_columns_to_files, last_start_time, cache_enabled -): - if cur_file_path: - PartitionFileContainer.get_instance().add_generated_file( - dataset_partition, last_start_time, now - datetime.timedelta(microseconds=1), cur_file_path, config - ) - - if cache_enabled: - output_folder = data_paths.get_output_folder_name(start_time=now) - if not os.path.exists(output_folder): - os.makedirs(output_folder, exist_ok=True) - return _create_wip_file(output_folder, now, is_folder=split_columns_to_files) - else: - return "" - - -@csp.node -def _cache_filename_provider( - dataset_partition: ManagedDatasetPartition, - config: Optional[ParquetOutputConfig], - split_columns_to_files: Optional[bool], - cache_control_ts: csp.ts[bool], - default_cache_enabled: bool, -) -> csp.ts[str]: - with csp.alarms(): - a_update_file_alarm = csp.alarm(bool) - - with csp.state(): - s_data_paths = dataset_partition.data_paths - s_last_start_time = None - s_cur_file_path = None - s_cache_enabled = default_cache_enabled - - with csp.start(): - config = config if config is not None else ParquetOutputConfig() - csp.schedule_alarm(a_update_file_alarm, datetime.timedelta(), False) - - with csp.stop(): - # We need to chack that s_cur_file_path since if the engine had startup error, s_cur_file_path is undefined - if "s_cur_file_path" in locals() and s_cur_file_path: - if s_cache_enabled: - PartitionFileContainer.get_instance().add_generated_file( - dataset_partition, s_last_start_time, csp.now(), s_cur_file_path, config - ) - - if csp.ticked(cache_control_ts): - if cache_control_ts: - # We didn't write and need to start writing - if not s_cache_enabled: - s_cache_enabled = True - s_cur_file_path = _finalize_current_output_file( - s_data_paths, - config, - dataset_partition, - csp.now(), - s_cur_file_path, - split_columns_to_files, - s_last_start_time, - s_cache_enabled, - ) - s_last_start_time = csp.now() - cutoff_time = s_data_paths.get_file_cutoff_time(s_last_start_time) - csp.schedule_alarm(a_update_file_alarm, cutoff_time, False) - return s_cur_file_path - else: - # It's a bit ugly for now, we will keep writing even when cache is disabled but then we will throw away the written data. - # we need a better way to address this in the future - if s_cache_enabled: - s_cache_enabled = False - s_cur_file_path = _finalize_current_output_file( - s_data_paths, - config, - dataset_partition, - csp.now(), - s_cur_file_path, - split_columns_to_files, - s_last_start_time, - s_cache_enabled, - ) - s_last_start_time = csp.now() - cutoff_time = s_data_paths.get_file_cutoff_time(s_last_start_time) - csp.schedule_alarm(a_update_file_alarm, cutoff_time, False) - return s_cur_file_path - - if csp.ticked(a_update_file_alarm) and s_last_start_time != csp.now(): - s_cur_file_path = _finalize_current_output_file( - s_data_paths, - config, - dataset_partition, - csp.now(), - s_cur_file_path, - split_columns_to_files, - s_last_start_time, - s_cache_enabled, - ) - s_last_start_time = csp.now() - cutoff_time = s_data_paths.get_file_cutoff_time(s_last_start_time) - csp.schedule_alarm(a_update_file_alarm, cutoff_time, False) - return s_cur_file_path - - -@csp.node -def _serialize_value(value: csp.ts["T"], type_serializer: CacheObjectSerializer) -> csp.ts[bytes]: - if csp.ticked(value): - csp.output(type_serializer.serialize_to_bytes(value)) - - -def create_managed_parquet_writer_node( - function_name: str, - dataset_partition: ManagedDatasetPartition, - values: OutputsContainer, - field_mapping: Dict[str, Union[str, Dict[str, str]]], - config: Optional[ParquetOutputConfig] = None, - data_timestamp_column_name=None, - controlled_cache: bool = False, - default_cache_enabled: bool = True, -): - metadata = dataset_partition.dataset.metadata - if data_timestamp_column_name is None: - timestamp_column_name = getattr(metadata, "timestamp_column_name", None) - else: - timestamp_column_name = data_timestamp_column_name - config = config.copy() if config else ParquetOutputConfig() - config.allow_overwrite = True - cache_serializers = Context.instance().config.cache_config.cache_serializers - - split_columns_to_files = metadata.split_columns_to_files - - if controlled_cache: - cache_control_ts = values[CSP_CACHE_ENABLED_OUTPUT] - else: - cache_control_ts = csp.const(True) - default_cache_enabled = True - - if not isinstance(values, OutputsContainer): - values = OutputsContainer(**{DatasetNameConstants.UNNAMED_OUTPUT_NAME: values}) - - if data_timestamp_column_name and data_timestamp_column_name != DatasetNameConstants.CSP_TIMESTAMP: - timestamp_ts = values - for k in data_timestamp_column_name.split("."): - timestamp_ts = getattr(timestamp_ts, k) - writer = ParquetWriter( - file_name=None, - timestamp_column_name=None, - config=config, - filename_provider=_cache_filename_provider_custom_time( - dataset_partition=dataset_partition, - config=config, - split_columns_to_files=split_columns_to_files, - timestamp_ts=timestamp_ts, - ), - split_columns_to_files=split_columns_to_files, - ) - else: - writer = ParquetWriter( - file_name=None, - timestamp_column_name=timestamp_column_name, - config=config, - filename_provider=_cache_filename_provider( - dataset_partition=dataset_partition, - config=config, - split_columns_to_files=split_columns_to_files, - cache_control_ts=cache_control_ts, - default_cache_enabled=default_cache_enabled, - ), - split_columns_to_files=split_columns_to_files, - ) - - all_columns = set() - for key, value in values._items(): - if key in ALL_SPECIAL_OUTPUT_NAMES: - continue - if isinstance(value, dict): - basket_metadata = metadata.dict_basket_columns[key] - writer.publish_dict_basket( - key, value, key_type=basket_metadata.key_type, value_type=basket_metadata.value_type - ) - elif isinstance(value.tstype.typ, type) and issubclass(value.tstype.typ, csp.Struct): - s_field_map = field_mapping.get(key) - for k, v in s_field_map.items(): - try: - if v in all_columns: - raise RuntimeError(f"Found multiple writers of column {v}") - except TypeError: - raise RuntimeError(f"Invalid cache field name mapping: {v}") - all_columns.add(v) - - writer.publish_struct(value, field_map=field_mapping.get(key)) - else: - col_name = field_mapping.get(key, key) - try: - if col_name in all_columns: - raise RuntimeError(f"Found multiple writers of column {col_name}") - except TypeError: - raise RuntimeError(f"Invalid cache field name mapping: {col_name}") - all_columns.add(col_name) - type_serializer = cache_serializers.get(value.tstype.typ) - if type_serializer: - writer.publish(col_name, _serialize_value(value, type_serializer)) - else: - writer.publish(col_name, value) - if ( - data_timestamp_column_name - and data_timestamp_column_name not in all_columns - and data_timestamp_column_name != DatasetNameConstants.CSP_TIMESTAMP - ): - raise RuntimeError( - f"{data_timestamp_column_name} specified as timestamp column but no writers for this column found" - ) diff --git a/csp/impl/mem_cache.py b/csp/impl/mem_cache.py index f4886465..784b3de2 100644 --- a/csp/impl/mem_cache.py +++ b/csp/impl/mem_cache.py @@ -1,9 +1,9 @@ import copy import inspect -import logging import threading from collections import namedtuple from functools import wraps +from warnings import warn from csp.impl.constants import UNSET @@ -149,11 +149,6 @@ def _preprocess_args(args): yield (arg_name, normalize_arg(arg_value)) -class _WarnedFlag: - def __init__(self): - self.value = False - - def function_full_name(f): """A utility function that can be used for implementation of function_name for csp_memoized_graph_object :param f: @@ -177,7 +172,6 @@ def csp_memoized(func=None, *, force_memoize=False, function_name=None, is_user_ :param is_user_data: A flag that specifies whether the memoized object is user object or graph object :return: """ - warned_flag = _WarnedFlag() def _impl(func): func_args = _resolve_func_args(func) @@ -204,10 +198,8 @@ def __call__(*args, **kwargs): except TypeError as e: if force_memoize: raise - if not warned_flag.value: - logging_context = function_name if function_name else str(func) - logging.debug(f"Not memoizing output of {str(logging_context)}: {str(e)}") - warned_flag.value = True + logging_context = function_name if function_name else str(func) + warn(f"Not memoizing output of {str(logging_context)}: {str(e)}", Warning) cur_item = func(*args, **kwargs) else: if cur_item is UNSET: diff --git a/csp/impl/types/instantiation_type_resolver.py b/csp/impl/types/instantiation_type_resolver.py index 22e5e66c..baafef2e 100644 --- a/csp/impl/types/instantiation_type_resolver.py +++ b/csp/impl/types/instantiation_type_resolver.py @@ -21,7 +21,7 @@ def __init__(self): self._type_registry: typing.Dict[typing.Tuple[type, type], type] = {} self._add_type_upcast(int, float, float) - def resolve_type(self, expected_type: type, new_type: type, allow_subtypes: bool, raise_on_error=True): + def resolve_type(self, expected_type: type, new_type: type, raise_on_error=True): if expected_type == new_type: return expected_type if expected_type is object or new_type is object: @@ -57,7 +57,7 @@ def resolve_type(self, expected_type: type, new_type: type, allow_subtypes: bool else: return None - if allow_subtypes and inspect.isclass(expected_type) and inspect.isclass(new_type): + if inspect.isclass(expected_type) and inspect.isclass(new_type): if issubclass(expected_type, new_type): # Generally if B inherits from A, we want to resolve from A, the only exception # is "Generic types". Dict[int, int] inherits from dict but we want the type to be resolved to the generic type @@ -203,7 +203,6 @@ def __init__( values: typing.List[object], forced_tvars: typing.Union[typing.Dict[str, typing.Type], None], is_input=True, - allow_subtypes=True, allow_none_ts=False, ): self._function_name = function_name @@ -211,7 +210,6 @@ def __init__( self._arguments = values self._forced_tvars = forced_tvars self._def_name = "inputdef" if is_input else "outputdef" - self._allow_subtypes = allow_subtypes self._allow_none_ts = allow_none_ts self._tvars: typing.Dict[str, type] = {} @@ -318,9 +316,7 @@ def _rec_validate_type_spec_vs_type_spec_and_resolve_tvars( return False else: # At this point it must be a scalar value - res_type = UpcastRegistry.instance().resolve_type( - expected_sub_type, actual_sub_type, allow_subtypes=self._allow_subtypes, raise_on_error=False - ) + res_type = UpcastRegistry.instance().resolve_type(expected_sub_type, actual_sub_type, raise_on_error=False) return res_type is expected_sub_type return True @@ -391,12 +387,7 @@ def _add_scalar_value(self, arg, in_out_def): def _is_scalar_value_matching_spec(self, inp_def_type, arg): if inp_def_type is typing.Any: return True - if ( - UpcastRegistry.instance().resolve_type( - inp_def_type, type(arg), allow_subtypes=self._allow_subtypes, raise_on_error=False - ) - is inp_def_type - ): + if UpcastRegistry.instance().resolve_type(inp_def_type, type(arg), raise_on_error=False) is inp_def_type: return True if CspTypingUtils.is_union_type(inp_def_type): types = inp_def_type.__args__ @@ -533,9 +524,7 @@ def _add_t_var_resolution(self, tvar, resolved_type, arg=None): self._raise_arg_mismatch_error(arg=self._cur_arg, tvar_info={tvar: old_tvar_type}) return - combined_type = UpcastRegistry.instance().resolve_type( - resolved_type, old_tvar_type, allow_subtypes=self._allow_subtypes, raise_on_error=False - ) + combined_type = UpcastRegistry.instance().resolve_type(resolved_type, old_tvar_type, raise_on_error=False) if combined_type is None: conflicting_tvar_types = self._conflicting_tvar_types.get(tvar) if conflicting_tvar_types is None: @@ -605,9 +594,7 @@ def _try_resolve_tvar_conflicts(self): assert resolved_type, f'"{tvar}" was not resolved' for conflicting_type in conflicting_types: if ( - UpcastRegistry.instance().resolve_type( - resolved_type, conflicting_type, allow_subtypes=self._allow_subtypes, raise_on_error=False - ) + UpcastRegistry.instance().resolve_type(resolved_type, conflicting_type, raise_on_error=False) is not resolved_type ): raise TypeError( @@ -627,7 +614,6 @@ def __init__( input_definitions: typing.Tuple[InputDef], arguments: typing.List[object], forced_tvars: typing.Union[typing.Dict[str, typing.Type], None], - allow_subtypes: bool = True, allow_none_ts: bool = False, ): self._scalar_inputs: typing.List[object] = [] @@ -637,7 +623,6 @@ def __init__( input_or_output_definitions=input_definitions, values=arguments, forced_tvars=forced_tvars, - allow_subtypes=allow_subtypes, allow_none_ts=allow_none_ts, ) @@ -692,14 +677,12 @@ def __init__( output_definitions: typing.Tuple[OutputDef], values: typing.List[object], forced_tvars: typing.Union[typing.Dict[str, typing.Type], None], - allow_subtypes=True, ): super().__init__( function_name=function_name, input_or_output_definitions=output_definitions, values=values, forced_tvars=forced_tvars, - allow_subtypes=allow_subtypes, allow_none_ts=False, ) diff --git a/csp/impl/wiring/base_parser.py b/csp/impl/wiring/base_parser.py index 760d803d..5b90fc6d 100644 --- a/csp/impl/wiring/base_parser.py +++ b/csp/impl/wiring/base_parser.py @@ -20,11 +20,9 @@ OutputTypeError, ) from csp.impl.types.container_type_normalizer import ContainerTypeNormalizer -from csp.impl.types.tstype import TsType from csp.impl.types.type_annotation_normalizer_transformer import TypeAnnotationNormalizerTransformer from csp.impl.types.typing_utils import CspTypingUtils from csp.impl.warnings import WARN_PYTHONIC -from csp.impl.wiring.special_output_names import CSP_CACHE_ENABLED_OUTPUT LEGACY_METHODS = {"__alarms__", "__state__", "__start__", "__stop__", "__outputs__", "__return__"} @@ -92,7 +90,7 @@ def wrapper(*args, **kwargs): class BaseParser(ast.NodeTransformer, metaclass=ABCMeta): _DEBUG_PARSE = False - def __init__(self, name, raw_func, func_frame, debug_print=False, add_cache_control_output=False): + def __init__(self, name, raw_func, func_frame, debug_print=False): self._name = name self._outputs = [] self._special_outputs = tuple() @@ -115,7 +113,6 @@ def __init__(self, name, raw_func, func_frame, debug_print=False, add_cache_cont body = ast.parse(source) self._funcdef = body.body[0] self._type_annotation_normalizer.normalize_type_annotations(self._funcdef) - self._add_cache_control_output = add_cache_control_output def _eval_expr(self, exp): return eval( @@ -548,15 +545,3 @@ def _postprocess_basket_outputs(self, main_func_signature, enforce_shape_for_bas output.typ.shape_func = self._compile_function(shape_func) else: output.typ.shape_func = lambda *args, s=output.typ.shape: s - - def _resolve_special_outputs(self): - if self._add_cache_control_output: - self._special_outputs += ( - OutputDef( - name=CSP_CACHE_ENABLED_OUTPUT, - typ=TsType[bool], - kind=ArgKind.TS, - ts_idx=self._outputs[-1].ts_idx + 1, - shape=None, - ), - ) diff --git a/csp/impl/wiring/cache_support/__init__.py b/csp/impl/wiring/cache_support/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/csp/impl/wiring/cache_support/cache_config_resolver.py b/csp/impl/wiring/cache_support/cache_config_resolver.py deleted file mode 100644 index a1b51233..00000000 --- a/csp/impl/wiring/cache_support/cache_config_resolver.py +++ /dev/null @@ -1,22 +0,0 @@ -from typing import List - -from csp.impl.config import CacheConfig - - -class CacheConfigResolver: - def __init__(self, cache_config: CacheConfig): - from csp.impl.wiring.cache_support.graph_building import CacheCategoryOverridesTree - - self._cache_config = cache_config - if cache_config: - self._cache_category_overrides = CacheCategoryOverridesTree.construct_from_cache_config(cache_config) - self._graph_overrides = getattr(cache_config, "graph_overrides", {}) - else: - self._cache_category_overrides = None - self._graph_overrides = None - - def resolve_cache_config(self, graph: object, category: List[str]): - resolved_config = self._graph_overrides.get(graph, None) - if resolved_config is None: - resolved_config = self._cache_category_overrides.resolve_root_folder(category) - return resolved_config diff --git a/csp/impl/wiring/cache_support/cache_type_mapper.py b/csp/impl/wiring/cache_support/cache_type_mapper.py deleted file mode 100644 index f2a69702..00000000 --- a/csp/impl/wiring/cache_support/cache_type_mapper.py +++ /dev/null @@ -1,55 +0,0 @@ -import datetime -from typing import Union - -import csp.typing -from csp.impl.types.typing_utils import CspTypingUtils -from csp.utils.qualified_name_utils import QualifiedNameUtils - - -class CacheTypeMapper: - STRING_TO_TYPE_MAPPING = { - "datetime": datetime.datetime, - "date": datetime.date, - "timedelta": datetime.timedelta, - "int": int, - "float": float, - "str": str, - "bool": bool, - } - TYPE_TO_STRING_MAPPING = {v: k for k, v in STRING_TO_TYPE_MAPPING.items()} - ARRAY_TYPE_NAME_TO_TYPE = { - "ARRAY": csp.typing.Numpy1DArray, - "MULTI_DIM_ARRAY": csp.typing.NumpyNDArray, - } - ARRAY_TYPE_TO_TYPE_NAME = {v: k for k, v in ARRAY_TYPE_NAME_TO_TYPE.items()} - - @classmethod - def json_to_type(cls, typ: Union[str, dict]): - if isinstance(typ, str): - python_type = cls.STRING_TO_TYPE_MAPPING.get(typ) - if python_type is None: - python_type = QualifiedNameUtils.get_object_from_qualified_name(typ) - if python_type is None: - raise TypeError(f"Unsupported arrow serialization type {typ}") - return python_type - else: - array_type = None - if isinstance(typ, dict) and len(typ) == 1: - typ_key, typ_value = next(iter(typ.items())) - array_type = cls.ARRAY_TYPE_NAME_TO_TYPE.get(typ_key) - if array_type is None: - raise TypeError(f"Trying to deserialize invalid type: {typ}") - return array_type[cls.json_to_type(typ_value)] - - @classmethod - def type_to_json(cls, typ): - str_type = cls.TYPE_TO_STRING_MAPPING.get(typ) - if str_type is None: - if CspTypingUtils.is_generic_container(typ): - origin = CspTypingUtils.get_origin(typ) - type_name = cls.ARRAY_TYPE_TO_TYPE_NAME.get(origin) - if type_name is not None: - return {type_name: cls.type_to_json(typ.__args__[0])} - - return QualifiedNameUtils.get_qualified_object_name(typ) - return str_type diff --git a/csp/impl/wiring/cache_support/dataset_partition_cached_data.py b/csp/impl/wiring/cache_support/dataset_partition_cached_data.py deleted file mode 100644 index 4b227aad..00000000 --- a/csp/impl/wiring/cache_support/dataset_partition_cached_data.py +++ /dev/null @@ -1,662 +0,0 @@ -import datetime -import itertools -import logging -import numpy -import os -import pytz -import shutil -from concurrent.futures.thread import ThreadPoolExecutor -from typing import Callable, Dict, List, Optional - -import csp -from csp.adapters.output_adapters.parquet import resolve_array_shape_column_name -from csp.impl.managed_dataset.aggregation_period_utils import AggregationPeriodUtils -from csp.impl.managed_dataset.dateset_name_constants import DatasetNameConstants -from csp.impl.types.typing_utils import CspTypingUtils - - -class DataSetCachedData: - def __init__(self, dataset, cache_serializers, data_set_partition_calculator_func): - self._dataset = dataset - self._cache_serializers = cache_serializers - self._data_set_partition_calculator_func = data_set_partition_calculator_func - - def get_partition_keys(self): - return self._dataset.data_paths.get_partition_keys(self._dataset.metadata) - - def __call__(self, *args, **kwargs): - return DatasetPartitionCachedData( - self._data_set_partition_calculator_func(*args, **kwargs), self._cache_serializers - ) - - -class DatasetPartitionCachedData: - def __init__(self, dataset_partition, cache_serializers): - self._dataset_partition = dataset_partition - self._cache_serializers = cache_serializers - - @property - def metadata(self): - return self._dataset_partition.dataset.metadata - - @classmethod - def _normalize_time(cls, time: datetime.datetime, drop_tz_info=False): - res = None - if time is not None: - if isinstance(time, datetime.timedelta): - return time - if time.tzinfo is None: - res = pytz.utc.localize(time) - else: - res = time.astimezone(pytz.UTC) - if res is not None and drop_tz_info: - res = res.replace(tzinfo=None) - return res - - def _get_shape_columns(self, column_list): - for c in column_list: - c_type = self.metadata.columns.get(c) - if c_type and CspTypingUtils.is_numpy_nd_array_type(c_type): - yield c, resolve_array_shape_column_name(c) - - def _get_shape_columns_dict(self, column_list): - return dict(self._get_shape_columns(column_list)) - - def _get_array_columns(self, column_list): - for c in column_list: - c_type = self.metadata.columns.get(c) - if c_type and CspTypingUtils.is_numpy_array_type(c_type): - yield c - - def _get_array_columns_set(self, column_list): - return set(self._get_array_columns(column_list)) - - def get_data_files_for_period( - self, - starttime: Optional[datetime.datetime] = None, - endtime: Optional[datetime.datetime] = None, - missing_range_handler: Callable[[datetime.datetime, datetime.datetime], bool] = None, - ): - """Retrieve a list of all files in the given time range (inclusive) - :param starttime: The start time of the period - :param endtime: The end time of the period - :param missing_range_handler: A function that handles missing data. Will be called with (missing_period_starttime, missing_period_endtime), - should return True, if the missing data is not an error, should return False otherwise (in which case an exception will be raised) - """ - return self._dataset_partition.get_data_for_period( - self._normalize_time(starttime, True), self._normalize_time(endtime, True), missing_range_handler - )[0] - - def _truncate_df(self, starttime, endtime, df): - import pandas - - if starttime is not None: - starttime = pytz.UTC.localize(starttime) - if endtime is not None: - endtime = pytz.UTC.localize(endtime) - - timestamp_column_name = self._remove_unnamed_output_prefix(self.metadata.timestamp_column_name) - - if starttime is not None or endtime is not None: - mask = pandas.Series(True, df.index) - if starttime is not None: - mask &= df[timestamp_column_name] >= starttime - if endtime is not None: - mask &= df[timestamp_column_name] <= endtime - df = df[mask].reset_index(drop=True) - return df - - @classmethod - def _remove_unnamed_output_prefix(cls, value): - unnamed_prefix = f"{DatasetNameConstants.UNNAMED_OUTPUT_NAME}." - if isinstance(value, str): - return value.replace(unnamed_prefix, "") - else: - value.columns = [c.replace(unnamed_prefix, "") for c in value.columns] - - def _load_single_file_all_columns( - self, starttime, endtime, file_path, column_list, basket_column_list, struct_basket_sub_columns - ): - import numpy - import pandas - - df = pandas.read_parquet(file_path, columns=column_list) - self._remove_unnamed_output_prefix(df) - df = self._truncate_df(starttime, endtime, df) - - shape_columns = self._get_shape_columns_dict(column_list) - if shape_columns: - columns_to_drop = [] - for k, v in shape_columns.items(): - df[k] = numpy.array([a.reshape(s) for a, s in zip(df[k], df[v])], dtype=object) - columns_to_drop.append(v) - df = df.drop(columns=columns_to_drop) - return df - - def _create_empty_full_array(self, dtype, field_array_shape, pandas_dtype): - import numpy - - if numpy.issubdtype(dtype, numpy.integer) or numpy.issubdtype(dtype, numpy.floating): - field_array = numpy.full(field_array_shape, numpy.nan) - pandas_dtype = float - elif numpy.issubdtype(dtype, numpy.datetime64): - field_array = numpy.full(field_array_shape, None, dtype=dtype) - else: - field_array = numpy.full(field_array_shape, None, dtype=object) - pandas_dtype = object - return field_array, pandas_dtype - - def _convert_array_columns(self, arrow_columns, column_list, array_columns, shape_columns): - if not array_columns: - return arrow_columns, column_list - - new_column_values = [] - new_column_list = [] - - shape_columns_names = set(shape_columns.values()) - shape_column_arrays = {} - for c, v in zip(column_list, arrow_columns): - if c in shape_columns_names: - shape_column_arrays[c] = numpy.array(v) - - for c, v in zip(column_list, arrow_columns): - if c in shape_columns_names: - continue - if c in array_columns: - numpy_v = numpy.array(v, dtype=object) - shape_col_name = shape_columns.get(c) - if shape_col_name: - shape_col = shape_column_arrays[shape_col_name] - numpy_v = numpy.array([v.reshape(shape) for v, shape in zip(numpy_v, shape_col)], dtype=object) - new_column_values.append(numpy_v) - else: - new_column_values.append(v.to_pandas()) - new_column_list.append(c) - return new_column_values, new_column_list - - def _load_data_split_to_columns( - self, starttime, endtime, file_path, column_list, basket_column_list, struct_basket_sub_columns - ): - import numpy - import pandas - from pyarrow import Table - from pyarrow.parquet import ParquetFile - - value_arrays = [] - for c in column_list: - parquet_file = ParquetFile(os.path.join(file_path, f"{c}.parquet")) - value_arrays.append(parquet_file.read().columns[0]) - - array_columns = self._get_array_columns_set(column_list) - shape_columns = self._get_shape_columns_dict(column_list) - # If there are no array use the pyarrow table from arrays to pandas as it is faster, otherwise we need to convert columns since arrays are not - # pyarrow native types - if array_columns and value_arrays and value_arrays and value_arrays[0]: - value_arrays, column_list = self._convert_array_columns( - value_arrays, column_list, array_columns, shape_columns - ) - res = pandas.DataFrame.from_dict(dict(zip(column_list, value_arrays))) - else: - res = Table.from_arrays(value_arrays, column_list).to_pandas() - - self._remove_unnamed_output_prefix(res) - - if basket_column_list: - basket_dfs = [] - columns_l0 = list(res.columns) - columns_l1 = [""] * len(columns_l0) - - for column in basket_column_list: - value_type = self.metadata.dict_basket_columns[column].value_type - - if issubclass(value_type, csp.Struct): - columns = struct_basket_sub_columns.get(column, value_type.metadata().keys()) - value_columns = [f"{column}.{k}" for k in columns] - else: - assert ( - column not in struct_basket_sub_columns - ), f"Specified sub columns for {column} but it's not a struct" - value_columns = [column] - value_files = [os.path.join(file_path, f"{value_column}.parquet") for value_column in value_columns] - symbol_file = os.path.join(file_path, f"{column}__csp_symbol.parquet") - value_count_file = os.path.join(file_path, f"{column}__csp_value_count.parquet") - symbol_data = ParquetFile(symbol_file).read().columns[0].to_pandas() - value_data = [ParquetFile(value_file).read().columns[0].to_pandas() for value_file in value_files] - value_count_data_array = ParquetFile(value_count_file).read().columns[0].to_pandas().values - - if len(value_count_data_array) == 0 or value_count_data_array[-1] == 0: - continue - - cycle_indices = value_count_data_array.cumsum() - 1 - value_count_indices = numpy.indices(value_count_data_array.shape)[0] - good_index_mask = numpy.full(cycle_indices.shape, True) - good_index_mask[1:] = cycle_indices[1:] != cycle_indices[:-1] - - index_array = numpy.full(len(symbol_data), numpy.nan) - index_array[cycle_indices[good_index_mask]] = value_count_indices[good_index_mask] - basked_data_index = pandas.Series(index_array).bfill().astype(int).values - - data_dict = {"index": basked_data_index, "symbol": symbol_data} - for value_column, data in zip(value_columns, value_data): - data_dict[value_column] = data - - basket_data_raw = pandas.DataFrame(data_dict) - if basket_data_raw.empty: - continue - else: - all_symbols = basket_data_raw["symbol"].unique() - all_symbols.sort() - field_array_shape = (value_count_indices.size, all_symbols.size) - sym_indices = numpy.searchsorted(all_symbols, basket_data_raw.symbol.values) - - field_matrices = {} - for f in value_columns: - pandas_dtype = basket_data_raw[f].dtype - dtype = basket_data_raw[f].values.dtype - field_array, pandas_dtype = self._create_empty_full_array( - dtype, field_array_shape, pandas_dtype - ) - - field_array[basked_data_index, sym_indices] = basket_data_raw[f] - - field_matrices[f] = pandas.DataFrame(field_array, columns=all_symbols, dtype=pandas_dtype) - - # pandas pivot_table is WAAAAY to slow, we have to implement our own here - basket_data_aligned = pandas.concat( - field_matrices.values(), keys=list(field_matrices.keys()), axis=1 - ) - - if column == DatasetNameConstants.UNNAMED_OUTPUT_NAME: - l0, l1 = zip(*basket_data_aligned.columns) - if issubclass(value_type, csp.Struct): - unnamed_prefix_len = len(DatasetNameConstants.UNNAMED_OUTPUT_NAME) + 1 - l0 = [k[unnamed_prefix_len:] for k in l0] - basket_data_aligned.columns = list(zip(l0, l1)) - columns_l0 += l0 - else: - basket_data_aligned.columns = list(l1) - columns_l0 = None - columns_l1 += l1 - else: - columns_l0 += basket_data_aligned.columns.get_level_values(0).tolist() - columns_l1 += basket_data_aligned.columns.get_level_values(1).tolist() - basket_dfs.append(basket_data_aligned) - res = pandas.concat([res] + basket_dfs, axis=1) - if columns_l0: - res.columns = [columns_l0, columns_l1] - - return self._truncate_df(starttime, endtime, res) - - def _read_flat_data_from_files(self, symbol_file, value_files, num_values_to_skip, num_values_to_read): - import pyarrow.parquet - - parquet_files = [pyarrow.parquet.ParquetFile(symbol_file)] - if value_files: - parquet_files += [pyarrow.parquet.ParquetFile(file) for file in value_files.values()] - symbol_parquet_file = parquet_files[0] - for row_group_index in range(symbol_parquet_file.num_row_groups): - row_group = symbol_parquet_file.read_row_group(row_group_index, []) - if num_values_to_skip >= row_group.num_rows: - num_values_to_skip -= row_group.num_rows - continue - row_group_batches = [f.read_row_group(row_group_index).to_batches()[0] for f in parquet_files] - column_names = list(itertools.chain(*(batch.schema.names for batch in row_group_batches))) - column_values = list(itertools.chain(*(batch.columns for batch in row_group_batches))) - row_group_table = pyarrow.Table.from_arrays(column_values, column_names) - - cur_row_group_start_index = num_values_to_skip - num_values_to_skip = 0 - cur_row_group_num_values_to_read = int( - min(row_group.num_rows - cur_row_group_start_index, num_values_to_read) - ) - num_values_to_read -= int(cur_row_group_num_values_to_read) - yield row_group_table.slice(cur_row_group_start_index, cur_row_group_num_values_to_read) - if num_values_to_read == 0: - return - assert num_values_to_read == 0 - - def _load_flat_basket_data( - self, - starttime, - endtime, - timestamp_file_name, - symbol_file_name, - value_count_file_name, - value_files, - need_timestamp=True, - ): - import numpy - import pyarrow.parquet - - timestamps_arrow_array = pyarrow.parquet.ParquetFile(timestamp_file_name).read()[0] - timestamps_array = numpy.array(timestamps_arrow_array) - - if timestamps_array.size == 0: - return None - cond = numpy.full(timestamps_array.shape, True) - if starttime is not None: - cond = (timestamps_array >= numpy.datetime64(starttime)) & cond - if endtime is not None: - cond = (timestamps_array <= numpy.datetime64(endtime)) & cond - mask_indices = numpy.where(cond)[0] - if mask_indices.size == 0: - return None - start_index, end_index = mask_indices[0], mask_indices[-1] - value_counts = numpy.array(pyarrow.parquet.ParquetFile(value_count_file_name).read()[0]) - num_values_to_skip = value_counts[:start_index].sum() - value_counts_sub_array = value_counts[start_index : end_index + 1] - value_counts_sub_array_cumsum = value_counts_sub_array.cumsum() - num_values_to_read = value_counts_sub_array_cumsum[-1] if value_counts_sub_array_cumsum.size > 0 else 0 - if num_values_to_read == 0: - return None - res = pyarrow.concat_tables( - filter( - None, - self._read_flat_data_from_files(symbol_file_name, value_files, num_values_to_skip, num_values_to_read), - ) - ) - - if need_timestamp: - timestamps_full = numpy.full(num_values_to_read, None, timestamps_array.dtype) - timestamp_array_size = len(res) - timestamps_sub_array = timestamps_array[start_index : end_index + 1] - timestamps_sub_array = timestamps_sub_array[value_counts_sub_array != 0] - value_counts_sub_array_cumsum_aux = value_counts_sub_array_cumsum[ - value_counts_sub_array_cumsum < timestamp_array_size - ] - timestamps_full[0] = timestamps_sub_array[0] - timestamps_full[value_counts_sub_array_cumsum_aux] = timestamps_sub_array[1:] - null_indices = numpy.where(numpy.isnat(timestamps_full))[0] - non_null_indices = numpy.where(~numpy.isnat(timestamps_full))[0] - fill_indices = non_null_indices[numpy.searchsorted(non_null_indices, null_indices, side="right") - 1] - timestamps_full[null_indices] = timestamps_full[fill_indices] - res = res.add_column(0, self.metadata.timestamp_column_name, pyarrow.array(timestamps_full)) - return res - - def _get_flat_basket_df_for_period( - self, - basket_field_name: str, - symbol_column: str, - struct_fields: List[str] = None, - starttime: Optional[datetime.datetime] = None, - endtime: Optional[datetime.datetime] = None, - missing_range_handler: Callable[[datetime.datetime, datetime.datetime], bool] = None, - num_threads=1, - load_values=True, - concat=True, - need_timestamp=True, - ): - starttime = self._normalize_time(starttime, True) - endtime = self._normalize_time(endtime, True) - data_files = self.get_data_files_for_period(starttime, endtime, missing_range_handler) - - if basket_field_name is None: - basket_field_name = DatasetNameConstants.UNNAMED_OUTPUT_NAME - - if basket_field_name not in self.metadata.dict_basket_columns: - raise RuntimeError(f"No basket {basket_field_name} is returned from graph") - - symbol_files = [os.path.join(f, f"{basket_field_name}__csp_symbol.parquet") for f in data_files.values()] - if load_values: - value_type = self.metadata.dict_basket_columns[basket_field_name].value_type - if issubclass(value_type, csp.Struct): - struct_fields = struct_fields if struct_fields is not None else list(value_type.metadata().keys()) - value_files = [ - {field: os.path.join(f, f"{basket_field_name}.{field}.parquet") for field in struct_fields} - for f in data_files.values() - ] - else: - assert ( - struct_fields is None - ), f"Trying to provide struct_fields for non struct output {basket_field_name}" - value_files = [ - {basket_field_name: os.path.join(f, f"{basket_field_name}.parquet")} for f in data_files.values() - ] - else: - value_files = list(itertools.repeat({}, len(symbol_files))) - value_count_files = [ - os.path.join(f, f"{basket_field_name}__csp_value_count.parquet") for f in data_files.values() - ] - timestamp_files = [ - os.path.join(f, f"{self.metadata.timestamp_column_name}.parquet") for f in data_files.values() - ] - file_tuples = [ - (t, s, c, d) for t, s, c, d in zip(timestamp_files, symbol_files, value_count_files, value_files) - ] - - if num_threads > 1: - with ThreadPoolExecutor(max_workers=num_threads) as pool: - tasks = [ - pool.submit(self._load_flat_basket_data, starttime, endtime, *tup, need_timestamp=need_timestamp) - for tup in file_tuples - ] - results = list(task.result() for task in tasks) - else: - results = [ - self._load_flat_basket_data(starttime, endtime, *tup, need_timestamp=need_timestamp) - for tup in file_tuples - ] - results = list(filter(None, results)) - if not results: - return None - if concat: - import pyarrow - - return pyarrow.concat_tables(results) - else: - return results - - def get_flat_basket_df_for_period( - self, - symbol_column: str, - basket_field_name: str = None, - struct_fields: List[str] = None, - starttime: Optional[datetime.datetime] = None, - endtime: Optional[datetime.datetime] = None, - missing_range_handler: Callable[[datetime.datetime, datetime.datetime], bool] = None, - num_threads=1, - ): - res = self._get_flat_basket_df_for_period( - basket_field_name=basket_field_name, - symbol_column=symbol_column, - struct_fields=struct_fields, - starttime=starttime, - endtime=endtime, - missing_range_handler=missing_range_handler, - num_threads=num_threads, - concat=True, - ) - if res is None: - return None - res_df = res.to_pandas() - res_df.rename(columns={res_df.columns[1]: symbol_column}, inplace=True) - res_df.columns = [self._remove_unnamed_output_prefix(c) for c in res_df.columns] - return res_df - - def get_all_basket_ids_in_range( - self, - basket_field_name=None, - starttime: Optional[datetime.datetime] = None, - endtime: Optional[datetime.datetime] = None, - missing_range_handler: Callable[[datetime.datetime, datetime.datetime], bool] = None, - num_threads=1, - ): - import numpy - - symbol_column_name = "__csp_symbol__" - parquet_tables = self._get_flat_basket_df_for_period( - basket_field_name=basket_field_name, - symbol_column=symbol_column_name, - starttime=starttime, - endtime=endtime, - missing_range_handler=missing_range_handler, - num_threads=num_threads, - load_values=False, - concat=False, - need_timestamp=False, - ) - unique_arrays = [numpy.unique(numpy.array(t[0])) for t in parquet_tables] - return sorted(numpy.unique(numpy.concatenate(unique_arrays + unique_arrays))) - - def invalidate_cache( - self, starttime: Optional[datetime.datetime] = None, endtime: Optional[datetime.datetime] = None - ): - existing_data = self.get_data_files_for_period(starttime, endtime, lambda *args, **kwargs: True) - - if not existing_data: - return - - aggregation_period_utils = AggregationPeriodUtils(self.metadata.time_aggregation) - if starttime is not None: - agg_period_starttime = aggregation_period_utils.resolve_period_start(starttime) - if starttime != agg_period_starttime: - raise RuntimeError( - f"Trying to invalidate data starting on {starttime} - invalidation should be for full aggregation period (starting on {agg_period_starttime})" - ) - - if endtime is not None: - agg_period_endtime = aggregation_period_utils.resolve_period_end(endtime, exclusive_end=False) - if endtime != agg_period_endtime: - raise RuntimeError( - f"Trying to invalidate data ending on {endtime} - invalidation should be for full aggregation period (ending on {agg_period_endtime})" - ) - - root_folders_to_possibly_remove = set() - for k, v in existing_data.items(): - output_folder_name = self._dataset_partition.data_paths.get_output_folder_name(k[0]) - root_folders_to_possibly_remove.add(os.path.dirname(output_folder_name)) - logging.info(f"Removing {output_folder_name}") - shutil.rmtree(output_folder_name) - partition_root_folder = self._dataset_partition.data_paths.root_folder - while root_folders_to_possibly_remove: - aux = root_folders_to_possibly_remove - root_folders_to_possibly_remove = set() - for v in aux: - if not v.startswith(partition_root_folder): - continue - can_remove = True - for item in os.listdir(v): - if not item.startswith(".") and not item.endswith("_WIP"): - can_remove = False - break - if can_remove: - logging.info(f"Removing {v}") - shutil.rmtree(v) - root_folders_to_possibly_remove.add(os.path.dirname(v)) - - def get_data_df_for_period( - self, - starttime: Optional[datetime.datetime] = None, - endtime: Optional[datetime.datetime] = None, - missing_range_handler: Callable[[datetime.datetime, datetime.datetime], bool] = None, - data_loader_function: Callable[[str, Optional[List[str]]], object] = None, - column_list=None, - basket_column_list=None, - struct_basket_sub_columns: Optional[Dict[str, List[str]]] = None, - combine=True, - num_threads=1, - ): - """Retrieve a list of all files in the given time range (inclusive) - :param starttime: The start time of the period - :param endtime: The end time of the period - :param missing_range_handler: A function that handles missing data. Will be called with (missing_period_starttime, missing_period_endtime), - should return True, if the missing data is not an error, should return False otherwise (in which case an exception will be raised) - :param data_loader_function: A custom loader function that overrides the default pandas read. If non None specified, it implies combine=False. The - function will be called with 2 arguments (file_path, column_list). The file_path is the path of the file to be loaded and column_list is the list - of columns to be loaded (if column_list is None then all columns should be loaded) - :param column_list: The list of columns to be loaded. If None specified then all columns will be loaded - :param basket_column_list: The list of basket columns to be loaded. If None specified then all basket columns will be loaded. - :param struct_basket_sub_columns: A dictionary of {basket_name: List[str]} that specifies which sub columns of the basket should be loaded. Only valid for - struct baskets - :param combine: Combine the loaded data frames into a single dataframe (if False, will return a list of dataframes). If data_loader_function - is specified then combine is always treated as False - :param num_threads: The number of threads to use for loading the data - """ - starttime = self._normalize_time(starttime) - endtime = self._normalize_time(endtime) - if endtime is not None and isinstance(endtime, datetime.timedelta): - endtime = starttime + endtime - - data_files = self.get_data_files_for_period(starttime, endtime, missing_range_handler) - if data_loader_function is None: - if self.metadata.split_columns_to_files: - data_loader_function = self._load_data_split_to_columns - else: - data_loader_function = self._load_single_file_all_columns - else: - combine = False - - if column_list is None: - column_list = list(self.metadata.columns.keys()) - shape_columns = self._get_shape_columns_dict(column_list) - if shape_columns: - column_list += list(shape_columns.values()) - - if basket_column_list is None: - basket_columns = getattr(self.metadata, "dict_basket_columns", None) - basket_column_list = list(basket_columns.keys()) if basket_columns else None - - if struct_basket_sub_columns is None: - struct_basket_sub_columns = {} - if basket_column_list: - for col in basket_column_list: - value_type = self.metadata.dict_basket_columns[col].value_type - if issubclass(value_type, csp.Struct): - self.metadata.dict_basket_columns[col] - struct_basket_sub_columns[col] = list(value_type.metadata().keys()) - else: - if "" in struct_basket_sub_columns: - struct_basket_sub_columns[DatasetNameConstants.UNNAMED_OUTPUT_NAME] = struct_basket_sub_columns.pop("") - for k, v in struct_basket_sub_columns.items(): - if k not in basket_column_list: - raise RuntimeError(f"Specified sub columns for basket '{k}' but it's not loaded from file: {v}") - - if self.metadata.timestamp_column_name not in column_list: - column_list = [self.metadata.timestamp_column_name] + column_list - - if num_threads > 1: - with ThreadPoolExecutor(max_workers=num_threads) as pool: - tasks = [ - pool.submit( - data_loader_function, - file_start_time, - file_end_time, - data_file, - column_list, - basket_column_list, - struct_basket_sub_columns, - ) - for (file_start_time, file_end_time), data_file in data_files.items() - ] - dfs = [task.result() for task in tasks] - else: - dfs = [ - data_loader_function( - file_start_time, - file_end_time, - data_file, - column_list, - basket_column_list, - struct_basket_sub_columns, - ) - for (file_start_time, file_end_time), data_file in data_files.items() - ] - - dfs = [df for df in dfs if len(df) > 0] - - # For now we do it in one process, in the future might push it into multiprocessing load - for k, typ in self._dataset_partition.dataset.metadata.columns.items(): - serializer = self._cache_serializers.get(typ) - if serializer: - for df in dfs: - df[k] = df[k].apply(lambda v: serializer.deserialize_from_bytes(v) if v is not None else None) - - if combine: - if len(dfs) > 0: - import pandas - - return pandas.concat(dfs, ignore_index=True) - else: - return None - else: - return dfs diff --git a/csp/impl/wiring/cache_support/graph_building.py b/csp/impl/wiring/cache_support/graph_building.py deleted file mode 100644 index e4731e19..00000000 --- a/csp/impl/wiring/cache_support/graph_building.py +++ /dev/null @@ -1,745 +0,0 @@ -import copy -import os -from datetime import datetime, timedelta -from typing import Dict, List, Optional, Set, Tuple, TypeVar, Union - -import csp -from csp.adapters.parquet import ParquetOutputConfig -from csp.impl.config import CacheCategoryConfig, CacheConfig, Config -from csp.impl.managed_dataset.cache_user_custom_object_serializer import CacheObjectSerializer -from csp.impl.managed_dataset.dataset_metadata import TimeAggregation -from csp.impl.managed_dataset.dateset_name_constants import DatasetNameConstants -from csp.impl.managed_dataset.managed_dataset import ManagedDataset -from csp.impl.managed_dataset.managed_dataset_lock_file_util import ManagedDatasetLockUtil -from csp.impl.mem_cache import normalize_arg -from csp.impl.struct import Struct -from csp.impl.types import tstype -from csp.impl.types.common_definitions import ArgKind, OutputBasketContainer, OutputDef -from csp.impl.types.tstype import ts -from csp.impl.types.typing_utils import CspTypingUtils -from csp.utils.qualified_name_utils import QualifiedNameUtils - -# relative to avoid cycles -from ..context import Context -from ..edge import Edge -from ..node import node -from ..outputs import OutputsContainer -from ..signature import Signature -from ..special_output_names import UNNAMED_OUTPUT_NAME -from .cache_config_resolver import CacheConfigResolver - -T = TypeVar("T") - - -class _UnhashableObjectWrapper: - def __init__(self, obj): - self._obj = obj - - def __hash__(self): - return hash(id(self._obj)) - - def __eq__(self, other): - return id(self._obj) == id(other._obj) - - -class _CacheManagerKey: - def __init__(self, scalars, *extra_args): - self._normalized_scalars = tuple(self._normalize_scalars(scalars)) + tuple(extra_args) - self._hash = hash(self._normalized_scalars) - - def __hash__(self): - return self._hash - - def __eq__(self, other): - return self._hash == other._hash and self._normalized_scalars == other._normalized_scalars - - @classmethod - def _normalize_scalars(cls, scalars): - for scalar in scalars: - normalized_scalar = normalize_arg(scalar) - try: - hash(normalized_scalar) - except TypeError: - yield _UnhashableObjectWrapper(scalar) - else: - yield normalized_scalar - - -class WrappedStructEdge(Edge): - def __init__(self, wrapped_edge, parquet_reader, field_map): - super().__init__( - tstype=wrapped_edge.tstype, - nodedef=wrapped_edge.nodedef, - output_idx=wrapped_edge.output_idx, - basket_idx=wrapped_edge.basket_idx, - ) - self._parquet_reader = parquet_reader - self._field_map = field_map - self._single_field_edges = {} - - def __getattr__(self, key): - res = self._single_field_edges.get(key, None) - if res is not None: - return res - elemtype = self.tstype.typ.metadata().get(key) - if elemtype is None: - raise AttributeError("'%s' object has no attribute '%s'" % (self.tstype.typ.__name__, key)) - res = self._parquet_reader.subscribe_all(elemtype, field_map=self._field_map[key]) - self._single_field_edges[key] = res - return res - - -class WrappedCachedStructBasket(dict): - def __init__(self, typ, name, wrapped_edges, parquet_reader): - super().__init__(**wrapped_edges) - self._typ = typ - self._name = name - self._parquet_reader = parquet_reader - self._shape = None - self._field_dicts = {} - - def get_basket_field(self, field_name): - res = self._field_dicts.get(field_name) - if res is None: - if self._shape is None: - self._shape = list(self.keys()) - # res = self._parquet_reader.subscribe_dict_basket_struct_column(self._typ, self._name, self._shape, field_name) - res = {k: getattr(v, field_name) for k, v in self.items()} - self._field_dicts[field_name] = res - return res - - -class CacheCategoryOverridesTree: - """A utility class that is used to resolved category overrides for a given category like ['level_1', 'level_2', ...] - The basic implementation is a tree of levels - """ - - def __init__(self, cur_level_key: str = None, cur_level_value: CacheCategoryConfig = None, parent=None): - self.cur_level_key = cur_level_key - self.cur_level_value = cur_level_value - self.parent = parent - self.children = {} - - def _get_path_str(self): - path = [] - cur = self - while cur is not None: - if cur.cur_level_value is not None: - path.append(cur.cur_level_value) - cur = cur.parent - return str(reversed(path)) - - def __str__(self): - return f"CacheCategoryOverridesTree({self._get_path_str()}:{self.cur_level_value})" - - def __repr__(self): - return self.__str__() - - def _get_child(self, key: str): - res = self.children.get(key) - if res is None: - res = CacheCategoryOverridesTree(cur_level_key=key, parent=self) - self.children[key] = res - return res - - def _add_override(self, override: CacheCategoryConfig, cur_level_index=0): - if cur_level_index < len(override.category): - self._get_child(override.category[cur_level_index])._add_override(override, cur_level_index + 1) - else: - if self.cur_level_value is not None: - raise RuntimeError(f"Trying to override cache directory for {self._get_path_str()} more than once") - self.cur_level_value = override - - def resolve_root_folder(self, category: List[str], cur_level: int = 0): - """ - :param category: The category of the dataset - :return: A config override or the default config for the given category - """ - if cur_level == len(category): - return self.cur_level_value - - # We want the longest match possible, so first attempt resolving using children - child = self.children.get(category[cur_level]) - if child is not None: - child_res = child.resolve_root_folder(category, cur_level + 1) - if child_res is not None: - return child_res - - return self.cur_level_value - - @classmethod - def construct_from_cache_config(cls, cache_config: Optional[CacheConfig] = None): - res = CacheCategoryOverridesTree() - - if cache_config is None: - raise RuntimeError("data_folder must be set in global cache_config to use caching") - - res.cur_level_value = cache_config - - if hasattr(cache_config, "category_overrides"): - for override in cache_config.category_overrides: - res._add_override(override) - return res - - @classmethod - def construct_from_config(cls, config: Optional[Config] = None): - if config is not None and hasattr(config, "cache_config"): - return cls.construct_from_cache_config(config.cache_config) - return CacheCategoryOverridesTree() - - -class ContextCacheInfo(Struct): - """Graph building context storage class - - Should be stored inside a Context object, contains all the data collected during the graph building time to enable support - of caching - """ - - # A dictionary from tuple (function_id, scalar_arguments) to a corresponding GraphCacheManager - cache_managers: dict - # Dictionary of a graph function id to MangedDataset and field mapping that corresponds to it - managed_datasets_by_graph_object: Dict[object, Tuple[ManagedDataset, Dict[str, str]]] - # The object that is used to resolve the underlying sub folders of the given graph - cache_data_paths_resolver: CacheConfigResolver - - -class _EngineLockRelease: - """A utility that will release the given lock on engine stop""" - - def __init__(self): - self.cur_lock = None - - @node - def release_node(self): - with csp.stop(): - if self.cur_lock: - self.cur_lock.unlock() - self.cur_lock = None - - -class _MissingPeriodCallback: - def __init__(self): - self.first_missing_period = None - - def __call__(self, start, end): - if self.first_missing_period is None: - self.first_missing_period = (start, end) - return True - - -@node -def _deserialize_value(value: ts[bytes], type_serializer: CacheObjectSerializer, typ: "T") -> ts["T"]: - if csp.ticked(value): - return type_serializer.deserialize_from_bytes(value) - - -class GraphBuildPartitionCacheManager(object): - """A utility class that "manages" cache at graph building time - - One instance is created per (dataset, partition_values) - """ - - def __init__( - self, - function_name, - dataset: ManagedDataset, - partition_values, - expected_outputs, - cache_options, - csp_cache_start=None, - csp_cache_end=None, - csp_timestamp_shift=None, - ): - self._function_name = function_name - self.dataset_partition = dataset.get_partition(partition_values) - self._outputs = None - self._written_outputs = None - self._cache_options = cache_options - self._context = Context.TLS.instance - self._csp_cache_start = csp_cache_start if csp_cache_start else Context.TLS.instance.start_time - self._csp_cache_end = csp_cache_end if csp_cache_end else Context.TLS.instance.end_time - self._csp_timestamp_shift = csp_timestamp_shift if csp_timestamp_shift else timedelta() - missing_period_callback = _MissingPeriodCallback() - data_for_period, is_full_period_covered = self.get_data_for_period( - self._csp_cache_start, self._csp_cache_end, missing_range_handler=missing_period_callback - ) - cache_config = Context.TLS.instance.config.cache_config - self._first_missing_period = missing_period_callback.first_missing_period - if is_full_period_covered: - from csp.adapters.parquet import ParquetReader - - cache_serializers = cache_config.cache_serializers - # We need to release lock at the end of the run, generator is not guaranteed to do so if it's not called - engine_lock_releaser = _EngineLockRelease() - reader = ParquetReader( - filename_or_list=self._read_files_provider( - dataset, data_for_period, cache_config.lock_file_permissions, engine_lock_releaser - ), - time_column=self.dataset_partition.dataset.metadata.timestamp_column_name, - split_columns_to_files=self.dataset_partition.dataset.metadata.split_columns_to_files, - start_time=csp_cache_start, - end_time=csp_cache_end, - allow_overlapping_periods=True, - time_shift=self._csp_timestamp_shift, - ) - # We need to instantiate the node ot have it run - engine_lock_releaser.release_node() - self._outputs = OutputsContainer() - is_unnamed_output = False - for output in expected_outputs: - output_name = output.name - if output_name is None: - output_name = DatasetNameConstants.UNNAMED_OUTPUT_NAME - is_unnamed_output = True - if output.kind == ArgKind.BASKET_TS: - output_dict = reader.subscribe_dict_basket(typ=output.typ.typ, name=output_name, shape=output.shape) - output_value = WrappedCachedStructBasket(output.typ.typ, output_name, output_dict, reader) - else: - assert output.kind == ArgKind.TS - if isinstance(output.typ.typ, type) and issubclass(output.typ.typ, Struct): - # Reverse field mapping - write_field_map = cache_options.field_mapping.get(output_name) - field_map = {v: k for k, v in write_field_map.items()} - # Wrap the edge to allow single column reading - output_value = WrappedStructEdge( - reader.subscribe_all(typ=output.typ.typ, field_map=field_map), reader, write_field_map - ) - else: - type_serializer = cache_serializers.get(output.typ.typ) - if type_serializer: - output_value = _deserialize_value( - reader.subscribe_all( - typ=bytes, field_map=cache_options.field_mapping.get(output_name, output_name) - ), - type_serializer, - output.typ.typ, - ) - else: - output_value = reader.subscribe_all( - typ=output.typ.typ, field_map=cache_options.field_mapping.get(output_name, output_name) - ) - if is_unnamed_output: - assert len(expected_outputs) == 1 - self._outputs = output_value - else: - self._outputs[output_name] = output_value - - def _read_files_provider(self, dataset, data_files, lock_file_permissions, engine_lock_releaser): - assert data_files - items_iter = iter(data_files.items()) - finished = False - num_failures = 0 - next_filename = None - lock_util = ManagedDatasetLockUtil(lock_file_permissions) - while not finished: - if num_failures > 10: - raise RuntimeError( - f"Failed to read cached files too many times, last attempted file is {next_filename}" - ) - try: - (next_start_time, next_end_time), next_filename = next(items_iter) - except StopIteration: - finished = True - continue - is_file = os.path.isfile(next_filename) - is_dir = os.path.isdir(next_filename) - if not is_file and not is_dir: - data_files, _ = self.get_data_for_period(next_start_time, self._csp_cache_end) - assert data_files - items_iter = iter(data_files.items()) - num_failures += 1 - with dataset.use_lock_context(): - lock = lock_util.read_lock(next_filename, is_file) - lock.lock() - engine_lock_releaser.cur_lock = lock - try: - if os.path.exists(next_filename): - num_failures = 0 - yield next_filename - else: - data_files, _ = self.get_data_for_period(next_start_time, self._csp_cache_end) - assert data_files - items_iter = iter(data_files.items()) - num_failures += 1 - finally: - lock.unlock() - engine_lock_releaser.cur_lock = None - - @property - def first_missing_period(self): - return self._first_missing_period - - @property - def is_force_cache_read(self): - return hasattr(self._cache_options, "data_timestamp_column_name") - - @property - def outputs(self): - return self._outputs - - @property - def written_outputs(self): - return self._written_outputs - - @classmethod - def _resolve_anonymous_dataset_category(cls, cache_options): - category = getattr(cache_options, "category", None) - if category is None: - category = ["csp_unnamed_cache"] - return category - - @classmethod - def get_dataset_for_func(cls, graph, func, cache_options, data_folder): - category = cls._resolve_anonymous_dataset_category(cache_options) - cache_config_resolver = None - if isinstance(data_folder, Config): - cache_config_resolver = CacheConfigResolver(data_folder.cache_config) - elif isinstance(data_folder, CacheConfig): - cache_config_resolver = CacheConfigResolver(data_folder) - if isinstance(data_folder, CacheConfigResolver): - cache_config_resolver = data_folder - - if cache_config_resolver is None: - cache_config_resolver = CacheConfigResolver(CacheConfig(data_folder=data_folder)) - - cache_config = cache_config_resolver.resolve_cache_config(graph, category) - - dataset_name = getattr(cache_options, "dataset_name", None) if cache_options else None - if dataset_name is None: - dataset_name = f"{QualifiedNameUtils.get_qualified_object_name(func)}" - return ManagedDataset.load_from_disk(cache_config=cache_config, name=dataset_name, data_category=category) - - @classmethod - def _resolve_dataset(cls, graph, func, signature, cache_options, expected_outputs, tvars): - context_cache_data = Context.TLS.instance.cache_data - # We might have the dataset already - - func_id = id(func) - existing_dataset_and_field_mapping = context_cache_data.managed_datasets_by_graph_object.get(func_id) - if existing_dataset_and_field_mapping is not None: - cache_options.field_mapping = existing_dataset_and_field_mapping[1] - return existing_dataset_and_field_mapping[0] - - dataset_name = getattr(cache_options, "dataset_name", None) - partition_columns = {input.name: input.typ for input in signature.scalars} - column_types = {} - dict_basket_column_types = {} - - if len(expected_outputs) == 1 and expected_outputs[0].name is None: - timestamp_column_auto_prefix = f"{DatasetNameConstants.UNNAMED_OUTPUT_NAME}." - cur_def = expected_outputs[0] - expected_outputs = ( - OutputDef( - name=DatasetNameConstants.UNNAMED_OUTPUT_NAME, - typ=cur_def.typ, - kind=cur_def.kind, - ts_idx=cur_def.ts_idx, - shape=cur_def.shape, - ), - ) - else: - timestamp_column_auto_prefix = "" - - field_mapping = cache_options.field_mapping - for i, out in enumerate(expected_outputs): - if out.kind == ArgKind.BASKET_TS: - # Let's make sure that we're handling dict basket - if isinstance(out.shape, list) and tstype.isTsType(out.typ): - dict_basket_column_types[out.name] = signature.resolve_basket_key_type(i, tvars), out.typ.typ - else: - raise NotImplementedError(f"Caching of basket output {out.name} of type {out.typ} is unsupported") - elif isinstance(out.typ.typ, type) and issubclass(out.typ.typ, Struct): - struct_field_mapping = field_mapping.get(out.name) - if struct_field_mapping is None: - if cache_options.prefix_struct_names: - struct_col_types = {f"{out.name}.{k}": v for k, v in out.typ.typ.metadata().items()} - else: - struct_col_types = out.typ.typ.metadata() - column_types.update(struct_col_types) - struct_field_mapping = {n1: n2 for n1, n2 in zip(out.typ.typ.metadata(), struct_col_types)} - field_mapping[out.name] = struct_field_mapping - else: - for k, v in out.typ.typ.metadata().items(): - cache_col_name = struct_field_mapping.get(k, k) - column_types[cache_col_name] = v - else: - name = field_mapping.get(out.name, out.name) - column_types[name] = out.typ.typ - - if hasattr(cache_options, "data_timestamp_column_name"): - timestamp_column_name = timestamp_column_auto_prefix + cache_options.data_timestamp_column_name - else: - timestamp_column_name = "csp_timestamp" - - category = cls._resolve_anonymous_dataset_category(cache_options) - resolved_cache_config = context_cache_data.cache_data_paths_resolver.resolve_cache_config(graph, category) - dataset = cls._create_dataset( - func, - resolved_cache_config, - category, - dataset_name=dataset_name, - timestamp_column_name=timestamp_column_name, - columns_types=column_types, - partition_columns=partition_columns, - split_columns_to_files=cache_options.split_columns_to_files, - time_aggregation=cache_options.time_aggregation, - dict_basket_column_types=dict_basket_column_types, - ) - context_cache_data.managed_datasets_by_graph_object[func_id] = dataset, field_mapping - return dataset - - @classmethod - def _get_qualified_function_name(cls, func): - return f"{func.__module__}.{func.__name__}" - - @classmethod - def _create_dataset( - cls, - func, - cache_config, - category, - dataset_name, - timestamp_column_name: str = None, - columns_types: Dict[str, object] = None, - partition_columns: Dict[str, type] = None, - *, - split_columns_to_files: Optional[bool], - time_aggregation: TimeAggregation, - dict_basket_column_types: Dict[type, Union[type, Tuple[type, type]]], - ): - name = dataset_name if dataset_name else f"{QualifiedNameUtils.get_qualified_object_name(func)}" - dataset = ManagedDataset( - name=name, - category=category, - cache_config=cache_config, - timestamp_column_name=timestamp_column_name, - columns_types=columns_types, - partition_columns=partition_columns, - split_columns_to_files=split_columns_to_files, - time_aggregation=time_aggregation, - dict_basket_column_types=dict_basket_column_types, - ) - return dataset - - def get_data_for_period(self, starttime: datetime, endtime: datetime, missing_range_handler): - res, full_period_covered = self.dataset_partition.get_data_for_period( - starttime=starttime - self._csp_timestamp_shift, - endtime=endtime - self._csp_timestamp_shift, - missing_range_handler=missing_range_handler, - ) - if self._csp_timestamp_shift: - res = { - (start + self._csp_timestamp_shift, end + self._csp_timestamp_shift): path - for (start, end), path in res.items() - } - - return res, full_period_covered - - def _fix_outputs_for_caching(self, outputs): - if isinstance(outputs, OutputsContainer) and UNNAMED_OUTPUT_NAME in outputs: - outputs_dict = dict(outputs._items()) - outputs_dict[DatasetNameConstants.UNNAMED_OUTPUT_NAME] = outputs_dict.pop(UNNAMED_OUTPUT_NAME) - return OutputsContainer(**outputs_dict) - else: - return outputs - - def cache_outputs(self, outputs): - from csp.impl.managed_dataset.managed_parquet_writer import create_managed_parquet_writer_node - - outputs = self._fix_outputs_for_caching(outputs) - assert self._written_outputs is None - self._written_outputs = outputs - create_managed_parquet_writer_node( - function_name=self._function_name, - dataset_partition=self.dataset_partition, - values=outputs, - field_mapping=self._cache_options.field_mapping, - config=getattr(self._cache_options, "parquet_output_config", None), - data_timestamp_column_name=getattr(self.dataset_partition.dataset.metadata, "timestamp_column_name", None), - controlled_cache=self._cache_options.controlled_cache, - default_cache_enabled=self._cache_options.default_cache_enabled, - ) - - @classmethod - def create_cache_manager( - cls, - graph, - func, - signature, - non_ignored_scalars, - all_scalars, - cache_options, - expected_outputs, - tvars, - csp_cache_start=None, - csp_cache_end=None, - csp_timestamp_shift=None, - ): - if not hasattr(Context.TLS, "instance"): - raise RuntimeError("Graph must be instantiated under a wiring context") - - assert Context.TLS.instance.start_time is not None - assert Context.TLS.instance.end_time is not None - key = _CacheManagerKey( - all_scalars, id(func), tuple(tvars.items()), csp_cache_start, csp_cache_end, csp_timestamp_shift - ) - existing_cache_manager = Context.TLS.instance.cache_data.cache_managers.get(key) - if existing_cache_manager is not None: - return existing_cache_manager - - # We're going to modify field mapping, so we need to make a copy - cache_options = cache_options.copy() - cache_options.field_mapping = dict(cache_options.field_mapping) - - for output in expected_outputs: - if output.kind == ArgKind.TS and isinstance(output.typ.typ, type) and issubclass(output.typ.typ, Struct): - struct_field_map = cache_options.field_mapping.get(output.name) - if struct_field_map: - # We don't want to omit any fields from the cache otherwise read data will be different from what's written - # so whatever the user doesn't map, we will map to the original field. - full_field_map = copy.copy(struct_field_map) - for k in output.typ.typ.metadata(): - if k not in full_field_map: - full_field_map[k] = k - cache_options.field_mapping[output.name] = full_field_map - - dataset = cls._resolve_dataset(graph, func, signature, cache_options, expected_outputs, tvars) - - partition_values = dict(zip((i.name for i in signature.scalars), non_ignored_scalars)) - res = GraphBuildPartitionCacheManager( - function_name=QualifiedNameUtils.get_qualified_object_name(func), - dataset=dataset, - partition_values=partition_values, - expected_outputs=expected_outputs, - cache_options=cache_options, - csp_cache_start=csp_cache_start, - csp_cache_end=csp_cache_end, - csp_timestamp_shift=csp_timestamp_shift, - ) - Context.TLS.instance.cache_data.cache_managers[key] = res - return res - - -class GraphCacheOptions(Struct): - # The name of the dataset to which the data will be written - optional - dataset_name: str - # An optional column mapping for scalar time series, the mapping should be string (the name of the column in parquet file), - # for struct time series it should be a map of {struct_field_name:column_name}. - field_mapping: Dict[str, Union[str, Dict[str, str]]] - # A boolean that specifies whether struct fields should be prefixed with the output name. For example for a graph output - # named "o" and a field named "f", if prefix_struct_names is True then the column will be "o.f" else the column will be "f" - prefix_struct_names: bool = True - # This is an advanced usage of graph caching, in some instances we want to override timestamp and write data with custom timestamp, - # if this is specified, the values from the given column will be used as the timestamp column - data_timestamp_column_name: str - # Optional category specification for the dataset, can only be specified if using the default dataset. An example of category - # can be ['daily_statistics', 'market_prices']. This category will be part of the files path. Additionally, cache paths can be - # overridden for different categories. - category: List[str] - # Inputs to ignore for caching purposes - ignored_inputs: Set[str] - # A boolean that specifies whether each column should be written to a separate file - split_columns_to_files: bool - # The configuration of the written files - parquet_output_config: ParquetOutputConfig - # Data aggregation period - time_aggregation: TimeAggregation = TimeAggregation.DAY - # A boolean flag that specifies whether the node/graph provides a ts that specifies that the output should be cached - controlled_cache: bool = False - # The default value of whether at start the outputs should be cached, ignored if controlled_cache is False - default_cache_enabled: bool = False - - -class ResolvedGraphCacheOptions(Struct): - """A struct with all resolved graph cache options""" - - dataset_name: str - field_mapping: dict - prefix_struct_names: bool - data_timestamp_column_name: str - category: List[str] - ignored_inputs: Set[str] - split_columns_to_files: bool - parquet_output_config: ParquetOutputConfig - time_aggregation: TimeAggregation - controlled_cache: bool - default_cache_enabled: bool - - -def resolve_graph_cache_options(signature: Signature, cache_enabled, cache_options: GraphCacheOptions): - """Called at graph building time to validate that the given caching options are valid for given signature - - :param signature: The signature of the cached graph - :param cache_enabled: A boolean or that specifies whether cache enabled - :param cache_options: Graph cache read/write options - :return: - """ - if cache_enabled: - if cache_options is None: - cache_options = GraphCacheOptions() - - field_mapping = getattr(cache_options, "field_mapping", None) - split_columns_to_files = getattr(cache_options, "split_columns_to_files", None) - has_basket_outputs = False - - if signature._ts_inputs: - all_ts_ignored = False - ignored_inputs = getattr(cache_options, "ignored_inputs", None) - if ignored_inputs: - all_ts_ignored = True - for input in signature._ts_inputs: - if input.name not in ignored_inputs: - all_ts_ignored = False - if not all_ts_ignored: - raise NotImplementedError("Caching of graph with ts arguments is unsupported") - if not signature._outputs: - raise NotImplementedError("Caching of graph without outputs is unsupported") - for output in signature._outputs: - if isinstance(output.typ, OutputBasketContainer): - if CspTypingUtils.get_origin(output.typ.typ) is List: - raise NotImplementedError("Caching of list basket outputs is unsupported") - has_basket_outputs = True - elif not tstype.isTsType(output.typ): - if tstype.isTsStaticBasket(output.typ): - if CspTypingUtils.get_origin(output.typ) is List: - raise NotImplementedError("Caching of list basket outputs is unsupported") - else: - raise TypeError( - f"Cached output basket {output.name} must have shape provided using with_shape or with_shape_of" - ) - assert tstype.isTsType(output.typ) or isinstance(output.typ, OutputBasketContainer) - ignored_inputs = getattr(cache_options, "ignored_inputs", set()) - for input in signature.scalars: - if input.name not in ignored_inputs and not ManagedDataset.is_supported_partition_type(input.typ): - raise NotImplementedError( - f"Caching is unsupported for argument type {input.typ} (argument {input.name})" - ) - resolved_cache_options = ResolvedGraphCacheOptions(prefix_struct_names=cache_options.prefix_struct_names) - - for attr in ( - "dataset_name", - "data_timestamp_column_name", - "category", - "ignored_inputs", - "parquet_output_config", - "time_aggregation", - "controlled_cache", - "default_cache_enabled", - ): - if hasattr(cache_options, attr): - setattr(resolved_cache_options, attr, getattr(cache_options, attr)) - if has_basket_outputs: - if split_columns_to_files is False: - raise RuntimeError("Cached graph with output basket must set split_columns_to_files to True") - split_columns_to_files = True - elif split_columns_to_files is None: - split_columns_to_files = False - - resolved_cache_options.split_columns_to_files = split_columns_to_files - - resolved_cache_options.field_mapping = {} if field_mapping is None else field_mapping - else: - if cache_options: - raise RuntimeError("cache_options must be None if caching is disabled") - resolved_cache_options = None - return resolved_cache_options diff --git a/csp/impl/wiring/cache_support/partition_files_container.py b/csp/impl/wiring/cache_support/partition_files_container.py deleted file mode 100644 index 319bae70..00000000 --- a/csp/impl/wiring/cache_support/partition_files_container.py +++ /dev/null @@ -1,99 +0,0 @@ -import threading -from datetime import datetime -from typing import Dict, Tuple - -from csp.adapters.output_adapters.parquet import ParquetOutputConfig - - -class SinglePartitionFiles: - def __init__(self, dataset_partition, parquet_output_config): - self._dataset_partition = dataset_partition - self._parquet_output_config = parquet_output_config - # A mapping of (start, end)->file_path - self._files_by_period: Dict[Tuple[datetime, datetime], str] = {} - - @property - def dataset_partition(self): - return self._dataset_partition - - @property - def parquet_output_config(self): - return self._parquet_output_config - - @property - def files_by_period(self): - return self._files_by_period - - def add_file(self, start_time: datetime, end_time: datetime, file_path: str): - self._files_by_period[(start_time, end_time)] = file_path - - -class PartitionFileContainer: - TLS = threading.local() - - def __init__(self, cache_config): - # A mapping of id(dataset_partition)->(start_time,end_time)->file_path - self._files_by_partition_and_period: Dict[int, SinglePartitionFiles] = {} - self._cache_config = cache_config - - @classmethod - def get_instance(cls): - return cls.TLS.instance - - def __enter__(self): - assert not hasattr(self.TLS, "instance") - self.TLS.instance = self - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - try: - # We don't want to finalize cache if there is an exception - if exc_val is None: - # First let's publish all files - for partition_files in self._files_by_partition_and_period.values(): - for (start_time, end_time), file_path in partition_files.files_by_period.items(): - partition_files.dataset_partition.publish_file( - file_path, - start_time, - end_time, - self._cache_config.data_file_permissions, - lock_file_permissions=self._cache_config.lock_file_permissions, - ) - - # Let's now merge whatever we can - if self._cache_config.merge_existing_files: - for partition_files in self._files_by_partition_and_period.values(): - for (start_time, end_time), file_path in partition_files.files_by_period.items(): - partition_files.dataset_partition.merge_files( - start_time, - end_time, - cache_config=self._cache_config, - parquet_output_config=partition_files.parquet_output_config, - ) - partition_files.dataset_partition.cleanup_unneeded_files( - start_time=start_time, end_time=end_time, cache_config=self._cache_config - ) - finally: - del self.TLS.instance - - @property - def files_by_partition(self): - return self._files_by_partition_and_period - - def add_generated_file( - self, - dataset_partition, - start_time: datetime, - end_time: datetime, - file_path: str, - parquet_output_config: ParquetOutputConfig, - ): - key = id(dataset_partition) - - single_partition_files = self._files_by_partition_and_period.get(key) - if single_partition_files is None: - single_partition_files = SinglePartitionFiles(dataset_partition, parquet_output_config) - self._files_by_partition_and_period[key] = single_partition_files - else: - assert single_partition_files.parquet_output_config == parquet_output_config - single_partition_files.add_file(start_time, end_time, file_path) diff --git a/csp/impl/wiring/cache_support/runtime_cache_manager.py b/csp/impl/wiring/cache_support/runtime_cache_manager.py deleted file mode 100644 index 4ef9da7a..00000000 --- a/csp/impl/wiring/cache_support/runtime_cache_manager.py +++ /dev/null @@ -1,74 +0,0 @@ -from typing import Dict, List - -import csp -from csp.impl.managed_dataset.managed_dataset import ManagedDataset, ManagedDatasetPartition -from csp.impl.wiring.cache_support.partition_files_container import PartitionFileContainer - - -class _DatasetRecord(csp.Struct): - dataset: ManagedDataset - read: bool = False - write: bool = False - - -class RuntimeCacheManager: - def __init__(self, cache_config, cache_data): - self._cache_config = cache_config - self._partition_file_container = PartitionFileContainer(cache_config) - self._datasets: Dict[int, _DatasetRecord] = {} - self._dataset_write_partitions: List[ManagedDatasetPartition] = [] - self._dataset_read_partitions: List[ManagedDatasetPartition] = [] - self._read_locks = [] - for graph_cache_manager in cache_data.cache_managers.values(): - if graph_cache_manager.outputs is not None: - self.add_read_partition(graph_cache_manager.dataset_partition) - else: - self.add_write_partition(graph_cache_manager.dataset_partition) - - def _validate_and_lock_datasets(self): - res = [] - for dataset_record in self._datasets.values(): - res.append( - dataset_record.dataset.validate_and_lock_metadata( - lock_file_permissions=self._cache_config.lock_file_permissions, - data_file_permissions=self._cache_config.data_file_permissions, - read=dataset_record.read, - write=dataset_record.write, - ) - ) - return res - - def __enter__(self): - self._read_locks = [] - self._read_locks.extend(self._validate_and_lock_datasets()) - for partition in self._dataset_write_partitions: - partition.create_root_folder(self._cache_config) - - self._partition_file_container.__enter__() - - def __exit__(self, exc_type, exc_val, exc_tb): - try: - for lock in self._read_locks: - lock.unlock() - finally: - self._read_locks.clear() - - self._partition_file_container.__exit__(exc_type, exc_val, exc_tb) - - def _add_dataset(self, dataset: ManagedDataset, read=False, write=False): - dataset_id = id(dataset) - dataset_record = self._datasets.get(dataset_id) - if dataset_record is None: - dataset_record = _DatasetRecord(dataset=dataset, read=read, write=write) - self._datasets[dataset_id] = dataset_record - return - dataset_record.read |= read - dataset_record.write |= write - - def add_write_partition(self, dataset_partition: ManagedDatasetPartition): - self._add_dataset(dataset_partition.dataset, write=True) - self._dataset_write_partitions.append(dataset_partition) - - def add_read_partition(self, dataset_partition: ManagedDatasetPartition): - self._add_dataset(dataset_partition.dataset, read=True) - self._dataset_read_partitions.append(dataset_partition) diff --git a/csp/impl/wiring/context.py b/csp/impl/wiring/context.py index 73156fcd..d9bacf53 100644 --- a/csp/impl/wiring/context.py +++ b/csp/impl/wiring/context.py @@ -1,8 +1,6 @@ import threading from datetime import datetime -from typing import Optional -from csp.impl.config import Config from csp.impl.mem_cache import CspGraphObjectsMemCache @@ -13,24 +11,13 @@ def __init__( self, start_time: datetime = None, end_time: datetime = None, - config: Optional[Config] = None, is_global_instance: bool = False, ): - from csp.impl.wiring.cache_support.cache_config_resolver import CacheConfigResolver - from csp.impl.wiring.cache_support.graph_building import ContextCacheInfo - self.roots = [] self.start_time = start_time self.end_time = end_time self.mem_cache = None self.delayed_nodes = [] - self.config = config - - self.cache_data = ContextCacheInfo( - cache_managers={}, - managed_datasets_by_graph_object={}, - cache_data_paths_resolver=CacheConfigResolver(getattr(config, "cache_config", None)), - ) self._is_global_instance = is_global_instance if hasattr(self.TLS, "instance") and self.TLS.instance._is_global_instance: @@ -39,11 +26,7 @@ def __init__( # Note we don't copy everything here from the global context since it may cause undesired behaviors. # roots - We want to accumulate roots that are only relevant for the current run, not all the roots in the global context. # start_time, end_time - not even set in the global context - # mem_cache - can cause issues with dynamic graph nodes and with caching (i.e cached graphs can be built differently based on the run period - # and the data that is available on the disk) - # config - can differ between runs - # cache_data - Generally we don't support caching for global contexts. There are a bunch of issues around it at the moment (one reason is that - # at wiring time we need to know start and end time of the graph). + # mem_cache - can cause issues with dynamic graph nodes for delayed_node in prev.delayed_nodes: # The copy of the delayed node will add all the new delayed nodes to the current context delayed_node.copy() diff --git a/csp/impl/wiring/graph.py b/csp/impl/wiring/graph.py index 392bd167..0446b65c 100644 --- a/csp/impl/wiring/graph.py +++ b/csp/impl/wiring/graph.py @@ -1,31 +1,13 @@ -import datetime import inspect -import threading import types -from contextlib import contextmanager -import csp.impl.wiring.edge from csp.impl.constants import UNSET from csp.impl.error_handling import ExceptionContext -from csp.impl.managed_dataset.dateset_name_constants import DatasetNameConstants from csp.impl.mem_cache import csp_memoized_graph_object, function_full_name -from csp.impl.types.common_definitions import InputDef from csp.impl.types.instantiation_type_resolver import GraphOutputTypeResolver -from csp.impl.wiring import Signature -from csp.impl.wiring.cache_support.dataset_partition_cached_data import DataSetCachedData, DatasetPartitionCachedData -from csp.impl.wiring.cache_support.graph_building import ( - GraphBuildPartitionCacheManager, - GraphCacheOptions, - resolve_graph_cache_options, -) -from csp.impl.wiring.context import Context from csp.impl.wiring.graph_parser import GraphParser -from csp.impl.wiring.outputs import CacheWriteOnlyOutputsContainer, OutputsContainer -from csp.impl.wiring.special_output_names import ALL_SPECIAL_OUTPUT_NAMES, UNNAMED_OUTPUT_NAME - - -class NoCachedDataException(RuntimeError): - pass +from csp.impl.wiring.outputs import OutputsContainer +from csp.impl.wiring.special_output_names import UNNAMED_OUTPUT_NAME class _GraphDefMetaUsingAux: @@ -42,120 +24,11 @@ def using(self, **_forced_tvars): def __call__(self, *args, **kwargs): return self._graph_meta._instantiate(self._forced_tvars, *args, **kwargs) - def cache_periods(self, start_time, end_time): - return self._graph_meta.cached_data(start_time, end_time) - - def cached_data(self, data_folder=None, _forced_tvars=None) -> DatasetPartitionCachedData: - """Get the proxy object for accessing the graph cached data. - This is the basic interface for inspecting cache files and loading cached data as dataframes - :param data_folder: The root folder of the cache or an instance of CacheDataPathResolver - :return: An instance of DatasetPartitionCachedData to access the graph cached data - """ - if data_folder is None: - data_folder = Context.instance().config - return self._graph_meta.cached_data(data_folder, _forced_tvars) - - def cached(self, *args, **kwargs): - """A utility function to ensure that a graph is read from cache - For example if there is a cached graph g. - Calling g(a1, a2, ...) can either read it from cache or write the results to cache if no cached data is found. - Calling g.cached(a1, a2, ...) forces reading from cache, if no cache data is found then exception will be raised. - :param args: Positional arguments to the graph - :param kwargs: Keyword arguments to the graph - """ - return self._graph_meta.cached(*args, _forced_tvars=self._forced_tvars, **kwargs) - - -class _ForceCached: - """This class is an ugly workaround to avoid instantiating cached graphs. - The problem: - my_graph.cached(...) - is implemented by calling the regular code path of the graph instantiation and checking whether the graph is actually read from cache. This is a - problem since the user doesn't expect the graph to be instantiated if they use "cached" property. We can't also provide an argument "force_cached" to the instantiation - function since it's memcached and extra argument will cause calls to graph.cached(...) and graph(...) to result in different instances which is wrong. - - This class is a workaround to pass this "require_cached" flag not via arguments - """ - - _INSTANCE = threading.local() - - @classmethod - def is_force_cached(cls): - if not hasattr(cls._INSTANCE, "force_cached"): - return False - return cls._INSTANCE.force_cached - - @classmethod - @contextmanager - def force_cached(cls): - prev_value = cls.is_force_cached() - try: - cls._INSTANCE.force_cached = True - yield - finally: - cls._INSTANCE.force_cached = prev_value - - -class _CacheProxy: - """A helper class that allows to access cached data in a given time range, that can be smaller than the engine run time - - Usage: - my_graph.cached[start:end] - The cached property of the graph will return an instance of _CacheProxy which can then be called with the appropriate parameters. - """ - - def __init__(self, owner, csp_cache_start=None, csp_cache_end=None, csp_timestamp_shift=None): - self._owner = owner - self._csp_cache_start = csp_cache_start - self._csp_cache_end = csp_cache_end - self._csp_timestamp_shift = csp_timestamp_shift - - def __getitem__(self, item): - assert isinstance(item, slice), "cached item range must be a slice" - assert item.step is None, "Providing step for cache range is not supported" - res = _CacheProxy(self._owner, csp_timestamp_shift=self._csp_timestamp_shift) - res._csp_cache_start = item.start - # The range values are exclusive but for caching purposes we need inclusive end time - res._csp_cache_end = item.stop - datetime.timedelta(microseconds=1) - return res - - def shifted(self, csp_timestamp_shift: datetime.timedelta): - return _CacheProxy( - self._owner, - csp_cache_start=self._csp_cache_start, - csp_cache_end=self._csp_cache_end, - csp_timestamp_shift=csp_timestamp_shift, - ) - - def __call__(self, *args, _forced_tvars=None, **kwargs): - with _ForceCached.force_cached(): - return self._owner._cached_impl( - _forced_tvars, - self._csp_cache_start, - self._csp_cache_end, - args, - kwargs, - csp_timestamp_shift=self._csp_timestamp_shift, - ) - class GraphDefMeta(type): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._instantiate_func = self._instantiate_impl - ignored_inputs = getattr(self._cache_options, "ignored_inputs", None) - if self._cache_options and ignored_inputs is not None: - non_ignored_inputs = [input for input in self._signature.inputs if input.name not in ignored_inputs] - self._cache_signature = Signature( - name=self._signature._name, - inputs=[ - InputDef(name=s.name, typ=s.typ, kind=s.kind, basket_kind=s.basket_kind, ts_idx=None, arg_idx=i) - for i, s in enumerate(non_ignored_inputs) - ], - outputs=self._signature._outputs, - defaults={k: v for k, v in self._signature._defaults.items() if k not in ignored_inputs}, - ) - else: - self._cache_signature = self._signature if self.memoize or self.force_memoize: if self.wrapped_node: full_name = function_full_name(self.wrapped_node._impl) @@ -171,103 +44,7 @@ def _extract_forced_tvars(cls, d): return d.pop("_forced_tvars") return {} - @property - def _cached_function(self): - return self.wrapped_node._impl if self.wrapped_node else self._func - - def cache_periods(self, start_time, end_time): - from csp.impl.managed_dataset.aggregation_period_utils import AggregationPeriodUtils - - agg_period_utils = AggregationPeriodUtils(self._cache_options.time_aggregation) - return list(agg_period_utils.iterate_periods_in_date_range(start_time, end_time)) - - def cached_data(self, data_folder=None, _forced_tvars=None) -> DatasetPartitionCachedData: - """Get the proxy object for accessing the graph cached data. - This is the basic interface for inspecting cache files and loading cached data as dataframes - :param data_folder: An instance of string (data folder), csp.Config or csp.cache_support.CacheConfig with the appropriate cache config. Note, that - only if one of he configs passed in then the category resolution and custom data types serialization handled properly. Pass in string only if None of the above - features is used. - :return: An instance of DatasetPartitionCachedData to access the graph cached data - """ - if not self._cache_options: - raise RuntimeError("Trying to get cached data from graph that doesn't cache") - if data_folder is None: - data_folder = Context.instance().config - - if isinstance(data_folder, csp.Config): - cache_config = data_folder.cache_config - elif isinstance(data_folder, csp.cache_support.CacheConfig): - cache_config = data_folder - else: - cache_config = None - - if cache_config: - cache_serializers = cache_config.cache_serializers - else: - cache_serializers = {} - - cache_signature = self._cache_signature - - dataset = GraphBuildPartitionCacheManager.get_dataset_for_func( - graph=self, func=self._cached_function, cache_options=self._cache_options, data_folder=data_folder - ) - if dataset is None: - return None - - def _get_dataset_partition(*args, **kwargs): - inputs, scalars, tvars = cache_signature.parse_inputs(_forced_tvars, *args, **kwargs) - partition_values = dict(zip((i.name for i in cache_signature._inputs), scalars)) - return dataset.get_partition(partition_values) - - return DataSetCachedData(dataset, cache_serializers, _get_dataset_partition) - - @property - def cached(self) -> _CacheProxy: - """ - Usage: - my_graph.cached[start:end] - Will return an instance of _CacheProxy which can then be called with the appropriate parameters. - :return: A cache proxy that can be used to limit the time of the returned graph. - """ - return _CacheProxy(self) - - def _cached_impl(self, _forced_tvars, csp_cache_start, csp_cache_end, args, kwargs, csp_timestamp_shift=None): - """A utility function to ensure that a graph is read from cache - For example if there is a cached graph g. - Calling g(a1, a2, ...) can either read it from cache or write the results to cache if no cached data is found. - Calling g.cached(a1, a2, ...) forces reading from cache, if no cache data is found then exception will be raised. - :param args: Positional arguments to the graph - :param csp_cache_start: The start time of the cached data before which we don't want to load any date - :param csp_cache_end: The end time of the cached data after which we don't want to load any date - :param kwargs: Keyword arguments to the graph - """ - if Context.TLS.instance.config and hasattr(Context.TLS.instance.config, "cache_config") and self._cache_options: - read_from_cache, res, _ = self._instantiate_func( - _forced_tvars, self._signature, args, kwargs, csp_cache_start, csp_cache_end, csp_timestamp_shift - ) - assert read_from_cache - else: - raise NoCachedDataException( - f"No data found in cache for {self._signature._name} for the given run period, seems like cache_config is unset" - ) - return res - - def _raise_if_forced_cache_read(self, missing_period=None): - if _ForceCached.is_force_cached(): - if missing_period: - missing_period_str = f" {str(missing_period[0])} to {str(missing_period[1])}" - raise NoCachedDataException( - f"No data found in cache for {self._signature._name} for period{missing_period_str}" - ) - else: - raise NoCachedDataException( - f"No data found in cache for {self._signature._name} for the given run period" - ) - - def _instantiate_impl( - self, _forced_tvars, signature, args, kwargs, csp_cache_start=None, csp_cache_end=None, csp_timestamp_shift=None - ): - read_from_cache = False + def _instantiate_impl(self, _forced_tvars, signature, args, kwargs): inputs, scalars, tvars = signature.parse_inputs(_forced_tvars, *args, allow_none_ts=True, **kwargs) basket_shape_eval_inputs = list(scalars) @@ -278,62 +55,13 @@ def _instantiate_impl( expected_outputs = signature.resolve_output_definitions( tvars=tvars, basket_shape_eval_inputs=basket_shape_eval_inputs ) - if signature.special_outputs: - expected_regular_outputs = tuple(v for v in expected_outputs if v.name not in ALL_SPECIAL_OUTPUT_NAMES) - else: - expected_regular_outputs = expected_outputs - - cache_manager = None - if ( - hasattr(Context.TLS, "instance") - and Context.TLS.instance.config - and hasattr(Context.TLS.instance.config, "cache_config") - and self._cache_options - ): - ignored_inputs = getattr(self._cache_options, "ignored_inputs", set()) - cache_scalars = tuple(s for s, s_def in zip(scalars, signature.scalars) if s_def.name not in ignored_inputs) - - cache_manager = GraphBuildPartitionCacheManager.create_cache_manager( - graph=self, - func=self._cached_function, - signature=self._cache_signature, - non_ignored_scalars=cache_scalars, - all_scalars=scalars, - cache_options=self._cache_options, - expected_outputs=expected_regular_outputs, - tvars=tvars, - csp_cache_start=csp_cache_start, - csp_cache_end=csp_cache_end, - csp_timestamp_shift=csp_timestamp_shift, - ) - allow_non_cached_read = True - if cache_manager: - if cache_manager.outputs is not None: - res = cache_manager.outputs - read_from_cache = True - elif cache_manager.written_outputs is not None: - res = cache_manager.written_outputs - else: - self._raise_if_forced_cache_read(cache_manager.first_missing_period) - res = self._func(*args, **kwargs) - cache_manager.cache_outputs(res) - allow_non_cached_read = not cache_manager.is_force_cache_read - else: - self._raise_if_forced_cache_read() - res = self._func(*args, **kwargs) - - if read_from_cache: - expected_outputs = expected_regular_outputs + res = self._func(*args, **kwargs) # Validate graph return values if isinstance(res, OutputsContainer): outputs_raw = [] for e_o in expected_outputs: - output_name = ( - e_o.name - if e_o.name - else (DatasetNameConstants.UNNAMED_OUTPUT_NAME if read_from_cache else UNNAMED_OUTPUT_NAME) - ) + output_name = e_o.name if e_o.name else UNNAMED_OUTPUT_NAME cur_o = res._get(output_name, UNSET) if cur_o is UNSET: raise KeyError(f"Output {output_name} is not returned from the graph") @@ -360,26 +88,17 @@ def _instantiate_impl( output_definitions=expected_outputs, values=outputs_raw, forced_tvars=tvars, - allow_subtypes=self._cache_options is None, ) if signature.special_outputs: - if not read_from_cache: - if expected_outputs[0].name is None: - res = next(iter(res._values())) - else: - res = OutputsContainer(**{k: v for k, v in res._items() if k not in ALL_SPECIAL_OUTPUT_NAMES}) + if expected_outputs[0].name is None: + res = next(iter(res._values())) + else: + res = OutputsContainer(**{k: v for k, v in res._items() if k != -UNNAMED_OUTPUT_NAME}) - return read_from_cache, res, allow_non_cached_read + return res def _instantiate(self, _forced_tvars, *args, **kwargs): - _, res, allow_non_cached_read = self._instantiate_func(_forced_tvars, self._signature, args=args, kwargs=kwargs) - if not allow_non_cached_read: - if isinstance(res, csp.impl.wiring.edge.Edge): - return CacheWriteOnlyOutputsContainer(iter([res])) - else: - return CacheWriteOnlyOutputsContainer(iter(res)) - else: - return res + return self._instantiate_func(_forced_tvars, self._signature, args=args, kwargs=kwargs) def __call__(cls, *args, **kwargs): return cls._instantiate(None, *args, **kwargs) @@ -400,14 +119,9 @@ def _create_graph( signature, memoize, force_memoize, - cache, - cache_options, wrapped_function=None, wrapped_node=None, ): - resolved_cache_options = resolve_graph_cache_options( - signature=signature, cache_enabled=cache, cache_options=cache_options - ) return GraphDefMeta( func_name, (object,), @@ -416,7 +130,6 @@ def _create_graph( "_func": impl, "memoize": memoize, "force_memoize": force_memoize, - "_cache_options": resolved_cache_options, "__wrapped__": wrapped_function, "__module__": wrapped_function.__module__, "wrapped_node": wrapped_node, @@ -430,18 +143,15 @@ def graph( *, memoize=True, force_memoize=False, - cache: bool = False, - cache_options: GraphCacheOptions = None, name=None, debug_print=False, ): """ :param func: - :param memoize: Specify whether the node should be memoized (default True) - :param force_memoize: If True, the node will be memoized even if csp.memoize(False) was called. Usually it should not be set, set + :param memoize: Specify whether the graph should be memoized (default True) + :param force_memoize: If True, the graph will be memoized even if csp.memoize(False) was called. Usually it should not be set, set this to True ONLY if memoization required to guarantee correctness of the function (i.e the function must be called at most once with the for each set of parameters). - :param cache_options: The options for graph caching :param name: Provide a custom name for the constructed graph type :param debug_print: A boolean that specifies that processed function should be printed :return: @@ -450,12 +160,10 @@ def graph( def _impl(func): with ExceptionContext(): - add_cache_control_output = cache_options is not None and getattr(cache_options, "controlled_cache", False) parser = GraphParser( name or func.__name__, func, func_frame, - add_cache_control_output=add_cache_control_output, debug_print=debug_print, ) parser.parse() @@ -468,8 +176,6 @@ def _impl(func): signature, memoize, force_memoize, - cache, - cache_options, wrapped_function=func, ) diff --git a/csp/impl/wiring/graph_parser.py b/csp/impl/wiring/graph_parser.py index 36c24317..efe73cdc 100644 --- a/csp/impl/wiring/graph_parser.py +++ b/csp/impl/wiring/graph_parser.py @@ -2,19 +2,18 @@ from csp.impl.wiring import Signature from csp.impl.wiring.base_parser import BaseParser, CspParseError, _pythonic_depr_warning -from csp.impl.wiring.special_output_names import CSP_CACHE_ENABLED_OUTPUT, UNNAMED_OUTPUT_NAME +from csp.impl.wiring.special_output_names import UNNAMED_OUTPUT_NAME class GraphParser(BaseParser): _DEBUG_PARSE = False - def __init__(self, name, raw_func, func_frame, debug_print=False, add_cache_control_output=False): + def __init__(self, name, raw_func, func_frame, debug_print=False): super().__init__( name=name, raw_func=raw_func, func_frame=func_frame, debug_print=debug_print, - add_cache_control_output=add_cache_control_output, ) def visit_FunctionDef(self, node): @@ -51,26 +50,17 @@ def visit_Return(self, node): if len(self._outputs) and node.value is None: raise CspParseError("return does not return values with non empty outputs") - if self._add_cache_control_output: - if isinstance(node.value, ast.Call): - parsed_return = self.visit_Call(node.value) - if isinstance(parsed_return, ast.Return): - return parsed_return + if isinstance(node.value, ast.Call): + if len(self._outputs) > 1: + self._validate_output(node) - returned_value = node.value - return self._wrap_returned_value_and_add_special_outputs(returned_value) - else: - if isinstance(node.value, ast.Call): - if len(self._outputs) > 1: - self._validate_output(node) - - parsed_return = self.visit_Call(node.value) - if isinstance(parsed_return, ast.Call): - return ast.Return(value=parsed_return, lineno=node.lineno, end_lineno=node.end_lineno) + parsed_return = self.visit_Call(node.value) + if isinstance(parsed_return, ast.Call): + return ast.Return(value=parsed_return, lineno=node.lineno, end_lineno=node.end_lineno) - return parsed_return + return parsed_return - return node + return node def _parse_single_output_definition(self, name, arg_type_node, ts_idx, typ=None): return self._parse_single_output_definition_with_shapes( @@ -101,11 +91,8 @@ def _parse_return(self, node, special_outputs=None): raise CspParseError("returning from graph without any outputs defined", node.lineno) elif ( len(self._signature._outputs) == 1 and self._signature._outputs[0].name is None - ): # graph only has one unnamed output - if self._add_cache_control_output: - return self._wrap_returned_value_and_add_special_outputs(node.args[0]) - else: - return ast.Return(value=node.args[0], lineno=node.lineno, end_lineno=node.end_lineno) + ): # graph only has one unnamed output: + return ast.Return(value=node.args[0], lineno=node.lineno, end_lineno=node.end_lineno) else: node.keywords = [ast.keyword(arg=self._signature._outputs[0].name, value=node.args[0])] node.args.clear() @@ -160,32 +147,10 @@ def visit_Expr(self, node): return res.value return res - def _get_special_output_name_mapping(self): - """ - :return: A dict mapping local_variable->output_name - """ - return {CSP_CACHE_ENABLED_OUTPUT: CSP_CACHE_ENABLED_OUTPUT} - def visit_Call(self, node: ast.Call): if (isinstance(node.func, ast.Name) and node.func.id == "__return__") or BaseParser._is_csp_output_call(node): special_outputs = {} - if self._add_cache_control_output: - special_outputs = self._get_special_output_name_mapping() return self._parse_return(node, special_outputs) - if ( - isinstance(node.func, ast.Attribute) - and isinstance(node.func.value, ast.Name) - and node.func.value.id == "csp" - and node.func.attr == "set_cache_enable_ts" - ): - if len(node.args) != 1 or node.keywords: - raise CspParseError("Invalid call to csp.set_cache_enable_ts", node.lineno) - if self._add_cache_control_output: - return ast.Assign(targets=[ast.Name(id=CSP_CACHE_ENABLED_OUTPUT, ctx=ast.Store())], value=node.args[0]) - else: - raise CspParseError( - "Invalid call to csp.set_cache_enable_ts in graph with non controlled cache", node.lineno - ) return self.generic_visit(node) def _is_ts_args_removed_from_signature(self): @@ -193,23 +158,10 @@ def _is_ts_args_removed_from_signature(self): def _parse_impl(self): self._inputs, input_defaults, self._outputs = self.parse_func_signature(self._funcdef) - self._resolve_special_outputs() # Should have inputs and outputs at this point self._signature = Signature( self._name, self._inputs, self._outputs, input_defaults, special_outputs=self._special_outputs ) - # We need to set default value for the cache control variable as the first command in the function - if self._add_cache_control_output: - self._funcdef.body = [ - ast.Assign( - targets=[ast.Name(id=CSP_CACHE_ENABLED_OUTPUT, ctx=ast.Store())], - value=ast.Call( - func=ast.Attribute(value=ast.Name(id="csp", ctx=ast.Load()), attr="null_ts", ctx=ast.Load()), - args=[ast.Name(id="bool", ctx=ast.Load())], - keywords=[], - ), - ) - ] + self._funcdef.body self.generic_visit(self._funcdef) newfuncdef = ast.FunctionDef(name=self._funcdef.name, body=self._funcdef.body, returns=None) diff --git a/csp/impl/wiring/node.py b/csp/impl/wiring/node.py index e4cfeab0..a905bee3 100644 --- a/csp/impl/wiring/node.py +++ b/csp/impl/wiring/node.py @@ -195,16 +195,12 @@ def _create_node( cppimpl, pre_create_hook, name, - cache: bool = False, - cache_options=None, ): - create_wrapper = cache or cache_options parser = NodeParser( name, func, func_frame, debug_print=debug_print, - add_cache_control_output=cache_options and cache_options.controlled_cache, ) parser.parse() @@ -223,26 +219,7 @@ def _create_node( "__doc__": parser._docstring, }, ) - if create_wrapper: - from csp.impl.wiring.graph import _create_graph - - def wrapper(*args, **kwargs): - return nodetype(*args, **kwargs) - - return _create_graph( - name, - func.__doc__, - wrapper, - parser._signature.copy(drop_alarms=True), - memoize, - force_memoize, - cache, - cache_options, - wrapped_function=func, - wrapped_node=nodetype, - ) - else: - return nodetype + return nodetype def _node_internal_use( @@ -254,8 +231,6 @@ def _node_internal_use( debug_print=False, cppimpl=None, pre_create_hook=None, - cache: bool = False, - cache_options=None, name=None, ): """A decorator similar to the @node decorator that exposes some internal arguments that shoudn't be visible to users""" @@ -272,8 +247,6 @@ def _impl(func): cppimpl=cppimpl, pre_create_hook=pre_create_hook, name=name or func.__name__, - cache=cache, - cache_options=cache_options, ) if func is None: @@ -291,8 +264,6 @@ def node( force_memoize=False, debug_print=False, cppimpl=None, - cache: bool = False, - cache_options=None, name=None, ): """ @@ -303,8 +274,6 @@ def node( set of parameters). :param debug_print: A boolean that specifies that processed function should be printed :param cppimpl: - :param cache: - :param cache_options: :param name: Provide a custom name for the constructed node type, helpful when viewing a graph with many same-named nodes :return: """ @@ -317,7 +286,5 @@ def node( debug_print=debug_print, cppimpl=cppimpl, pre_create_hook=None, - cache=cache, - cache_options=cache_options, name=name, ) diff --git a/csp/impl/wiring/node_parser.py b/csp/impl/wiring/node_parser.py index 818ea2f7..af6726ea 100644 --- a/csp/impl/wiring/node_parser.py +++ b/csp/impl/wiring/node_parser.py @@ -13,7 +13,6 @@ from csp.impl.wiring import Signature from csp.impl.wiring.ast_utils import ASTUtils from csp.impl.wiring.base_parser import BaseParser, CspParseError, _pythonic_depr_warning -from csp.impl.wiring.special_output_names import CSP_CACHE_ENABLED_OUTPUT class _SingleProxyFuncArgResolver(object): @@ -96,13 +95,12 @@ class NodeParser(BaseParser): _INPUT_PROXY_VARNAME = "input_proxy" _OUTPUT_PROXY_VARNAME = "output_proxy" - def __init__(self, name, raw_func, func_frame, debug_print=False, add_cache_control_output=False): + def __init__(self, name, raw_func, func_frame, debug_print=False): super().__init__( name=name, raw_func=raw_func, func_frame=func_frame, debug_print=debug_print, - add_cache_control_output=add_cache_control_output, ) self._stateblock = [] self._startblock = [] @@ -661,14 +659,6 @@ def _parse_engine_end_time(self, node): keywords=[], ) - def _parse_csp_enable_cache(self, node): - if len(node.args) != 1 or node.keywords: - raise CspParseError("Invalid call to csp.enable_cache", node.lineno) - - output = self._ts_outproxy_expr(CSP_CACHE_ENABLED_OUTPUT) - res = ast.BinOp(left=output, op=ast.Add(), right=node.args[0]) - return res - def _parse_csp_engine_stats(self, node): if len(node.args) or len(node.keywords): raise CspParseError("csp.engine_stats takes no arguments", node.lineno) @@ -770,7 +760,6 @@ def _is_ts_args_removed_from_signature(self): def _parse_impl(self): self._inputs, input_defaults, self._outputs = self.parse_func_signature(self._funcdef) idx = self._parse_special_blocks(self._funcdef.body) - self._resolve_special_outputs() self._signature = Signature( self._name, self._inputs, @@ -918,7 +907,6 @@ def _init_internal_maps(cls): "csp.set_buffering_policy": cls._parse_set_buffering_policy, "csp.engine_start_time": cls._parse_engine_start_time, "csp.engine_end_time": cls._parse_engine_end_time, - "csp.enable_cache": cls._parse_csp_enable_cache, "csp.engine_stats": cls._parse_csp_engine_stats, } diff --git a/csp/impl/wiring/outputs.py b/csp/impl/wiring/outputs.py index 90afe076..1e0b1ecb 100644 --- a/csp/impl/wiring/outputs.py +++ b/csp/impl/wiring/outputs.py @@ -35,18 +35,3 @@ def _get(self, item, dflt=None): def __repr__(self): return "OutputsContainer( %s )" % (",".join("%s=%r" % (k, v) for k, v in self._items())) - - -class CacheWriteOnlyOutputsContainer(list): - def __repr__(self): - return f'CacheWriteOnlyOutputsContainer( {",".join(v for v in self)} )' - - def __getattr__(self, item): - raise RuntimeError( - "Outputs of graphs with custom data_timestamp_column_name must be read using .cached property" - ) - - def __getitem__(self, item): - raise RuntimeError( - "Outputs of graphs with custom data_timestamp_column_name must be read using .cached property" - ) diff --git a/csp/impl/wiring/runtime.py b/csp/impl/wiring/runtime.py index bcc7dd15..d334bd1e 100644 --- a/csp/impl/wiring/runtime.py +++ b/csp/impl/wiring/runtime.py @@ -3,15 +3,13 @@ import time from collections import deque from datetime import datetime, timedelta -from typing import Optional from csp.impl.__cspimpl import _cspimpl -from csp.impl.config import Config from csp.impl.error_handling import ExceptionContext from csp.impl.wiring.adapters import _graph_return_adapter from csp.impl.wiring.context import Context from csp.impl.wiring.edge import Edge -from csp.impl.wiring.outputs import CacheWriteOnlyOutputsContainer, OutputsContainer +from csp.impl.wiring.outputs import OutputsContainer from csp.profiler import Profiler, graph_info MAX_END_TIME = datetime(2261, 12, 31, 23, 59, 50, 999999) @@ -34,14 +32,14 @@ def _normalize_run_times(starttime, endtime, realtime): return starttime, endtime -def build_graph(f, *args, config: Optional[Config] = None, starttime=None, endtime=None, realtime=False, **kwargs): +def build_graph(f, *args, starttime=None, endtime=None, realtime=False, **kwargs): assert ( (starttime is None) == (endtime is None) ), "Start time and end time should either both be specified or none of them should be specified when building a graph" if starttime: starttime, endtime = _normalize_run_times(starttime, endtime, realtime) with ExceptionContext(), GraphRunInfo(starttime=starttime, endtime=endtime, realtime=realtime), Context( - start_time=starttime, end_time=endtime, config=config + start_time=starttime, end_time=endtime ) as c: # Setup the profiler if within a profiling context if Profiler.instance() is not None and not Profiler.instance().initialized: @@ -54,7 +52,7 @@ def build_graph(f, *args, config: Optional[Config] = None, starttime=None, endti processed_outputs = OutputsContainer() - if outputs is not None and not isinstance(outputs, CacheWriteOnlyOutputsContainer): + if outputs is not None: if isinstance(outputs, Edge): processed_outputs[0] = outputs elif isinstance(outputs, list): @@ -112,19 +110,6 @@ def _build_engine(engine, context, memo=None): return engine -def _run_engine(engine, starttime, endtime, context_config=None, cache_data=None): - # context = Context.TLS.instance - cache_config = getattr(context_config, "cache_config", None) if context_config else None - if cache_config: - from csp.impl.wiring.cache_support.runtime_cache_manager import RuntimeCacheManager - - runtime_cache_manager = RuntimeCacheManager(cache_config, cache_data) - with runtime_cache_manager: - return engine.run(starttime, endtime) - else: - return engine.run(starttime, endtime) - - class GraphRunInfo: TLS = threading.local() @@ -174,7 +159,6 @@ def run( *args, starttime=None, endtime=MAX_END_TIME, - config: Optional[Config] = None, queue_wait_time=None, realtime=False, output_numpy=False, @@ -193,9 +177,6 @@ def run( orig_g.context = None if isinstance(g, Context): - if config is not None: - raise RuntimeError("Config can not be specified when running a built graph") - if g.start_time is not None: assert ( (g.start_time, g.end_time) == (starttime, endtime) @@ -212,8 +193,6 @@ def run( engine = _cspimpl.PyEngine(**engine_settings) engine = _build_engine(engine, g) - context_config = g.config - cache_data = g.cache_data mem_cache = g.mem_cache # Release graph construct at this point to free up all the edge / nodedef memory thats no longer needed del g @@ -224,18 +203,14 @@ def run( time.sleep((starttime - datetime.utcnow()).total_seconds()) with mem_cache: - return _run_engine( - engine, starttime=starttime, endtime=endtime, context_config=context_config, cache_data=cache_data - ) + return engine.run(starttime, endtime) if isinstance(g, Edge): - if config is not None: - raise RuntimeError("Config can not be specified when running a built graph") return run(lambda: g, starttime=starttime, endtime=endtime, **engine_settings) # wrapped in a _WrappedContext so that we can give up the mem before run graph = _WrappedContext( - build_graph(g, *args, starttime=starttime, endtime=endtime, realtime=realtime, config=config, **kwargs) + build_graph(g, *args, starttime=starttime, endtime=endtime, realtime=realtime, **kwargs) ) with GraphRunInfo(starttime=starttime, endtime=endtime, realtime=realtime): return run(graph, starttime=starttime, endtime=endtime, **engine_settings) diff --git a/csp/impl/wiring/signature.py b/csp/impl/wiring/signature.py index eb058274..2a3ebdbb 100644 --- a/csp/impl/wiring/signature.py +++ b/csp/impl/wiring/signature.py @@ -100,7 +100,6 @@ def parse_inputs(self, forced_tvars, *args, allow_subtypes=True, allow_none_ts=F input_definitions=self._inputs[self._num_alarms :], arguments=flat_args, forced_tvars=forced_tvars, - allow_subtypes=allow_subtypes, allow_none_ts=allow_none_ts, ) diff --git a/csp/impl/wiring/special_output_names.py b/csp/impl/wiring/special_output_names.py index 94b2c9dc..3f66c063 100644 --- a/csp/impl/wiring/special_output_names.py +++ b/csp/impl/wiring/special_output_names.py @@ -1,4 +1 @@ -CSP_CACHE_ENABLED_OUTPUT = "__csp_cache_enable_ts" UNNAMED_OUTPUT_NAME = "__csp__unnamed_output__" - -ALL_SPECIAL_OUTPUT_NAMES = {CSP_CACHE_ENABLED_OUTPUT} diff --git a/csp/impl/wiring/threaded_runtime.py b/csp/impl/wiring/threaded_runtime.py index 0726e3c4..55e32601 100644 --- a/csp/impl/wiring/threaded_runtime.py +++ b/csp/impl/wiring/threaded_runtime.py @@ -1,8 +1,6 @@ import threading -from typing import Optional import csp -from csp.impl.config import Config from csp.impl.pushadapter import PushInputAdapter from csp.impl.types.tstype import ts from csp.impl.wiring import MAX_END_TIME, py_push_adapter_def @@ -110,7 +108,6 @@ def run_on_thread( *args, starttime=None, endtime=MAX_END_TIME, - config: Optional[Config] = None, queue_wait_time=None, realtime=False, auto_shutdown=False, @@ -122,7 +119,6 @@ def run_on_thread( *args, starttime=starttime, endtime=endtime, - config=config, queue_wait_time=queue_wait_time, realtime=realtime, auto_shutdown=auto_shutdown, diff --git a/csp/tests/impl/test_struct.py b/csp/tests/impl/test_struct.py index 0920aa5c..6f76339a 100644 --- a/csp/tests/impl/test_struct.py +++ b/csp/tests/impl/test_struct.py @@ -171,20 +171,92 @@ def __init__(self, x: int): # items[:-2] are normal values of the given type that should be handled, # items[-2] is a normal value for non-generic and non-str types and None for generic and str types (the purpose is to test the raise of TypeError if a single object instead of a sequence is passed), # items[-1] is a value of a different type that is not convertible to the give type for non-generic types and None for generic types (the purpose is to test the raise of TypeError if an object of the wrong type is passed). -pystruct_list_test_values = { - int : [4, 2, 3, 5, 6, 7, 8, 's'], +pystruct_list_test_values = { + int: [4, 2, 3, 5, 6, 7, 8, "s"], bool: [True, True, True, False, True, False, True, 2], - float: [1.4, 3.2, 2.7, 1.0, -4.5, -6.0, -2.0, 's'], - datetime: [datetime(2022, 12, 6, 1, 2, 3), datetime(2022, 12, 7, 2, 2, 3), datetime(2022, 12, 8, 3, 2, 3), datetime(2022, 12, 9, 4, 2, 3), datetime(2022, 12, 10, 5, 2, 3), datetime(2022, 12, 11, 6, 2, 3), datetime(2022, 12, 13, 7, 2, 3), timedelta(seconds=.123)], - timedelta: [timedelta(seconds=.123), timedelta(seconds=12), timedelta(seconds=1), timedelta(seconds=.5), timedelta(seconds=123), timedelta(seconds=70), timedelta(seconds=700), datetime(2022, 12, 8, 3, 2, 3)], - date: [date(2022, 12, 6), date(2022, 12, 7), date(2022, 12, 8), date(2022, 12, 9), date(2022, 12, 10), date(2022, 12, 11), date(2022, 12, 13), timedelta(seconds=.123)], - time: [time(1, 2, 3), time(2, 2, 3), time(3, 2, 3), time(4, 2, 3), time(5, 2, 3), time(6, 2, 3), time(7, 2, 3), timedelta(seconds=.123)], - str : ['s', 'pqr', 'masd', 'wes', 'as', 'm', None, 5], - csp.Struct: [SimpleStruct(a = 1), AnotherSimpleStruct(b = 'sd'), SimpleStruct(a = 3), AnotherSimpleStruct(b = 'sdf'), SimpleStruct(a = -4), SimpleStruct(a = 5), SimpleStruct(a = 7), 4], # untyped struct list - SimpleStruct: [SimpleStruct(a = 1), SimpleStruct(a = 3), SimpleStruct(a = -1), SimpleStruct(a = -4), SimpleStruct(a = 5), SimpleStruct(a = 100), SimpleStruct(a = 1200), AnotherSimpleStruct(b = 'sd')], - SimpleEnum: [SimpleEnum.A, SimpleEnum.C, SimpleEnum.B, SimpleEnum.B, SimpleEnum.B, SimpleEnum.C, SimpleEnum.C, AnotherSimpleEnum.D], + float: [1.4, 3.2, 2.7, 1.0, -4.5, -6.0, -2.0, "s"], + datetime: [ + datetime(2022, 12, 6, 1, 2, 3), + datetime(2022, 12, 7, 2, 2, 3), + datetime(2022, 12, 8, 3, 2, 3), + datetime(2022, 12, 9, 4, 2, 3), + datetime(2022, 12, 10, 5, 2, 3), + datetime(2022, 12, 11, 6, 2, 3), + datetime(2022, 12, 13, 7, 2, 3), + timedelta(seconds=0.123), + ], + timedelta: [ + timedelta(seconds=0.123), + timedelta(seconds=12), + timedelta(seconds=1), + timedelta(seconds=0.5), + timedelta(seconds=123), + timedelta(seconds=70), + timedelta(seconds=700), + datetime(2022, 12, 8, 3, 2, 3), + ], + date: [ + date(2022, 12, 6), + date(2022, 12, 7), + date(2022, 12, 8), + date(2022, 12, 9), + date(2022, 12, 10), + date(2022, 12, 11), + date(2022, 12, 13), + timedelta(seconds=0.123), + ], + time: [ + time(1, 2, 3), + time(2, 2, 3), + time(3, 2, 3), + time(4, 2, 3), + time(5, 2, 3), + time(6, 2, 3), + time(7, 2, 3), + timedelta(seconds=0.123), + ], + str: ["s", "pqr", "masd", "wes", "as", "m", None, 5], + csp.Struct: [ + SimpleStruct(a=1), + AnotherSimpleStruct(b="sd"), + SimpleStruct(a=3), + AnotherSimpleStruct(b="sdf"), + SimpleStruct(a=-4), + SimpleStruct(a=5), + SimpleStruct(a=7), + 4, + ], # untyped struct list + SimpleStruct: [ + SimpleStruct(a=1), + SimpleStruct(a=3), + SimpleStruct(a=-1), + SimpleStruct(a=-4), + SimpleStruct(a=5), + SimpleStruct(a=100), + SimpleStruct(a=1200), + AnotherSimpleStruct(b="sd"), + ], + SimpleEnum: [ + SimpleEnum.A, + SimpleEnum.C, + SimpleEnum.B, + SimpleEnum.B, + SimpleEnum.B, + SimpleEnum.C, + SimpleEnum.C, + AnotherSimpleEnum.D, + ], list: [[1], [1, 2, 1], [6], [8, 3, 5], [3], [11, 8], None, None], # generic type list - SimpleClass: [SimpleClass(x = 1), SimpleClass(x = 5), SimpleClass(x = 9), SimpleClass(x = -1), SimpleClass(x = 2), SimpleClass(x = 3), None, None], # generic type user-defined + SimpleClass: [ + SimpleClass(x=1), + SimpleClass(x=5), + SimpleClass(x=9), + SimpleClass(x=-1), + SimpleClass(x=2), + SimpleClass(x=3), + None, + None, + ], # generic type user-defined } @@ -705,8 +777,8 @@ def __init__(self, iterable=None): class StructWithListDerivedType(csp.Struct): ldt: ListDerivedType - s1 = StructWithListDerivedType(ldt=ListDerivedType([1,2])) - self.assertTrue(isinstance(s1.to_dict()['ldt'], ListDerivedType)) + s1 = StructWithListDerivedType(ldt=ListDerivedType([1, 2])) + self.assertTrue(isinstance(s1.to_dict()["ldt"], ListDerivedType)) s2 = StructWithListDerivedType.from_dict(s1.to_dict()) self.assertEqual(s1, s2) @@ -1813,14 +1885,15 @@ def custom_jsonifier(obj): json.loads(test_struct.to_json(custom_jsonifier)) def test_list_field_append(self): - ''' Was a BUG when the struct with list field was not recognizing changes made to this field in python''' + """Was a BUG when the struct with list field was not recognizing changes made to this field in python""" for typ, v in pystruct_list_test_values.items(): + class A(csp.Struct): a: [typ] - - s = A( a = [] ) + + s = A(a=[]) s.a.append(v[0]) - + self.assertEqual(s.a, [v[0]]) s.a.append(v[1]) @@ -1834,14 +1907,15 @@ class A(csp.Struct): s.a.append(v[-1]) def test_list_field_insert(self): - ''' Was a BUG when the struct with list field was not recognizing changes made to this field in python''' + """Was a BUG when the struct with list field was not recognizing changes made to this field in python""" for typ, v in pystruct_list_test_values.items(): + class A(csp.Struct): a: [typ] - - s = A( a = [] ) + + s = A(a=[]) s.a.insert(0, v[0]) - + self.assertEqual(s.a, [v[0]]) s.a.insert(1, v[1]) @@ -1864,19 +1938,20 @@ class A(csp.Struct): s.a.insert(-1, v[-1]) def test_list_field_pop(self): - ''' Was a BUG when the struct with list field was not recognizing changes made to this field in python''' + """Was a BUG when the struct with list field was not recognizing changes made to this field in python""" for typ, v in pystruct_list_test_values.items(): + class A(csp.Struct): a: [typ] - - s = A(a = [v[0], v[1], v[2], v[3], v[4]]) + + s = A(a=[v[0], v[1], v[2], v[3], v[4]]) b = s.a.pop() - + self.assertEqual(s.a, [v[0], v[1], v[2], v[3]]) self.assertEqual(b, v[4]) b = s.a.pop(-1) - + self.assertEqual(s.a, [v[0], v[1], v[2]]) self.assertEqual(b, v[3]) @@ -1884,13 +1959,13 @@ class A(csp.Struct): self.assertEqual(s.a, [v[0], v[2]]) self.assertEqual(b, v[1]) - + with self.assertRaises(IndexError) as e: s.a.pop() s.a.pop() s.a.pop() - - s = A(a = [v[0], v[1], v[2], v[3], v[4]]) + + s = A(a=[v[0], v[1], v[2], v[3], v[4]]) b = s.a.pop(-3) @@ -1904,14 +1979,15 @@ class A(csp.Struct): s.a.pop(4) def test_list_field_set_item(self): - ''' Was a BUG when the struct with list field was not recognizing changes made to this field in python''' + """Was a BUG when the struct with list field was not recognizing changes made to this field in python""" for typ, v in pystruct_list_test_values.items(): + class A(csp.Struct): a: [typ] - - s = A( a = [v[0], v[1], v[2]] ) + + s = A(a=[v[0], v[1], v[2]]) s.a.__setitem__(0, v[3]) - + self.assertEqual(s.a, [v[3], v[1], v[2]]) s.a[1] = v[4] @@ -1927,7 +2003,7 @@ class A(csp.Struct): with self.assertRaises(IndexError) as e: s.a[-100] = v[0] - + s.a[5:6] = [v[0], v[1], v[2]] self.assertEqual(s.a, [v[3], v[4], v[5], v[0], v[1], v[2]]) @@ -1944,7 +2020,7 @@ class A(csp.Struct): self.assertEqual(s.a, [v[3], v[1], v[2], v[2], v[5]]) - # Check if not str or generic type (as str is a sequence of str) + # Check if not str or generic type (as str is a sequence of str) if v[-2] is not None: with self.assertRaises(TypeError) as e: s.a[1:4] = v[-2] @@ -1964,41 +2040,67 @@ class A(csp.Struct): self.assertEqual(s.a, [v[3], v[1], v[2], v[2], v[5]]) def test_list_field_reverse(self): - ''' Was a BUG when the struct with list field was not recognizing changes made to this field in python''' + """Was a BUG when the struct with list field was not recognizing changes made to this field in python""" for typ, v in pystruct_list_test_values.items(): + class A(csp.Struct): a: [typ] - - s = A( a = [v[0], v[1], v[2], v[3]] ) + + s = A(a=[v[0], v[1], v[2], v[3]]) s.a.reverse() - + self.assertEqual(s.a, [v[3], v[2], v[1], v[0]]) - + def test_list_field_sort(self): - ''' Was a BUG when the struct with list field was not recognizing changes made to this field in python''' + """Was a BUG when the struct with list field was not recognizing changes made to this field in python""" # Not using pystruct_list_test_values, as sort() tests are of different semantics (order and sorting key existance matters). - values = { - int : [1, 5, 2, 2, -1, -5, 's'], - float: [1.4, 5.2, 2.7, 2.7, -1.4, -5.2, 's'], - datetime: [datetime(2022, 12, 6, 1, 2, 3), datetime(2022, 12, 8, 3, 2, 3), datetime(2022, 12, 7, 2, 2, 3), datetime(2022, 12, 7, 2, 2, 3), datetime(2022, 12, 5, 2, 2, 3), datetime(2022, 12, 3, 2, 2, 3), None], - timedelta: [timedelta(seconds=1), timedelta(seconds=123), timedelta(seconds=12), timedelta(seconds=12), timedelta(seconds=.1), timedelta(seconds=.01), None], - date: [date(2022, 12, 6), date(2022, 12, 8), date(2022, 12, 7), date(2022, 12, 7), date(2022, 12, 5), date(2022, 12, 3), None], + values = { + int: [1, 5, 2, 2, -1, -5, "s"], + float: [1.4, 5.2, 2.7, 2.7, -1.4, -5.2, "s"], + datetime: [ + datetime(2022, 12, 6, 1, 2, 3), + datetime(2022, 12, 8, 3, 2, 3), + datetime(2022, 12, 7, 2, 2, 3), + datetime(2022, 12, 7, 2, 2, 3), + datetime(2022, 12, 5, 2, 2, 3), + datetime(2022, 12, 3, 2, 2, 3), + None, + ], + timedelta: [ + timedelta(seconds=1), + timedelta(seconds=123), + timedelta(seconds=12), + timedelta(seconds=12), + timedelta(seconds=0.1), + timedelta(seconds=0.01), + None, + ], + date: [ + date(2022, 12, 6), + date(2022, 12, 8), + date(2022, 12, 7), + date(2022, 12, 7), + date(2022, 12, 5), + date(2022, 12, 3), + None, + ], time: [time(5, 2, 3), time(7, 2, 3), time(6, 2, 3), time(6, 2, 3), time(4, 2, 3), time(3, 2, 3), None], - str : ['s', 'xyz', 'w', 'w', 'bds', 'a', None], + str: ["s", "xyz", "w", "w", "bds", "a", None], } for typ, v in values.items(): + class A(csp.Struct): a: [typ] - - s = A( a = [v[0], v[1], v[2], v[3], v[4], v[5]] ) - + + s = A(a=[v[0], v[1], v[2], v[3], v[4], v[5]]) + s.a.sort() - + self.assertEqual(s.a, [v[5], v[4], v[0], v[2], v[3], v[1]]) s.a.sort(reverse=True) - + self.assertEqual(s.a, [v[1], v[2], v[3], v[0], v[4], v[5]]) with self.assertRaises(TypeError) as e: @@ -2012,16 +2114,17 @@ class A(csp.Struct): s.a.sort(key=abs) self.assertEqual(s.a, [v[0], v[4], v[2], v[3], v[1], v[5]]) - + def test_list_field_extend(self): - ''' Was a BUG when the struct with list field was not recognizing changes made to this field in python''' + """Was a BUG when the struct with list field was not recognizing changes made to this field in python""" for typ, v in pystruct_list_test_values.items(): + class A(csp.Struct): a: [typ] - - s = A( a = [v[0], v[1], v[2]] ) + + s = A(a=[v[0], v[1], v[2]]) s.a.extend([v[3]]) - + self.assertEqual(s.a, [v[0], v[1], v[2], v[3]]) s.a.extend([]) @@ -2029,25 +2132,26 @@ class A(csp.Struct): self.assertEqual(s.a, [v[0], v[1], v[2], v[3], v[4], v[5]]) - # Check if not str or generic type (as str is a sequence of str) + # Check if not str or generic type (as str is a sequence of str) if v[-2] is not None: with self.assertRaises(TypeError) as e: s.a.extend(v[-2]) - - # Check if not generic type + + # Check if not generic type if v[-1] is not None: with self.assertRaises(TypeError) as e: s.a.extend([v[-1]]) - + def test_list_field_remove(self): - ''' Was a BUG when the struct with list field was not recognizing changes made to this field in python''' + """Was a BUG when the struct with list field was not recognizing changes made to this field in python""" for typ, v in pystruct_list_test_values.items(): + class A(csp.Struct): a: [typ] - - s = A( a = [v[0], v[1], v[0], v[2]] ) + + s = A(a=[v[0], v[1], v[0], v[2]]) s.a.remove(v[0]) - + self.assertEqual(s.a, [v[1], v[0], v[2]]) s.a.remove(v[2]) @@ -2058,32 +2162,34 @@ class A(csp.Struct): s.a.remove(v[3]) def test_list_field_clear(self): - ''' Was a BUG when the struct with list field was not recognizing changes made to this field in python''' + """Was a BUG when the struct with list field was not recognizing changes made to this field in python""" for typ, v in pystruct_list_test_values.items(): + class A(csp.Struct): a: [typ] - - s = A( a = [v[0], v[1], v[2], v[3]] ) + + s = A(a=[v[0], v[1], v[2], v[3]]) s.a.clear() - + self.assertEqual(s.a, []) - + def test_list_field_del(self): - ''' Was a BUG when the struct with list field was not recognizing changes made to this field in python''' + """Was a BUG when the struct with list field was not recognizing changes made to this field in python""" for typ, v in pystruct_list_test_values.items(): + class A(csp.Struct): a: [typ] - - s = A( a = [v[0], v[1], v[2], v[3]] ) + + s = A(a=[v[0], v[1], v[2], v[3]]) del s.a[0] - + self.assertEqual(s.a, [v[1], v[2], v[3]]) del s.a[1] self.assertEqual(s.a, [v[1], v[3]]) - s = A( a = [v[0], v[1], v[2], v[3]] ) + s = A(a=[v[0], v[1], v[2], v[3]]) del s.a[1:3] self.assertEqual(s.a, [v[0], v[3]]) @@ -2094,16 +2200,17 @@ class A(csp.Struct): with self.assertRaises(IndexError) as e: del s.a[5] - + def test_list_field_inplace_concat(self): - ''' Was a BUG when the struct with list field was not recognizing changes made to this field in python''' + """Was a BUG when the struct with list field was not recognizing changes made to this field in python""" for typ, v in pystruct_list_test_values.items(): + class A(csp.Struct): a: [typ] - - s = A( a = [v[0], v[1]] ) + + s = A(a=[v[0], v[1]]) s.a.__iadd__([v[2], v[3]]) - + self.assertEqual(s.a, [v[0], v[1], v[2], v[3]]) s.a += (v[4], v[5]) @@ -2117,22 +2224,23 @@ class A(csp.Struct): with self.assertRaises(TypeError) as e: s.a += v[-1] - # Check if not generic type + # Check if not generic type if v[-1] is not None: with self.assertRaises(TypeError) as e: s.a += [v[-1]] - + self.assertEqual(s.a, [v[0], v[1], v[2], v[3], v[4], v[5]]) - + def test_list_field_inplace_repeat(self): - ''' Was a BUG when the struct with list field was not recognizing changes made to this field in python''' + """Was a BUG when the struct with list field was not recognizing changes made to this field in python""" for typ, v in pystruct_list_test_values.items(): + class A(csp.Struct): a: [typ] - - s = A( a = [v[0], v[1]] ) + + s = A(a=[v[0], v[1]]) s.a.__imul__(1) - + self.assertEqual(s.a, [v[0], v[1]]) s.a *= 2 @@ -2143,10 +2251,10 @@ class A(csp.Struct): s.a *= [3] with self.assertRaises(TypeError) as e: - s.a *= 's' - + s.a *= "s" + s.a *= 0 - + self.assertEqual(s.a, []) s.a += [v[2], v[3]] @@ -2154,18 +2262,19 @@ class A(csp.Struct): self.assertEqual(s.a, [v[2], v[3]]) s.a *= -1 - + self.assertEqual(s.a, []) - + def test_list_field_lifetime(self): - '''Ensure that the lifetime of PyStructList field exceeds the lifetime of struct holding it''' + """Ensure that the lifetime of PyStructList field exceeds the lifetime of struct holding it""" + class A(csp.Struct): a: [int] - - s = A( a = [1, 2, 3] ) + + s = A(a=[1, 2, 3]) l = s.a del s - + self.assertEqual(l, [1, 2, 3]) diff --git a/csp/tests/test_caching.py b/csp/tests/test_caching.py deleted file mode 100644 index 04c86de1..00000000 --- a/csp/tests/test_caching.py +++ /dev/null @@ -1,2438 +0,0 @@ -# import collections -# import csp -# import csp.typing -# import glob -# import math -# import numpy -# import os -# import pandas -# import pytz -# import re -# import tempfile -# from typing import Dict -# import unittest -# from csp import Config, graph, node, ts -# from csp.adapters.parquet import ParquetOutputConfig -# from csp.cache_support import BaseCacheConfig, CacheCategoryConfig, CacheConfig, CacheConfigResolver, GraphCacheOptions, NoCachedDataException -# from csp.impl.managed_dataset.cache_user_custom_object_serializer import CacheObjectSerializer -# from csp.impl.managed_dataset.dataset_metadata import TimeAggregation -# from csp.impl.managed_dataset.managed_dataset_path_resolver import DatasetPartitionKey -# from csp.impl.types.instantiation_type_resolver import TSArgTypeMismatchError -# from csp.utils.object_factory_registry import Injected, register_injected_object, set_new_registry_thread_instance -# from datetime import date, datetime, timedelta -# from csp.tests.utils.typed_curve_generator import TypedCurveGenerator - - -# class _DummyStructWithTimestamp(csp.Struct): -# val: int -# timestamp: datetime - - -# class _GraphTempCacheFolderConfig: -# def __init__(self, allow_overwrite=False, merge_existing_files=True): -# self._temp_folder = None -# self._allow_overwrite = allow_overwrite -# self._merge_existing_files = merge_existing_files - -# def __enter__(self): -# assert self._temp_folder is None -# self._temp_folder = tempfile.TemporaryDirectory(prefix='csp_unit_tests') -# return Config(cache_config=CacheConfig(data_folder=self._temp_folder.name, allow_overwrite=self._allow_overwrite, -# merge_existing_files=self._merge_existing_files)) - -# def __exit__(self, exc_type, exc_val, exc_tb): -# if self._temp_folder: -# self._temp_folder.cleanup() -# self._temp_folder = None - - -# @csp.node -# def csp_sorted(x: ts[['T']]) -> ts[['T']]: -# if csp.ticked(x): -# return sorted(x) - - -# class TestCaching(unittest.TestCase): - -# EXPECTED_OUTPUT_TEST_SIMPLE = {'i': [(datetime(2020, 3, 1, 21, 31, 1, 1001), 1), (datetime(2020, 3, 1, 22, 32, 2, 2002), 2), (datetime(2020, 3, 1, 23, 33, 3, 3003), 3), -# (datetime(2020, 3, 2, 0, 34, 4, 4004), 4), (datetime(2020, 3, 2, 1, 35, 5, 5005), 5), (datetime(2020, 3, 2, 2, 36, 6, 6006), 6)], -# 'd': [(datetime(2020, 3, 1, 21, 31, 1, 1001), date(2020, 1, 2)), (datetime(2020, 3, 1, 22, 32, 2, 2002), date(2020, 1, 3)), -# (datetime(2020, 3, 1, 23, 33, 3, 3003), date(2020, 1, 4)), (datetime(2020, 3, 2, 0, 34, 4, 4004), date(2020, 1, 5)), -# (datetime(2020, 3, 2, 1, 35, 5, 5005), date(2020, 1, 6)), (datetime(2020, 3, 2, 2, 36, 6, 6006), date(2020, 1, 7))], -# 'dt': [(datetime(2020, 3, 1, 21, 31, 1, 1001), datetime(2020, 1, 2, 0, 0, 0, 1)), -# (datetime(2020, 3, 1, 22, 32, 2, 2002), datetime(2020, 1, 3, 0, 0, 0, 2)), -# (datetime(2020, 3, 1, 23, 33, 3, 3003), datetime(2020, 1, 4, 0, 0, 0, 3)), -# (datetime(2020, 3, 2, 0, 34, 4, 4004), datetime(2020, 1, 5, 0, 0, 0, 4)), -# (datetime(2020, 3, 2, 1, 35, 5, 5005), datetime(2020, 1, 6, 0, 0, 0, 5)), -# (datetime(2020, 3, 2, 2, 36, 6, 6006), datetime(2020, 1, 7, 0, 0, 0, 6))], -# 'f': [(datetime(2020, 3, 1, 21, 31, 1, 1001), 0.2), (datetime(2020, 3, 1, 22, 32, 2, 2002), 0.4), -# (datetime(2020, 3, 1, 23, 33, 3, 3003), 0.6000000000000001), (datetime(2020, 3, 2, 0, 34, 4, 4004), 0.8), -# (datetime(2020, 3, 2, 1, 35, 5, 5005), 1.0), (datetime(2020, 3, 2, 2, 36, 6, 6006), 1.2000000000000002)], -# 's': [(datetime(2020, 3, 1, 21, 31, 1, 1001), '1'), (datetime(2020, 3, 1, 22, 32, 2, 2002), '2'), -# (datetime(2020, 3, 1, 23, 33, 3, 3003), '3'), (datetime(2020, 3, 2, 0, 34, 4, 4004), '4'), -# (datetime(2020, 3, 2, 1, 35, 5, 5005), '5'), (datetime(2020, 3, 2, 2, 36, 6, 6006), '6')], -# 'b': [(datetime(2020, 3, 1, 21, 31, 1, 1001), True), (datetime(2020, 3, 1, 22, 32, 2, 2002), False), -# (datetime(2020, 3, 1, 23, 33, 3, 3003), True), (datetime(2020, 3, 2, 0, 34, 4, 4004), False), -# (datetime(2020, 3, 2, 1, 35, 5, 5005), True), (datetime(2020, 3, 2, 2, 36, 6, 6006), False)], -# 'simple_leaf_node': [(datetime(2020, 3, 1, 20, 30), 1)], -# 'p1_i': [(datetime(2020, 3, 1, 21, 31, 1, 1001), 33), (datetime(2020, 3, 1, 22, 32, 2, 2002), 34), -# (datetime(2020, 3, 1, 23, 33, 3, 3003), 35), (datetime(2020, 3, 2, 0, 34, 4, 4004), 36), -# (datetime(2020, 3, 2, 1, 35, 5, 5005), 37), (datetime(2020, 3, 2, 2, 36, 6, 6006), 38)], -# 'p1_d': [(datetime(2020, 3, 1, 21, 31, 1, 1001), date(2021, 1, 2)), (datetime(2020, 3, 1, 22, 32, 2, 2002), date(2021, 1, 3)), -# (datetime(2020, 3, 1, 23, 33, 3, 3003), date(2021, 1, 4)), (datetime(2020, 3, 2, 0, 34, 4, 4004), date(2021, 1, 5)), -# (datetime(2020, 3, 2, 1, 35, 5, 5005), date(2021, 1, 6)), (datetime(2020, 3, 2, 2, 36, 6, 6006), date(2021, 1, 7))], -# 'p1_dt': [(datetime(2020, 3, 1, 21, 31, 1, 1001), datetime(2020, 6, 7, 1, 2, 3, 5)), -# (datetime(2020, 3, 1, 22, 32, 2, 2002), datetime(2020, 6, 8, 1, 2, 3, 6)), -# (datetime(2020, 3, 1, 23, 33, 3, 3003), datetime(2020, 6, 9, 1, 2, 3, 7)), -# (datetime(2020, 3, 2, 0, 34, 4, 4004), datetime(2020, 6, 10, 1, 2, 3, 8)), -# (datetime(2020, 3, 2, 1, 35, 5, 5005), datetime(2020, 6, 11, 1, 2, 3, 9)), -# (datetime(2020, 3, 2, 2, 36, 6, 6006), datetime(2020, 6, 12, 1, 2, 3, 10))], -# 'p1_f': [(datetime(2020, 3, 1, 21, 31, 1, 1001), 11.4), (datetime(2020, 3, 1, 22, 32, 2, 2002), 17.1), -# (datetime(2020, 3, 1, 23, 33, 3, 3003), 22.8), (datetime(2020, 3, 2, 0, 34, 4, 4004), 28.5), -# (datetime(2020, 3, 2, 1, 35, 5, 5005), 34.2), (datetime(2020, 3, 2, 2, 36, 6, 6006), 39.900000000000006)], -# 'p1_s': [(datetime(2020, 3, 1, 21, 31, 1, 1001), 'my_str1'), (datetime(2020, 3, 1, 22, 32, 2, 2002), 'my_str2'), -# (datetime(2020, 3, 1, 23, 33, 3, 3003), 'my_str3'), (datetime(2020, 3, 2, 0, 34, 4, 4004), 'my_str4'), -# (datetime(2020, 3, 2, 1, 35, 5, 5005), 'my_str5'), (datetime(2020, 3, 2, 2, 36, 6, 6006), 'my_str6')], -# 'p1_b': [(datetime(2020, 3, 1, 21, 31, 1, 1001), False), (datetime(2020, 3, 1, 22, 32, 2, 2002), True), -# (datetime(2020, 3, 1, 23, 33, 3, 3003), False), (datetime(2020, 3, 2, 0, 34, 4, 4004), True), -# (datetime(2020, 3, 2, 1, 35, 5, 5005), False), (datetime(2020, 3, 2, 2, 36, 6, 6006), True)], -# 'p1_simple_leaf_node': [(datetime(2020, 3, 1, 20, 30), 1)], -# 'p2_i': [(datetime(2020, 3, 1, 21, 31, 1, 1001), 33), (datetime(2020, 3, 1, 22, 32, 2, 2002), 34), -# (datetime(2020, 3, 1, 23, 33, 3, 3003), 35), (datetime(2020, 3, 2, 0, 34, 4, 4004), 36), -# (datetime(2020, 3, 2, 1, 35, 5, 5005), 37), (datetime(2020, 3, 2, 2, 36, 6, 6006), 38)], -# 'p2_d': [(datetime(2020, 3, 1, 21, 31, 1, 1001), date(2021, 1, 3)), (datetime(2020, 3, 1, 22, 32, 2, 2002), date(2021, 1, 4)), -# (datetime(2020, 3, 1, 23, 33, 3, 3003), date(2021, 1, 5)), (datetime(2020, 3, 2, 0, 34, 4, 4004), date(2021, 1, 6)), -# (datetime(2020, 3, 2, 1, 35, 5, 5005), date(2021, 1, 7)), (datetime(2020, 3, 2, 2, 36, 6, 6006), date(2021, 1, 8))], -# 'p2_dt': [(datetime(2020, 3, 1, 21, 31, 1, 1001), datetime(2020, 6, 7, 1, 2, 3, 6)), -# (datetime(2020, 3, 1, 22, 32, 2, 2002), datetime(2020, 6, 8, 1, 2, 3, 7)), -# (datetime(2020, 3, 1, 23, 33, 3, 3003), datetime(2020, 6, 9, 1, 2, 3, 8)), -# (datetime(2020, 3, 2, 0, 34, 4, 4004), datetime(2020, 6, 10, 1, 2, 3, 9)), -# (datetime(2020, 3, 2, 1, 35, 5, 5005), datetime(2020, 6, 11, 1, 2, 3, 10)), -# (datetime(2020, 3, 2, 2, 36, 6, 6006), datetime(2020, 6, 12, 1, 2, 3, 11))], -# 'p2_f': [(datetime(2020, 3, 1, 21, 31, 1, 1001), 11.4), (datetime(2020, 3, 1, 22, 32, 2, 2002), 17.1), -# (datetime(2020, 3, 1, 23, 33, 3, 3003), 22.8), (datetime(2020, 3, 2, 0, 34, 4, 4004), 28.5), -# (datetime(2020, 3, 2, 1, 35, 5, 5005), 34.2), (datetime(2020, 3, 2, 2, 36, 6, 6006), 39.900000000000006)], -# 'p2_s': [(datetime(2020, 3, 1, 21, 31, 1, 1001), 'my_str1'), (datetime(2020, 3, 1, 22, 32, 2, 2002), 'my_str2'), -# (datetime(2020, 3, 1, 23, 33, 3, 3003), 'my_str3'), (datetime(2020, 3, 2, 0, 34, 4, 4004), 'my_str4'), -# (datetime(2020, 3, 2, 1, 35, 5, 5005), 'my_str5'), (datetime(2020, 3, 2, 2, 36, 6, 6006), 'my_str6')], -# 'p2_b': [(datetime(2020, 3, 1, 21, 31, 1, 1001), True), (datetime(2020, 3, 1, 22, 32, 2, 2002), False), -# (datetime(2020, 3, 1, 23, 33, 3, 3003), True), (datetime(2020, 3, 2, 0, 34, 4, 4004), False), -# (datetime(2020, 3, 2, 1, 35, 5, 5005), True), (datetime(2020, 3, 2, 2, 36, 6, 6006), False)], -# 'p2_simple_leaf_node': [(datetime(2020, 3, 1, 20, 30), 1)], -# 'named1_i': [(datetime(2020, 3, 1, 21, 31, 1, 1001), 2), (datetime(2020, 3, 1, 22, 32, 2, 2002), 3), -# (datetime(2020, 3, 1, 23, 33, 3, 3003), 4), (datetime(2020, 3, 2, 0, 34, 4, 4004), 5), -# (datetime(2020, 3, 2, 1, 35, 5, 5005), 6), (datetime(2020, 3, 2, 2, 36, 6, 6006), 7)], -# 'named1_f': [(datetime(2020, 3, 1, 21, 31, 1, 1001), 10), (datetime(2020, 3, 1, 22, 32, 2, 2002), 20), -# (datetime(2020, 3, 1, 23, 33, 3, 3003), 30), (datetime(2020, 3, 2, 0, 34, 4, 4004), 40), -# (datetime(2020, 3, 2, 1, 35, 5, 5005), 50), (datetime(2020, 3, 2, 2, 36, 6, 6006), 60)], -# 'named2_i2': [(datetime(2020, 3, 1, 21, 31, 1, 1001), 3), (datetime(2020, 3, 1, 22, 32, 2, 2002), 4), -# (datetime(2020, 3, 1, 23, 33, 3, 3003), 5), (datetime(2020, 3, 2, 0, 34, 4, 4004), 6), -# (datetime(2020, 3, 2, 1, 35, 5, 5005), 7), (datetime(2020, 3, 2, 2, 36, 6, 6006), 8)], -# 'named2_f2': [(datetime(2020, 3, 1, 21, 31, 1, 1001), 20), (datetime(2020, 3, 1, 22, 32, 2, 2002), 40), -# (datetime(2020, 3, 1, 23, 33, 3, 3003), 60), (datetime(2020, 3, 2, 0, 34, 4, 4004), 80), -# (datetime(2020, 3, 2, 1, 35, 5, 5005), 100), (datetime(2020, 3, 2, 2, 36, 6, 6006), 120)], -# 'i_sample': [(datetime(2020, 3, 1, 22, 30), 1), (datetime(2020, 3, 2, 0, 30), 3), (datetime(2020, 3, 2, 2, 30), 5)]} -# EXPECTED_FILES = ['csp_unnamed_cache/test_caching.make_sub_graph_no_part/data/2020/03/01/20200301_203000_000000-20200301_235959_999999.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_no_part/data/2020/03/02/20200302_000000_000000-20200302_030000_000000.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_no_part/dataset_meta.yml', -# 'csp_unnamed_cache/test_caching.make_sub_graph_partitioned/data/32/20210101_000000_000000/20200606_010203_000004/5.7/my_str/True/2020/03/01/20200301_203000_000000-20200301_235959_999999.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_partitioned/data/32/20210101_000000_000000/20200606_010203_000004/5.7/my_str/True/2020/03/02/20200302_000000_000000-20200302_030000_000000.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_partitioned/data/32/20210102_000000_000000/20200606_010203_000005/5.7/my_str/False/2020/03/01/20200301_203000_000000-20200301_235959_999999.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_partitioned/data/32/20210102_000000_000000/20200606_010203_000005/5.7/my_str/False/2020/03/02/20200302_000000_000000-20200302_030000_000000.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_partitioned/dataset_meta.yml', -# 'dummy_stats/sub_category/dataset1/data/2020/03/01/20200301_203000_000000-20200301_235959_999999.parquet', -# 'dummy_stats/sub_category/dataset1/data/2020/03/02/20200302_000000_000000-20200302_030000_000000.parquet', -# 'dummy_stats/sub_category/dataset1/dataset_meta.yml', -# 'dummy_stats/sub_category/dataset2/data/2020/03/01/20200301_203000_000000-20200301_235959_999999.parquet', -# 'dummy_stats/sub_category/dataset2/data/2020/03/02/20200302_000000_000000-20200302_030000_000000.parquet', -# 'dummy_stats/sub_category/dataset2/dataset_meta.yml'] -# _SPLIT_COLUMNS_EXPECTED_FILES = ['csp_unnamed_cache/test_caching.make_sub_graph_no_part/data/2020/03/01/20200301_203000_000000-20200301_235959_999999/b.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_no_part/data/2020/03/01/20200301_203000_000000-20200301_235959_999999/csp_timestamp.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_no_part/data/2020/03/01/20200301_203000_000000-20200301_235959_999999/d.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_no_part/data/2020/03/01/20200301_203000_000000-20200301_235959_999999/dt.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_no_part/data/2020/03/01/20200301_203000_000000-20200301_235959_999999/f.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_no_part/data/2020/03/01/20200301_203000_000000-20200301_235959_999999/i.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_no_part/data/2020/03/01/20200301_203000_000000-20200301_235959_999999/s.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_no_part/data/2020/03/01/20200301_203000_000000-20200301_235959_999999/simple_leaf_node.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_no_part/data/2020/03/02/20200302_000000_000000-20200302_030000_000000/b.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_no_part/data/2020/03/02/20200302_000000_000000-20200302_030000_000000/csp_timestamp.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_no_part/data/2020/03/02/20200302_000000_000000-20200302_030000_000000/d.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_no_part/data/2020/03/02/20200302_000000_000000-20200302_030000_000000/dt.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_no_part/data/2020/03/02/20200302_000000_000000-20200302_030000_000000/f.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_no_part/data/2020/03/02/20200302_000000_000000-20200302_030000_000000/i.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_no_part/data/2020/03/02/20200302_000000_000000-20200302_030000_000000/s.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_no_part/data/2020/03/02/20200302_000000_000000-20200302_030000_000000/simple_leaf_node.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_no_part/dataset_meta.yml', -# 'csp_unnamed_cache/test_caching.make_sub_graph_partitioned/data/32/20210101_000000_000000/20200606_010203_000004/5.7/my_str/True/2020/03/01/20200301_203000_000000-20200301_235959_999999/b.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_partitioned/data/32/20210101_000000_000000/20200606_010203_000004/5.7/my_str/True/2020/03/01/20200301_203000_000000-20200301_235959_999999/csp_timestamp.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_partitioned/data/32/20210101_000000_000000/20200606_010203_000004/5.7/my_str/True/2020/03/01/20200301_203000_000000-20200301_235959_999999/d.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_partitioned/data/32/20210101_000000_000000/20200606_010203_000004/5.7/my_str/True/2020/03/01/20200301_203000_000000-20200301_235959_999999/dt.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_partitioned/data/32/20210101_000000_000000/20200606_010203_000004/5.7/my_str/True/2020/03/01/20200301_203000_000000-20200301_235959_999999/f.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_partitioned/data/32/20210101_000000_000000/20200606_010203_000004/5.7/my_str/True/2020/03/01/20200301_203000_000000-20200301_235959_999999/i.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_partitioned/data/32/20210101_000000_000000/20200606_010203_000004/5.7/my_str/True/2020/03/01/20200301_203000_000000-20200301_235959_999999/s.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_partitioned/data/32/20210101_000000_000000/20200606_010203_000004/5.7/my_str/True/2020/03/01/20200301_203000_000000-20200301_235959_999999/simple_leaf_node.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_partitioned/data/32/20210101_000000_000000/20200606_010203_000004/5.7/my_str/True/2020/03/02/20200302_000000_000000-20200302_030000_000000/b.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_partitioned/data/32/20210101_000000_000000/20200606_010203_000004/5.7/my_str/True/2020/03/02/20200302_000000_000000-20200302_030000_000000/csp_timestamp.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_partitioned/data/32/20210101_000000_000000/20200606_010203_000004/5.7/my_str/True/2020/03/02/20200302_000000_000000-20200302_030000_000000/d.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_partitioned/data/32/20210101_000000_000000/20200606_010203_000004/5.7/my_str/True/2020/03/02/20200302_000000_000000-20200302_030000_000000/dt.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_partitioned/data/32/20210101_000000_000000/20200606_010203_000004/5.7/my_str/True/2020/03/02/20200302_000000_000000-20200302_030000_000000/f.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_partitioned/data/32/20210101_000000_000000/20200606_010203_000004/5.7/my_str/True/2020/03/02/20200302_000000_000000-20200302_030000_000000/i.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_partitioned/data/32/20210101_000000_000000/20200606_010203_000004/5.7/my_str/True/2020/03/02/20200302_000000_000000-20200302_030000_000000/s.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_partitioned/data/32/20210101_000000_000000/20200606_010203_000004/5.7/my_str/True/2020/03/02/20200302_000000_000000-20200302_030000_000000/simple_leaf_node.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_partitioned/data/32/20210102_000000_000000/20200606_010203_000005/5.7/my_str/False/2020/03/01/20200301_203000_000000-20200301_235959_999999/b.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_partitioned/data/32/20210102_000000_000000/20200606_010203_000005/5.7/my_str/False/2020/03/01/20200301_203000_000000-20200301_235959_999999/csp_timestamp.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_partitioned/data/32/20210102_000000_000000/20200606_010203_000005/5.7/my_str/False/2020/03/01/20200301_203000_000000-20200301_235959_999999/d.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_partitioned/data/32/20210102_000000_000000/20200606_010203_000005/5.7/my_str/False/2020/03/01/20200301_203000_000000-20200301_235959_999999/dt.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_partitioned/data/32/20210102_000000_000000/20200606_010203_000005/5.7/my_str/False/2020/03/01/20200301_203000_000000-20200301_235959_999999/f.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_partitioned/data/32/20210102_000000_000000/20200606_010203_000005/5.7/my_str/False/2020/03/01/20200301_203000_000000-20200301_235959_999999/i.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_partitioned/data/32/20210102_000000_000000/20200606_010203_000005/5.7/my_str/False/2020/03/01/20200301_203000_000000-20200301_235959_999999/s.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_partitioned/data/32/20210102_000000_000000/20200606_010203_000005/5.7/my_str/False/2020/03/01/20200301_203000_000000-20200301_235959_999999/simple_leaf_node.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_partitioned/data/32/20210102_000000_000000/20200606_010203_000005/5.7/my_str/False/2020/03/02/20200302_000000_000000-20200302_030000_000000/b.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_partitioned/data/32/20210102_000000_000000/20200606_010203_000005/5.7/my_str/False/2020/03/02/20200302_000000_000000-20200302_030000_000000/csp_timestamp.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_partitioned/data/32/20210102_000000_000000/20200606_010203_000005/5.7/my_str/False/2020/03/02/20200302_000000_000000-20200302_030000_000000/d.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_partitioned/data/32/20210102_000000_000000/20200606_010203_000005/5.7/my_str/False/2020/03/02/20200302_000000_000000-20200302_030000_000000/dt.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_partitioned/data/32/20210102_000000_000000/20200606_010203_000005/5.7/my_str/False/2020/03/02/20200302_000000_000000-20200302_030000_000000/f.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_partitioned/data/32/20210102_000000_000000/20200606_010203_000005/5.7/my_str/False/2020/03/02/20200302_000000_000000-20200302_030000_000000/i.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_partitioned/data/32/20210102_000000_000000/20200606_010203_000005/5.7/my_str/False/2020/03/02/20200302_000000_000000-20200302_030000_000000/s.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_partitioned/data/32/20210102_000000_000000/20200606_010203_000005/5.7/my_str/False/2020/03/02/20200302_000000_000000-20200302_030000_000000/simple_leaf_node.parquet', -# 'csp_unnamed_cache/test_caching.make_sub_graph_partitioned/dataset_meta.yml', -# 'dummy_stats/sub_category/dataset1/data/2020/03/01/20200301_203000_000000-20200301_235959_999999/csp_timestamp.parquet', -# 'dummy_stats/sub_category/dataset1/data/2020/03/01/20200301_203000_000000-20200301_235959_999999/f.parquet', -# 'dummy_stats/sub_category/dataset1/data/2020/03/01/20200301_203000_000000-20200301_235959_999999/i.parquet', -# 'dummy_stats/sub_category/dataset1/data/2020/03/02/20200302_000000_000000-20200302_030000_000000/csp_timestamp.parquet', -# 'dummy_stats/sub_category/dataset1/data/2020/03/02/20200302_000000_000000-20200302_030000_000000/f.parquet', -# 'dummy_stats/sub_category/dataset1/data/2020/03/02/20200302_000000_000000-20200302_030000_000000/i.parquet', 'dummy_stats/sub_category/dataset1/dataset_meta.yml', -# 'dummy_stats/sub_category/dataset2/data/2020/03/01/20200301_203000_000000-20200301_235959_999999/csp_timestamp.parquet', -# 'dummy_stats/sub_category/dataset2/data/2020/03/01/20200301_203000_000000-20200301_235959_999999/f2.parquet', -# 'dummy_stats/sub_category/dataset2/data/2020/03/01/20200301_203000_000000-20200301_235959_999999/i2.parquet', -# 'dummy_stats/sub_category/dataset2/data/2020/03/02/20200302_000000_000000-20200302_030000_000000/csp_timestamp.parquet', -# 'dummy_stats/sub_category/dataset2/data/2020/03/02/20200302_000000_000000-20200302_030000_000000/f2.parquet', -# 'dummy_stats/sub_category/dataset2/data/2020/03/02/20200302_000000_000000-20200302_030000_000000/i2.parquet', 'dummy_stats/sub_category/dataset2/dataset_meta.yml'] - -# class _EdgeOutputSettings(csp.Enum): -# FIRST_CYCLE = 0x1 -# LAST_CYCLE = 0x2 -# BOTH_EDGES = FIRST_CYCLE | LAST_CYCLE - -# def _create_graph(self, split_columns_to_files): -# func_run_count = [0] - -# def cache_options(**kwargs): -# return GraphCacheOptions(split_columns_to_files=split_columns_to_files, **kwargs) - -# @node -# def pass_through(v: ts['T']) -> ts['T']: -# with csp.start(): -# func_run_count[0] += 1 - -# if csp.ticked(v): -# return v - -# def make_curve_pass_through(f): -# values = [(timedelta(hours=v, minutes=v, seconds=v, milliseconds=v, microseconds=v), f(v)) for v in range(1, 20)] -# typ = type(values[0][1]) -# return pass_through(csp.curve(typ, values)) - -# simple_leaf_node = [None] - -# @graph(cache=True, cache_options=cache_options()) -# def make_sub_graph_no_part() -> csp.Outputs(i=ts[int], d=ts[date], dt=ts[datetime], f=ts[float], s=ts[str], b=ts[bool], simple_leaf_node=ts[int]): -# return csp.output(i=make_curve_pass_through(lambda v: v), -# d=make_curve_pass_through(lambda v: date(2020, 1, 1) + timedelta(days=v)), -# dt=make_curve_pass_through(lambda v: datetime(2020, 1, 1) + timedelta(days=v, microseconds=v)), -# f=make_curve_pass_through(lambda v: v * .2), -# s=make_curve_pass_through(str), -# b=make_curve_pass_through(lambda v: bool(v % 2)), -# simple_leaf_node=simple_leaf_node[0]) - -# @graph(cache=True, cache_options=cache_options()) -# def make_sub_graph_partitioned(i_v: int, d_v: date, dt_v: datetime, f_v: float, s_v: str, b_v: bool) -> csp.Outputs(i=ts[int], d=ts[date], dt=ts[datetime], f=ts[float], s=ts[str], b=ts[bool], simple_leaf_node=ts[int]): -# no_part_sub_graph = make_sub_graph_no_part() - -# return csp.output(i=make_curve_pass_through(lambda v: i_v + v), -# d=make_curve_pass_through(lambda v: d_v + timedelta(days=v)), -# dt=make_curve_pass_through(lambda v: dt_v + timedelta(days=v, microseconds=v)), -# f=make_curve_pass_through(lambda v: v * f_v + f_v), -# s=make_curve_pass_through(lambda v: s_v + str(v)), -# b=make_curve_pass_through(lambda v: bool(v % 2) ^ b_v), -# simple_leaf_node=no_part_sub_graph.simple_leaf_node) - -# @graph(cache=True, cache_options=cache_options(dataset_name='dataset1', category=['dummy_stats', 'sub_category'])) -# def named_managed_graph_col_set_1() -> csp.Outputs(i=ts[int], f=ts[float]): -# return csp.output(i=make_curve_pass_through(lambda v: v + 1), -# f=make_curve_pass_through(lambda v: v * 10.0)) - -# @graph(cache=True, cache_options=cache_options(dataset_name='dataset2', category=['dummy_stats', 'sub_category'])) -# def named_managed_graph_col_set_2() -> csp.Outputs(i2=ts[int], f2=ts[float]): -# return csp.output(i2=make_curve_pass_through(lambda v: v + 2), -# f2=make_curve_pass_through(lambda v: v * 20.0)) - -# @graph -# def my_graph(require_cached: bool = False): -# self.maxDiff = 20000 -# simple_leaf_node[0] = pass_through(csp.const(1)) -# sub_graph = make_sub_graph_no_part() -# sub_graph_partitioned = make_sub_graph_partitioned.cached if require_cached else make_sub_graph_partitioned -# named_managed_graph_col_set_1_g = named_managed_graph_col_set_1.cached if require_cached else named_managed_graph_col_set_1 -# named_managed_graph_col_set_2_g = named_managed_graph_col_set_2.cached if require_cached else named_managed_graph_col_set_2 -# sub_graph_part_1 = sub_graph_partitioned(i_v=32, d_v=date(2021, 1, 1), dt_v=datetime(2020, 6, 6, 1, 2, 3, 4), -# f_v=5.7, s_v='my_str', b_v=True) -# sub_graph_part_2 = sub_graph_partitioned(i_v=32, d_v=date(2021, 1, 2), dt_v=datetime(2020, 6, 6, 1, 2, 3, 5), -# f_v=5.7, s_v='my_str', b_v=False) -# named_col_set_1 = named_managed_graph_col_set_1_g() -# named_col_set_2 = named_managed_graph_col_set_2_g() -# for k in sub_graph: -# csp.add_graph_output(k, sub_graph[k]) -# for k in sub_graph_part_1: -# csp.add_graph_output(f'p1_{k}', sub_graph_part_1[k]) -# for k in sub_graph_part_2: -# csp.add_graph_output(f'p2_{k}', sub_graph_part_2[k]) -# for k in named_col_set_1: -# csp.add_graph_output(f'named1_{k}', named_col_set_1[k]) -# for k in named_col_set_2: -# csp.add_graph_output(f'named2_{k}', named_col_set_2[k]) -# csp.add_graph_output('i_sample', pass_through(csp.sample(csp.timer(timedelta(hours=2), 1), sub_graph.i))) - -# return func_run_count, my_graph - -# def test_simple_graph(self): -# for split_columns_to_files in (True, False): -# with csp.memoize(False): -# func_run_count, my_graph = self._create_graph(split_columns_to_files=split_columns_to_files) - -# with _GraphTempCacheFolderConfig() as config: -# g1 = csp.run(my_graph, starttime=datetime(2020, 3, 1, 20, 30), endtime=timedelta(hours=0, minutes=390), -# config=config) -# self.assertTrue(len(g1) > 0) -# func_run_count1 = func_run_count[0] -# # leaf node is that same in all, it's repeated 3 times -# self.assertEqual(len(g1) - 2, func_run_count1) -# self.assertEqual(g1, self.EXPECTED_OUTPUT_TEST_SIMPLE) -# g2 = csp.run(my_graph, starttime=datetime(2020, 3, 1, 20, 30), endtime=timedelta(hours=0, minutes=390), -# config=config) -# self.assertEqual(g1, g2) -# func_run_count2 = func_run_count[0] -# # When the sub graph is read from cache, we only have one "pass_through" for i_sample -# self.assertEqual(func_run_count1 + 1, func_run_count2) -# g3 = csp.run(my_graph, require_cached=True, starttime=datetime(2020, 3, 1, 20, 30), endtime=timedelta(hours=0, minutes=390), -# config=config) -# func_run_count3 = func_run_count[0] -# self.assertEqual(g1, g3) -# self.assertEqual(func_run_count2 + 1, func_run_count3) -# files_in_cache = self._get_files_in_cache(config) - -# if split_columns_to_files: -# aux_files = [] -# for f in files_in_cache: -# if f.endswith('.parquet'): -# aux_files.append(os.path.dirname(f) + '.parquet') -# else: -# aux_files.append(f) -# aux_files = sorted(set(aux_files)) -# self.assertEqual(aux_files, self.EXPECTED_FILES) -# self.assertEqual(files_in_cache, self._SPLIT_COLUMNS_EXPECTED_FILES) -# else: -# self.assertEqual(files_in_cache, self.EXPECTED_FILES) - -# def _get_files_in_cache(self, config): -# all_files_and_folders = sorted(glob.glob(f'{config.cache_config.data_folder}/**', recursive=True)) -# files_in_cache = [v.replace(f'{config.cache_config.data_folder}/', '') for v in all_files_and_folders if os.path.isfile(v)] -# # When we right from command line, the tests import paths differ. So let's support it as well -# files_in_cache = [f.replace('csp.tests.test_caching', 'test_caching') for f in files_in_cache] -# files_in_cache = [f.replace('/csp.tests.', '/') for f in files_in_cache] -# return files_in_cache - -# def test_no_cache(self): -# for split_columns_to_files in (True, False): -# func_run_count, my_graph_func = self._create_graph(split_columns_to_files=split_columns_to_files) -# g1 = csp.run(my_graph_func, starttime=datetime(2020, 3, 1, 20, 30), endtime=timedelta(hours=0, minutes=390)) -# func_run_count1 = func_run_count[0] -# g2 = csp.run(my_graph_func, starttime=datetime(2020, 3, 1, 20, 30), endtime=timedelta(hours=0, minutes=390)) -# func_run_count2 = func_run_count[0] -# self.assertEqual(g1, g2) -# self.assertEqual(func_run_count1 * 2, func_run_count2) - -# with self.assertRaises(NoCachedDataException): -# g3 = csp.run(my_graph_func, require_cached=True, starttime=datetime(2020, 3, 1, 20, 30), endtime=timedelta(hours=0, minutes=390)) - -# def _get_all_files(self, config): -# return sorted(glob.glob(f'{config.cache_config.data_folder}/**/*.parquet', recursive=True)) - -# def _get_default_graph_caching_kwargs(self, split_columns_to_files): -# if split_columns_to_files: -# graph_kwargs = {'cache_options': GraphCacheOptions(split_columns_to_files=True)} -# else: -# graph_kwargs = {} -# return graph_kwargs - -# def test_merge(self): -# for merge_existing_files in (True, False): -# for split_columns_to_files in (True, False): -# graph_kwargs = self._get_default_graph_caching_kwargs(split_columns_to_files) - -# def _time_to_seconds(t): -# return t.hour * 3600 + t.minute * 60 + t.second - -# @csp.node() -# def my_node() -> csp.Outputs(hours=ts[int], minutes=ts[int], seconds=ts[int]): -# with csp.alarms(): -# alarm = csp.alarm( int ) -# with csp.start(): -# csp.schedule_alarm(alarm, timedelta(), _time_to_seconds(csp.now())) -# if csp.ticked(alarm): -# csp.schedule_alarm(alarm, timedelta(seconds=60), alarm + 60) -# return csp.output(hours=alarm // 3600, minutes=alarm // 60, seconds=alarm) - -# @csp.graph(cache=True, **graph_kwargs) -# def sub_graph() -> csp.Outputs(hours=ts[int], minutes=ts[int], seconds=ts[int]): -# node = my_node() -# return csp.output(hours=node.hours, minutes=node.minutes, seconds=node.seconds) - -# def _validate_file_df(g, start_time, dt, g_start=None, g_end=None): -# end_time = start_time + dt if isinstance(dt, timedelta) else dt -# g_start = g_start if g_start else start_time -# g_end = g_end if g_end else end_time -# g_end = g_start + g_end if isinstance(g_end, timedelta) else g_end -# df = sub_graph.cached_data(config.cache_config.data_folder)().get_data_df_for_period(start_time, dt) - -# self.assertTrue((df.seconds.diff()[1:] == 60).all()) -# self.assertTrue((df.minutes == df.seconds // 60).all()) -# self.assertTrue((df.hours == df.seconds // 3600).all()) -# self.assertTrue(df.iloc[-1]['csp_timestamp'] == end_time) -# self.assertTrue(df.iloc[0]['csp_timestamp'] == start_time) -# self.assertTrue(df.iloc[0]['seconds'] == _time_to_seconds(start_time)) -# self.assertEqual(g['seconds'][0][1], _time_to_seconds(g_start)) -# self.assertEqual(g['seconds'][-1][1], _time_to_seconds(g_end)) - -# def graph(): -# res = sub_graph() -# csp.add_graph_output('seconds', res.seconds) - -# with _GraphTempCacheFolderConfig(allow_overwrite=True, merge_existing_files=merge_existing_files) as config: -# missing_range_handler = lambda start, end: True -# start_time1 = datetime(2020, 3, 1, 9, 30, tzinfo=pytz.utc) -# dt1 = timedelta(hours=0, minutes=60) -# g = csp.run(graph, starttime=start_time1, endtime=dt1, config=config) -# files = list(sub_graph.cached_data(config.cache_config.data_folder)().get_data_files_for_period().values()) -# self.assertEqual(len(files), 1) -# _validate_file_df(g, start_time1, dt1) - -# start_time2 = start_time1 + timedelta(minutes=180) -# dt2 = dt1 -# g = csp.run(graph, starttime=start_time2, endtime=dt2, config=config) -# files = list(sub_graph.cached_data(config.cache_config.data_folder)().get_data_files_for_period(missing_range_handler=missing_range_handler).values()) - -# if config.cache_config.merge_existing_files: -# self.assertEqual(len(files), 2) -# _validate_file_df(g, start_time2, dt2) -# # Test repeated writing of the same file -# g = csp.run(graph, starttime=start_time2, endtime=dt2, config=config) -# _validate_file_df(g, start_time2, dt2) - -# start_time3 = start_time2 + dt2 - timedelta(minutes=5) -# dt3 = timedelta(minutes=15) -# g = csp.run(graph, starttime=start_time3, endtime=dt3, config=config) -# files = list(sub_graph.cached_data(config.cache_config.data_folder)().get_data_files_for_period(missing_range_handler=missing_range_handler).values()) - -# if config.cache_config.merge_existing_files: -# self.assertEqual(len(files), 2) -# _validate_file_df(g, start_time2, start_time3 + dt3, g_start=start_time3, g_end=dt3) - -# start_time4 = start_time2 - timedelta(minutes=5) -# dt4 = timedelta(minutes=15) -# g = csp.run(graph, starttime=start_time4, endtime=dt4, config=config) -# files = list(sub_graph.cached_data(config.cache_config.data_folder)().get_data_files_for_period(missing_range_handler=missing_range_handler).values()) -# if config.cache_config.merge_existing_files: -# self.assertEqual(len(files), 2) -# _validate_file_df(g, start_time4, start_time3 + dt3, g_start=start_time4, g_end=dt4) - -# start_time5 = start_time1 + timedelta(minutes=40) -# dt5 = timedelta(minutes=200) -# g = csp.run(graph, starttime=start_time5, endtime=dt5, config=config) -# files = list(sub_graph.cached_data(config.cache_config.data_folder)().get_data_files_for_period().values()) -# if config.cache_config.merge_existing_files: -# self.assertEqual(len(files), 1) -# _validate_file_df(g, start_time1, start_time3 + dt3, g_start=start_time5, g_end=dt5) - -# g = csp.run(graph, starttime=start_time1 + timedelta(minutes=10), endtime=dt1, config=config) -# files = list(sub_graph.cached_data(config.cache_config.data_folder)().get_data_files_for_period().values()) -# if config.cache_config.merge_existing_files: -# self.assertEqual(len(files), 1) -# _validate_file_df(g, start_time1, start_time3 + dt3, g_start=start_time1 + timedelta(minutes=10), g_end=dt1) - -# start_time6 = start_time1 - timedelta(minutes=10) -# dt6 = start_time3 + dt3 + timedelta(minutes=10) -# g = csp.run(graph, starttime=start_time6, endtime=dt6, config=config) -# files = list(sub_graph.cached_data(config.cache_config.data_folder)().get_data_files_for_period().values()) -# self.assertEqual(len(files), 1) -# _validate_file_df(g, start_time6, dt6) - -# def test_folder_overrides(self): -# for split_columns_to_files in (True, False): -# start_time = datetime(2020, 3, 1, 20, 30) -# end_time = start_time + timedelta(seconds=1) - -# @csp.graph(cache=True) -# def g1() -> csp.Outputs(o=csp.ts[int]): -# return csp.output(o=csp.null_ts(int)) - -# @csp.graph(cache=True, cache_options=GraphCacheOptions(category=['C1', 'C2', 'C3'], split_columns_to_files=split_columns_to_files)) -# def g2() -> csp.Outputs(o=csp.ts[int]): -# return csp.output(o=csp.null_ts(int)) - -# @csp.graph(cache=True, cache_options=GraphCacheOptions(category=['C1', 'C2', 'C3_2'], split_columns_to_files=split_columns_to_files)) -# def g3() -> csp.Outputs(o=csp.ts[int]): -# return csp.output(o=csp.null_ts(int)) - -# @csp.graph(cache=True, cache_options=GraphCacheOptions(category=['C1', 'C2'], split_columns_to_files=split_columns_to_files)) -# def g4() -> csp.Outputs(o=csp.ts[int]): -# return csp.output(o=csp.null_ts(int)) - -# @csp.graph(cache=True, cache_options=GraphCacheOptions(category=['C1'], split_columns_to_files=split_columns_to_files)) -# def g5() -> csp.Outputs(o=csp.ts[int]): -# return csp.output(o=csp.null_ts(int)) - -# @csp.graph(cache=True, cache_options=GraphCacheOptions(dataset_name='named_dataset1', split_columns_to_files=split_columns_to_files)) -# def g6() -> csp.Outputs(o=csp.ts[int]): -# return csp.output(o=csp.null_ts(int)) - -# @csp.graph(cache=True, cache_options=GraphCacheOptions(dataset_name='named_dataset2', category=['C1', 'C2'], split_columns_to_files=split_columns_to_files)) -# def g7() -> csp.Outputs(o=csp.ts[int]): -# return csp.output(o=csp.null_ts(int)) - -# @csp.graph(cache=True, cache_options=GraphCacheOptions(dataset_name='named_dataset3', category=['C1', 'C2'], split_columns_to_files=split_columns_to_files)) -# def g8() -> csp.Outputs(o=csp.ts[int]): -# return csp.output(o=csp.null_ts(int)) - -# @csp.graph -# def g(): -# g1(), g2(), g3(), g4(), g5(), g6(), g7(), g8() - -# def _get_data_folders_for_config(config): -# all_folders = sorted({os.path.dirname(v) for v in self._get_files_in_cache(config)}) -# return sorted({v[:v.index('/data')] for v in all_folders if '/data' in v}) - -# with _GraphTempCacheFolderConfig() as config: -# with _GraphTempCacheFolderConfig() as config2: -# config_copy = Config.from_dict(config.to_dict()) -# root_folder = config_copy.cache_config.data_folder -# config_copy.cache_config.data_folder = os.path.join(root_folder, "default_output_folder") -# config_copy.cache_config.category_overrides = [ -# CacheCategoryConfig(category=['C1'], -# data_folder=os.path.join(root_folder, 'C1_O')), -# CacheCategoryConfig(category=['C1', 'C2'], -# data_folder=os.path.join(root_folder, 'C1_C2_O')), -# CacheCategoryConfig(category=['C1', 'C2', 'C3'], -# data_folder=os.path.join(root_folder, 'C1_C2_C3_O')) -# ] -# config_copy.cache_config.graph_overrides = {g8: BaseCacheConfig(data_folder=config2.cache_config.data_folder)} -# csp.run(g, starttime=start_time, endtime=end_time, config=config_copy) -# data_folders = _get_data_folders_for_config(config) -# data_folders2 = _get_data_folders_for_config(config2) -# expected_dataset_folders = { -# 'g1': 'default_output_folder/csp_unnamed_cache/test_caching.g1', 'g2': 'C1_C2_C3_O/C1/C2/C3/test_caching.g2', -# 'g3': 'C1_C2_O/C1/C2/C3_2/test_caching.g3', 'g4': 'C1_C2_O/C1/C2/test_caching.g4', 'g5': 'C1_O/C1/test_caching.g5', -# 'g6': 'default_output_folder/csp_unnamed_cache/named_dataset1', 'g7': 'C1_C2_O/C1/C2/named_dataset2'} -# expected_dataset_folders2 = {'g8': 'C1/C2/named_dataset3'} -# self.assertEqual(data_folders, sorted(expected_dataset_folders.values())) -# self.assertEqual(data_folders2, sorted(expected_dataset_folders2.values())) - -# full_path = lambda v: os.path.join(root_folder, v) -# get_data_files = lambda g, f: g.cached_data(full_path(f))().get_data_files_for_period(start_time, end_time) - -# self.assertEqual(1, len(get_data_files(g1, "default_output_folder"))) -# self.assertEqual(1, len(get_data_files(g2, "C1_C2_C3_O"))) -# self.assertEqual(1, len(get_data_files(g3, "C1_C2_O"))) -# self.assertEqual(1, len(get_data_files(g4, "C1_C2_O"))) -# self.assertEqual(1, len(get_data_files(g5, "C1_O"))) -# self.assertEqual(1, len(get_data_files(g6, "default_output_folder"))) -# self.assertEqual(1, len(get_data_files(g7, "C1_C2_O"))) - -# data_path_resolver = CacheConfigResolver(config_copy.cache_config) -# get_data_files = lambda g: g.cached_data(data_path_resolver)().get_data_files_for_period(start_time, end_time) -# self.assertEqual(1, len(get_data_files(g1))) -# self.assertEqual(1, len(get_data_files(g2))) -# self.assertEqual(1, len(get_data_files(g3))) -# self.assertEqual(1, len(get_data_files(g4))) -# self.assertEqual(1, len(get_data_files(g5))) -# self.assertEqual(1, len(get_data_files(g6))) -# self.assertEqual(1, len(get_data_files(g7))) -# self.assertEqual(1, len(get_data_files(g8))) - -# def test_caching_reads_only_needed_columns(self): -# for split_columns_to_files in (True, False): -# graph_kwargs = self._get_default_graph_caching_kwargs(split_columns_to_files) - -# class MyS(csp.Struct): -# x: int -# y: int - -# @graph(cache=True, **graph_kwargs) -# def g(s: str) -> csp.Outputs(o=csp.ts[MyS]): -# t = csp.engine_start_time() -# o_ts = csp.curve(MyS, [(t + timedelta(seconds=v), MyS(x=v, y=v * 2)) for v in range(20)]) -# return csp.output(o=o_ts) - -# @graph -# def g_x_reader(s: str) -> csp.Outputs(o=csp.ts[int]): -# return csp.count(g('A').o.x) - -# @graph -# def g_delayed_demux(s: str) -> csp.ts[int]: -# demux = csp.DelayedDemultiplex(g('A').o.x, g('A').o.x) -# return csp.count(demux.demultiplex(1)) - -# with _GraphTempCacheFolderConfig() as config: -# csp.run(g, 'A', starttime=datetime(2020, 1, 1), endtime=timedelta(seconds=20), config=config) -# files = g.cached_data(config.cache_config.data_folder)('A').get_data_files_for_period(datetime(2020, 1, 1), datetime(2020, 1, 1) + timedelta(seconds=20)) -# self.assertEqual(len(files), 1) -# file = next(iter(files.values())) -# if split_columns_to_files: -# # Let's fake the data file by removing the column y -# file_to_remove = os.path.join(file, 'o.y.parquet') -# self.assertTrue(os.path.exists(file_to_remove)) -# os.unlink(file_to_remove) -# self.assertFalse(os.path.exists(file_to_remove)) -# with self.assertRaisesRegex(Exception, r'.*IOError.*Failed to open .*o\.y.*'): -# csp.run(g, 'A', starttime=datetime(2020, 1, 1), endtime=timedelta(seconds=20), config=config) -# else: -# df = pandas.read_parquet(file) -# # Let's fake the data file by removing the column y. We want to make sure that we don't attempt to read column y -# df.drop(columns=['o.y']).to_parquet(file) -# with self.assertRaisesRegex(RuntimeError, r'Missing column o\.y.*'): -# csp.run(g, 'A', starttime=datetime(2020, 1, 1), endtime=timedelta(seconds=20), config=config) -# # This should not raise since we don't try to read the y column -# csp.run(g_x_reader, 'A', starttime=datetime(2020, 1, 1), endtime=timedelta(seconds=20), config=config) -# csp.run(g_delayed_demux, 'A', starttime=datetime(2020, 1, 1), endtime=timedelta(seconds=20), config=config) - -# def test_enum_serialization(self): -# for split_columns_to_files in (True, False): -# graph_kwargs = self._get_default_graph_caching_kwargs(split_columns_to_files) - -# class MyEnum(csp.Enum): -# X = csp.Enum.auto() -# Y = csp.Enum.auto() -# ZZZ = csp.Enum.auto() - -# raiseExc = [False] - -# @graph(cache=True, **graph_kwargs) -# def g(s: str) -> csp.Outputs(o=csp.ts[MyEnum]): -# if raiseExc[0]: -# raise RuntimeError("Shouldn't get here") -# o_ts = csp.curve(MyEnum, [(timedelta(seconds=1), MyEnum.X), (timedelta(seconds=1), MyEnum.Y), (timedelta(seconds=2), MyEnum.ZZZ), (timedelta(seconds=3), MyEnum.X)]) -# return csp.output(o=o_ts) - -# from csp.utils.qualified_name_utils import QualifiedNameUtils -# QualifiedNameUtils.register_type(MyEnum) - -# with _GraphTempCacheFolderConfig() as config: -# csp.run(g, 'A', starttime=datetime(2020, 1, 1), endtime=timedelta(seconds=20), config=config) -# raiseExc[0] = True -# cached_res = csp.run(g, 'A', starttime=datetime(2020, 1, 1), endtime=timedelta(seconds=20), config=config) -# enum_values = [v[1] for v in cached_res['o']] -# data_df = g.cached_data(config.cache_config.data_folder)('A').get_data_df_for_period() -# self.assertEqual(data_df['o'].tolist(), ['X', 'Y', 'ZZZ', 'X']) -# self.assertEqual(enum_values, [MyEnum.X, MyEnum.Y, MyEnum.ZZZ, MyEnum.X]) - -# def test_enum_field_serialization(self): -# for split_columns_to_files in (True, False): -# graph_kwargs = self._get_default_graph_caching_kwargs(split_columns_to_files) -# from csp.tests.impl.test_enum import MyEnum - -# class MyStruct(csp.Struct): -# e: MyEnum - -# raiseExc = [False] - -# @graph(cache=True, **graph_kwargs) -# def g(s: str) -> csp.Outputs(o=csp.ts[MyStruct]): -# if raiseExc[0]: -# raise RuntimeError("Shouldn't get here") -# make_s = lambda v: MyStruct(e=v) if v is not None else MyStruct() -# o_ts = csp.curve(MyStruct, [(timedelta(seconds=1), make_s(MyEnum.A)), (timedelta(seconds=1), make_s(MyEnum.B)), -# (timedelta(seconds=2), make_s(MyEnum.C)), (timedelta(seconds=3), make_s(MyEnum.A)), -# (timedelta(seconds=4), make_s(None))]) -# return csp.output(o=o_ts) - -# with _GraphTempCacheFolderConfig() as config: -# csp.run(g, 'A', starttime=datetime(2020, 1, 1), endtime=timedelta(seconds=20), config=config) -# raiseExc[0] = True -# cached_res = csp.run(g, 'A', starttime=datetime(2020, 1, 1), endtime=timedelta(seconds=20), config=config) -# enum_values = [v[1].e if hasattr(v[1], 'e') else None for v in cached_res['o']] -# data_df = g.cached_data(config.cache_config.data_folder)('A').get_data_df_for_period() -# self.assertEqual(data_df['o.e'].tolist(), ['A', 'B', 'C', 'A', None]) -# self.assertEqual(enum_values, [MyEnum.A, MyEnum.B, MyEnum.C, MyEnum.A, None]) - -# def test_nested_struct_caching(self): -# for split_columns_to_files in (True, False): -# graph_kwargs = self._get_default_graph_caching_kwargs(split_columns_to_files) -# from csp.tests.impl.test_enum import MyEnum -# class MyStruct1(csp.Struct): -# v_int: int -# v_str: str -# e: MyEnum - -# class MyStruct2(csp.Struct): -# v: MyStruct1 -# v_float: float - -# class MyStruct3(MyStruct2): -# v2: MyStruct2 - -# from csp.utils.qualified_name_utils import QualifiedNameUtils -# QualifiedNameUtils.register_type(MyStruct1) -# QualifiedNameUtils.register_type(MyStruct2) - -# raiseExc = [False] - -# struct_values = [MyStruct3(), -# MyStruct3(v=MyStruct1(v_int=1)), -# MyStruct3(v=MyStruct1(v_int=2)), -# MyStruct3(v=MyStruct1(v_int=3, v_str='3_val')), -# MyStruct3(v=MyStruct1(v_str='4_val')), -# MyStruct3(v=MyStruct1(v_str='5_val'), v2=MyStruct2(v_float=5.5, v=MyStruct1(v_int=6, v_str='6_val', e=MyEnum.B)), v_float=6.5), -# MyStruct3(v=MyStruct1()) -# ] - -# @graph(cache=True, **graph_kwargs) -# def g() -> csp.Outputs(o=csp.ts[MyStruct3]): -# if raiseExc[0]: -# raise RuntimeError("Shouldn't get here") -# o_ts = csp.curve(MyStruct3, [(timedelta(seconds=i), v) for i, v in enumerate(struct_values)]) -# return csp.output(o=o_ts) - -# @graph -# def g2(): -# csp.add_graph_output('o', g().o) -# csp.add_graph_output('o.v', g().o.v) -# csp.add_graph_output('o.v_float', g().o.v_float) - -# @graph -# def g3(): -# csp.add_graph_output('o.v_float', g().o.v_float) - -# with _GraphTempCacheFolderConfig() as config: -# csp.run(g, starttime=datetime(2020, 1, 1), endtime=timedelta(seconds=20), config=config) -# raiseExc[0] = True -# cached_res = csp.run(g2, starttime=datetime(2020, 1, 1), endtime=timedelta(seconds=20), config=config) -# cached_float = csp.run(g3, starttime=datetime(2020, 1, 1), endtime=timedelta(seconds=20), config=config) -# cached_values = list(zip(*cached_res['o']))[1] -# cached_v_values = list(zip(*cached_res['o.v']))[1] -# expected_v_values = [getattr(v, 'v') for v in cached_values if hasattr(v, 'v')] -# self.assertEqual(len(struct_values), len(cached_values)) -# for v1, v2 in zip(struct_values, cached_values): -# self.assertEqual(v1, v2) -# self.assertEqual(len(cached_v_values), len(expected_v_values)) -# for v1, v2 in zip(cached_v_values, expected_v_values): -# self.assertEqual(v1, v2) -# self.assertEqual(cached_float['o.v_float'], cached_res['o.v_float']) - -# def test_caching_same_timestamp_with_missing_values(self): -# for split_columns_to_files in (True, False): -# graph_kwargs = self._get_default_graph_caching_kwargs(split_columns_to_files) - -# @csp.node -# def my_node() -> csp.Outputs(v1=csp.ts[int], v2=csp.ts[int], v3=csp.ts[int]): -# with csp.alarms(): -# a = csp.alarm( int ) -# with csp.start(): -# csp.schedule_alarm(a, timedelta(0), 0) -# csp.schedule_alarm(a, timedelta(0), 1) -# csp.schedule_alarm(a, timedelta(0), 2) -# csp.schedule_alarm(a, timedelta(0), 3) -# if csp.ticked(a): -# if a == 0: -# csp.output(v1=10 + a, v2=20 + a) -# elif a == 1: -# csp.output(v1=10 + a, v3=30 + a) -# else: -# csp.output(v1=10 + a, v2=20 + a, v3=30 + a) - -# @graph(cache=True, **graph_kwargs) -# def g() -> csp.Outputs(v1=csp.ts[int], v2=csp.ts[int], v3=csp.ts[int]): -# outs = my_node() -# return csp.output(v1=outs.v1, v2=outs.v2, v3=outs.v3) - -# @graph -# def main(): -# csp.add_graph_output('l', csp_sorted(csp.collect([g().v1, g().v2, g().v3]))) - -# with _GraphTempCacheFolderConfig() as config: -# out1 = csp.run(main, starttime=datetime(2020, 1, 1), endtime=timedelta(seconds=20), config=config) -# out2 = csp.run(main, starttime=datetime(2020, 1, 1), endtime=timedelta(seconds=20), config=config) -# self.assertEqual(out1, out2) - -# def test_timestamp_with_nanos_caching(self): -# for split_columns_to_files in (True, False): -# graph_kwargs = self._get_default_graph_caching_kwargs(split_columns_to_files) -# timestamp_value = pandas.Timestamp('2020-01-01 00:00:00') + pandas.to_timedelta(123, 'ns') - -# @csp.node -# def my_node() -> csp.ts[datetime]: -# with csp.alarms(): -# a = csp.alarm( datetime ) -# with csp.start(): -# csp.schedule_alarm(a, timedelta(seconds=1), timestamp_value) -# if csp.ticked(a): -# return a - -# @graph(cache=True, **graph_kwargs) -# def g() -> csp.Outputs(t=csp.ts[datetime]): -# return csp.output(t=my_node()) - -# with _GraphTempCacheFolderConfig() as config: -# csp.run(g, starttime=datetime(2020, 1, 1), endtime=timedelta(seconds=20), config=config) - -# data_path_resolver = CacheConfigResolver(config.cache_config) - -# data_df = g.cached_data(data_path_resolver)().get_data_df_for_period() - -# self.assertEqual(timestamp_value.nanosecond, 123) -# self.assertEqual(len(data_df), 1) -# self.assertEqual(data_df.t.iloc[0].tz_localize(None), timestamp_value) - -# def test_unsupported_basket_caching(self): -# with self.assertRaisesRegex(NotImplementedError, "Caching of list basket outputs is unsupported"): -# @csp.graph(cache=True) -# def g_bad() -> csp.Outputs(list_basket=[csp.ts[str]]): -# raise RuntimeError() - -# with self.assertRaisesRegex(TypeError, "Cached output basket dict_basket must have shape provided using with_shape or with_shape_of"): -# @csp.graph(cache=True) -# def g_bad() -> csp.Outputs(dict_basket=csp.OutputBasket(Dict[str, csp.ts[str]])): -# raise RuntimeError() - -# with self.assertRaisesRegex(RuntimeError, "Cached graph with output basket must set split_columns_to_files to True"): -# @csp.graph(cache=True, cache_options=GraphCacheOptions(split_columns_to_files=False)) -# def g_bad() -> csp.Outputs(dict_basket=csp.OutputBasket(Dict[str, csp.ts[str]], shape=[1,2,3])): -# raise RuntimeError() -# # TODO: add shape validation check here - -# def test_simple_dict_basket_caching(self): -# def shape_func(l=None): -# if l is None: -# return ['x', 'y', 'z'] -# return l - -# @csp.node -# def my_node() -> csp.Outputs(scalar1=csp.ts[int], dict_basket= -# csp.OutputBasket(Dict[str, csp.ts[int]], shape=shape_func()), scalar=csp.ts[int]): -# with csp.alarms(): -# a_index = csp.alarm( int ) -# with csp.start(): -# csp.schedule_alarm(a_index, timedelta(), 0) -# if csp.ticked(a_index) and a_index < 10: -# if a_index == 1: -# csp.schedule_alarm(a_index, timedelta(), 2) -# else: -# csp.schedule_alarm(a_index, timedelta(seconds=1), a_index + 1) - -# if a_index == 0: -# csp.output(scalar1=1, dict_basket={'x': 1, 'y': 2, 'z': 3}, scalar=2) -# elif a_index == 1: -# csp.output(dict_basket={'x': 2, 'z': 3}, scalar=3) -# elif a_index == 2: -# csp.output(dict_basket={'x': 3, 'z': 34}) -# elif a_index == 3: -# csp.output(scalar1=5) -# elif a_index == 4: -# csp.output(dict_basket={'x': 45}) - -# @csp.graph(cache=True) -# def g_bad() -> csp.Outputs(scalar1=csp.ts[int], dict_basket=csp.OutputBasket(Dict[str, csp.ts[int]], shape=shape_func()), scalar=csp.ts[int]): -# # __outputs__(dict_basket={'T': csp.ts['K']}.with_shape(shape_func(['xx']))) -# # -# # return csp.output( dict_basket={'xx': csp.const(1)}) - -# return csp.output(scalar1=my_node().scalar1, dict_basket=my_node().dict_basket, scalar=my_node().scalar) - -# # @csp.node -# # def g_bad(): -# # __outputs__(scalar1=csp.ts[int], dict_basket={'T': csp.ts['K']}.with_shape(shape_func()), scalar=csp.ts[int]) -# # return csp.output(scalar1=5, dict_basket={'x': 1}, scalar=3) - -# @graph -# def run_graph(g: object): -# g_bad() - -# with _GraphTempCacheFolderConfig() as config: -# csp.run(run_graph, g_bad, starttime=datetime(2020, 1, 1), endtime=timedelta(seconds=30), config=config) - -# def test_simple_basket_caching(self): -# for typ in (int, bool, float, str, datetime, date, TypedCurveGenerator.SimpleEnum, TypedCurveGenerator.SimpleStruct, TypedCurveGenerator.NestedStruct): -# @graph(cache=True) -# def cached_graph() -> csp.Outputs(v1=csp.OutputBasket(Dict[str, csp.ts[typ]], shape=['0', '', '2']), v2=csp.ts[int]): -# curve_generator = TypedCurveGenerator() - -# return csp.output(v1={'0': curve_generator.gen_transformed_curve(typ, 0, 100, 1, skip_indices=[5, 6, 7], duplicate_timestamp_indices=[8, 9]), -# '': curve_generator.gen_transformed_curve(typ, 13, 100, 1, skip_indices=[5, 7, 9]), -# '2': curve_generator.gen_transformed_curve(typ, 27, 100, 1, skip_indices=[5, 6]) -# }, -# v2=curve_generator.gen_int_curve(100, 10, 1, skip_indices=[2], duplicate_timestamp_indices=[7, 8])) - -# @graph -# def run_graph(force_cached: bool = False): -# g = cached_graph.cached if force_cached else cached_graph -# csp.add_graph_output('v1[0]', g().v1['0']) -# csp.add_graph_output('v1[1]', g().v1['']) -# csp.add_graph_output('v1[2]', g().v1['2']) -# csp.add_graph_output('v2', g().v2) - -# with _GraphTempCacheFolderConfig() as config: -# res = csp.run(run_graph, starttime=datetime(2020, 1, 1), endtime=timedelta(seconds=30), config=config) -# res2 = csp.run(run_graph, True, starttime=datetime(2020, 1, 1), endtime=timedelta(seconds=30), config=config) -# self.assertEqual(res, res2) - -# def test_basket_caching_first_last_cycles(self): -# for basket_edge_settings in self._EdgeOutputSettings: -# for scalar_edge_settings in self._EdgeOutputSettings: -# @graph(cache=True) -# def cached_graph(basket_edge_settings: self._EdgeOutputSettings, scalar_edge_settings: self._EdgeOutputSettings) -> csp.Outputs( -# v1=csp.OutputBasket(Dict[str, csp.ts[int]], shape=['0', '1']), v2=csp.ts[int]): -# curve_generator = TypedCurveGenerator() -# output_scalar_on_initial_cycle = bool(scalar_edge_settings.value & self._EdgeOutputSettings.FIRST_CYCLE.value) -# output_basket_on_initial_cycle = bool(basket_edge_settings.value & self._EdgeOutputSettings.FIRST_CYCLE.value) -# basket_skip_indices = [] if bool(basket_edge_settings.value & self._EdgeOutputSettings.LAST_CYCLE.value) else [3] -# scalar_skip_indices = [] if bool(scalar_edge_settings.value & self._EdgeOutputSettings.LAST_CYCLE.value) else [3] - -# return csp.output(v1={'0': curve_generator.gen_int_curve(0, 3, 1, output_on_initial_cycle=output_basket_on_initial_cycle, skip_indices=basket_skip_indices), -# '1': curve_generator.gen_int_curve(13, 3, 1, output_on_initial_cycle=output_basket_on_initial_cycle, skip_indices=basket_skip_indices)}, -# v2=curve_generator.gen_int_curve(100, 3, 1, output_on_initial_cycle=output_scalar_on_initial_cycle, skip_indices=scalar_skip_indices)) - -# @graph -# def run_graph(basket_edge_settings: self._EdgeOutputSettings, scalar_edge_settings: self._EdgeOutputSettings, force_cached: bool = False): -# g = cached_graph.cached if force_cached else cached_graph -# g_res = g(basket_edge_settings, scalar_edge_settings) -# csp.add_graph_output('v1[0]', g_res.v1['0']) -# csp.add_graph_output('v1[1]', g_res.v1['1']) -# csp.add_graph_output('v2', g_res.v2) - -# with _GraphTempCacheFolderConfig() as config: -# res = csp.run(run_graph, basket_edge_settings, scalar_edge_settings, starttime=datetime(2020, 1, 1), endtime=timedelta(seconds=30), config=config) -# res2 = csp.run(run_graph, basket_edge_settings, scalar_edge_settings, True, starttime=datetime(2020, 1, 1), endtime=timedelta(seconds=30), config=config) -# self.assertEqual(res, res2) - -# def test_basket_multiday_read_write(self): -# # 5 hours - means we will have data each day and the last few days are empty. With 49 hours we will have some days in the middle having empty -# # data and we want to see that it's handled properly (this test actually found a hidden bug) -# for curve_hours in (5, 49): -# for typ in (int, bool, float, str, datetime, date, TypedCurveGenerator.SimpleEnum, TypedCurveGenerator.SimpleStruct, TypedCurveGenerator.NestedStruct): -# @graph(cache=True) -# def cached_graph() -> csp.Outputs(v1=csp.OutputBasket(Dict[str, csp.ts[typ]], shape=['0', '', '2']), v2=csp.ts[int]): -# curve_generator = TypedCurveGenerator(period=timedelta(hours=curve_hours)) - -# return csp.output(v1={'0': curve_generator.gen_transformed_curve(typ, 0, 10, 1, skip_indices=[5, 6, 7], duplicate_timestamp_indices=[8, 9]), -# '': curve_generator.gen_transformed_curve(typ, 13, 10, 1, skip_indices=[5, 7, 9]), -# '2': curve_generator.gen_transformed_curve(typ, 27, 10, 1, skip_indices=[5, 6]) -# }, -# v2=curve_generator.gen_int_curve(100, 10, 1, skip_indices=[2], duplicate_timestamp_indices=[7, 8])) - -# @graph -# def run_graph(force_cached: bool = False): -# g = cached_graph.cached if force_cached else cached_graph -# csp.add_graph_output('v1[0]', g().v1['0']) -# csp.add_graph_output('v1[1]', g().v1['']) -# csp.add_graph_output('v1[2]', g().v1['2']) -# csp.add_graph_output('v2', g().v2) - -# self.maxDiff = None -# with _GraphTempCacheFolderConfig() as config: -# res = csp.run(run_graph, starttime=datetime(2020, 1, 1), endtime=timedelta(days=5) - timedelta(microseconds=1), config=config) -# res2 = csp.run(run_graph, True, starttime=datetime(2020, 1, 1), endtime=timedelta(days=5) - timedelta(microseconds=1), config=config) -# self.assertEqual(res, res2) -# data_path_resolver = CacheConfigResolver(config.cache_config) -# # A sanity check that we can load the data with some empty dataframes on some days -# base_data_df = cached_graph.cached_data(data_path_resolver)().get_data_df_for_period() - -# def test_merge_baskets(self): -# def _simple_struct_to_dict(o): -# if o is None: -# return None -# return {c: getattr(o, c, None) for c in TypedCurveGenerator.SimpleStruct.metadata()} - -# for batch_size in (117, None): -# output_config = ParquetOutputConfig() if batch_size is None else ParquetOutputConfig(batch_size=batch_size) -# for typ in (int, bool, float, str, datetime, date, TypedCurveGenerator.SimpleEnum, TypedCurveGenerator.SimpleStruct, TypedCurveGenerator.NestedStruct): -# @graph(cache=True) -# def base_graph() -> csp.Outputs(v1=csp.OutputBasket(Dict[str, csp.ts[typ]], shape=['COL1', 'COL2', 'COL3']), v2=csp.ts[int]): -# curve_generator = TypedCurveGenerator(period=timedelta(seconds=7)) -# return csp.output(v1={'COL1': curve_generator.gen_transformed_curve(typ, 0, 2600, 1, skip_indices=[95, 96, 97], duplicate_timestamp_indices=[98, 99]), -# 'COL2': curve_generator.gen_transformed_curve(typ, 13, 2600, 1, skip_indices=[95, 97, 99, 1090]), -# 'COL3': curve_generator.gen_transformed_curve(typ, 27, 2600, 1, skip_indices=[95, 96]) -# }, -# v2=curve_generator.gen_int_curve(100, 2600, 1, skip_indices=[92], duplicate_timestamp_indices=[97, 98])) - -# @graph(cache=True, cache_options=GraphCacheOptions(parquet_output_config=output_config)) -# def cached_graph() -> csp.Outputs(v1=csp.OutputBasket(Dict[str, csp.ts[typ]], shape=['COL1', 'COL2', 'COL3']), v2=csp.ts[int]): -# return csp.output(v1=base_graph.cached().v1, -# v2=base_graph.cached().v2) - -# @graph -# def run_graph(force_cached: bool = False): -# g = cached_graph.cached if force_cached else cached_graph -# csp.add_graph_output('COL1', g().v1['COL1']) -# csp.add_graph_output('COL2', g().v1['COL2']) -# csp.add_graph_output('COL3', g().v1['COL3']) -# csp.add_graph_output('v2', g().v2) - -# # enough to check this just for one type -# merge_existing_files = typ is int -# with _GraphTempCacheFolderConfig(allow_overwrite=True, merge_existing_files=merge_existing_files) as config: -# base_data_outputs = csp.run(base_graph, starttime=datetime(2020, 3, 1, 9, 20, tzinfo=pytz.utc), -# endtime=datetime(2020, 3, 1, 14, 0, tzinfo=pytz.utc), -# config=config) - -# aux_dfs = [pandas.DataFrame(dict(zip(['csp_timestamp', k], zip(*v)))) for k, v in base_data_outputs.items()] -# for aux_df in aux_dfs: -# repeated_timestamp_mask = 1 - (aux_df['csp_timestamp'].shift(1) != aux_df['csp_timestamp']).astype(int) -# aux_df['cycle_count'] = repeated_timestamp_mask.cumsum() * repeated_timestamp_mask -# aux_df.set_index(['csp_timestamp', 'cycle_count'], inplace=True) - -# # this does not work as of pandas==1.4.0 -# # expected_base_df = pandas.concat(aux_dfs, axis=1) -# expected_base_df = aux_dfs[0] -# for df in aux_dfs[1:]: -# expected_base_df = expected_base_df.merge(df, left_index=True, right_index=True, how="outer") -# expected_base_df = expected_base_df.reset_index().drop(columns=['cycle_count']) - -# expected_base_df.columns = [['csp_timestamp', 'v1', 'v1', 'v1', 'v2'], ['', 'COL1', 'COL2', 'COL3', '']] -# expected_base_df = expected_base_df[['csp_timestamp', 'v2', 'v1']] -# expected_base_df['csp_timestamp'] = expected_base_df['csp_timestamp'].dt.tz_localize(pytz.utc) -# if typ is datetime: -# for c in ['COL1', 'COL2', 'COL3']: -# expected_base_df.loc[:, ('v1', c)] = expected_base_df.loc[:, ('v1', c)].dt.tz_localize(pytz.utc) -# if typ is TypedCurveGenerator.SimpleEnum: -# for c in ['COL1', 'COL2', 'COL3']: -# expected_base_df.loc[:, ('v1', c)] = expected_base_df.loc[:, ('v1', c)].apply(lambda v: v.name if isinstance(v, TypedCurveGenerator.SimpleEnum) else v) -# if typ is TypedCurveGenerator.SimpleStruct: -# for k in TypedCurveGenerator.SimpleStruct.metadata(): -# for c in ['COL1', 'COL2', 'COL3']: -# expected_base_df.loc[:, (f'v1.{k}', c)] = expected_base_df.loc[:, ('v1', c)].apply(lambda v: getattr(v, k, None) if v else v) -# expected_base_df.drop(columns=['v1'], inplace=True, level=0) -# if typ is TypedCurveGenerator.NestedStruct: -# for k in TypedCurveGenerator.NestedStruct.metadata(): -# for c in ['COL1', 'COL2', 'COL3']: -# if k == 'value2': -# expected_base_df.loc[:, (f'v1.{k}', c)] = expected_base_df.loc[:, ('v1', c)].apply(lambda v: _simple_struct_to_dict(getattr(v, k, None)) if v else v) -# else: -# expected_base_df.loc[:, (f'v1.{k}', c)] = expected_base_df.loc[:, ('v1', c)].apply(lambda v: getattr(v, k, None) if v else v) -# expected_base_df.drop(columns=['v1'], inplace=True, level=0) - -# data_path_resolver = CacheConfigResolver(config.cache_config) -# base_data_df = base_graph.cached_data(data_path_resolver)().get_data_df_for_period() -# self.assertTrue(base_data_df.fillna(-111111).eq(expected_base_df.fillna(-111111)).all().all()) -# missing_range_handler = lambda start, end: True -# start_time1 = datetime(2020, 3, 1, 9, 30, tzinfo=pytz.utc) -# dt1 = timedelta(hours=0, minutes=60) -# res1 = csp.run(run_graph, starttime=start_time1, endtime=dt1, config=config) -# files = list(cached_graph.cached_data(config.cache_config.data_folder)().get_data_files_for_period().values()) -# self.assertEqual(len(files), 1) -# res1_df = cached_graph.cached_data(data_path_resolver)().get_data_df_for_period(start_time1, dt1) -# self.assertTrue( -# expected_base_df[expected_base_df.csp_timestamp.between(start_time1, start_time1 + dt1)].reset_index(drop=True).fillna(-111111).eq(res1_df.fillna(-111111)).all().all()) -# start_time2 = start_time1 + timedelta(minutes=180) -# dt2 = dt1 -# res2 = csp.run(run_graph, starttime=start_time2, endtime=dt2, config=config) -# files = list(cached_graph.cached_data(config.cache_config.data_folder)().get_data_files_for_period(missing_range_handler=missing_range_handler).values()) -# self.assertEqual(len(files), 2) -# res2_df = cached_graph.cached_data(data_path_resolver)().get_data_df_for_period(start_time2, dt2) -# self.assertTrue( -# expected_base_df[expected_base_df.csp_timestamp.between(start_time2, start_time2 + dt2)].reset_index(drop=True).fillna(-111111).eq(res2_df.fillna(-111111)).all().all()) - -# # # Test repeated writing of the same file -# res2b = csp.run(run_graph, starttime=start_time2, endtime=dt2, config=config) -# self.assertEqual(res2b, res2) - -# start_time3 = start_time2 + dt2 - timedelta(minutes=5) -# dt3 = timedelta(minutes=15) -# res3 = csp.run(run_graph, starttime=start_time3, endtime=dt3, config=config) -# files = list(cached_graph.cached_data(config.cache_config.data_folder)().get_data_files_for_period(missing_range_handler=missing_range_handler).values()) -# if config.cache_config.merge_existing_files: -# self.assertEqual(len(files), 2) -# res3_df = cached_graph.cached_data(data_path_resolver)().get_data_df_for_period(start_time3, dt3) -# self.assertTrue( -# expected_base_df[expected_base_df.csp_timestamp.between(start_time3, start_time3 + dt3)].reset_index(drop=True).fillna(-111111).eq(res3_df.fillna(-111111)).all().all()) - -# start_time4 = start_time2 - timedelta(minutes=5) -# dt4 = timedelta(minutes=15) -# res4 = csp.run(run_graph, starttime=start_time4, endtime=dt4, config=config) -# files = list(cached_graph.cached_data(config.cache_config.data_folder)().get_data_files_for_period(missing_range_handler=missing_range_handler).values()) -# if config.cache_config.merge_existing_files: -# self.assertEqual(len(files), 2) -# res4_df = cached_graph.cached_data(data_path_resolver)().get_data_df_for_period(start_time4, dt4) -# self.assertTrue( -# expected_base_df[expected_base_df.csp_timestamp.between(start_time4, start_time4 + dt4)].reset_index(drop=True).fillna(-111111).eq(res4_df.fillna(-111111)).all().all()) - -# start_time5 = start_time1 + timedelta(minutes=40) -# dt5 = timedelta(minutes=200) - -# res5 = csp.run(run_graph, starttime=start_time5, endtime=dt5, config=config) -# files = list(cached_graph.cached_data(config.cache_config.data_folder)().get_data_files_for_period(missing_range_handler=missing_range_handler).values()) -# if config.cache_config.merge_existing_files: -# self.assertEqual(len(files), 1) -# res5_df = cached_graph.cached_data(data_path_resolver)().get_data_df_for_period(start_time5, dt5) -# self.assertTrue(expected_base_df[expected_base_df.csp_timestamp.between(start_time5, start_time5 + dt5)].reset_index(drop=True).equals(res5_df)) - -# start_time6 = start_time1 + timedelta(minutes=10) -# res6 = csp.run(run_graph, starttime=start_time6, endtime=dt1, config=config) -# files = list(cached_graph.cached_data(config.cache_config.data_folder)().get_data_files_for_period(missing_range_handler=missing_range_handler).values()) -# if config.cache_config.merge_existing_files: -# self.assertEqual(len(files), 1) -# res6_df = cached_graph.cached_data(data_path_resolver)().get_data_df_for_period(start_time6, dt1) -# self.assertTrue( -# expected_base_df[expected_base_df.csp_timestamp.between(start_time6, start_time6 + dt1)].reset_index(drop=True).fillna(-111111).eq(res6_df.fillna(-111111)).all().all()) -# start_time7 = start_time1 - timedelta(minutes=10) -# dt7 = start_time3 + dt3 + timedelta(minutes=10) -# res7 = csp.run(run_graph, starttime=start_time7, endtime=dt7, config=config) -# files = list(cached_graph.cached_data(config.cache_config.data_folder)().get_data_files_for_period().values()) -# self.assertEqual(len(files), 1) -# res7_df = cached_graph.cached_data(data_path_resolver)().get_data_df_for_period(start_time7, dt7) -# self.assertTrue(expected_base_df[expected_base_df.csp_timestamp.between(start_time7, dt7)].reset_index(drop=True).fillna(-111111).eq(res7_df.fillna(-111111)).all().all()) - -# def test_subtype_dict_caching(self): -# with _GraphTempCacheFolderConfig(allow_overwrite=True) as config: -# for cache in (True, False): -# @graph(cache=cache) -# def main() -> csp.Outputs(o=csp.OutputBasket(Dict[str, csp.ts[TypedCurveGenerator.SimpleStruct]], shape=(['A', 'B']))): -# curve_generator = TypedCurveGenerator(period=timedelta(seconds=1)) -# return csp.output(o={ -# 'A': curve_generator.gen_transformed_curve(TypedCurveGenerator.SimpleSubStruct, 100, 10, 1), -# 'B': curve_generator.gen_transformed_curve(TypedCurveGenerator.SimpleSubStruct, 500, 10, 1), -# }) - -# start_time = datetime(2021, 1, 1) -# end_time = start_time + timedelta(seconds=11) -# if cache: -# with self.assertRaises(csp.impl.types.instantiation_type_resolver.ArgTypeMismatchError): -# csp.run(main, starttime=start_time, endtime=end_time, config=config) -# else: -# csp.run(main, starttime=start_time, endtime=end_time, config=config) - -# def test_subclass_caching(self): -# @csp.graph -# def main() -> csp.Outputs(o=csp.ts[TypedCurveGenerator.SimpleStruct]): -# return csp.output(o=csp.const(TypedCurveGenerator.SimpleSubStruct())) - -# @csp.graph(cache=True) -# def main_cached() -> csp.Outputs(o=csp.ts[TypedCurveGenerator.SimpleStruct]): -# return csp.output(o=csp.const(TypedCurveGenerator.SimpleSubStruct())) - -# start_time = datetime(2021, 1, 1) -# end_time = start_time + timedelta(seconds=11) -# with _GraphTempCacheFolderConfig(allow_overwrite=True) as config: -# csp.run(main, starttime=start_time, endtime=end_time, config=config) -# # Cached graphs must return exact types -# with self.assertRaises(csp.impl.types.instantiation_type_resolver.TSArgTypeMismatchError): -# csp.run(main_cached, starttime=start_time, endtime=end_time, config=config) - -# def test_key_subset(self): -# with _GraphTempCacheFolderConfig(allow_overwrite=True) as config: -# @graph(cache=True, cache_options=GraphCacheOptions(ignored_inputs={'tickers'})) -# def main(tickers: [str]) -> csp.Outputs(prices=csp.OutputBasket(Dict[str, csp.ts[float]], shape="tickers")): -# curve_generator = TypedCurveGenerator(period=timedelta(seconds=1)) -# return csp.output(prices={ -# 'AAPL': curve_generator.gen_transformed_curve(float, 100, 10, 1), -# 'IBM': curve_generator.gen_transformed_curve(float, 500, 10, 1), -# }) - -# start_time = datetime(2021, 1, 1) -# end_time = start_time + timedelta(seconds=11) -# res1 = csp.run(main, ['AAPL', 'IBM'], starttime=start_time, endtime=end_time, config=config) -# res2 = csp.run(main, ['AAPL'], starttime=start_time, endtime=end_time, config=config) -# res3 = csp.run(main, ['IBM'], starttime=start_time, endtime=end_time, config=config) -# self.assertEqual(len(res1), 2) -# self.assertEqual(len(res2), 1) -# self.assertEqual(len(res3), 1) -# self.assertEqual(res1['prices[AAPL]'], res2['prices[AAPL]']) -# self.assertEqual(res1['prices[IBM]'], res3['prices[IBM]']) - -# def test_simple_node_caching(self): -# throw_exc = [False] - -# @csp.node(cache=True) -# def main_node() -> csp.Outputs(x=csp.ts[int]): -# with csp.alarms(): -# a = csp.alarm( int ) -# with csp.start(): -# if throw_exc[0]: -# raise RuntimeError("Shouldn't get here, node should be cached") -# csp.schedule_alarm(a, timedelta(), 0) - -# if csp.ticked(a): -# csp.schedule_alarm(a, timedelta(seconds=1), a + 1) -# return csp.output(x=a) - -# start_time = datetime(2021, 1, 1) -# end_time = start_time + timedelta(seconds=11) -# with _GraphTempCacheFolderConfig(allow_overwrite=True) as config: -# res1 = csp.run(main_node, starttime=start_time, endtime=end_time, config=config) -# throw_exc[0] = True -# res2 = csp.run(main_node, starttime=start_time, endtime=end_time, config=config) -# self.assertEqual(res1, res2) - -# def test_node_caching_with_args(self): -# throw_exc = [False] - -# @csp.node(cache=True, cache_options=GraphCacheOptions(ignored_inputs={'input_ts', 'input_basket'})) -# def main_node(input_ts: csp.ts[int], input_basket: {str: csp.ts[int]}, addition: int = Injected('addition_value')) -> csp.Outputs( -# o1=csp.ts[int], o2=csp.OutputBasket(Dict[str, csp.ts[int]], shape_of='input_basket')): -# with csp.alarms(): -# a = csp.alarm( int ) -# with csp.start(): -# if throw_exc[0]: -# raise RuntimeError("Shouldn't get here, node should be cached") -# csp.schedule_alarm(a, timedelta(), -42) -# if csp.ticked(input_ts): -# csp.output(o1=input_ts + addition) -# for k, v in input_basket.tickeditems(): -# csp.output(o2={k: v + addition}) - -# def main_graph(): -# curve_generator = TypedCurveGenerator(period=timedelta(seconds=1)) -# return main_node(curve_generator.gen_int_curve(0, 10, 1), {'1': curve_generator.gen_int_curve(10, 10, 1), '2': curve_generator.gen_int_curve(20, 10, 1)}) - -# start_time = datetime(2021, 1, 1) -# end_time = start_time + timedelta(seconds=11) -# with _GraphTempCacheFolderConfig(allow_overwrite=True) as config: -# with set_new_registry_thread_instance(): -# register_injected_object('addition_value', 42) -# res1 = csp.run(main_graph, starttime=start_time, endtime=end_time, config=config) -# throw_exc[0] = True -# res2 = csp.run(main_graph, starttime=start_time, endtime=end_time, config=config) -# self.assertEqual(res1, res2) - -# def test_caching_int_as_float(self): -# @csp.graph(cache=True) -# def main_cached() -> csp.Outputs(o=csp.ts[float]): -# return csp.output(o=csp.const.using(T=int)(int(42))) - -# start_time = datetime(2021, 1, 1) -# end_time = start_time + timedelta(seconds=11) -# with _GraphTempCacheFolderConfig(allow_overwrite=True) as config: -# res1 = csp.run(main_cached, starttime=start_time, endtime=end_time, config=config) -# res2 = csp.run(main_cached, starttime=start_time, endtime=end_time, config=config) -# cached_val = res2['o'][0][1] -# self.assertIs(type(cached_val), float) -# self.assertEqual(cached_val, 42.0) - -# def test_consecutive_files_merge(self): -# for split_columns_to_files in (True, False): -# @csp.graph(cache=True, cache_options=GraphCacheOptions(split_columns_to_files=split_columns_to_files)) -# def main_cached() -> csp.Outputs(o=csp.ts[float]): -# return csp.output(o=csp.const.using(T=int)(int(42))) - -# start_time = datetime(2021, 1, 1) -# end_time = start_time + timedelta(seconds=11) -# with _GraphTempCacheFolderConfig(allow_overwrite=True) as config: -# csp.run(main_cached, starttime=start_time, endtime=end_time, config=config) -# csp.run(main_cached, starttime=end_time + timedelta(microseconds=1), endtime=end_time + timedelta(seconds=1), config=config) -# files = list(main_cached.cached_data(config.cache_config.data_folder)().get_data_files_for_period().items()) -# self.assertEqual(len(files), 1) -# self.assertEqual(files[0][0], (start_time, start_time + timedelta(seconds=12))) - -# def test_aggregation(self): -# ref_date = datetime(2021, 1, 1) -# dfs = [] -# for aggregation_period in TimeAggregation: -# for split_columns_to_files in (True, False): -# @csp.node(cache=True, cache_options=GraphCacheOptions(split_columns_to_files=split_columns_to_files, -# time_aggregation=aggregation_period)) -# def n1() -> csp.Outputs(c=csp.ts[int]): -# with csp.alarms(): -# a_t = csp.alarm( date ) -# with csp.start(): -# first_out_time = ref_date + timedelta(days=math.ceil((csp.now() - ref_date).total_seconds() / 86400 / 5) * 5) -# csp.schedule_alarm(a_t, first_out_time, ref_date.date()) - -# if csp.ticked(a_t): -# csp.schedule_alarm(a_t, timedelta(days=5), csp.now().date()) -# return csp.output(c=int((csp.now() - ref_date).total_seconds() / 86400)) - -# with _GraphTempCacheFolderConfig(allow_overwrite=True) as config: -# for i in range(100): -# csp.run(n1, starttime=ref_date + timedelta(days=7 * i), endtime=timedelta(days=8, microseconds=1), config=config) - -# all_parquet_files = glob.glob(os.path.join(config.cache_config.data_folder, '**', '*.parquet'), recursive=True) -# files_for_period = n1.cached_data(config.cache_config.data_folder)().get_data_files_for_period() -# dfs.append(n1.cached_data(config.cache_config.data_folder)().get_data_df_for_period()) -# self.assertTrue((dfs[-1]['c'].diff().iloc[1:] == 5).all()) -# num_parquet_files = len(all_parquet_files) // 2 if split_columns_to_files else len(all_parquet_files) - -# if aggregation_period == TimeAggregation.DAY: -# self.assertEqual(len(files_for_period), 702) -# self.assertEqual(num_parquet_files, 702) -# elif aggregation_period == TimeAggregation.MONTH: -# self.assertEqual(len(files_for_period), 24) -# self.assertEqual(num_parquet_files, 24) -# elif aggregation_period == TimeAggregation.QUARTER: -# self.assertEqual(len(files_for_period), 8) -# self.assertEqual(num_parquet_files, 8) -# else: -# self.assertEqual(len(files_for_period), 2) -# self.assertEqual(num_parquet_files, 2) -# for df1, df2 in zip(dfs[0:-1], dfs[1:]): -# self.assertTrue((df1 == df2).all().all()) - -# def test_struct_column_subset_read(self): -# for split_columns_to_files in (True, False): -# @graph(cache=True, cache_options=GraphCacheOptions(ignored_inputs={'t'}, split_columns_to_files=split_columns_to_files)) -# def g(t: 'T' = TypedCurveGenerator.SimpleSubStruct) -> csp.Outputs(o=csp.ts['T']): -# curve_generator = TypedCurveGenerator(period=timedelta(seconds=1)) -# return csp.output(o=curve_generator.gen_transformed_curve(t, 0, 10, 1)) - -# @graph -# def g_single_col() -> csp.Outputs(value=csp.ts[float]): - -# return csp.output(value=g().o.value2) - -# with _GraphTempCacheFolderConfig(allow_overwrite=True) as config: -# start_time = datetime(2021, 1, 1) -# end_time = start_time + timedelta(seconds=11) -# res1 = csp.run(g, starttime=start_time, endtime=end_time, config=config) -# res2 = csp.run(g.cached, TypedCurveGenerator.SimpleSubStruct, starttime=start_time, endtime=end_time, config=config) -# res3 = csp.run(g, TypedCurveGenerator.SimpleStruct, starttime=start_time, endtime=end_time, config=config) -# res4 = csp.run(g.cached, TypedCurveGenerator.SimpleStruct, starttime=start_time, endtime=end_time, config=config) -# # Since now we try to write with a different schema, this should raise -# with self.assertRaisesRegex(RuntimeError, "Metadata mismatch .*"): -# res5 = csp.run(g, TypedCurveGenerator.SimpleStruct, starttime=start_time, endtime=end_time + timedelta(seconds=1), config=config) - -# self.assertEqual(res1, res2) -# self.assertEqual(res3, res4) -# self.assertEqual(len(res1['o']), len(res3['o'])) -# self.assertNotEqual(res1, res3) -# for (t1, v1), (t2, v2) in zip(res1['o'], res3['o']): -# v1_aux = TypedCurveGenerator.SimpleStruct() -# v1_aux.copy_from(v1) -# self.assertEqual(t1, t2) -# self.assertEqual(v1_aux, v2) -# res5 = csp.run(g_single_col, starttime=start_time, endtime=end_time, config=config) -# files = g.cached_data(config)().get_data_files_for_period() -# self.assertEqual(len(files), 1) -# file = next(iter(files.values())) -# if split_columns_to_files: -# os.unlink(os.path.join(file, 'o.value1.parquet')) -# else: -# import pandas -# df = pandas.read_parquet(file) -# df = df.drop(columns=['o.value1']) -# df.to_parquet(file) - -# res6 = csp.run(g_single_col, starttime=start_time, endtime=end_time, config=config) -# self.assertEqual(res5, res6) -# # Since we removed some data when trying to read all again, we should fail -# if split_columns_to_files: -# with self.assertRaisesRegex(Exception, 'IOError.*'): -# res7 = csp.run(g.cached, TypedCurveGenerator.SimpleSubStruct, starttime=start_time, endtime=end_time, config=config) -# else: -# with self.assertRaisesRegex(RuntimeError, '.*Missing column o.value1.*'): -# res7 = csp.run(g.cached, TypedCurveGenerator.SimpleSubStruct, starttime=start_time, endtime=end_time, config=config) - -# def test_basket_struct_column_subset_read(self): -# @graph(cache=True, cache_options=GraphCacheOptions(ignored_inputs={'t'})) -# def g(t: 'T' = TypedCurveGenerator.SimpleSubStruct) -> csp.Outputs(o=csp.OutputBasket(Dict[str, csp.ts['T']], shape=['my_key'])) : -# curve_generator = TypedCurveGenerator(period=timedelta(seconds=1)) -# return csp.output(o={'my_key': curve_generator.gen_transformed_curve(t, 0, 10, 1)}) - -# @graph(cache=True, cache_options=GraphCacheOptions(ignored_inputs={'t'})) -# def g_unnamed_out(t: 'T' = TypedCurveGenerator.SimpleSubStruct) -> csp.OutputBasket(Dict[str, csp.ts['T']], shape=['my_key']): -# return g.cached(t).o - -# @graph -# def g_single_col(unnamed: bool = False) -> csp.Outputs(value=csp.ts[float]): - -# if unnamed: -# res = csp.get_basket_field(g_unnamed_out(), 'value2') -# else: -# res = csp.get_basket_field(g().o, 'value2') - -# return csp.output(value=res['my_key']) - -# def verify_all(x: csp.ts[bool]): -# self.assertTrue(x is not None) - -# @graph -# def g_verify_multiple_type(): -# verify_all(g_unnamed_out(TypedCurveGenerator.SimpleStruct)['my_key'].value2 == g_unnamed_out(TypedCurveGenerator.SimpleSubStruct)['my_key'].value2) - -# with _GraphTempCacheFolderConfig(allow_overwrite=True) as config: -# start_time = datetime(2021, 1, 1) -# end_time = start_time + timedelta(seconds=11) -# res1 = csp.run(g, starttime=start_time, endtime=end_time, config=config) -# res2 = csp.run(g.cached, TypedCurveGenerator.SimpleSubStruct, starttime=start_time, endtime=end_time, config=config) -# res3 = csp.run(g, TypedCurveGenerator.SimpleStruct, starttime=start_time, endtime=end_time, config=config) -# res4 = csp.run(g.cached, TypedCurveGenerator.SimpleStruct, starttime=start_time, endtime=end_time, config=config) -# res_unnamed = csp.run(g_unnamed_out, starttime=start_time, endtime=end_time, config=config) -# res_unnamed_cached = csp.run(g_unnamed_out.cached, starttime=start_time, endtime=end_time, config=config) -# csp.run(g_verify_multiple_type, starttime=start_time, endtime=end_time, config=config) -# self.assertEqual(res_unnamed, res_unnamed_cached) -# self.assertEqual(res_unnamed['my_key'], res1['o[my_key]']) -# # Since now we try to write with a different schema, this should raise -# with self.assertRaisesRegex(RuntimeError, "Metadata mismatch .*"): -# res5 = csp.run(g, TypedCurveGenerator.SimpleStruct, starttime=start_time, endtime=end_time + timedelta(seconds=1), config=config) - -# self.assertEqual(res1, res2) -# self.assertEqual(res3, res4) -# self.assertEqual(len(res1['o[my_key]']), len(res3['o[my_key]'])) -# self.assertNotEqual(res1, res3) -# for (t1, v1), (t2, v2) in zip(res1['o[my_key]'], res3['o[my_key]']): -# v1_aux = TypedCurveGenerator.SimpleStruct() -# v1_aux.copy_from(v1) -# self.assertEqual(t1, t2) -# self.assertEqual(v1_aux, v2) -# res5 = csp.run(g_single_col, False, starttime=start_time, endtime=end_time, config=config) -# res5_unnamed = csp.run(g_single_col, True, starttime=start_time, endtime=end_time, config=config) -# self.assertEqual(res5, res5_unnamed) -# files = g.cached_data(config)().get_data_files_for_period() -# self.assertEqual(len(files), 1) -# file = next(iter(files.values())) -# # TODO: uncomment -# # os.unlink(os.path.join(file, 'o.value1.parquet')) - -# res6 = csp.run(g_single_col, starttime=start_time, endtime=end_time, config=config) -# self.assertEqual(res5, res6) -# # Since we removed some data when trying to read all again, we should fail -# # TODO: uncomment -# # with self.assertRaisesRegex(Exception, 'IOError.*'): -# # res7 = csp.run(g.cached, TypedCurveGenerator.SimpleSubStruct, starttime=start_time, endtime=end_time, config=config) - -# def test_unnamed_output_caching(self): -# for split_columns_to_files in (True, False): -# with _GraphTempCacheFolderConfig(allow_overwrite=True) as config: -# @csp.graph(cache=True, cache_options=GraphCacheOptions(split_columns_to_files=split_columns_to_files)) -# def g_scalar() -> csp.ts[int]: -# gen = TypedCurveGenerator() -# return gen.gen_int_curve(0, 10, 1) - -# @csp.graph(cache=True, cache_options=GraphCacheOptions(split_columns_to_files=split_columns_to_files)) -# def g_struct() -> csp.ts[TypedCurveGenerator.SimpleStruct]: -# gen = TypedCurveGenerator() -# return gen.gen_transformed_curve(TypedCurveGenerator.SimpleStruct, 0, 10, 1) - -# @csp.graph(cache=True) -# def g_scalar_basket() -> csp.OutputBasket(Dict[str, csp.ts[int]] , shape=['k1', 'k2']): -# gen = TypedCurveGenerator() -# return {'k1': gen.gen_int_curve(0, 10, 1), -# 'k2': gen.gen_int_curve(100, 10, 1)} - -# @csp.graph(cache=True) -# def g_struct_basket() -> csp.OutputBasket(Dict[str, csp.ts[TypedCurveGenerator.SimpleStruct]], shape=['k1', 'k2']): -# gen = TypedCurveGenerator() -# return {'k1': gen.gen_transformed_curve(TypedCurveGenerator.SimpleStruct, 0, 10, 1), -# 'k2': gen.gen_transformed_curve(TypedCurveGenerator.SimpleStruct, 100, 10, 1)} - -# def run_test_single_graph(g_func): -# res1 = csp.run(g_func, starttime=datetime(2020, 3, 1, 20, 30), endtime=timedelta(hours=0, minutes=390), -# config=config) -# res2 = csp.run(g_func, starttime=datetime(2020, 3, 1, 20, 30), endtime=timedelta(hours=0, minutes=390), -# config=config) -# res3 = csp.run(g_func.cached, starttime=datetime(2020, 3, 1, 20, 30), endtime=timedelta(hours=0, minutes=390), -# config=config) -# self.assertEqual(res1, res2) -# self.assertEqual(res2, res3) -# return res1 - -# run_test_single_graph(g_scalar) -# run_test_single_graph(g_struct) -# run_test_single_graph(g_scalar_basket) -# run_test_single_graph(g_struct_basket) - -# res1_df = g_scalar.cached_data(config)().get_data_df_for_period() -# res2_df = g_struct.cached_data(config)().get_data_df_for_period() -# res3_df = g_scalar_basket.cached_data(config)().get_data_df_for_period() -# res4_df = g_struct_basket.cached_data(config)().get_data_df_for_period() -# self.assertEqual(list(res1_df.columns), ['csp_timestamp', 'csp_unnamed_output']) -# self.assertEqual(list(res2_df.columns), ['csp_timestamp', 'value1', 'value2']) -# self.assertEqual(list(res3_df.columns), ['csp_timestamp', 'k1', 'k2']) -# self.assertEqual(list(res4_df.columns), [('csp_timestamp', ''), ('value1', 'k1'), ('value1', 'k2'), ('value2', 'k1'), ('value2', 'k2')]) - -# for df in (res1_df, res2_df, res3_df, res4_df): -# self.assertEqual(len(df), 11) - -# def test_basket_ids_retrieval(self): -# for aggregation_period in (TimeAggregation.MONTH, TimeAggregation.DAY,): -# with _GraphTempCacheFolderConfig(allow_overwrite=True) as config: -# @csp.graph(cache=True, cache_options=GraphCacheOptions(ignored_inputs={'keys'}, time_aggregation=aggregation_period)) -# def g(keys: object) -> csp.OutputBasket(Dict[str, csp.ts[TypedCurveGenerator.SimpleStruct]], shape="keys"): -# gen = TypedCurveGenerator() -# return {keys[0]: gen.gen_transformed_curve(TypedCurveGenerator.SimpleStruct, 0, 10, 1), -# keys[1]: gen.gen_transformed_curve(TypedCurveGenerator.SimpleStruct, 100, 10, 1)} - -# @csp.graph(cache=True, cache_options=GraphCacheOptions(ignored_inputs={'keys'})) -# def g_named_output(keys: object) -> csp.Outputs(out=csp.OutputBasket(Dict[str, csp.ts[TypedCurveGenerator.SimpleStruct]], shape="keys")): - -# gen = TypedCurveGenerator() -# return csp.output(out={keys[0]: gen.gen_transformed_curve(TypedCurveGenerator.SimpleStruct, 0, 10, 1), -# keys[1]: gen.gen_transformed_curve(TypedCurveGenerator.SimpleStruct, 100, 10, 1)}) - -# csp.run(g, ['k1', 'k2'], starttime=datetime(2020, 3, 1), endtime=datetime(2020, 3, 1, 23, 59, 59, 999999), -# config=config) -# csp.run(g, ['k3', 'k4'], starttime=datetime(2020, 3, 2), endtime=datetime(2020, 3, 2, 23, 59, 59, 999999), -# config=config) -# csp.run(g_named_output, ['k1', 'k2'], starttime=datetime(2020, 3, 1), endtime=datetime(2020, 3, 1, 23, 59, 59, 999999), -# config=config) -# csp.run(g_named_output, ['k3', 'k4'], starttime=datetime(2020, 3, 2), endtime=datetime(2020, 3, 2, 23, 59, 59, 999999), -# config=config) -# self.assertEqual(g.cached_data(config)().get_all_basket_ids_in_range(), ['k1', 'k2', 'k3', 'k4']) -# self.assertEqual(g.cached_data(config)().get_all_basket_ids_in_range(starttime=datetime(2020, 3, 1), endtime=datetime(2020, 3, 1, 23, 59, 59, 999999)), -# ['k1', 'k2']) -# self.assertEqual(g.cached_data(config)().get_all_basket_ids_in_range(starttime=datetime(2020, 3, 2), endtime=datetime(2020, 3, 2, 23, 59, 59, 999999)), -# ['k3', 'k4']) -# self.assertEqual(g_named_output.cached_data(config)().get_all_basket_ids_in_range('out'), ['k1', 'k2', 'k3', 'k4']) -# self.assertEqual(g_named_output.cached_data(config)().get_all_basket_ids_in_range('out', starttime=datetime(2020, 3, 1), endtime=datetime(2020, 3, 1, 23, 59, 59, 999999)), -# ['k1', 'k2']) -# self.assertEqual(g_named_output.cached_data(config)().get_all_basket_ids_in_range('out', starttime=datetime(2020, 3, 2), endtime=datetime(2020, 3, 2, 23, 59, 59, 999999)), -# ['k3', 'k4']) - -# def test_custom_time_fields(self): -# from csp.impl.wiring.graph import NoCachedDataException -# import numpy - -# @csp.graph(cache=True, cache_options=GraphCacheOptions(data_timestamp_column_name='timestamp')) -# def g1() -> csp.ts[_DummyStructWithTimestamp]: -# s = csp.engine_start_time() -# return csp.curve(_DummyStructWithTimestamp, [(s + timedelta(hours=1 + i), -# _DummyStructWithTimestamp(val=i, timestamp=s + timedelta(hours=(2 * i) ** 2))) for i in range(10)]) - -# @csp.graph(cache=True, cache_options=GraphCacheOptions(data_timestamp_column_name='timestamp')) -# def g2() -> csp.Outputs(timestamp=csp.ts[datetime], values=csp.OutputBasket(Dict[str, csp.ts[int]], shape=['v1', 'v2'])): -# s = csp.engine_start_time() -# values = {} -# values['v1'] = csp.curve(int, [(s + timedelta(hours=1 + i), i) for i in range(10)]) -# values['v2'] = csp.curve(int, [(s + timedelta(hours=1 + i), i * 100) for i in range(10)]) -# t = csp.curve(datetime, [(s + timedelta(hours=1 + i), s + timedelta(hours=(2 * i) ** 2)) for i in range(10)]) -# return csp.output(timestamp=t, values=values) - -# with _GraphTempCacheFolderConfig(allow_overwrite=True) as config: -# s = datetime(2021, 1, 1) -# csp.run(g1, starttime=s, endtime=timedelta(hours=100), config=config) -# csp.run(g2, starttime=s, endtime=timedelta(hours=100), config=config) -# data_files = g1.cached_data(config)().get_data_files_for_period() -# data_files2 = g2.cached_data(config)().get_data_files_for_period() -# res = csp.run(g1.cached, starttime=datetime(2021, 1, 1), endtime=datetime(2021, 1, 14, 12, 0), config=config) -# res2 = csp.run(g2.cached, starttime=datetime(2021, 1, 1), endtime=datetime(2021, 1, 14, 12, 0), config=config) -# with self.assertRaises(NoCachedDataException): -# csp.run(g1.cached, starttime=datetime(2021, 1, 1), endtime=datetime(2021, 1, 14, 12, 1), config=config) - -# self.assertEqual(list(data_files.keys()), list(data_files2.keys())) -# self.assertEqual([(k, v.val) for k, v in res[0]], res2['values[v1]']) -# all_file_time_ranges = list(data_files.keys()) -# expected_start_end = res[0][0][1].timestamp, res[0][-1][1].timestamp -# actual_start_end = all_file_time_ranges[0][0], all_file_time_ranges[-1][1] -# self.assertEqual(expected_start_end, actual_start_end) -# data_df = g1.cached_data(config)().get_data_df_for_period() -# data_df2 = g2.cached_data(config)().get_data_df_for_period() -# self.assertTrue(all((data_df.timestamp.diff().dt.total_seconds() / 3600).values[1:].astype(int) == numpy.diff(((numpy.arange(0, 10) * 2) ** 2)))) -# self.assertTrue(all(data_df.val.values == (numpy.arange(0, 10)))) -# self.assertTrue((data_df['val'] == data_df2['values']['v1']).all()) -# self.assertTrue((data_df['val'] * 100 == data_df2['values']['v2']).all()) - -# def test_cached_with_start_stop_times(self): -# @csp.graph(cache=True) -# def g() -> csp.ts[int]: -# return csp.curve(int, [(datetime(2021, 1, 1), 1), (datetime(2021, 1, 2), 2), (datetime(2021, 1, 3), 3)]) - -# @csp.graph -# def g2(csp_cache_start: object = None) -> csp.ts[int]: -# end = csp.engine_end_time() - timedelta(days=1, microseconds=-1) -# if csp_cache_start: -# cached_g = g.cached[csp_cache_start:end] -# else: -# cached_g = g.cached[:end] -# return csp.delay(cached_g(), timedelta(days=1)) - -# with _GraphTempCacheFolderConfig(allow_overwrite=True) as config: -# res1 = csp.run(g, starttime=datetime(2021, 1, 1), endtime=timedelta(days=3, microseconds=-1), config=config) -# res2 = csp.run(g2, starttime=datetime(2021, 1, 1), endtime=timedelta(days=3, microseconds=-1) + timedelta(days=1), config=config) -# res1_transformed = [(v1 + timedelta(days=1), v2) for (v1, v2) in res1[0]] -# self.assertEqual(res1_transformed, res2[0]) -# with self.assertRaises(NoCachedDataException): -# csp.run(g2, starttime=datetime(2021, 1, 1), endtime=timedelta(days=3, microseconds=-1) + timedelta(days=1, microseconds=1), config=config) - -# res3 = csp.run(g2, datetime(2021, 1, 1, 1), starttime=datetime(2021, 1, 1), endtime=timedelta(days=3, microseconds=-1) + timedelta(days=1), config=config) -# self.assertEqual(res3[0], res2[0][1:]) - -# def test_cached_graph_not_instantiated(self): -# raise_exception = [False] - -# @csp.graph(cache=True) -# def g() -> csp.ts[int]: -# assert not raise_exception[0] -# return csp.curve(int, [(datetime(2021, 1, 1), 1), (datetime(2021, 1, 2), 2), (datetime(2021, 1, 3), 3)]) - -# with _GraphTempCacheFolderConfig(allow_overwrite=True) as config: -# csp.run(g, starttime=datetime(2021, 1, 1), endtime=timedelta(days=3, microseconds=-1), config=config) -# raise_exception[0] = True -# csp.run(g, starttime=datetime(2021, 1, 1), endtime=timedelta(days=3, microseconds=-1), config=config) -# with self.assertRaises(NoCachedDataException): -# csp.run(g.cached, starttime=datetime(2021, 1, 1), endtime=timedelta(days=4, microseconds=-1), config=config) - -# def test_caching_with_struct_arguments(self): -# @csp.graph(cache=True) -# def g(value: TypedCurveGenerator.SimpleStruct) -> csp.ts[TypedCurveGenerator.SimpleStruct]: -# return csp.curve(TypedCurveGenerator.SimpleStruct, [(datetime(2021, 1, 1), value)]) - -# with _GraphTempCacheFolderConfig(allow_overwrite=True) as config: -# s = TypedCurveGenerator.SimpleStruct(value1=42) -# res = csp.run(g, s, starttime=datetime(2021, 1, 1), endtime=timedelta(days=3, microseconds=-1), config=config) -# res2 = csp.run(g.cached, s, starttime=datetime(2021, 1, 1), endtime=timedelta(days=3, microseconds=-1), config=config) -# self.assertEqual(res, res2) - -# def test_caching_user_types(self): -# class OrderedDictSerializer(CacheObjectSerializer): -# def serialize_to_bytes(self, value): -# import pickle -# return pickle.dumps(value) - -# def deserialize_from_bytes(self, value): -# import pickle -# return pickle.loads(value) - -# @csp.graph(cache=True) -# def g() -> csp.ts[collections.OrderedDict]: -# return csp.curve(collections.OrderedDict, [(datetime(2021, 1, 1), collections.OrderedDict({1: 2, 3: 4}))]) - -# with _GraphTempCacheFolderConfig(allow_overwrite=True) as config: -# # We don't know how to serialize ordereddict, this should raise -# with self.assertRaises(TypeError): -# res = csp.run(g, starttime=datetime(2021, 1, 1), endtime=timedelta(days=3, microseconds=-1), config=config) - -# config.cache_config.cache_serializers[collections.OrderedDict] = OrderedDictSerializer() -# res = csp.run(g, starttime=datetime(2021, 1, 1), endtime=timedelta(days=3, microseconds=-1), config=config) -# res2 = csp.run(g.cached, starttime=datetime(2021, 1, 1), endtime=timedelta(days=3, microseconds=-1), config=config) -# self.assertEqual(res, res2) -# res_df = g.cached_data(config)().get_data_df_for_period() -# self.assertEqual(res_df['csp_unnamed_output'].iloc[0], res[0][0][1]) - -# def test_special_character_partitioning(self): -# # Since we're using glob to locate the files on disk, there was a bug that special characters in the partition values broke the partition data -# # lookup. This test tests that it works now. - -# for split_columns_to_files in (True, False): -# graph_kwargs = self._get_default_graph_caching_kwargs(split_columns_to_files) - -# @csp.graph(cache=True, **graph_kwargs) -# def g(x1: str, x2: str) -> csp.ts[str]: -# return csp.curve(str, [(datetime(2021, 1, 1), x1), (datetime(2021, 1, 1) + timedelta(seconds=1), x2)]) - -# with _GraphTempCacheFolderConfig(allow_overwrite=True) as config: -# x1 = "[][]" -# x2 = "*x*)(" -# res = csp.run(g, x1, x2, starttime=datetime(2021, 1, 1), endtime=timedelta(days=3, microseconds=-1), config=config) -# res2 = csp.run(g.cached, x1, x2, starttime=datetime(2021, 1, 1), endtime=timedelta(days=3, microseconds=-1), config=config) -# self.assertEqual(res, res2) -# df = g.cached_data(config)(x1, x2).get_data_df_for_period() -# self.assertEqual(df['csp_unnamed_output'].tolist(), [x1, x2]) - -# def test_cutoff_bug(self): -# """Test for bug that was there of +-1 micro second offset, that caused some stitch data to be missing -# :return: -# """ -# for split_columns_to_files in (True, False): -# if split_columns_to_files: -# cache_options = GraphCacheOptions(split_columns_to_files=True) -# else: -# cache_options = GraphCacheOptions(split_columns_to_files=True) -# cache_options.time_aggregation = TimeAggregation.MONTH - -# @csp.graph(cache=True, cache_options=cache_options) -# def g() -> csp.ts[int]: -# l = [(datetime(2021, 1, 1), 1), (datetime(2021, 1, 1, 23, 59, 59, 999999), 2), -# (datetime(2021, 1, 2), 3), (datetime(2021, 1, 2, 23, 59, 59, 999999), 4), -# (datetime(2021, 1, 3), 5), (datetime(2021, 1, 3, 23, 59, 59, 999999), 6)] -# l = [v for v in l if csp.engine_start_time() <= v[0] <= csp.engine_end_time()] -# return csp.curve(int, l) - -# with _GraphTempCacheFolderConfig(allow_overwrite=True) as config: -# csp.run(g, starttime=datetime(2021, 1, 1), endtime=datetime(2021, 1, 1, 23, 59, 59, 999999), config=config) -# csp.run(g, starttime=datetime(2021, 1, 1, 12), endtime=datetime(2021, 1, 3, 23, 59, 59, 999999), config=config) -# self.assertEqual(g.cached_data(config)().get_data_df_for_period()['csp_unnamed_output'].tolist(), [1, 2, 3, 4, 5, 6]) - -# def test_scalar_flat_basket_loading(self): -# @csp.graph(cache=True) -# def simple_cached() -> csp.Outputs(i=csp.OutputBasket(Dict[str, csp.ts[int]], shape=['V1', 'V2']), -# s=csp.OutputBasket(Dict[str, csp.ts[str]], shape=['V3', 'V4'])): - -# i_v1 = csp.curve(int, [(timedelta(hours=10), 1), (timedelta(hours=10), 1), (timedelta(hours=30), 1)]) -# i_v2 = csp.curve(int, [(timedelta(hours=10), 10), (timedelta(hours=20), 11)]) -# s_v3 = csp.curve(str, [(timedelta(hours=30), "val1")]) -# s_v4 = csp.curve(str, [(timedelta(hours=10), "val2"), (timedelta(hours=20), "val3")]) -# return csp.output(i={'V1': i_v1, 'V2': i_v2}, s={'V3': s_v3, 'V4': s_v4}) - -# @csp.graph(cache=True) -# def simple_cached_unnamed() -> csp.OutputBasket(Dict[str, csp.ts[int]], shape=['V1', 'V2']): - -# return csp.output(simple_cached().i) - -# start_time = datetime(2021, 1, 1) -# end_time = start_time + timedelta(hours=30) -# with _GraphTempCacheFolderConfig(allow_overwrite=True) as config: -# csp.run(simple_cached, starttime=start_time, endtime=end_time, config=config) -# csp.run(simple_cached_unnamed, starttime=start_time, endtime=end_time, config=config) -# df_ref_full = simple_cached.cached_data(config)().get_data_df_for_period().stack(dropna=False).reset_index().drop(columns=['level_0']).rename(columns={'level_1': 'symbol'}) -# df_ref_full['csp_timestamp'] = df_ref_full['csp_timestamp'].ffill().dt.tz_localize(None) -# df_ref_full = df_ref_full[df_ref_full.symbol.str.len() > 0].reset_index(drop=True) - -# for start_dt, end_dt in ((None, None), -# (timedelta(hours=10), None), -# (timedelta(hours=10), timedelta(hours=10, microseconds=1)), -# (timedelta(hours=10), timedelta(hours=20)), -# (timedelta(hours=10, microseconds=1), timedelta(hours=20)), -# (timedelta(hours=10, microseconds=1), None), -# (timedelta(hours=10, microseconds=1), timedelta(hours=10, microseconds=2)), -# (timedelta(hours=10, microseconds=1), timedelta(hours=30))): - -# cur_start = start_time + start_dt if start_dt else None -# cur_end = start_time + end_dt if end_dt else None -# mask = df_ref_full.index >= 0 -# if cur_start: -# mask &= df_ref_full.csp_timestamp >= cur_start -# if cur_end: -# mask &= df_ref_full.csp_timestamp <= cur_end -# df_ref = df_ref_full[mask] - -# df_ref_i = df_ref[['csp_timestamp', 'symbol', 'i']][~df_ref.i.isna()].reset_index(drop=True) -# df_ref_s = df_ref[['csp_timestamp', 'symbol', 's']][~df_ref.s.isna()].reset_index(drop=True) - -# i_df_flat = simple_cached.cached_data(config)().get_flat_basket_df_for_period(basket_field_name='i', symbol_column='symbol', -# starttime=cur_start, endtime=cur_end) -# s_df_flat = simple_cached.cached_data(config)().get_flat_basket_df_for_period(basket_field_name='s', symbol_column='symbol', -# starttime=cur_start, endtime=cur_end) -# unnamed_flat = simple_cached_unnamed.cached_data(config)().get_flat_basket_df_for_period(symbol_column='symbol', -# starttime=cur_start, endtime=cur_end) -# self.assertTrue((i_df_flat == df_ref_i).all().all()) -# self.assertTrue((s_df_flat == df_ref_s).all().all()) - -# # We can't rename columns when None is returned so we have to add this check -# if unnamed_flat is None: -# self.assertTrue(len(df_ref_i) == 0) -# else: -# self.assertTrue((unnamed_flat.rename(columns={'csp_unnamed_output': 'i'}) == df_ref_i).all().all()) - -# def test_struct_flat_basket_loading(self): -# @csp.graph(cache=True) -# def simple_cached() -> csp.Outputs(ret=csp.OutputBasket(Dict[str, csp.ts[TypedCurveGenerator.SimpleSubStruct]], shape=['V1', 'V2'])): - -# i_v1 = csp.curve(int, [(timedelta(hours=10), 1), (timedelta(hours=10), 1), (timedelta(hours=30), 1)]) -# i_v2 = csp.curve(int, [(timedelta(hours=10), 10), (timedelta(hours=20), 11)]) -# s_v3 = csp.curve(str, [(timedelta(hours=30), "val1")]) -# s_v4 = csp.curve(str, [(timedelta(hours=10), "val2"), (timedelta(hours=20), "val3")]) -# res = {} -# res['V1'] = TypedCurveGenerator.SimpleSubStruct.fromts(value1=i_v1, value3=s_v3) -# res['V2'] = TypedCurveGenerator.SimpleSubStruct.fromts(value1=i_v2, value3=s_v4) -# return csp.output(ret=res) - -# @csp.graph(cache=True) -# def simple_cached_unnamed() -> csp.OutputBasket(Dict[str, csp.ts[TypedCurveGenerator.SimpleSubStruct]], shape=['V1', 'V2']): -# return csp.output(simple_cached().ret) - -# start_time = datetime(2021, 1, 1) -# end_time = start_time + timedelta(hours=30) -# with _GraphTempCacheFolderConfig(allow_overwrite=True) as config: -# csp.run(simple_cached, starttime=start_time, endtime=end_time, config=config) -# csp.run(simple_cached_unnamed, starttime=start_time, endtime=end_time, config=config) -# df_ref_full = simple_cached.cached_data(config)().get_data_df_for_period().stack(dropna=False).reset_index().drop(columns=['level_0']).rename(columns={'level_1': 'symbol'}) -# df_ref_full['csp_timestamp'] = df_ref_full['csp_timestamp'].ffill().dt.tz_localize(None) -# mask = (df_ref_full.symbol.str.len() > 0) & (~df_ref_full['ret.value1'].isna() | ~df_ref_full['ret.value2'].isna() | ~df_ref_full['ret.value2'].isna()) -# df_ref_full = df_ref_full[mask].reset_index(drop=True) -# df_ref_full = df_ref_full[['csp_timestamp', 'symbol', 'ret.value1', 'ret.value2', 'ret.value3']] - -# for start_dt, end_dt in ((None, None), -# (timedelta(hours=10), None), -# (timedelta(hours=10), timedelta(hours=10, microseconds=1)), -# (timedelta(hours=10), timedelta(hours=20)), -# (timedelta(hours=10, microseconds=1), timedelta(hours=20)), -# (timedelta(hours=10, microseconds=1), None), -# (timedelta(hours=10, microseconds=1), timedelta(hours=10, microseconds=2)), -# (timedelta(hours=10, microseconds=1), timedelta(hours=30))): - -# cur_start = start_time + start_dt if start_dt else None -# cur_end = start_time + end_dt if end_dt else None -# mask = df_ref_full.index >= 0 -# if cur_start: -# mask &= df_ref_full.csp_timestamp >= cur_start -# if cur_end: -# mask &= df_ref_full.csp_timestamp <= cur_end -# df_ref = df_ref_full[mask].fillna(-999).reset_index(drop=True) - -# df_flat = simple_cached.cached_data(config)().get_flat_basket_df_for_period(basket_field_name='ret', symbol_column='symbol', -# starttime=cur_start, endtime=cur_end) - -# unnamed_flat = simple_cached_unnamed.cached_data(config)().get_flat_basket_df_for_period(symbol_column='symbol', -# starttime=cur_start, endtime=cur_end) -# if unnamed_flat is not None: -# unnamed_flat_normalized = unnamed_flat.rename(columns=dict(zip(unnamed_flat.columns, df_flat.columns))) -# if df_flat is None: -# self.assertTrue(len(df_ref) == 0) -# self.assertTrue(unnamed_flat is None) -# else: -# self.assertTrue((df_flat.fillna(-999) == df_ref.fillna(-999)).all().all()) -# self.assertTrue((df_flat.fillna(-999) == unnamed_flat_normalized.fillna(-999)).all().all()) - -# for c in TypedCurveGenerator.SimpleSubStruct.metadata().keys(): -# df_flat_single_col = simple_cached.cached_data(config)().get_flat_basket_df_for_period(basket_field_name='ret', symbol_column='symbol', struct_fields=[c], -# starttime=cur_start, endtime=cur_end) -# if df_flat_single_col is None: -# self.assertTrue(len(df_ref) == 0) -# continue -# df_flat_single_col_ref = df_ref[df_flat_single_col.columns] -# self.assertTrue((df_flat_single_col.fillna(-999) == df_flat_single_col_ref.fillna(-999)).all().all()) - -# def test_simple_time_shift(self): -# @csp.graph(cache=True) -# def simple_cached() -> csp.ts[int]: - -# return csp.curve(int, [(timedelta(hours=i), i) for i in range(72)]) - -# @csp.graph -# def cached_data_shifted(shift: timedelta) -> csp.ts[int]: -# return simple_cached.cached.shifted(csp_timestamp_shift=shift)() - -# def to_df(res): -# return pandas.DataFrame({'timestamp': res[0][0], 'value': res[0][1]}) - -# start_time = datetime(2021, 1, 1) -# end_time = start_time + timedelta(hours=71) -# with _GraphTempCacheFolderConfig(allow_overwrite=True) as config: -# csp.run(simple_cached, starttime=start_time, endtime=end_time, config=config) -# with self.assertRaises(NoCachedDataException): -# csp.run(cached_data_shifted, timedelta(minutes=1), starttime=start_time, endtime=end_time, config=config, output_numpy=True) - -# ref_df = to_df(csp.run(cached_data_shifted, timedelta(), starttime=start_time, endtime=end_time, config=config, output_numpy=True)) -# td12 = timedelta(hours=12) - -# shifted_df = ref_df.shift(12).iloc[12:, :].reset_index(drop=True) -# shifted_df['timestamp'] += td12 -# res_df1 = to_df(csp.run(cached_data_shifted, td12, starttime=start_time + td12, endtime=end_time, config=config, output_numpy=True)) -# res_df2 = to_df(csp.run(cached_data_shifted, td12, starttime=start_time + td12, endtime=end_time + td12, config=config, output_numpy=True)) -# self.assertTrue((shifted_df == res_df1).all().all()) -# self.assertTrue((ref_df.value == res_df2.value).all()) - -# def test_struct_basket_time_shift(self): -# @csp.graph(cache=True) -# def struct_cached() -> csp.OutputBasket(Dict[str, csp.ts[TypedCurveGenerator.SimpleStruct]], shape=['A', 'B']): - -# generator = TypedCurveGenerator(period=timedelta(hours=1)) -# return { -# 'A': generator.gen_transformed_curve(TypedCurveGenerator.SimpleStruct, start_value=0, num_cycles=71, increment=1, duplicate_timestamp_indices=[11], skip_indices=[3, 15]), -# 'B': generator.gen_transformed_curve(TypedCurveGenerator.SimpleStruct, start_value=100, num_cycles=71, increment=1, duplicate_timestamp_indices=[10, 11, 12, 13, 14], -# skip_indices=[3, 15]) -# } - -# @csp.node -# def dict_builder(x: {str: csp.ts[TypedCurveGenerator.SimpleStruct]}) -> csp.ts[object]: -# res = {'timestamp': csp.now()} -# ticked_items = {k: v for k, v in x.tickeditems()} -# for k in x.keys(): -# res[k] = ticked_items.get(k) -# return res - -# @csp.graph -# def cached_data_shifted(shift: timedelta) -> csp.ts[object]: -# return dict_builder(struct_cached.cached.shifted(csp_timestamp_shift=shift)()) - -# def to_df(res): -# keys = list(res[0][1][0].keys()) -# values = [[v for k, v in d.items()] for d in res[0][1]] -# return pandas.DataFrame(dict(zip(keys, zip(*values)))) - -# start_time = datetime(2021, 1, 1) -# end_time = start_time + timedelta(hours=71) -# with _GraphTempCacheFolderConfig(allow_overwrite=True) as config: -# csp.run(struct_cached, starttime=start_time, endtime=end_time, config=config) -# with self.assertRaises(NoCachedDataException): -# csp.run(cached_data_shifted, timedelta(minutes=1), starttime=start_time, endtime=end_time, config=config, output_numpy=True) - -# ref_df = to_df(csp.run(cached_data_shifted, timedelta(), starttime=start_time, endtime=end_time, config=config, output_numpy=True)) -# td12 = timedelta(hours=12) - -# ref_df1 = ref_df.copy() -# ref_df1.timestamp += td12 -# ref_df1 = ref_df1[ref_df1.timestamp.between(start_time + td12, end_time)].reset_index(drop=True) -# res_df1 = to_df(csp.run(cached_data_shifted, td12, starttime=start_time + td12, endtime=end_time, config=config, output_numpy=True)) -# self.assertTrue((res_df1.fillna(-1) == ref_df1.fillna(-1)).all().all()) - -# ref_df2 = ref_df.copy() -# ref_df2.timestamp += td12 -# ref_df2 = ref_df2[ref_df2.timestamp.between(start_time + td12, end_time + td12)].reset_index(drop=True) -# res_df2 = to_df(csp.run(cached_data_shifted, td12, starttime=start_time + td12, endtime=end_time + td12, config=config, output_numpy=True)) -# self.assertTrue((ref_df2.fillna(-1) == res_df2.fillna(-1)).all().all()) - -# ref_df2 = ref_df.copy() -# ref_df2.timestamp -= td12 -# ref_df2 = ref_df2[ref_df2.timestamp.between(start_time - td12, end_time - td12)].reset_index(drop=True) -# res_df2 = to_df(csp.run(cached_data_shifted, -td12, starttime=start_time - td12, endtime=end_time - td12, config=config, output_numpy=True)) -# self.assertTrue((ref_df2.fillna(-1) == res_df2.fillna(-1)).all().all()) - -# def test_caching_separate_folder(self): -# @csp.graph(cache=True) -# def g(name: str) -> csp.ts[float]: -# if name == 'a': -# return csp.curve(float, [(timedelta(seconds=i), i) for i in range(10)]) -# else: -# return csp.curve(float, [(timedelta(seconds=i), i * 2) for i in range(10)]) - -# with _GraphTempCacheFolderConfig() as config: -# with _GraphTempCacheFolderConfig() as config2: -# res1 = csp.run(g, 'a', starttime=datetime(2020, 1, 1), endtime=timedelta(seconds=30), config=config) -# # Write to a different folder data for a different day and key -# res2 = csp.run(g, 'a', starttime=datetime(2020, 1, 2), endtime=timedelta(seconds=30), config=config2) -# res2b = csp.run(g, 'b', starttime=datetime(2020, 1, 2), endtime=timedelta(seconds=30), config=config2) - -# config3 = config.copy() -# files1 = g.cached_data(config3)('a').get_data_files_for_period() -# self.assertEqual(len(files1), 1) - -# csp.run(g.cached, 'a', starttime=datetime(2020, 1, 1), endtime=timedelta(seconds=30), config=config) -# with self.assertRaises(NoCachedDataException): -# csp.run(g.cached, 'a', starttime=datetime(2020, 1, 2), endtime=timedelta(seconds=30), config=config) -# with self.assertRaises(NoCachedDataException): -# csp.run(g.cached, 'b', starttime=datetime(2020, 1, 2), endtime=timedelta(seconds=30), config=config) - -# config3.cache_config.read_folders = [config2.cache_config.data_folder] -# files2 = g.cached_data(config3)('a').get_data_files_for_period(missing_range_handler=lambda *args, **kwargs: True) -# self.assertEqual(len(files2), 2) - -# res3 = csp.run(g, 'a', starttime=datetime(2020, 1, 1), endtime=timedelta(seconds=30), config=config2) -# res4 = csp.run(g, 'a', starttime=datetime(2020, 1, 2), endtime=timedelta(seconds=30), config=config2) -# res4b = csp.run(g, 'b', starttime=datetime(2020, 1, 2), endtime=timedelta(seconds=30), config=config2) -# self.assertEqual(res1, res3) -# self.assertEqual(res2, res4) -# self.assertEqual(res2b, res4b) - -# def test_cache_invalidation(self): -# for split_columns_to_file in (True, False): -# with _GraphTempCacheFolderConfig() as config: -# graph_kwargs = self._get_default_graph_caching_kwargs(split_columns_to_file) - -# @csp.graph(cache=True, **graph_kwargs) -# def my_graph(val: str) -> csp.ts[float]: -# return csp.curve(float, [(timedelta(days=i), float(i)) for i in range(5)]) - -# start1 = datetime(2021, 1, 1) -# end1 = start1 + timedelta(days=5, microseconds=-1) - -# self.assertTrue(my_graph.cached_data(config) is None) -# csp.run(my_graph, 'val1', starttime=start1, endtime=end1, config=config) -# # We should be able to invalidate cache when no cached data exists yet -# my_graph.cached_data(config)('val2').invalidate_cache() -# csp.run(my_graph, 'val2', starttime=start1, endtime=end1, config=config) -# cached_data1 = my_graph.cached_data(config)('val1').get_data_df_for_period() -# cached_data2 = my_graph.cached_data(config)('val2').get_data_df_for_period() -# self.assertTrue((cached_data1 == cached_data2).all().all()) -# self.assertEqual(len(cached_data1), 5) -# my_graph.cached_data(config)('val2').invalidate_cache(start1 + timedelta(days=1), end1) -# cached_data2_after_invalidation = my_graph.cached_data(config)('val2').get_data_df_for_period() -# self.assertTrue((cached_data1.head(1) == cached_data2_after_invalidation).all().all()) -# with self.assertRaises(NoCachedDataException): -# csp.run(my_graph.cached, 'val2', starttime=start1, endtime=end1, config=config) -# # this should run fine, we still have data -# csp.run(my_graph.cached, 'val2', starttime=start1, endtime=start1 + timedelta(days=1, microseconds=-1), config=config) -# my_graph.cached_data(config)('val2').invalidate_cache() -# # now we have no data -# with self.assertRaises(NoCachedDataException): -# csp.run(my_graph.cached, 'val2', starttime=start1, endtime=start1 + timedelta(days=1, microseconds=-1), config=config) -# my_graph.cached_data(config)('val1').invalidate_cache() -# self.assertTrue(my_graph.cached_data(config)('val1').get_data_df_for_period() is None) -# # We should still be able to invalidate -# my_graph.cached_data(config)('val1').invalidate_cache() -# # We should have no data in the data folder -# self.assertFalse(os.listdir(os.path.join(my_graph.cached_data(config)._dataset.data_paths.root_folder, 'data'))) - -# def test_controlled_cache(self): -# for default_cache_enabled in (True, False): -# with _GraphTempCacheFolderConfig() as config: -# @csp.graph(cache=True, cache_options=GraphCacheOptions(controlled_cache=True, default_cache_enabled=default_cache_enabled)) -# def graph_unnamed_output1() -> csp.ts[float]: -# csp.set_cache_enable_ts(csp.curve(bool, [(timedelta(seconds=5), True), (timedelta(seconds=6.1), False), (timedelta(seconds=8), True)])) - -# return csp.output(csp.curve(float, [(timedelta(seconds=i), float(i)) for i in range(10)])) - -# @csp.graph(cache=True, cache_options=GraphCacheOptions(controlled_cache=True, default_cache_enabled=default_cache_enabled)) -# def graph_unnamed_output2() -> csp.ts[float]: -# csp.set_cache_enable_ts(csp.curve(bool, [(timedelta(seconds=5), True), (timedelta(seconds=6.1), False), (timedelta(seconds=8), True)])) - -# return (csp.curve(float, [(timedelta(seconds=i), float(i)) for i in range(10)])) - -# @csp.graph(cache=True, cache_options=GraphCacheOptions(controlled_cache=True, default_cache_enabled=default_cache_enabled)) -# def graph_single_named_output() -> csp.Outputs(res1=csp.ts[float]): -# csp.set_cache_enable_ts(csp.curve(bool, [(timedelta(seconds=5), True), (timedelta(seconds=6.1), False), (timedelta(seconds=8), True)])) - -# return csp.output(res1=csp.curve(float, [(timedelta(seconds=i), float(i)) for i in range(10)])) - -# @csp.graph(cache=True, cache_options=GraphCacheOptions(controlled_cache=True, default_cache_enabled=default_cache_enabled)) -# def graph_multiple_outputs() -> csp.Outputs(res1=csp.ts[float], res2=csp.OutputBasket(Dict[str, csp.ts[float]], shape=['value'])): -# csp.set_cache_enable_ts(csp.curve(bool, [(timedelta(seconds=5), True), (timedelta(seconds=6.1), False), (timedelta(seconds=8), True)])) - -# res = csp.curve(float, [(timedelta(seconds=i), float(i)) for i in range(10)]) -# return csp.output(res1=res, res2={'value': res}) - -# @csp.graph(cache=True, cache_options=GraphCacheOptions(controlled_cache=True, default_cache_enabled=default_cache_enabled)) -# def main_graph_cached_named_output() -> csp.Outputs(res1=csp.ts[float]): -# csp.set_cache_enable_ts(csp.curve(bool, [(timedelta(seconds=5), True), (timedelta(seconds=6.1), False), (timedelta(seconds=8), True)])) - -# return csp.output(csp.curve(float, [(timedelta(seconds=i), float(i)) for i in range(10)])) - -# @csp.node(cache=True, cache_options=GraphCacheOptions(controlled_cache=True, default_cache_enabled=default_cache_enabled)) -# def node_unnamed_output() -> csp.ts[float]: -# with csp.alarms(): -# a_enable = csp.alarm( bool ) -# a_value = csp.alarm( float ) -# with csp.start(): -# csp.schedule_alarm(a_enable, timedelta(seconds=5), True) -# csp.schedule_alarm(a_enable, timedelta(seconds=6.1), False) -# csp.schedule_alarm(a_enable, timedelta(seconds=8), True) -# csp.schedule_alarm(a_value, timedelta(), 0) -# if csp.ticked(a_enable): -# csp.enable_cache(a_enable) -# if csp.ticked(a_value): -# if a_value < 9: -# csp.schedule_alarm(a_value, timedelta(seconds=1), a_value + 1) -# if a_value == 6: -# return a_value -# elif a_value == 8: -# return csp.output(a_value) -# raise NotImplementedError() -# else: -# csp.output(a_value) - -# @csp.node(cache=True, cache_options=GraphCacheOptions(controlled_cache=True, default_cache_enabled=default_cache_enabled)) -# def node_single_named_output() -> csp.Outputs(res=csp.ts[float]): -# with csp.alarms(): -# a_enable = csp.alarm( bool ) -# a_value = csp.alarm( float ) -# with csp.start(): -# csp.schedule_alarm(a_enable, timedelta(seconds=5), True) -# csp.schedule_alarm(a_enable, timedelta(seconds=6.1), False) -# csp.schedule_alarm(a_enable, timedelta(seconds=8), True) -# csp.schedule_alarm(a_value, timedelta(), 0) -# if csp.ticked(a_enable): -# csp.enable_cache(a_enable) -# if csp.ticked(a_value): -# if a_value < 9: -# csp.schedule_alarm(a_value, timedelta(seconds=1), a_value + 1) -# if a_value == 6: -# return a_value -# elif a_value == 8: -# return csp.output(res=a_value) -# raise NotImplementedError() -# else: -# csp.output(res=a_value) - -# @csp.node(cache=True, cache_options=GraphCacheOptions(controlled_cache=True, default_cache_enabled=default_cache_enabled)) -# def node_multiple_outputs() -> csp.Outputs(res1=csp.ts[float], res2=csp.OutputBasket(Dict[str, csp.ts[float]], shape=['value'])): -# with csp.alarms(): -# a_enable = csp.alarm( bool ) -# a_value = csp.alarm( float ) -# with csp.start(): -# csp.schedule_alarm(a_enable, timedelta(seconds=5), True) -# csp.schedule_alarm(a_enable, timedelta(seconds=6.1), False) -# csp.schedule_alarm(a_enable, timedelta(seconds=8), True) -# csp.schedule_alarm(a_value, timedelta(), 0) -# if csp.ticked(a_enable): -# csp.enable_cache(a_enable) -# if csp.ticked(a_value): -# if a_value < 9: -# csp.schedule_alarm(a_value, timedelta(seconds=1), a_value + 1) -# csp.output(res2={'value': a_value}) -# return csp.output(res1=a_value) -# raise NotImplementedError() - -# starttime = datetime(2020, 1, 1, 9, 29) -# endtime = starttime + timedelta(minutes=20) - -# results = [] -# for g in [graph_unnamed_output1, graph_unnamed_output2, graph_single_named_output, graph_multiple_outputs, node_unnamed_output, -# node_single_named_output, node_multiple_outputs]: -# csp.run(g, starttime=starttime, endtime=endtime, config=config) -# results.append(g.cached_data(config)().get_data_df_for_period(missing_range_handler=lambda *a, **ka: True)) -# for res in results: -# if default_cache_enabled: -# self.assertEqual(len(res), 9) -# else: -# self.assertEqual(len(res), 4) -# combined_df = pandas.concat(results[:1] + [res.drop(columns=['csp_timestamp']) for res in results], axis=1) -# self.assertEqual(list(combined_df.columns), -# ['csp_timestamp', 'csp_unnamed_output', 'csp_unnamed_output', 'csp_unnamed_output', 'res1', -# ('res1', ''), ('res2', 'value'), 'csp_unnamed_output', 'res', ('res1', ''), ('res2', 'value')]) -# self.assertTrue((combined_df.iloc[:, 1:].diff(axis=1).iloc[:, 1:] == 0).all().all()) - -# def test_controlled_cache_never_set(self): -# """ -# Test that if we never output the controolled set control, we don't get any errors -# :return: -# """ -# for default_cache_enabled in (True, False): -# with _GraphTempCacheFolderConfig() as config: -# @csp.graph(cache=True, cache_options=GraphCacheOptions(controlled_cache=True, default_cache_enabled=default_cache_enabled)) -# def graph_unnamed_output1() -> csp.ts[float]: -# return csp.output(csp.null_ts(float)) - -# @csp.node(cache=True, cache_options=GraphCacheOptions(controlled_cache=True, default_cache_enabled=default_cache_enabled)) -# def node_unnamed_output() -> csp.ts[float]: -# with csp.alarms(): -# a_enable = csp.alarm( bool ) -# a_value = csp.alarm( float ) -# if False: -# csp.output(a_value) - -# starttime = datetime(2020, 1, 1, 9, 29) -# endtime = starttime + timedelta(minutes=20) - -# csp.run(graph_unnamed_output1, starttime=starttime, endtime=endtime, config=config) - -# def test_controlled_cache_bug(self): -# """ -# There was a bug when we run across multiple aggregation periods that the cache enabled was not handled properly, this test was written to reproduce the bug -# and fix it. Here we have aggregation period of 1 month but we are running across 2 months. -# :return: -# """ -# @csp.graph( -# cache=True, -# cache_options=GraphCacheOptions( -# controlled_cache=True, -# time_aggregation=TimeAggregation.MONTH)) -# def cached_g() -> csp.ts[int]: - -# csp.set_cache_enable_ts(csp.curve(bool, [(datetime(2004, 7, 1), True), (datetime(2004, 8, 2), False)])) -# return csp.curve(int, [(datetime(2004, 6, 30), 1), -# (datetime(2004, 7, 1), 2), -# (datetime(2004, 7, 2), 3), -# (datetime(2004, 8, 1), 4), -# (datetime(2004, 8, 1, 1), 5), -# (datetime(2004, 8, 2), 6), -# ]) - -# @csp.graph( -# cache=True, -# cache_options=GraphCacheOptions( -# controlled_cache=True, -# time_aggregation=TimeAggregation.MONTH)) -# def cached_g_struct() -> csp.ts[_DummyStructWithTimestamp]: - -# csp.set_cache_enable_ts(csp.curve(bool, [(datetime(2004, 7, 1), True), (datetime(2004, 8, 2), False)])) -# return _DummyStructWithTimestamp.fromts(val=cached_g()) - -# @csp.graph( -# cache=True, -# cache_options=GraphCacheOptions( -# controlled_cache=True, -# time_aggregation=TimeAggregation.MONTH)) -# def cached_g_basket() -> csp.OutputBasket(Dict[str, csp.ts[int]], shape=['a', 'b']): - -# csp.set_cache_enable_ts(csp.curve(bool, [(datetime(2004, 7, 1), True), (datetime(2004, 8, 2), False)])) -# return {'a': csp.curve(int, [(datetime(2004, 6, 30), 1), -# (datetime(2004, 7, 1), 2), -# (datetime(2004, 7, 2), 3), -# (datetime(2004, 8, 1), 4), -# (datetime(2004, 8, 1, 1), 5), -# (datetime(2004, 8, 2), 6), -# ]), -# 'b': csp.curve(int, [(datetime(2004, 6, 30), 1), (datetime(2004, 8, 1, 1), 5), ]) -# } - -# @csp.graph( -# cache=True, -# cache_options=GraphCacheOptions( -# controlled_cache=True, -# time_aggregation=TimeAggregation.MONTH)) -# def cached_g_basket_struct() -> csp.OutputBasket(Dict[str, csp.ts[_DummyStructWithTimestamp]], shape=['a', 'b']): - -# csp.set_cache_enable_ts(csp.curve(bool, [(datetime(2004, 7, 1), True), (datetime(2004, 8, 2), False)])) -# aux = cached_g_basket() -# return {'a': _DummyStructWithTimestamp.fromts(val=aux['a']), -# 'b': _DummyStructWithTimestamp.fromts(val=aux['b'])} - -# def g(): -# cached_g() -# cached_g_struct() -# cached_g_basket() -# cached_g_basket_struct() - -# with _GraphTempCacheFolderConfig() as config: -# starttime = datetime(2004, 6, 30) -# endtime = datetime(2004, 8, 2, 23, 59, 59, 999999) -# csp.run(g, starttime=starttime, endtime=endtime, config=config) -# df = cached_g.cached_data(config)().get_data_df_for_period() -# self.assertEqual(df.csp_unnamed_output.tolist(), [2, 3, 4, 5]) -# struct_df = cached_g_struct.cached_data(config)().get_data_df_for_period() -# self.assertEqual(struct_df.val.tolist(), [2, 3, 4, 5]) -# basket_df = cached_g_basket.cached_data(config)().get_data_df_for_period() -# self.assertEqual(basket_df.a.tolist(), [2, 3, 4, 5]) -# self.assertEqual(basket_df.b.fillna(-1).tolist(), [-1, -1, -1, 5]) -# struct_basket_df = cached_g_basket_struct.cached_data(config)().get_data_df_for_period() -# self.assertEqual(struct_basket_df.val.a.tolist(), [2, 3, 4, 5]) -# self.assertEqual(struct_basket_df.val.b.fillna(-1).tolist(), [-1, -1, -1, 5]) - -# def test_numpy_1d_array_caching(self): -# for split_columns_to_files in (True, False): -# cache_args = self._get_default_graph_caching_kwargs(split_columns_to_files=split_columns_to_files) - -# for typ in (int, bool, float, str): -# a1 = numpy.array([1, 2, 3, 4, 0], dtype=typ) -# a2 = numpy.array([[1, 2], [3324, 4]], dtype=typ)[:, 0] -# self.assertTrue(a1.flags.c_contiguous) -# self.assertFalse(a2.flags.c_contiguous) - -# @csp.graph(cache=True, **cache_args) -# def g1() -> csp.ts[csp.typing.Numpy1DArray[typ]]: -# return csp.flatten([csp.const(a1), csp.const(a2)]) - -# @csp.graph(cache=True, cache_options=GraphCacheOptions(parquet_output_config=ParquetOutputConfig(batch_size=3), -# split_columns_to_files=split_columns_to_files)) -# def g2() -> csp.ts[csp.typing.Numpy1DArray[typ]]: -# return csp.flatten([csp.const(a1), csp.const(a2)]) - -# @csp.node(cache=True, **cache_args) -# def n1() -> csp.Outputs(arr1=csp.ts[csp.typing.Numpy1DArray[typ]], arr2=csp.ts[numpy.ndarray]): -# with csp.alarms(): -# a_values1 = csp.alarm( csp.typing.Numpy1DArray ) -# a_values2 = csp.alarm( numpy.ndarray ) - -# with csp.start(): -# csp.schedule_alarm(a_values1, timedelta(0), a1) -# csp.schedule_alarm(a_values1, timedelta(seconds=1), a2) -# csp.schedule_alarm(a_values2, timedelta(0), numpy.array([numpy.nan, 1])) -# csp.schedule_alarm(a_values2, timedelta(0), numpy.array([2, numpy.nan, 3])) - -# if csp.ticked(a_values1): -# csp.output(arr1=a_values1) -# if csp.ticked(a_values2): -# csp.output(arr2=a_values2) - -# def verify_equal_array(expected_list, result): -# res_list = [v for t, v in result[0]] -# self.assertEqual(len(expected_list), len(res_list)) -# for e, r in zip(expected_list, res_list): -# self.assertTrue((e == r).all()) - -# def verify_n1_result(expected_list, result): -# verify_equal_array(expected_list, {0: result['arr1']}) -# arr2_values = [v for _, v in result['arr2']] -# expected_arr2_values = numpy.array([numpy.array([numpy.nan, 1.]), numpy.array([2., numpy.nan, 3.])], dtype=object) -# self.assertEqual(len(arr2_values), len(expected_arr2_values)) -# for v1, v2 in zip(arr2_values, expected_arr2_values): -# self.assertTrue(((v1 == v2) | (numpy.isnan(v1) & (numpy.isnan(v1) == numpy.isnan(v1)))).all()) - -# with _GraphTempCacheFolderConfig() as config: -# starttime = datetime(2020, 1, 1, 9, 29) -# endtime = starttime + timedelta(minutes=20) -# expected_list = [a1, a2] - -# res = csp.run(g1, starttime=starttime, endtime=endtime, config=config) -# verify_equal_array(expected_list, res) -# res = csp.run(g1.cached, starttime=starttime, endtime=endtime, config=config) -# verify_equal_array(expected_list, res) -# res = csp.run(n1, starttime=starttime, endtime=endtime, config=config) -# verify_n1_result(expected_list, res) -# res = csp.run(n1.cached, starttime=starttime, endtime=endtime, config=config) -# verify_n1_result(expected_list, res) -# res = csp.run(g2, starttime=starttime, endtime=endtime, config=config) -# verify_equal_array(expected_list, res) -# res = csp.run(g2.cached, starttime=starttime, endtime=endtime, config=config) -# verify_equal_array(expected_list, res) - -# def test_numpy_wrong_type_errors(self): -# @csp.graph(cache=True) -# def g1() -> csp.ts[csp.typing.Numpy1DArray[int]]: -# return csp.const(numpy.zeros(1, dtype=float)) - -# @csp.graph(cache=True) -# def g2() -> csp.ts[csp.typing.Numpy1DArray[object]]: -# return csp.const(numpy.zeros(1, dtype=object)) - -# @csp.node(cache=True) -# def n1() -> csp.ts[csp.typing.Numpy1DArray[int]]: -# with csp.alarms(): -# a_out = csp.alarm( bool ) -# with csp.start(): -# csp.schedule_alarm(a_out, timedelta(), True) -# if csp.ticked(a_out): -# return numpy.zeros(1, dtype=float) - -# with _GraphTempCacheFolderConfig() as config: -# with self.assertRaisesRegex(TSArgTypeMismatchError, re.escape("In function g1: Expected ts[csp.typing.Numpy1DArray[int]] for return value, got ts[csp.typing.Numpy1DArray[float]]")): -# csp.run(g1, starttime=datetime(2020, 1, 1), endtime=timedelta(minutes=20), config=config) - -# with _GraphTempCacheFolderConfig() as config: -# with self.assertRaisesRegex(TypeError, re.escape("Unsupported array value type when writing to parquet:DIALECT_GENERIC")): -# csp.run(g2, starttime=datetime(2020, 1, 1), endtime=timedelta(minutes=20), config=config) - -# with _GraphTempCacheFolderConfig() as config: -# with self.assertRaisesRegex(TypeError, re.escape("Expected array of type dtype('int64') got dtype('float64')")): -# csp.run(n1, starttime=datetime(2020, 1, 1), endtime=timedelta(minutes=20), config=config) - -# def test_basket_array_caching(self): -# @csp.graph(cache=True) -# def g1() -> csp.OutputBasket(Dict[str, csp.ts[csp.typing.Numpy1DArray[int]]], shape=['a']): -# a = numpy.zeros(3, dtype=int) -# return { -# 'a': csp.const(a) -# } - -# with _GraphTempCacheFolderConfig() as config: -# starttime = datetime(2020, 1, 1, 9, 29) -# endtime = starttime + timedelta(minutes=20) - -# with self.assertRaisesRegex(NotImplementedError, re.escape('Writing of baskets with array values is not supported')): -# res = csp.run(g1, starttime=starttime, endtime=endtime, config=config) - -# def test_multi_dimensional_array_caching(self): -# a1 = numpy.array([1, 2, 3, 4, 0], dtype=float) -# a2 = numpy.array([[1, 2], [3, 4]], dtype=float) -# expected_df = pandas.DataFrame.from_dict({'csp_timestamp': [pytz.utc.localize(datetime(2020, 1, 1, 9, 29))] * 2, 'csp_unnamed_output': [a1, a2]}) -# for split_columns_to_files in (True, False): -# cache_args = self._get_default_graph_caching_kwargs(split_columns_to_files=split_columns_to_files) - -# @csp.graph(cache=True, **cache_args) -# def g1() -> csp.ts[csp.typing.NumpyNDArray[float]]: -# return csp.flatten([csp.const(a1), csp.const(a2)]) - -# @csp.graph(cache=True, **cache_args) -# def g2() -> csp.Outputs(res=csp.ts[csp.typing.NumpyNDArray[float]]): -# return csp.output(res=g1()) - -# with _GraphTempCacheFolderConfig() as config: -# starttime = datetime(2020, 1, 1, 9, 29) -# endtime = starttime + timedelta(minutes=20) -# csp.run(g2, starttime=starttime, endtime=endtime, config=config) -# res_cached = csp.run(g1.cached, starttime=starttime, endtime=endtime, config=config) -# df = g1.cached_data(config)().get_data_df_for_period() -# df2 = g2.cached_data(config)().get_data_df_for_period() -# self.assertEqual(len(df), len(expected_df)) -# self.assertTrue((df.csp_timestamp == expected_df.csp_timestamp).all()) -# self.assertTrue(all([(v1 == v2).all() for v1, v2 in zip(df['csp_unnamed_output'], expected_df['csp_unnamed_output'])])) -# cached_values = list(zip(*res_cached[0]))[1] -# self.assertTrue(all([(v1 == v2).all() for v1, v2 in zip(cached_values, expected_df['csp_unnamed_output'])])) -# # We need to check the named column as well. -# self.assertEqual(len(df2), len(expected_df)) -# self.assertTrue(all([(v1 == v2).all() for v1, v2 in zip(df2['res'], expected_df['csp_unnamed_output'])])) - -# def test_read_folder_data_load_as_df(self): -# @csp.graph(cache=True) -# def g1() -> csp.ts[float]: -# return csp.const(42.0) - -# with _GraphTempCacheFolderConfig() as config1: -# with _GraphTempCacheFolderConfig() as config2: -# starttime = datetime(2020, 1, 1, 9, 29) -# endtime = starttime + timedelta(minutes=20) -# csp.run(g1, starttime=starttime, endtime=endtime, config=config1) -# res1_cached = g1.cached_data(config1)().get_data_df_for_period() -# config2.cache_config.read_folders = [config1.cache_config.data_folder] -# res2_cached = g1.cached_data(config2)().get_data_df_for_period() -# self.assertTrue((res1_cached == res2_cached).all().all()) - -# def test_multiple_readers_different_shapes(self): -# @csp.graph(cache=True, cache_options=GraphCacheOptions(ignored_inputs={'shape', 'dummy'})) -# def g(shape: [str], dummy: object) -> csp.OutputBasket(Dict[str, csp.ts[str]], shape="shape"): -# res = {} -# for v in shape: -# res[v] = csp.const(f'{v}_value') -# return res - -# @csp.graph -# def read_g(): -# __outputs__(v1={str: csp.ts[str]}, v2={str: csp.ts[str]}) -# df = pandas.DataFrame({'dummy': [1]}) - -# v2_a = g.cached(['b', 'c', 'd'], df) -# v2_b = g.cached(['b', 'c', 'd'], df) -# assert id(v2_a) == id(v2_b) - -# return csp.output(v1=g.cached(['a', 'b'], df), v2=v2_a) - -# with _GraphTempCacheFolderConfig() as config: -# starttime = datetime(2020, 1, 1, 9, 29) -# endtime = starttime + timedelta(minutes=1) -# csp.run(g, ['a', 'b', 'c', 'd'], None, starttime=starttime, endtime=endtime, config=config) -# res = csp.run(read_g, starttime=starttime, endtime=endtime, config=config) -# self.assertEqual(sorted(res.keys()), ['v1[a]', 'v1[b]', 'v2[b]', 'v2[c]', 'v2[d]']) - -# def test_basket_partial_cache_load(self): -# @csp.graph(cache=True) -# def g1() -> csp.OutputBasket(Dict[str, csp.ts[TypedCurveGenerator.SimpleStruct]], shape=['A', 'B']): -# return {'A': csp.curve(TypedCurveGenerator.SimpleStruct, [(timedelta(seconds=0), TypedCurveGenerator.SimpleStruct(value1=0)), -# (timedelta(seconds=1), TypedCurveGenerator.SimpleStruct(value1=2)), -# ]), -# 'B': csp.curve(TypedCurveGenerator.SimpleStruct, [(timedelta(seconds=1), TypedCurveGenerator.SimpleStruct(value1=3)), -# (timedelta(seconds=2), TypedCurveGenerator.SimpleStruct(value1=4)), -# ]) -# } - -# @csp.graph(cache=True) -# def g2() -> csp.Outputs(my_named_output=csp.OutputBasket(Dict[str, csp.ts[TypedCurveGenerator.SimpleStruct]], shape=['A', 'B'])): -# return csp.output(my_named_output=g1()) - -# with _GraphTempCacheFolderConfig() as config: -# starttime = datetime(2020, 1, 1, 9, 29) -# endtime = starttime + timedelta(minutes=1) -# csp.run(g2, starttime=starttime, endtime=endtime, config=config) -# res1 = g1.cached_data(config)().get_data_df_for_period() -# res2 = g1.cached_data(config)().get_data_df_for_period(struct_basket_sub_columns={'': ['value1']}) -# with self.assertRaisesRegex(RuntimeError, re.escape("Specified sub columns for basket 'csp_unnamed_output' but it's not loaded from file") + '.*'): -# res3 = g2.cached_data(config)().get_data_df_for_period(struct_basket_sub_columns={'': ['value1']}) -# res3 = g2.cached_data(config)().get_data_df_for_period(struct_basket_sub_columns={'my_named_output': ['value1']}) - -# self.assertEqual(res1.columns.levels[0].to_list(), ['csp_timestamp', 'value1', 'value2']) -# self.assertEqual(res1.columns.levels[1].to_list(), ['', 'A', 'B']) - -# self.assertEqual(res2.columns.levels[0].to_list(), ['csp_timestamp', 'value1']) -# self.assertEqual(res2.columns.levels[1].to_list(), ['', 'A', 'B']) - -# self.assertEqual(res3.columns.levels[0].to_list(), ['csp_timestamp', 'my_named_output.value1']) -# self.assertEqual(res3.columns.levels[1].to_list(), ['', 'A', 'B']) - -# self.assertTrue((res1['value1'].fillna(-111111) == res2['value1'].fillna(-111111)).all().all()) -# self.assertTrue((res1['value1'].fillna(-111111) == res3['my_named_output.value1'].fillna(-111111)).all().all()) -# self.assertTrue((res1['csp_timestamp'] == res2['csp_timestamp']).all().all()) -# self.assertTrue((res1['csp_timestamp'] == res3['csp_timestamp']).all().all()) - -# def test_partition_retrieval(self): -# @csp.graph(cache=True) -# def g1(i: int, d: date, dt: datetime, td: timedelta, f_val: float, s: str, b: bool, struct: TypedCurveGenerator.SimpleStruct) -> csp.Outputs(v1=csp.ts[TypedCurveGenerator.SimpleStruct], v2=csp.ts[float]): -# return csp.output(v1=csp.const(struct), v2=csp.const(f_val)) - -# @csp.graph(cache=True) -# def g2() -> csp.ts[int]: -# return csp.const(42) - -# s1 = datetime(2020, 1, 1) -# e1 = s1 + timedelta(hours=70, microseconds=-1) -# with _GraphTempCacheFolderConfig() as config: -# # i: int, d: date, dt: datetime, td: timedelta, f: float, s: str, b: bool, struct: TypedCurveGenerator.SimpleStruct -# csp.run(g1, i=1, d=date(2013, 5, 8), dt=datetime(2025, 3, 6, 11, 20, 59, 999599), td=timedelta(seconds=5), f_val=5.3, s="test1", b=False, -# struct=TypedCurveGenerator.SimpleStruct(value1=53), -# starttime=s1, endtime=e1, config=config) -# csp.run(g1, i=52, d=date(2013, 5, 31), dt=datetime(2025, 3, 5), td=timedelta(days=100), f_val=7.8, s="test2", b=True, struct=TypedCurveGenerator.SimpleStruct(value1=-53), starttime=s1, -# endtime=e1, config=config) -# csp.run(g2, starttime=s1, endtime=e1, config=config) -# g1_keys = g1.cached_data(config).get_partition_keys() -# g2_keys = g2.cached_data(config).get_partition_keys() -# self.assertEqual([DatasetPartitionKey({'i': 1, -# 'd': date(2013, 5, 8), -# 'dt': datetime(2025, 3, 6, 11, 20, 59, 999599), -# 'td': timedelta(seconds=5), -# 'f_val': 5.3, -# 's': 'test1', -# 'b': False, -# 'struct': TypedCurveGenerator.SimpleStruct(value1=53)}), -# DatasetPartitionKey({'i': 52, -# 'd': date(2013, 5, 31), -# 'dt': datetime(2025, 3, 5), -# 'td': timedelta(days=100), -# 'f_val': 7.8, -# 's': 'test2', -# 'b': True, -# 'struct': TypedCurveGenerator.SimpleStruct(value1=-53)})], -# g1_keys) -# # self.assertEqual(len(g1_keys), 2) -# self.assertEqual([DatasetPartitionKey({})], g2_keys) -# df1 = g1.cached_data(config)(**g1_keys[0].kwargs).get_data_df_for_period() -# df2 = g2.cached_data(config)(**g2_keys[0].kwargs).get_data_df_for_period() -# self.assertTrue((df1.fillna(-42) == pandas.DataFrame({'csp_timestamp': [pytz.utc.localize(s1)], 'v1.value1': [53], 'v1.value2': [-42], 'v2': [5.3]})).all().all()) -# self.assertTrue((df2 == pandas.DataFrame({'csp_timestamp': [pytz.utc.localize(s1)], 'csp_unnamed_output': [42]})).all().all()) - - -# if __name__ == '__main__': -# unittest.main() diff --git a/csp/tests/test_engine.py b/csp/tests/test_engine.py index 028b5d6f..09bb25ec 100644 --- a/csp/tests/test_engine.py +++ b/csp/tests/test_engine.py @@ -31,18 +31,6 @@ def _dummy_node(): raise NotImplementedError() -@csp.graph(cache=True) -def _dummy_graph_cached() -> csp.ts[float]: - raise NotImplementedError() - return csp.const(1) - - -@csp.node(cache=True) -def _dummy_node_cached() -> csp.ts[float]: - raise NotImplementedError() - return 1 - - class TestEngine(unittest.TestCase): def test_simple(self): @csp.node @@ -1303,7 +1291,7 @@ def my_node(val: int) -> ts[int]: def dummy(v: ts[int]) -> ts[int]: return v - @csp.graph(cache=True) + @csp.graph def my_ranked_node(val: int, rank: int = 0) -> csp.Outputs(val=ts[int]): res = my_node(val) for i in range(rank): @@ -1833,12 +1821,10 @@ def test_graph_node_pickling(self): """Checks for a bug that we had when transitioning to python 3.8 - the graphs and nodes became unpicklable :return: """ - from csp.tests.test_engine import _dummy_graph, _dummy_graph_cached, _dummy_node, _dummy_node_cached + from csp.tests.test_engine import _dummy_graph, _dummy_node self.assertEqual(_dummy_graph, pickle.loads(pickle.dumps(_dummy_graph))) self.assertEqual(_dummy_node, pickle.loads(pickle.dumps(_dummy_node))) - self.assertEqual(_dummy_graph_cached, pickle.loads(pickle.dumps(_dummy_graph_cached))) - self.assertEqual(_dummy_node_cached, pickle.loads(pickle.dumps(_dummy_node_cached))) def test_memoized_object(self): @csp.csp_memoized @@ -2078,8 +2064,9 @@ def raise_interrupt(): csp.schedule_alarm(a, timedelta(seconds=1), True) if csp.ticked(a): import signal + os.kill(os.getpid(), signal.SIGINT) - + # Python nodes @csp.graph def g(l: list): @@ -2094,12 +2081,12 @@ def g(l: list): for element in stopped: self.assertTrue(element) - + # C++ nodes class RTI: def __init__(self): self.stopped = [False, False, False] - + @csp.node(cppimpl=_csptestlibimpl.set_stop_index) def n2(obj_: object, idx: int): return @@ -2114,7 +2101,7 @@ def g2(rti: RTI): rti = RTI() with self.assertRaises(KeyboardInterrupt): csp.run(g2, rti, starttime=datetime.utcnow(), endtime=timedelta(seconds=60), realtime=True) - + for element in rti.stopped: self.assertTrue(element) diff --git a/csp/tests/test_parsing.py b/csp/tests/test_parsing.py index 5fb1fdba..18ec0c6d 100644 --- a/csp/tests/test_parsing.py +++ b/csp/tests/test_parsing.py @@ -1,6 +1,6 @@ import sys import unittest -from datetime import date, datetime, timedelta +from datetime import datetime, timedelta from typing import Callable, Dict, List import csp @@ -986,37 +986,6 @@ def test_list_inside_callable(self): def graph(v: Dict[str, Callable[[], str]]): pass - def test_graph_caching_parsing(self): - with self.assertRaisesRegex( - NotImplementedError, "Caching is unsupported for argument type typing.List\\[int\\] \\(argument x\\)" - ): - - @csp.graph(cache=True) - def graph(x: List[int]): - __outputs__(o=ts[int]) - pass - - with self.assertRaisesRegex( - NotImplementedError, "Caching is unsupported for argument type typing.Dict\\[int, int\\] \\(argument x\\)" - ): - - @csp.graph(cache=True) - def graph(x: Dict[int, int]): - __outputs__(o=ts[int]) - pass - - with self.assertRaisesRegex(NotImplementedError, "Caching of list basket outputs is unsupported"): - - @csp.graph(cache=True) - def graph(): - __outputs__(o=[ts[int]]) - pass - - @csp.graph(cache=True) - def graph(a1: datetime, a2: date, a3: int, a4: float, a5: str, a6: bool): - __outputs__(o=ts[int]) - pass - def test_list_default_value(self): # There was a bug parsing list default value @csp.graph