Source code for micromet.report.graphs

import plotly.graph_objects as go
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import scipy
import logging

from micromet.utils import logger_check


[docs] def energy_sankey(df, date_text="2024-06-19 12:00", logger: logging.Logger = None): """ Create a Sankey diagram of energy balance for a specific time. This function generates a Sankey diagram to visualize the flow of energy components in a system, such as incoming and outgoing radiation, and heat fluxes. Parameters ---------- df : pd.DataFrame A DataFrame with a DatetimeIndex and columns for energy components like 'SW_IN', 'LW_IN', 'NETRAD', 'G', 'LE', 'H'. date_text : str, optional The date and time for which to plot the energy balance. Defaults to "2024-06-19 12:00". logger : logging.Logger, optional A logger for outputting debug information. Defaults to None. Returns ------- go.Figure A Plotly Figure object containing the Sankey diagram. """ select_date = pd.to_datetime(date_text) swi = df.loc[select_date, "SW_IN"] lwi = df.loc[select_date, "LW_IN"] swo = df.loc[select_date, "SW_OUT"] lwo = df.loc[select_date, "LW_OUT"] nr = df.loc[select_date, "NETRAD"] shf = df.loc[select_date, "G"] le = df.loc[select_date, "LE"] h = df.loc[select_date, "H"] # Define the energy balance terms and their indices labels = [ "Incoming Shortwave Radiation", "Incoming Longwave Radiation", "Total Incoming Radiation", "Outgoing Shortwave Radiation", "Outgoing Longwave Radiation", "Net Radiation", "Ground Heat Flux", "Sensible Heat", "Latent Heat", "Residual", ] logger = logger_check(logger) logger.debug(f"Sensible Heat: {h}") rem = nr - (shf + h + le) ebr = (h + le) / (nr - shf) # Define the source and target nodes and the corresponding values for the energy flow source = [0, 1, 2, 2, 2, 5, 5, 5, 5] # Indices of the source nodes target = [2, 2, 5, 3, 4, 6, 7, 8, 9] # Indices of the target nodes # Define the source and target nodes and the corresponding values for the energy flow # source = [0, 1, 2, 2, 2, 5, 5, 5, 5] # Indices of the source nodes # target = [2, 2, 5, 3, 4, 6, 7, 8, 9] # Indices of the target nodes values = [lwi, swi, nr, swo, lwo, shf, h, le, rem] # Values of the energy flow # Create the Sankey diagram fig = go.Figure( data=[ go.Sankey( node=dict( pad=15, thickness=20, line=dict(color="black", width=0.5), label=labels, ), link=dict( source=source, target=target, value=values, ), ) ] ) # Update layout and title fig.update_layout( title_text=f"Energy Balance {ebr:0.2f} on {select_date:%Y-%m-%d}", font_size=10 ) # Show the figure # fig.show() return fig
[docs] def scatterplot_instrument_comparison(edmet, compare_dict, station, logger: logging.Logger = None): """ Generate a scatter plot comparing two instrument measurements. This function creates a scatter plot to compare measurements from two instruments, including a linear regression fit and a 1:1 reference line. Parameters ---------- edmet : pd.DataFrame A DataFrame with a DatetimeIndex containing the measurement data. compare_dict : dict A dictionary mapping instrument column names to their metadata. station : str The identifier for the station, used in the plot title. logger : logging.Logger, optional A logger for outputting regression statistics. Defaults to None. Returns ------- tuple A tuple containing the slope, intercept, R-squared, p-value, standard error, and the matplotlib Figure and Axes objects. """ # Compare two instruments instruments = list(compare_dict.keys()) df = edmet[instruments].replace(-9999, np.nan).dropna() df = df.resample("1h").mean().interpolate(method="linear") df = df.dropna() x = df[instruments[0]] y = df[instruments[1]] xinfo = compare_dict[instruments[0]] yinfo = compare_dict[instruments[1]] # one to one line xline = np.arange(df.min().min(), df.max().max(), 0.1) # Perform linear regression slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(x, y) # Predict y values y_pred = slope * x + intercept # R-squared r_squared = r_value**2 fig, ax = plt.subplots(figsize=(10, 8)) # Plot ax.scatter(x, y, alpha=0.5, s=1, label="Data points") ax.set_title(f"{xinfo[1]} Comparison: {station}") ax.plot(xline, xline, label="1:1 line", color="green", linestyle="--") ax.plot( x, y_pred, color="red", label=f"Fit: y = {slope:.2f}x + {intercept:.2f}\n$R^2$ = {r_squared:.3f}", ) plt.legend() plt.grid(True) ax.set_xlabel(f"{xinfo[0]} {xinfo[1]} ({xinfo[2]})") ax.set_ylabel(f"{yinfo[0]} {yinfo[1]} ({yinfo[2]})") plt.show() # Log results logger = logger_check(logger) logger.info(f"Slope: {slope:.3f}") logger.info(f"Intercept: {intercept:.3f}") logger.info(f"R-squared: {r_squared:.3f}") return slope, intercept, r_squared, p_value, std_err, fig, ax
[docs] def mean_squared_error(series1: pd.Series, series2: pd.Series) -> float: """ Calculate the Mean Squared Error (MSE) between two series. MSE is a measure of the average squared difference between the estimated values and the actual value. Parameters ---------- series1 : pd.Series The first data series. series2 : pd.Series The second data series. Returns ------- float The Mean Squared Error between the two series. Raises ------ ValueError If the input series are not of the same length. """ if len(series1) != len(series2): raise ValueError("Input Series must be of the same length.") return np.mean((series1 - series2) ** 2)
[docs] def mean_diff_plot( m1, m2, sd_limit=1.96, ax=None, scatter_kwds=None, mean_line_kwds=None, limit_lines_kwds=None, ): """ Construct a Tukey/Bland-Altman Mean Difference Plot. This plot shows the difference between two measurements against their mean, which is useful for assessing the agreement between two measurement methods. Parameters ---------- m1 : array_like A 1-D array of measurements. m2 : array_like A 1-D array of measurements. sd_limit : float, optional The number of standard deviations for the limits of agreement. Defaults to 1.96. ax : plt.Axes, optional An existing matplotlib Axes to draw the plot on. Defaults to None. scatter_kwds : dict, optional Keyword arguments for the scatter plot. Defaults to None. mean_line_kwds : dict, optional Keyword arguments for the mean difference line. Defaults to None. limit_lines_kwds : dict, optional Keyword arguments for the limits of agreement lines. Defaults to None. Returns ------- plt.Figure The matplotlib Figure object. """ fig, ax = plt.subplots(figsize=(10, 8)) if len(m1) != len(m2): raise ValueError("m1 does not have the same length as m2.") if sd_limit < 0: raise ValueError(f"sd_limit ({sd_limit}) is less than 0.") means = np.mean([m1, m2], axis=0) diffs = m1 - m2 mean_diff = np.mean(diffs) std_diff = np.std(diffs, axis=0) scatter_kwds = scatter_kwds or {} if "s" not in scatter_kwds: scatter_kwds["s"] = 20 mean_line_kwds = mean_line_kwds or {} limit_lines_kwds = limit_lines_kwds or {} for kwds in [mean_line_kwds, limit_lines_kwds]: if "color" not in kwds: kwds["color"] = "gray" if "linewidth" not in kwds: kwds["linewidth"] = 1 if "linestyle" not in mean_line_kwds: kwds["linestyle"] = "--" if "linestyle" not in limit_lines_kwds: kwds["linestyle"] = ":" ax.scatter(means, diffs, **scatter_kwds) # Plot the means against the diffs. ax.axhline(mean_diff, **mean_line_kwds) # draw mean line. # Annotate mean line with mean difference. ax.annotate( f"mean diff:\n{np.round(mean_diff, 2)}", xy=(0.99, 0.5), horizontalalignment="right", verticalalignment="center", fontsize=14, xycoords="axes fraction", ) if sd_limit > 0: half_ylim = (1.5 * sd_limit) * std_diff ax.set_ylim(mean_diff - half_ylim, mean_diff + half_ylim) limit_of_agreement = sd_limit * std_diff lower = mean_diff - limit_of_agreement upper = mean_diff + limit_of_agreement for j, lim in enumerate([lower, upper]): ax.axhline(lim, **limit_lines_kwds) ax.annotate( f"-{sd_limit} SD: {lower:0.2g}", xy=(0.99, 0.07), horizontalalignment="right", verticalalignment="bottom", fontsize=14, xycoords="axes fraction", ) ax.annotate( f"+{sd_limit} SD: {upper:0.2g}", xy=(0.99, 0.92), horizontalalignment="right", fontsize=14, xycoords="axes fraction", ) elif sd_limit == 0: half_ylim = 3 * std_diff ax.set_ylim(mean_diff - half_ylim, mean_diff + half_ylim) ax.set_ylabel("Difference", fontsize=15) ax.set_xlabel("Means", fontsize=15) ax.tick_params(labelsize=13) fig.tight_layout() return fig
[docs] def bland_alt_plot(edmet, compare_dict, station, alpha=0.5, logger: logging.Logger = None): """ Create a Bland-Altman plot to assess agreement between instruments. This function generates a Bland-Altman plot to visualize the agreement between two instruments, including the bias and limits of agreement. Parameters ---------- edmet : pd.DataFrame A DataFrame with a DatetimeIndex containing measurement data. compare_dict : dict A dictionary mapping instrument column names to their metadata. station : str The identifier for the station, used in the plot title. alpha : float, optional The transparency level for the plot elements. Defaults to 0.5. logger : logging.Logger, optional A logger for outputting statistics. Defaults to None. Returns ------- tuple[plt.Figure, plt.Axes] A tuple containing the matplotlib Figure and Axes objects. """ # Compare two instruments instruments = list(compare_dict.keys()) df = edmet[instruments].replace(-9999, np.nan).dropna() df = df.resample("1h").mean().interpolate(method="linear") df = df.dropna() x = df[instruments[0]] y = df[instruments[1]] rmse = np.sqrt(mean_squared_error(x, y)) logger = logger_check(logger) logger.info(f"RMSE: {rmse:.3f}") mean_vals = df[instruments].mean(axis=1) diff_vals = x - y bias = diff_vals.mean() spread = diff_vals.std() logger.info(f"Bias = {bias:.3f}, Spread = {spread:.3f}") top = diff_vals.mean() + 1.96 * diff_vals.std() bottom = diff_vals.mean() - 1.96 * diff_vals.std() f, ax = plt.subplots(1, figsize=(8, 5), alpha=alpha) mean_diff_plot(x, y, ax=ax) ax.text( mean_vals.mean(), top, s=compare_dict[instruments[0]][0], verticalalignment="top", fontweight="bold", ) ax.text( mean_vals.mean(), bottom, s=compare_dict[instruments[1]][0], verticalalignment="bottom", fontweight="bold", ) ax.set_title( f"{compare_dict[instruments[0]][0]} vs {compare_dict[instruments[1]][0]} at {station}" ) ax.set_xlabel( f"Mean {compare_dict[instruments[0]][1]} ({compare_dict[instruments[0]][2]})" ) ax.set_ylabel( f"Difference ({compare_dict[instruments[0]][2]})\n({compare_dict[instruments[0]][0]} - {compare_dict[instruments[1]][0]})", fontsize=10, ) return f, ax
# Example of filtering by date range
[docs] def plot_timeseries_daterange( input_df, selected_station, selected_field, start_date, end_date ) -> None: """ Plot a time series for a specific station and variable over a date range. This function filters a DataFrame by station and date range, and then plots the selected variable over time. Parameters ---------- input_df : pd.DataFrame A DataFrame with a MultiIndex ('station', 'timestamp'). selected_station : str The identifier of the station to plot. selected_field : str The name of the column (variable) to plot. start_date : str or pd.Timestamp The start date of the time range. end_date : str or pd.Timestamp The end date of the time range. """ global fig, ax # ax.clear() fig, ax = plt.subplots(figsize=(10, 8)) # Filter data by date range filtered_df = input_df.loc[selected_station].loc[start_date:end_date] filtered_df = filtered_df.loc[:, selected_field].replace(-9999, np.nan) # Plot each selected category ax.plot(filtered_df.index, filtered_df, label=selected_station, linewidth=2) plt.title(f"{selected_station} {selected_field}\n{start_date} to {end_date}") plt.xlabel("Date") plt.ylabel("Value") plt.legend() plt.grid(True) plt.xticks(rotation=45) plt.tight_layout() plt.show()
[docs] def save_plot(b) -> None: """ Save the current matplotlib figure to a file. This function is intended to be used as a callback for an interactive widget, such as a button in a Jupyter notebook. Parameters ---------- b : object The triggering widget event (not used in the function). """ # This line saves the plot as a .png file. Change it to .pdf to save as pdf. fig.savefig("plot.png")