diff --git a/databricks/koalas/frame.py b/databricks/koalas/frame.py index d49f7760d7..b2bb8a7631 100644 --- a/databricks/koalas/frame.py +++ b/databricks/koalas/frame.py @@ -1509,6 +1509,45 @@ def to_koalas(self): else: return DataFrame(self) + def cache(self): + """ + Yields and caches the current DataFrame. + + The Koalas DataFrame is yielded as a protected resource and it's corresponding + Spark DataFrame is cached which gets uncached after execution goes of the context. + + Examples + -------- + >>> df = ks.DataFrame([(.2, .3), (.0, .6), (.6, .0), (.2, .1)], + ... columns=['dogs', 'cats']) + >>> df + dogs cats + 0 0.2 0.3 + 1 0.0 0.6 + 2 0.6 0.0 + 3 0.2 0.1 + + >>> with df.cache() as cached_df: + ... print(cached_df.count()) + ... + dogs 4 + cats 4 + dtype: int64 + + >>> df = df.cache() + >>> df.to_pandas().mean(axis=1) + 0 0.25 + 1 0.30 + 2 0.30 + 3 0.15 + dtype: float64 + + To uncache the dataframe, use `unpersist` function + + >>> df.unpersist() + """ + return _CachedDataFrame(self._sdf, self._metadata) + def to_table(self, name: str, format: Optional[str] = None, mode: str = 'error', partition_cols: Union[str, List[str], None] = None, **options): @@ -3603,3 +3642,37 @@ def _reduce_spark_multi(sdf, aggs): l2 = list(row) assert len(l2) == len(aggs), (row, l2) return l2 + + +class _CachedDataFrame(DataFrame): + """ + Cached Koalas DataFrame, which corresponds to Pandas DataFrame logically, but internally + it caches the corresponding Spark DataFrame. + """ + def __init__(self, sdf, metadata): + self._cached = sdf.cache() + super(_CachedDataFrame, self).__init__(self._cached, index=metadata) + + def __enter__(self): + return self + + def __exit__(self, exception_type, exception_value, traceback): + self.unpersist() + + def unpersist(self): + """ + The `unpersist` function is used to uncache the Koalas DataFrame when it + is not used with `with` statement. + + Examples + -------- + >>> df = ks.DataFrame([(.2, .3), (.0, .6), (.6, .0), (.2, .1)], + ... columns=['dogs', 'cats']) + >>> df = df.cache() + + To uncache the dataframe, use `unpersist` function + + >>> df.unpersist() + """ + if self._cached.is_cached: + self._cached.unpersist()