Skip to content

Commit

Permalink
Fix #1529
Browse files Browse the repository at this point in the history
  • Loading branch information
mario-bucev committed May 17, 2024
1 parent ace5cad commit 7557fa5
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 16 deletions.
10 changes: 5 additions & 5 deletions core/src/main/scala/stainless/ast/Expressions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -251,23 +251,23 @@ trait Expressions extends inox.ast.Expressions with Types { self: Trees =>

/** $encodingof `array(index)` */
sealed case class ArraySelect(array: Expr, index: Expr) extends Expr with CachingTyped {
override protected def computeType(using s: Symbols): Type = (array.getType, index.getType) match {
case (ArrayType(base), Int32Type()) => base
override protected def computeType(using s: Symbols): Type = getArrayType(array) match {
case ArrayType(base) => checkParamType(index, Int32Type(), base)
case _ => Untyped
}
}

/** $encodingof `array.updated(index, value)` */
sealed case class ArrayUpdated(array: Expr, index: Expr, value: Expr) extends Expr with CachingTyped {
override protected def computeType(using s: Symbols): Type = (array.getType, index.getType) match {
case (ArrayType(base), Int32Type()) => unveilUntyped(ArrayType(s.leastUpperBound(base, value.getType)))
override protected def computeType(using s: Symbols): Type = getArrayType(array) match {
case at @ ArrayType(base) => checkParamTypes(Seq(index, value), Seq(Int32Type(), base), getArrayType(at, ArrayType(s.leastUpperBound(base, value.getType))))
case _ => Untyped
}
}

/** $encodingof `array.length` */
sealed case class ArrayLength(array: Expr) extends Expr with CachingTyped {
override protected def computeType(using s: Symbols): Type = array.getType match {
override protected def computeType(using s: Symbols): Type = getArrayType(array) match {
case ArrayType(_) => Int32Type()
case _ => Untyped
}
Expand Down
4 changes: 4 additions & 0 deletions core/src/main/scala/stainless/ast/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,8 @@ trait Types extends inox.ast.Types { self: Trees =>

sealed case class ArrayType(base: Type) extends Type

protected def getArrayType(tpe: Typed, tpes: Typed*)(using Symbols): Type = tpe.getType match {
case at: ArrayType => checkAllTypes(tpes, at, at)
case _ => Untyped
}
}
25 changes: 15 additions & 10 deletions core/src/main/scala/stainless/extraction/imperative/Trees.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ trait Trees extends oo.Trees with Definitions { self =>
val cellClassDef = s.lookup.get[ClassDef]("stainless.lang.Cell")
(cell1.getType, cell2.getType) match {
case (ClassType(id1, tps1), ClassType(id2, tps2)) if cellClassDef.isDefined && id1 == cellClassDef.get.id && id1 == id2 && tps1 == tps2 => {
UnitType()
UnitType()
}
case _ =>
Untyped
Expand Down Expand Up @@ -96,8 +96,8 @@ trait Trees extends oo.Trees with Definitions { self =>

/** $encodingof `array(index) = value` */
case class ArrayUpdate(array: Expr, index: Expr, value: Expr) extends Expr with CachingTyped {
protected def computeType(using Symbols): Type = array.getType match {
case ArrayType(base) => checkParamTypes(Seq(index, value), Seq(Int32Type(), base), UnitType())
protected def computeType(using Symbols): Type = getArrayType(array) match {
case at @ ArrayType(base) => checkParamTypes(Seq(index, value), Seq(Int32Type(), base), UnitType())
case _ => Untyped
}
}
Expand All @@ -110,34 +110,34 @@ trait Trees extends oo.Trees with Definitions { self =>
/** $encodingof `MutableMap.withDefaultValue[From,To](default)` */
sealed case class MutableMapWithDefault(from: Type, to: Type, default: Expr) extends Expr with CachingTyped {
override protected def computeType(using Symbols): Type = {
checkParamType(default, FunctionType(Seq(), to), unveilUntyped(MutableMapType(from, to)))
checkParamType(default, FunctionType(Seq(), to), getMutableMapType(MutableMapType(from, to)))
}
}

/** $encodingof `map.apply(key)` (or `map(key)`) */
sealed case class MutableMapApply(map: Expr, key: Expr) extends Expr with CachingTyped {
override protected def computeType(using Symbols): Type = map.getType match {
override protected def computeType(using Symbols): Type = getMutableMapType(map) match {
case MutableMapType(from, to) => checkParamType(key, from, to)
case _ => Untyped
}
}

/** $encodingof `map.updated(key, value)` (or `map + (key -> value)`) */
sealed case class MutableMapUpdated(map: Expr, key: Expr, value: Expr) extends Expr with CachingTyped {
override protected def computeType(using Symbols): Type = map.getType match {
case mmt @ MutableMapType(from, to) => checkParamTypes(Seq(key, value), Seq(from, to), MutableMapType(from, to))
override protected def computeType(using Symbols): Type = getMutableMapType(map) match {
case mmt @ MutableMapType(from, to) => checkParamType(key, from, getMutableMapType(mmt, MutableMapType(from, value.getType)))
case _ => Untyped
}
}

/** $encodingof `map.duplicate()` */
sealed case class MutableMapDuplicate(map: Expr) extends Expr with CachingTyped {
override protected def computeType(using Symbols): Type = map.getType
override protected def computeType(using Symbols): Type = getMutableMapType(map)
}

/** $encodingof `map.update(key, value)` (or `map(key) = value`) */
sealed case class MutableMapUpdate(map: Expr, key: Expr, value: Expr) extends Expr with CachingTyped {
override protected def computeType(using Symbols): Type = map.getType match {
override protected def computeType(using Symbols): Type = getMutableMapType(map) match {
case mmt @ MutableMapType(from, to) => checkParamTypes(Seq(key, value), Seq(from, to), UnitType())
case _ => Untyped
}
Expand Down Expand Up @@ -240,6 +240,11 @@ trait Trees extends oo.Trees with Definitions { self =>
new ExprOpsImpl(self)
}

protected def getMutableMapType(tpe: Typed, tpes: Typed*)(using s: Symbols): Type =
widenTypeParameter(s.leastUpperBound(tpe +: tpes map (_.getType))) match {
case mt: MutableMapType => mt
case _ => Untyped
}

/* ========================================
* EXTRACTORS
Expand Down Expand Up @@ -293,7 +298,7 @@ trait Printer extends oo.Printer {
case Swap(array1, index1, array2, index2) =>
p"swap($array1, $index1, $array2, $index2)"

case CellSwap(cell1, cell2) =>
case CellSwap(cell1, cell2) =>
p"swap($cell1, $cell2)"

case LetVar(vd, value, expr) =>
Expand Down
6 changes: 5 additions & 1 deletion core/src/main/scala/stainless/extraction/oo/Trees.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ package stainless
package extraction
package oo

import scala.collection.mutable.{Map => MutableMap}
import scala.collection.mutable.Map as MutableMap

trait Trees extends innerfuns.Trees with Definitions { self =>

Expand Down Expand Up @@ -192,6 +192,10 @@ trait Trees extends innerfuns.Trees with Definitions { self =>
case _ => Untyped
}

override protected def getArrayType(tpe: Typed, tpes: Typed*)(using Symbols): Type =
super.getArrayType(widenTypeParameter(tpe), tpes: _*)


/* ========================================
* EXTRACTORS
* ======================================== */
Expand Down
12 changes: 12 additions & 0 deletions frontends/benchmarks/imperative/valid/i1529a.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
object i1529a {
type ByteArray = Array[Byte]

case class ByteArrayWrapper(ba: ByteArray)

def test(baw: ByteArrayWrapper): Unit = {
require(baw.ba.length == 64)
baw.ba(0) = 3
val ba2 = baw.ba.updated(0, 4.toByte)
assert(ba2(0) == 4)
}
}
16 changes: 16 additions & 0 deletions frontends/benchmarks/imperative/valid/i1529b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import stainless.lang.*

object i1529b {
case class Key(i: BigInt)
case class Value(i: BigInt)
type MutableMapAlias = MutableMap[Key, Value]

case class MutableMapWrapper(mm: MutableMapAlias)

def test(mmw: MutableMapWrapper): Unit = {
require(mmw.mm(Key(42)) == Value(24))
mmw.mm(Key(2)) = Value(4)
val mmw2 = MutableMapWrapper(MutableMap.withDefaultValue(() => Value(123)))
assert(mmw2.mm(Key(1)) == Value(123))
}
}
13 changes: 13 additions & 0 deletions frontends/benchmarks/verification/valid/SetTypeAlias.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import stainless.lang.*

object SetTypeAlias {
type SetAlias = Set[Byte]

case class SetAliasWrapper(sa: SetAlias)

def test(saw: SetAliasWrapper, six: SetAliasWrapper): Unit = {
require(saw.sa.contains(42))
val saucisse = saw.sa ++ six.sa
assert(saucisse.contains(42))
}
}

0 comments on commit 7557fa5

Please sign in to comment.