-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cost Matrix function #104
Cost Matrix function #104
Changes from all commits
30de49c
1a03325
4a1f380
bdc1b5b
416dcb4
f593377
21d38a8
e17bba5
10e8849
52b3c7a
7d2924d
7cf44a6
4764b00
98784c5
1fb0fc1
808d6ac
3415386
d373d52
933c106
63af17a
f952be8
6d812e3
97f5d77
fd54e9f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,3 +36,8 @@ sinkhorn_unbalanced2 | |
```@docs | ||
quadreg | ||
``` | ||
|
||
## Utilities | ||
```@docs | ||
cost_matrix | ||
``` |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -106,3 +106,66 @@ end | |
|
||
Distributions.support(d::FiniteDiscreteMeasure) = d.support | ||
Distributions.probs(d::FiniteDiscreteMeasure) = d.p | ||
|
||
""" | ||
cost_matrix( | ||
c, | ||
μ::Union{FiniteDiscreteMeasure, DiscreteNonParametric}, | ||
ν::Union{FiniteDiscreteMeasure, DiscreteNonParametric} | ||
) | ||
|
||
Compute cost matrix from Finite Discrete Measures `μ` and `ν` using cost function `c`. | ||
|
||
Note that the use of functions such as `SqEuclidean()` from `Distances.jl` have | ||
better performance than generic functions. Thus, it's prefered to use | ||
`cost_matrix(SqEuclidean(), μ, ν)`, instead of `cost_matrix((x,y)->sum((x-y).^2), μ, ν)` | ||
or even `cost_matrix(sqeuclidean, μ, ν)`. | ||
|
||
For custom cost functions, it is necessary to guarantee that the function `c` works | ||
on vectors, i.e., if one wants to compute the squared Euclidean distance, | ||
the one must define `c(x,y) = sum((x - y).^2)`. | ||
|
||
# Example | ||
```julia | ||
μ = discretemeasure(rand(10),normalize!(rand(10),1)) | ||
ν = discretemeasure(rand(8)) | ||
c = TotalVariation() | ||
C = cost_matrix(c, μ, ν) | ||
``` | ||
""" | ||
function cost_matrix( | ||
c, | ||
μ::Union{FiniteDiscreteMeasure,DiscreteNonParametric}, | ||
ν::Union{FiniteDiscreteMeasure,DiscreteNonParametric}, | ||
) | ||
if typeof(c) <: PreMetric && length(μ.support[1]) == 1 | ||
return pairwise(c, vcat(μ.support...), vcat(ν.support...)) | ||
elseif typeof(c) <: PreMetric && length(μ.support[1]) > 1 | ||
return pairwise(c, vcat(μ.support'...), vcat(ν.support'...); dims=1) | ||
else | ||
return pairwise(c, μ.support, ν.support) | ||
end | ||
end | ||
Comment on lines
+136
to
+148
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we would want to restrict this to pairs of marginals with the same type of the support. Or at least the dimension should match in the case of arrays and scalars. So we would want a more fine-grained function signature. In general, it would be better to avoid the type checks in the function definition since it makes it more difficult to extend and specialize the method. I think it would be better to just
Also one should avoid the splatting of support, it will lead to massive compile times and inference problems for larger arrays. The same comments apply to the implementation below. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm. So the splatter is inefficient. Now, how should one efficiently construct matrix There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As mentioned above, if you work with ColVecs or RowVecs you actually don't want to construct a matrix at all. But if you deal with an actual vector of vectors, then usually you would use e.g. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree. But since we didn't require |
||
|
||
""" | ||
cost_matrix( | ||
c, | ||
μ::Union{FiniteDiscreteMeasure, DiscreteNonParametric}, | ||
symmetric = false | ||
) | ||
|
||
Compute cost matrix from Finite Discrete Measures `μ` to itself using cost function `c`. | ||
If the cost function is symmetric, set the argument `symmetric` to `true` in order | ||
to increase performance. | ||
""" | ||
function cost_matrix( | ||
c, μ::Union{FiniteDiscreteMeasure,DiscreteNonParametric}; symmetric=false | ||
) | ||
if typeof(c) <: PreMetric && length(μ.support[1]) == 1 | ||
return pairwise(c, vcat(μ.support...)) | ||
elseif typeof(c) <: PreMetric && length(μ.support[1]) > 1 | ||
return pairwise(c, vcat(μ.support'...); dims=1) | ||
else | ||
return pairwise(c, μ.support; symmetric=symmetric) | ||
end | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
using LinearAlgebra: symmetric | ||
using OptimalTransport | ||
using Pkg: Pkg | ||
using SafeTestsets | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure if this should be exposed to users.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I remember you said this. Can you expand on why you think so? As a user, I'd like to have access to this function, since if I wanted to create the cost matrix, it would save me some time (I always misuse
pairwise
at first). I don't see the downside in making this available.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, do you think the
cost_matrix
function is useful (even if only for internal use)? Cause I remember you were not sure about it.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem is that the best implementation would be to just forward everything to
pairwise
since the type of the support and the cost function should know best what to do. E.g., if you useColVecs
orRowVecs
then there is no need to concatenate vectors, you can just callpairwise
with the underlying matrix. This is also how it is implemented in KernelFunctions and it is much much more efficient than extracting and combining all columns or rows.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, but if I construct a
FiniteDiscreteMeasure
withoutKernelFunctions
, then there is no matrix attribute, this is why I wrote like that. Sorry, I don't understand your point that the best implementation would be to just forward everything topairwise
. I mean, the reason forcost_matrix
is that I wouldn't have to deal with the variations. For example, if my cost function is a personalized function, then, for example,pairwise(sqeuclidean, mu.support, nu.support)
behaves differently thanpairwise(SqEuclidean(), mu.support, nu.support)
, which would require, for example, addingdims=1
and transforming the support to a matrix.What I'd like to do is to write a function that deals with all these varying cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's exactly my point, you don't know what type of vectors is used (e.g., whether it is a
ColVecs
or aRowVecs
), so oftenreduce(hcat, mu.support)
can be a very inefficient and suboptimal choice. If instead you would just usepairwise(c, mu.support, nu.support)
then you could make use of the optimizations in packages such as KernelFunctions. So my main point is just that probably this is handled here on the wrong level - the packages that define e.g.ColVecs
andRowVecs
should define howpairwise
is handled and make sure that it is efficient since we can't handle all possible types here.In general though I am a bit surprised about the problems you mention with
SqEuclidean
. All the desired cases work automatically due to https://github.com/JuliaStats/Distances.jl/blob/b52f0a10017553b311a9c9eed6f96e34a5629c2f/src/generic.jl#L333-L351 (even though it is not optimized forColVecs
but IMO that's an issue of KernelFunctions or the separate package where they should be moved):There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(
sqeuclidean
works as well but uses the fallback in StatsBase - IMO this should be changed in Distances)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting. I was getting an error when using
mu.support
without the splatter, but I was using thedims
argument. So perhaps that was the issue. If that is so, then I agree with you that thecost_matrix
function is not necessary.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll close the PR. Thanks for the inputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW I made a PR to Distances that would fix the SqEuclidean/sqeuclidean discrepancy: JuliaStats/Distances.jl#224