Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: add doc decorator #539

Merged
merged 2 commits into from
Sep 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions tests/model/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from trackintel.model.util import (
NonCachedAccessor,
_copy_docstring,
doc,
_register_trackintel_accessor,
_wrapped_gdf_method,
TrackintelGeoDataFrame,
Expand Down Expand Up @@ -275,3 +276,57 @@ def foo(val):
# remove accessor again to make tests independent
pd.DataFrame._accesors = pd.DataFrame._accessors.remove("foo")
del pd.DataFrame.foo


class TestDoc:
"""Test doc decorator"""

def test_default_docstring(self):
"""Test that default docstring is kept."""

def foo():
pass

default = "I am a docstring"
foo.__doc__ = default
foo = doc()(foo)
assert foo.__doc__ == default

def test_None(self):
"""Test that None in args create no docstring"""

def foo():
pass

foo = doc(None)(foo)
assert foo.__doc__ == ""

def test_docstring_component(self):
"""Test that docstring component is formatable"""
d = "this is a {adjective} function"

def foo():
pass

foo._docstring_components = [d]
foo = doc(foo, adjective="cool")(foo)
assert foo.__doc__ == d.format(adjective="cool")

def test_string(self):
"""Test if string can be supplied"""
d = "this is a {adjective} function"

def foo():
pass

foo = doc(d, adjective="cool")(foo)
assert foo.__doc__ == d.format(adjective="cool")

def test_fall_through_case(self):
"""Test fall through case in for loop"""

def foo():
pass

foo = doc(foo)(foo)
assert foo.__doc__ == ""
59 changes: 7 additions & 52 deletions trackintel/io/postgis.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import trackintel as ti
from trackintel.io.util import _index_warning_default_none
from trackintel.model.util import doc, _shared_docs
hongyeehh marked this conversation as resolved.
Show resolved Hide resolved


def _handle_con_string(func):
Expand Down Expand Up @@ -126,6 +127,7 @@ def read_positionfixes_postgis(
return ti.io.read_positionfixes_gpd(pfs, **(read_gpd_kws or {}))


@doc(_shared_docs["write_postgis"], long="positionfixes", short="pfs")
@_handle_con_string
def write_positionfixes_postgis(
positionfixes, name, con, schema=None, if_exists="fail", index=True, index_label=None, chunksize=None, dtype=None
Expand Down Expand Up @@ -217,6 +219,7 @@ def read_triplegs_postgis(
return ti.io.read_triplegs_gpd(tpls, **(read_gpd_kws or {}))


@doc(_shared_docs["write_postgis"], long="triplegs", short="tpls")
@_handle_con_string
def write_triplegs_postgis(
triplegs, name, con, schema=None, if_exists="fail", index=True, index_label=None, chunksize=None, dtype=None
Expand Down Expand Up @@ -320,6 +323,7 @@ def read_staypoints_postgis(
return ti.io.read_staypoints_gpd(sp, **(read_gpd_kws or {}))


@doc(_shared_docs["write_postgis"], long="staypoints", short="sp")
@_handle_con_string
def write_staypoints_postgis(
staypoints, name, con, schema=None, if_exists="fail", index=True, index_label=None, chunksize=None, dtype=None
Expand Down Expand Up @@ -429,6 +433,7 @@ def read_locations_postgis(
return ti.io.read_locations_gpd(locs, center=center, **(read_gpd_kws or {}))


@doc(_shared_docs["write_postgis"], long="locations", short="locs")
@_handle_con_string
def write_locations_postgis(
locations, name, con, schema=None, if_exists="fail", index=True, index_label=None, chunksize=None, dtype=None
Expand Down Expand Up @@ -559,6 +564,7 @@ def read_trips_postgis(
return ti.io.read_trips_gpd(trips, **(read_gpd_kws or {}))


@doc(_shared_docs["write_postgis"], long="trips", short="trips")
@_handle_con_string
def write_trips_postgis(
trips, name, con, schema=None, if_exists="fail", index=True, index_label=None, chunksize=None, dtype=None
Expand Down Expand Up @@ -684,6 +690,7 @@ def read_tours_postgis(
return ti.io.read_tours_gpd(tours, **(read_gpd_kws or {}))


@doc(_shared_docs["write_postgis"], long="tours", short="tours")
@_handle_con_string
def write_tours_postgis(
tours, name, con, schema=None, if_exists="fail", index=True, index_label=None, chunksize=None, dtype=None
Expand All @@ -701,55 +708,3 @@ def write_tours_postgis(
chunksize=chunksize,
dtype=dtype,
)


# helper docstring to change __doc__ of all write functions conveniently in one place
__doc = """Stores {long} to PostGIS. Usually, this is directly called on a {long}
DataFrame (see example below).

Parameters
----------
{long} : GeoDataFrame (as trackintel {long})
The {long} to store to the database.

name : str
The name of the table to write to.

con : sqlalchemy.engine.Connection or sqlalchemy.engine.Engine
active connection to PostGIS database.

schema : str, optional
The schema (if the database supports this) where the table resides.

if_exists : str, {{'fail', 'replace', 'append'}}, default 'fail'
How to behave if the table already exists.

- fail: Raise a ValueError.
- replace: Drop the table before inserting new values.
- append: Insert new values to the existing table.

index : bool, default True
Write DataFrame index as a column. Uses index_label as the column name in the table.

index_label : str or sequence, default None
Column label for index column(s). If None is given (default) and index is True, then the index names are used.

chunksize : int, optional
How many entries should be written at the same time.

dtype: dict of column name to SQL type, default None
Specifying the datatype for columns.
The keys should be the column names and the values should be the SQLAlchemy types.

Examples
--------
>>> {short}.as_{long}.to_postgis(conn_string, table_name)
>>> ti.io.postgis.write_{long}_postgis({short}, conn_string, table_name)
"""

write_positionfixes_postgis.__doc__ = __doc.format(long="positionfixes", short="pfs")
write_triplegs_postgis.__doc__ = __doc.format(long="triplegs", short="tpls")
write_staypoints_postgis.__doc__ = __doc.format(long="staypoints", short="sp")
write_locations_postgis.__doc__ = __doc.format(long="locations", short="locs")
write_trips_postgis.__doc__ = __doc.format(long="trips", short="trips")
write_tours_postgis.__doc__ = __doc.format(long="tours", short="tours")
110 changes: 104 additions & 6 deletions trackintel/model/util.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import warnings
from functools import partial, update_wrapper, wraps
from textwrap import dedent

import pandas as pd
import warnings
from geopandas import GeoDataFrame


def _copy_docstring(wrapped, assigned=("__doc__",), updated=[]):
"""Thin wrapper for `functools.update_wrapper` to mimic `functools.wraps` but to only copy the docstring."""
return partial(update_wrapper, wrapped=wrapped, assigned=assigned, updated=updated)


def _wrapped_gdf_method(func):
"""Decorator function that downcast types to trackintel class if is (Geo)DataFrame and has the required columns."""

Expand Down Expand Up @@ -134,3 +131,104 @@ def decorator(accessor):
return accessor

return decorator


def _copy_docstring(wrapped, assigned=("__doc__",), updated=[]):
"""Thin wrapper for `functools.update_wrapper` to mimic `functools.wraps` but to only copy the docstring."""
return partial(update_wrapper, wrapped=wrapped, assigned=assigned, updated=updated)


# doc is derived from pandas.util._decorators (2.1.0)
# module https://github.com/pandas-dev/pandas/blob/main/LICENSE


def doc(*docstrings, **params):
"""
A decorator to take docstring templates, concatenate them and perform string
substitution on them.

This decorator will add a variable "_docstring_components" to the wrapped
callable to keep track the original docstring template for potential usage.
If it should be consider as a template, it will be saved as a string.
Otherwise, it will be saved as callable, and later user __doc__ and dedent
to get docstring.

Parameters
----------
*docstrings : None, str, or callable
The string / docstring / docstring template to be appended in order
after default docstring under callable.
**params
The string which would be used to format docstring template.
"""

def decorator(decorated):
# collecting docstring and docstring templates
components = []
if decorated.__doc__:
components.append(dedent(decorated.__doc__))

for docstring in docstrings:
if docstring is None:
continue
if hasattr(docstring, "_docstring_components"):
components.extend(docstring._docstring_components)
elif isinstance(docstring, str) or docstring.__doc__:
components.append(docstring)

decorated._docstring_components = components
params_applied = (c.format(**params) if (isinstance(c, str) and params) else c for c in components)
decorated.__doc__ = "".join(c if isinstance(c, str) else dedent(c.__doc__ or "") for c in params_applied)
return decorated

return decorator


_shared_docs = {}

# in _shared_docs as all write_postgis_xyz functions use this docstring
_shared_docs[
"write_postgis"
hongyeehh marked this conversation as resolved.
Show resolved Hide resolved
] = """
Stores {long} to PostGIS. Usually, this is directly called on a {long}
DataFrame (see example below).

Parameters
----------
{long} : GeoDataFrame (as trackintel {long})
The {long} to store to the database.

name : str
The name of the table to write to.

con : sqlalchemy.engine.Connection or sqlalchemy.engine.Engine
active connection to PostGIS database.

schema : str, optional
The schema (if the database supports this) where the table resides.

if_exists : str, {{'fail', 'replace', 'append'}}, default 'fail'
How to behave if the table already exists.

- fail: Raise a ValueError.
- replace: Drop the table before inserting new values.
- append: Insert new values to the existing table.

index : bool, default True
Write DataFrame index as a column. Uses index_label as the column name in the table.

index_label : str or sequence, default None
Column label for index column(s). If None is given (default) and index is True, then the index names are used.

chunksize : int, optional
How many entries should be written at the same time.

dtype: dict of column name to SQL type, default None
Specifying the datatype for columns.
The keys should be the column names and the values should be the SQLAlchemy types.

Examples
--------
>>> {short}.as_{long}.to_postgis(conn_string, table_name)
>>> ti.io.postgis.write_{long}_postgis({short}, conn_string, table_name)
"""