Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Discrepant DKL with clusterMAF #38

Closed
ThomasGesseyJones opened this issue Aug 21, 2023 · 7 comments · Fixed by #39
Closed

Discrepant DKL with clusterMAF #38

ThomasGesseyJones opened this issue Aug 21, 2023 · 7 comments · Fixed by #39

Comments

@ThomasGesseyJones
Copy link
Collaborator

In the process of upgrading some existing code that used MAF to use clusterMAF, I found the DKL divergences newly computed using clusterMAF were significantly higher than those previously found using MAF. For a subset of my examples I had Polychord DKL values and so was able to determine this was not caused by the old MAF version of the code being poorly converged, as it was in agreement with Polychord whereas the new clusterMAF version was not.

A minimum working example of this problem with comparison to Analytic BMDs and DKLs can be constructed for a triangle distribution

# Imports
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import triang
from margarine.maf import MAF
from margarine.clustered import clusterMAF
from margarine.marginal_stats import calculate
import matplotlib.pyplot as plt

# Setup distribution
ANALYTIC_DKL = np.log(2) - 0.5
ANALYTIC_BMD = 0.5
triangle_data = np.array(triang.rvs(.5, loc=-1, scale=2, size=10000))
uniform_data = np.array(np.random.uniform(0, 1, size=10000))
data = np.vstack([triangle_data, uniform_data]).T

# MAF
weights = np.ones(len(data), dtype=np.float32)
bij = MAF(data, weights=weights)
bij.train(10000, early_stop=True)
x = bij.sample(5000)
stats = calculate(bij).statistics()
print(stats)
print(f'analytic DKL = {ANALYTIC_DKL}')
print(f'analytic BMD = {ANALYTIC_BMD}')

# Clustered MAF
bij = clusterMAF(data, weights=weights)
bij.train(10000, early_stop=True)
x = bij.sample(5000)
stats = calculate(bij).statistics()
print(stats)
print(f'analytic DKL = {ANALYTIC_DKL}')
print(f'analytic BMD = {ANALYTIC_BMD}')

The output I find for the MAF portion is in agreement with the analytic results

                  Value  Lower Bound  Upper Bound
Statistic                                        
KL Divergence  0.188605     0.186660     0.190233
BMD            0.501265     0.488464     0.519707
analytic DKL = 0.1931471805599453
analytic BMD = 0.5

Whereas for ClusterMAF I find

                  Value  Lower Bound  Upper Bound
Statistic                                        
KL Divergence  0.880778     0.879759     0.880913
BMD            0.560372     0.537915     0.582491
analytic DKL = 0.1931471805599453
analytic BMD = 0.5

with DKL coming out about 0.69 (~ln(2)) larger than expected.

@htjb
Copy link
Owner

htjb commented Aug 21, 2023

Hi @ThomasGesseyJones, thanks for spotting this! I think after some offline discussion, I agree that there is an error in the clusterMAF log-probability calculation which leads to the error in the KL divergence.

I had not clocked that the factor of ln(2) above is actually a factor of ln(n_cluster) as you highlighted offline. The log-probability for the cluster MAF was being calculated according to the sum of the probability of the sample evaluated for each cluster, in a similar way to the probability of a KDE being calculated as the sum over the kernel probabilities. However, in a KDE the probability is the mean over the kernels not the sum and we should be taking the mean in the cluster MAF as well. We take the mean to renormalise.

This may also help to tackle #37 which I initially thought was due to poor training.

To fix this, I believe you just need to add the following - np.log(self.cluster_number) to line 281 in clustered.py. Please make a PR for this!

Super good spot thank you!

Running your example locally with the above fix gives

                  Value  Lower Bound  Upper Bound
Statistic                                        
KL Divergence  0.178811     0.174948     0.182488
BMD            0.477987     0.468832     0.488763
analytic DKL = 0.1931471805599453
analytic BMD = 0.5

for the MAF and

                  Value  Lower Bound  Upper Bound
Statistic                                        
KL Divergence  0.174118     0.165567     0.182208
BMD            0.472542     0.455731     0.496967
analytic DKL = 0.1931471805599453
analytic BMD = 0.5

for the clusterMAF.

@ThomasGesseyJones
Copy link
Collaborator Author

Hi @htjb, thanks for looking at this so quickly. I will put together a PR with your suggested solution, and the casting changes we discussed offline.

The posterior probability being off by a constant factor $n_{\rm cluster}$ explains nicely why the BMD's were accurate while the KL divergence values were discrepant. Given
$$D_{\rm KL} \equiv \int \mathcal{P} \log \left(\frac{\mathcal{P}}{\pi} \right) d \theta,$$
which is implemented in practice via sampling from the posterior and then averaging over those samples
$$D_{\rm KL} = \left\langle \log \left(\frac{\mathcal{P}}{\pi} \right) \right \rangle_{\rm samples}.$$
So if $\mathcal{P}$ has a constant multiplicative error such that $\mathcal{P} \rightarrow C\mathcal{P}$ then
$$D_{\rm KL}^{\rm calc} \rightarrow \left\langle \log \left(\frac{\mathcal{CP}}{\pi} \right) \right \rangle_{\rm samples} = \left\langle \log(C) + \log \left(\frac{\mathcal{P}}{\pi} \right) \right \rangle_{\rm samples} = \log(C) + D_{\rm KL}^{\rm true},$$
also explaining why the errors in $D_{\rm KL}$ was $+ \log(n_{\rm cluster})$.

If we do the same for the BMD
$$\frac{d}{2} = \int \mathcal{P} \left( \log \left(\frac{\mathcal{P}}{\pi} \right) - D_{\rm KL} \right)^2 d \theta,$$
implemented using the same sampling technique of
$$\frac{d}{2} = \left\langle \left( \log \left(\frac{\mathcal{P}}{\pi} \right) - D_{\rm KL} \right)^2 \right \rangle_{\rm samples},$$
then the error cancels out
$$\frac{d^{\rm calc}}{2} \rightarrow \left\langle \left( \log \left(\frac{\mathcal{CP}}{\pi} \right) - \left(\log(C) + D_{\rm KL}^{\rm true}\right) \right)^2 \right \rangle_{\rm samples} = \frac{d^{\rm true}}{2} .$$
So it seems all past calculated BMD results can still be trusted and KL divergence values are easily correctable.

Out of curiosity, in your reply you said

However, in a KDE the probability is the mean over the kernels not the sum and we should be taking the mean in the cluster MAF as well.

How come it is the simple mean over the cluster flows, and not a mean weighted, by say the weights of the sample in the clusters? In clusterMAF.__call__ the flow choices seem to be weighted by the number of samples in each of them, so I am a bit confused about what the difference between these cases is. Sorry if these are rather simple questions, I am not very familiar with MAFs.

@htjb
Copy link
Owner

htjb commented Aug 22, 2023

Thanks for the explanation of the BMD consistency! This looks good.

How come it is the simple mean over the cluster flows, and not a mean weighted, by say the weights of the sample in the clusters?

So when we draw samples we want to draw them with a probability relative to the cluster size hence the weights. However, when calculating probabilities you want to sum over the clusters (note that each flow in the clusterMAF returns log-probabilities hence sum of exp log)
$$P(x) = \sum_k^{N_{cluster}} \exp \log P_k(x)$$
But each cluster probability is normalised so we need to renormalise the sum by dividing by $N_{cluster}$
$$P(x) = \frac{1}{N_{cluster}} \sum_k^{N_{cluster}} \exp \log P_k(x)$$
or in log space
$$\log P(x) = \log( \sum_k^{N_{cluster}} \exp \log P_k(x)) - \log N_{cluster}$$.

@ThomasGesseyJones
Copy link
Collaborator Author

ThomasGesseyJones commented Aug 23, 2023

After some offline discussion and experimentation, I think we have come to the conclusion that the $P_{\rm k}(x)$ from each flow should not be weighted equally, but instead in proportion to the total weights of samples in the corresponding cluster $\sum_i w^{k}_{i}$. Hence, the total probability equation should read

$$ P(x) = \frac{1}{\sum_{k, i} w^k_i}\sum_k \left( \frac{1}{\sum_i w^k_i} \exp \log P_{k} (x)\right), $$

or in log space

$$ \log P(x) = \log \left( \sum_k \exp \left(\log P_{k} (x) - \log \sum_i w^k_i\right)\right) - \log \left( \sum_{k, i} w^k_i \right). $$

Note: these equations contain a typo see discussion below

The need for this weighting is illustrated in the following example of a distribution composed of two unequally weighted Gaussians

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
from margarine.maf import MAF
from margarine.clustered import clusterMAF
from margarine.marginal_stats import calculate

min_allowed = -5
max_allowed = 5

# Setup distribution
first_gaussian_data = np.array(norm.rvs(loc=-2, scale=1, size=7500))
second_gaussian_data = np.array(norm.rvs(loc=2, scale=1, size=7500))
data = np.concatenate([first_gaussian_data, second_gaussian_data])
weights = np.ones(len(data), dtype=np.float32)
weights[:len(first_gaussian_data)] *= 3
weights = weights[(data > min_allowed) & (data < max_allowed)]
data = data[(data > min_allowed) & (data < max_allowed)]
data = data[:, np.newaxis]

# Plot distribution of samples
plt.hist(data, weights=weights, bins=100)
plt.xlabel('x')
plt.ylabel('Samples')
plt.show()

# Clustered MAF
bij = clusterMAF(data, weights=weights)
bij.train(10000, early_stop=True)
x = bij.sample(5000)
stats = calculate(bij).statistics()
print(stats)
plt.hist(np.squeeze(x.numpy()), bins=100)
plt.xlabel('x')
plt.ylabel('Samples')
plt.show()

xs = np.linspace(min_allowed, max_allowed, 1000)
xs = xs[:, np.newaxis]
posterior = np.exp(bij.log_prob(xs))
posterior[np.isnan(posterior)] = 0
print(f'Posterior Integral {np.trapz(posterior, dx=(max_allowed-min_allowed)/1000)}')
plt.plot(xs, posterior)
plt.xlabel('x')
plt.ylabel('P(x)')
plt.show()

Without the weighting clusterMAF.log_prob combines the two Gaussian cluster's probabilities equally leading to the posterior distribution it returns being erroneously symmetric. With the weighting the input distribution with one Gaussian 3 times higher weighted than the other is recovered.

A similar weighting is already implemented in clusterMAF.__call__ to determine the chances a sample should be drawn from each flow. However, it is proportion to the number of samples in each cluster rather than weights. If the samples are equally weighted this is equivalent, but if they are not equally weighted this can again lead to erroneous results with clusters with many low-weight samples being treated as more significant than they actually are. The same example code also illustrates this, finding erroneously symmetric sampling because the number of samples is the same between the Gaussians just their weights are different. Changing, the weighting in clusterMAF.__call__ to be by sample weights not sample numbers corrects this issue.

I will implement these weightings in my fork and create a PR after some testing verifies what I have done is working as expected.

@htjb
Copy link
Owner

htjb commented Aug 24, 2023

Thanks @ThomasGesseyJones this sounds good! I am curious should the total probability not be
$$P(x) = \frac{1}{\sum_{k}(\sum_i w^k_i)}\sum_k \left( (\sum_i w^k_i) \exp (\log P_{k} (x))\right),$$
since a weighted average of cluster probabilities is given by
$$\overline{x} = \frac{\sum_k w_k x_k}{\sum_k w_k}.$$

Here $\sum_i w^k_i$ is the cluster weight for cluster $k$ equivalent to $w_k$ in the above. $x_k$ is the cluster probability for the sample.

P.S. I added some extra brackets because I was getting confused and seperated out the sum in the denominator hahaa.

I will implement these weightings in my fork and create a PR

Yes please it would be good to have a look over the code! Thanks.

@ThomasGesseyJones
Copy link
Collaborator Author

Hi @htjb. You are quite right, there is an error in the equations I typed up. They should have read

$$ P(x) = \frac{1}{\sum_{k} \sum_{i} w^k_i}\sum_k \left( \left( \sum_i w^k_i \right) \exp \log P_{k} (x)\right), $$

or in log space

$$ \log P(x) = \log \left( \sum_k \exp \left(\log P_{k} (x) + \log \left(\sum_i w^k_i\right)\right)\right) - \log \left( \sum_{k} \sum_{i} w^k_i \right). $$

as you suggested. I have edited my earlier comment to add a warning for anyone reading through this thread, and to direct them to this discussion.

I just checked and while the equations I typed were wrong the implementation seems to be correct https://github.com/ThomasGesseyJones/margarine/commit/c77c95ad5beadbc563636d2902dc70c1958f5cb5

@htjb
Copy link
Owner

htjb commented Aug 25, 2023

Yeah I scanned the code before and it looked good! Please make a PR when you are ready 😄

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants