Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update README.md with forum #116

Merged
merged 7 commits into from
Dec 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ For starters, check out some of our walk-through notebooks:
7. [Model comparison for cognitive models](examples/Model_Comparison_MPT.ipynb)
8. [Hierarchical model comparison for cognitive models](examples/Hierarchical_Model_Comparison_MPT.ipynb)

## Project Documentation
## Documentation \& Help

The project documentation is available at <https://bayesflow.org>
The project documentation is available at <https://bayesflow.org>. Please use the [BayesFlow Forums](https://discuss.bayesflow.org/) for any BayesFlow-related questions and discussions, and [GitHub Issues](https://github.com/stefanradev93/BayesFlow/issues) for bug reports and feature requests.

## Installation

Expand Down
23 changes: 14 additions & 9 deletions bayesflow/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def plot_recovery(
n_row=None,
xlabel="Ground truth",
ylabel="Estimated",
**kwargs
**kwargs,
):
"""Creates and plots publication-ready recovery plot with true vs. point estimate + uncertainty.
The point estimate can be controlled with the ``point_agg`` argument, and the uncertainty estimate
Expand Down Expand Up @@ -110,7 +110,7 @@ def plot_recovery(
**kwargs : optional
Additional keyword arguments passed to ax.errorbar or ax.scatter.
Example: `rasterized=True` to reduce PDF file size with many dots

Returns
-------
f : plt.Figure - the figure instance for optional saving
Expand Down Expand Up @@ -240,7 +240,7 @@ def plot_z_score_contraction(
tick_fontsize=12,
color="#8f2727",
n_col=None,
n_row=None
n_row=None,
):
"""Implements a graphical check for global model sensitivity by plotting the posterior
z-score over the posterior contraction for each set of posterior samples in ``post_samples``
Expand Down Expand Up @@ -567,7 +567,7 @@ def plot_sbc_histograms(
tick_fontsize=12,
hist_color="#a34f4f",
n_row=None,
n_col=None
n_col=None,
):
"""Creates and plots publication-ready histograms of rank statistics for simulation-based calibration
(SBC) checks according to [1].
Expand Down Expand Up @@ -910,7 +910,7 @@ def plot_losses(
for i, ax in enumerate(looper):
# Plot train curve
ax.plot(train_step_index, train_losses.iloc[:, i], color=train_color, lw=lw_train, alpha=0.9, label="Training")
if moving_average:
if moving_average and train_losses.columns[i] == "Loss":
moving_average_window = int(train_losses.shape[0] * ma_window_fraction)
smoothed_loss = train_losses.iloc[:, i].rolling(window=moving_average_window).mean()
ax.plot(train_step_index, smoothed_loss, color="grey", lw=lw_train, label="Training (Moving Average)")
Expand All @@ -929,7 +929,7 @@ def plot_losses(
)
# Schmuck
ax.set_xlabel("Training step #", fontsize=label_fontsize)
ax.set_ylabel("Loss value", fontsize=label_fontsize)
ax.set_ylabel("Value", fontsize=label_fontsize)
sns.despine(ax=ax)
ax.grid(alpha=grid_alpha)
ax.set_title(train_losses.columns[i], fontsize=title_fontsize)
Expand Down Expand Up @@ -1061,7 +1061,7 @@ def plot_calibration_curves(
fig_size=None,
color="#8f2727",
n_row=None,
n_col=None
n_col=None,
):
"""Plots the calibration curves, the ECEs and the marginal histograms of predicted posterior model probabilities
for a model comparison problem. The marginal histograms inform about the fraction of predictions in each bin.
Expand Down Expand Up @@ -1114,7 +1114,6 @@ def plot_calibration_curves(
elif n_row is not None and n_col is None:
n_col = int(np.ceil(num_models / n_row))


# Compute calibration
cal_errs, probs_true, probs_pred = expected_calibration_error(true_models, pred_models, num_bins)

Expand Down Expand Up @@ -1273,7 +1272,13 @@ def plot_confusion_matrix(
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
ax.text(
j, i, format(cm[i, j], fmt), fontsize=value_fontsize, ha="center", va="center", color="white" if cm[i, j] > thresh else "black"
j,
i,
format(cm[i, j], fmt),
fontsize=value_fontsize,
ha="center",
va="center",
color="white" if cm[i, j] > thresh else "black",
)
if title:
ax.set_title("Confusion Matrix", fontsize=title_fontsize)
Expand Down
5 changes: 3 additions & 2 deletions bayesflow/summary_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def __init__(
# Construct final attention layer, which will perform cross-attention
# between the outputs ot the self-attention layers and the dynamic template
if bidirectional:
final_input_dim = template_dim*2
final_input_dim = template_dim * 2
else:
final_input_dim = template_dim
self.output_attention = MultiHeadAttentionBlock(
Expand Down Expand Up @@ -184,7 +184,8 @@ def call(self, x, **kwargs):

class SetTransformer(tf.keras.Model):
"""Implements the set transformer architecture from [1] which ultimately represents
a learnable permutation-invariant function.
a learnable permutation-invariant function. Designed to naturally model interactions in
the input set, which may be hard to capture with the simpler ``DeepSet`` architecture.

[1] Lee, J., Lee, Y., Kim, J., Kosiorek, A., Choi, S., & Teh, Y. W. (2019).
Set transformer: A framework for attention-based permutation-invariant neural networks.
Expand Down