diff --git a/hat/clock.py b/hat/clock.py deleted file mode 100644 index 409fa97..0000000 --- a/hat/clock.py +++ /dev/null @@ -1,71 +0,0 @@ -import datetime -import functools -import time - -import humanize - - -def digital_clock(func): - "A quiet decorator for timing functions" - - @functools.wraps(func) - def clocked(*args, **kwargs): - # time - t0 = time.perf_counter() - result = func(*args, **kwargs) - elapsed = time.perf_counter() - t0 - human_readable_time = humanize.naturaldelta(datetime.timedelta(seconds=elapsed)) - - # name - name = func.__name__ - - print(f"{name}() took {human_readable_time}") - return result - - # return function with timing decorator - return clocked - - -def clock(func): - "A verbose decorator for timing functions" - - @functools.wraps(func) - def clocked(*args, **kwargs): - # time - t0 = time.perf_counter() - result = func(*args, **kwargs) - elapsed = time.perf_counter() - t0 - human_readable_time = humanize.naturaldelta(datetime.timedelta(seconds=elapsed)) - - # name - name = func.__name__ - - # arguments - arg_str = ", ".join(repr(arg) for arg in args) - if arg_str == "": - arg_str = "no arguments" - - # keywords - pairs = [f"{k}={w}" for k, w in sorted(kwargs.items())] - key_str = ", ".join(pairs) - if key_str == "": - key_str = "no keywords" - - print( - f"""{name}() took {human_readable_time} to run - with following inputs {arg_str} and {key_str}""" - ) - return result - - # return function with timing decorator - return clocked - - -if __name__ == "__main__": - - @clock - def test(arg1, keyword=False): - time.sleep(0.1) - pass - - test(1, keyword="hello") diff --git a/hat/filters.py b/hat/filters.py index cae74c8..bc0aebb 100644 --- a/hat/filters.py +++ b/hat/filters.py @@ -164,7 +164,7 @@ def filter_timeseries(sims_ds: xr.DataArray, obs_ds: xr.DataArray, threshold=80) obs_ds = obs_ds.sel(station=matching_stations) obs_ds = obs_ds.sel(time=sims_ds.time) - obs_ds = obs_ds.dropna(dim='station', how='all') + obs_ds = obs_ds.dropna(dim="station", how="all") sims_ds = sims_ds.sel(station=obs_ds.station) # Only keep observations in the same time period as the simulations diff --git a/hat/graphs.py b/hat/graphs.py deleted file mode 100644 index 8f6ba39..0000000 --- a/hat/graphs.py +++ /dev/null @@ -1,30 +0,0 @@ -import pandas as pd -import plotly.express as px - - -def graph_sims_and_obs( - sims, - obs, - ID, - sims_data_name="simulation_timeseries", - obs_data_name="obsdis", - height=500, - width=1200, -): - # observations, simulations, time - o = obs.sel(station=ID)[obs_data_name].values - s = sims.sel(station=ID)[sims_data_name].values - t = obs.sel(station=ID).time.values - - df = pd.DataFrame({"time": t, "simulations": s, "observations": o}) - fig = px.line( - df, - x="time", - y=["simulations", "observations"], - title="Simulations & Observations", - ) - fig.data[0].line.color = "#34eb7d" - fig.data[1].line.color = "#3495eb" - fig.update_layout(height=height, width=width) - fig.update_yaxes(title_text="discharge") - fig.show() diff --git a/hat/hydrostats.py b/hat/hydrostats.py index 67cc14b..c2712d3 100644 --- a/hat/hydrostats.py +++ b/hat/hydrostats.py @@ -3,7 +3,6 @@ import folium import geopandas as gpd -import pandas as pd import numpy as np import xarray as xr from branca.colormap import linear @@ -17,7 +16,9 @@ def run_analysis( sims_ds: xr.DataArray, obs_ds: xr.DataArray, ) -> xr.Dataset: - """Run statistical analysis on simulation and observation timeseries""" + """ + Run statistical analysis on simulation and observation timeseries + """ # list of stations stations = sims_ds.coords["station"].values @@ -35,7 +36,7 @@ def run_analysis( for station in stations: sims = sims_ds.sel(station=station).to_numpy() obs = obs_ds.sel(station=station).to_numpy() - + stat = func(sims, obs) if stat is None: print(f"Warning! All NaNs for station {station}") @@ -44,7 +45,7 @@ def run_analysis( statistics = np.array(statistics) # Add the Series to the DataFrame - ds[name] = xr.DataArray(statistics, coords={'station': stations}) + ds[name] = xr.DataArray(statistics, coords={"station": stations}) return ds diff --git a/hat/images.py b/hat/images.py deleted file mode 100644 index 14ebeae..0000000 --- a/hat/images.py +++ /dev/null @@ -1,29 +0,0 @@ -import matplotlib -import numpy as np -from quicklook import quicklook - - -def arr_to_image(arr: np.array) -> np.array: - """modify array so that it is optimized for viewing""" - - # image array - img = np.array(arr) - - img = quicklook.replace_nan(img) - img = quicklook.percentile_clip(img, 2) - img = quicklook.bytescale(img) - img = quicklook.reshape_array(img) - - return img - - -def numpy_to_png( - arr: np.array, dim="time", index="somedate", fpath="image.png" -) -> None: - """Save numpy array to png""" - - # image from array - img = arr_to_image(arr) - - # save to file - matplotlib.image.imsave(fpath, img) diff --git a/hat/interactive/__init__.py b/hat/interactive/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/hat/interactive/explorers.py b/hat/interactive/explorers.py new file mode 100644 index 0000000..4026318 --- /dev/null +++ b/hat/interactive/explorers.py @@ -0,0 +1,258 @@ +import os + +import ipywidgets +import pandas as pd +import xarray as xr +from IPython.core.display import display + +from hat.interactive.leaflet import LeafletMap, PyleafletColormap +from hat.interactive.widgets import ( + MetaDataWidget, + PlotlyWidget, + StatisticsWidget, + WidgetsManager, +) +from hat.observations import read_station_metadata_file + + +def prepare_simulations_data(simulations, sims_var_name): + """ + process simulations raw datasets to a standard dataframe + """ + + # If simulations is a dictionary, load data for each experiment + sim_ds = {} + for exp, path in simulations.items(): + # Expanding the tilde + expanded_path = os.path.expanduser(path) + + if os.path.isfile(expanded_path): # Check if it's a file + ds = xr.open_dataset(expanded_path) + elif os.path.isdir(expanded_path): # Check if it's a directory + # Handle the case when it's a directory; + # assume all .nc files in the directory need to be combined + files = [f for f in os.listdir(expanded_path) if f.endswith(".nc")] + ds = xr.open_mfdataset( + [os.path.join(expanded_path, f) for f in files], combine="by_coords" + ) + else: + raise ValueError(f"Invalid path: {expanded_path}") + sim_ds[exp] = ds[sims_var_name] + + return sim_ds + + +def prepare_observations_data(observations, sim_ds, obs_var_name): + """ + process observation raw dataset to a standard dataframe + """ + file_extension = os.path.splitext(observations)[-1].lower() + + if file_extension == ".csv": + obs_df = pd.read_csv(observations, parse_dates=["Timestamp"]) + obs_melted = obs_df.melt( + id_vars="Timestamp", var_name="station", value_name=obs_var_name + ) + # Convert the melted DataFrame to xarray Dataset + obs_ds = obs_melted.set_index(["Timestamp", "station"]).to_xarray() + obs_ds = obs_ds.rename({"Timestamp": "time"}) + elif file_extension == ".nc": + obs_ds = xr.open_dataset(observations) + else: + raise ValueError("Unsupported file format for observations.") + + # Subset obs_ds based on sim_ds time values + if isinstance(sim_ds, xr.Dataset): + time_values = sim_ds["time"].values + elif isinstance(sim_ds, dict): + # Use the first dataset in the dictionary to determine time values + first_dataset = next(iter(sim_ds.values())) + time_values = first_dataset["time"].values + else: + raise ValueError("Unexpected type for sim_ds") + + obs_ds = obs_ds[obs_var_name].sel(time=time_values) + return obs_ds + + +def find_common_station(station_index, stations_metadata, statistics, sim_ds, obs_ds): + """ + find common station between observation and simulation and station metadata + """ + ids = [] + ids += [list(obs_ds["station"].values)] + ids += [list(ds["station"].values) for ds in sim_ds.values()] + ids += [stations_metadata[station_index]] + if statistics: + ids += [list(ds["station"].values) for ds in statistics.values()] + + common_ids = None + for id in ids: + if common_ids is None: + common_ids = set(id) + else: + common_ids = set(id) & common_ids + return list(common_ids) + + +class TimeSeriesExplorer: + """ + Initialize the interactive map with configurations and data sources. + """ + + def __init__(self, config, stations, observations, simulations, stats=None): + self.config = config + self.stations_metadata = read_station_metadata_file( + fpath=stations, + coord_names=config["station_coordinates"], + epsg=config["station_epsg"], + filters=config["station_filters"], + ) + + # Use the external functions to prepare data + sim_ds = prepare_simulations_data(simulations, config["sims_var_name"]) + obs_ds = prepare_observations_data(observations, sim_ds, config["obs_var_name"]) + + # set station index + self.station_index = config["station_id_column_name"] + + # Retrieve statistics from the statistics netcdf input + self.statistics = {} + if stats: + for name, path in stats.items(): + self.statistics[name] = xr.open_dataset(path) + + # Ensure the keys of self.statistics match the keys of self.sim_ds + assert set(self.statistics.keys()) == set( + sim_ds.keys() + ), "Mismatch between statistics and simulations keys." + + # find common station ids between metadata, observation and simulations + common_ids = find_common_station( + self.station_index, self.stations_metadata, self.statistics, sim_ds, obs_ds + ) + + print(f"Found {len(common_ids)} common stations") + self.stations_metadata = self.stations_metadata.loc[ + self.stations_metadata[self.station_index].isin(common_ids) + ] + obs_ds = obs_ds.sel(station=common_ids) + for sim, ds in sim_ds.items(): + sim_ds[sim] = ds.sel(station=common_ids) + + # Create loading widget + self.loading_widget = ipywidgets.Label(value="") + + # Title label + self.title_label = ipywidgets.Label( + "Interactive Map Visualisation for Hydrological Model Performance", + layout=ipywidgets.Layout(justify_content="center"), + style={"font_weight": "bold", "font_size": "24px", "font_family": "Arial"}, + ) + + # Create the interactive widgets + datasets = sim_ds + datasets["obs"] = obs_ds + widgets = {} + widgets["plot"] = PlotlyWidget(datasets) + widgets["stats"] = StatisticsWidget(self.statistics) + widgets["meta"] = MetaDataWidget(self.stations_metadata, self.station_index) + self.widgets = WidgetsManager( + widgets, config["station_id_column_name"], self.loading_widget + ) + + # Create the main leaflet map + self.leafletmap = LeafletMap() + + def create_frame(self): + """ + Initialize the layout elements for the map visualization. + """ + + # # Layouts 1 + # main_layout = ipywidgets.Layout( + # justify_content='space-around', + # align_items='stretch', + # spacing='2px', + # width='1000px' + # ) + # half_layout = ipywidgets.Layout( + # justify_content='space-around', + # align_items='center', + # spacing='2px', + # width='50%' + # ) + + # # Frames + # stats_frame = ipywidgets.HBox( + # [self.widgets['plot'].output, self.widgets['stats'].output], + # # layout=main_layout + # ) + # main_frame = ipywidgets.VBox( + # [ + # self.title_label, + # self.loading_widget, + # self.leafletmap.output(main_layout), + # self.widgets['meta'].output, stats_frame + # ], + # layout=main_layout + # ) + + # Layouts 2 + main_layout = ipywidgets.Layout( + justify_content="space-around", + align_items="stretch", + spacing="2px", + width="1000px", + ) + left_layout = ipywidgets.Layout( + justify_content="space-around", + align_items="center", + spacing="2px", + width="40%", + ) + right_layout = ipywidgets.Layout( + justify_content="center", align_items="center", spacing="2px", width="60%" + ) + + # Frames + top_left_frame = self.leafletmap.output(left_layout) + top_right_frame = ipywidgets.VBox( + [self.widgets["plot"].output, self.widgets["stats"].output], + layout=right_layout, + ) + main_top_frame = ipywidgets.HBox([top_left_frame, top_right_frame]) + + # Main layout + main_frame = ipywidgets.VBox( + [self.title_label, main_top_frame, self.widgets["meta"].output], + layout=main_layout, + ) + return main_frame + + def mapplot(self, colorby=None, sim=None, limits=None, mp_colormap="viridis"): + """Plot the map with stations colored by a given metric. + input example: + colorby = "kge" this should be the objective functions of the statistics + limits = [, ] min and max values of the color bar + mp_colormap = "viridis" colormap name to be used based on matplotlib colormap + """ + # create colormap from statistics + stats = None + if self.statistics and colorby is not None and sim is not None: + stats = self.statistics[sim][colorby] + colormap = PyleafletColormap(self.config, stats, mp_colormap, limits) + + # add layer to the leaflet map + self.leafletmap.add_geolayer( + self.stations_metadata, + colormap, + self.widgets, + self.config["station_coordinates"], + ) + + # Initialize frame elements + frame = self.create_frame() + + # Display the main layout + display(frame) diff --git a/hat/interactive/leaflet.py b/hat/interactive/leaflet.py new file mode 100644 index 0000000..9585d35 --- /dev/null +++ b/hat/interactive/leaflet.py @@ -0,0 +1,150 @@ +import json + +import ipyleaflet +import ipywidgets +import matplotlib as mpl +import matplotlib.pyplot as plt +import numpy as np + + +def _compute_bounds(stations_metadata, coord_names): + """Compute the bounds of the map based on the stations metadata.""" + + lon_column = coord_names[0] + lat_column = coord_names[1] + + lons = stations_metadata[lon_column].values + lats = stations_metadata[lat_column].values + + min_lat, max_lat = min(lats), max(lats) + min_lon, max_lon = min(lons), max(lons) + + return [(float(min_lat), float(min_lon)), (float(max_lat), float(max_lon))] + + +class LeafletMap: + def __init__( + self, + basemap=ipyleaflet.basemaps.OpenStreetMap.Mapnik, + ): + self.map = ipyleaflet.Map( + basemap=basemap, layout=ipywidgets.Layout(width="100%", height="600px") + ) + self.legend_widget = ipywidgets.Output() + + def _set_boundaries(self, stations_metadata, coord_names): + """Compute the boundaries of the map based on the stations metadata.""" + + lon_column = coord_names[0] + lat_column = coord_names[1] + + lons = stations_metadata[lon_column].values + lats = stations_metadata[lat_column].values + + min_lat, max_lat = min(lats), max(lats) + min_lon, max_lon = min(lons), max(lons) + + bounds = [(float(min_lat), float(min_lon)), (float(max_lat), float(max_lon))] + self.map.fit_bounds(bounds) + + def add_geolayer(self, geodata, colormap, widgets, coord_names=None): + geojson = ipyleaflet.GeoJSON( + data=json.loads(geodata.to_json()), + style={ + "radius": 7, + "opacity": 0.5, + "weight": 1.9, + "dashArray": "2", + "fillOpacity": 0.5, + }, + hover_style={"radius": 10, "fillOpacity": 1}, + point_style={"radius": 5}, + style_callback=colormap.style_callback(), + ) + geojson.on_click(widgets.update) + self.map.add_layer(geojson) + + if coord_names is not None: + self._set_boundaries(geodata, coord_names) + + self.legend_widget = colormap.legend() + + def output(self, layout): + output = ipywidgets.VBox([self.map, self.legend_widget], layout=layout) + return output + + +class PyleafletColormap: + def __init__(self, config, stats=None, colormap_style="viridis", range=None): + self.config = config + self.stats = stats + print(self.stats) + if self.stats is not None: + # Normalize the data for coloring + if range is None: + self.min_val = self.stats.values.min() + self.max_val = self.stats.values.max() + else: + self.min_val = range[0] + self.max_val = range[1] + else: + self.min_val = 0 + self.max_val = 1 + + self.colormap = plt.cm.get_cmap(colormap_style) + + def style_callback(self): + if self.stats is not None: + norm = plt.Normalize(self.min_val, self.max_val) + + def map_color(feature): + station_id = feature["properties"][ + self.config["station_id_column_name"] + ] + color = mpl.colors.rgb2hex( + self.colormap(norm(self.stats.sel(station=station_id).values)) + ) + return { + "color": "black", + "fillColor": color, + } + + else: + + def map_color(feature): + return { + "color": "black", + "fillColor": "blue", + } + + return map_color + + def legend(self): + """Generate an HTML legend for the map.""" + # Convert the colormap to a list of RGB values + rgb_values = [ + mpl.colors.rgb2hex(self.colormap(i)) for i in np.linspace(0, 1, 256) + ] + + # Create a gradient style using the RGB values + gradient_style = ", ".join(rgb_values) + gradient_html = f""" +
+ """ + + # Create labels + labels_html = f""" +
+ Low: {self.min_val:.1f} + High: {self.max_val:.1f} +
+ """ + # Combine gradient and labels + legend_html = gradient_html + labels_html + + return ipywidgets.HTML(legend_html) diff --git a/hat/interactive/widgets.py b/hat/interactive/widgets.py new file mode 100644 index 0000000..f4ce147 --- /dev/null +++ b/hat/interactive/widgets.py @@ -0,0 +1,319 @@ +import time + +import numpy as np +import pandas as pd +import plotly.graph_objs as go +from IPython.core.display import display +from IPython.display import clear_output +from ipywidgets import HTML, DatePicker, HBox, Label, Layout, Output, VBox + + +class ThrottledClick: + """ + Initialize a click throttler with a given delay. + to prevent user from swift multiple events clicking that results in crashing + """ + + def __init__(self, delay=1.0): + self.delay = delay + self.last_call = 0 + + def should_process(self): + """ + Determine if a click should be processed based on the delay. + """ + current_time = time.time() + if current_time - self.last_call > self.delay: + self.last_call = current_time + return True + return False + + +class WidgetsManager: + def __init__(self, widgets, index_column, loading_widget=None): + self.widgets = widgets + self.index_column = index_column + self.throttler = self._initialize_throttler() + self.loading_widget = loading_widget + + def _initialize_throttler(self, delay=1.0): + """Initialize the throttler for click events.""" + return ThrottledClick(delay) + + def update(self, feature, **kwargs): + """Handle the selection of a marker on the map.""" + + # Check if we should process the click + if not self.throttler.should_process(): + return + + if self.loading_widget is not None: + self.loading_widget.value = ( + "Loading..." # Indicate that data is being loaded + ) + + # Extract station_id from the selected feature + metadata = feature["properties"] + index = metadata[self.index_column] + + # update widgets + for wgt in self.widgets.values(): + wgt.update(index, metadata) + + if self.loading_widget is not None: + self.loading_widget.value = "" # Clear the loading message + + def __getitem__(self, item): + return self.widgets[item] + + +class Widget: + def __init__(self, output): + self.output = output + + def update(self, index, metadata): + raise NotImplementedError + + +def filter_nan_values(dates, data_values): + """Filters out NaN values and their associated dates.""" + valid_dates = [date for date, val in zip(dates, data_values) if not np.isnan(val)] + valid_data = [val for val in data_values if not np.isnan(val)] + + return valid_dates, valid_data + + +class PlotlyWidget(Widget): + """Plotly widget to display timeseries.""" + + def __init__(self, datasets): + self.datasets = datasets + # initial_title = { + # 'text': "Click on your desired station location", + # 'y':0.9, + # 'x':0.5, + # 'xanchor': 'center', + # 'yanchor': 'top' + # } + self.figure = go.FigureWidget( + layout=go.Layout( + # title = initial_title, + height=350, + margin=dict(l=120), + legend=dict( + orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1 + ), + xaxis_title="Date", + xaxis_tickformat="%d-%m-%Y", + yaxis_title="Discharge [m3/s]", + ) + ) + ds_time = datasets["obs"]["time"].values.astype("datetime64[D]") + self.ds_time_str = [dt.isoformat() for dt in pd.to_datetime(ds_time)] + + # Add date pickers for start and end dates + self.start_date_picker = DatePicker(description="Start") + self.end_date_picker = DatePicker(description="End") + + # Observe changes in the date pickers to update the plot + self.start_date_picker.observe(self._update_plot_dates, names="value") + self.end_date_picker.observe(self._update_plot_dates, names="value") + + # Date picker title + date_label = Label( + "Please select the date to accurately change the date axis of the plot" + ) + date_picker_box = HBox([self.start_date_picker, self.end_date_picker]) + + layout = Layout(justify_content="center", align_items="center") + output = VBox([self.figure, date_label, date_picker_box], layout=layout) + super().__init__(output) + + def _update_plot_dates(self, change): + start_date = self.start_date_picker.value.strftime("%Y-%m-%d") + end_date = self.end_date_picker.value.strftime("%Y-%m-%d") + self.figure.update_layout(xaxis_range=[start_date, end_date]) + + def update_data(self, station_id): + """Update the simulation data for the given station ID.""" + for name, ds in self.datasets.items(): + if station_id in ds["station"].values: + ds_time_series_data = ds.sel(station=station_id).values + valid_dates_ds, valid_data_ds = filter_nan_values( + self.ds_time_str, ds_time_series_data + ) + self._update_trace(valid_dates_ds, valid_data_ds, name) + else: + print(f"Station ID: {station_id} not found in dataset {name}.") + + def _update_trace(self, x_data, y_data, name): + """Update or add a trace to the Plotly figure.""" + trace_exists = any([trace.name == name for trace in self.figure.data]) + if trace_exists: + for trace in self.figure.data: + if trace.name == name: + trace.x = x_data + trace.y = y_data + else: + self.figure.add_trace( + go.Scatter(x=x_data, y=y_data, mode="lines", name=name) + ) + + def update_title(self, metadata): + station_id = metadata["station_id"] + station_name = metadata["StationName"] + updated_title = ( + f"Selected station:
ID: {station_id}, name: {station_name}
" + ) + self.figure.update_layout( + title={ + "text": updated_title, + "y": 0.9, + "x": 0.5, + "xanchor": "center", + "yanchor": "top", + "font": {"color": "black", "size": 16}, + } + ) + + def update(self, index, metadata): + """Update the overall plot with new data for the given station ID.""" + # self.update_title(metadata) + self.update_data(index) + + +class HTMLTableWidget(Widget): + def __init__(self, dataframe, title): + """ + Initialize the table object for displaying statistics and station properties. + """ + self.dataframe = dataframe + self.title = title + super().__init__(Output()) + + # Define the styles for the statistics table + self.table_style = """ + + """ + self.stat_title_style = ( + "style='font-size: 18px; font-weight: bold; text-align: center;'" + ) + # Initialize the stat_table_html and station_table_html with empty tables + empty_df = pd.DataFrame() + self.display_dataframe_with_scroll(empty_df, title=self.title) + + def display_dataframe_with_scroll(self, df, title=""): + """Display a DataFrame with a scrollable view.""" + table_html = df.to_html(classes="custom-table") + content = f"{self.table_style}

{title}

{table_html}
" # noqa: E501 + with self.output: + clear_output(wait=True) # Clear any previous plots or messages + display(HTML(content)) + + def update(self, index, metadata): + dataframe = self.extract_dataframe(index) + self.display_dataframe_with_scroll(dataframe, title=self.title) + + +class DataFrameWidget(Widget): + def __init__(self, dataframe, title): + """ + Initialize the table object for displaying statistics and station properties. + """ + self.dataframe = dataframe + self.title = title + super().__init__(output=Output(title=self.title)) + + # Initialize the stat_table_html and station_table_html with empty tables + empty_df = pd.DataFrame() + with self.output: + clear_output(wait=True) # Clear any previous plots or messages + display(empty_df) + + def update(self, index, metadata): + dataframe = self.extract_dataframe(index) + with self.output: + clear_output(wait=True) # Clear any previous plots or messages + display(dataframe) + + +class MetaDataWidget(HTMLTableWidget): + def __init__(self, dataframe, station_index): + title = "Station Metadata" + self.station_index = station_index + super().__init__(dataframe, title) + + def extract_dataframe(self, station_id): + """Generate a station property table for the given station ID.""" + stations_df = self.dataframe + selected_station_df = stations_df[stations_df[self.station_index] == station_id] + + return selected_station_df + + +class StatisticsWidget(HTMLTableWidget): + def __init__(self, dataframe): + title = "Model Performance Statistics Overview" + super().__init__(dataframe, title) + + def extract_dataframe(self, station_id): + """Generate a statistics table for the given station ID.""" + data = [] + + # Check if statistics is None or empty + if not self.dataframe: + print("No statistics data provided.") + return pd.DataFrame() # Return an empty dataframe + + # Loop through each simulation and get the statistics for the given station_id + for exp_name, stats in self.dataframe.items(): + if station_id in stats["station"].values: + row = [exp_name] + [ + round(stats[var].sel(station=station_id).values.item(), 2) + for var in stats.data_vars + if var not in ["longitude", "latitude"] + ] + data.append(row) + + # Check if data has any items + if not data: + print(f"No statistics data found for station ID: {station_id}.") + return pd.DataFrame() # Return an empty dataframe + + # Convert the data to a DataFrame for display + columns = ["Exp. name"] + list(stats.data_vars.keys()) + statistics_df = pd.DataFrame(data, columns=columns) + + # Round the numerical columns to 2 decimal places + numerical_columns = [col for col in statistics_df.columns if col != "Exp. name"] + statistics_df[numerical_columns] = statistics_df[numerical_columns].round(2) + + return statistics_df diff --git a/hat/networking.py b/hat/networking.py deleted file mode 100644 index a5b04ab..0000000 --- a/hat/networking.py +++ /dev/null @@ -1,44 +0,0 @@ -import os -import platform -import socket - - -def get_host(): - """Local host on Mac and network host on HPC - (note the network address on HPC is not constant)""" - - # get the hostname - hostname = socket.gethostname() - - # get the IP address(es) associated with the hostname - ip_addresses = socket.getaddrinfo( - hostname, None, socket.AF_INET, socket.SOCK_STREAM - ) - - # return first valid address - for ip_address in ip_addresses: - network_host = ip_address[4][0] - return network_host - - -def mac_or_hpc(): - """Is this running on a Mac or the HPC or other?""" - - if platform.system() == "Darwin": - return "mac" - elif platform.system() == "Linux" and os.environ.get("ECPLATFORM"): - return "hpc" - else: - return "other" - - -def host_and_port(host="127.0.0.1", port=8000): - """return network host and port for tiler app to use""" - - computer = mac_or_hpc() - - if computer == "hpc": - host = get_host() - port = 8700 - - return (host, port) diff --git a/hat/parsers.py b/hat/parsers.py deleted file mode 100644 index 0c50b64..0000000 --- a/hat/parsers.py +++ /dev/null @@ -1,17 +0,0 @@ -import dateutil.parser -import streamlit as st - - -@st.cache_data -def datetime_from_cftime(cftimes): - """parse CFTimeIndex to python datetime, - e.g. from a NetCDF file ds.indexes['time']""" - return [dateutil.parser.parse(x.isoformat()) for x in cftimes] - - -def simulation_timeperiod(sim): - # simulation timeperiod - min_time = min(sim.indexes["time"]) - max_time = max(sim.indexes["time"]) - - return (min_time, max_time) diff --git a/hat/plots.py b/hat/plots.py deleted file mode 100644 index 3581409..0000000 --- a/hat/plots.py +++ /dev/null @@ -1,72 +0,0 @@ -from typing import Union - -import numpy as np -import pandas as pd -import plotly.express as px -import streamlit as st -from matplotlib import pyplot as plt - - -# PLOTLY (interactive) -def plotly_timeseries(t, y): - df = pd.DataFrame({"time": t, "discharge": y}) - return px.line(df, x="time", y="discharge", title="Discharge Timeseries") - - -# MATPLOTLIB (not interactive) -def plot_timeseries(t, y, jupyter=False): - fig, ax1 = plt.subplots() - fig.set_size_inches(14, 6) - ax1.plot(t, y, "dodgerblue") - ax1.set_xlabel("time (s)") - ax1.set_ylabel("discharge", color="b") - ax1.tick_params("y", colors="b") - - if jupyter: - return fig - - st.write(fig) - - -def histogram( - arr: Union[np.array, np.ma.MaskedArray], - bins=10, - clip=None, - title="Histogram", - figsize=(6, 4), -): - """plot histogram of a numpy array or masked numpy array""" - - # apply mask (if one exists) - if isinstance(arr, np.ma.MaskedArray): - arr = arr.compressed() - - # return if not numpy - if not isinstance(arr, np.ndarray): - print("histogram() requires a numpy array or masked numpy array") - return - - # remove flat dimensions - arr = arr.squeeze() - - # remove nans - arr = arr[~np.isnan(arr)] - - # histogram range (percentile clip or minmax) - if clip: - histogram_range = ( - round(np.percentile(arr, clip)), - round(np.percentile(arr, 100 - clip)), - ) - else: - histogram_range = (np.min(arr), np.max(arr)) - - # count number of values in each bin - counts, bins = np.histogram(arr, bins=bins, range=histogram_range) - - _ = plt.figure(figsize=figsize) - plt.hist(bins[:-1], bins, weights=counts) - plt.title(title) - - # show plot - plt.show() diff --git a/hat/tools/hydrostats_cli.py b/hat/tools/hydrostats_cli.py index 7aae11f..204f92a 100644 --- a/hat/tools/hydrostats_cli.py +++ b/hat/tools/hydrostats_cli.py @@ -4,10 +4,10 @@ import xarray as xr from hat import hydrostats_functions +from hat.data import find_main_var from hat.exceptions import UserError from hat.filters import filter_timeseries from hat.hydrostats import run_analysis -from hat.data import find_main_var def check_inputs(functions, sims, obs): @@ -80,7 +80,6 @@ def hydrostats_cli( if not functions: return - # simulations sims_ds = xr.open_dataset(sims) var = find_main_var(sims_ds, min_dim=2) diff --git a/hat/visualisation.py b/hat/visualisation.py deleted file mode 100644 index daedd39..0000000 --- a/hat/visualisation.py +++ /dev/null @@ -1,593 +0,0 @@ -""" -Python module for visualising geospatial content using jupyter notebook, -for both spatial and temporal, e.g. netcdf, vector, raster, with time series etc -""" - -import json -import os -import time -from typing import Dict -import geopandas as gpd -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -import plotly.graph_objs as go -import xarray as xr -from ipyleaflet import CircleMarker, Map, GeoJSON -from IPython.core.display import display -from IPython.display import clear_output -from ipywidgets import HTML, HBox, Label, Layout, Output, VBox, DatePicker -import matplotlib as mpl -from hat.observations import read_station_metadata_file - - -class InteractiveElements: - def __init__(self, bounds, map_instance, statistics=None): - """Initialize the interactive elements for the map.""" - self.map = self._initialize_map(bounds) - self.loading_label = Label(value="") - self.throttler = self._initialize_throttler() - self.plotly_obj = PlotlyObject() - self.table_obj = TableObject(self) - self.statistics = statistics - self.output_widget = Output() - self.map_instance = map_instance - self.statistics = map_instance.statistics - self.stations_metadata = map_instance.stations_metadata - self.station_index = map_instance.station_index - - - def _initialize_map(self, bounds): - """Initialize the map widget.""" - map_widget = Map(layout=Layout(width="100%", height="600px")) - map_widget.fit_bounds(bounds) - return map_widget - - def generate_html_legend(self, colormap, min_val, max_val): - """Generate an HTML legend for the map.""" - # Convert the colormap to a list of RGB values - rgb_values = [mpl.colors.rgb2hex(colormap(i)) for i in np.linspace(0, 1, 256)] - - # Create a gradient style using the RGB values - gradient_style = ', '.join(rgb_values) - gradient_html = f""" -
- """ - - # Create labels - labels_html = f""" -
- Low: {min_val:.1f} - High: {max_val:.1f} -
- """ - # Combine gradient and labels - legend_html = gradient_html + labels_html - - return HTML(legend_html) - - def _initialize_throttler(self, delay=1.0): - """Initialize the throttler for click events.""" - return ThrottledClick(delay) - - - def add_plotly_object(self, plotly_obj): - """Add a PlotlyObject to the IPyLeaflet class.""" - self.plotly_obj = plotly_obj - - def add_table_object(self, table_obj): - """Add a PlotlyObject to the IPyLeaflet class.""" - self.table_obj = table_obj - - def handle_marker_selection(self, feature, **kwargs): - """Handle the selection of a marker on the map.""" - # Extract station_id from the selected feature - station_id = feature["properties"]["station_id"] - - # Call the action handler with the extracted station_id - self.handle_marker_action(station_id) - - def handle_marker_action(self, station_id): - '''Define a callback to handle marker clicks and add the plot figure and statistics and station property.''' - - # Check if we should process the click - if not self.throttler.should_process(): - return - - self.loading_label.value = "Loading..." # Indicate that data is being loaded - station_id = str(station_id) # Convert station ID to string for consistency - station_name = self.stations_metadata.loc[self.stations_metadata[self.station_index] == station_id, "StationName"].values[0] - updated_title = f"Selected station:
ID: {station_id}, name: {station_name}
" - - # Update the plot with simulation and observation data - self.plotly_obj.update(station_id) - self.plotly_obj.figure.update_layout(title={ - 'text': updated_title, - 'y': 0.9, - 'x': 0.5, - 'xanchor': 'center', - 'yanchor': 'top', - 'font': { - 'color': 'black', - 'size': 16 - } - } - ) - - # Generate and display the statistics table for the clicked station - if self.statistics: - self.table_obj.update(station_id) - - # Update the table in the layout - children_list = list(self.map_instance.top_right_frame.children) - - # Find the index of the old table and replace it with the new table - for i, child in enumerate(children_list): - if isinstance(child, type(self.table_obj.stat_table_html)): - children_list[i] = self.table_obj.stat_table_html - break - else: - # If the old table wasn't found, append the new table - children_list.append(self.table_obj.stat_table_html) - - self.map_instance.top_right_frame.children = tuple(children_list) - - - # Update the station table in the layout - children_list_main = list(self.map_instance.layout.children) - - # Find the index of the old station table and replace it with the new table - for i, child in enumerate(children_list_main): - if isinstance(child, type(self.table_obj.station_table_html)): - children_list_main[i] = self.table_obj.station_table_html - break - else: - # If the old station table wasn't found, append the new table - children_list_main.append(self.table_obj.station_table_html) - - self.map_instance.layout.children = tuple(children_list_main) - - self.table_obj.update(station_id) - - self.loading_label.value = "" # Clear the loading message - - with self.output_widget: - clear_output(wait=True) # Clear any previous plots or messages - - -class PlotlyObject: - def __init__(self): - """Initialize the Plotly object for visualization.""" - initial_title = "Click on your desired station location!" - self.figure = go.FigureWidget( - layout=go.Layout( - title = initial_title, - height=350, - margin=dict(l=100), - legend=dict( - orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1 - ), - xaxis_title="Date", - xaxis_tickformat="%d-%m-%Y", - yaxis_title="Discharge [m3/s]" - ) - ) - - def update_simulation_data(self, station_id): - """Update the simulation data for the given station ID.""" - for name, ds in self.sim_ds.items(): - if station_id in ds["station"].values: - ds_time_series_data = ds["dis"].sel(station=station_id).values - valid_dates_ds, valid_data_ds = filter_nan_values( - self.ds_time_str, ds_time_series_data - ) - self._update_trace(valid_dates_ds, valid_data_ds, name) - else: - print(f"Station ID: {station_id} not found in dataset {name}.") - - def update_observation_data(self, station_id): - """Update the observation data for the given station ID.""" - if station_id in self.obs_ds["station"].values: - obs_time_series = self.obs_ds["obsdis"].sel(station=station_id).values - valid_dates_obs, valid_data_obs = filter_nan_values( - self.ds_time_str, obs_time_series - ) - self._update_trace(valid_dates_obs, valid_data_obs, "Obs. Data") - else: - print(f"Station ID: {station_id} not found in obs_df.") - - def _update_trace(self, x_data, y_data, name): - """Update or add a trace to the Plotly figure.""" - trace_exists = any([trace.name == name for trace in self.figure.data]) - if trace_exists: - for trace in self.figure.data: - if trace.name == name: - trace.x = x_data - trace.y = y_data - else: - self.figure.add_trace( - go.Scatter(x=x_data, y=y_data, mode="lines", name=name) - ) - - def update(self, station_id): - """Update the overall plot with new data for the given station ID.""" - self.update_simulation_data(station_id) - self.update_observation_data(station_id) - - -class TableObject: - def __init__(self, map_instance): - """Initialize the table object for displaying statistics and station properties.""" - self.map_instance = map_instance - # Define the styles for the statistics table - self.table_style = """ - - """ - self.stat_title_style = "style='font-size: 18px; font-weight: bold; text-align: center;'" - # Initialize the stat_table_html and station_table_html with empty tables - empty_df = pd.DataFrame() - self.stat_table_html = self.display_dataframe_with_scroll(empty_df, title="Model Performance Statistics Overview") - self.station_table_html = self.display_dataframe_with_scroll(empty_df, title="Station Property") - - - def generate_statistics_table(self, station_id): - """Generate a statistics table for the given station ID.""" - data = [] - - # Check if statistics is None or empty - if not self.map_instance.statistics: - print("No statistics data provided.") - return pd.DataFrame() # Return an empty dataframe - - # Loop through each simulation and get the statistics for the given station_id - for exp_name, stats in self.map_instance.statistics.items(): - if station_id in stats["station"].values: - row = [exp_name] + [ - round(stats[var].sel(station=station_id).values.item(), 2) - for var in stats.data_vars - if var not in ["longitude", "latitude"] - ] - data.append(row) - - # Check if data has any items - if not data: - print(f"No statistics data found for station ID: {station_id}.") - return pd.DataFrame() # Return an empty dataframe - - # Convert the data to a DataFrame for display - columns = ["Exp. name"] + list(stats.data_vars.keys()) - statistics_df = pd.DataFrame(data, columns=columns) - - # Round the numerical columns to 2 decimal places - numerical_columns = [col for col in statistics_df.columns if col != "Exp. name"] - statistics_df[numerical_columns] = statistics_df[numerical_columns].round(2) - - return statistics_df - - def generate_station_table(self, station_id): - """Generate a station property table for the given station ID.""" - stations_df = self.map_instance.stations_metadata - selected_station_df = stations_df[stations_df[self.map_instance.station_index] == station_id] - - return selected_station_df - - def display_dataframe_with_scroll(self, df, title=""): - """Display a DataFrame with a scrollable view.""" - table_html = df.to_html(classes="custom-table") - content = f"{self.table_style}

{title}

{table_html}
" - return(HTML(content)) - - def update(self, station_id): - """Update the tables with new data for the given station ID.""" - df_stat = self.generate_statistics_table(station_id) - self.stat_table_html = self.display_dataframe_with_scroll(df_stat, title="Model Performance Statistics Overview") - - df_station = self.generate_station_table(station_id) - self.station_table_html = self.display_dataframe_with_scroll(df_station, title="Station Property") - - - -class InteractiveMap: - def __init__(self, config, stations, observations, simulations, stats=None): - """Initialize the interactive map with configurations and data sources.""" - self.config = config - self.stations_metadata = read_station_metadata_file( - fpath=stations, - coord_names=config["station_coordinates"], - epsg=config["station_epsg"], - filters=config["station_filters"], - ) - - obs_var_name = config["obs_var_name"] - - # Use the external functions to prepare data - self.sim_ds = prepare_simulations_data(simulations) - self.obs_ds = prepare_observations_data(observations, self.sim_ds, obs_var_name) - - # Convert ds 'time' to datetime format for alignment with external_df - self.ds_time = self.obs_ds["time"].values.astype("datetime64[D]") - - # set station index - self.station_index = config["station_id_column_name"] - self.stations_metadata[self.station_index] = self.stations_metadata[self.station_index].astype(str) - - # Retrieve statistics from the statistics netcdf input - self.statistics = {} - if stats: - for name, path in stats.items(): - self.statistics[name] = xr.open_dataset(path) - - # Ensure the keys of self.statistics match the keys of self.sim_ds - assert set(self.statistics.keys()) == set(self.sim_ds.keys()), "Mismatch between statistics and simulations keys." - - # find common station ids between metadata, observation and simulations - self.common_id = self.find_common_station() - - print(f"Found {len(self.common_id)} common stations") - self.stations_metadata = self.stations_metadata.loc[self.stations_metadata[self.station_index].isin(self.common_id)] - self.obs_ds = self.obs_ds.sel(station=self.common_id) - for sim, ds in self.sim_ds.items(): - self.sim_ds[sim] = ds.sel(station=self.common_id) - - # Pass the map bound to the interactive elements - self.bounds = compute_bounds(self.stations_metadata, self.common_id, self.station_index, self.config["station_coordinates"]) - # self.interactive_elements = InteractiveElements(self.bounds) - self.interactive_elements = InteractiveElements(self.bounds, self) - - # Pass the necessary data to the interactive elements - self.interactive_elements.plotly_obj.sim_ds = self.sim_ds - self.interactive_elements.plotly_obj.obs_ds = self.obs_ds - self.interactive_elements.plotly_obj.ds_time_str = [dt.isoformat() for dt in pd.to_datetime(self.ds_time)] - - - def find_common_station(self): - """find common station between observation and simulation and station metadata""" - ids = [] - ids += [list(self.obs_ds["station"].values)] - ids += [list(ds["station"].values) for ds in self.sim_ds.values()] - ids += [self.stations_metadata[self.station_index]] - if self.statistics: - ids += [list(ds["station"].values) for ds in self.statistics.values()] - - common_ids = None - for id in ids: - if common_ids is None: - common_ids = set(id) - else: - common_ids = set(id) & common_ids - return list(common_ids) - - - def _update_plot_dates(self, change): - start_date = self.start_date_picker.value.strftime('%Y-%m-%d') - end_date = self.end_date_picker.value.strftime('%Y-%m-%d') - self.interactive_elements.plotly_obj.figure.update_layout(xaxis_range=[start_date, end_date]) - - - def mapplot(self, colorby="kge", sim=None, range=None, colormap=None): - """Plot the map with stations colored by a given metric. - input example: - colorby = "kge" # this should be the objective functions of the statistics - range = [, ] #min and max values of the color bar - colormap = plt.cm.get_cmap("RdYlGn") # color map to be used based on matplotlib colormap - """ - - # Retrieve the statistics data of simulation choice/ by default - stat_data = self.statistics[sim][colorby] - - # Normalize the data for coloring - if range is None: - min_val, max_val = stat_data.values.min(), stat_data.values.max() - else: - min_val, max_val = range[0], range[1] - - norm = plt.Normalize(min_val, max_val) - - if colormap is None: - colormap = plt.cm.get_cmap("viridis") - - def map_color(feature): - station_id = feature['properties'][self.config["station_id_column_name"]] - color = mpl.colors.rgb2hex( - colormap(norm(stat_data.sel(station=station_id).values)) - ) - return { - 'color': 'black', - 'fillColor': color, - } - - geo_data = GeoJSON( - data=json.loads(self.stations_metadata.to_json()), - style={'radius': 7, 'opacity':0.5, 'weight':1.9, 'dashArray':'2', 'fillOpacity':0.5}, - hover_style={'radius': 10, 'fillOpacity': 1}, - point_style={'radius': 5}, - style_callback=map_color, - ) - self.legend_widget = self.interactive_elements.generate_html_legend(colormap, min_val, max_val) - - - geo_data.on_click(self.interactive_elements.handle_marker_selection) - self.interactive_elements.map.add_layer(geo_data) - - # Add date pickers for start and end dates - self.start_date_picker = DatePicker(description='Start') - self.end_date_picker = DatePicker(description='End') - - # Observe changes in the date pickers to update the plot - self.start_date_picker.observe(self._update_plot_dates, names='value') - self.end_date_picker.observe(self._update_plot_dates, names='value') - - # Initialize layout elements - self._initialize_layout_elements() - - # Display the main layout - display(self.layout) - - - def _initialize_layout_elements(self): - """Initialize the layout elements for the map visualization.""" - # Title label - self.title_label = Label( - "Interactive Map Visualisation for Hydrological Model Performance", - layout=Layout(justify_content='center'), - style={'font_weight': 'bold', 'font_size': '24px', 'font_family': 'Arial'} - ) - - # Layouts - main_layout = Layout(justify_content='space-around', align_items='stretch', spacing='2px', width='1000px') - left_layout = Layout(justify_content='space-around', align_items='center', spacing='2px', width='40%') - right_layout = Layout(justify_content='center', align_items='center', spacing='2px', width='60%') - - # Date picker box and label - self.date_label = Label("Please select the date to accurately change the date axis of the plot") - self.date_picker_box = HBox([self.start_date_picker, self.end_date_picker]) - - # Frames - self.top_right_frame = VBox([self.interactive_elements.plotly_obj.figure, - self.date_label, self.date_picker_box, self.interactive_elements.table_obj.stat_table_html - ], layout=right_layout) - self.top_left_frame = VBox([self.interactive_elements.map, self.legend_widget], layout=left_layout) - self.main_top_frame = HBox([self.top_left_frame, self.top_right_frame]) - - # Main layout - self.layout = VBox([self.title_label, - self.interactive_elements.loading_label, - self.main_top_frame, - self.interactive_elements.table_obj.station_table_html], layout=main_layout) - - -class ThrottledClick: - """Initialize a click throttler with a given delay. to prevent user from swift multiple events clicking that results in crashing""" - def __init__(self, delay=1.0): - self.delay = delay - self.last_call = 0 - - def should_process(self): - """Determine if a click should be processed based on the delay.""" - current_time = time.time() - if current_time - self.last_call > self.delay: - self.last_call = current_time - return True - return False - - -def prepare_simulations_data(simulations): - """process simulations raw datasets to a standard dataframe""" - - # If simulations is a dictionary, load data for each experiment - sim_ds = {} - for exp, path in simulations.items(): - # Expanding the tilde - expanded_path = os.path.expanduser(path) - - if os.path.isfile(expanded_path): # Check if it's a file - ds = xr.open_dataset(expanded_path) - elif os.path.isdir(expanded_path): # Check if it's a directory - # Handle the case when it's a directory; - # assume all .nc files in the directory need to be combined - files = [f for f in os.listdir(expanded_path) if f.endswith(".nc")] - ds = xr.open_mfdataset( - [os.path.join(expanded_path, f) for f in files], combine="by_coords" - ) - else: - raise ValueError(f"Invalid path: {expanded_path}") - sim_ds[exp] = ds - - return sim_ds - - -def prepare_observations_data(observations, sim_ds, obs_var_name): - """process observation raw dataset to a standard dataframe""" - file_extension = os.path.splitext(observations)[-1].lower() - - if file_extension == ".csv": - obs_df = pd.read_csv(observations, parse_dates=["Timestamp"]) - obs_melted = obs_df.melt( - id_vars="Timestamp", var_name="station", value_name=obs_var_name - ) - - # Convert the melted DataFrame to xarray Dataset - obs_ds = obs_melted.set_index(["Timestamp", "station"]).to_xarray() - obs_ds = obs_ds.rename({"Timestamp": "time"}) - - elif file_extension == ".nc": - obs_ds = xr.open_dataset(observations) - - else: - raise ValueError("Unsupported file format for observations.") - - # Subset obs_ds based on sim_ds time values - if isinstance(sim_ds, xr.Dataset): - time_values = sim_ds["time"].values - elif isinstance(sim_ds, dict): - # Use the first dataset in the dictionary to determine time values - first_dataset = next(iter(sim_ds.values())) - time_values = first_dataset["time"].values - else: - raise ValueError("Unexpected type for sim_ds") - - obs_ds = obs_ds.sel(time=time_values) - return obs_ds - - -def properties_to_dataframe(properties: Dict) -> pd.DataFrame: - """Convert feature properties to a DataFrame for display.""" - return pd.DataFrame([properties]) - - -def filter_nan_values(dates, data_values): - """Filters out NaN values and their associated dates.""" - valid_dates = [date for date, val in zip(dates, data_values) if not np.isnan(val)] - valid_data = [val for val in data_values if not np.isnan(val)] - - return valid_dates, valid_data - -def compute_bounds(stations_metadata, common_ids, station_index, coord_names): - """Compute the bounds of the map based on the stations metadata.""" - - # Filter the metadata to only include stations with common IDs - filtered_stations = stations_metadata[stations_metadata[station_index].isin(common_ids)] - - lon_column = coord_names[0] - lat_column = coord_names[1] - - lons = filtered_stations[lon_column].values - lats = filtered_stations[lat_column].values - - min_lat, max_lat = min(lats), max(lats) - min_lon, max_lon = min(lons), max(lons) - - return [(float(min_lat), float(min_lon)), (float(max_lat), float(max_lon))] \ No newline at end of file diff --git a/hat/visualisation_v2.py b/hat/visualisation_v2.py deleted file mode 100644 index 9cf723d..0000000 --- a/hat/visualisation_v2.py +++ /dev/null @@ -1,199 +0,0 @@ -""" -Python module for visualising geospatial content using jupyter notebook, -for both spatial and temporal, e.g. netcdf, vector, raster, with time series etc -""" -import os -import pandas as pd -from typing import Dict, Union, List - -import geopandas as gpd -from shapely.geometry import Point -import xarray as xr - -from hat.hydrostats import run_analysis -from hat.filters import filter_timeseries - -class NotebookMap: - def __init__(self, config: Dict, stations_metadata: str, observations: str, simulations: Union[Dict, str], stats=None): - self.config = config - - # Prepare Station Metadata - self.stations_metadata = self.prepare_station_metadata( - fpath=stations_metadata, - station_id_column_name=config["station_id_column_name"], - coord_names=config['station_coordinates'], - epsg=config['station_epsg'], - filters=config['station_filters'] - ) - - # Prepare Observations Data - self.observation = self.prepare_observations_data(observations) - - # Prepare Simulations Data - self.simulations = self.prepare_simulations_data(simulations) - - # Ensure stations in obs and sims are present in metadata - valid_stations = set(self.stations_metadata[self.config["station_id_column_name"]].values) - - # Filter data based on valid stations - self.observation = self.filter_stations_by_metadata(self.observation, valid_stations) - self.simulations = {exp: self.filter_stations_by_metadata(ds, valid_stations) for exp, ds in self.simulations.items()} - - self.stats_input = stats - self.stat_threshold = 70 # default for now, may need to be added as option - self.stats_output = {} - if self.stats_input: - self.stats_output = self.calculate_statistics() - - def filter_stations_by_metadata(self, ds, valid_stations): - """Filter the stations in the dataset to only include those in valid_stations.""" - return ds.sel(station=[s for s in ds.station.values if s in valid_stations]) - - - def prepare_station_metadata(self, fpath: str, station_id_column_name: str, coord_names: List[str], epsg: int, filters=None) -> xr.Dataset: - # Read the station metadata file - df = pd.read_csv(fpath) - - # Convert to a GeoDataFrame - geometry = [Point(xy) for xy in zip(df[coord_names[0]], df[coord_names[1]])] - gdf = gpd.GeoDataFrame(df, crs=f"EPSG:{epsg}", geometry=geometry) - - # Apply filters if provided - if filters: - for column, value in filters.items(): - gdf = gdf[gdf[column] == value] - - return gdf - - def prepare_observations_data(self, observations: str) -> xr.Dataset: - """ - Load and preprocess observations data. - - Parameters: - - observations: Path to the observations data file. - - Returns: - - obs_ds: An xarray Dataset containing the observations data. - """ - file_extension = os.path.splitext(observations)[-1].lower() - station_id_column_name = self.config.get('station_id_column_name', 'station_id_num') - - if file_extension == '.csv': - obs_df = pd.read_csv(observations, parse_dates=["Timestamp"]) - obs_melted = obs_df.melt(id_vars="Timestamp", var_name="station", value_name="obsdis") - - # Convert melted DataFrame to xarray Dataset - obs_ds = obs_melted.set_index(["Timestamp", "station"]).to_xarray() - obs_ds = obs_ds.rename({"Timestamp": "time"}) - - elif file_extension == '.nc': - obs_ds = xr.open_dataset(observations) - - # Check for necessary attributes - if 'obsdis' not in obs_ds or 'time' not in obs_ds.coords: - raise ValueError("The NetCDF file lacks the expected variables or coordinates.") - - # Rename the station_id to station and set it as an index - obs_ds = obs_ds.rename({station_id_column_name: "station"}) - obs_ds = obs_ds.set_index(station="station") - else: - raise ValueError("Unsupported file format for observations.") - - return obs_ds - - - def prepare_simulations_data(self, simulations: Union[Dict, str]) -> Dict[str, xr.Dataset]: - """ - Load and preprocess simulations data. - - Parameters: - - simulations: Either a string path to the simulations data file or a dictionary mapping - experiment names to file paths. - - Returns: - - datasets: A dictionary mapping experiment names to their respective xarray Datasets. - """ - sim_ds = {} - - # Handle the case where simulations is a single string path - if isinstance(simulations, str): - sim_ds["default"] = xr.open_dataset(simulations) - - # Handle the case where simulations is a dictionary of experiment names to paths - elif isinstance(simulations, dict): - for exp, path in simulations.items(): - expanded_path = os.path.expanduser(path) - - if os.path.isfile(expanded_path): # If it's a file - ds = xr.open_dataset(expanded_path) - - elif os.path.isdir(expanded_path): # If it's a directory - # Assume all .nc files in the directory need to be combined - files = [f for f in os.listdir(expanded_path) if f.endswith('.nc')] - ds = xr.open_mfdataset([os.path.join(expanded_path, f) for f in files], combine='by_coords') - - else: - raise ValueError(f"Invalid path: {expanded_path}") - - sim_ds[exp] = ds - else: - raise TypeError("Expected simulations to be either str or dict.") - - return sim_ds - - - - def calculate_statistics(self) -> Dict[str, xr.Dataset]: - """ - Calculate statistics for the simulations against the observations. - - Returns: - - statistics: A dictionary mapping experiment names to their respective statistics xarray Datasets. - """ - stats_output = {} - - if isinstance(self.simulations, xr.Dataset): - # For a single simulation dataset - sim_filtered, obs_filtered = filter_timeseries(self.simulations, self.observation, self.stat_threshold) - stats_output["default"] = run_analysis(self.stats_input, sim_filtered, obs_filtered) - elif isinstance(self.simulations, dict): - # For multiple simulation datasets - for exp, ds in self.simulations.items(): - # print(f"Processing experiment: {exp}") - # print("Simulation dataset stations:", ds.station.values) - # print("Observation dataset stations:", self.observation.station.values) - - sim_filtered, obs_filtered = filter_timeseries(ds, self.observation, self.stat_threshold) - # stats_output[exp] = run_analysis(self.stats_input, sim_filtered, obs_filtered) - - stations_series = pd.Series(ds.station.values) - duplicates = stations_series.value_counts().loc[lambda x: x > 1] - print(exp, ":", duplicates) - - else: - raise ValueError("Unexpected type for self.simulations") - - return stats_output - -# Utility Functions - -def properties_to_dataframe(properties: Dict) -> pd.DataFrame: - """Convert feature properties to a DataFrame for display.""" - return pd.DataFrame([properties]) - -def filter_nan_values(dates, data_values): - """ - Filters out NaN values and their associated dates. - - Parameters: - - dates: List of dates. - - data_values: List of data values corresponding to the dates. - - Returns: - - valid_dates: List of dates without NaN values. - - valid_data: List of non-NaN data values. - """ - valid_dates = [date for date, val in zip(dates, data_values) if not np.isnan(val)] - valid_data = [val for val in data_values if not np.isnan(val)] - - return valid_dates, valid_data \ No newline at end of file diff --git a/notebooks/examples/4_visualisation_interactive.ipynb b/notebooks/examples/4_visualisation_interactive.ipynb index 1c498cf..26b7365 100644 --- a/notebooks/examples/4_visualisation_interactive.ipynb +++ b/notebooks/examples/4_visualisation_interactive.ipynb @@ -13,7 +13,6 @@ "metadata": {}, "outputs": [], "source": [ - "from hat.visualisation import NotebookMap\n", "stations = \"~/git/hat/data/outlets_v4.0_20230726_withEFAS.csv\"\n", "observations = \"~/git/hat/data/observations/destine_observations.nc\"\n", "simulations = {\n", @@ -45,7 +44,8 @@ "metadata": {}, "outputs": [], "source": [ - "map = NotebookMap(config, stations, observations, simulations, stats=statistics)" + "from hat.interactive.explorers import TimeSeriesExplorer\n", + "map = TimeSeriesExplorer(config, stations, observations, simulations, stats=statistics)" ] }, { @@ -73,8 +73,22 @@ } ], "metadata": { + "kernelspec": { + "display_name": "hat-venv", + "language": "python", + "name": "hat-venv" + }, "language_info": { - "name": "python" + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.10" } }, "nbformat": 4, diff --git a/tests/test_hydrostats_decorators.py b/tests/test_hydrostats_decorators.py index 3a53a38..c3839aa 100644 --- a/tests/test_hydrostats_decorators.py +++ b/tests/test_hydrostats_decorators.py @@ -63,10 +63,8 @@ def test_filter_nan(): assert np.allclose(decorated(nan3, nan3), np.array([2, 4])) # all nans - with pytest.raises(ValueError): - decorated(arr, nans) - with pytest.raises(ValueError): - decorated(nans, arr) + assert decorated(arr, nans) is None + assert decorated(nans, arr) is None # def test_handle_divide_by_zero_error(): @@ -124,8 +122,7 @@ def test_hydrostat(): # # all zero division # with pytest.raises(ZeroDivisionError): - # decorated_divide(ones, zeros) + # print(decorated_divide(ones, zeros)) # all nans - with pytest.raises(ValueError): - decorated_divide(nans, nans) + assert decorated_divide(nans, nans) is None