Skip to content
This repository has been archived by the owner on Jul 7, 2024. It is now read-only.

Commit

Permalink
Refactor sites methods
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed May 20, 2024
1 parent 6d6e8e4 commit 0c67e56
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/Ansatz.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ for f in [
:noutputs,
:inputs,
:outputs,
:sites,
:nsites,
:nlanes,
:socket,
Expand All @@ -46,6 +45,8 @@ function Base.summary(io::IO, tn::A) where {A<:Ansatz}
end
Base.show(io::IO, tn::A) where {A<:Ansatz} = summary(io, tn)

sites(tn::Ansatz; kwargs...) = sites(Quantum(tn); kwargs...)

function Tenet.inds(tn::Ansatz; kwargs...)
if keys(kwargs) === (:bond,)
inds(tn, Val(:bond), kwargs[:bond]...)
Expand Down
10 changes: 9 additions & 1 deletion src/Quantum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,15 @@ Base.show(io::IO, q::Quantum) = print(io, "Quantum (inputs=$(ninputs(q)), output
Returns the sites of a [`Quantum`](@ref) Tensor Network.
"""
sites(tn::Quantum) = collect(keys(tn.sites))
function sites(tn::Quantum; kwargs...)
if isempty(kwargs)
collect(keys(tn.sites))
elseif keys(kwargs) === (:at,)
findfirst(i -> i === kwargs[:at], tn.sites)
else
throw(MethodError(sites, (Quantum,), kwargs))
end
end

"""
nsites(q::Quantum)
Expand Down

0 comments on commit 0c67e56

Please sign in to comment.