-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: ReshapedArrays #10507
WIP: ReshapedArrays #10507
Changes from all commits
190e27b
febd683
c69dffe
e74832d
1d77ae4
2624ed8
4c17ea3
f5905b1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -514,3 +514,140 @@ checked_mul(x::Int128, y::Int128) = x * y | |
checked_add(x::UInt128, y::UInt128) = x + y | ||
checked_sub(x::UInt128, y::UInt128) = x - y | ||
checked_mul(x::UInt128, y::UInt128) = x * y | ||
|
||
|
||
## A faster alternative to div if the dividend is reused many times | ||
|
||
unsigned_type(::Int8) = UInt8 | ||
unsigned_type(::Int16) = UInt16 | ||
unsigned_type(::Int32) = UInt32 | ||
unsigned_type(::Int64) = UInt64 | ||
unsigned_type(::Int128) = UInt128 | ||
|
||
abstract FastDivInteger{T} | ||
|
||
immutable SignedFastDivInteger{T<:Signed} <: FastDivInteger{T} | ||
divisor::T | ||
multiplier::T | ||
addmul::Int8 | ||
shift::UInt8 | ||
|
||
function SignedFastDivInteger(d::T) | ||
ut = unsigned_type(d) | ||
signedmin = reinterpret(ut, typemin(d)) | ||
|
||
ad::ut = abs(d) | ||
ad <= 1 && error("cannot compute magic for d == $d") | ||
t::ut = signedmin + signbit(d) | ||
anc::ut = t - 1 - rem(t, ad) # absolute value of nc | ||
p = sizeof(d)*8 - 1 # initialize p | ||
q1::ut, r1::ut = divrem(signedmin, anc) | ||
q2::ut, r2::ut = divrem(signedmin, ad) | ||
while true | ||
p += 1 | ||
q1 *= 2 # update q1 = 2p/abs(nc) | ||
r1 *= 2 # update r1 = rem(2p/abs(nc)) | ||
if r1 >= anc # must be unsigned comparison | ||
q1 += 1 | ||
r1 -= anc | ||
end | ||
q2 *= 2 # update q2 = 2p/abs(d) | ||
r2 *= 2 # update r2 = rem(2p/abs(d)) | ||
if r2 >= ad # must be unsigned comparison | ||
q2 += 1 | ||
r2 -= ad | ||
end | ||
delta::ut = ad - r2 | ||
(q1 < delta || (q1 == delta && r1 == 0)) || break | ||
end | ||
|
||
m = flipsign((q2 + 1) % T, d) # resulting magic number | ||
s = p - sizeof(d)*8 # resulting shift | ||
new(d, m, d > 0 && m < 0 ? Int8(1) : d < 0 && m > 0 ? Int8(-1) : Int8(0), UInt8(s)) | ||
end | ||
end | ||
SignedFastDivInteger(x::Signed) = SignedFastDivInteger{typeof(x)}(x) | ||
|
||
immutable UnsignedFastDivInteger{T<:Unsigned} <: FastDivInteger{T} | ||
divisor::T | ||
multiplier::T | ||
add::Bool | ||
shift::UInt8 | ||
|
||
function UnsignedFastDivInteger(d::T) | ||
(d == 0 || d == 1) && error("cannot compute magic for d == $d") | ||
u2 = convert(T, 2) | ||
add = false | ||
signedmin::typeof(d) = one(d) << (sizeof(d)*8-1) | ||
signedmax::typeof(d) = signedmin - 1 | ||
allones = (zero(d) - 1) % T | ||
|
||
nc::typeof(d) = allones - rem(convert(T, allones - d), d) | ||
p = 8*sizeof(d) - 1 # initialize p | ||
q1::typeof(d), r1::typeof(d) = divrem(signedmin, nc) | ||
q2::typeof(d), r2::typeof(d) = divrem(signedmax, d) | ||
while true | ||
p += 1 | ||
if r1 >= convert(T, nc - r1) | ||
q1 = q1 + q1 + T(1) # update q1 | ||
r1 = r1 + r1 - nc # update r1 | ||
else | ||
q1 = q1 + q1 # update q1 | ||
r1 = r1 + r1 # update r1 | ||
end | ||
if convert(T, r2 + T(1)) >= convert(T, d - r2) | ||
add |= q2 >= signedmax | ||
q2 = q2 + q2 + 1 # update q2 | ||
r2 = r2 + r2 + T(1) - d # update r2 | ||
else | ||
add |= q2 >= signedmin | ||
q2 = q2 + q2 # update q2 | ||
r2 = r2 + r2 + T(1) # update r2 | ||
end | ||
delta::typeof(d) = d - 1 - r2 | ||
(p < sizeof(d)*16 && (q1 < delta || (q1 == delta && r1 == 0))) || break | ||
end | ||
m = q2 + 1 # resulting magic number | ||
s = p - sizeof(d)*8 - add # resulting shift | ||
new(d, m % T, add, s % UInt8) | ||
end | ||
end | ||
UnsignedFastDivInteger(x::Unsigned) = UnsignedFastDivInteger{typeof(x)}(x) | ||
|
||
# Special type to handle div by 1 | ||
immutable FastDivInteger1{T} <: FastDivInteger{T} end | ||
|
||
(*)(a::FastDivInteger, b::FastDivInteger) = a.divisor*b.divisor | ||
(*)(a::Number, b::FastDivInteger) = a*b.divisor | ||
(*)(a::FastDivInteger, b::Number) = a.divisor*b | ||
|
||
# div{T}(a::Integer, b::SignedFastDivInteger{T}) = div(convert(T, a), b) | ||
# div{T}(a::Integer, b::UnsignedFastDivInteger{T}) = div(convert(T, a), b) | ||
# rem{T}(a::Integer, b::FastDivInteger{T}) = rem(convert(T, a), b) | ||
# divrem{T}(a::Integer, b::FastDivInteger{T}) = divrem(convert(T, a), b) | ||
|
||
function div{T}(a::T, b::SignedFastDivInteger{T}) | ||
x = ((widen(a)*b.multiplier) >>> sizeof(a)*8) % T | ||
x += (a*b.addmul) % T | ||
(signbit(x) + (x >> b.shift)) % T | ||
end | ||
function div{T}(a::T, b::UnsignedFastDivInteger{T}) | ||
x = ((widen(a)*b.multiplier) >>> sizeof(a)*8) % T | ||
x = ifelse(b.add, convert(T, convert(T, (convert(T, a - x) >>> 1)) + x), x) | ||
x >>> b.shift | ||
end | ||
div{T}(a, ::FastDivInteger1{T}) = convert(T, a) | ||
|
||
rem{T}(a::T, b::FastDivInteger{T}) = | ||
a - div(a, b)*b.divisor | ||
rem{T}(a, ::FastDivInteger1{T}) = zero(T) | ||
|
||
function divrem{T}(a::T, b::FastDivInteger{T}) | ||
d = div(a, b) | ||
(d, a - d*b.divisor) | ||
end | ||
divrem{T}(a, ::FastDivInteger1{T}) = (convert(T, a), zero(T)) | ||
|
||
# Type unstable! | ||
call(::Type{FastDivInteger}, x::Signed) = x == 1 ? FastDivInteger1{typeof(x)}() : SignedFastDivInteger(x) | ||
call(::Type{FastDivInteger}, x::Unsigned) = x == 1 ? FastDivInteger1{typeof(x)}() : UnsignedFastDivInteger(x) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there no type stable solution, e.g. constructing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was hoping to avoid an extra But if we expect this to be used in other contexts, then perhaps the type instability is something we need to fix, even if it slows down the usage. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
module Reshaped | ||
|
||
import Base: getindex, ind2sub, linearindexing, reshape, similar, size | ||
# just using, not overloading: | ||
import Base: LinearFast, LinearSlow, FastDivInteger, FastDivInteger1, tail | ||
|
||
# Remapping a block of In dimensions of the parent array to Out dimensions of the view | ||
immutable IndexMD{In,Out,MI<:(FastDivInteger...)} | ||
invstride_parent::MI | ||
dims_view::NTuple{Out,Int} | ||
end | ||
IndexMD(mi, dims) = IndexMD{length(mi),length(dims),typeof(mi)}(mi, dims) | ||
|
||
typealias ReshapeIndex Union(Colon, IndexMD) | ||
|
||
immutable ReshapedArray{T,N,P<:AbstractArray,I<:(ReshapeIndex...)} <: AbstractArray{T,N} | ||
parent::P | ||
indexes::I | ||
dims::NTuple{N,Int} | ||
end | ||
|
||
function ReshapedArray(parent::AbstractArray, indexes::(ReshapeIndex...), dims) | ||
ReshapedArray{eltype(parent),length(dims),typeof(parent),typeof(indexes)}(parent, indexes, dims) | ||
end | ||
|
||
# Since ranges are immutable and are frequently used to build arrays, treat them as a special case. | ||
function reshape(a::Range, dims::Dims) | ||
if prod(dims) != length(a) | ||
throw(DimensionMismatch("new dimensions $(dims) must be consistent with array size $(length(a))")) | ||
end | ||
A = Array(eltype(a), dims) | ||
k = 0 | ||
for item in a | ||
A[k+=1] = item | ||
end | ||
A | ||
end | ||
|
||
reshape(parent::AbstractArray, dims::Dims) = reshape((parent, linearindexing(parent)), dims) | ||
|
||
function reshape(p::(AbstractArray,LinearSlow), dims::Dims) | ||
parent = p[1] | ||
# Split on dimensions where the strides line up | ||
stridep = dimstrides(size(parent)) | ||
stridev = dimstrides(dims) | ||
stridep[end] == stridev[end] || throw(DimensionMismatch("Must have the same number of elements")) | ||
indexes = Any[] | ||
iplast = ivlast = 1 | ||
ip = iv = 2 | ||
while ip <= length(stridep) || iv <= length(stridev) | ||
if stridep[ip] == stridev[iv] | ||
if ip-iplast <= 1 && iv-ivlast == 1 | ||
push!(indexes, Colon()) | ||
else | ||
mi = map(FastDivInteger, size(parent)[iplast:ip-1]) | ||
imd = IndexMD(mi, dims[ivlast:iv-1]) | ||
push!(indexes, imd) | ||
end | ||
iplast = ip | ||
ivlast = iv | ||
ip += (ip < length(stridep) || iv == length(stridev)) # for consuming trailing 1s in dims | ||
iv += 1 | ||
elseif stridep[ip] < stridev[iv] | ||
ip += 1 | ||
else | ||
iv += 1 | ||
end | ||
end | ||
ReshapedArray(parent, tuple(indexes...), dims) | ||
end | ||
|
||
function reshape(p::(AbstractArray,LinearFast), dims::Dims) | ||
parent = p[1] | ||
prod(dims) == length(parent) || throw(DimensionMismatch("Must have the same number of elements")) | ||
ReshapedArray(parent, (IndexMD((FastDivInteger1{Int}(),), dims),), dims) | ||
end | ||
|
||
reshape(parent::ReshapedArray, dims::Dims) = reshape(parent.parent, dims) | ||
|
||
size(A::ReshapedArray) = A.dims | ||
similar{T}(A::ReshapedArray, ::Type{T}, dims::Dims) = Array(T, dims) | ||
linearindexing(A::ReshapedArray) = linearindexing(A.parent) | ||
|
||
size(index::IndexMD) = index.dims_view | ||
|
||
@inline function getindex(indx::IndexMD{1}, indexes::Int...) | ||
sub2ind(indx.dims_view, indexes...) | ||
end | ||
@inline function getindex(indx::IndexMD, indexes::Int...) | ||
ind2sub(indx.invstride_parent, sub2ind(indx.dims_view, indexes...)) | ||
end | ||
|
||
consumes{In,Out,MI}(::Type{IndexMD{In,Out,MI}}) = Out | ||
consumes(::Type{Colon}) = 1 | ||
produces{In,Out,MI}(::Type{IndexMD{In,Out,MI}}) = In | ||
produces(::Type{Colon}) = 1 | ||
|
||
getindex(A::ReshapedArray) = A.parent[1] | ||
getindex(A::ReshapedArray, indx::Real) = A.parent[indx] | ||
|
||
stagedfunction getindex{T,N,P,I}(A::ReshapedArray{T,N,P,I}, indexes::Real...) | ||
length(indexes) == N || throw(DimensionMismatch("Must index with all $N indexes")) | ||
c = map(consumes, I) | ||
breaks = [0;cumsum([c...])] | ||
argbreaks = Any[] | ||
for i = 1:length(c) | ||
ab = Expr[] | ||
for j = breaks[i]+1:breaks[i+1] | ||
push!(ab, :(indexes[$j])) | ||
end | ||
push!(argbreaks, ab) | ||
end | ||
argsin = Expr[:(getindex(A.indexes[$i], $(argbreaks[i]...))) for i = 1:length(c)] | ||
np = map(produces, I) | ||
npc = [0;cumsum([np...])] | ||
argsout = Expr[] | ||
if length(argsin) > 1 | ||
for i = 1:npc[end] | ||
j = findlast(npc .< i) | ||
di = i - npc[j] | ||
push!(argsout, :(tindex[$j][$di])) | ||
end | ||
else | ||
for i = 1:npc[end] | ||
push!(argsout, :(tindex[$i])) | ||
end | ||
end | ||
meta = Expr(:meta, :inline) | ||
ex = length(argsin) == 1 ? | ||
quote | ||
$meta | ||
tindex = $(argsin...) | ||
getindex(A.parent, $(argsout...)) | ||
end : | ||
quote | ||
$meta | ||
tindex = tuple($(argsin...)) | ||
getindex(A.parent, $(argsout...)) | ||
end | ||
ex | ||
end | ||
|
||
# Like strides, but operates on a Dims tuple, and returns one extra element (the total size) | ||
dimstrides(::()) = () | ||
dimstrides(s::Dims) = dimstrides((1,), s) | ||
dimstrides(t::Tuple, ::()) = t | ||
@inline dimstrides(t::Tuple, sz::Dims) = dimstrides(tuple(t..., t[end]*sz[1]), tail(sz)) | ||
|
||
ind2sub(dims::(FastDivInteger,), ind::Integer) = ind | ||
@inline ind2sub(dims::(FastDivInteger,FastDivInteger), ind::Integer) = begin | ||
dv, rm = divrem(ind-1,dims[1]) | ||
rm+1, dv+1 | ||
end | ||
@inline ind2sub(dims::(FastDivInteger,FastDivInteger,FastDivInteger...), ind::Integer) = begin | ||
dv, rm = divrem(ind-1,dims[1]) | ||
tuple(rm+1, ind2sub(tail(dims),dv+1)...) | ||
end | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess that once the improved version of the generic definition of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right. |
||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe when a and b are both
FastDivInteger
there product should also be one?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a reasonable suggestion.