From d8302131e0bd88d942747199ab5974f2f9936210 Mon Sep 17 00:00:00 2001 From: jlewitt1 Date: Thu, 11 Jul 2024 23:26:21 +0300 Subject: [PATCH] remove kvstore, mapper, tables and queue --- docs/api/python.rst | 15 +- docs/api/python/mapper.rst | 42 - docs/api/python/run.rst | 2 +- docs/api/python/table.rst | 18 - docs/architecture.rst | 4 +- runhouse/__init__.py | 4 - runhouse/resources/functionals/__init__.py | 0 runhouse/resources/functionals/mapper.py | 306 ------ runhouse/resources/kvstores/__init__.py | 2 - runhouse/resources/kvstores/kvstore.py | 80 -- runhouse/resources/queues/__init__.py | 2 - runhouse/resources/queues/queue.py | 86 -- runhouse/resources/tables/__init__.py | 5 - runhouse/resources/tables/dask_table.py | 33 - .../resources/tables/huggingface_table.py | 88 -- runhouse/resources/tables/pandas_table.py | 38 - runhouse/resources/tables/rapids_table.py | 25 - runhouse/resources/tables/ray_table.py | 25 - runhouse/resources/tables/table.py | 591 ----------- tests/conftest.py | 13 - .../test_clusters/test_cluster.py | 42 - .../test_functions/test_mapper.py | 199 ---- .../test_modules/test_tables/__init__.py | 0 .../test_modules/test_tables/conftest.py | 65 -- .../test_modules/test_tables/table_tests.py | 919 ------------------ 25 files changed, 4 insertions(+), 2600 deletions(-) delete mode 100644 docs/api/python/mapper.rst delete mode 100644 docs/api/python/table.rst delete mode 100644 runhouse/resources/functionals/__init__.py delete mode 100644 runhouse/resources/functionals/mapper.py delete mode 100644 runhouse/resources/kvstores/__init__.py delete mode 100644 runhouse/resources/kvstores/kvstore.py delete mode 100644 runhouse/resources/queues/__init__.py delete mode 100644 runhouse/resources/queues/queue.py delete mode 100644 runhouse/resources/tables/__init__.py delete mode 100644 runhouse/resources/tables/dask_table.py delete mode 100644 runhouse/resources/tables/huggingface_table.py delete mode 100644 runhouse/resources/tables/pandas_table.py delete mode 100644 runhouse/resources/tables/rapids_table.py delete mode 100644 runhouse/resources/tables/ray_table.py delete mode 100644 runhouse/resources/tables/table.py delete mode 100644 tests/test_resources/test_modules/test_functions/test_mapper.py delete mode 100644 tests/test_resources/test_modules/test_tables/__init__.py delete mode 100644 tests/test_resources/test_modules/test_tables/conftest.py delete mode 100644 tests/test_resources/test_modules/test_tables/table_tests.py diff --git a/docs/api/python.rst b/docs/api/python.rst index f63961e01..cee067dfb 100644 --- a/docs/api/python.rst +++ b/docs/api/python.rst @@ -7,7 +7,7 @@ Resources ------------------------------------ Resources are the Runhouse abstraction for objects that can be saved, shared, and reused. This includes both compute abstractions (clusters, functions, packages, environments) and -data abstractions (blobs, folders, tables). +data abstractions (files, blobs, folders). .. toctree:: :maxdepth: 1 @@ -48,15 +48,10 @@ and rich debugging and accessibility interfaces built-in. python/module -.. toctree:: - :maxdepth: 1 - - python/mapper - Data Abstractions ------------------------------------ -The Folder, Table, Blob, and File APIs provide a simple interface for storing, recalling, and moving data between +The Folder, Blob, and File APIs provide a simple interface for storing, recalling, and moving data between the user's laptop, remote compute, cloud storage, and specialized storage (e.g. data warehouses). They provide least-common-denominator APIs across providers, allowing users to easily specify the actions they want to take on the data without needed to dig into provider-specific APIs. We'd like to extend this @@ -67,12 +62,6 @@ to other data concepts in the future, like kv-stores, time-series, vector and gr python/folder - -.. toctree:: - :maxdepth: 1 - - python/table - .. toctree:: :maxdepth: 1 diff --git a/docs/api/python/mapper.rst b/docs/api/python/mapper.rst deleted file mode 100644 index 3b7dbdef3..000000000 --- a/docs/api/python/mapper.rst +++ /dev/null @@ -1,42 +0,0 @@ -Mapper -==================================== -Mapper is a built-in Module for parallelizing a function or module method over a list of inputs. It -holds a pool of replicas of the function or module, distributes the inputs to the replicas, and collects -the results. The replicas are either created by the mapper automatically, or they can be created by the user -and passed into the mapper (or a mix of both). The advantage of that flexibility is that the mapper can call -replicas on totally different infrastructure (e.g. if you have two different clusters). - -When the mapper creates the replicas, it creates duplicate envs -of the original mapped module's env (e.g. `env_replica_1`), and sends the module into the replica env on the same -cluster, thus creating many modules each in separate processes (and potentially separate nodes). Keep in mind that -you must specify the compute resources (e.g. `compute={"CPU": 0.5}`) in the `env` constructor if you have a multinode -cluster and want the replica envs to overflow onto worker nodes. - -The mapper then simply calls each in a threadpool and collects the results. By default, the threadpool is the same -size as the number of replicas, so each thread blocks until a replica is available. You can control this by setting -`concurrency`, which is the number of simultaneous calls that can be made to any replica (e.g. if concurrency is 2, -then 2 threads will be calling each replica at the same time). - -The mapper can either sit locally or on a cluster, but you should generally put it on the cluster if you can. -If the mapper is local, you'll need to send the mapped module to the cluster before passing it to the mapper, -and the mapper will create each new replica on the cluster remotely, which will take longer. The communication -between the mapper and the replicas will also be faster and more dependable if they are on the same cluster. -Just note that if you're sending or returning a large amount of data, it may take time to transfer before or after -you see the loading bar when the mapper is actually processing. Generally you'll get the best performance if you -read and write to blob storage or the filesystem rather than sending the data around. - - -Mapper Factory Method -~~~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: runhouse.mapper - - -Mapper Class -~~~~~~~~~~~~ - -.. autoclass:: runhouse.Mapper - :members: - :exclude-members: - - .. automethod:: __init__ diff --git a/docs/api/python/run.rst b/docs/api/python/run.rst index 624316b86..1e8290fbb 100644 --- a/docs/api/python/run.rst +++ b/docs/api/python/run.rst @@ -64,7 +64,7 @@ A Run may contain some (or all) of these core components: - **Downstream dependencies**: Runhouse artifacts saved by the Run. .. note:: - Artifacts represent any Runhouse primitive (e.g. :ref:`Blob`, :ref:`Function`, :ref:`Table`, etc.) that is + Artifacts represent any Runhouse primitive (e.g. :ref:`Blob`, :ref:`Function`, etc.) that is loaded or saved by the Run. diff --git a/docs/api/python/table.rst b/docs/api/python/table.rst deleted file mode 100644 index 23c743abd..000000000 --- a/docs/api/python/table.rst +++ /dev/null @@ -1,18 +0,0 @@ -Table -==================================== -A Table is a Runhouse primitive used for abstracting a particular tabular data storage configuration. - - -Table Factory Method -~~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: runhouse.table - -Table Class -~~~~~~~~~~~ - -.. autoclass:: runhouse.Table - :members: - :exclude-members: - - .. automethod:: __init__ diff --git a/docs/architecture.rst b/docs/architecture.rst index 40aa79937..7bc4ee710 100644 --- a/docs/architecture.rst +++ b/docs/architecture.rst @@ -8,7 +8,7 @@ Runhouse Resources Resources are the Runhouse primitive for objects that can be saved, shared, and reused. This can be split into compute resources (clusters, functions, modules, environments, and runs) and data resources -(folder, table, blob, etc). +(folder, blob, etc). Compute ------- @@ -41,8 +41,6 @@ dig into provider-specific APIs. * **Folder**: Represents a specified location (could be local, remote, or file storage), for managing where various Runhouse resources live. -* **Table**: Provides convenient APIs for writing, partitioning, fetch, and stream various data types. - * **Blob**: Represents a data object stored in a particular system. * **File**: Represents a file object stored in a particular system. diff --git a/runhouse/__init__.py b/runhouse/__init__.py index 22349b14d..b0165de58 100644 --- a/runhouse/__init__.py +++ b/runhouse/__init__.py @@ -2,7 +2,6 @@ from runhouse.resources.blobs import blob, Blob, file, File from runhouse.resources.envs import conda_env, CondaEnv, env, Env from runhouse.resources.folders import Folder, folder, GCSFolder, S3Folder -from runhouse.resources.functionals.mapper import Mapper, mapper from runhouse.resources.functions.aws_lambda import LambdaFunction from runhouse.resources.functions.aws_lambda_factory import aws_lambda_fn from runhouse.resources.functions.function import Function @@ -19,14 +18,11 @@ # WARNING: Any built-in module that is imported here must be capitalized followed by all lowercase, or we will # will not find the module class when attempting to reconstruct it from a config. -from runhouse.resources.kvstores.kvstore import Kvstore from runhouse.resources.module import Module, module from runhouse.resources.packages import git_package, GitPackage, package, Package from runhouse.resources.provenance import capture_stdout, Run, run, RunStatus, RunType -from runhouse.resources.queues import Queue from runhouse.resources.resource import Resource from runhouse.resources.secrets import provider_secret, ProviderSecret, Secret, secret -from runhouse.resources.tables import Table, table from runhouse.rns.top_level_rns_fns import ( as_caller, diff --git a/runhouse/resources/functionals/__init__.py b/runhouse/resources/functionals/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/runhouse/resources/functionals/mapper.py b/runhouse/resources/functionals/mapper.py deleted file mode 100644 index 2a7ac5666..000000000 --- a/runhouse/resources/functionals/mapper.py +++ /dev/null @@ -1,306 +0,0 @@ -import concurrent.futures -import contextvars -from typing import Callable, List, Optional, Union - -try: - from tqdm import tqdm -except ImportError: - - def tqdm(*args, **kwargs): - return args[0] - - -from runhouse.logger import logger -from runhouse.resources.envs.env import Env -from runhouse.resources.functions import function, Function -from runhouse.resources.hardware.cluster import Cluster - -from runhouse.resources.module import Module - - -class Mapper(Module): - def __init__( - self, - module: Module = None, - method: str = None, - replicas: Union[None, int, List[Module]] = None, - concurrency=1, - **kwargs, - ): - """ - Runhouse Mapper object. It is used for mapping a function or module method over a list of inputs, - across a series of replicas. - - .. note:: - To create a Mapper, please use the factory method :func:`mapper`. - """ - super().__init__(**kwargs) - self.module = module - self.method = method - self.concurrency = concurrency - self._num_auto_replicas = None - self._auto_replicas = [] - self._user_replicas = [] - self._last_called = 0 - if isinstance(replicas, int): - if self.module.system: - # Only add replicas if the replicated module is already on a cluster - if replicas > self.num_replicas and replicas > 0: - self.add_replicas(replicas) - else: - # Otherwise, store this for later once we've sent the mapper to the cluster - self._num_auto_replicas = replicas - elif isinstance(replicas, list): - self._user_replicas = replicas - - @property - def replicas(self): - return [self.module] + self._auto_replicas + self._user_replicas - - @property - def num_replicas(self): - return len(self.replicas) - - def add_replicas(self, replicas: Union[int, List[Module]]): - if isinstance(replicas, int): - new_replicas = replicas - self.num_replicas - logger.info(f"Adding {new_replicas} replicas") - self._add_auto_replicas(new_replicas) - else: - self._user_replicas.extend(replicas) - - def drop_replicas(self, num_replicas: int, reap: bool = True): - if reap: - for replica in self._auto_replicas[-num_replicas:]: - replica.system.kill(replica.env.name) - self._auto_replicas = self._auto_replicas[:-num_replicas] - - def _add_auto_replicas(self, num_replicas: int): - self._auto_replicas.extend(self.module.replicate(num_replicas)) - - def increment_counter(self): - self._last_called += 1 - if self._last_called >= len(self.replicas): - self._last_called = 0 - return self._last_called - - def to( - self, - system: Union[str, Cluster], - env: Optional[Union[str, List[str], Env]] = None, - name: Optional[str] = None, - force_install: bool = False, - ): - """Put a copy of the Mapper and its internal module on the destination system and env, and - return the new mapper. - - Example: - >>> local_mapper = rh.mapper(my_module, replicas=2) - >>> cluster_mapper = local_mapper.to(my_cluster) - """ - if not self.module.system: - # Note that we don't pass name here, as this is the name meant for the mapper - self.module = self.module.to( - system=system, env=env, force_install=force_install - ) - remote_mapper = super().to( - system=system, env=env, name=name, force_install=force_install - ) - - if isinstance(self._num_auto_replicas, int): - remote_mapper.add_replicas(self._num_auto_replicas) - return remote_mapper - - def map(self, *args, method: Optional[str] = None, retries: int = 0, **kwargs): - """Map the function or method over a list of arguments. - - Example: - >>> mapper = rh.mapper(local_sum, replicas=2).to(my_cluster) - >>> mapper.map([1, 2], [1, 4], [2, 3], retries=3) - - >>> # If you're mapping over a remote module, you can choose not to specify which method to call initially - >>> # so you can call different methods in different maps (note that our replicas can hold state!) - >>> # Note that in the example below we're careful to use the same number of replicas as data - >>> # we have to process, or the state in a replica would be overwritten by the next call. - >>> shards = len(source_paths) - >>> mapper = rh.mapper(remote_module, replicas=shards).to(my_cluster) - >>> mapper.map(*source_paths, method="load_data") - >>> mapper.map([]*shards, method="process_data") # Calls each replica once with no args - >>> mapper.map(*output_paths, method="save_data") - """ - # Don't stream logs by default unless the mapper is remote (i.e. mediating the mapping) - return self.starmap( - arg_list=zip(*args), method=method, retries=retries, **kwargs - ) - - def starmap( - self, arg_list: List, method: Optional[str] = None, retries: int = 0, **kwargs - ): - """Like :func:`map` except that the elements of the iterable are expected to be iterables - that are unpacked as arguments. An iterable of ``[(1,2), (3, 4)]`` results in - ``func(1,2), func(3,4)]``. - - Example: - >>> def local_sum(arg1, arg2, arg3): - >>> return arg1 + arg2 + arg3 - >>> - >>> remote_fn = rh.function(local_sum).to(my_cluster) - >>> mapper = rh.mapper(remote_fn, replicas=2) - >>> arg_list = [(1,2), (3, 4)] - >>> # runs the function twice, once with args (1, 2) and once with args (3, 4) - >>> mapper.starmap(arg_list) - """ - # Don't stream logs by default unless the mapper is remote (i.e. mediating the mapping) - if self.system and not self.system.on_this_cluster(): - kwargs["stream_logs"] = kwargs.get("stream_logs", True) - else: - kwargs["stream_logs"] = kwargs.get("stream_logs", False) - - retry_list = [] - - def call_method_on_replica(job, retry=True): - replica, method_name, context, argies, kwargies = job - # reset context - for var, value in context.items(): - var.set(value) - - try: - return getattr(replica, method_name)(*argies, **kwargies) - except Exception as e: - logger.error(f"Error running {method_name} on {replica.name}: {e}") - if retry: - retry_list.append(job) - else: - return e - - context = contextvars.copy_context() - jobs = [ - ( - self.replicas[self.increment_counter()], - method or self.method, - context, - args, - kwargs, - ) - for args in arg_list - ] - - results = [] - max_threads = round(self.concurrency * self.num_replicas) - with concurrent.futures.ThreadPoolExecutor(max_workers=max_threads) as executor: - futs = [ - executor.submit(call_method_on_replica, job, retries > 0) - for job in jobs - ] - for fut in tqdm(concurrent.futures.as_completed(futs), total=len(jobs)): - results.extend([fut.result()]) - for i in range(retries): - if len(retry_list) == 0: - break - logger.info(f"Retry {i}: {len(retry_list)} failed jobs") - jobs, retry_list = retry_list, [] - retry = i != retries - 1 - results.append( - list( - tqdm( - executor.map(call_method_on_replica, jobs, retry), - total=len(jobs), - ) - ) - ) - - return results - - # TODO should we add an async version of this for when we're on the cluster? - # async def call_method_on_args(argies): - # return getattr(self.replicas[self.increment_counter()], self.method)(*argies, **kwargs) - # - # async def gather(): - # return await asyncio.gather( - # *[ - # call_method_on_args(args) - # for args in zip(*args) - # ] - # ) - # return asyncio.run(gather()) - - def call(self, *args, method: Optional[str] = None, **kwargs): - """Call the function or method on a single replica. - - Example: - >>> def local_sum(arg1, arg2, arg3): - >>> return arg1 + arg2 + arg3 - >>> - >>> remote_fn = rh.function(local_sum).to(my_cluster) - >>> mapper = rh.mapper(remote_fn, replicas=2) - >>> for i in range(10): - >>> mapper.call(i, 1, 2) - >>> # output: 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, run in round-robin replica order - - """ - return getattr(self.replicas[self.increment_counter()], method or self.method)( - *args, **kwargs - ) - - -def mapper( - module: Union[Module, Callable], - method: Optional[str] = None, - replicas: Union[None, int, List[Module]] = None, - concurrency: int = 1, - **kwargs, -) -> Mapper: - """ - A factory method for creating Mapper modules. A mapper is a module that can map a function or module method over - a list of inputs in various ways. - - Args: - module (Module): The module or function to be mapped. - method (Optional[str], optional): The method of the module to be called. If the module is already a callable, - this value defaults to ``"call"``. - concurrency (int, optional): The number of concurrent calls to each replica, executed in separate threads. - Defaults to 1. - replicas (Optional[List[Module]], optional): List of user-specified replicas, or an int specifying the number - of replicas to be automatically created. Defaults to None. - - Returns: - Mapper: The resulting Mapper object. - - Example: - >>> def local_sum(arg1, arg2, arg3): - >>> return arg1 + arg2 + arg3 - >>> - >>> # Option 1: Pass a function directly to the mapper, and send both to the cluster - >>> mapper = rh.mapper(local_sum, replicas=2).to(my_cluster) - >>> mapper.map([1, 2], [1, 4], [2, 3]) - - >>> # Option 2: Create a remote module yourself and pass it to the mapper, which is still local - >>> remote_fn = rh.function(local_sum).to(my_cluster, env=my_fn_env) - >>> mapper = rh.mapper(remote_fn, replicas=2) - >>> mapper.map([1, 2], [1, 4], [2, 3]) - >>> # output: [4, 9] - - >>> # Option 3: Create a remote module and mapper for greater flexibility, and send both to the cluster - >>> # You can map over a "class" module (stateless) or an "instance" module to preserve state - >>> remote_class = rh.module(cls=MyClass).to(system=cluster, env=my_module_env) - >>> stateless_mapper = rh.mapper(remote_class, method="my_class_method", replicas=2).to(cluster) - >>> mapper.map([1, 2], [1, 4], [2, 3]) - - >>> remote_app = remote_class() - >>> stateful_mapper = rh.mapper(remote_app, method="my_instance_method", replicas=2).to(cluster) - >>> mapper.map([1, 2], [1, 4], [2, 3]) - """ - - if callable(module) and not isinstance(module, Module): - module = function(module, **kwargs) - - if isinstance(module, Function): - method = method or "call" - - return Mapper( - module=module, - method=method, - replicas=replicas, - concurrency=concurrency, - **kwargs, - ) diff --git a/runhouse/resources/kvstores/__init__.py b/runhouse/resources/kvstores/__init__.py deleted file mode 100644 index 2a0b53c7f..000000000 --- a/runhouse/resources/kvstores/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# TODO Redis, Mongo, Feast, BigTable, DynamoDB, etc. -from .kvstore import Kvstore diff --git a/runhouse/resources/kvstores/kvstore.py b/runhouse/resources/kvstores/kvstore.py deleted file mode 100644 index aca8e3844..000000000 --- a/runhouse/resources/kvstores/kvstore.py +++ /dev/null @@ -1,80 +0,0 @@ -from typing import Any, Optional, Union - -from runhouse import Cluster, Env -from runhouse.resources.module import Module - - -class Kvstore(Module): - RESOURCE_TYPE = "kvstore" - DEFAULT_CACHE_FOLDER = ".cache/runhouse/kvstores" - - """Simple dict wrapper to act as key-value/object storage. Wrapping this in an actor allows us to - access it across Ray processes and nodes, and even keep some things pinned to Python memory.""" - - def __init__( - self, - name: Optional[str] = None, - system: Union[Cluster, str] = None, - env: Optional[Env] = None, - dryrun: bool = False, - **kwargs, - ): - """ - Runhouse KVStore object - - .. note:: - To build a KVStore, please use the factory method :func:`kvstore`. - """ - super().__init__(name=name, system=system, env=env, dryrun=dryrun, **kwargs) - self.data = {} - - def put(self, key: str, value: Any): - self.data[key] = value - - def get(self, key: str, default=None): - if default == KeyError: - return self.data[key] - return self.data.get(key, default) - - def pop(self, key: str, *args): - # We accept *args here to match the signature of dict.pop (throw an error if key is not found, - # unless another arg is provided as a default) - return self.data.pop(key, *args) - - def keys(self): - return list(self.data.keys()) - - def values(self): - return list(self.data.values()) - - def items(self): - return list(self.data.items()) - - def clear(self): - self.data = {} - - def rename_key(self, old_key, new_key, *args): - # We accept *args here to match the signature of dict.pop (throw an error if key is not found, - # unless another arg is provided as a default) - self.data[new_key] = self.data.pop(old_key, *args) - - def __len__(self): - return len(self.data) - - def contains(self, key: str): - return key in self - - def __contains__(self, key: str): - return key in self.data - - def __getitem__(self, key: str): - return self.data[key] - - def __setitem__(self, key: str, value: Any): - self.data[key] = value - - def __delitem__(self, key: str): - del self.data[key] - - def __repr__(self): - return repr(self.data) diff --git a/runhouse/resources/queues/__init__.py b/runhouse/resources/queues/__init__.py deleted file mode 100644 index 8355a78ab..000000000 --- a/runhouse/resources/queues/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# TODO RabbitMQ, Flink, Kafka, Faust, etc. -from .queue import Queue diff --git a/runhouse/resources/queues/queue.py b/runhouse/resources/queues/queue.py deleted file mode 100644 index bd9cb1cce..000000000 --- a/runhouse/resources/queues/queue.py +++ /dev/null @@ -1,86 +0,0 @@ -import queue -from typing import Any, List, Optional, Union - -from runhouse import Cluster, Env -from runhouse.resources.module import Module - - -class Queue(Module): - RESOURCE_TYPE = "queue" - DEFAULT_CACHE_FOLDER = ".cache/runhouse/queues" - - """Simple dict wrapper to act as a queue. Wrapping this in an actor allows us to access - it across Ray processes and nodes, and even keep some things pinned to Python memory.""" - - def __init__( - self, - name: Optional[str] = None, - system: Union[Cluster, str] = None, - env: Optional[Env] = None, - max_size: int = 0, - persist: bool = False, # TODO - dryrun: bool = False, - **kwargs, - ): - """ - Runhouse Queue object - - .. note:: - To build a Queue, please use the factory method :func:`queue`. - """ - super().__init__(name=name, system=system, env=env, dryrun=dryrun, **kwargs) - if not self._system or self._system.on_this_cluster(): - self.data = queue.Queue(maxsize=max_size) - self.persist = persist - self._subscribers = [] - - def put(self, item: Any, block=True, timeout=None): - self.data.put(item, block=block, timeout=timeout) - for fn, out_queue in self._subscribers: - res = fn(item) - if out_queue: - out_queue.put(res) - - def put_nowait(self, item: Any): - self.data.put_nowait(item) - - def put_batch(self, items: List[Any], block=True, timeout=None): - for item in items: - self.data.put(item, block=block, timeout=timeout) - - def get(self, block=True, timeout=None): - return self.data.get(block=block, timeout=timeout) - - def get_nowait(self): - return self.data.get_nowait() - - def get_batch(self, batch_size: int, block=True, timeout=None): - items = [] - for _ in range(batch_size): - items.append(self.data.get(block=block, timeout=timeout)) - return items - - def __iter__(self): - try: - while True: - yield self.get() - except queue.Empty: - return - - def qsize(self): - return self.data.qsize() - - def empty(self): - return self.data.empty() - - def full(self): - return self.data.full() - - def task_done(self): - return self.data.task_done() - - def join(self): - return self.data.join() - - def subscribe(self, function, out_queue=None): - self._subscribers.append((function, out_queue)) diff --git a/runhouse/resources/tables/__init__.py b/runhouse/resources/tables/__init__.py deleted file mode 100644 index ba0e6cda8..000000000 --- a/runhouse/resources/tables/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# TODO ParquetTable, HuggingFaceTable, BigQueryTable, SnowflakeTable, -# DeltaLakeTable, PostgreSQLTable, RedShiftTable, MySQLTable? -# GoogleAnalyticsTable, custom ETL tables - -from .table import Table, table diff --git a/runhouse/resources/tables/dask_table.py b/runhouse/resources/tables/dask_table.py deleted file mode 100644 index 24d5f4790..000000000 --- a/runhouse/resources/tables/dask_table.py +++ /dev/null @@ -1,33 +0,0 @@ -from runhouse.logger import logger - -from .table import Table - - -class DaskTable(Table): - DEFAULT_FOLDER_PATH = "/runhouse/dask-tables" - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def write(self, write_index: bool = False): - # https://docs.dask.org/en/stable/how-to/connect-to-remote-data.html - if self._cached_data is not None: - # https://stackoverflow.com/questions/72891631/how-to-remove-null-dask-index-from-parquet-file - self.data.to_parquet( - self.fsspec_url, - write_index=write_index, - storage_options=self.data_config, - ) - logger.info(f"Saved {str(self)} to: {self.fsspec_url}") - - return self - - def fetch(self, **kwargs): - import dask.dataframe as dd - - # https://docs.dask.org/en/stable/generated/dask.dataframe.read_parquet.html - self._cached_data = dd.read_parquet( - self.fsspec_url, storage_options=self.data_config - ) - - return self._cached_data diff --git a/runhouse/resources/tables/huggingface_table.py b/runhouse/resources/tables/huggingface_table.py deleted file mode 100644 index 34d8905f4..000000000 --- a/runhouse/resources/tables/huggingface_table.py +++ /dev/null @@ -1,88 +0,0 @@ -from typing import Optional - -from runhouse.logger import logger - -from .table import Table - - -class HuggingFaceTable(Table): - DEFAULT_FOLDER_PATH = "/runhouse/huggingface-tables" - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def write(self): - hf_dataset = None - if self._cached_data is not None: - import datasets - - if isinstance(self.data, datasets.Dataset): - arrow_table = self.data.data.table - ray_dataset = self._ray_dataset_from_arrow(arrow_table) - self.data, hf_dataset = ray_dataset, self.data - self._write_ray_dataset(self.data) - elif isinstance(self.data, datasets.DatasetDict): - # TODO [JL] Add support for dataset dict - raise NotImplementedError( - "Runhouse does not currently support DatasetDict objects, please convert to " - "a Dataset before saving." - ) - else: - raise TypeError( - "Unsupported data type for HuggingFaceTable. Please use a Dataset" - ) - - logger.info(f"Saved {str(self)} to: {self.fsspec_url}") - - # Restore the original dataset - if hf_dataset is not None: - self.data = self.to_dataset(hf_dataset) - - return self - - def fetch(self, **kwargs): - # Read as pyarrow table, then convert back to HF dataset - arrow_table = super().fetch(**kwargs) - self._cached_data = self.to_dataset(arrow_table) - return self._cached_data - - def stream( - self, - batch_size: int, - drop_last: bool = False, - shuffle_seed: Optional[int] = None, - shuffle_buffer_size: Optional[int] = None, - prefetch_batches: Optional[int] = None, - as_dict: bool = True, - ): - """Stream data as either Dataset object or dict (as generated by ray iter)""" - for batch in super().stream( - batch_size, drop_last, shuffle_seed, shuffle_buffer_size, prefetch_batches - ): - yield batch if as_dict else self.to_dataset(batch) - - @staticmethod - def to_dataset(data): - """Convert to a HuggingFace Dataset. Relevant when fetching the data or when choosing to stream the data - in as a Dataset.""" - import pandas as pd - import pyarrow as pa - import ray.data - from datasets import Dataset - - if isinstance(data, Dataset): - return data - - elif isinstance(data, dict): - return Dataset.from_dict(data) - - elif isinstance(data, pd.DataFrame): - return Dataset.from_pandas(data) - - elif isinstance(data, (pa.Table, ray.data.Dataset)): - return Dataset.from_pandas(data.to_pandas()) - - else: - raise TypeError( - f"Data must be a dict, Pandas DataFrame, ray Dataset, or PyArrow table, not {type(data)}" - ) diff --git a/runhouse/resources/tables/pandas_table.py b/runhouse/resources/tables/pandas_table.py deleted file mode 100644 index b97241c6d..000000000 --- a/runhouse/resources/tables/pandas_table.py +++ /dev/null @@ -1,38 +0,0 @@ -from runhouse.rns.utils.api import generate_uuid - -from .table import Table - - -class PandasTable(Table): - DEFAULT_FOLDER_PATH = "/runhouse/pandas-tables" - DEFAULT_STREAM_FORMAT = "pandas" - - def __init__(self, **kwargs): - if not kwargs.get("file_name"): - kwargs["file_name"] = f"{generate_uuid()}.parquet" - super().__init__(**kwargs) - - def __iter__(self): - for block in self.stream(batch_size=self.DEFAULT_BATCH_SIZE): - for idx, row in block.iterrows(): - yield row - - def write(self): - if self._cached_data is not None: - # https://pandas.pydata.org/pandas-docs/version/1.1/reference/api/pandas.DataFrame.to_parquet.html - self.data.to_parquet( - self.fsspec_url, - storage_options=self.data_config, - partition_cols=self.partition_cols, - ) - - return self - - def fetch(self, **kwargs): - import pandas as pd - - # https://pandas.pydata.org/docs/reference/api/pandas.read_parquet.html - self._cached_data = pd.read_parquet( - self.fsspec_url, storage_options=self.data_config - ) - return self._cached_data diff --git a/runhouse/resources/tables/rapids_table.py b/runhouse/resources/tables/rapids_table.py deleted file mode 100644 index b594920a9..000000000 --- a/runhouse/resources/tables/rapids_table.py +++ /dev/null @@ -1,25 +0,0 @@ -from runhouse.logger import logger - -from .table import Table - - -class RapidsTable(Table): - DEFAULT_FOLDER_PATH = "/runhouse/rapids-tables" - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def write(self): - # https://docs.rapids.ai/api/cudf/nightly/api_docs/api/cudf.dataframe.to_parquet - if self._cached_data is not None: - self.data.to_parquet(self.fsspec_url) - logger.info(f"Saved {str(self)} to: {self.fsspec_url}") - - return self - - def fetch(self, **kwargs): - import cudf - - # https://docs.rapids.ai/api/cudf/nightly/api_docs/api/cudf.read_parquet.html - self._cached_data = cudf.read_parquet(self.fsspec_url) - return self._cached_data diff --git a/runhouse/resources/tables/ray_table.py b/runhouse/resources/tables/ray_table.py deleted file mode 100644 index 64f4aa8af..000000000 --- a/runhouse/resources/tables/ray_table.py +++ /dev/null @@ -1,25 +0,0 @@ -from runhouse.logger import logger - -from .table import Table - - -class RayTable(Table): - DEFAULT_FOLDER_PATH = "/runhouse/ray-tables" - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def write(self): - if self._cached_data is not None: - self._write_ray_dataset(self.data) - logger.info(f"Saved {str(self)} to: {self.fsspec_url}") - - return self - - def fetch(self, **kwargs): - import ray - - self._cached_data = ray.data.read_parquet( - self.fsspec_url, filesystem=self._folder.fsspec_fs - ) - return self._cached_data diff --git a/runhouse/resources/tables/table.py b/runhouse/resources/tables/table.py deleted file mode 100644 index f12cfda20..000000000 --- a/runhouse/resources/tables/table.py +++ /dev/null @@ -1,591 +0,0 @@ -import copy -from pathlib import Path -from typing import Dict, List, Optional - -import fsspec - -import ray - -from runhouse import Folder -from runhouse.globals import rns_client - -from runhouse.resources.folders import folder -from runhouse.resources.hardware.utils import _current_cluster, _get_cluster_from -from runhouse.resources.resource import Resource - -PREFETCH_KWARG = "prefetch_batches" if ray.__version__ >= "2.4.0" else "prefetch_blocks" - -from runhouse.logger import logger - - -class Table(Resource): - RESOURCE_TYPE = "table" - DEFAULT_FOLDER_PATH = "/runhouse-table" - DEFAULT_CACHE_FOLDER = ".cache/runhouse/tables" - DEFAULT_STREAM_FORMAT = "pyarrow" - DEFAULT_BATCH_SIZE = 256 - DEFAULT_PREFETCH_BATCHES = 1 - - def __init__( - self, - path: str, - name: Optional[str] = None, - file_name: Optional[str] = None, - system: Optional[str] = None, - data_config: Optional[dict] = None, - dryrun: bool = False, - partition_cols: Optional[List] = None, - stream_format: Optional[str] = None, - metadata: Optional[Dict] = None, - **kwargs, - ): - """ - The Runhouse Table object. - - .. note:: - To build a Table, please use the factory method :func:`table`. - """ - super().__init__(name=name, dryrun=dryrun) - self.file_name = file_name - - # Use factory method so correct subclass for system is returned - # strip filename from path if provided - self._folder = folder( - path=str(Path(path).parents[0]) if Path(path).suffix else path, - system=system, - data_config=data_config, - dryrun=dryrun, - ) - - self._cached_data = None - self.partition_cols = partition_cols - self.stream_format = stream_format or self.DEFAULT_STREAM_FORMAT - self.metadata = metadata or {} - - @staticmethod - def from_config(config: dict, dryrun=False, _resolve_children=True): - if _resolve_children: - config = Table._check_for_child_configs(config) - return _load_table_subclass(config, dryrun=dryrun) - - def config(self, condensed=True): - config = super().config(condensed) - if isinstance(self._folder, Resource): - config["system"] = self._resource_string_for_subconfig( - self.system, condensed - ) - config["data_config"] = self._folder._data_config - else: - config["system"] = self.system - self.save_attrs_to_config( - config, ["path", "partition_cols", "metadata", "file_name"] - ) - config.update(config) - - return config - - @classmethod - def _check_for_child_configs(cls, config: dict): - """Overload by child resources to load any resources they hold internally.""" - system = config.get("system") - if isinstance(system, str) or isinstance(system, dict): - config["system"] = _get_cluster_from(system) - return config - - @property - def data(self) -> "ray.data.Dataset": - """Get the table data. If data is not already cached, return a Ray dataset. - - With the dataset object we can stream or convert to other types, for example: - - .. code-block:: python - - data.iter_batches() - data.to_pandas() - data.to_dask() - """ - if self._cached_data is not None: - return self._cached_data - return self._read_ray_dataset() - - @data.setter - def data(self, new_data): - """Update the data blob to new data""" - self._cached_data = new_data - # TODO should we save here? - # self.save(overwrite=True) - - @property - def system(self): - return self._folder.system - - @system.setter - def system(self, new_system): - self._folder.system = new_system - - @property - def path(self): - if self.file_name: - return f"{self._folder.path}/{self.file_name}" - return self._folder.path - - @path.setter - def path(self, new_path): - self._folder.path = new_path - - def set_metadata(self, key, val): - self.metadata[key] = val - - def get_metadata(self, key): - return self.metadata.get(key) - - @property - def fsspec_url(self): - if self.file_name: - return f"{self._folder.fsspec_url}/{self.file_name}" - return self._folder.fsspec_url - - @property - def data_config(self): - return self._folder.data_config - - @data_config.setter - def data_config(self, new_data_config): - self._folder.data_config = new_data_config - - # @classmethod - # def from_name(cls, name, dryrun=False): - # """Load existing Table via its name.""" - # config = rns_client.load_config(name=name) - # if not config: - # raise ValueError(f"Table {name} not found.") - # - # # We don't need to load the cluster dict here (if system is a cluster) because the table init - # # goes through the Folder factory method, which handles that. - # - # # Add this table's name to the resource artifact registry if part of a run - # rns_client.add_upstream_resource(name) - # - # # Uses the table subclass associated with the `resource_subtype` - # table_cls = _load_table_subclass(config=config, dryrun=dryrun) - # return table_cls.from_config(config=config, dryrun=dryrun) - - def to(self, system, path=None, data_config=None): - """Copy and return the table on the given filesystem and path. - - Example: - >>> local_table = rh.table(data, path="local/path") - >>> s3_table = local_table.to("s3") - >>> cluster_table = local_table.to(my_cluster) - """ - new_table = copy.copy(self) - new_table._folder = self._folder.to( - system=system, path=path, data_config=data_config - ) - return new_table - - def _save_sub_resources(self, folder: str = None): - if isinstance(self.system, Resource): - self.system.save(folder=folder) - - def write(self): - """Write underlying table data to fsspec URL. - - Example: - >>> rh.table(data, path="path/to/write").write() - """ - import pandas as pd - import pyarrow as pa - import ray.data - - if self._cached_data is not None: - data_to_write = self.data - - if isinstance(data_to_write, pd.DataFrame): - data_to_write = self._ray_dataset_from_pandas(data_to_write) - - if isinstance(data_to_write, pa.Table): - data_to_write = self._ray_dataset_from_arrow(data_to_write) - - if not isinstance(data_to_write, ray.data.Dataset): - raise TypeError(f"Invalid Table format {type(data_to_write)}") - - self._write_ray_dataset(data_to_write) - logger.info(f"Saved {str(self)} to: {self.fsspec_url}") - - return self - - def fetch(self, columns: Optional[list] = None) -> "pa.Table": - """Returns the complete table contents. - - Example: - >>> table = rh.table(data) - >>> fomratted_data = table.fetch() - """ - # https://arrow.apache.org/docs/python/generated/pyarrow.parquet.read_table.html - self._cached_data = self.read_table_from_file(columns) - if self._cached_data is not None: - return self._cached_data - - # When trying to read a file like object could lead to IsADirectoryError if the folder path is actually a - # directory and the file has been automatically generated for us inside the folder - # (ex: with pyarrow table or with partitioned data that saves multiple files within the directory) - - try: - import pyarrow.parquet as pq - - table_data = pq.read_table( - self.path, columns=columns, filesystem=self._folder.fsspec_fs - ) - - if not table_data: - raise ValueError(f"No table data found in path: {self.path}") - - self._cached_data = table_data - return self._cached_data - - except: - raise Exception(f"Failed to read table in path: {self.path}") - - def __getstate__(self): - """Override the pickle method to clear _cached_data before pickling.""" - state = self.__dict__.copy() - state["_cached_data"] = None - return state - - def __iter__(self): - for block in self.stream(batch_size=self.DEFAULT_BATCH_SIZE): - for sample in block: - yield sample - - def __len__(self): - import pandas as pd - import ray.data - - if isinstance(self.data, pd.DataFrame): - len_dataset = self.data.shape[0] - - elif isinstance(self.data, ray.data.Dataset): - len_dataset = self.data.count() - - else: - if not hasattr(self.data, "__len__") or not self.data: - raise RuntimeError("Cannot get len for dataset.") - else: - len_dataset = len(self.data) - - return len_dataset - - def __str__(self): - return self.__class__.__name__ - - def stream( - self, - batch_size: int, - drop_last: bool = False, - shuffle_seed: Optional[int] = None, - shuffle_buffer_size: Optional[int] = None, - prefetch_batches: Optional[int] = None, - ): - """Return a local batched iterator over the ray dataset. - - Example: - >>> table = rh.table(data) - >>> batches = table.stream(batch_size=4) - >>> for _, batch in batches: - >>> print(batch) - """ - ray_data = self.data - - if self.stream_format == "torch": - # https://docs.ray.io/en/master/data/api/doc/ray.data.Dataset.iter_torch_batches.html#ray.data.Dataset.iter_torch_batches - return ray_data.iter_torch_batches( - batch_size=batch_size, - drop_last=drop_last, - local_shuffle_buffer_size=shuffle_buffer_size, - local_shuffle_seed=shuffle_seed, - # We need to do this to handle the name change of the prefetch_batches argument in ray 2.4.0 - **{PREFETCH_KWARG: prefetch_batches or self.DEFAULT_PREFETCH_BATCHES}, - ) - - elif self.stream_format == "tf": - # https://docs.ray.io/en/master/data/api/doc/ray.data.Dataset.iter_tf_batches.html - return ray_data.iter_tf_batches( - batch_size=batch_size, - drop_last=drop_last, - local_shuffle_buffer_size=shuffle_buffer_size, - local_shuffle_seed=shuffle_seed, - # We need to do this to handle the name change of the prefetch_batches argument in ray 2.4.0 - **{PREFETCH_KWARG: prefetch_batches or self.DEFAULT_PREFETCH_BATCHES}, - ) - else: - # https://docs.ray.io/en/latest/data/api/dataset.html#ray.data.Dataset.iter_batches - return ray_data.iter_batches( - batch_size=batch_size, - batch_format=self.stream_format, - drop_last=drop_last, - local_shuffle_buffer_size=shuffle_buffer_size, - local_shuffle_seed=shuffle_seed, - # We need to do this to handle the name change of the prefetch_batches argument in ray 2.4.0 - **{PREFETCH_KWARG: prefetch_batches or self.DEFAULT_PREFETCH_BATCHES}, - ) - - def _read_ray_dataset(self, columns: Optional[List[str]] = None): - """Read parquet data as a ray dataset object.""" - # https://docs.ray.io/en/latest/data/api/input_output.html#parquet - dataset = ray.data.read_parquet( - self.fsspec_url, columns=columns, filesystem=self._folder.fsspec_fs - ) - return dataset - - def _write_ray_dataset(self, data_to_write: "ray.data.Dataset"): - """Write a ray dataset to a fsspec filesystem""" - if self.partition_cols: - # TODO [JL]: https://arrow.apache.org/docs/python/generated/pyarrow.parquet.write_to_dataset.html - logger.warning("Partitioning by column not currently supported.") - pass - - # delete existing contents or they'll just be appended to - self.rm() - - # https://docs.ray.io/en/master/data/api/doc/ray.data.Dataset.write_parquet.html - # data_to_write.repartition(os.cpu_count() * 2).write_parquet( - data_to_write.write_parquet(self.fsspec_url, filesystem=self._folder.fsspec_fs) - - @staticmethod - def _ray_dataset_from_arrow(data: "pa.Table"): - """Convert an Arrow Table to a Ray Dataset""" - import ray.data - - return ray.data.from_arrow(data) - - @staticmethod - def _ray_dataset_from_pandas(data: "pd.DataFrame"): - """Convert an Pandas DataFrame to a Ray Dataset""" - import ray.data - - return ray.data.from_pandas(data) - - def read_table_from_file(self, columns: Optional[list] = None): - """Read a table from it's path. - - Example: - >>> table = rh.table(path="path/to/table") - >>> table_data = table.read_table_from_file() - """ - try: - import pyarrow.parquet as pq - - with fsspec.open(self.fsspec_url, mode="rb", **self.data_config) as t: - table_data = pq.read_table(t.full_name, columns=columns) - return table_data - except: - return None - - def rm(self, recursive: bool = True): - """Delete table, including its partitioned files where relevant. - - Example: - >>> table = rh.table(path="path/to/table") - >>> table.rm() - """ - # https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.rm - self._folder.rm(recursive=recursive) - - def exists_in_system(self): - """Whether the table exists in file system. - - Example: - >>> table.exists_in_system() - """ - return ( - self._folder.exists_in_system() - and len(self._folder.ls(self.fsspec_url)) >= 1 - ) - - -def _load_table_subclass(config: dict, dryrun: bool, data=None): - """Load the relevant Table subclass based on the config or data type provided""" - resource_subtype = config.get("resource_subtype", Table.__name__) - - try: - import datasets - - if resource_subtype == "HuggingFaceTable" or isinstance(data, datasets.Dataset): - from .huggingface_table import HuggingFaceTable - - return HuggingFaceTable(**config, dryrun=dryrun) - except ModuleNotFoundError: - pass - except Exception as e: - raise e - - try: - import pandas as pd - - if resource_subtype == "PandasTable" or isinstance(data, pd.DataFrame): - from .pandas_table import PandasTable - - return PandasTable(**config, dryrun=dryrun) - except ModuleNotFoundError: - pass - except Exception as e: - raise e - - try: - import dask.dataframe as dd - - if resource_subtype == "DaskTable" or isinstance(data, dd.DataFrame): - from .dask_table import DaskTable - - return DaskTable(**config, dryrun=dryrun) - except ModuleNotFoundError: - pass - except Exception as e: - raise e - - try: - import ray.data - - if resource_subtype == "RayTable" or isinstance(data, ray.data.Dataset): - from .ray_table import RayTable - - return RayTable(**config, dryrun=dryrun) - except ModuleNotFoundError: - pass - except Exception as e: - raise e - - try: - import cudf - - if resource_subtype == "CudfTable" or isinstance(data, cudf.DataFrame): - raise NotImplementedError("Cudf not currently supported") - except ModuleNotFoundError: - pass - except Exception as e: - raise e - - try: - import pyarrow as pa - - if resource_subtype == "Table" or isinstance(data, pa.Table): - new_table = Table(**config, dryrun=dryrun) - return new_table - except ModuleNotFoundError: - pass - except Exception as e: - raise e - - raise TypeError( - f"Unsupported data type {type(data)} for Table construction. " - f"For converting data to pyarrow see: " - f"https://arrow.apache.org/docs/7.0/python/generated/pyarrow.Table.html" - ) - - -def table( - data=None, - name: Optional[str] = None, - path: Optional[str] = None, - system: Optional[str] = None, - data_config: Optional[dict] = None, - partition_cols: Optional[list] = None, - mkdir: bool = False, - dryrun: bool = False, - stream_format: Optional[str] = None, - metadata: Optional[dict] = None, -) -> Table: - """Constructs a Table object, which can be used to interact with the table at the given path. - - Args: - data: Data to be stored in the table. - name (Optional[str]): Name for the table, to reuse it later on. - path (Optional[str]): Full path to the data file. - system (Optional[str]): File system. Currently this must be one of: - [``file``, ``github``, ``sftp``, ``ssh``, ``s3``, ``gs``, ``azure``]. - data_config (Optional[dict]): The data config to pass to the underlying fsspec handler. - partition_cols (Optional[list]): List of columns to partition the table by. - mkdir (bool): Whether to create a remote folder for the table. (Default: ``False``) - dryrun (bool): Whether to create the Table if it doesn't exist, or load a Table object as a dryrun. - (Default: ``False``) - stream_format (Optional[str]): Format to stream the Table as. - Currently this must be one of: [``pyarrow``, ``torch``, ``tf``, ``pandas``] - metadata (Optional[dict]): Metadata to store for the table. - - Returns: - Table: The resulting Table object. - - Example: - >>> import runhouse as rh - >>> # Create and save (pandas) table - >>> rh.table( - >>> data=data, - >>> name="~/my_test_pandas_table", - >>> path="table_tests/test_pandas_table.parquet", - >>> system="file", - >>> mkdir=True, - >>> ).save() - >>> - >>> # Load table from above - >>> reloaded_table = rh.table(name="~/my_test_pandas_table") - """ - if name and not any( - [ - data is not None, - path, - system, - data_config, - partition_cols, - stream_format, - metadata, - ] - ): - # Try reloading existing table - return Table.from_name(name, dryrun) - - system = _get_cluster_from( - system or _current_cluster(key="config") or Folder.DEFAULT_FS, dryrun=dryrun - ) - - file_name = None - if path: - # Extract the file name from the path if provided - full_path = Path(path) - file_suffix = full_path.suffix - if file_suffix: - path = str(full_path.parent) - file_name = full_path.name - - if path is None: - # If no path is provided we need to create one based on the name of the table - table_name_in_path = rns_client.resolve_rns_data_resource_name(name) - if system == rns_client.DEFAULT_FS or ( - isinstance(system, Resource) and system.on_this_cluster() - ): - # create random path to store in .cache folder of local filesystem - path = str( - Path( - f"~/{Table.DEFAULT_CACHE_FOLDER}/{table_name_in_path}" - ).expanduser() - ) - else: - # save to the default bucket - path = f"{Table.DEFAULT_FOLDER_PATH}/{table_name_in_path}" - - config = { - "system": system, - "name": name, - "path": path, - "file_name": file_name, - "data_config": data_config, - "partition_cols": partition_cols, - "stream_format": stream_format, - "metadata": metadata, - } - - new_table = _load_table_subclass(config=config, dryrun=dryrun, data=data) - if data is not None: - new_table.data = data - - return new_table diff --git a/tests/conftest.py b/tests/conftest.py index 52f87641a..6fc76ac38 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -350,19 +350,6 @@ def event_loop(): summer_func_with_auth, # noqa: F401 ) -# ----------------- Tables ----------------- - -from tests.test_resources.test_modules.test_tables.conftest import ( - arrow_table, # noqa: F401 - cudf_table, # noqa: F401 - dask_table, # noqa: F401 - huggingface_table, # noqa: F401 - pandas_table, # noqa: F401 - ray_table, # noqa: F401 - table, # noqa: F401 -) - - ########## DEFAULT LEVELS ########## default_fixtures = {} diff --git a/tests/test_resources/test_clusters/test_cluster.py b/tests/test_resources/test_clusters/test_cluster.py index 8ab6d02d8..f3a551819 100644 --- a/tests/test_resources/test_clusters/test_cluster.py +++ b/tests/test_resources/test_clusters/test_cluster.py @@ -2,7 +2,6 @@ import time from threading import Thread -import pandas as pd import pytest import requests @@ -41,22 +40,6 @@ def load_shared_resource_config(resource_class_name, address): return loaded_resource.config() -def save_resource_and_return_config(): - df = pd.DataFrame( - {"id": [1, 2, 3, 4, 5, 6], "grade": ["a", "b", "b", "a", "a", "e"]} - ) - table = rh.table(df, name="test_table") - return table.config() - - -def test_table_to_rh_here(): - df = pd.DataFrame( - {"id": [1, 2, 3, 4, 5, 6], "grade": ["a", "b", "b", "a", "a", "e"]} - ) - rh.table(df, name="test_table").to(rh.here) - assert rh.here.get("test_table") is not None - - def summer(a: int, b: int): return a + b @@ -177,22 +160,6 @@ def test_cluster_factory_and_properties(self, cluster): if "ssl_certfile" in args: assert cluster.cert_config.cert_path == args["ssl_certfile"] - @pytest.mark.level("local") - @pytest.mark.clustertest - def test_docker_cluster_fixture_is_logged_out(self, docker_cluster_pk_ssh_no_auth): - save_resource_and_return_config_cluster = rh.function( - save_resource_and_return_config, - name="save_resource_and_return_config_cluster", - ).to( - system=docker_cluster_pk_ssh_no_auth, - ) - saved_config_on_cluster = save_resource_and_return_config_cluster() - # This cluster was created without any logged in Runhouse config. Make sure that the simple resource - # created on the cluster starts with "~", which is the prefix that local Runhouse configs are saved with. - assert ("/" not in saved_config_on_cluster["name"]) or ( - saved_config_on_cluster["name"].startswith("~") - ) - @pytest.mark.level("local") @pytest.mark.clustertest def test_cluster_recreate(self, cluster): @@ -292,15 +259,6 @@ def test_cluster_delete_env(self, cluster): assert cluster.get(env1.name) assert cluster.get("k1") - @pytest.mark.level("local") - @pytest.mark.clustertest - @pytest.mark.skip(reason="TODO") - def test_rh_here_objects(self, cluster): - save_test_table_remote = rh.function(test_table_to_rh_here, system=cluster) - save_test_table_remote() - assert "test_table" in cluster.keys() - assert isinstance(cluster.get("test_table"), rh.Table) - @pytest.mark.level("local") @pytest.mark.clustertest def test_condensed_config_for_cluster(self, cluster): diff --git a/tests/test_resources/test_modules/test_functions/test_mapper.py b/tests/test_resources/test_modules/test_functions/test_mapper.py deleted file mode 100644 index ce3c1b7ad..000000000 --- a/tests/test_resources/test_modules/test_functions/test_mapper.py +++ /dev/null @@ -1,199 +0,0 @@ -import multiprocessing -import os - -import pytest - -import runhouse as rh - -from tests.utils import get_pid_and_ray_node - -REMOTE_FUNC_NAME = "@/remote_function" - - -def summer(a, b): - return a + b - - -async def async_summer(a, b): - return a + b - - -def np_array(list): - import numpy as np - - return np.array(list) - - -def np_summer(a, b): - import numpy as np - - print(f"Summing {a} and {b}") - return int(np.array([a, b]).sum()) - - -def multiproc_np_sum(inputs): - print(f"CPUs: {os.cpu_count()}") - # See https://pythonspeed.com/articles/python-multiprocessing/ - # and https://github.com/pytorch/pytorch/issues/3492 - with multiprocessing.get_context("spawn").Pool() as P: - return P.starmap(np_summer, inputs) - - -def getpid(a=0): - return os.getpid() + a - - -def sleep_and_return(secs): - # Return the start and end time so we can ensure that the calls are non-blocking - import time - - start = time.time() - time.sleep(secs) - return start, time.time() - - -def throw_exception(a): - raise Exception("mapper exception") - - -class TestMapper: - - """Testing strategy: - - 1. Explicit mapper - for each of the below, test map, starmap, call (round robin), and test that - threaded calls are non-blocking: - 1. local mapper with local functions (rh.mapper(my_fn, processes=8)) - 2. local mapper with remote functions (rh.mapper(rh.fn(my_fn).to(cluster), processes=8)) - 3. remote mapper with local functions (rh.mapper(my_fn, processes=8).to(cluster)) - 4. local mapper with additional replicas added manually (rh.mapper(fn, processes=8).add_replicas(my_replicas) - 2. Implicit mapper - 1. fn.map(), fn.starmap(), fn.call() - - """ - - @pytest.mark.level("local") - def test_local_mapper_remote_function(self, cluster): - # Test .map() - num_replicas = 3 - pid_fn = rh.function(getpid).to(cluster) - mapper = rh.mapper(pid_fn, replicas=num_replicas) - assert len(mapper.replicas) == num_replicas - for i in range(num_replicas): - assert mapper.replicas[i].system == cluster - assert mapper.replicas[0].env.name == pid_fn.env.name - assert mapper.replicas[1].env.name == pid_fn.env.name + "_replica_0" - assert mapper.replicas[2].env.name == pid_fn.env.name + "_replica_1" - pids = mapper.map([0] * 10) - assert len(pids) == 10 - assert len(set(pids)) == num_replicas - - # Test .starmap() and reusing the envs - summer_fn = rh.function(summer).to(cluster) - sum_mapper = rh.mapper(summer_fn, replicas=num_replicas) - assert len(sum_mapper.replicas) == num_replicas - for i in range(num_replicas): - assert sum_mapper.replicas[i].system == cluster - assert sum_mapper.replicas[0].env.name == summer_fn.env.name - assert sum_mapper.replicas[1].env.name == summer_fn.env.name + "_replica_0" - assert sum_mapper.replicas[2].env.name == summer_fn.env.name + "_replica_1" - res = sum_mapper.starmap([[1, 2]] * 10) - assert res == [3] * 10 - res = sum_mapper.map([1] * 10, [2] * 10) - assert res == [3] * 10 - - # Doing this down here to confirm that first mapper using the envs isn't corrupted - pids = mapper.starmap([[0]] * 10) - assert len(pids) == 10 - assert len(set(pids)) == num_replicas - - # Test call - assert len(set(mapper.call() for _ in range(4))) == 3 - - @pytest.mark.level("local") - def test_remote_mapper_remote_function(self, cluster): - # Test that calls are non-blocking, and sending the mapper to the cluster - # Also tests passing function directly into mapper without sending it to the cluster first - sleep_mapper = rh.mapper(sleep_and_return).to(cluster) - sleep_mapper.add_replicas(5) - start_end_times = sleep_mapper.map([1] * 5) - assert len(start_end_times) == 5 - assert all(isinstance(t, tuple) and len(t) == 2 for t in start_end_times) - # Ensure that the calls are non-blocking by checking that each end time - # is greater than the start time before it - for i in range(1, len(start_end_times)): - # Assert this one started before the last one ended - assert start_end_times[i][0] < start_end_times[i - 1][1] - - last_end_time = max([end for (_, end) in start_end_times]) - earliest_start_time = min([start for (start, _) in start_end_times]) - assert last_end_time - earliest_start_time < 2 - - @pytest.mark.level("release") - def test_local_multinode_map(self, multinode_cpu_cluster): - num_replicas = 6 - env = rh.env(name="test_env", reqs=["pytest"]) - pid_fn = rh.function(get_pid_and_ray_node).to(multinode_cpu_cluster, env=env) - mapper = rh.mapper(pid_fn, replicas=num_replicas) - assert len(mapper.replicas) == num_replicas - for i in range(num_replicas): - assert mapper.replicas[i].system == multinode_cpu_cluster - ids = mapper.map([0] * 100) - pids, nodes = zip(*ids) - assert len(pids) == 100 - assert len(set(pids)) == num_replicas - # TODO: rohinb2: Fix this to actually schedule on individual nodes - # assert len(set(nodes)) == 2 - assert len(set(node for (_, node) in [mapper.call() for _ in range(10)])) == 2 - - @pytest.mark.level("release") - def test_remote_multinode_map(self, multinode_cpu_cluster): - # Test that calls are non-blocking, and sending the mapper to the cluster - env = rh.env(name="new_env", reqs=["pytest"]) - sleep_fn = rh.function(sleep_and_return).to(multinode_cpu_cluster, env=env) - sleep_mapper = rh.mapper(sleep_fn, concurrency=2).to(multinode_cpu_cluster) - sleep_mapper.add_replicas(5) - start_end_times = sleep_mapper.map([1] * 10) - assert len(start_end_times) == 10 - assert all(isinstance(t, tuple) and len(t) == 2 for t in start_end_times) - # Ensure that the calls are non-blocking by checking that each end time - # is greater than the start time before it - for i in range(1, len(start_end_times)): - # Assert this one started before the last one ended - assert start_end_times[i][0] < start_end_times[i - 1][1] - - last_end_time = max([end for (_, end) in start_end_times]) - earliest_start_time = min([start for (start, _) in start_end_times]) - assert last_end_time - earliest_start_time < 2 - - @pytest.mark.skip - @pytest.mark.level("local") - def test_maps(self, cluster): - pid_fn = rh.function(getpid, system=cluster) - num_pids = [1] * 10 - pids = pid_fn.map(num_pids) - assert len(set(pids)) > 1 - assert all(pid > 0 for pid in pids) - - pids = pid_fn.repeat(num_repeats=10) - assert len(set(pids)) > 1 - assert all(pid > 0 for pid in pids) - - pids = [pid_fn.enqueue() for _ in range(10)] - assert len(pids) == 10 - assert all(pid > 0 for pid in pids) - - re_fn = rh.function(summer, system=cluster) - summands = list(zip(range(5), range(4, 9))) - res = re_fn.starmap(summands) - assert res == [4, 6, 8, 10, 12] - - alist, blist = range(5), range(4, 9) - res = re_fn.map(alist, blist) - assert res == [4, 6, 8, 10, 12] - - @pytest.mark.level("local") - def test_throws_exception(self, cluster): - remote_exception = rh.function(throw_exception).to(system=cluster) - mapper = rh.mapper(remote_exception, replicas=2) - results = mapper.map([None, None]) - assert [isinstance(res, Exception) for res in results] diff --git a/tests/test_resources/test_modules/test_tables/__init__.py b/tests/test_resources/test_modules/test_tables/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/test_resources/test_modules/test_tables/conftest.py b/tests/test_resources/test_modules/test_tables/conftest.py deleted file mode 100644 index 077117666..000000000 --- a/tests/test_resources/test_modules/test_tables/conftest.py +++ /dev/null @@ -1,65 +0,0 @@ -import pandas as pd -import pytest - - -@pytest.fixture(scope="session") -def table(request): - return request.getfixturevalue(request.param) - - -@pytest.fixture -def huggingface_table(): - from datasets import load_dataset - - dataset = load_dataset("yelp_review_full", split="train[:1%]") - return dataset - - -@pytest.fixture -def arrow_table(): - import pyarrow as pa - - df = pd.DataFrame( - { - "int": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], - "str": ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"], - } - ) - arrow_table = pa.Table.from_pandas(df) - return arrow_table - - -@pytest.fixture -def cudf_table(): - import cudf - - gdf = cudf.DataFrame( - {"id": [1, 2, 3, 4, 5, 6], "grade": ["a", "b", "b", "a", "a", "e"]} - ) - return gdf - - -@pytest.fixture -def pandas_table(): - df = pd.DataFrame( - {"id": [1, 2, 3, 4, 5, 6], "grade": ["a", "b", "b", "a", "a", "e"]} - ) - return df - - -@pytest.fixture -def dask_table(): - import dask.dataframe as dd - - index = pd.date_range("2021-09-01", periods=2400, freq="1H") - df = pd.DataFrame({"a": range(2400), "b": list("abcaddbe" * 300)}, index=index) - ddf = dd.from_pandas(df, npartitions=10) - return ddf - - -@pytest.fixture -def ray_table(): - import ray - - ds = ray.data.range(10000) - return ds diff --git a/tests/test_resources/test_modules/test_tables/table_tests.py b/tests/test_resources/test_modules/test_tables/table_tests.py deleted file mode 100644 index 074cae831..000000000 --- a/tests/test_resources/test_modules/test_tables/table_tests.py +++ /dev/null @@ -1,919 +0,0 @@ -import shutil - -import pandas as pd -import pyarrow as pa -import ray.data -import runhouse as rh - -from runhouse import Folder - -NUM_PARTITIONS = 10 - - -# TODO top to bottom update. Named "table_tests" so it's skipped until we add a proper test class and suite - - -def delete_local_folder(path): - shutil.rmtree(path) - - -def tokenize_function(examples): - from transformers import AutoTokenizer - - tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") - return tokenizer(examples["text"], padding="max_length", truncation=True) - - -# ----------------------------------------------- -# ----------------- Local tests ----------------- -# ----------------------------------------------- -def test_create_and_reload_file_locally(tmp_path): - orig_data = pd.DataFrame({"my_col": list(range(50))}) - name = "~/my_local_test_table" - - my_table = ( - rh.table( - data=orig_data, - name=name, - path=str(tmp_path), - system="file", - ) - .write() - .save() - ) - - reloaded_table = rh.table(name=name) - reloaded_data: ray.data.Dataset = reloaded_table.data - - assert reloaded_data.to_pandas().equals(orig_data) - - batches = reloaded_table.stream(batch_size=10) - for idx, batch in enumerate(batches): - assert isinstance(batch, pd.DataFrame) - assert batch["my_col"].tolist() == list(range(idx * 10, (idx + 1) * 10)) - - del orig_data - del my_table - - reloaded_table.delete_configs() - - reloaded_table.rm() - assert not reloaded_table.exists_in_system() - - -def test_create_and_reload_pandas_locally(pandas_table, tmp_path): - name = "~/my_test_local_pandas_table" - - my_table = ( - rh.table( - data=pandas_table, - path=str(tmp_path), - name=name, - system="file", - mkdir=True, - ) - .write() - .save() - ) - - reloaded_table = rh.table(name=name) - reloaded_data: ray.data.Dataset = reloaded_table.data - - assert pandas_table.equals(reloaded_data.to_pandas()) - - batches = reloaded_table.stream(batch_size=10) - for idx, batch in enumerate(batches): - assert isinstance(batch, pd.DataFrame) - assert batch["id"].tolist() == list(range(1, 7)) - - del pandas_table - del my_table - - reloaded_table.delete_configs() - - reloaded_table.rm() - assert not reloaded_table.exists_in_system() - - -def test_create_and_reload_pyarrow_locally(arrow_table, tmp_path): - name = "~/my_test_local_pyarrow_table" - - my_table = ( - rh.table( - data=arrow_table, - name=name, - path=str(tmp_path), - system="file", - mkdir=True, - ) - .write() - .save() - ) - - reloaded_table = rh.table(name=name) - reloaded_data: ray.data.Dataset = reloaded_table.data - - assert arrow_table.to_pandas().equals(reloaded_data.to_pandas()) - - batches = reloaded_table.stream(batch_size=10) - for idx, batch in enumerate(batches): - assert isinstance(batch, pa.Table) - assert batch["int"].to_pylist() == list(range(1, 11)) - - del arrow_table - del my_table - - reloaded_table.delete_configs() - - reloaded_table.rm() - assert not reloaded_table.exists_in_system() - - -def test_create_and_reload_ray_locally(ray_table, tmp_path): - name = "~/my_test_local_ray_table" - - my_table = ( - rh.table( - data=ray_table, - path=str(tmp_path), - name=name, - system="file", - mkdir=True, - ) - .write() - .save() - ) - - reloaded_table = rh.table(name=name) - reloaded_data: ray.data.Dataset = reloaded_table.data - - assert ray_table.to_pandas().equals(reloaded_data.to_pandas()) - - batches = reloaded_table.stream(batch_size=10) - for idx, batch in enumerate(batches): - assert isinstance(batch, pa.Table) - # NOTE [DG] 2021-08-10: This will generally fail because ray automatically partitions the data into - # blocks, and order is not necessarily preserved when reading the data back in. Ideally we fix this - # when we switch to in-memory tables. - # assert batch["value"].to_pylist() == list(range(idx * 10, (idx + 1) * 10)) - - if idx in [0, 10, 33]: - # Some random batches to check - assert [isinstance(val, int) for val in batch["value"].to_pylist()] - - del ray_table - del my_table - - reloaded_table.delete_configs() - - reloaded_table.rm() - assert not reloaded_table.exists_in_system() - - -def test_create_and_reload_huggingface_locally(huggingface_table, tmp_path): - name = "~/my_test_local_huggingface_table" - - my_table = ( - rh.table( - data=huggingface_table, - name=name, - path=str(tmp_path), - system="file", - mkdir=True, - ) - .write() - .save() - ) - - reloaded_table = rh.table(name=name) - reloaded_data: ray.data.Dataset = reloaded_table.data - - assert huggingface_table.to_pandas().equals(reloaded_data.to_pandas()) - - batches = reloaded_table.stream(batch_size=10, as_dict=False) - for idx, batch in enumerate(batches): - assert batch.column_names == ["label", "text"] - assert batch.shape == (10, 2) - - del huggingface_table - del my_table - - reloaded_table.delete_configs() - - reloaded_table.rm() - assert not reloaded_table.exists_in_system() - - -def test_create_and_reload_dask_locally(dask_table, tmp_path): - name = "~/my_test_local_dask_table" - - my_table = ( - rh.table( - data=dask_table, - name=name, - path=str(tmp_path), - system="file", - mkdir=True, - ) - .write() - .save() - ) - - reloaded_table = rh.table(name=name) - reloaded_data: "dask.dataframe.core.DataFrame" = reloaded_table.data.to_dask() - assert reloaded_data.columns.to_list() == ["a", "b"] - - batches = reloaded_table.stream(batch_size=10) - for idx, batch in enumerate(batches): - assert batch.column_names == ["a", "b"] - assert batch.shape == (10, 2) - - del dask_table - del my_table - - reloaded_table.delete_configs() - - reloaded_table.rm() - assert not reloaded_table.exists_in_system() - - -# -------------------------------------------- -# ----------------- S3 tests ----------------- -# -------------------------------------------- - - -def test_create_and_reload_pyarrow_data_from_s3(arrow_table, table_s3_bucket): - name = "@/my_test_pyarrow_table" - - my_table = ( - rh.table( - data=arrow_table, - name=name, - path=f"/{table_s3_bucket}/pyarrow_df", - system="s3", - mkdir=True, - ) - .write() - .save() - ) - - reloaded_table = rh.table(name=name) - reloaded_data: ray.data.Dataset = reloaded_table.data - assert reloaded_data.to_pandas().equals(arrow_table.to_pandas()) - - batches = reloaded_table.stream(batch_size=10) - for idx, batch in enumerate(batches): - assert batch.column_names == ["int", "str"] - - del arrow_table - del my_table - - reloaded_table.delete_configs() - - reloaded_table.rm() - assert not reloaded_table.exists_in_system() - - -def test_create_and_reload_pandas_data_from_s3(pandas_table, table_s3_bucket): - name = "@/my_test_pandas_table" - - my_table = ( - rh.table( - data=pandas_table, - name=name, - path=f"/{table_s3_bucket}/pandas_df", - system="s3", - mkdir=True, - ) - .write() - .save() - ) - - reloaded_table = rh.table(name=name) - reloaded_data: ray.data.Dataset = reloaded_table.data - assert pandas_table.equals(reloaded_data.to_pandas()) - - batches = reloaded_table.stream(batch_size=10) - for idx, batch in enumerate(batches): - assert isinstance(batch, pd.DataFrame) - assert batch["id"].tolist() == list(range(1, 7)) - - del pandas_table - del my_table - - reloaded_table.delete_configs() - - reloaded_table.rm() - assert not reloaded_table.exists_in_system() - - -def test_create_and_reload_huggingface_data_from_s3(huggingface_table, table_s3_bucket): - name = "@/my_test_hf_table" - - my_table = ( - rh.table( - data=huggingface_table, - name=name, - path=f"/{table_s3_bucket}/huggingface_data", - system="s3", - mkdir=True, - ) - .write() - .save() - ) - - reloaded_table = rh.table(name=name) - - # Stream in as huggingface dataset - batches = reloaded_table.stream(batch_size=10, as_dict=False) - for idx, batch in enumerate(batches): - assert batch.column_names == ["label", "text"] - assert batch.shape == (10, 2) - - del huggingface_table - del my_table - - reloaded_table.delete_configs() - - reloaded_table.rm() - assert not reloaded_table.exists_in_system() - - -def test_create_and_reload_dask_data_from_s3(dask_table, table_s3_bucket): - name = "@/my_test_dask_table" - - my_table = ( - rh.table( - data=dask_table, - name=name, - path=f"/{table_s3_bucket}/dask", - system="s3", - mkdir=True, - ) - .write() - .save() - ) - - reloaded_table = rh.table(name=name) - reloaded_data: "dask.dataframe.core.DataFrame" = reloaded_table.data.to_dask() - assert reloaded_data.columns.to_list() == ["a", "b"] - - batches = reloaded_table.stream(batch_size=10) - for idx, batch in enumerate(batches): - assert batch.column_names == ["a", "b"] - assert batch.shape == (10, 2) - - del dask_table - del my_table - - reloaded_table.delete_configs() - - reloaded_table.rm() - assert not reloaded_table.exists_in_system() - - -def test_create_and_reload_ray_data_from_s3(ray_table, table_s3_bucket): - name = "@/my_test_ray_table" - - my_table = ( - rh.table( - data=ray_table, - name=name, - path=f"/{table_s3_bucket}/ray_data", - system="s3", - mkdir=True, - ) - .write() - .save() - ) - - reloaded_table = rh.table(name=name) - reloaded_data: ray.data.Dataset = reloaded_table.data - assert reloaded_data.to_pandas().equals(ray_table.to_pandas()) - - batches = reloaded_table.stream(batch_size=10) - for idx, batch in enumerate(batches): - assert isinstance(batch, pa.Table) - if idx in [0, 10, 33]: - # Some random batches to check - assert [isinstance(val, int) for val in batch["value"].to_pylist()] - - del ray_table - del my_table - - reloaded_table.delete_configs() - - reloaded_table.rm() - assert not reloaded_table.exists_in_system() - - -# ----------------- Iter ----------------- -def test_load_pandas_data_as_iter(pandas_table, table_s3_bucket): - name = "@/my_test_pandas_table" - - my_table = ( - rh.table( - data=pandas_table, - name=name, - path=f"/{table_s3_bucket}/pandas", - system="s3", - mkdir=True, - ) - .write() - .save() - ) - - reloaded_table = rh.table(name=name) - reloaded_data: ray.data.Dataset = next(iter(reloaded_table)) - - assert isinstance(reloaded_data, pd.Series) - assert reloaded_data.to_dict() == {"id": 1, "grade": "a"} - - del pandas_table - del my_table - - reloaded_table.delete_configs() - - reloaded_table.rm() - assert not reloaded_table.exists_in_system() - - -def test_load_pyarrow_data_as_iter(arrow_table, table_s3_bucket): - name = "@/my_test_pyarrow_table" - - my_table = ( - rh.table( - data=arrow_table, - name=name, - path=f"/{table_s3_bucket}/pyarrow-data", - system="s3", - mkdir=True, - ) - .write() - .save() - ) - - reloaded_table = rh.table(name=name) - reloaded_data: pa.ChunkedArray = next(iter(reloaded_table)) - - assert isinstance(reloaded_data, pa.ChunkedArray) - assert reloaded_data.to_pylist() == list(range(1, 11)) - - del arrow_table - del my_table - - reloaded_table.delete_configs() - - reloaded_table.rm() - assert not reloaded_table.exists_in_system() - - -def test_load_huggingface_data_as_iter(huggingface_table, table_s3_bucket): - name = "@/my_test_huggingface_table" - - my_table = ( - rh.table( - data=huggingface_table, - name=name, - path=f"/{table_s3_bucket}/huggingface-dataset", - system="s3", - mkdir=True, - ) - .write() - .save() - ) - - reloaded_table = rh.table(name=name) - reloaded_data: pa.ChunkedArray = next(iter(reloaded_table)) - assert isinstance(reloaded_data, pa.ChunkedArray) - - del huggingface_table - del my_table - - reloaded_table.delete_configs() - - reloaded_table.rm() - assert not reloaded_table.exists_in_system() - - -# ----------------- Shuffling ----------------- -def test_shuffling_pyarrow_data_from_s3(arrow_table, table_s3_bucket): - name = "@/my_test_shuffled_pyarrow_table" - - my_table = ( - rh.table( - data=arrow_table, - name=name, - path=f"/{table_s3_bucket}/pyarrow", - system="s3", - mkdir=True, - ) - .write() - .save() - ) - - reloaded_table = rh.table(name=name) - batches = reloaded_table.stream( - batch_size=10, shuffle_seed=42, shuffle_buffer_size=10 - ) - for idx, batch in enumerate(batches): - assert isinstance(batch, pa.Table) - assert arrow_table.columns[0].to_pylist() != batch.columns[0].to_pylist() - assert arrow_table.columns[1].to_pylist() != batch.columns[1].to_pylist() - - del arrow_table - del my_table - - reloaded_table.delete_configs() - - reloaded_table.rm() - assert not reloaded_table.exists_in_system() - - -# ------------------------------------------------- -# ----------------- Cluster tests ----------------- -# ------------------------------------------------- - - -def test_create_and_reload_pandas_data_from_cluster(pandas_table, cluster): - # Make sure the destination folder for the data exists on the cluster - data_path_on_cluster = f"{Folder.DEFAULT_CACHE_FOLDER}/pandas-data" - cluster.run([f"mkdir -p {data_path_on_cluster}"]) - - name = "@/my_test_pandas_table" - my_table = ( - rh.table( - data=pandas_table, - name=name, - path=data_path_on_cluster, - system=cluster, - ) - .write() - .save() - ) - - reloaded_table = rh.table(name=name) - - reloaded_data: ray.data.Dataset = reloaded_table.data - assert pandas_table.equals(reloaded_data.to_pandas()) - - batches = reloaded_table.stream(batch_size=10) - for idx, batch in enumerate(batches): - assert isinstance(batch, pd.DataFrame) - assert batch["id"].tolist() == list(range(1, 7)) - - del pandas_table - del my_table - - reloaded_table.delete_configs() - - reloaded_table.rm() - assert not reloaded_table.exists_in_system() - - -def test_create_and_reload_ray_data_from_cluster(ray_table, cluster): - data_path_on_cluster = f"{Folder.DEFAULT_CACHE_FOLDER}/ray-data" - cluster.run([f"mkdir -p {data_path_on_cluster}"]) - - name = "@/my_test_ray_cluster_table" - - my_table = ( - rh.table( - data=ray_table, - name=name, - path=data_path_on_cluster, - system=cluster, - ) - .write() - .save() - ) - - reloaded_table = rh.table(name=name) - reloaded_data: ray.data.Dataset = reloaded_table.data - assert ray_table.to_pandas().equals(reloaded_data.to_pandas()) - - batches = reloaded_table.stream(batch_size=10) - for idx, batch in enumerate(batches): - assert isinstance(batch, pa.Table) - if idx in [0, 10, 33]: - # Some random batches to check - assert [isinstance(val, int) for val in batch["value"].to_pylist()] - - del ray_table - del my_table - - reloaded_table.delete_configs() - - reloaded_table.rm() - assert not reloaded_table.exists_in_system() - - -def test_create_and_reload_pyarrow_data_from_cluster(arrow_table, cluster): - data_path_on_cluster = f"{Folder.DEFAULT_CACHE_FOLDER}/pyarrow-data" - cluster.run([f"mkdir -p {data_path_on_cluster}"]) - - name = "@/my_test_pyarrow_cluster_table" - - my_table = ( - rh.table( - data=arrow_table, - name=name, - path=data_path_on_cluster, - system=cluster, - ) - .write() - .save() - ) - - reloaded_table = rh.table(name=name) - reloaded_data: ray.data.Dataset = reloaded_table.data - assert arrow_table.to_pandas().equals(reloaded_data.to_pandas()) - - batches = reloaded_table.stream(batch_size=10) - for idx, batch in enumerate(batches): - assert batch.column_names == ["int", "str"] - - del arrow_table - del my_table - - reloaded_table.delete_configs() - - reloaded_table.rm() - assert not reloaded_table.exists_in_system() - - -def test_create_and_reload_huggingface_data_from_cluster(huggingface_table, cluster): - data_path_on_cluster = f"{Folder.DEFAULT_CACHE_FOLDER}/hf-data" - cluster.run([f"mkdir -p {data_path_on_cluster}"]) - - name = "@/my_test_hf_cluster_table" - - my_table = ( - rh.table( - data=huggingface_table, - name=name, - path=data_path_on_cluster, - system=cluster, - ) - .write() - .save() - ) - - reloaded_table = rh.table(name=name) - reloaded_data: ray.data.Dataset = reloaded_table.data - assert huggingface_table.to_pandas().equals(reloaded_data.to_pandas()) - - # Stream in as huggingface dataset - batches = reloaded_table.stream(batch_size=10, as_dict=False) - for idx, batch in enumerate(batches): - assert batch.column_names == ["label", "text"] - assert batch.shape == (10, 2) - - del huggingface_table - del my_table - - reloaded_table.delete_configs() - - reloaded_table.rm() - assert not reloaded_table.exists_in_system() - - -def test_create_and_reload_dask_data_from_cluster(dask_table, cluster): - data_path_on_cluster = f"{Folder.DEFAULT_CACHE_FOLDER}/dask-data" - cluster.run([f"mkdir -p {data_path_on_cluster}"]) - - name = "@/my_test_dask_cluster_table" - - my_table = ( - rh.table( - data=dask_table, - name=name, - path=data_path_on_cluster, - system=cluster, - ) - .write() - .save() - ) - - reloaded_table = rh.table(name=name) - reloaded_data: "dask.dataframe.core.DataFrame" = reloaded_table.data.to_dask() - assert reloaded_data.columns.to_list() == ["a", "b"] - - batches = reloaded_table.stream(batch_size=10) - for idx, batch in enumerate(batches): - assert batch.column_names == ["a", "b"] - assert batch.shape == (10, 2) - - del dask_table - del my_table - - reloaded_table.delete_configs() - - reloaded_table.rm() - assert not reloaded_table.exists_in_system() - - -def test_to_cluster_attr(pandas_table, cluster, tmp_path): - local_path = tmp_path / "table_tests/local_test_table" - name = "~/my_local_test_table" - - my_table = ( - rh.table( - data=pandas_table, - name=name, - path=str(local_path), - system="file", - ) - .write() - .save() - ) - - cluster_table = my_table.to(system=cluster) - - assert isinstance(cluster_table.system, rh.Cluster) - assert cluster_table._folder._fs_str == "ssh" - - data = cluster_table.data - assert data.to_pandas().equals(pandas_table) - - del pandas_table - del my_table - - cluster_table.delete_configs() - cluster_table.rm() - assert not cluster_table.exists_in_system() - - -# ------------------------------------------------- -# ----------------- Fetching tests ----------------- -# ------------------------------------------------- -def test_create_and_fetch_pyarrow_data_from_s3(arrow_table, table_s3_bucket): - name = "@/my_test_fetch_pyarrow_table" - - my_table = ( - rh.table( - data=arrow_table, - name=name, - path=f"/{table_s3_bucket}/pyarrow", - system="s3", - mkdir=True, - ) - .write() - .save() - ) - - reloaded_table = rh.table(name=name) - reloaded_data: pa.Table = reloaded_table.fetch() - assert arrow_table == reloaded_data - - del arrow_table - del my_table - - reloaded_table.delete_configs() - - reloaded_table.rm() - assert not reloaded_table.exists_in_system() - - -def test_create_and_fetch_pandas_data_from_s3(pandas_table, table_s3_bucket): - name = "@/my_test_fetch_pandas_table" - - my_table = ( - rh.table( - data=pandas_table, - name=name, - path=f"/{table_s3_bucket}/pandas", - system="s3", - mkdir=True, - ) - .write() - .save() - ) - - reloaded_table = rh.table(name=name) - reloaded_data: pd.DataFrame = reloaded_table.fetch() - assert pandas_table.equals(reloaded_data) - - del pandas_table - del my_table - - reloaded_table.delete_configs() - - reloaded_table.rm() - assert not reloaded_table.exists_in_system() - - -def test_create_and_fetch_huggingface_data_from_s3(huggingface_table, table_s3_bucket): - name = "@/my_test_fetch_huggingface_table" - - my_table = ( - rh.table( - data=huggingface_table, - name=name, - path=f"/{table_s3_bucket}/huggingface", - system="s3", - mkdir=True, - ) - .write() - .save() - ) - - reloaded_table = rh.table(name=name) - reloaded_data = reloaded_table.fetch() - assert huggingface_table.shape == reloaded_data.shape - - del huggingface_table - del my_table - - reloaded_table.delete_configs() - - reloaded_table.rm() - assert not reloaded_table.exists_in_system() - - -def test_create_and_fetch_ray_data_from_s3(ray_table, table_s3_bucket): - name = "@/my_test_fetch_ray_table" - - my_table = ( - rh.table( - data=ray_table, - name=name, - path=f"/{table_s3_bucket}/ray", - system="s3", - mkdir=True, - ) - .write() - .save() - ) - - reloaded_table = rh.table(name=name) - reloaded_data: ray.data.Dataset = reloaded_table.fetch() - assert ray_table.to_pandas().equals(reloaded_data.to_pandas()) - - del ray_table - del my_table - - reloaded_table.delete_configs() - - reloaded_table.rm() - assert not reloaded_table.exists_in_system() - - -def test_create_and_fetch_dask_data_from_s3(dask_table, table_s3_bucket): - name = "@/my_test_fetch_dask_table" - - my_table = ( - rh.table( - data=dask_table, - name=name, - path=f"/{table_s3_bucket}/dask", - system="s3", - mkdir=True, - ) - .write() - .save() - ) - - reloaded_table = rh.table(name=name) - reloaded_data: "dask.dataframe.core.DataFrame" = reloaded_table.fetch() - assert dask_table.npartitions == reloaded_data.npartitions - - del dask_table - del my_table - - reloaded_table.delete_configs() - - reloaded_table.rm() - assert not reloaded_table.exists_in_system() - - -# ------------------------------------------------- -# ----------------- Table Sharing tests ----------------- -# ------------------------------------------------- -def test_sharing_table(pandas_table): - name = "shared_pandas_table" - - my_table = ( - rh.table( - data=pandas_table, - name=name, - system="s3", - mkdir=True, - ) - .write() - .save() - ) - - my_table.share( - users=["donny@run.house", "josh@run.house"], - access_level="write", - notify_users=False, - ) - - assert my_table.exists_in_system() - - -def test_read_shared_table(): - my_table = rh.table(name="@/shared_pandas_table") - df: pd.DataFrame = my_table.fetch() - assert not df.empty