diff --git a/packages/python/plotly/plotly/data/__init__.py b/packages/python/plotly/plotly/data/__init__.py index 9102677c9d6..7669c4588c3 100644 --- a/packages/python/plotly/plotly/data/__init__.py +++ b/packages/python/plotly/plotly/data/__init__.py @@ -3,7 +3,7 @@ """ -def gapminder(datetimes=False, centroids=False, year=None): +def gapminder(datetimes=False, centroids=False, year=None, pretty_names=False): """ Each row represents a country on a given year. @@ -24,10 +24,27 @@ def gapminder(datetimes=False, centroids=False, year=None): df["year"] = (df["year"].astype(str) + "-01-01").astype("datetime64[ns]") if not centroids: df = df.drop(["centroid_lat", "centroid_lon"], axis=1) + if pretty_names: + df.rename( + mapper=dict( + country="Country", + continent="Continent", + year="Year", + lifeExp="Life Expectancy", + gdpPercap="GDP per Capita", + pop="Population", + iso_alpha="ISO Alpha Country Code", + iso_num="ISO Numeric Country Code", + centroid_lat="Centroid Latitude", + centroid_lon="Centroid Longitude", + ), + axis="columns", + inplace=True, + ) return df -def tips(): +def tips(pretty_names=False): """ Each row represents a restaurant bill. @@ -36,7 +53,23 @@ def tips(): Returns: A `pandas.DataFrame` with 244 rows and the following columns: `['total_bill', 'tip', 'sex', 'smoker', 'day', 'time', 'size']`.""" - return _get_dataset("tips") + + df = _get_dataset("tips") + if pretty_names: + df.rename( + mapper=dict( + total_bill="Total Bill", + tip="Tip", + sex="Payer Gender", + smoker="Smokers at Table", + day="Day of Week", + time="Meal", + size="Party Size", + ), + axis="columns", + inplace=True, + ) + return df def iris():