Skip to content

Implement varname prefix / unprefix #119

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
keywords = ["probablistic programming"]
license = "MIT"
desc = "Common interfaces for probabilistic programming"
version = "0.10.1"
version = "0.11.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
7 changes: 7 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ vsym
@vsym
```

## VarName prefixing and unprefixing

```@docs
prefix
unprefix
```

## VarName serialisation

```@docs
Expand Down
4 changes: 3 additions & 1 deletion src/AbstractPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ export VarName,
index_to_dict,
dict_to_index,
varname_to_string,
string_to_varname
string_to_varname,
prefix,
unprefix

# Abstract model functions
export AbstractProbabilisticProgram,
Expand Down
192 changes: 192 additions & 0 deletions src/varname.jl
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,8 @@
end
end

### Serialisation to JSON / string

# String constants for each index type that we support serialisation /
# deserialisation of
const _BASE_INTEGER_TYPE = "Base.Integer"
Expand Down Expand Up @@ -936,3 +938,193 @@
should have been generated by `varname_to_string`.
"""
string_to_varname(str::AbstractString) = dict_to_varname(JSON.parse(str))

### Prefixing and unprefixing

"""
_strip_identity(optic)

Remove identity lenses from composed optics.
"""
_strip_identity(::Base.ComposedFunction{typeof(identity),typeof(identity)}) = identity
function _strip_identity(o::Base.ComposedFunction{Outer,typeof(identity)}) where {Outer}
return _strip_identity(o.outer)

Check warning on line 951 in src/varname.jl

View check run for this annotation

Codecov / codecov/patch

src/varname.jl#L949-L951

Added lines #L949 - L951 were not covered by tests
end
function _strip_identity(o::Base.ComposedFunction{typeof(identity),Inner}) where {Inner}
return _strip_identity(o.inner)

Check warning on line 954 in src/varname.jl

View check run for this annotation

Codecov / codecov/patch

src/varname.jl#L953-L954

Added lines #L953 - L954 were not covered by tests
end
_strip_identity(o::Base.ComposedFunction) = o
_strip_identity(o::Accessors.PropertyLens) = o
_strip_identity(o::Accessors.IndexLens) = o
_strip_identity(o::typeof(identity)) = o

Check warning on line 959 in src/varname.jl

View check run for this annotation

Codecov / codecov/patch

src/varname.jl#L959

Added line #L959 was not covered by tests

"""
_inner(optic)

Get the innermost (non-identity) layer of an optic.

```jldoctest; setup=:(using Accessors)
julia> AbstractPPL._inner(Accessors.@o _.a.b.c)
(@o _.a)

julia> AbstractPPL._inner(Accessors.@o _[1][2][3])
(@o _[1])

julia> AbstractPPL._inner(Accessors.@o _)
identity (generic function with 1 method)
```
"""
_inner(o::Base.ComposedFunction{Outer,Inner}) where {Outer,Inner} = o.inner
_inner(o::Accessors.PropertyLens) = o
_inner(o::Accessors.IndexLens) = o
_inner(o::typeof(identity)) = o

"""
_outer(optic)

Get the outer layer of an optic.

```jldoctest; setup=:(using Accessors)
julia> AbstractPPL._outer(Accessors.@o _.a.b.c)
(@o _.b.c)

julia> AbstractPPL._outer(Accessors.@o _[1][2][3])
(@o _[2][3])

julia> AbstractPPL._outer(Accessors.@o _.a)
identity (generic function with 1 method)

julia> AbstractPPL._outer(Accessors.@o _[1])
identity (generic function with 1 method)

julia> AbstractPPL._outer(Accessors.@o _)
identity (generic function with 1 method)
```
"""
_outer(o::Base.ComposedFunction{Outer,Inner}) where {Outer,Inner} = o.outer
_outer(::Accessors.PropertyLens) = identity
_outer(::Accessors.IndexLens) = identity
_outer(::typeof(identity)) = identity

"""
optic_to_vn(optic)

Convert an Accessors optic to a VarName. This is best explained through
examples.

```jldoctest; setup=:(using Accessors)
julia> AbstractPPL.optic_to_vn(Accessors.@o _.a)
a

julia> AbstractPPL.optic_to_vn(Accessors.@o _.a.b)
a.b

julia> AbstractPPL.optic_to_vn(Accessors.@o _.a[1])
a[1]
```

The outermost layer of the optic (technically, what Accessors.jl calls the
'innermost') must be a `PropertyLens`, or else it will fail. This is because a
VarName needs to have a symbol.

```jldoctest; setup=:(using Accessors)
julia> AbstractPPL.optic_to_vn(Accessors.@o _[1])
ERROR: ArgumentError: optic_to_vn: could not convert optic `(@o _[1])` to a VarName
[...]
```
"""
function optic_to_vn(::Accessors.PropertyLens{sym}) where {sym}
return VarName{sym}()
end
function optic_to_vn(o::Base.ComposedFunction{Outer,typeof(identity)}) where {Outer}
return optic_to_vn(o.outer)

Check warning on line 1040 in src/varname.jl

View check run for this annotation

Codecov / codecov/patch

src/varname.jl#L1039-L1040

Added lines #L1039 - L1040 were not covered by tests
end
function optic_to_vn(
o::Base.ComposedFunction{Outer,Accessors.PropertyLens{sym}}
) where {Outer,sym}
return VarName{sym}(o.outer)
end
function optic_to_vn(@nospecialize(o))
msg = "optic_to_vn: could not convert optic `$o` to a VarName"
throw(ArgumentError(msg))
end

unprefix_optic(o, ::typeof(identity)) = o # Base case
function unprefix_optic(optic, optic_prefix)
# strip one layer of the optic and check for equality
inner = _inner(_strip_identity(optic))
inner_prefix = _inner(_strip_identity(optic_prefix))
if inner != inner_prefix
msg = "could not remove prefix $(optic_prefix) from optic $(optic)"
throw(ArgumentError(msg))
end
# recurse
return unprefix_optic(
_outer(_strip_identity(optic)), _outer(_strip_identity(optic_prefix))
)
end

"""
unprefix(vn::VarName, prefix::VarName)

Remove a prefix from a VarName.

```jldoctest
julia> AbstractPPL.unprefix(@varname(y.x), @varname(y))
x

julia> AbstractPPL.unprefix(@varname(y.x.a), @varname(y))
x.a

julia> AbstractPPL.unprefix(@varname(y[1].x), @varname(y[1]))
x

julia> AbstractPPL.unprefix(@varname(y), @varname(n))
ERROR: ArgumentError: could not remove prefix n from VarName y
[...]
```
"""
function unprefix(
vn::VarName{sym_vn}, prefix::VarName{sym_prefix}
) where {sym_vn,sym_prefix}
if sym_vn != sym_prefix
msg = "could not remove prefix $(prefix) from VarName $(vn)"
throw(ArgumentError(msg))
end
optic_vn = getoptic(vn)
optic_prefix = getoptic(prefix)
return optic_to_vn(unprefix_optic(optic_vn, optic_prefix))
end

"""
prefix(vn::VarName, prefix::VarName)

Add a prefix to a VarName.

```jldoctest
julia> AbstractPPL.prefix(@varname(x), @varname(y))
y.x

julia> AbstractPPL.prefix(@varname(x.a), @varname(y))
y.x.a

julia> AbstractPPL.prefix(@varname(x.a), @varname(y[1]))
y[1].x.a
```
"""
function prefix(vn::VarName{sym_vn}, prefix::VarName{sym_prefix}) where {sym_vn,sym_prefix}
optic_vn = getoptic(vn)
optic_prefix = getoptic(prefix)
# Special case `identity` to avoid having ComposedFunctions with identity
if optic_vn == identity
new_inner_optic_vn = PropertyLens{sym_vn}()
else
new_inner_optic_vn = optic_vn ∘ PropertyLens{sym_vn}()
end
if optic_prefix == identity
new_optic_vn = new_inner_optic_vn
else
new_optic_vn = new_inner_optic_vn ∘ optic_prefix
end
return VarName{sym_prefix}(new_optic_vn)
end
15 changes: 15 additions & 0 deletions test/varname.jl
Original file line number Diff line number Diff line change
Expand Up @@ -233,4 +233,19 @@ end
# Serialisation should now work
@test string_to_varname(varname_to_string(vn)) == vn
end

@testset "prefix and unprefix" begin
@test prefix(@varname(y), @varname(x)) == @varname(x.y)
@test prefix(@varname(y), @varname(x[1])) == @varname(x[1].y)
@test prefix(@varname(y), @varname(x.a)) == @varname(x.a.y)
@test prefix(@varname(y[1]), @varname(x)) == @varname(x.y[1])
@test prefix(@varname(y.a), @varname(x)) == @varname(x.y.a)

@test unprefix(@varname(x.y[1]), @varname(x)) == @varname(y[1])
@test unprefix(@varname(x[1].y), @varname(x[1])) == @varname(y)
@test unprefix(@varname(x.a.y), @varname(x.a)) == @varname(y)
@test unprefix(@varname(x.y.a), @varname(x)) == @varname(y.a)
@test_throws ArgumentError unprefix(@varname(x.y.a), @varname(n))
@test_throws ArgumentError unprefix(@varname(x.y.a), @varname(x[1]))
end
end
Loading