Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MMD metric occasionally returns nan #855

Closed
gw265981 opened this issue Nov 8, 2024 · 4 comments · Fixed by #862
Closed

MMD metric occasionally returns nan #855

gw265981 opened this issue Nov 8, 2024 · 4 comments · Fixed by #862
Assignees
Labels
bug Something isn't working

Comments

@gw265981
Copy link
Contributor

gw265981 commented Nov 8, 2024

What's the problem?

The MMD metric between a dataset and a coreset sometimes returns nan when the dataset is relatively large (order of 10^5 points).

This is likely a precision error: kernel_xx_mean + kernel_yy_mean - 2 * kernel_xy_mean in MMD.compute evaluates to a small negative number in these cases, but larger than the precision threshold that was put in place to catch this. Then the square root turns it into nan.

The issue goes away when double precision is used: from jax import config; config.update("jax_enable_x64", True).

The simplest solution is to increase the precision tolerance threshold to catch these cases, but I am not sure what would be an appropriate value.

We could just truncate the expression at 0, but it could mask a bug down the line. Also the coreset may be relatively small (<10%) for this to happen and it is strange to have an MMD of 0 in that case.

Another solution is to increase precision to float64, but I am not sure about the performance and other implications of this.

How can we reproduce the issue?

# import relevant modules 

num_data_points = 10_000
num_features = 2
num_cluster_centers = 10
random_seed = 123
x, *_ = make_blobs(
    num_data_points,
    n_features=num_features,
    centers=num_cluster_centers,
    random_state=random_seed,
    return_centers=True,
    cluster_std=2,
)
data = Data(x)
coreset_size = 1000

num_samples_length_scale = min(num_data_points, 1_000)
generator = np.random.default_rng(random_seed)
idx = generator.choice(num_data_points, num_samples_length_scale, replace=False)
length_scale = median_heuristic(x[idx])
kernel = SquaredExponentialKernel(length_scale=length_scale)

herding_solver = KernelHerding(coreset_size = coreset_size, kernel = kernel)
herding_coreset, herding_state = eqx.filter_jit(herding_solver.reduce)(data)

mmd_kernel = SquaredExponentialKernel(
    length_scale=length_scale,
    output_scale=1.0 / (length_scale * jnp.sqrt(2.0 * jnp.pi)),
)

mmd_metric = MMD(kernel=mmd_kernel)

print(herding_coreset.compute_metric(mmd_metric))

Python version

3.12

Package version

0.3.0

Operating system

Windows 10

Other packages

No response

Relevant log output

No response

@gw265981 gw265981 added bug Something isn't working new Something yet to be discussed by development team labels Nov 8, 2024
@pc532627 pc532627 removed the new Something yet to be discussed by development team label Nov 11, 2024
@gw265981 gw265981 self-assigned this Nov 12, 2024
@gw265981
Copy link
Contributor Author

It appears the root of the problem lies in catastrophic cancellation, which happens when subtracting two approximations that are very close in value. The resulting approximation of the subtraction can have a very large precision error even if the individual approximations are good. E.g., take 2 numbers with 8 significant figures that coincide in the first 5 s.f.: 0.20802966, 0.20802812. Their difference is now 0.00000154, so we went from 8 s.f. to 3 s.f. If the original numbers are approximations, the resulting subtraction can vary a lot.

In our case we have MMD = K_xx + K_yy - 2*K_xy. MMD measures a distance between distributions, so it will be close to 0 if the distributions are close, leading to catastrophic cancellation. To make matters worse, MMD seems to decrease when the number of data points, N, is large even if the coreset stays at a constant proportion of N:

mmd_vs_N

I am still not sure what the best solution here would be. Some ideas:

  1. Rewrite the MMD algorithm such that the difference is computed more directly. I don't know how feasible/easy this would be, so it would likely take a bit of time to research.

  2. Increase the precision of computation. Not very feasible as it appears you have to set precision globally in JAX: Possible to set float32 as defeault, but use float64 for some parts of calculation? jax-ml/jax#19443. We can warn users to enable this config in scripts where precision is important.

  3. Increase the precision threshold/set the floor to 0 for MMD. A temporary hack if we want to postpone thinking about the full solution for now.

Logs for the example above:

K_xx:
32-bit: float32, mean time: 0.6503s, std: 0.2882
64-bit: float64, mean time: 1.3937s, std: 0.3300
Relative difference: 0.000273%
K_xy:
32-bit: float32, mean time: 0.0779s, std: 0.0117
64-bit: float64, mean time: 0.0956s, std: 0.0179
Relative difference: 0.000242%
K_yy:
32-bit: float32, mean time: 0.0615s, std: 0.0022
64-bit: float64, mean time: 0.0583s, std: 0.0080
Relative difference: 0.000027%

K_xx + K_yy - 2*K_xy:
float32: -1.341104507446289e-07
float64: 6.73454663216444e-08
Relative difference: 299.138053%

@tp832944
Copy link
Contributor

Setting the floor to zero (option 3) seems like the obvious option to me when it's numerical analysis explaining this - we can't just go on a threshold setting law combatting physics. We may want some safety check in place to guard against a coding bug.

@gw265981
Copy link
Contributor Author

I agree that if we go that route, just truncating to 0 is the best option. I assume the threshold was there originally to allow some precision tolerance but detect if something weird is happening - and I think we understand what is happening now.

@pc532627
Copy link
Contributor

@tp832944 @gw265981 I agree with setting the floor to zero - happy for code updates to be made to implement this.

@gw265981 gw265981 linked a pull request Nov 13, 2024 that will close this issue
9 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants