Skip to content

Commit

Permalink
Deprecate the misk.crypto.CiphertextFormat class (#2232)
Browse files Browse the repository at this point in the history
* Deprecate the misk.crypto.CiphertextFormat class

* deprecated in a more sensible way

* removed serializeEncryptionContext from the .api file
  • Loading branch information
yoavamit authored Dec 9, 2021
1 parent c3e62c6 commit a68c546
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 185 deletions.
1 change: 0 additions & 1 deletion misk-crypto/api/misk-crypto.api
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ public final class misk/crypto/CiphertextFormat$Companion {
public final fun deserialize ([BLjava/util/Map;)Lkotlin/Pair;
public final fun deserializeFleFormat ([B)Lkotlin/Pair;
public final fun serialize ([B[B)[B
public final fun serializeEncryptionContext (Ljava/util/Map;)[B
}

public final class misk/crypto/CiphertextFormat$EncryptionContextMismatchException : java/security/GeneralSecurityException {
Expand Down
37 changes: 15 additions & 22 deletions misk-crypto/src/main/kotlin/misk/crypto/CiphertextFormat.kt
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ class CiphertextFormat private constructor() {
/**
* Serializes the given [ciphertext] and associated encryption context to a [ByteArray]
*/
@Deprecated(
message = "This function has been moved to its own library. " +
"See https://github.com/squareup/cash-ciphertext-format",
level = DeprecationLevel.WARNING,
replaceWith = ReplaceWith("CiphertextFormat.serialize",
"com.squareup.cash.crypto.format.CiphertextFormat")
)
fun serialize(ciphertext: ByteArray, aad: ByteArray?): ByteArray {
val outputStream = ByteStreams.newDataOutput()
outputStream.writeByte(CURRENT_VERSION)
Expand All @@ -76,6 +83,13 @@ class CiphertextFormat private constructor() {
* This method also compares the given [context] to the serialized AAD
* and will throw an exception if they do not match.
*/
@Deprecated(
message = "This function has been moved to its own library. " +
"See https://github.com/squareup/cash-ciphertext-format",
level = DeprecationLevel.WARNING,
replaceWith = ReplaceWith("CiphertextFormat.deserialize",
"com.squareup.cash.crypto.format.CiphertextFormat")
)
fun deserialize(
serialized: ByteArray,
context: Map<String, String>?
Expand Down Expand Up @@ -109,7 +123,7 @@ class CiphertextFormat private constructor() {
* Serializes the encryption context to a [ByteArray] so it could be passed to Tink's
* encryption/decryption methods.
*/
fun serializeEncryptionContext(context: Map<String, String>?): ByteArray? {
private fun serializeEncryptionContext(context: Map<String, String>?): ByteArray? {
if (context == null || context.isEmpty()) {
return null
}
Expand Down Expand Up @@ -149,27 +163,6 @@ class CiphertextFormat private constructor() {
return aad
}

@VisibleForTesting
internal fun deserializeEncryptionContext(aad: ByteArray?): Map<String, String>? {
if (aad == null) {
return null
}
val src = DataInputStream(ByteArrayInputStream(aad))
val entries = decodeVarInt(src)
if (entries == 0) {
return null
}
return (1..entries).map {
val keySize = decodeVarInt(src)
val keyBytes = ByteArray(keySize)
src.readFully(keyBytes)
val valueSize = decodeVarInt(src)
val valueBytes = ByteArray(valueSize)
src.readFully(valueBytes)
keyBytes.toString(Charsets.UTF_8) to valueBytes.toString(Charsets.UTF_8)
}.toMap()
}

private fun readCiphertext(src: DataInputStream): ByteArray {
val ciphertextStream = ByteArrayOutputStream()
var readByte = src.read()
Expand Down
162 changes: 0 additions & 162 deletions misk-crypto/src/test/kotlin/misk/crypto/CiphertextFormatTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -13,171 +13,9 @@ import java.util.UUID
class CiphertextFormatTest {

companion object {
private const val VERSION_INDEX = 0
private const val EC_LENGTH_INDEX = 1
private val fauxCiphertext = byteArrayOf(1, 2, 3, 4, 5, 6, 7, 8, 9, 0)
}

@Test
fun testBasicEncryptionContextSerialization() {
val context = mapOf(
"table_name" to "unimportant",
"database_name" to "unimportant",
"key" to "value"
)
val serialized = CiphertextFormat.serializeEncryptionContext(context)
assertThat(CiphertextFormat.deserializeEncryptionContext(serialized))
.isNotNull
.isEqualTo(context)
}

@Test
fun testEncryptionContextSerializationSortsKeys() {
val context = mapOf(
"table_name" to "unimportant",
"database_name" to "unimportant",
"key" to "value"
)
val context2 = linkedMapOf(
"key" to "value",
"database_name" to "unimportant",
"table_name" to "unimportant"
)

assertThat(CiphertextFormat.serializeEncryptionContext(context))
.isEqualTo(CiphertextFormat.serializeEncryptionContext(context2))
}

@Test
fun testEncryptionContextSerializationWithVarInts() {
val context = mutableMapOf<String, String>()
(0..300).forEach { context["$it"] = UUID.randomUUID().toString() }
val serialized = CiphertextFormat.serializeEncryptionContext(context)
assertThat(CiphertextFormat.deserializeEncryptionContext(serialized))
.isNotNull
.isEqualTo(context)
}

@Test
fun testEmptyEncryptionContext() {
assertThat(CiphertextFormat.serializeEncryptionContext(mapOf())).isNull()
}

@Test
fun testNullEncryptionContext() {
assertThat(CiphertextFormat.serializeEncryptionContext(null)).isNull()
}

@Test
fun testEncryptionContextValueTooLong() {
val value = (0..Short.MAX_VALUE).joinToString("") { "a" }
val context = mapOf("key" to value)
assertThatThrownBy { CiphertextFormat.serializeEncryptionContext(context) }
.hasMessage("value is too long")
}

@Test
fun testEncryptionContextKeyTooLong() {
val key = (0..Short.MAX_VALUE).joinToString("") { "a" }
val context = mapOf(key to "value")
assertThatThrownBy { CiphertextFormat.serializeEncryptionContext(context) }
.hasMessage("key is too long")
}

@Test
fun testEncryptionContextTooLong() {
val key = (100..Short.MAX_VALUE).joinToString("") { "a" }
val value = (100..Short.MAX_VALUE).joinToString("") { "a" }
assertThatThrownBy { CiphertextFormat.serializeEncryptionContext(mapOf(key to value)) }
.hasMessage("encryption context is too long")
}

@Test
fun testEncryptionContextWithForbiddenCharacters() {
var context = mapOf("" to "value")
assertThatThrownBy { CiphertextFormat.serializeEncryptionContext(context) }
.hasMessage("empty key or value")
context = mapOf("key" to "")
assertThatThrownBy { CiphertextFormat.serializeEncryptionContext(context) }
.hasMessage("empty key or value")
}

@Test
fun testFromByteArrayWithNoContext() {
val aad = CiphertextFormat.serializeEncryptionContext(null)
val serialized = CiphertextFormat.serialize(fauxCiphertext, aad)
assertThatCode { CiphertextFormat.deserialize(serialized, null) }
.doesNotThrowAnyException()
assertThatCode { CiphertextFormat.deserialize(serialized, mapOf()) }
.doesNotThrowAnyException()
assertThatThrownBy { CiphertextFormat.deserialize(serialized, mapOf("key" to "value")) }
.isInstanceOf(CiphertextFormat.UnexpectedEncryptionContextException::class.java)
}

@Test
fun testFromByteArrayWithContext() {
val context = mapOf("key" to "value")
val aad = CiphertextFormat.serializeEncryptionContext(context)
val serialized = CiphertextFormat.serialize(fauxCiphertext, aad)
assertThatCode { CiphertextFormat.deserialize(serialized, context) }
.doesNotThrowAnyException()
assertThatThrownBy {
CiphertextFormat.deserialize(
serialized,
mapOf("wrong_key" to "wrong_value")
)
}
.isInstanceOf(CiphertextFormat.EncryptionContextMismatchException::class.java)
assertThatThrownBy { CiphertextFormat.deserialize(serialized, null) }
.isInstanceOf(CiphertextFormat.MissingEncryptionContextException::class.java)
assertThatThrownBy { CiphertextFormat.deserialize(serialized, emptyMap()) }
.isInstanceOf(CiphertextFormat.MissingEncryptionContextException::class.java)
}

@Test
fun testFromByteArrayWithLongContext() {
val context = mutableMapOf<String, String>()
(0..300).forEach { context["$it"] = UUID.randomUUID().toString() }
val aad = CiphertextFormat.serializeEncryptionContext(context)
val serialized = CiphertextFormat.serialize(fauxCiphertext, aad)
val (ciphertext, ciphertextAad) = CiphertextFormat.deserialize(serialized, context)
assertThat(ciphertext).isEqualTo(fauxCiphertext)
assertThat(ciphertextAad).isEqualTo(aad)
}

@Test
fun testFromByteArrayWithEmptyContext() {
val context = mapOf<String, String>()
val aad = CiphertextFormat.serializeEncryptionContext(context)
val serialized = CiphertextFormat.serialize(fauxCiphertext, aad)
assertThatCode { CiphertextFormat.deserialize(serialized, context) }
.doesNotThrowAnyException()
assertThatCode { CiphertextFormat.deserialize(serialized, null) }
.doesNotThrowAnyException()
assertThatThrownBy { CiphertextFormat.deserialize(serialized, mapOf("key" to "value")) }
.isInstanceOf(CiphertextFormat.UnexpectedEncryptionContextException::class.java)
}

@Test
fun testUnsupportedSchemaVersion() {
val context = mapOf("key" to "value")
val aad = CiphertextFormat.serializeEncryptionContext(context)
val serialized = CiphertextFormat.serialize(fauxCiphertext, aad)
serialized[VERSION_INDEX] = 3
assertThatThrownBy { CiphertextFormat.deserialize(serialized, context) }
.hasMessage("invalid version: 3")
}

@Test
fun testWrongEncryptionContextSize() {
val context = mapOf("key" to "value", "key2" to "value2")
val aad = CiphertextFormat.serializeEncryptionContext(context)
val serialized = CiphertextFormat.serialize(fauxCiphertext, aad)
serialized[EC_LENGTH_INDEX + 1] = 1
assertThatThrownBy { CiphertextFormat.deserialize(serialized, context) }
.hasMessage("encryption context doesn't match")
}

@Test
fun testFromByteArrayV1() {
val context = mapOf("key" to "value")
Expand Down

0 comments on commit a68c546

Please sign in to comment.