-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Discrepant DKL with clusterMAF #38
Comments
Hi @ThomasGesseyJones, thanks for spotting this! I think after some offline discussion, I agree that there is an error in the I had not clocked that the factor of ln(2) above is actually a factor of ln(n_cluster) as you highlighted offline. The log-probability for the cluster MAF was being calculated according to the sum of the probability of the sample evaluated for each cluster, in a similar way to the probability of a KDE being calculated as the sum over the kernel probabilities. However, in a KDE the probability is the mean over the kernels not the sum and we should be taking the mean in the cluster MAF as well. We take the mean to renormalise. This may also help to tackle #37 which I initially thought was due to poor training. To fix this, I believe you just need to add the following Super good spot thank you! Running your example locally with the above fix gives Value Lower Bound Upper Bound
Statistic
KL Divergence 0.178811 0.174948 0.182488
BMD 0.477987 0.468832 0.488763
analytic DKL = 0.1931471805599453
analytic BMD = 0.5 for the MAF and Value Lower Bound Upper Bound
Statistic
KL Divergence 0.174118 0.165567 0.182208
BMD 0.472542 0.455731 0.496967
analytic DKL = 0.1931471805599453
analytic BMD = 0.5 for the clusterMAF. |
Hi @htjb, thanks for looking at this so quickly. I will put together a PR with your suggested solution, and the casting changes we discussed offline. The posterior probability being off by a constant factor If we do the same for the BMD Out of curiosity, in your reply you said
How come it is the simple mean over the cluster flows, and not a mean weighted, by say the weights of the sample in the clusters? In |
Thanks for the explanation of the BMD consistency! This looks good.
So when we draw samples we want to draw them with a probability relative to the cluster size hence the weights. However, when calculating probabilities you want to sum over the clusters (note that each flow in the clusterMAF returns log-probabilities hence sum of exp log) |
After some offline discussion and experimentation, I think we have come to the conclusion that the or in log space Note: these equations contain a typo see discussion below The need for this weighting is illustrated in the following example of a distribution composed of two unequally weighted Gaussians import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
from margarine.maf import MAF
from margarine.clustered import clusterMAF
from margarine.marginal_stats import calculate
min_allowed = -5
max_allowed = 5
# Setup distribution
first_gaussian_data = np.array(norm.rvs(loc=-2, scale=1, size=7500))
second_gaussian_data = np.array(norm.rvs(loc=2, scale=1, size=7500))
data = np.concatenate([first_gaussian_data, second_gaussian_data])
weights = np.ones(len(data), dtype=np.float32)
weights[:len(first_gaussian_data)] *= 3
weights = weights[(data > min_allowed) & (data < max_allowed)]
data = data[(data > min_allowed) & (data < max_allowed)]
data = data[:, np.newaxis]
# Plot distribution of samples
plt.hist(data, weights=weights, bins=100)
plt.xlabel('x')
plt.ylabel('Samples')
plt.show()
# Clustered MAF
bij = clusterMAF(data, weights=weights)
bij.train(10000, early_stop=True)
x = bij.sample(5000)
stats = calculate(bij).statistics()
print(stats)
plt.hist(np.squeeze(x.numpy()), bins=100)
plt.xlabel('x')
plt.ylabel('Samples')
plt.show()
xs = np.linspace(min_allowed, max_allowed, 1000)
xs = xs[:, np.newaxis]
posterior = np.exp(bij.log_prob(xs))
posterior[np.isnan(posterior)] = 0
print(f'Posterior Integral {np.trapz(posterior, dx=(max_allowed-min_allowed)/1000)}')
plt.plot(xs, posterior)
plt.xlabel('x')
plt.ylabel('P(x)')
plt.show() Without the weighting A similar weighting is already implemented in I will implement these weightings in my fork and create a PR after some testing verifies what I have done is working as expected. |
Thanks @ThomasGesseyJones this sounds good! I am curious should the total probability not be Here P.S. I added some extra brackets because I was getting confused and seperated out the sum in the denominator hahaa.
Yes please it would be good to have a look over the code! Thanks. |
Hi @htjb. You are quite right, there is an error in the equations I typed up. They should have read or in log space as you suggested. I have edited my earlier comment to add a warning for anyone reading through this thread, and to direct them to this discussion. I just checked and while the equations I typed were wrong the implementation seems to be correct https://github.com/ThomasGesseyJones/margarine/commit/c77c95ad5beadbc563636d2902dc70c1958f5cb5 |
Yeah I scanned the code before and it looked good! Please make a PR when you are ready 😄 |
In the process of upgrading some existing code that used
MAF
to useclusterMAF
, I found the DKL divergences newly computed usingclusterMAF
were significantly higher than those previously found usingMAF
. For a subset of my examples I had Polychord DKL values and so was able to determine this was not caused by the oldMAF
version of the code being poorly converged, as it was in agreement with Polychord whereas the newclusterMAF
version was not.A minimum working example of this problem with comparison to Analytic BMDs and DKLs can be constructed for a triangle distribution
The output I find for the MAF portion is in agreement with the analytic results
Whereas for ClusterMAF I find
with DKL coming out about 0.69 (~ln(2)) larger than expected.
The text was updated successfully, but these errors were encountered: