Skip to content

Commit

Permalink
Merge pull request #96 from TensorBFS/jg/fix-sampling
Browse files Browse the repository at this point in the history
Fix sampling algorithm
  • Loading branch information
mroavi authored Jul 30, 2024
2 parents 9986d3f + 8a250cd commit 849a28b
Show file tree
Hide file tree
Showing 10 changed files with 236 additions and 106 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
TropicalNumbers = "b3a74e9c-7526-4576-a4eb-79c0d4c32334"
Expand All @@ -25,6 +26,7 @@ LinearAlgebra = "1"
OMEinsum = "0.8"
Pkg = "1"
PrecompileTools = "1"
PrettyTables = "2"
Requires = "1"
StatsBase = "0.34"
TropicalNumbers = "0.5.4, 0.6"
Expand Down
1 change: 1 addition & 0 deletions docs/src/api/public.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,5 @@ read_td_file
sample
update_evidence!
update_temperature
random_matrix_product_state
```
3 changes: 2 additions & 1 deletion src/Core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,8 @@ Get the cardinalities of variables in this tensor network.
"""
function get_cards(tn::TensorNetworkModel; fixedisone = false)::Vector
vars = get_vars(tn)
[fixedisone && haskey(tn.evidence, vars[k]) ? 1 : length(tn.tensors[k]) for k in eachindex(vars)]
size_dict = OMEinsum.get_size_dict(getixsv(tn.code), tn.tensors)
[fixedisone && haskey(tn.evidence, vars[k]) ? 1 : size_dict[vars[k]] for k in eachindex(vars)]
end

chevidence(tn::TensorNetworkModel, evidence) = TensorNetworkModel(tn.vars, tn.code, tn.tensors, evidence)
Expand Down
6 changes: 3 additions & 3 deletions src/RescaledArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ $(TYPEDSIGNATURES)
Returns a rescaled array that equivalent to the input tensor.
"""
function rescale_array(tensor::AbstractArray{T})::RescaledArray where {T}
maxf = maximum(tensor)
maxf = maximum(abs, tensor)
if iszero(maxf)
@warn("The maximum value of the array to rescale is 0!")
return RescaledArray(zero(T), tensor)
end
return RescaledArray(log(maxf), OMEinsum.asarray(tensor ./ maxf, tensor))
return RescaledArray(T(log(maxf)), OMEinsum.asarray(tensor ./ maxf, tensor))
end

for CT in [:DynamicEinCode, :StaticEinCode]
Expand All @@ -46,4 +46,4 @@ end
Base.size(arr::RescaledArray) = size(arr.normalized_value)
Base.size(arr::RescaledArray, i::Int) = size(arr.normalized_value, i)

match_arraytype(::Type{<:RescaledArray{T, N}}, target::AbstractArray{T, N}) where {T, N} = rescale_array(target)
match_arraytype(::Type{<:RescaledArray{T, N, AT}}, target::AbstractArray{T, N}) where {T, N, AT} = rescale_array(match_arraytype(AT, target))
4 changes: 4 additions & 0 deletions src/TensorInference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using OMEinsum, LinearAlgebra
using DocStringExtensions, TropicalNumbers
# The Tropical GEMM support
using StatsBase
using PrettyTables
import Pkg

# reexport OMEinsum functions
Expand All @@ -34,6 +35,9 @@ export sample
# MMAP
export MMAPModel

# utils
export random_matrix_product_state

include("Core.jl")
include("RescaledArray.jl")
include("utils.jl")
Expand Down
22 changes: 8 additions & 14 deletions src/mar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ end
# `CacheTree` stores intermediate `NestedEinsum` contraction results.
# It is a tree structure that isomorphic to the contraction tree,
# `content` is the cached intermediate contraction result.
# `siblings` are the siblings of current node.
struct CacheTree{T}
# `children` are the children of current node, e.g. tensors that are contracted to get `content`.
mutable struct CacheTree{T}
content::AbstractArray{T}
siblings::Vector{CacheTree{T}}
const children::Vector{CacheTree{T}}
end

function cached_einsum(se::SlicedEinsum, @nospecialize(xs), size_dict)
Expand Down Expand Up @@ -62,7 +62,7 @@ function generate_gradient_tree(code::NestedEinsum, cache::CacheTree{T}, dy::Abs
if OMEinsum.isleaf(code)
return CacheTree(dy, CacheTree{T}[])
else
xs = ntuple(i -> cache.siblings[i].content, length(cache.siblings))
xs = ntuple(i -> cache.children[i].content, length(cache.children))
# `einsum_grad` is the back-propagation rule for einsum function.
# If the forward pass is `y = einsum(EinCode(inputs_labels, output_labels), (A, B, ...), size_dict)`
# Then the back-propagation pass is
Expand All @@ -73,7 +73,7 @@ function generate_gradient_tree(code::NestedEinsum, cache::CacheTree{T}, dy::Abs
# ```
# Let `L` be the loss, we will have `y̅ := ∂L/∂y`, `A̅ := ∂L/∂A`...
dxs = einsum_backward_rule(code.eins, xs, cache.content, size_dict, dy)
return CacheTree(dy, generate_gradient_tree.(code.args, cache.siblings, dxs, Ref(size_dict)))
return CacheTree(dy, generate_gradient_tree.(code.args, cache.children, dxs, Ref(size_dict)))
end
end

Expand Down Expand Up @@ -116,7 +116,7 @@ function extract_leaves!(code, cache, res)
res[code.tensorindex] = cache.content
else
# resurse deeper
extract_leaves!.(code.args, cache.siblings, Ref(res))
extract_leaves!.(code.args, cache.children, Ref(res))
end
return res
end
Expand Down Expand Up @@ -145,10 +145,7 @@ The following example is taken from [`examples/asia-network/main.jl`](https://te
```jldoctest; setup = :(using TensorInference, Random; Random.seed!(0))
julia> model = read_model_file(pkgdir(TensorInference, "examples", "asia-network", "model.uai"));
julia> tn = TensorNetworkModel(model; evidence=Dict(1=>0))
TensorNetworkModel{Int64, DynamicNestedEinsum{Int64}, Array{Float64}}
variables: 1 (evidence → 0), 2, 3, 4, 5, 6, 7, 8
contraction time = 2^6.022, space = 2^2.0, read-write = 2^7.077
julia> tn = TensorNetworkModel(model; evidence=Dict(1=>0));
julia> marginals(tn)
Dict{Vector{Int64}, Vector{Float64}} with 8 entries:
Expand All @@ -161,10 +158,7 @@ Dict{Vector{Int64}, Vector{Float64}} with 8 entries:
[7] => [0.145092, 0.854908]
[2] => [0.05, 0.95]
julia> tn2 = TensorNetworkModel(model; evidence=Dict(1=>0), mars=[[2, 3], [3, 4]])
TensorNetworkModel{Int64, DynamicNestedEinsum{Int64}, Array{Float64}}
variables: 1 (evidence → 0), 2, 3, 4, 5, 6, 7, 8
contraction time = 2^7.781, space = 2^5.0, read-write = 2^8.443
julia> tn2 = TensorNetworkModel(model; evidence=Dict(1=>0), mars=[[2, 3], [3, 4]]);
julia> marginals(tn2)
Dict{Vector{Int64}, Matrix{Float64}} with 2 entries:
Expand Down
Loading

0 comments on commit 849a28b

Please sign in to comment.