Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unified Wald constructor #85

Open
DominiqueMakowski opened this issue Jul 7, 2024 · 13 comments
Open

Unified Wald constructor #85

DominiqueMakowski opened this issue Jul 7, 2024 · 13 comments

Comments

@DominiqueMakowski
Copy link
Contributor

New issue to track and discuss the streamlining of the Wald API (merging of Wald and Mixture Wald) (#83 (comment))

@itsdfish
Copy link
Owner

itsdfish commented Jul 7, 2024

The benchmarks below show that the mixture model is about 6 times slower than the wald model due to sampling the drift rate, and extra terms. On a development branch I show that condionally executing the wald code when eta is zero achieves the desired speed up without time instability problems.

My plan is to drop WaldMixture in favor of a general Wald model. I will not merge this until later. I would prefer to make some breaking changes in bulk. In the meantime, you can just use WaldMixture with the parametric constraint that eta = 0 to achieve the special case.

using BenchmarkTools
using SequentialSamplingModels

wald_mixture = WaldMixture(;ν=3.0, η=.2, α=.5, τ=.130)
wald = Wald(;ν=3.0, α=.5, τ=.130)

@benchmark rand($wald_mixture, $1000)

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  110.102 μs …  1.529 ms  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     111.269 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   116.936 μs ± 20.719 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

  █▇▂▂▁    ▁        ▁                                          ▁
  ███████▆▇██▇▇█▆▆▆██▇▇█▇▆▅▇▇█▇▆▆▆▆▅▇█▆▆▆▆▆▇▆▆▆▆▇▇▇▇▅▆▇▇▇▇▆▆▆▆ █
  110 μs        Histogram: log(frequency) by time       178 μs <

 Memory estimate: 7.94 KiB, allocs estimate: 1.

 
@benchmark rand($wald, $1000)

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  19.509 μs … 361.751 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     20.727 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   21.309 μs ±   4.358 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

   ▄▆██▇▆▄▂                                                    ▂
  ██████████▇▇▆▆▆▆▆▆▆▄▅▅▅▃▅▇▇███▇▆▆▄▅▅▇▇█▇▇▇▇▇▇▇▆▅▄▃▁▄▃▁▃▃▅▃▄▄ █
  19.5 μs       Histogram: log(frequency) by time      35.7 μs <

 Memory estimate: 15.88 KiB, allocs estimate: 2.
 
rts = rand(wald_mixture, 1000)
@benchmark logpdf.($wald_mixture, $rts)

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  114.675 μs …  1.301 ms  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     114.859 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   119.654 μs ± 18.190 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

  █ ▂     ▁       ▁                                            ▁
  ███▇▅▆▆▇██▇▇▇▆▅▆██▇▇▇▆▆▆▇▇▆▇▇▆▅▆█▇▆▆▆▅▅▄▆▇▆▆▆▅▄▄▅▅▅▄▄▃▃▄▃▄▅▄ █
  115 μs        Histogram: log(frequency) by time       185 μs <

 Memory estimate: 7.94 KiB, allocs estimate: 1.


rts = rand(wald, 1000)
@benchmark logpdf.($wald, $rts)

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  17.573 μs … 321.755 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     17.740 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   18.959 μs ±   5.142 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

  █▅▂                       ▁   ▁▁                             ▁
  ████▅▅▆▆▅▅▅▅▄▄▅▅▄▄▄▅▅▅▃▁▃▆██████▇▅▅▅▅▅▆▆▇▇▇▆▅▅▆▅▅▄▄▄▄▆▅▆▇▇▆▅ █
  17.6 μs       Histogram: log(frequency) by time      38.1 μs <

 Memory estimate: 7.94 KiB, allocs estimate: 1.

@DominiqueMakowski
Copy link
Contributor Author

Two more semi-related things:

rand(ExGaussian(0.3, 0.2, 0),

  • ExGaussian with tau=0 currently errors, where one would expect it to be a Normal(). Falling back on Normal() if tau=0 would prevent the need for extra-care in prior specification (where one needs to exclude 0)

rand(ExGaussian(0.3, 0.0, 0.1), 100)

  • I would have expected, when sigma=0, that it returns the same values (as for the LogNormal or Normal), but it still has some variability. Maybe something to clarify as it's a bit unexpected

@itsdfish
Copy link
Owner

itsdfish commented Jul 7, 2024

  • Falling back on Normal() if tau=0

That is a good idea. I'll make that change.

  • but it still has some variability. Maybe something to clarify as it's a bit unexpected

I think this behavior is to be expected. X ~ normal(3, 0) + exponential(.10) simplifies to a shifted exponential where the var(X) = .10:

using SequentialSamplingModels
 var(rand(ExGaussian(0.3, 0.0, 0.1), 10_000))
0.010246564388110698

@DominiqueMakowski
Copy link
Contributor Author

I think this behavior is to be expected

Indeed that makes sense, that it would still have variability from the exp distrib

@itsdfish
Copy link
Owner

itsdfish commented Jul 7, 2024

I'll have to think about circumventing the error when tau = 0. Upon further thought, I think it might not be a good idea because the expontial distribution is not defined when tau = 0, which, I think, would imply the exguassian is not defined. Along the same lines, the logpdf is not defined with tau = 0 because it is used as a divisor. Right now I think the current implementation is correct.

@DominiqueMakowski
Copy link
Contributor Author

which, I think, would imply the exguassian is not defined

Although one could argue that if the exponential is "null" then only the Normal remains for Normal + Exp.
At the occasion we can check with other implementation see how they do (I just tried to test using brms but it doesn't give me access to the distribution constructor)

@itsdfish
Copy link
Owner

itsdfish commented Jul 7, 2024

I verified with gamlss in R:

> dexGAUS(1, mu = 1, sigma = 1, nu = 0, log = FALSE)
Error in dexGAUS(1, mu = 1, sigma = 1, nu = 0, log = FALSE) : 
  nu must be greater than 0  
 
> rexGAUS(1, mu = 1, sigma = 1, nu = 0, log = FALSE)
Error in rexGAUS(1, mu = 1, sigma = 1, nu = 0, log = FALSE) : 
  nu must be positive 

@DominiqueMakowski
Copy link
Contributor Author

For ExGaussian, the problem with it throwing a DomainError with τ=0 means that sampling often fails, even when specifying τ on a log-link and feeding as exp(τ) to ExGaussian() (see also #93 and #81).
My guess is that due to some numeric imprecision, Turing explores very large negative values which gets turned into 0 (instead of 0.00...) and make the whole thing error.

Would it make sense to return -Inf for the logpdf when τ=0, (mostly for convenience when used in Turing), otherwise it currently requires adding some safeguards to the model and forcing the logprob to be -Inf if exp(τ)==0

@itsdfish
Copy link
Owner

@DominiqueMakowski, please add the proposed fix via add SequentialSamplingModels#inf_fix to your environment to see if it solves your problem. If so, I will merge into main.

@DominiqueMakowski
Copy link
Contributor Author

I've ran a couple of times and it seems like it fixed the issue!

@itsdfish
Copy link
Owner

Awesome. I will merge the fix and release a new version shortly. Hopefully, it continues to work well with Pigeons.

@DominiqueMakowski
Copy link
Contributor Author

On a side note:
It seems like Wald is particularly prone to errors in Turing under fairly normal conditions, with

ERROR: DomainError with 0.0:
InverseGaussian: the condition λ > zero(λ) is not satisfied.

Despite having a link function on alpha that should prevent it from being 0. I'm not sure what the cause is, though.

Also, would it make sense (in terms of efficiency) to use the LocationScale() wrapper to create the distribution?

LocationScale(τ, 1, InverseGaussian(μ,λ))

@itsdfish
Copy link
Owner

itsdfish commented Dec 4, 2024

Numerical errors can be frustrating. It sounds like there are two potential points of error: (1) your link function, and (2) the reparameterization from Wald parameters to InverseGaussian parameters. One thing you could do is set up print statements to see where the problem is occurring and you could consider adding a check like x = max(x, eps()). Its worth noting that this will also mask errors other than numerical errors (e.g., there is a problem with your link function which causes large negative numbers).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants