Skip to content

Commit

Permalink
Merge pull request #160 from QuantEcon/sl/update_LinInterp
Browse files Browse the repository at this point in the history
ENH: LinInterp for mulitple functions
  • Loading branch information
sglyon authored Jul 7, 2017
2 parents e305bae + d814043 commit 961c4c7
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 7 deletions.
4 changes: 4 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ language: julia
sudo: false
julia:
- 0.5
- 0.6
- nightly
matrix:
allow_failures:
- julia: nightly
notifications:
email: false
#script: # use the default script setting which is equivalent to the following
Expand Down
93 changes: 87 additions & 6 deletions src/interp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@ li.([0.1, 0.2, 0.3])
```
"""
immutable LinInterp{TB<:AbstractVector,TV<:AbstractVector}
immutable LinInterp{TV<:AbstractArray,TB<:AbstractVector}
breaks::TB
vals::TV
_n::Int
_ncol::Int

function (::Type{LinInterp{TB,TV}}){TB,TV}(b::TB, v::TV)
function (::Type{LinInterp{TV,TB}}){TB,TV}(b::TB, v::TV)
if size(b, 1) != size(v, 1)
m = "breaks and vals must have same number of elements"
throw(DimensionMismatch(m))
Expand All @@ -38,15 +39,19 @@ immutable LinInterp{TB<:AbstractVector,TV<:AbstractVector}
m = "breaks must be sorted"
throw(ArgumentError(m))
end
new{TB,TV}(b, v, length(b))
new{TV,TB}(b, v, length(b), size(v, 2))
end
end

function LinInterp{TB<:AbstractVector,TV<:AbstractVector}(b::TB, v::TV)
LinInterp{TB,TV}(b, v)
function Base.:(==)(li1::LinInterp, li2::LinInterp)
all(getfield(li1, f) == getfield(li2, f) for f in fieldnames(li1))
end

@compat function (li::LinInterp)(xp::Number)
function LinInterp{TV<:AbstractArray,TB<:AbstractVector}(b::TB, v::TV)
LinInterp{TV,TB}(b, v)
end

@compat function (li::LinInterp{<:AbstractVector})(xp::Number)
ix = searchsortedfirst(li.breaks, xp)

# handle corner cases
Expand All @@ -61,6 +66,73 @@ end
end
end

@compat function (li::LinInterp{<:AbstractMatrix})(xp::Number, col::Int)
ix = searchsortedfirst(li.breaks, xp)
@boundscheck begin
if col > li._ncol || col < 1
msg = "col must be beteween 1 and $(li._ncol), found $col"
throw(BoundsError(msg))
end
end

@inbounds begin
# handle corner cases
ix == 1 && return li.vals[1, col]
ix == li._n + 1 && return li.vals[end, col]

# now get on to the real work...
z = (li.breaks[ix] - xp)/(li.breaks[ix] - li.breaks[ix-1])

return (1-z) * li.vals[ix, col] + z * li.vals[ix-1, col]
end
end

_out_eltype{TV,TB}(li::LinInterp{TV,TB}) = promote_type(eltype(TV), eltype(TB))

@compat function (li::LinInterp{<:AbstractMatrix})(
xp::Number, cols::AbstractVector{<:Integer}
)
ix = searchsortedfirst(li.breaks, xp)
@boundscheck begin
for col in cols
if col > li._ncol || col < 1
msg = "all cols must be beteween 1 and $(li._ncol), found $col"
throw(BoundsError(msg))
end
end
end

out = Array{_out_eltype(li)}(length(cols))

@inbounds begin
# handle corner cases
if ix == 1
for col in cols
out[col] = li.vals[1, col]
end
return out
end

if ix == li._n + 1
for col in cols
out[col] = li.vals[end, col]
end
return out
end

# now get on to the real work...
z = (li.breaks[ix] - xp)/(li.breaks[ix] - li.breaks[ix-1])

for col in cols
out[col] = (1-z) * li.vals[ix, col] + z * li.vals[ix-1, col]
end

return out
end
end

@compat (li::LinInterp{<:AbstractMatrix})(xp::Number) = li(xp, 1:li._ncol)

"""
interp(grid::AbstractVector, function_vals::AbstractVector)
Expand Down Expand Up @@ -88,3 +160,12 @@ function interp(grid::AbstractVector, function_vals::AbstractVector)
return LinInterp(grid, function_vals)
end
end

function interp(grid::AbstractVector, function_vals::AbstractMatrix)
if !issorted(grid)
inds = sortperm(grid)
return LinInterp(grid[inds], function_vals[inds, :])
else
return LinInterp(grid, function_vals)
end
end
33 changes: 32 additions & 1 deletion test/test_interp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,49 +3,77 @@
# uniform interpolation
breaks = linspace(-3, 3, 100)
vals = exp.(breaks)
vals2 = [vals sin.(breaks)]

li = interp(breaks, vals)
li2 = LinInterp(breaks, vals)

li_mat = interp(breaks, vals2)
li_mat2 = LinInterp(breaks, vals2)

# test constructor
@test li == li2
@test li_mat == li_mat2

# make sure evaluation is inferrable
for T in (Float64, Float32, Float16, Int64, Int32, Int16)
@inferred li(one(T))
@test begin
@inferred li(one(T))
true
end
@test begin
@inferred li_mat(one(T))
true
end
end

# on grid is exact
for i in 1:length(breaks)
@test abs(li(breaks[i]) - vals[i]) < 1e-15
@test all(abs.(li_mat(breaks[i]) - vals2[i, :] .< 1e-15))
end

# off grid is close
for x in linspace(-3, 3, 300)
@test abs(li(x) - exp(x)) < 1e-2
@test all(abs.(li_mat(x) .- [exp(x), sin(x)]) .< 1e-2)
@test li(x) li_mat(x, 1)
end

# test errors for col spec for li_mat being wrong
@test_throws BoundsError li_mat(0.5, 0)
@test_throws BoundsError li_mat(0.5, 3)
@test_throws BoundsError li_mat(0.5, [0, 1])
@test_throws BoundsError li_mat(0.5, [2, 3])


# non-uniform
breaks = cumsum(0.1 .* rand(20))
vals = 0.1 .* map(sin, breaks)
li = interp(breaks, vals)
li_mat = interp(breaks, [vals vals+1])

# on grid is exact
for i in 1:length(breaks)
@test abs(li(breaks[i]) - vals[i]) < 1e-15
@test all(abs.(li_mat(breaks[i]) .- [vals[i], vals[i]+1]) .< 1e-15)
end

# off grid is close
for x in linspace(extrema(breaks)..., 30)
@test abs(li(x) - 0.1*sin(x)) < 1e-2
@test all(abs.(li_mat(x) - [0.1*sin(x), 0.1*sin(x)+1]) .< 1e-2)

end

# un-sorted works for `interp` function, but not `LinInterp`
breaks = rand(10)
vals = map(sin, breaks)

@inferred interp(breaks, vals)
@inferred interp(breaks, [vals vals+1])
@test_throws ArgumentError LinInterp(breaks, vals)
@test_throws ArgumentError LinInterp(breaks, [vals vals+1])

# dimension mismatch
breaks = cumsum(rand(10))
Expand All @@ -54,4 +82,7 @@
@test_throws DimensionMismatch interp(breaks, vals)
@test_throws DimensionMismatch LinInterp(breaks, vals)

@test_throws DimensionMismatch interp(breaks, [vals vals+1])
@test_throws DimensionMismatch LinInterp(breaks, [vals vals+1])

end # @testset

0 comments on commit 961c4c7

Please sign in to comment.