Skip to content

Commit

Permalink
Migrate to polars
Browse files Browse the repository at this point in the history
  • Loading branch information
gutzbenj committed Mar 12, 2023
1 parent 0f2507f commit 20b5329
Show file tree
Hide file tree
Showing 16 changed files with 304 additions and 235 deletions.
28 changes: 19 additions & 9 deletions example/observations_station_gaussian_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,11 @@ def __init__(self, station_data: StationsResult):

valid_data = self.get_valid_data(result_values)

valid_data = valid_data.with_row_count("rc")

model, pars = self.make_composite_yearly_model(valid_data)

x = valid_data.with_row_count("rc").get_column("rc").to_numpy()
x = valid_data.get_column("rc").to_numpy()
y = valid_data.get_column("value").to_numpy()

out = model.fit(y, pars, x=x)
Expand Down Expand Up @@ -101,13 +103,13 @@ def make_composite_yearly_model(self, valid_data: pl.DataFrame) -> Tuple[Gaussia
https://lmfit.github.io/lmfit-py/model.html#composite-models-adding-or-multiplying-models"""
number_of_years = valid_data.get_column("date").dt.year().n_unique()

x = valid_data.with_row_count("rc").get_column("rc").to_numpy()
x = valid_data.get_column("rc").to_numpy()
y = valid_data.get_column("value").to_numpy()

index_per_year = x.max() / number_of_years

pars, composite_model = None, None
for year, group in valid_data.groupby(pl.col("date").dt.year()):
for year, group in valid_data.groupby(pl.col("date").dt.year(), maintain_order=True):
gmod = GaussianModel(prefix=f"g{year}_")
if pars is None:
pars = gmod.make_params()
Expand All @@ -121,9 +123,11 @@ def make_composite_yearly_model(self, valid_data: pl.DataFrame) -> Tuple[Gaussia
return composite_model, pars

@staticmethod
def model_pars_update(year: int, group: pl.DataFrame, pars: Parameters, index_per_year: float, y_max: float) -> Parameters:
def model_pars_update(
year: int, group: pl.DataFrame, pars: Parameters, index_per_year: float, y_max: float
) -> Parameters:
"""updates the initial values of the model parameters"""
idx = group.with_row_count("rc").get_column("rc").to_numpy()
idx = group.get_column("rc").to_numpy()
mean_index = idx.mean()

pars[f"g{year}_center"].set(value=mean_index, min=0.75 * mean_index, max=1.25 * mean_index)
Expand All @@ -135,8 +139,14 @@ def model_pars_update(year: int, group: pl.DataFrame, pars: Parameters, index_pe
def plot_data_and_model(self, valid_data: pl.DataFrame, out: ModelResult, savefig_to_file=True) -> None:
"""plots the data and the model"""
if savefig_to_file:
fig, ax = fig, ax = plt.subplots(figsize=(12, 12))
df = pl.DataFrame({"year": valid_data.get_column("date"), "value": valid_data.get_column("value").to_numpy(), "model": out.best_fit})
_ = plt.subplots(figsize=(12, 12))
df = pl.DataFrame(
{
"year": valid_data.get_column("date"),
"value": valid_data.get_column("value").to_numpy(),
"model": out.best_fit,
}
)
title = valid_data.get_column("parameter").unique()[0]
df.to_pandas().plot(x="year", y=["value", "model"], title=title)
if savefig_to_file:
Expand All @@ -151,8 +161,8 @@ def main():
"""Run example."""
logging.basicConfig(level=logging.INFO)

station_data_one_year = station_example(start_date="2020-12-25", end_date="2022-01-01")
_ = ModelYearlyGaussians(station_data_one_year)
# station_data_one_year = station_example(start_date="2020-12-25", end_date="2022-01-01")
# _ = ModelYearlyGaussians(station_data_one_year)

station_data_many_years = station_example(start_date="1995-12-25", end_date="2022-12-31")
_ = ModelYearlyGaussians(station_data_many_years)
Expand Down
22 changes: 21 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ tabulate = "^0.8"
timezonefinder = "^6.1"
tqdm = "^4.47"

crate = { version = "^0.30.1", optional = true } # Export feature.
dash = { version = "^2.8", optional = true } # Explorer UI feature.
dash-bootstrap-components = { version = "^1.4", optional = true } # Explorer UI feature.
dash-leaflet = { version = "^0.1.23", optional = true } # Explorer UI feature.
Expand Down
85 changes: 47 additions & 38 deletions tests/core/timeseries/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@
# Copyright (C) 2018-2021, earthobservations developers.
# Distributed under the MIT License. See LICENSE for more info.
import datetime
import datetime as dt
import json
import os
import shutil
import sqlite3
from unittest import mock

import dateutil.parser
import numpy as np
import openpyxl
import pandas as pd
import polars as pl
import pytest
import pytz
from surrogate import surrogate

from wetterdienst.core.process import filter_by_date_and_resolution
Expand Down Expand Up @@ -65,11 +68,11 @@ def dwd_climate_summary_tabular_columns():

@pytest.fixture
def df_station():
return pd.DataFrame.from_dict(
return pl.DataFrame(
{
"station_id": ["19087"],
"from_date": [dateutil.parser.isoparse("1957-05-01T00:00:00.000Z")],
"to_date": [dateutil.parser.isoparse("1995-11-30T00:00:00.000Z")],
"from_date": [dt.datetime(1957, 5, 1)],
"to_date": [dt.datetime(1995, 11, 30)],
"height": [645.0],
"latitude": [48.8049],
"longitude": [13.5528],
Expand All @@ -82,12 +85,12 @@ def df_station():

@pytest.fixture
def df_data():
return pd.DataFrame.from_dict(
return pl.DataFrame(
{
"station_id": ["01048"],
"dataset": ["climate_summary"],
"parameter": ["temperature_air_max_200"],
"date": [dateutil.parser.isoparse("2019-12-28T00:00:00.000Z")],
"date": [dt.datetime(2019, 12, 28, tzinfo=pytz.UTC)],
"value": [1.3],
"quality": [None],
}
Expand All @@ -112,59 +115,59 @@ def test_to_dict(df_data):
def test_filter_by_date(df_data):
"""Test filter by date"""
df = filter_by_date_and_resolution(df_data, "2019-12-28", Resolution.HOURLY)
assert not df.empty
assert not df.is_empty()
df = filter_by_date_and_resolution(df_data, "2019-12-27", Resolution.HOURLY)
assert df.empty
assert df.is_empty()


def test_filter_by_date_interval(df_data):
"""Test filter by date interval"""
df = filter_by_date_and_resolution(df_data, "2019-12-27/2019-12-29", Resolution.HOURLY)
assert not df.empty
assert not df.is_empty()
df = filter_by_date_and_resolution(df_data, "2020/2022", Resolution.HOURLY)
assert df.empty
assert df.is_empty()


def test_filter_by_date_monthly():
"""Test filter by date in monthly scope"""
result = pd.DataFrame.from_dict(
result = pl.DataFrame(
{
"station_id": ["01048"],
"dataset": ["climate_summary"],
"parameter": ["temperature_air_max_200"],
"from_date": [dateutil.parser.isoparse("2019-12-28T00:00:00.000Z")],
"to_date": [dateutil.parser.isoparse("2020-01-28T00:00:00.000Z")],
"from_date": [dt.datetime(2019, 12, 28, tzinfo=pytz.UTC)],
"to_date": [dt.datetime(2020, 1, 28, tzinfo=pytz.UTC)],
"value": [1.3],
"quality": [None],
}
)
df = filter_by_date_and_resolution(result, "2019-12/2020-01", Resolution.MONTHLY)
assert not df.empty
assert not df.is_empty()
df = filter_by_date_and_resolution(result, "2020/2022", Resolution.MONTHLY)
assert df.empty
assert df.is_empty()
df = filter_by_date_and_resolution(result, "2020", Resolution.MONTHLY)
assert df.empty
assert df.is_empty()


def test_filter_by_date_annual():
"""Test filter by date in annual scope"""
df = pd.DataFrame.from_dict(
df = pl.DataFrame(
{
"station_id": ["01048"],
"dataset": ["climate_summary"],
"parameter": ["temperature_air_max_200"],
"from_date": [dateutil.parser.isoparse("2019-01-01T00:00:00.000Z")],
"to_date": [dateutil.parser.isoparse("2019-12-31T00:00:00.000Z")],
"from_date": [dt.datetime(2019, 1, 1, tzinfo=pytz.UTC)],
"to_date": [dt.datetime(2019, 12, 31, tzinfo=pytz.UTC)],
"value": [1.3],
"quality": [None],
}
)
df = filter_by_date_and_resolution(df, date="2019-05/2019-09", resolution=Resolution.ANNUAL)
assert not df.empty
assert not df.is_empty()
df = filter_by_date_and_resolution(df, date="2020/2022", resolution=Resolution.ANNUAL)
assert df.empty
assert df.is_empty()
df = filter_by_date_and_resolution(df, date="2020", resolution=Resolution.ANNUAL)
assert df.empty
assert df.is_empty()


@pytest.mark.sql
Expand All @@ -173,11 +176,11 @@ def test_filter_by_sql(df_data):
df = ExportMixin(df=df_data).filter_by_sql(
sql="SELECT * FROM data WHERE parameter='temperature_air_max_200' AND value < 1.5"
)
assert not df.empty
assert not df.is_empty()
df = ExportMixin(df=df_data).filter_by_sql(
sql="SELECT * FROM data WHERE parameter='temperature_air_max_200' AND value > 1.5"
)
assert df.empty
assert df.is_empty()


def test_format_json(df_data):
Expand Down Expand Up @@ -215,7 +218,7 @@ def test_request(default_settings):
settings=default_settings,
).filter_by_station_id(station_id=[1048])
df = request.values.all().df
assert not df.empty
assert not df.is_empty()


@pytest.mark.remote
Expand Down Expand Up @@ -426,9 +429,13 @@ def test_export_zarr(tmp_path, settings_si_false_wide_shape, dwd_climate_summary
assert columns == set(dwd_climate_summary_tabular_columns)
# Validate content.
data = group
assert data["date"][0] == pd.Timestamp(2019, 1, 1, 0, 0, tzinfo=datetime.timezone.utc).to_numpy()
assert dt.datetime.fromtimestamp(int(data["date"][0]) / 1e9, tz=datetime.timezone.utc) == dt.datetime(
2019, 1, 1, 0, 0, tzinfo=datetime.timezone.utc
)
assert data["temperature_air_min_005"][0] == 1.5
assert data["date"][-1] == pd.Timestamp(2020, 1, 1, 0, 0, tzinfo=datetime.timezone.utc).to_numpy()
assert dt.datetime.fromtimestamp(int(data["date"][-1]) / 1e9, tz=datetime.timezone.utc) == dt.datetime(
2020, 1, 1, 0, 0, tzinfo=datetime.timezone.utc
)
assert data["temperature_air_min_005"][-1] == -4.6
shutil.rmtree(filename)

Expand Down Expand Up @@ -460,9 +467,9 @@ def test_export_feather(tmp_path, settings_si_false_wide_shape, dwd_climate_summ
assert table.column_names == dwd_climate_summary_tabular_columns
# Validate content.
data = table.to_pydict()
assert data["date"][0] == pd.Timestamp(2019, 1, 1, 0, 0, tzinfo=datetime.timezone.utc)
assert data["date"][0] == dt.datetime(2019, 1, 1, 0, 0, tzinfo=datetime.timezone.utc)
assert data["temperature_air_min_005"][0] == 1.5
assert data["date"][-1] == pd.Timestamp(2020, 1, 1, 0, 0, tzinfo=datetime.timezone.utc)
assert data["date"][-1] == dt.datetime(2020, 1, 1, 0, 0, tzinfo=datetime.timezone.utc)
assert data["temperature_air_min_005"][-1] == -4.6
os.unlink(filename)

Expand All @@ -489,9 +496,10 @@ def test_export_sqlite(tmp_path, settings_si_false_wide_shape):
cursor.close()
connection.close()
assert results[0] == (
0,
"01048",
"climate_summary",
"2019-01-01 00:00:00.000000",
"2019-01-01 00:00:00+00:00",
19.9,
10.0,
8.5,
Expand Down Expand Up @@ -523,9 +531,10 @@ def test_export_sqlite(tmp_path, settings_si_false_wide_shape):
)

assert results[-1] == (
365,
"01048",
"climate_summary",
"2020-01-01 00:00:00.000000",
"2020-01-01 00:00:00+00:00",
6.9,
10.0,
3.2,
Expand Down Expand Up @@ -558,7 +567,10 @@ def test_export_sqlite(tmp_path, settings_si_false_wide_shape):


@pytest.mark.remote
def test_export_cratedb(settings_si_false):
def test_export_cratedb(
tmp_path,
settings_si_false,
):
"""Test export of DataFrame to cratedb"""
request = DwdObservationRequest(
parameter=DwdObservationDataset.CLIMATE_SUMMARY,
Expand All @@ -569,17 +581,14 @@ def test_export_cratedb(settings_si_false):
station_id=[1048],
)
with mock.patch(
"pandas.DataFrame.to_sql",
"polars.DataFrame.write_database",
) as mock_to_sql:
df = request.values.all().df
ExportMixin(df=df).to_target("crate://localhost/?database=test&table=testdrive")
mock_to_sql.assert_called_once_with(
name="testdrive",
con="crate://localhost",
schema="test",
table_name="testdrive",
connection_uri="crate://localhost",
if_exists="replace",
index=False,
chunksize=5000,
)


Expand Down
2 changes: 1 addition & 1 deletion tests/provider/dwd/mosmix/test_api_stations.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def test_dwd_mosmix_stations_filtered(default_settings, mosmix_stations_schema):
"state": None,
},
],
schema=mosmix_stations_schema
schema=mosmix_stations_schema,
)
# expected_df.from_date = pd.to_datetime(expected_df.from_date, utc=True)
# expected_df.to_date = pd.to_datetime(expected_df.to_date, utc=True)
Expand Down
21 changes: 12 additions & 9 deletions tests/provider/noaa/ghcn/test_api_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import numpy as np
import pandas as pd
import polars as pl
import pytest
from pandas._testing import assert_frame_equal
from polars.testing import assert_frame_equal

from wetterdienst.provider.noaa.ghcn import NoaaGhcnParameter, NoaaGhcnRequest

Expand All @@ -25,18 +26,20 @@ def test_api_amsterdam(start_date, end_date, default_settings):
settings=default_settings,
).filter_by_name("DE BILT")
given_df = request.values.all().df
expected_df = pd.DataFrame(

expected_df = pl.DataFrame(
{
"station_id": pd.Categorical(["NLM00006260"]),
"dataset": pd.Categorical(["daily"]),
"parameter": pd.Categorical(["temperature_air_mean_200"]),
"date": [pd.Timestamp("2015-04-15 22:00:00+0000", tz="UTC")],
"station_id": ["NLM00006260"],
"dataset": ["daily"],
"parameter": ["temperature_air_mean_200"],
"date": [dt.datetime(2015, 4, 15, 23, tzinfo=dt.timezone.utc)],
"value": [282.75],
"quality": [np.nan],
"quality": [None],
}
)
assert_frame_equal(
given_df[given_df["date"] == pd.Timestamp("2015-04-15 22:00:00+00:00")].reset_index(drop=True),
given_df.filter(
pl.col("date").eq(dt.datetime(2015, 4, 15, 22, tzinfo=dt.timezone.utc))
),
expected_df,
check_categorical=False,
)
Loading

0 comments on commit 20b5329

Please sign in to comment.