Skip to content

Commit

Permalink
Optimize HessianFIxedkappa when kappa=0; prevent crash on 0-length wt…
Browse files Browse the repository at this point in the history
… or FE matrix argument values
  • Loading branch information
droodman committed Feb 25, 2023
1 parent fe0a6c7 commit 87c6170
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 114 deletions.
130 changes: 25 additions & 105 deletions src/WRE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -464,121 +464,41 @@ function HessianFixedkappa!(o::StrBootTest{T}, dest::AbstractMatrix{T}, is::Vect
end

@inbounds for (row,i) enumerate(is)
if !iszero(κ)
if iszero(κ)
dest[row,:] .= o.ȲȲ[i+1,j+1]; t✻plus!(view(dest,row:row,:), view(o.S✻ȲUfold,i+1,:,j+1)', o.v);
coldotplus!(dest, row, o.v, o.S✻UU[i+1, j+1], o.v)
if o.Repl.Yendog[i+1] && o.Repl.Yendog[j+1]
mul!(o.invZperpZperpS✻ZperpUv, o.invZperpZperpS✻ZperpU[i+1], o.v)
coldotminus!(dest, row, o.invZperpZperpS✻ZperpUv, o.S✻ZperpUv)
end

if o.NFE>0 && !o.FEboot
mul!(o.CT✻FEUv, o.CT✻FEU[i+1], o.v)
coldotminus!(dest, row, o.CT✻FEUv, o.invFEwtCT✻FEUv)
end
else
o.T1L .= o.XȲ[:,i+1] # X_∥^' Y_(∥i)
o.Repl.Yendog[i+1] &&
t✻plus!(o.T1L, o.S✻XU[i+1], o.v)
coldot!(dest, row, o.T1L, o.T1R) # multiply in the left-side linear term
end

if !isone(κ)
_dest = t✻(view(o.S✻ȲUfold,i+1,:,j+1)', o.v); _dest .+= o.ȲȲ[i+1,j+1]
coldotplus!(_dest, 1, o.v, o.S✻UU[i+1, j+1], o.v)
if o.Repl.Yendog[i+1] && o.Repl.Yendog[j+1]
mul!(o.invZperpZperpS✻ZperpUv, o.invZperpZperpS✻ZperpU[i+1], o.v)
coldotminus!(_dest, 1, o.invZperpZperpS✻ZperpUv, o.S✻ZperpUv)
end
if !isone(κ)
_dest = t✻(view(o.S✻ȲUfold,i+1,:,j+1)', o.v); _dest .+= o.ȲȲ[i+1,j+1]
coldotplus!(_dest, 1, o.v, o.S✻UU[i+1, j+1], o.v)
if o.Repl.Yendog[i+1] && o.Repl.Yendog[j+1]
mul!(o.invZperpZperpS✻ZperpUv, o.invZperpZperpS✻ZperpU[i+1], o.v)
coldotminus!(_dest, 1, o.invZperpZperpS✻ZperpUv, o.S✻ZperpUv)
end

if o.NFE>0 && !o.FEboot
mul!(o.CT✻FEUv, o.CT✻FEU[i+1], o.v)
coldotminus!(_dest, 1, o.CT✻FEUv, o.invFEwtCT✻FEUv)
end
if o.NFE>0 && !o.FEboot
mul!(o.CT✻FEUv, o.CT✻FEU[i+1], o.v)
coldotminus!(_dest, 1, o.CT✻FEUv, o.invFEwtCT✻FEUv)
end

if iszero(κ)
dest[row:row,:] .= _dest
else
dest[row:row,:] .*= κ; dest[row:row,:] .+= (1-κ) .* _dest
dest[row:row,:] .*= κ; dest[row:row,:] .+= (1-κ) .* _dest
end
end

# if o.Repl.Yendog[i+1] && o.Repl.Yendog[j+1]
# mul!(o.invZperpZperpS✻ZperpUv, o.invZperpZperpS✻ZperpU[i+1], o.v)
# o.NFE>0 && !o.FEboot &&
# mul!(o.CT✻FEUv, o.CT✻FEU[i+1], o.v)
# if iszero(κ)
# dest[row,:] .= o.ȲȲ[i+1,j+1]; t✻plus!(view(dest,row:row,:), view(o.S✻ȲUfold,i+1,:,j+1)', o.v)
# coldotminus!(dest, row, o.invZperpZperpS✻ZperpUv, o.S✻ZperpUv)
# coldotplus!(dest, row, o.v, o.S✻UU[i+1,j+1], o.v)
# o.NFE>0 && !o.FEboot &&
# coldotminus!(dest, row, o.CT✻FEUv, o.invFEwtCT✻FEUv)
# elseif isone(κ)
# t✻!(view(dest,row,:), view(o.S✻ȲUfold,i+1,:,j+1)', o.v); dest[row,:] .+= o.ȲȲ[i+1,j+1]
# coldotminus!(dest, row, o.invZperpZperpS✻ZperpUv, o.S✻ZperpUv)
# coldotplus!(dest, row, o.v, o.S✻UU[i+1, j+1], o.v)
# o.NFE>0 && !o.FEboot &&
# coldotminus!(dest, row, o.CT✻FEUv, o.invFEwtCT✻FEUv)
# else
# _dest = t✻(view(o.S✻ȲUfold,i+1,:,j+1)', o.v); _dest .+= o.ȲȲ[i+1,j+1]
# coldotminus!(_dest, 1, o.invZperpZperpS✻ZperpUv, o.S✻ZperpUv)
# coldotplus!(_dest, 1, o.v, o.S✻UU[i+1, j+1], o.v)
# o.NFE>0 && !o.FEboot &&
# coldotminus!(_dest, 1, o.CT✻FEUv, o.invFEwtCT✻FEUv)
# dest[row,:] .*= κ; dest[row,:] .+= (1 - κ) .* _dest
# end




# if !o.Repl.Yendog[i+1] && !o.Repl.Yendog[j+1] # if both vars exog, result = order-0 term only, same for all draws
# !iszero(κ) &&
# (dest[row,:] .= dot(view(o.XȲ,:,i+1), j>0 ? view(o.invXXXZ̄,:,j) : view(o.DGP.γ⃛,:,1)))
# if !isone(κ)
# if iszero(κ)
# dest[row,:] .= o.ȲȲ[i+1,j+1]
# else
# dest[row,:] .= κ .* dest[row,:] .+ (1 - κ) .* o.ȲȲ[i+1,j+1]
# end
# end
# else
# if !iszero(κ) # repetitiveness in this section to maintain type stability
# if o.Repl.Yendog[i+1]
# mul!(o.T1L, o.S✻XU[i+1], o.v)
# o.T1L .+= view(o.XȲ,:,i+1)
# if o.Repl.Yendog[j+1]
# coldot!(dest, row, o.T1L, o.T1R)
# else
# mul!(view(dest,row,:), o.T1L', view(o.invXXXZ̄,:,j))
# end
# else
# if o.Repl.Yendog[j+1]
# mul!(view(dest,row,:), o.T1R', view(o.XȲ,:,i+1))
# else
# dest[row,:] .= dot(view(o.invXXXZ̄,:,j), view(o.XȲ,:,i+1))
# end
# end
# end
# if !isone(κ)
# if o.Repl.Yendog[i+1] && o.Repl.Yendog[j+1]
# mul!(o.invZperpZperpS✻ZperpUv, o.invZperpZperpS✻ZperpU[i+1], o.v)
# o.NFE>0 && !o.FEboot &&
# mul!(o.CT✻FEUv, o.CT✻FEU[i+1], o.v)
# if iszero(κ)
# dest[row,:] .= o.ȲȲ[i+1,j+1]; t✻plus!(view(dest,row:row,:), view(o.S✻ȲUfold,i+1,:,j+1)', o.v)
# coldotminus!(dest, row, o.invZperpZperpS✻ZperpUv, o.S✻ZperpUv)
# coldotplus!(dest, row, o.v, o.S✻UU[i+1,j+1], o.v)
# o.NFE>0 && !o.FEboot &&
# coldotminus!(dest, row, o.CT✻FEUv, o.invFEwtCT✻FEUv)
# else
# _dest = t✻(view(o.S✻ȲUfold,i+1,:,j+1)', o.v); _dest .+= o.ȲȲ[i+1,j+1]
# coldotminus!(_dest, 1, o.invZperpZperpS✻ZperpUv, o.S✻ZperpUv)
# coldotplus!(_dest, 1, o.v, o.S✻UU[i+1, j+1], o.v)
# o.NFE>0 && !o.FEboot &&
# coldotminus!(_dest, 1, o.CT✻FEUv, o.invFEwtCT✻FEUv)
# dest[row,:] .= κ .* dest[row,:] .+ (1 - κ) .* _dest
# end
# elseif iszero(κ)
# dest[row,:] .= o.ȲȲ[i+1,j+1]; t✻plus!(view(dest,row:row,:), view(o.S✻ȲUfold,i+1,:,j+1)', o.v)
# else
# _dest = t✻(view(o.S✻ȲUfold,i+1,:,j+1)', o.v); _dest .+= o.ȲȲ[i+1,j+1]
# dest[row,:] .= κ .* dest[row,:] .+ (1 - κ) .* _dest
# end
# elseif iszero(κ)
# dest[row,:] .= o.ȲȲ[i+1,j+1]
# else
# dest[row,:] .= κ .* dest[row,:] .+ (1 - κ) .* o.ȲȲ[i+1,j+1]
# end
# end

if _jk
!iszero(κ) &&
(dest[row,1] = dot(i>0 ? o.Repl.XZ[:,i] : o.Repl.Xy₁par,
Expand Down
3 changes: 2 additions & 1 deletion src/init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -366,13 +366,14 @@ function InitFEs(o::StrBootTest{T}) where T
resize!(o.invFEwt, o.NFE)
resize!(o.FEs , o.NFE)
end
o.FEdfadj==-1 && (o.FEdfadj = o.NFE)

if o.FEboot # are all of this FE's obs in same bootstrapping cluster?
tmp = o.ID[is, 1:o.NBootClustVar]
o.FEboot = all(tmp .== view(tmp,1,:)')
end

o.FEdfadj==-1 && (o.FEdfadj = o.NFE)

if o.robust && o.B>0 && o.bootstrapt && !o.FEboot && o.granular < o.NErrClustVar
o.infoBootAll = panelsetup(o.ID✻⋂, 1:o.NBootClustVar) # info for bootstrapping clusters wrt data collapsed to intersections of all bootstrapping && error clusters
end
Expand Down
6 changes: 3 additions & 3 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ function __wildboottest(
getauxweights && reps>0 ? getv(o) : nothing #=, o=#)
end

vecconvert(T::DataType, X) = isa(X, Vector{T}) ? X : Vector{T}(reshape(X, size(X,1)))
vecconvert(T::DataType, X) = iszero(length(X)) ? T[] : isa(X, Vector{T}) ? X : Vector{T}(reshape(X, size(X,1)))
matconvert(T::DataType, X) = isa(X, Matrix{T}) ? X : Matrix{T}(reshape(X, size(X,1), size(X,2)))

function _wildboottest(T::DataType,
Expand Down Expand Up @@ -347,9 +347,9 @@ Function to perform wild-bootstrap-based hypothesis test
* `nerrclustvar::Integer=nbootclustvar`: number of error-clustering variables
* `issorted:Bool=false`: time-saving flag: data matrices are already sort by column types 2, then 3, then 1 (see notes)
* `hetrobust::Bool=true`: true unless errors are treated as iid
* `nfe::Integer=0`: number of fixed-effect groups; if 0 yet `feid` is provided, will be computed
* `nfe::Integer=0`: number of fixed-effect groups; if 0 yet `feid` is provided, will be computed, at small speed penalty
* `feid::AbstractVector{<:Integer}`: data vector for one-way fixed effect group identifier
* `fedfadj::Integer=nfe`: degrees of freedom that fixed effects (if any) consume
* `fedfadj::Integer`: degrees of freedom that fixed effects (if any) consume; defaults to number of FEs
* `obswt::AbstractVector=[]`: observation weight vector; default is equal weighting
* `fweights::Bool=false`: true for frequency weights
* `maxmatsize::Number`: maximum size of auxilliary weight matrix (v), in gigabytes
Expand Down
4 changes: 2 additions & 2 deletions src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,9 @@ function t✻plus!(A::AbstractVector{T}, B::AbstractMatrix{T}, C::AbstractVector
end
function t✻minus!(A::AbstractMatrix{T}, B::AbstractVecOrMat{T}, C::AbstractMatrix{T}) where T # add B*C to A in place
if length(B)>0 && length(C)>0
@tturbo warn_check_args=false for i eachindex(axes(A,1),axes(B,1)), k eachindex(axes(A,2), axes(C,2))
@tturbo warn_check_args=false for i indices((A,B),1), k indices((A,C),2)
Aᵢₖ = zero(T)
for j eachindex(axes(B,2),axes(C,1))
for j indices((B,C),(2,1))
Aᵢₖ += B[i,j] * C[j,k]
end
A[i,k] -= Aᵢₖ
Expand Down
6 changes: 3 additions & 3 deletions test/unittests.log
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,8 @@ CI = [-0.01093 0.123]
areg wage ttl_exp collgrad tenure if occupation<. & grade<., robust absorb(industry)
boottest tenure
t(1833) = 1.6602
p = 0.0971
CI = [-0.004733 0.07348]
p = 0.0951
CI = [-0.005038 0.07348]


areg wage ttl_exp collgrad tenure [aw=hours] if occupation<. & grade<., robust absorb(industry)
Expand Down Expand Up @@ -313,7 +313,7 @@ areg n w k, absorb(ind)
boottest k, cluster(id year)
t(8) = 28.5012
p = 0.0000
CI = [0.7823 0.9412]
CI = [0.7823 0.946]


areg n w k [aw=ys], absorb(ind)
Expand Down

0 comments on commit 87c6170

Please sign in to comment.