Skip to content

Commit

Permalink
Improvements in maps (#676)
Browse files Browse the repository at this point in the history
* start fixing triggers maps

* cleanup

* fix issues with 'contains' in maps and with trigger generation

* backup

* backup

* doc

* Apply suggestions from code review

Co-authored-by: Felix Wolf <60103963+Felalolf@users.noreply.github.com>

* feedback from Felix

---------

Co-authored-by: Felix Wolf <60103963+Felalolf@users.noreply.github.com>
  • Loading branch information
jcp19 and Felalolf authored Sep 22, 2023
1 parent 17f510b commit 8895ec3
Show file tree
Hide file tree
Showing 12 changed files with 135 additions and 24 deletions.
1 change: 1 addition & 0 deletions src/main/scala/viper/gobra/frontend/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4306,6 +4306,7 @@ object Desugar extends LazyLogging {
} yield underlyingType(dright.typ) match {
case _: in.SequenceT | _: in.SetT => in.Contains(dleft, dright)(src)
case _: in.MultisetT => in.LessCmp(in.IntLit(0)(src), in.Contains(dleft, dright)(src))(src)
case _: in.MapT => in.Contains(dleft, dright)(src)
case t => violation(s"expected a sequence or (multi)set type, but got $t")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ trait GhostExprTyping extends BaseTyping { this: TypeInfoImpl =>
case PIn(left, right) => isExpr(left).out ++ isExpr(right).out ++ {
underlyingType(exprType(right)) match {
case t : GhostCollectionType => ghostComparableTypes.errors(exprType(left), t.elem)(expr)
case t : MapT => ghostComparableTypes.errors(exprType(left), t.key)(expr)
case _ : AdtT => noMessages
case t => error(right, s"expected a ghost collection, but got $t")
}
Expand Down
2 changes: 2 additions & 0 deletions src/main/scala/viper/gobra/translator/context/Context.scala
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ trait Context {

def expression(x: in.Expr): CodeWriter[vpr.Exp] = typeEncoding.finalExpression(this)(x)

def triggerExpr(x: in.TriggerExpr): CodeWriter[vpr.Exp] = typeEncoding.triggerExpr(this)(x)

def assertion(x: in.Assertion): CodeWriter[vpr.Exp] = typeEncoding.finalAssertion(this)(x)

def invariant(x: in.Assertion): (CodeWriter[Unit], vpr.Exp) = typeEncoding.invariant(this)(x)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import viper.gobra.translator.encodings.closures.ClosureEncoding
import viper.gobra.translator.encodings.combinators.{DefaultEncoding, FinalTypeEncoding, SafeTypeEncodingCombiner, TypeEncoding}
import viper.gobra.translator.encodings.interfaces.InterfaceEncoding
import viper.gobra.translator.encodings.maps.{MapEncoding, MathematicalMapEncoding}
import viper.gobra.translator.encodings.defaults.{DefaultGlobalVarEncoding, DefaultMethodEncoding, DefaultPredicateEncoding, DefaultPureMethodEncoding}
import viper.gobra.translator.encodings.defaults.{DefaultGlobalVarEncoding, DefaultMethodEncoding, DefaultPredicateEncoding, DefaultPureMethodEncoding, DefaultTriggerExprEncoding}
import viper.gobra.translator.encodings.options.OptionEncoding
import viper.gobra.translator.encodings.preds.PredEncoding
import viper.gobra.translator.encodings.sequences.SequenceEncoding
Expand Down Expand Up @@ -61,6 +61,7 @@ class DfltTranslatorConfig(
val pureMethodEncoding = new DefaultPureMethodEncoding
val predicateEncoding = new DefaultPredicateEncoding
val globalVarEncoding = new DefaultGlobalVarEncoding
val triggerExprEncoding = new DefaultTriggerExprEncoding

val typeEncoding: TypeEncoding = new FinalTypeEncoding(
new SafeTypeEncodingCombiner(Vector(
Expand All @@ -73,7 +74,7 @@ class DfltTranslatorConfig(
new TerminationEncoding, new BuiltInEncoding, new OutlineEncoding, new DeferEncoding,
new GlobalEncoding, new Comments,
), Vector(
methodEncoding, pureMethodEncoding, predicateEncoding, globalVarEncoding
methodEncoding, pureMethodEncoding, predicateEncoding, globalVarEncoding, triggerExprEncoding
))
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class FinalTypeEncoding(te: TypeEncoding) extends TypeEncoding {
override def equal(ctx: Context): (in.Expr, in.Expr, in.Node) ==> CodeWriter[vpr.Exp] = te.equal(ctx) orElse expectedMatch("equal")
override def goEqual(ctx: Context): (in.Expr, in.Expr, in.Node) ==> CodeWriter[vpr.Exp] = te.goEqual(ctx) orElse expectedMatch("equal")
override def expression(ctx: Context): in.Expr ==> CodeWriter[vpr.Exp] = te.expression(ctx) orElse expectedMatch("expression")
override def triggerExpr(ctx: Context): in.TriggerExpr ==> CodeWriter[vpr.Exp] = te.triggerExpr(ctx) orElse expectedMatch("trigger expression")
override def assertion(ctx: Context): in.Assertion ==> CodeWriter[vpr.Exp] = te.assertion(ctx) orElse expectedMatch("assertion")
override def reference(ctx: Context): in.Location ==> CodeWriter[vpr.Exp] = te.reference(ctx) orElse expectedMatch("reference")
override def addressFootprint(ctx: Context): (in.Location, in.Expr) ==> CodeWriter[vpr.Exp] = te.addressFootprint(ctx) orElse expectedMatch("addressFootprint")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,13 @@ trait TypeEncoding extends Generator {
case in.Conversion(t2, expr :: t) if typ(ctx).isDefinedAt(t) && typ(ctx).isDefinedAt(t2) => ctx.expression(expr)
}

/**
* Encodes expressions when they occur as the top-level expression in a trigger.
* The default implements an encoding for predicate instances and defers the
* encoding of all expressions to the expression encoding.
*/
def triggerExpr(@unused ctx: Context): in.TriggerExpr ==> CodeWriter[vpr.Exp] = PartialFunction.empty

/**
* Encodes assertions.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ abstract class TypeEncodingCombiner(encodings: Vector[TypeEncoding], defaults: V
override def equal(ctx: Context): (in.Expr, in.Expr, in.Node) ==> CodeWriter[vpr.Exp] = combiner(_.equal(ctx))
override def goEqual(ctx: Context): (in.Expr, in.Expr, in.Node) ==> CodeWriter[vpr.Exp] = combiner(_.goEqual(ctx))
override def expression(ctx: Context): in.Expr ==> CodeWriter[vpr.Exp] = combiner(_.expression(ctx))
override def triggerExpr(ctx: Context): in.TriggerExpr ==> CodeWriter[vpr.Exp] = combiner(_.triggerExpr(ctx))
override def assertion(ctx: Context): in.Assertion ==> CodeWriter[vpr.Exp] = combiner(_.assertion(ctx))
override def reference(ctx: Context): in.Location ==> CodeWriter[vpr.Exp] = combiner(_.reference(ctx))
override def addressFootprint(ctx: Context): (in.Location, in.Expr) ==> CodeWriter[vpr.Exp] = combiner(_.addressFootprint(ctx))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
//
// Copyright (c) 2011-2023 ETH Zurich.

package viper.gobra.translator.encodings.defaults

import org.bitbucket.inkytonik.kiama.==>
import viper.gobra.ast.{internal => in}
import viper.gobra.translator.context.Context
import viper.gobra.translator.encodings.combinators.Encoding
import viper.silver.{ast => vpr}

class DefaultTriggerExprEncoding extends Encoding {
import viper.gobra.translator.util.ViperWriter._

override def triggerExpr(ctx: Context): in.TriggerExpr ==> CodeWriter[vpr.Exp] = {
// use predicate access encoding but then take just the predicate access, i.e. without the access predicate:
case in.Accessible.Predicate(op) =>
for {
v <- ctx.assertion(in.Access(in.Accessible.Predicate(op), in.FullPerm(op.info))(op.info))
pap = v.asInstanceOf[vpr.PredicateAccessPredicate]
} yield pap.loc
case e: in.Expr => ctx.expression(e)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,11 @@ class MapEncoding extends LeafTypeEncoding {
* Encodes expressions as values that do not occupy some identifiable location in memory.
* R[ nil(map[K]V°) ] -> null
* R[ dflt(map[K]V°) ] -> null
* R[ len(e: map[K]V) ] -> [e] == null? 0 : | getCorrespondingMap([e]) |
* R[ (e: map[K]V)[idx] ] -> [e] == null? [ dflt(V) ] : goMapLookup(e[idx])
* R[ keySet(e: map[K]V) ] -> [e] == null? 0 : MapDomain(getCorrespondingMap(e))
* R[ valueSet(e: map[K]V) ] -> [e] == null? 0 : MapRange(getCorrespondingMap(e))
* R[ len(e: map[K]V) ] -> [ e ] == null? 0 : | getCorrespondingMap([ e ]) |
* R[ (e: map[K]V)[idx] ] -> [ e ] == null? [ dflt(V) ] : goMapLookup(e[idx])
* R[ keySet(e: map[K]V) ] -> [ e ] == null? 0 : MapDomain(getCorrespondingMap([ e ]))
* R[ valueSet(e: map[K]V) ] -> [ e ] == null? 0 : MapRange(getCorrespondingMap([ e ]))
* R[ k in (e: map[K]V) ] -> [ e ] == null? false : MapContains([ k ], getCorrespondingMap([ e ]))
*/
override def expression(ctx: Context): in.Expr ==> CodeWriter[vpr.Exp] = {
def goE(x: in.Expr): CodeWriter[vpr.Exp] = ctx.expression(x)
Expand Down Expand Up @@ -85,6 +86,21 @@ class MapEncoding extends LeafTypeEncoding {

case l@in.IndexedExp(_ :: ctx.Map(_, _), _, _) => for {(res, _) <- goMapLookup(l)(ctx)} yield res

case l@in.Contains(key, exp :: ctx.Map(keys, values)) =>
for {
mapVpr <- goE(exp)
keyVpr <- goE(key)
isComp <- MapEncoding.checkKeyComparability(key)(ctx)
correspondingMap <- getCorrespondingMap(exp, keys, values)(ctx)
containsExp =
withSrc(vpr.CondExp(
withSrc(vpr.EqCmp(mapVpr, withSrc(vpr.NullLit(), l)), l),
withSrc(vpr.FalseLit(), l),
withSrc(vpr.MapContains(keyVpr, correspondingMap), l)
), l)
checkCompAndContains <- assert(isComp, containsExp, comparabilityErrorT)(ctx)
} yield checkCompAndContains

case k@in.MapKeys(mapExp :: ctx.Map(keys, values), _) =>
for {
vprMap <- goE(mapExp)
Expand All @@ -111,6 +127,50 @@ class MapEncoding extends LeafTypeEncoding {
}
}

/**
* Encodes expressions when they occur as the top-level expression in a trigger.
* Notice that using the expression encoding for the following triggers,
* results in ill-formed triggers at the Viper level (e.g., because
* they have ternary operations).
* { m[i] } -> { getCorrespondingMap([ m ])[ [ i ] ] }
* { k in m } -> { [ k ] in getCorrespondingMap([ m ]) }
* { k in domain(m) } -> { [ k ] in domain(getCorrespondingMap([ m ])) }
* { k in range(m) } -> { [ k ] in range(getCorrespondingMap([ m ])) }
*/
override def triggerExpr(ctx: Context): in.TriggerExpr ==> CodeWriter[vpr.Exp] = {
default(super.triggerExpr(ctx)) {
case l@in.IndexedExp(m :: ctx.Map(keys, values), idx, _) =>
for {
vIdx <- ctx.expression(idx)
correspondingMap <- getCorrespondingMap(m, keys, values)(ctx)
lookupRes = withSrc(vpr.MapLookup(correspondingMap, vIdx), l)
} yield lookupRes

case l@in.Contains(key, m :: ctx.Map(keys, values)) =>
for {
vKey <- ctx.expression(key)
correspondingMap <- getCorrespondingMap(m, keys, values)(ctx)
contains = withSrc(vpr.MapContains(vKey, correspondingMap), l)
} yield contains

case l@in.Contains(key, in.MapKeys(m :: ctx.Map(keys, values), _)) =>
for {
vKey <- ctx.expression(key)
correspondingMap <- getCorrespondingMap(m, keys, values)(ctx)
vDomainMap = withSrc(vpr.MapDomain(correspondingMap), l)
contains = withSrc(vpr.AnySetContains(vKey, vDomainMap), l)
} yield contains

case l@in.Contains(key, in.MapValues(m :: ctx.Map(keys, values), _)) =>
for {
vKey <- ctx.expression(key)
correspondingMap <- getCorrespondingMap(m, keys, values)(ctx)
vRangeMap = withSrc(vpr.MapRange(correspondingMap), l)
contains = withSrc(vpr.AnySetContains(vKey, vRangeMap), l)
} yield contains
}
}

/**
* Encodes the allocation of a new map
* [r := make(map[T1]T2, n)] ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,20 +112,10 @@ class AssertionEncoding extends Encoding {

def trigger(trigger: in.Trigger)(ctx: Context) : CodeWriter[vpr.Trigger] = {
val (pos, info, errT) = trigger.vprMeta
for { expr <- sequence(trigger.exprs map (triggerExpr(_)(ctx))) }
for { expr <- sequence(trigger.exprs map ctx.triggerExpr)}
yield vpr.Trigger(expr)(pos, info, errT)
}

def triggerExpr(expr: in.TriggerExpr)(ctx: Context): CodeWriter[vpr.Exp] = expr match {
// use predicate access encoding but then take just the predicate access, i.e. remove `acc` and the permission amount:
case in.Accessible.Predicate(op) =>
for {
v <- ctx.assertion(in.Access(in.Accessible.Predicate(op), in.FullPerm(op.info))(op.info))
pap = v.asInstanceOf[vpr.PredicateAccessPredicate]
} yield pap.loc
case e: in.Expr => ctx.expression(e)
}

def quantifier(vars: Vector[in.BoundVar], triggers: Vector[in.Trigger], body: in.Expr)(ctx: Context) : CodeWriter[(Seq[vpr.LocalVarDecl], Seq[vpr.Trigger], vpr.Exp)] = {
val newVars = vars map ctx.variable

Expand Down
9 changes: 7 additions & 2 deletions src/main/scala/viper/gobra/translator/util/ViperWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -399,8 +399,13 @@ object ViperWriter {
/* Can be used in expressions. */
def assert(cond: vpr.Exp, exp: vpr.Exp, reasonT: (Source.Verifier.Info, ErrorReason) => VerificationError)(ctx: Context): Writer[vpr.Exp] = {
// In the future, this might do something more sophisticated
val (res, errT) = ctx.condition.assert(cond, exp, reasonT)
errorT(errT).map(_ => res)
cond match {
case vpr.TrueLit() =>
unit(exp)
case _ =>
val (res, errT) = ctx.condition.assert(cond, exp, reasonT)
errorT(errT).map(_ => res)
}
}

/* Emits Viper statements. */
Expand Down
25 changes: 20 additions & 5 deletions src/test/resources/regressions/features/maps/maps-simple1.gobra
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ func test6() {
m[3] = 10
v3, ok3 := m[3]
assert ok3 && v3 == 10

// check if key exists in the map
assert 3 in m
}

type T struct {
Expand Down Expand Up @@ -110,24 +113,24 @@ func test11() {
requires acc(m, _)
requires "key" in domain(m)
func test12(m map[string]string) (r string){
return m["key"]
return m["key"]
}

requires acc(m, _)
requires "value" in range(m)
func test13(m map[string]string) {
assert exists k string :: m[k] == "value"
assert exists k string :: m[k] == "value"
}

func test14() (res map[int]int) {
x := 1
y := 2
m := map[int]int{x: y, y: x}
m := map[int]int{x: y, y: x}
return m
}

func test15() (res map[int]int) {
m := map[int]int{C1: C2, C2: C1}
m := map[int]int{C1: C2, C2: C1}
return m
}

Expand All @@ -137,4 +140,16 @@ func test16() {
assert x == 0
x, contained := m[2]
assert x == 0 && !contained
}
}

requires m != nil ==> acc(m)
requires forall s string :: { s in domain(m) } s in domain(m) ==> acc(m[s])
func test17(m map[string]*int) {}

requires m != nil ==> acc(m)
requires forall s string :: { s in m } s in m ==> acc(m[s])
func test18(m map[string]*int) {}

requires m != nil ==> acc(m)
requires forall i int :: { i in range(m) } i in range(m) ==> 0 < i
func test19(m map[string]int) {}

0 comments on commit 8895ec3

Please sign in to comment.