Skip to content

Commit

Permalink
Handle Dates-array arithmetic with promote_op
Browse files Browse the repository at this point in the history
  • Loading branch information
timholy committed Jul 29, 2015
1 parent 6933adb commit 62b1753
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 73 deletions.
57 changes: 24 additions & 33 deletions base/dates/arithmetic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,41 +66,32 @@ end
(+)(y::Period,x::TimeType) = x + y
(-)(y::Period,x::TimeType) = x - y

for op in (:.+, :.-)
op_ = symbol(string(op)[2:end])
@eval begin
# GeneralPeriod, AbstractArray{TimeType}
($op){T<:TimeType}(x::AbstractArray{T}, y::GeneralPeriod) =
reshape(T[($op_)(i,y) for i in x], size(x))
($op){T<:TimeType}(y::GeneralPeriod, x::AbstractArray{T}) = ($op)(x,y)
($op_){T<:TimeType}(x::AbstractArray{T}, y::GeneralPeriod) = ($op)(x,y)
($op_){T<:TimeType}(y::GeneralPeriod, x::AbstractArray{T}) = ($op)(x,y)
# To avoid a switch to CompoundPeriod, we collect ranges
(-){T<:TimeType}(x::OrdinalRange{T}, y::OrdinalRange{T}) = collect(x) - collect(y)
(-){T<:TimeType}(x::Range{T}, y::Range{T}) = collect(x) - collect(y)

# TimeType, StridedArray{GeneralPeriod}
($op){T<:TimeType,P<:GeneralPeriod}(x::StridedArray{P}, y::T) =
reshape(T[($op_)(i,y) for i in x], size(x))
($op){P<:GeneralPeriod}(y::TimeType, x::StridedArray{P}) = ($op)(x,y)
($op_){T<:TimeType,P<:GeneralPeriod}(x::StridedArray{P}, y::T) = ($op)(x,y)
($op_){P<:GeneralPeriod}(y::TimeType, x::StridedArray{P}) = ($op)(x,y)
# promotion rules

# AbstractArray{TimeType}, StridedArray{GeneralPeriod}
($op_){T<:TimeType,P<:GeneralPeriod}(x::Range{T}, y::StridedArray{P}) = ($op_)(collect(x),y)
($op_){T<:TimeType,P<:GeneralPeriod}(x::AbstractArray{T}, y::StridedArray{P}) =
reshape(TimeType[($op_)(x[i],y[i]) for i in eachindex(x, y)], promote_shape(size(x),size(y)))
($op_){T<:TimeType,P<:GeneralPeriod}(y::StridedArray{P}, x::AbstractArray{T}) = ($op_)(x,y)
for (op,F) in ((:+, Base.AddFun),
(:-, Base.SubFun),
(:.+, Base.DotAddFun),
(:.-, Base.DotSubFun))
@eval begin
Base.promote_op{P<:Period}(::$F, ::Type{P}, ::Type{P}) = P
Base.promote_op{P1<:Period,P2<:Period}(::$F, ::Type{P1}, ::Type{P2}) = CompoundPeriod
Base.promote_op{D<:Date}(::$F, ::Type{D}, ::Type{D}) = Day
Base.promote_op{D<:DateTime}(::$F, ::Type{D}, ::Type{D}) = Millisecond
end
end

# TimeType, AbstractArray{TimeType}
(.-){T<:TimeType}(x::AbstractArray{T}, y::T) = reshape(Period[i - y for i in x], size(x))
(.-){T<:TimeType}(y::T, x::AbstractArray{T}) = -(x .- y)
(-){T<:TimeType}(x::AbstractArray{T}, y::T) = x .- y
(-){T<:TimeType}(y::T, x::AbstractArray{T}) = -(x .- y)

# AbstractArray{TimeType}, AbstractArray{TimeType}
(-){T<:TimeType}(x::OrdinalRange{T}, y::OrdinalRange{T}) = collect(x) - collect(y)
(-){T<:TimeType}(x::Range{T}, y::Range{T}) = collect(x) - collect(y)
(-){T<:TimeType}(x::AbstractArray{T}, y::Range{T}) = y - collect(x)
(-){T<:TimeType}(x::Range{T}, y::AbstractArray{T}) = collect(x) - y
(-){T<:TimeType}(x::AbstractArray{T}, y::AbstractArray{T}) =
reshape(Period[x[i] - y[i] for i in eachindex(x, y)], promote_shape(size(x),size(y)))
for (op,F) in ((:/, Base.RDivFun),
(:%, Base.RemFun),
(:div, Base.IDivFun),
(:mod, Base.ModFun),
(:./, Base.DotRDivFun),
(:.%, Base.DotRemFun))
@eval begin
Base.promote_op{P<:Period}(::$F, ::Type{P}, ::Type{P}) = typeof($op(1,1))
Base.promote_op{P<:Period,R<:Real}(::$F, ::Type{P}, ::Type{R}) = P
end
end
43 changes: 4 additions & 39 deletions base/dates/periods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,34 +43,14 @@ for op in (:+,:-,:lcm,:gcd)
@eval ($op){P<:Period}(x::P,y::P) = P(($op)(value(x),value(y)))
end

for op in (:/,:%,:div,:mod)
for op in (:/,:%,:div,:mod,:./,:.%)
@eval begin
($op){P<:Period}(x::P,y::P) = ($op)(value(x),value(y))
($op){P<:Period}(x::P,y::Real) = P(($op)(value(x),Int64(y)))
end
end
/{P<:Period}(X::StridedArray{P}, y::P) = X ./ y
%{P<:Period}(X::StridedArray{P}, y::P) = X .% y
*{P<:Period}(x::P,y::Real) = P(value(x) * Int64(y))
*(y::Real,x::Period) = x * y
.*{P<:Period}(y::Real, X::StridedArray{P}) = X .* y
for (op,Ty,Tz) in ((:.*,Real,:P),
(:./,:P,Float64), (:./,Real,:P),
(:.%,:P,Int64), (:.%,Integer,:P),
(:div,:P,Int64), (:div,Integer,:P),
(:mod,:P,Int64), (:mod,Integer,:P))
sop = string(op)
op_ = sop[1] == '.' ? symbol(sop[2:end]) : op
@eval begin
function ($op){P<:Period}(X::StridedArray{P},y::$Ty)
Z = similar(X, $Tz)
for i = 1:length(X)
@inbounds Z[i] = ($op_)(X[i],y)
end
return Z
end
end
end

# intfuncs
Base.gcdx{T<:Period}(a::T,b::T) = ((g,x,y)=gcdx(value(a),value(b)); return T(g),x,y)
Expand Down Expand Up @@ -186,6 +166,9 @@ function Base.string(x::CompoundPeriod)
end
Base.show(io::IO,x::CompoundPeriod) = print(io,string(x))

*(x::CompoundPeriod,y::Real) = CompoundPeriod(Period[p*Int64(y) for p in x.periods])
*(y::Real,x::CompoundPeriod) = x * y

# E.g. Year(1) + Day(1)
(+)(x::Period,y::Period) = CompoundPeriod(Period[x,y])
(+)(x::CompoundPeriod,y::Period) = CompoundPeriod(vcat(x.periods,y))
Expand All @@ -201,24 +184,6 @@ GeneralPeriod = Union{Period,CompoundPeriod}
(+)(x::GeneralPeriod) = x
(+){P<:GeneralPeriod}(x::StridedArray{P}) = x

for op in (:.+, :.-)
op_ = symbol(string(op)[2:end])
@eval begin
function ($op){P<:GeneralPeriod}(X::StridedArray{P},y::GeneralPeriod)
Z = similar(X, CompoundPeriod)
for i = 1:length(X)
@inbounds Z[i] = ($op_)(X[i],y)
end
return Z
end
($op){P<:GeneralPeriod}(x::GeneralPeriod,Y::StridedArray{P}) = ($op)(Y,x) |> ($op_)
($op_){P<:GeneralPeriod}(x::GeneralPeriod,Y::StridedArray{P}) = ($op)(Y,x) |> ($op_)
($op_){P<:GeneralPeriod}(Y::StridedArray{P},x::GeneralPeriod) = ($op)(Y,x)
($op_){P<:GeneralPeriod, Q<:GeneralPeriod}(X::StridedArray{P}, Y::StridedArray{Q}) =
reshape(CompoundPeriod[($op_)(X[i],Y[i]) for i in eachindex(X, Y)], promote_shape(size(X),size(Y)))
end
end

(==)(x::CompoundPeriod, y::Period) = x == CompoundPeriod(y)
(==)(y::Period, x::CompoundPeriod) = x == y
(==)(x::CompoundPeriod, y::CompoundPeriod) = x.periods == y.periods
Expand Down
2 changes: 1 addition & 1 deletion base/dates/types.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This file is a part of Julia. License is MIT: http://julialang.org/license

abstract AbstractTime
abstract AbstractTime <: AbstractScalar

abstract Period <: AbstractTime
abstract DatePeriod <: Period
Expand Down
3 changes: 3 additions & 0 deletions base/functors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ call(::DotMulFun, x, y) = x .* y
immutable RDivFun <: Func{2} end
call(::RDivFun, x, y) = x / y

immutable DotRDivFun <: Func{2} end
call(::RDivFun, x, y) = x ./ y

immutable LDivFun <: Func{2} end
call(::LDivFun, x, y) = x \ y

Expand Down

0 comments on commit 62b1753

Please sign in to comment.