Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pairwise convenience method for tables #123

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@ uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[extras]
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Random", "Test"]

[compat]
Tables = ">= 0.1.15"
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,15 @@ For performance reasons, it is recommended to use matrices with observations in
the ``Array`` type in Julia is column-major, making it more efficient to access memory column by column. However,
matrices with observations stored in rows are also supported via the argument ``dims=1``.

A convenience method is provided to compute pairwise distances between observations stored as rows in
any type of tabular data structure supported by the [Tables.jl](https://github.com/JuliaData/Tables.jl)
interface. Here is an example using a [`DataFrame`](https://github.com/JuliaData/DataFrames.jl):
```julia
using DataFrames
df = DataFrame(x = [1, 2, 3], y = [2, 5, 3])
pairwise(Euclidean(), df)
```

#### Computing column-wise and pairwise distances inplace

If the vector/matrix to store the results are pre-allocated, you may use the storage (without creating a new array) using the following syntax (``i`` being either ``1`` or ``2``):
Expand Down
3 changes: 2 additions & 1 deletion REQUIRE
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
julia 0.7-
julia 0.7-
Tables 0.1.15
1 change: 1 addition & 0 deletions src/Distances.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ module Distances

using LinearAlgebra
using Statistics
using Tables

export
# generic types/functions
Expand Down
33 changes: 33 additions & 0 deletions src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,17 @@ function deprecated_dims(dims::Union{Nothing,Integer})
end
end

"""
pairwise!(r::AbstractMatrix, metric::PreMetric,
a::AbstractMatrix, b::AbstractMatrix=a; dims)

Compute distances between each pair of rows (if `dims=1`) or columns (if `dims=2`)
in `a` and `b` according to distance `metric`, and store the result in `r`.
If a single matrix `a` is provided, compute distances between its rows or columns.

`a` and `b` must have the same numbers of columns if `dims=1`, or of rows if `dims=2`.
`r` must be a square matrix with size `size(a, dims) == size(b, dims)`.
"""
function pairwise!(r::AbstractMatrix, metric::PreMetric,
a::AbstractMatrix, b::AbstractMatrix;
dims::Union{Nothing,Integer}=nothing)
Expand Down Expand Up @@ -165,6 +176,15 @@ function pairwise!(r::AbstractMatrix, metric::PreMetric, a::AbstractMatrix;
end
end

"""
pairwise(metric::PreMetric, a::AbstractMatrix, b::AbstractMatrix=a; dims)

Compute distances between each pair of rows (if `dims=1`) or columns (if `dims=2`)
in `a` and `b` according to distance `metric`. If a single matrix `a` is provided,
compute distances between its rows or columns.

`a` and `b` must have the same numbers of columns if `dims=1`, or of rows if `dims=2`.
"""
function pairwise(metric::PreMetric, a::AbstractMatrix, b::AbstractMatrix;
dims::Union{Nothing,Integer}=nothing)
dims = deprecated_dims(dims)
Expand All @@ -183,3 +203,16 @@ function pairwise(metric::PreMetric, a::AbstractMatrix;
r = Matrix{result_type(metric, a, a)}(undef, n, n)
pairwise!(r, metric, a, dims=dims)
end

"""
pairwise(metric::PreMetric, t)

Compute distances between each pair of observations (i.e. rows) in table `t`
according to distance `metric`. `t` can be any type of table supported by
the [Tables.jl](https://github.com/JuliaData/Tables.jl) interface.
"""
function pairwise(metric::PreMetric, t::Any)
# TODO: avoid permuting using https://github.com/JuliaData/Tables.jl/pull/66
a = permutedims(Tables.matrix(t))
pairwise(metric, a, dims=2)
end
8 changes: 8 additions & 0 deletions test/test_dists.jl
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,14 @@ end
test_colwise(Mahalanobis(Q), X, Y, T)
end

@testset "pairwise Tables.jl interface" begin
t = [(a=1, b=2), (a=2, b=3), (a=0, b=5)]
a = [1 2; 2 3; 0 5]
@test pairwise(Euclidean(), t) == pairwise(Euclidean(), a, dims=1)

@test_throws ArgumentError pairwise(Euclidean(), [1])
end

function test_pairwise(dist, x, y, T)
@testset "Pairwise test for $(typeof(dist))" begin
nx = size(x, 2)
Expand Down