@@ -497,7 +497,7 @@ class Traverser(
497
497
val declaringClass = field.declaringClass
498
498
499
499
val updates = if (declaringClass.isEnum) {
500
- makeConcreteUpdatesForEnums (fieldId, declaringClass, stmt)
500
+ makeConcreteUpdatesForEnumsWithStmt (fieldId, declaringClass, stmt)
501
501
} else {
502
502
makeConcreteUpdatesForNonEnumStaticField(field, fieldId, declaringClass, stmt)
503
503
}
@@ -518,13 +518,10 @@ class Traverser(
518
518
return true
519
519
}
520
520
521
- @Suppress(" UnnecessaryVariable" )
522
- private fun makeConcreteUpdatesForEnums (
523
- fieldId : FieldId ,
524
- declaringClass : SootClass ,
525
- stmt : Stmt
526
- ): SymbolicStateUpdate {
527
- val type = declaringClass.type
521
+ private fun makeConcreteUpdatesForEnum (
522
+ type : RefType ,
523
+ fieldId : FieldId ? = null
524
+ ): Pair <SymbolicStateUpdate , SymbolicValue ?> {
528
525
val jClass = type.id.jClass
529
526
530
527
// symbolic value for enum class itself
@@ -545,7 +542,7 @@ class Traverser(
545
542
546
543
val (staticFieldUpdates, curFieldSymbolicValueForLocalVariable) = makeEnumStaticFieldsUpdates(
547
544
staticFields,
548
- declaringClass ,
545
+ type.sootClass ,
549
546
enumConstantSymbolicResultsByName,
550
547
enumConstantSymbolicValues,
551
548
enumClassValue,
@@ -564,14 +561,25 @@ class Traverser(
564
561
565
562
val initializedStaticFieldsMemoryUpdate = MemoryUpdate (
566
563
initializedStaticFields = staticFields.associate { it.first.fieldId to it.second.single() }.toPersistentMap(),
567
- meaningfulStaticFields = meaningfulStaticFields.map { it.first.fieldId }.toPersistentSet()
564
+ meaningfulStaticFields = meaningfulStaticFields.map { it.first.fieldId }.toPersistentSet(),
565
+ symbolicEnumValues = enumConstantSymbolicValues.toPersistentList()
568
566
)
569
567
570
- val allUpdates = staticFieldUpdates +
571
- nonStaticFieldsUpdates +
572
- initializedStaticFieldsMemoryUpdate +
573
- createConcreteLocalValueUpdate(stmt, curFieldSymbolicValueForLocalVariable)
568
+ return Pair (
569
+ staticFieldUpdates + nonStaticFieldsUpdates + initializedStaticFieldsMemoryUpdate,
570
+ curFieldSymbolicValueForLocalVariable
571
+ )
572
+ }
574
573
574
+ @Suppress(" UnnecessaryVariable" )
575
+ private fun makeConcreteUpdatesForEnumsWithStmt (
576
+ fieldId : FieldId ,
577
+ declaringClass : SootClass ,
578
+ stmt : Stmt
579
+ ): SymbolicStateUpdate {
580
+ val (enumUpdates, curFieldSymbolicValueForLocalVariable) =
581
+ makeConcreteUpdatesForEnum(declaringClass.type, fieldId)
582
+ val allUpdates = enumUpdates + createConcreteLocalValueUpdate(stmt, curFieldSymbolicValueForLocalVariable)
575
583
return allUpdates
576
584
}
577
585
@@ -1373,6 +1381,39 @@ class Traverser(
1373
1381
queuedSymbolicStateUpdates = queuedSymbolicStateUpdates.copy(memoryUpdates = MemoryUpdate ())
1374
1382
}
1375
1383
1384
+ /* *
1385
+ * Return a symbolic value of the ordinal corresponding to the enum value with the given address.
1386
+ */
1387
+ private fun findEnumOrdinal (type : RefType , addr : UtAddrExpression ): PrimitiveValue {
1388
+ val array = memory.findArray(MemoryChunkDescriptor (ENUM_ORDINAL , type, IntType .v()))
1389
+ return array.select(addr).toIntValue()
1390
+ }
1391
+
1392
+ /* *
1393
+ * Initialize enum class: create symbolic values for static enum values and generate constraints
1394
+ * that restrict the new instance to match one of enum values.
1395
+ */
1396
+ private fun initEnum (type : RefType , addr : UtAddrExpression , ordinal : PrimitiveValue ) {
1397
+ val classId = type.id
1398
+ var predefinedEnumValues = memory.getSymbolicEnumValues(classId)
1399
+ if (predefinedEnumValues.isEmpty()) {
1400
+ val (enumValuesUpdate, _) = makeConcreteUpdatesForEnum(type)
1401
+ queuedSymbolicStateUpdates + = enumValuesUpdate
1402
+ predefinedEnumValues = enumValuesUpdate.memoryUpdates.getSymbolicEnumValues(classId)
1403
+ }
1404
+
1405
+ val enumValueConstraints = mkOr(
1406
+ listOf (addrEq(addr, nullObjectAddr)) + predefinedEnumValues.map {
1407
+ mkAnd(
1408
+ addrEq(addr, it.addr),
1409
+ mkEq(ordinal, findEnumOrdinal(it.type, it.addr))
1410
+ )
1411
+ }
1412
+ )
1413
+
1414
+ queuedSymbolicStateUpdates + = mkOr(enumValueConstraints).asHardConstraint()
1415
+ }
1416
+
1376
1417
private fun arrayInstanceOf (value : ArrayValue , checkType : Type ): PrimitiveValue {
1377
1418
val notNullConstraint = mkNot(addrEq(value.addr, nullObjectAddr))
1378
1419
@@ -1530,6 +1571,11 @@ class Traverser(
1530
1571
queuedSymbolicStateUpdates + = typeConstraint.asHardConstraint()
1531
1572
queuedSymbolicStateUpdates + = typeRegistry.zeroDimensionConstraint(objectValue.addr).asHardConstraint()
1532
1573
1574
+ // If we are casting to an enum class, we should initialize enum values and add value equality constraints
1575
+ if (typeAfterCast.sootClass?.isEnum == true ) {
1576
+ initEnum(typeAfterCast, castedObject.addr, findEnumOrdinal(typeAfterCast, castedObject.addr))
1577
+ }
1578
+
1533
1579
// TODO add memory constraints JIRA:1523
1534
1580
return castedObject
1535
1581
}
@@ -1978,13 +2024,13 @@ class Traverser(
1978
2024
1979
2025
queuedSymbolicStateUpdates + = typeRegistry.typeConstraint(addr, typeStorage).all().asHardConstraint()
1980
2026
1981
- val array = memory.findArray(MemoryChunkDescriptor (ENUM_ORDINAL , type, IntType .v()))
1982
- val ordinal = array.select(addr).toIntValue()
2027
+ val ordinal = findEnumOrdinal(type, addr)
1983
2028
val enumSize = classLoader.loadClass(type.sootClass.name).enumConstants.size
1984
2029
1985
2030
queuedSymbolicStateUpdates + = mkOr(Ge (ordinal, 0 ), addrEq(addr, nullObjectAddr)).asHardConstraint()
1986
2031
queuedSymbolicStateUpdates + = mkOr(Lt (ordinal, enumSize), addrEq(addr, nullObjectAddr)).asHardConstraint()
1987
2032
2033
+ initEnum(type, addr, ordinal)
1988
2034
touchAddress(addr)
1989
2035
1990
2036
return ObjectValue (typeStorage, addr)
0 commit comments