diff --git a/Makefile b/Makefile index d60d307e6..86d51f9fe 100644 --- a/Makefile +++ b/Makefile @@ -19,14 +19,19 @@ SRCS = \ src/SDK/AlgorithmSuite.dfy \ src/SDK/CMM/DefaultCMM.dfy \ src/SDK/CMM/Defs.dfy \ + src/SDK/Deserialize.dfy \ src/SDK/Keyring/AESKeyring.dfy \ src/SDK/Keyring/Defs.dfy \ src/SDK/Keyring/RSAKeyring.dfy \ src/SDK/Materials.dfy \ + src/SDK/MessageHeader.dfy \ + src/SDK/Serialize.dfy \ src/SDK/ToyClient.dfy \ src/StandardLibrary/Base64.dfy \ src/StandardLibrary/StandardLibrary.dfy \ src/StandardLibrary/UInt.dfy \ + src/Util/Streams.dfy \ + src/Util/UTF8.dfy \ SRCV = $(patsubst src/%.dfy, build/%.dfy.verified, $(SRCS)) @@ -37,7 +42,9 @@ SRCDIRS = $(dir $(SRCS)) DEPS = $(foreach dir, $(SRCDIRS), $(wildcard $(dir)/*.cs)) \ $(BCDLL) -.PHONY: all hkdf test noverif clean-build clean +DEPS_CS = $(foreach dir, src/Crypto/ src/ src/SDK/ src/SDK/CMM src/SDK/Keyring src/StandardLibrary src/Util, $(wildcard $(dir)/*.cs)) + +.PHONY: all release build verify buildcs hkdf test clean-build clean all: verify build test @@ -53,6 +60,9 @@ build/%.dfy.verified: src/%.dfy build/Main.exe: $(SRCS) $(DEPS) $(DAFNY) /out:build/Main $(SRCS) $(DEPS) /compile:2 /noVerify /noIncludes && cp $(BCDLL) build/ +buildcs: build/Main.cs + csc /r:System.Numerics.dll /r:$(BCDLL) /target:exe /debug /nowarn:0164 /nowarn:0219 /nowarn:1717 /nowarn:0162 /nowarn:0168 build/Main.cs $(DEPS_CS) /out:build/Main.exe + # TODO: HKDF.dfy hasn't been reviewed yet. # Once it is, re-add: # @@ -68,9 +78,6 @@ lib/%.dll: test: $(DEPS) lit test -q -v -noverif: $(DEPS) - $(DAFNY) /out:build/Main $(SRCS) $(DEPS) /compile:2 /noVerify /noIncludes && cp $(BCDLL) build/ - clean-build: $(RM) -r build/* diff --git a/src/SDK/AlgorithmSuite.dfy b/src/SDK/AlgorithmSuite.dfy index 4aeff1559..5c83f23eb 100644 --- a/src/SDK/AlgorithmSuite.dfy +++ b/src/SDK/AlgorithmSuite.dfy @@ -10,14 +10,22 @@ module AlgorithmSuite { import C = Cipher import Digests - const validIDs: set := {0x0378, 0x0346, 0x0214, 0x0178, 0x0146, 0x0114, 0x0078, 0x0046, 0x0014}; + const VALID_IDS: set := {0x0378, 0x0346, 0x0214, 0x0178, 0x0146, 0x0114, 0x0078, 0x0046, 0x0014}; - newtype ID = x | x in validIDs witness 0x0014 + newtype ID = x | x in VALID_IDS witness 0x0014 { function method KeyLength(): nat { Suite[this].params.keyLen as nat } + function method IVLength(): nat { + Suite[this].params.ivLen as nat + } + + function method TagLength(): nat { + Suite[this].params.tagLen as nat + } + function method SignatureType(): Option { Suite[this].sign } @@ -46,4 +54,26 @@ module AlgorithmSuite { AES_192_GCM_IV12_TAG16_KDFNONE_SIGNONE := AlgSuite(C.AES_GCM_192, Digests.HmacNOSHA, None), AES_128_GCM_IV12_TAG16_KDFNONE_SIGNONE := AlgSuite(C.AES_GCM_128, Digests.HmacNOSHA, None) ] + + /* Suite is intended to have an entry for each possible value of ID. This is stated and checked in three ways. + * - lemma SuiteIsCompletes states and proves the connection between type ID and Suite.Keys + * - lemma ValidIDsAreSuiteKeys states and proves the connected between predicate ValidIDs and Suite.Keys + * - the member functions of ID use the expression `Suite[this]`, whose well-formedness relies on every + * ID being in Suite.Keys + */ + + lemma SuiteIsComplete(id: ID) + ensures id in Suite.Keys + { + } + + lemma ValidIDsAreSuiteKeys() + ensures VALID_IDS == set id | id in Suite.Keys :: id as uint16 + { + forall x | x in VALID_IDS + ensures exists id :: id in Suite.Keys && id as uint16 == x + { + assert x as ID in Suite.Keys; + } + } } diff --git a/src/SDK/CMM/DefaultCMM.dfy b/src/SDK/CMM/DefaultCMM.dfy index 52b3123b3..8956ba939 100644 --- a/src/SDK/CMM/DefaultCMM.dfy +++ b/src/SDK/CMM/DefaultCMM.dfy @@ -70,7 +70,7 @@ module DefaultCMMDef { { return Failure("Could not retrieve materials required for encryption"); } - res := Success(em); + return Success(em); } method DecryptMaterials(alg_id: AlgorithmSuite.ID, edks: seq, enc_ctx: Materials.EncryptionContext) returns (res: Result) @@ -104,7 +104,7 @@ module DefaultCMMDef { return Failure("Could not get materials required for decryption."); } - res := Success(dm); + return Success(dm); } } } diff --git a/src/SDK/Deserialize.dfy b/src/SDK/Deserialize.dfy new file mode 100644 index 000000000..3c74f6ed7 --- /dev/null +++ b/src/SDK/Deserialize.dfy @@ -0,0 +1,343 @@ +include "MessageHeader.dfy" +include "Materials.dfy" +include "AlgorithmSuite.dfy" + +include "../Util/Streams.dfy" +include "../StandardLibrary/StandardLibrary.dfy" +include "../Util/UTF8.dfy" + +/* + * The message header deserialization + * + * The message header is deserialized from a uint8 stream. + * When encountering an error, we stop and return it immediately, leaving the remaining inputs on the stream + */ +module Deserialize { + export + provides DeserializeHeader + provides Streams, StandardLibrary, UInt, AlgorithmSuite, Msg + + import Msg = MessageHeader + + import AlgorithmSuite + import Streams + import opened StandardLibrary + import opened UInt = StandardLibrary.UInt + import UTF8 + import Materials + + + method DeserializeHeader(rd: Streams.StringReader) returns (res: Result) + requires rd.Valid() + modifies rd + ensures rd.Valid() + ensures match res + case Success(header) => header.Valid() + case Failure(_) => true + { + var hb :- DeserializeHeaderBody(rd); + var auth :- DeserializeHeaderAuthentication(rd, hb.algorithmSuiteID); + return Success(Msg.Header(hb, auth)); + } + + /** + * Reads raw header data from the input stream and populates the header with all of the information about the + * message. + */ + method DeserializeHeaderBody(rd: Streams.StringReader) returns (ret: Result) + requires rd.Valid() + modifies rd + ensures rd.Valid() + ensures match ret + case Success(hb) => hb.Valid() + case Failure(_) => true + { + var version :- DeserializeVersion(rd); + var typ :- DeserializeType(rd); + var algorithmSuiteID :- DeserializeAlgorithmSuiteID(rd); + var messageID :- DeserializeMsgID(rd); + var aad :- DeserializeAAD(rd); + var encryptedDataKeys :- DeserializeEncryptedDataKeys(rd); + var contentType :- DeserializeContentType(rd); + var reserved :- DeserializeReserved(rd); + var ivLength :- rd.ReadByte(); + var frameLength :- rd.ReadUInt32(); + + // inter-field checks + if ivLength as nat != algorithmSuiteID.IVLength() { + return Failure("Deserialization Error: Incorrect IV length."); + } + if contentType.NonFramed? && frameLength != 0 { + return Failure("Deserialization Error: Frame length must be 0 when content type is non-framed."); + } else if contentType.Framed? && frameLength == 0 { + return Failure("Deserialization Error: Frame length must be non-0 when content type is framed."); + } + + var hb := Msg.HeaderBody( + version, + typ, + algorithmSuiteID, + messageID, + aad, + encryptedDataKeys, + contentType, + reserved, + ivLength, + frameLength); + return Success(hb); + } + + /* + * Reads IV length and auth tag of the lengths specified by algorithmSuiteID. + */ + method DeserializeHeaderAuthentication(rd: Streams.StringReader, algorithmSuiteID: AlgorithmSuite.ID) returns (ret: Result) + requires rd.Valid() + requires algorithmSuiteID in AlgorithmSuite.Suite.Keys + modifies rd + ensures rd.Valid() + ensures match ret + case Success(ha) => + && |ha.iv| == algorithmSuiteID.IVLength() + && |ha.authenticationTag| == algorithmSuiteID.TagLength() + case Failure(_) => true + { + var iv :- rd.ReadExact(algorithmSuiteID.IVLength()); + var authenticationTag :- rd.ReadExact(algorithmSuiteID.TagLength()); + return Success(Msg.HeaderAuthentication(iv, authenticationTag)); + } + + /* + * Methods for deserializing pieces of the message header. + */ + + method DeserializeVersion(rd: Streams.StringReader) returns (ret: Result) + requires rd.Valid() + modifies rd + ensures rd.Valid() + { + var version :- rd.ReadByte(); + if version == Msg.VERSION_1 { + return Success(version); + } else { + return Failure("Deserialization Error: Version not supported."); + } + } + + method DeserializeType(rd: Streams.StringReader) returns (ret: Result) + requires rd.Valid() + modifies rd + ensures rd.Valid() + { + var typ :- rd.ReadByte(); + if typ == Msg.TYPE_CUSTOMER_AED { + return Success(typ); + } else { + return Failure("Deserialization Error: Type not supported."); + } + } + + method DeserializeAlgorithmSuiteID(rd: Streams.StringReader) returns (ret: Result) + requires rd.Valid() + modifies rd + ensures rd.Valid() + { + var algorithmSuiteID :- rd.ReadUInt16(); + if algorithmSuiteID in AlgorithmSuite.VALID_IDS { + return Success(algorithmSuiteID as AlgorithmSuite.ID); + } else { + return Failure("Deserialization Error: Algorithm suite not supported."); + } + } + + method DeserializeMsgID(rd: Streams.StringReader) returns (ret: Result) + requires rd.Valid() + modifies rd + ensures rd.Valid() + { + var msgID: seq :- rd.ReadExact(Msg.MESSAGE_ID_LEN); + return Success(msgID); + } + + method DeserializeUTF8(rd: Streams.StringReader, n: nat) returns (ret: Result>) + requires rd.Valid() + modifies rd + ensures rd.Valid() + ensures match ret + case Success(bytes) => + && |bytes| == n + && UTF8.ValidUTF8Seq(bytes) + case Failure(_) => true + { + var bytes :- rd.ReadExact(n); + if UTF8.ValidUTF8Seq(bytes) { + return Success(bytes); + } else { + return Failure("Deserialization Error: Not a valid UTF8 string."); + } + } + + method DeserializeAAD(rd: Streams.StringReader) returns (ret: Result) + requires rd.Valid() + modifies rd + ensures rd.Valid() + ensures match ret + case Success(aad) => Msg.ValidAAD(aad) + case Failure(_) => true + { + reveal Msg.ValidAAD(); + + var aadLength :- rd.ReadUInt16(); + if aadLength == 0 { + return Success([]); + } else if aadLength < 2 { + return Failure("Deserialization Error: The number of bytes in encryption context exceeds the given length."); + } + var totalBytesRead := 0; + + var kvPairsCount :- rd.ReadUInt16(); + totalBytesRead := totalBytesRead + 2; + if kvPairsCount == 0 { + return Failure("Deserialization Error: Key value pairs count is 0."); + } + + var kvPairs: seq<(seq, seq)> := []; + var i := 0; + while i < kvPairsCount + invariant rd.Valid() + invariant |kvPairs| == i as int + invariant i <= kvPairsCount + invariant totalBytesRead == 2 + Msg.KVPairsLength(kvPairs, 0, i as nat) <= aadLength as nat + invariant Msg.ValidAAD(kvPairs) + { + var keyLength :- rd.ReadUInt16(); + totalBytesRead := totalBytesRead + 2; + + var key :- DeserializeUTF8(rd, keyLength as nat); + totalBytesRead := totalBytesRead + |key|; + + var valueLength :- rd.ReadUInt16(); + totalBytesRead := totalBytesRead + 2; + // check that we're not exceeding the stated AAD length + if aadLength as nat < totalBytesRead + valueLength as nat { + return Failure("Deserialization Error: The number of bytes in encryption context exceeds the given length."); + } + + var value :- DeserializeUTF8(rd, valueLength as nat); + totalBytesRead := totalBytesRead + |value|; + + // We want to keep entries sorted by key. We don't insist that the entries be sorted + // already, but we do insist there are no duplicate keys. + var opt, insertionPoint := InsertNewEntry(kvPairs, key, value); + match opt { + case Some(kvPairs_) => + Msg.KVPairsLengthInsert(kvPairs, insertionPoint, key, value); + kvPairs := kvPairs_; + case None => + return Failure("Deserialization Error: Duplicate key."); + } + + i := i + 1; + } + if aadLength as nat != totalBytesRead { + return Failure("Deserialization Error: Bytes actually read differs from bytes supposed to be read."); + } + return Success(kvPairs); + } + + method InsertNewEntry(kvPairs: seq<(seq, seq)>, key: seq, value: seq) + returns (res: Option, seq)>>, ghost insertionPoint: nat) + requires Msg.SortedKVPairs(kvPairs) + ensures match res + case None => + exists i :: 0 <= i < |kvPairs| && kvPairs[i].0 == key // key already exists + case Some(kvPairs') => + && insertionPoint <= |kvPairs| + && kvPairs' == kvPairs[..insertionPoint] + [(key, value)] + kvPairs[insertionPoint..] + && Msg.SortedKVPairs(kvPairs') + { + var n := |kvPairs|; + while 0 < n && LexicographicLessOrEqual(key, kvPairs[n - 1].0, UInt8Less) + invariant 0 <= n <= |kvPairs| + invariant forall i :: n <= i < |kvPairs| ==> LexicographicLessOrEqual(key, kvPairs[i].0, UInt8Less) + { + n := n - 1; + } + if 0 < n && kvPairs[n - 1].0 == key { + return None, n; + } else { + var kvPairs' := kvPairs[..n] + [(key, value)] + kvPairs[n..]; + if 0 < n { + LexPreservesTrichotomy(kvPairs'[n - 1].0, kvPairs'[n].0, UInt8Less); + } + return Some(kvPairs'), n; + } + } + + method DeserializeEncryptedDataKeys(rd: Streams.StringReader) returns (ret: Result) + requires rd.Valid() + modifies rd + ensures rd.Valid() + ensures match ret + case Success(edks) => edks.Valid() + case Failure(_) => true + { + var edkCount :- rd.ReadUInt16(); + if edkCount == 0 { + return Failure("Deserialization Error: Encrypted data key count is 0."); + } + + var edkEntries: seq := []; + var i := 0; + while i < edkCount + invariant rd.Valid() + invariant i <= edkCount + invariant |edkEntries| == i as int + invariant forall i :: 0 <= i < |edkEntries| ==> edkEntries[i].Valid() + { + // Key provider ID + var keyProviderIDLength :- rd.ReadUInt16(); + var str :- DeserializeUTF8(rd, keyProviderIDLength as nat); + var keyProviderID := ByteSeqToString(str); + + // Key provider info + var keyProviderInfoLength :- rd.ReadUInt16(); + var keyProviderInfo :- rd.ReadExact(keyProviderInfoLength as nat); + + // Encrypted data key + var edkLength :- rd.ReadUInt16(); + var edk :- rd.ReadExact(edkLength as nat); + + edkEntries := edkEntries + [Materials.EncryptedDataKey(keyProviderID, keyProviderInfo, edk)]; + i := i + 1; + } + + var edks := Msg.EncryptedDataKeys(edkEntries); + return Success(edks); + } + + method DeserializeContentType(rd: Streams.StringReader) returns (ret: Result) + requires rd.Valid() + modifies rd + ensures rd.Valid() + { + var byte :- rd.ReadByte(); + match Msg.UInt8ToContentType(byte) + case None => + return Failure("Deserialization Error: Content type not supported."); + case Some(contentType) => + return Success(contentType); + } + + method DeserializeReserved(rd: Streams.StringReader) returns (ret: Result) + requires rd.Valid() + modifies rd + ensures rd.Valid() + { + var reserved :- rd.ReadExact(4); + if reserved[0] == reserved[1] == reserved[2] == reserved[3] == 0 { + return Success(reserved[..]); + } else { + return Failure("Deserialization Error: Reserved fields must be 0."); + } + } +} diff --git a/src/SDK/Keyring/AESKeyring.dfy b/src/SDK/Keyring/AESKeyring.dfy index c00d54637..e1956e506 100644 --- a/src/SDK/Keyring/AESKeyring.dfy +++ b/src/SDK/Keyring/AESKeyring.dfy @@ -1,4 +1,3 @@ -include "../MessageHeader/Definitions.dfy" include "../../StandardLibrary/StandardLibrary.dfy" include "../../StandardLibrary/UInt.dfy" include "../AlgorithmSuite.dfy" diff --git a/src/SDK/Keyring/MultiKeyring.dfy b/src/SDK/Keyring/MultiKeyring.dfy index 264adaf83..a3b891597 100644 --- a/src/SDK/Keyring/MultiKeyring.dfy +++ b/src/SDK/Keyring/MultiKeyring.dfy @@ -1,4 +1,3 @@ -include "../MessageHeader/Definitions.dfy" include "../../StandardLibrary/StandardLibrary.dfy" include "../../StandardLibrary/UInt.dfy" include "../AlgorithmSuite.dfy" diff --git a/src/SDK/Keyring/RSAKeyring.dfy b/src/SDK/Keyring/RSAKeyring.dfy index 95352efee..6935b78be 100644 --- a/src/SDK/Keyring/RSAKeyring.dfy +++ b/src/SDK/Keyring/RSAKeyring.dfy @@ -61,7 +61,7 @@ module RSAKeyringDef { ensures res.Failure? ==> unchanged(encMat) { if encryptionKey.None? { - res := Failure("Encryption key undefined"); + return Failure("Encryption key undefined"); } else { var dataKey := encMat.plaintextDataKey; var algorithmID := encMat.algorithmSuiteID; diff --git a/src/SDK/Materials.dfy b/src/SDK/Materials.dfy index aa3a474ae..de69ca491 100644 --- a/src/SDK/Materials.dfy +++ b/src/SDK/Materials.dfy @@ -10,11 +10,24 @@ module Materials { type EncryptionContext = seq<(seq, seq)> + function method GetKeysFromEncryptionContext(encryptionContext: EncryptionContext): set> { + set i | 0 <= i < |encryptionContext| :: encryptionContext[i].0 + } + const EC_PUBLIC_KEY_FIELD: seq := StringToByteSeq("aws-crypto-public-key"); + ghost const ReservedKeyValues := { EC_PUBLIC_KEY_FIELD } datatype EncryptedDataKey = EncryptedDataKey(providerID : string, providerInfo : seq, ciphertext : seq) + { + predicate Valid() { + StringIs8Bit(providerID) && + |providerID| < UINT16_LIMIT && + |providerInfo| < UINT16_LIMIT && + |ciphertext| < UINT16_LIMIT + } + } // TODO: Add keyring trace class EncryptionMaterials { @@ -188,6 +201,6 @@ module Materials { function method enc_ctx_of_strings(x : seq<(string, string)>) : seq<(seq, seq)> { if x == [] then [] else - [(byteseq_of_string_lossy(x[0].0), byteseq_of_string_lossy(x[0].1))] + enc_ctx_of_strings(x[1..]) + [(StringToByteSeqLossy(x[0].0), StringToByteSeqLossy(x[0].1))] + enc_ctx_of_strings(x[1..]) } } diff --git a/src/SDK/MessageHeader.dfy b/src/SDK/MessageHeader.dfy new file mode 100644 index 000000000..c1f1aac11 --- /dev/null +++ b/src/SDK/MessageHeader.dfy @@ -0,0 +1,339 @@ +include "AlgorithmSuite.dfy" +include "../StandardLibrary/StandardLibrary.dfy" +include "Materials.dfy" +include "../Util/UTF8.dfy" + +module MessageHeader { + import AlgorithmSuite + import opened StandardLibrary + import opened UInt = StandardLibrary.UInt + import Materials + import UTF8 + + /* + * Definition of the message header, i.e., the header body and the header authentication + */ + + datatype Header = Header(body: HeaderBody, auth: HeaderAuthentication) + { + predicate Valid() { + && body.Valid() + && |auth.iv| == body.algorithmSuiteID.IVLength() + && |auth.authenticationTag| == body.algorithmSuiteID.TagLength() + } + } + + /* + * Header body type definition + */ + + const VERSION_1: uint8 := 0x01 + type Version = x | x == VERSION_1 witness VERSION_1 + + const TYPE_CUSTOMER_AED: uint8 := 0x80 + type Type = x | x == TYPE_CUSTOMER_AED witness TYPE_CUSTOMER_AED + + const MESSAGE_ID_LEN := 16 + type MessageID = x: seq | |x| == MESSAGE_ID_LEN witness [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0] + + type Reserved = x: seq | x == [0,0,0,0] witness [0,0,0,0] + + datatype ContentType = NonFramed | Framed + + function method ContentTypeToUInt8(contentType: ContentType): uint8 { + match contentType + case NonFramed => 0x01 + case Framed => 0x02 + } + + function method UInt8ToContentType(x: uint8): Option { + if x == 0x01 then + Some(NonFramed) + else if x == 0x02 then + Some(Framed) + else + None + } + + lemma ContentTypeConversionsCorrect(contentType: ContentType, x: uint8) + ensures UInt8ToContentType(ContentTypeToUInt8(contentType)) == Some(contentType) + ensures var opt := UInt8ToContentType(x); opt == None || ContentTypeToUInt8(opt.get) == x + { + } + + datatype EncryptedDataKeys = EncryptedDataKeys(entries: seq) + { + predicate Valid() { + && 0 < |entries| < UINT16_LIMIT + && (forall i :: 0 <= i < |entries| ==> entries[i].Valid()) + } + } + + datatype HeaderBody = HeaderBody( + version: Version, + typ: Type, + algorithmSuiteID: AlgorithmSuite.ID, + messageID: MessageID, + aad: Materials.EncryptionContext, + encryptedDataKeys: EncryptedDataKeys, + contentType: ContentType, + reserved: Reserved, + ivLength: uint8, + frameLength: uint32) + { + predicate Valid() { + && ValidAAD(aad) + && encryptedDataKeys.Valid() + && algorithmSuiteID.IVLength() == ivLength as nat + && ValidFrameLength(frameLength, contentType) + } + } + + /* + * Header authentication type definition + */ + + datatype HeaderAuthentication = HeaderAuthentication(iv: seq, authenticationTag: seq) + + /* + * Validity predicates -- predicates that say when the data structures above are in a good state. + */ + + predicate ValidKVPair(kvPair: (seq, seq)) { + && |kvPair.0| < UINT16_LIMIT + && |kvPair.1| < UINT16_LIMIT + && UTF8.ValidUTF8Seq(kvPair.0) + && UTF8.ValidUTF8Seq(kvPair.1) + } + + function KVPairsLength(kvPairs: Materials.EncryptionContext, lo: nat, hi: nat): nat + requires lo <= hi <= |kvPairs| + { + if lo == hi then 0 else + KVPairsLength(kvPairs, lo, hi - 1) + + 2 + |kvPairs[hi - 1].0| + + 2 + |kvPairs[hi - 1].1| + } + + lemma KVPairsLengthSplit(kvPairs: Materials.EncryptionContext, lo: nat, mid: nat, hi: nat) + requires lo <= mid <= hi <= |kvPairs| + ensures KVPairsLength(kvPairs, lo, hi) + == KVPairsLength(kvPairs, lo, mid) + KVPairsLength(kvPairs, mid, hi) + { + } + + lemma KVPairsLengthPrefix(kvPairs: Materials.EncryptionContext, more: Materials.EncryptionContext) + ensures KVPairsLength(kvPairs + more, 0, |kvPairs|) == KVPairsLength(kvPairs, 0, |kvPairs|) + { + var n := |kvPairs|; + if n == 0 { + } else { + var last := kvPairs[n - 1]; + calc { + KVPairsLength(kvPairs + more, 0, n); + == // def. KVPairsLength + KVPairsLength(kvPairs + more, 0, n - 1) + 4 + |last.0| + |last.1|; + == { assert kvPairs + more == kvPairs[..n - 1] + ([last] + more); } + KVPairsLength(kvPairs[..n - 1] + ([last] + more), 0, n - 1) + 4 + |last.0| + |last.1|; + == { KVPairsLengthPrefix(kvPairs[..n - 1], [last] + more); } + KVPairsLength(kvPairs[..n - 1], 0, n - 1) + 4 + |last.0| + |last.1|; + == { KVPairsLengthPrefix(kvPairs[..n - 1], [last] + more); } + KVPairsLength(kvPairs[..n - 1] + [last], 0, n - 1) + 4 + |last.0| + |last.1|; + == { assert kvPairs[..n - 1] + [last] == kvPairs; } + KVPairsLength(kvPairs, 0, n - 1) + 4 + |last.0| + |last.1|; + == // def. KVPairsLength + KVPairsLength(kvPairs, 0, n); + } + } + } + + lemma KVPairsLengthExtend(kvPairs: Materials.EncryptionContext, key: seq, value: seq) + ensures KVPairsLength(kvPairs + [(key, value)], 0, |kvPairs| + 1) + == KVPairsLength(kvPairs, 0, |kvPairs|) + 4 + |key| + |value| + { + KVPairsLengthPrefix(kvPairs, [(key, value)]); + } + + lemma KVPairsLengthInsert(kvPairs: Materials.EncryptionContext, insertionPoint: nat, key: seq, value: seq) + requires insertionPoint <= |kvPairs| + ensures var kvPairs' := kvPairs[..insertionPoint] + [(key, value)] + kvPairs[insertionPoint..]; + KVPairsLength(kvPairs', 0, |kvPairs'|) == KVPairsLength(kvPairs, 0, |kvPairs|) + 4 + |key| + |value| + decreases |kvPairs| + { + var kvPairs' := kvPairs[..insertionPoint] + [(key, value)] + kvPairs[insertionPoint..]; + if |kvPairs| == insertionPoint { + assert kvPairs' == kvPairs + [(key, value)]; + KVPairsLengthExtend(kvPairs, key, value); + } else { + var m := |kvPairs| - 1; + var (d0, d1) := kvPairs[m]; + var a, b, c, d := kvPairs[..insertionPoint], [(key, value)], kvPairs[insertionPoint..m], [(d0, d1)]; + assert kvPairs == a + c + d; + assert kvPairs' == a + b + c + d; + var ac := a + c; + var abc := a + b + c; + calc { + KVPairsLength(kvPairs', 0, |kvPairs'|); + KVPairsLength(abc + [(d0, d1)], 0, |abc| + 1); + == { KVPairsLengthExtend(abc, d0, d1); } + KVPairsLength(abc, 0, |abc|) + 4 + |d0| + |d1|; + == { KVPairsLengthInsert(ac, insertionPoint, key, value); } + KVPairsLength(ac, 0, |ac|) + 4 + |key| + |value| + 4 + |d0| + |d1|; + == { KVPairsLengthExtend(ac, d0, d1); } + KVPairsLength(kvPairs, 0, |kvPairs|) + 4 + |key| + |value|; + } + } + } + + function AADLength(kvPairs: Materials.EncryptionContext): nat { + if |kvPairs| == 0 then 0 else + 2 + KVPairsLength(kvPairs, 0, |kvPairs|) + } + + predicate {:opaque} ValidAAD(kvPairs: Materials.EncryptionContext) { + && |kvPairs| < UINT16_LIMIT + && (forall i :: 0 <= i < |kvPairs| ==> ValidKVPair(kvPairs[i])) + && SortedKVPairs(kvPairs) + && AADLength(kvPairs) < UINT16_LIMIT + } + + predicate ValidFrameLength(frameLength: uint32, contentType: ContentType) { + match contentType + case NonFramed => frameLength == 0 + case Framed => frameLength != 0 + } + + predicate SortedKVPairsUpTo(a: seq<(seq, seq)>, n: nat) + requires n <= |a| + { + forall j :: 0 < j < n ==> LexicographicLessOrEqual(a[j-1].0, a[j].0, UInt8Less) + } + + predicate SortedKVPairs(a: seq<(seq, seq)>) + { + SortedKVPairsUpTo(a, |a|) + } + + /* + * Specifications of serialized format + */ + + function {:opaque} HeaderBodyToSeq(hb: HeaderBody): seq + requires hb.Valid() + { + [hb.version as uint8] + + [hb.typ as uint8] + + UInt16ToSeq(hb.algorithmSuiteID as uint16) + + hb.messageID + + AADToSeq(hb.aad) + + EDKsToSeq(hb.encryptedDataKeys) + + [ContentTypeToUInt8(hb.contentType)] + + hb.reserved + + [hb.ivLength] + + UInt32ToSeq(hb.frameLength) + } + + function AADToSeq(kvPairs: Materials.EncryptionContext): seq + requires ValidAAD(kvPairs) + { + reveal ValidAAD(); + UInt16ToSeq(AADLength(kvPairs) as uint16) + + var n := |kvPairs|; + if n == 0 then [] else + UInt16ToSeq(n as uint16) + + KVPairsToSeq(kvPairs, 0, n) + } + + function KVPairsToSeq(kvPairs: Materials.EncryptionContext, lo: nat, hi: nat): seq + requires forall i :: 0 <= i < |kvPairs| ==> ValidKVPair(kvPairs[i]) + requires lo <= hi <= |kvPairs| + { + if lo == hi then [] else KVPairsToSeq(kvPairs, lo, hi - 1) + KVPairToSeq(kvPairs[hi - 1]) + } + + function KVPairToSeq(kvPair: (seq, seq)): seq + requires ValidKVPair(kvPair) + { + UInt16ToSeq(|kvPair.0| as uint16) + kvPair.0 + + UInt16ToSeq(|kvPair.1| as uint16) + kvPair.1 + } + + function EDKsToSeq(encryptedDataKeys: EncryptedDataKeys): seq + requires encryptedDataKeys.Valid() + { + var n := |encryptedDataKeys.entries|; + UInt16ToSeq(n as uint16) + + EDKEntriesToSeq(encryptedDataKeys.entries, 0, n) + } + + function EDKEntriesToSeq(entries: seq, lo: nat, hi: nat): seq + requires forall i :: 0 <= i < |entries| ==> entries[i].Valid() + requires lo <= hi <= |entries| + { + if lo == hi then [] else EDKEntriesToSeq(entries, lo, hi - 1) + EDKEntryToSeq(entries[hi - 1]) + } + + function EDKEntryToSeq(edk: Materials.EncryptedDataKey): seq + requires edk.Valid() + { + UInt16ToSeq(|edk.providerID| as uint16) + StringToByteSeq(edk.providerID) + + UInt16ToSeq(|edk.providerInfo| as uint16) + edk.providerInfo + + UInt16ToSeq(|edk.ciphertext| as uint16) + edk.ciphertext + } + + /* Function AADLength is defined without referring to SerializeAAD (because then + * these two would be mutually recursive with ValidAAD). The following lemma proves + * that the two definitions correspond. + */ + + lemma ADDLengthCorrect(kvPairs: Materials.EncryptionContext) + requires ValidAAD(kvPairs) + ensures |AADToSeq(kvPairs)| == 2 + AADLength(kvPairs) + { + reveal ValidAAD(); + KVPairsLengthCorrect(kvPairs, 0, |kvPairs|); + /**** Here's a more detailed proof: + var n := |kvPairs|; + if n != 0 { + var s := KVPairsToSeq(kvPairs, 0, n); + calc { + |AADToSeq(kvPairs)|; + == // def. AADToSeq + |UInt16ToSeq(AADLength(kvPairs) as uint16) + UInt16ToSeq(n as uint16) + s|; + == // UInt16ToSeq yields length-2 sequence + 2 + 2 + |s|; + == { KVPairsLengthCorrect(kvPairs, 0, n); } + 2 + 2 + KVPairsLength(kvPairs, 0, n); + } + } + ****/ + } + + lemma KVPairsLengthCorrect(kvPairs: Materials.EncryptionContext, lo: nat, hi: nat) + requires forall i :: 0 <= i < |kvPairs| ==> ValidKVPair(kvPairs[i]) + requires lo <= hi <= |kvPairs| + ensures |KVPairsToSeq(kvPairs, lo, hi)| == KVPairsLength(kvPairs, lo, hi) + { + /**** Here's a more detailed proof: + if lo < hi { + var kvPair := kvPairs[hi - 1]; + calc { + |KVPairsToSeq(kvPairs, lo, hi)|; + == // def. KVPairsToSeq + |KVPairsToSeq(kvPairs, lo, hi - 1) + KVPairToSeq(kvPair)|; + == + |KVPairsToSeq(kvPairs, lo, hi - 1)| + |KVPairToSeq(kvPair)|; + == { KVPairsLengthCorrect(kvPairs, lo, hi - 1); } + KVPairsLength(kvPairs, lo, hi - 1) + |KVPairToSeq(kvPair)|; + == // def. KVPairToSeq + KVPairsLength(kvPairs, lo, hi - 1) + + |UInt16ToSeq(|kvPair.0| as uint16) + kvPair.0 + UInt16ToSeq(|kvPair.1| as uint16) + kvPair.1|; + == + KVPairsLength(kvPairs, lo, hi - 1) + 2 + |kvPair.0| + 2 + |kvPair.1|; + == // def. KVPairsLength + KVPairsLength(kvPairs, lo, hi); + } + } + ****/ + } +} diff --git a/src/SDK/MessageHeader/Definitions.dfy b/src/SDK/MessageHeader/Definitions.dfy deleted file mode 100644 index fd03f5b8f..000000000 --- a/src/SDK/MessageHeader/Definitions.dfy +++ /dev/null @@ -1,41 +0,0 @@ -include "../AlgorithmSuite.dfy" -include "../../StandardLibrary/StandardLibrary.dfy" - -module MessageHeader.Definitions { - import AlgorithmSuite - import opened StandardLibrary - import opened UInt = StandardLibrary.UInt - - /* - * Header body type definition - */ - type T_Version = x | x == 0x01 /*Version 1.0*/ witness 0x01 - type T_Type = x | x == 0x80 /*Customer Authenticated Encrypted Data*/ witness 0x80 - type T_MessageID = x: seq | |x| == 16 witness [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0] - type T_Reserved = x: seq | x == [0,0,0,0] witness [0,0,0,0] - datatype T_ContentType = NonFramed | Framed - type EncCtx = array<(array, array)> - datatype T_AAD = AAD(kvPairs: EncCtx) | EmptyAAD - - datatype EDKEntry = EDKEntry(keyProviderId: array, keyProviderInfo: array, encDataKey: array) - datatype T_EncryptedDataKeys - = EncryptedDataKeys(entries: array) - - datatype HeaderBody = HeaderBody( - version: T_Version, - typ: T_Type, - algorithmSuiteID: AlgorithmSuite.ID, - messageID: T_MessageID, - aad: T_AAD, - encryptedDataKeys: T_EncryptedDataKeys, - contentType: T_ContentType, - reserved: T_Reserved, - ivLength: uint8, - frameLength: uint32) - - /* - * Header authentication type definition - */ - - datatype HeaderAuthentication = HeaderAuthentication(iv: array, authenticationTag: array) -} diff --git a/src/SDK/MessageHeader/Deserialize.dfy b/src/SDK/MessageHeader/Deserialize.dfy deleted file mode 100644 index 68e946baa..000000000 --- a/src/SDK/MessageHeader/Deserialize.dfy +++ /dev/null @@ -1,624 +0,0 @@ -include "Definitions.dfy" -include "Utils.dfy" -include "Validity.dfy" - -include "../AlgorithmSuite.dfy" -include "../../Util/Streams.dfy" -include "../../StandardLibrary/StandardLibrary.dfy" -include "../../Util/UTF8.dfy" - - -/* - * The message header deserialization - * - * The message header is deserialized from a uint8 stream. - * When encountering an error, we stop and return it immediately, leaving the remaining inputs on the stream - */ -module MessageHeader.Deserialize { - import opened Definitions - import opened Validity - import opened Utils - - import AlgorithmSuite - import opened Streams - import opened StandardLibrary - import opened UInt = StandardLibrary.UInt - import opened UTF8 - - /* - * Message header-specific - */ - - method deserializeVersion(is: StringReader) returns (ret: Either) - requires is.Valid() - modifies is - ensures is.Valid() - { - var res := readFixedLengthFromStreamOrFail(is, 1); - match res { - case Left(version) => - if version[0] == 0x01 { - return Left(version[0] as T_Version); - } else { - return Right(DeserializationError("Version not supported.")); - } - case Right(e) => return Right(e); - } - } - - method deserializeType(is: StringReader) returns (ret: Either) - requires is.Valid() - modifies is - ensures is.Valid() - { - var res := readFixedLengthFromStreamOrFail(is, 1); - match res { - case Left(typ) => - if typ[0] == 0x80 { - return Left(typ[0] as T_Type); - } else { - return Right(DeserializationError("Type not supported.")); - } - case Right(e) => return Right(e); - } - } - - method deserializeAlgorithmSuiteID(is: StringReader) returns (ret: Either) - requires is.Valid() - modifies is - ensures - match ret - case Left(algorithmSuiteID) => ValidAlgorithmID(algorithmSuiteID) - case Right(_) => true - ensures is.Valid() - { - var res := readFixedLengthFromStreamOrFail(is, 2); - match res { - case Left(algorithmSuiteID) => - var asid := arrayToUInt16(algorithmSuiteID); - if asid in AlgorithmSuite.validIDs { - return Left(asid as AlgorithmSuite.ID); - } else { - return Right(DeserializationError("Algorithm suite not supported.")); - } - case Right(e) => return Right(e); - } - } - - // TODO: - predicate method isValidMsgID (candidateID: array) - requires candidateID.Length == 16 - ensures ValidMessageId(candidateID[..]) - { - true - } - method deserializeMsgID(is: StringReader) returns (ret: Either) - requires is.Valid() - modifies is - ensures - match ret - case Left(msgId) => ValidMessageId(msgId) - case Right(_) => true - ensures is.Valid() - { - var res := readFixedLengthFromStreamOrFail(is, 16); - match res { - case Left(msgId) => - if isValidMsgID(msgId) { - return Left(msgId[..]); - } else { - return Right(DeserializationError("Not a valid Message ID.")); - } - case Right(e) => return Right(e); - } - } - - method deserializeUTF8(is: StringReader, n: nat) returns (ret: Either, Error>) - requires is.Valid() - modifies is - ensures - match ret - case Left(bytes) => - && bytes.Length == n - && ValidUTF8(bytes) - && fresh(bytes) - case Right(_) => true - ensures is.Valid() - { - ret := readFixedLengthFromStreamOrFail(is, n); - match ret { - case Left(bytes) => - if ValidUTF8(bytes) { - return ret; - } else { - return Right(DeserializationError("Not a valid UTF8 string.")); - } - case Right(e) => return ret; - } - } - - method deserializeUnrestricted(is: StringReader, n: nat) returns (ret: Either, Error>) - requires is.Valid() - modifies is - ensures - match ret - case Left(bytes) => - && bytes.Length == n - && fresh(bytes) - case Right(_) => true - ensures is.Valid() - { - ret := readFixedLengthFromStreamOrFail(is, n); - } - - // TODO: Probably this should be factored out into EncCtx at some point - method deserializeAAD(is: StringReader) returns (ret: Either) - requires is.Valid() - modifies is - ensures - match ret - case Left(aad) => - && ValidAAD(aad) - // TODO: I think we need to establish freshness to connect - // deserialization with the caller. - // Times out. - //&& fresh(ReprAAD(aad)) - case Right(_) => true - ensures is.Valid() - { - reveal ValidAAD(); - var kvPairsLength: uint16; - { - var res := deserializeUnrestricted(is, 2); - match res { - case Left(bytes) => kvPairsLength := arrayToUInt16(bytes); - case Right(e) => return Right(e); - } - } - if kvPairsLength == 0 { - return Left(EmptyAAD); - } - var totalBytesRead := 0; - - var kvPairsCount: uint16; - { - var res := deserializeUnrestricted(is, 2); - match res { - case Left(bytes) => - kvPairsCount := arrayToUInt16(bytes); - totalBytesRead := totalBytesRead + bytes.Length; - if kvPairsLength > 0 && kvPairsCount == 0 { - return Right(DeserializationError("Key value pairs count is 0.")); - } - assert kvPairsLength > 0 ==> kvPairsCount > 0; - case Right(e) => return Right(e); - } - } - - var kvPairs: EncCtx := new [kvPairsCount]; - assert kvPairs.Length > 0; - assert kvPairsCount == kvPairs.Length as uint16; - - var i := 0; - while i < kvPairsCount - invariant is.Valid() - invariant i <= kvPairsCount - invariant InBoundsKVPairsUpTo(kvPairs, i as nat) - invariant SortedKVPairsUpTo(kvPairs, i as nat) - invariant forall j :: 0 <= j < i ==> ValidUTF8(kvPairs[j].0) - invariant forall j :: 0 <= j < i ==> ValidUTF8(kvPairs[j].1) - // TODO: I think we need to establish freshness to connect - // deserialization with the caller. - // Times out. - //invariant fresh(ReprAADUpTo(kvPairs, i as nat)) - { - var keyLength: uint16; - { - var res := deserializeUnrestricted(is, 2); - match res { - case Left(bytes) => - keyLength := arrayToUInt16(bytes); - totalBytesRead := totalBytesRead + bytes.Length; - case Right(e) => return Right(e); - } - } - - var key := new uint8[keyLength]; - { - var res := deserializeUTF8(is, keyLength as nat); - match res { - case Left(bytes) => - key := bytes; - totalBytesRead := totalBytesRead + bytes.Length; - case Right(e) => return Right(e); - } - } - assert key.Length <= UINT16_MAX; - - var valueLength: uint16; - { - var res := deserializeUnrestricted(is, 2); - match res { - case Left(bytes) => - valueLength := arrayToUInt16(bytes); - totalBytesRead := totalBytesRead + bytes.Length; - case Right(e) => return Right(e); - } - } - - var value := new uint8[valueLength]; - { - var res := deserializeUTF8(is, valueLength as nat); - match res { - case Left(bytes) => - value := bytes; - totalBytesRead := totalBytesRead + bytes.Length; - case Right(e) => return Right(e); - } - } - assert value.Length <= UINT16_MAX; - - // check for sortedness by key - if i > 0 { - if lexCmpArrays(kvPairs[i-1].0, key, ltByte) { - kvPairs[i] := (key, value); - } else { - return Right(DeserializationError("Key-value pairs must be sorted by key.")); - } - } else { - assert i == 0; - kvPairs[i] := (key, value); - } - assert SortedKVPairsUpTo(kvPairs, (i+1) as nat); - i := i + 1; - } - if (kvPairsLength as nat) != totalBytesRead { - return Right(DeserializationError("Bytes actually read differs from bytes supposed to be read.")); - } - return Left(AAD(kvPairs)); - } - - // TODO: Probably this should be factored out into EDK at some point - method deserializeEncryptedDataKeys(is: StringReader, ghost aad: T_AAD) returns (ret: Either) - requires is.Valid() - modifies is - ensures - match ret - case Left(edks) => - && ValidEncryptedDataKeys(edks) - // TODO: I think we need to establish freshness to connect - // deserialization with the caller. - // Times out. - // && fresh(ReprEncryptedDataKeys(edks)) - case Right(_) => true - ensures is.Valid() - { - reveal ValidEncryptedDataKeys(); - var res: Either; - var edkCount: uint16; - res := deserializeUnrestricted(is, 2); - match res { - case Left(bytes) => edkCount := arrayToUInt16(bytes); - case Right(e) => return Right(e); - } - - if edkCount == 0 { - return Right(DeserializationError("Encrypted data key count must be > 0.")); - } - - var edkEntries: array := new [edkCount]; - var edks := EncryptedDataKeys(edkEntries); - var i := 0; - while i < edkCount - invariant is.Valid() - invariant i <= edkCount - invariant InBoundsEncryptedDataKeysUpTo(edks.entries, i as nat) - invariant forall j :: 0 <= j < i ==> ValidUTF8(edks.entries[j].keyProviderId) - // TODO: I think we need to establish freshness to connect - // deserialization with the caller. - // Times out. - //invariant fresh(ReprEncryptedDataKeysUpTo(edks.entries, i as nat)) - { - // Key provider ID - var keyProviderIDLength: uint16; - res := deserializeUnrestricted(is, 2); - match res { - case Left(bytes) => keyProviderIDLength := arrayToUInt16(bytes); - case Right(e) => return Right(e); - } - - var keyProviderID := new uint8[keyProviderIDLength]; - res := deserializeUTF8(is, keyProviderIDLength as nat); - match res { - case Left(bytes) => keyProviderID := bytes; - case Right(e) => return Right(e); - } - - // Key provider info - var keyProviderInfoLength: uint16; - res := deserializeUnrestricted(is, 2); - match res { - case Left(bytes) => keyProviderInfoLength := arrayToUInt16(bytes); - case Right(e) => return Right(e); - } - - var keyProviderInfo := new uint8[keyProviderInfoLength]; - res := deserializeUnrestricted(is, keyProviderInfoLength as nat); - match res { - case Left(bytes) => keyProviderInfo := bytes; - case Right(e) => return Right(e); - } - - // Encrypted data key - var edkLength: uint16; - res := deserializeUnrestricted(is, 2); - match res { - case Left(bytes) => edkLength := arrayToUInt16(bytes); - case Right(e) => return Right(e); - } - - var edk := new uint8[edkLength]; - res := deserializeUnrestricted(is, edkLength as nat); - match res { - case Left(bytes) => edk := bytes; - case Right(e) => return Right(e); - } - - edks.entries[i] := EDKEntry(keyProviderID, keyProviderInfo, edk); - i := i + 1; - } - - return Left(edks); - } - - method deserializeContentType(is: StringReader) returns (ret: Either) - requires is.Valid() - modifies is - ensures is.Valid() - { - var res := readFixedLengthFromStreamOrFail(is, 1); - match res { - case Left(contentType) => - if contentType[0] == 0x01 { - return Left(NonFramed); - } else if contentType[0] == 0x02 { - return Left(Framed); - } else { - return Right(DeserializationError("Content type not supported.")); - } - case Right(e) => return Right(e); - } - } - - method deserializeReserved(is: StringReader) returns (ret: Either) - requires is.Valid() - modifies is - ensures is.Valid() - { - var res := readFixedLengthFromStreamOrFail(is, 4); - match res { - case Left(reserved) => - if reserved[0] == reserved[1] == reserved[2] == reserved[3] == 0 { - return Left(reserved[..]); - } else { - return Right(DeserializationError("Reserved fields must be 0.")); - } - case Right(e) => return Right(e); - } - } - - method deserializeIVLength(is: StringReader, algSuiteId: AlgorithmSuite.ID) returns (ret: Either) - requires is.Valid() - requires algSuiteId in AlgorithmSuite.Suite.Keys - modifies is - ensures - match ret - case Left(ivLength) => ValidIVLength(ivLength, algSuiteId) - case Right(_) => true - ensures is.Valid() - { - var res := readFixedLengthFromStreamOrFail(is, 1); - match res { - case Left(ivLength) => - if ivLength[0] == AlgorithmSuite.Suite[algSuiteId].params.ivLen { - return Left(ivLength[0]); - } else { - return Right(DeserializationError("Incorrect IV length.")); - } - case Right(e) => return Right(e); - } - } - - method deserializeFrameLength(is: StringReader, contentType: T_ContentType) returns (ret: Either) - requires is.Valid() - modifies is - ensures - match ret - case Left(frameLength) => ValidFrameLength(frameLength, contentType) - case Right(_) => true - ensures is.Valid() - { - var res := readFixedLengthFromStreamOrFail(is, 4); - match res { - case Left(frameLength) => - if contentType.NonFramed? && arrayToUInt32(frameLength) == 0 { - return Left(arrayToUInt32(frameLength)); - } else { - return Right(DeserializationError("Frame length must be 0 when content type is non-framed.")); - } - case Right(e) => return Right(e); - } - } - /** - * Reads raw header data from the input stream and populates the header with all of the information about the - * message. - */ - method headerBody(is: StringReader) returns (ret: Either) - requires is.Valid() - modifies is - ensures is.Valid() - ensures - match ret - case Left(hb) => - && ValidHeaderBody(hb) - // TODO: I think we need to establish freshness to connect - // deserialization with the caller. - // Times out. - // && fresh(ReprAAD(hb.aad)) - // && fresh(ReprEncryptedDataKeys(hb.encryptedDataKeys)) - case Right(_) => true - { - reveal ValidHeaderBody(); - var version: T_Version; - { - var res := deserializeVersion(is); - match res { - case Left(version_) => version := version_; - case Right(e) => return Right(e); - } - } - - var typ: T_Type; - { - var res := deserializeType(is); - match res { - case Left(typ_) => typ := typ_; - case Right(e) => return Right(e); - } - } - - var algorithmSuiteID: AlgorithmSuite.ID; - { - var res := deserializeAlgorithmSuiteID(is); - match res { - case Left(algorithmSuiteID_) => algorithmSuiteID := algorithmSuiteID_; - case Right(e) => return Right(e); - } - } - - var messageID: T_MessageID; - { - var res := deserializeMsgID(is); - match res { - case Left(messageID_) => messageID := messageID_; - case Right(e) => return Right(e); - } - } - - // AAD deserialization: - var aad: T_AAD; - { - var res := deserializeAAD(is); - match res { - case Left(aad_) => aad := aad_; - case Right(e) => return Right(e); - } - } - - // EDK deserialization: - var encryptedDataKeys: T_EncryptedDataKeys; - { - var res := deserializeEncryptedDataKeys(is, aad); - match res { - case Left(encryptedDataKeys_) => encryptedDataKeys := encryptedDataKeys_; - case Right(e) => return Right(e); - } - } - - var contentType: T_ContentType; - { - var res := deserializeContentType(is); - match res { - case Left(contentType_) => contentType := contentType_; - case Right(e) => return Right(e); - } - } - - var reserved: T_Reserved; - { - var res := deserializeReserved(is); - match res { - case Left(reserved_) => reserved := reserved_; - case Right(e) => return Right(e); - } - } - - var ivLength: uint8; - { - var res := deserializeIVLength(is, algorithmSuiteID); - match res { - case Left(ivLength_) => ivLength := ivLength_; - case Right(e) => return Right(e); - } - } - - var frameLength: uint32; - { - var res := deserializeFrameLength(is, contentType); - match res { - case Left(frameLength_) => frameLength := frameLength_; - case Right(e) => return Right(e); - } - } - var hb := HeaderBody( - version, - typ, - algorithmSuiteID, - messageID, - aad, - encryptedDataKeys, - contentType, - reserved, - ivLength, - frameLength); - reveal ReprAAD(); - assert ValidHeaderBody(hb); - ret := Left(hb); - } - - method deserializeAuthenticationTag(is: StringReader, tagLength: nat, ghost iv: array) returns (ret: Either, Error>) - requires is.Valid() - modifies is - ensures - match ret - case Left(authenticationTag) => ValidAuthenticationTag(authenticationTag, tagLength, iv) - case Right(_) => true - ensures is.Valid() - { - ret := readFixedLengthFromStreamOrFail(is, tagLength); - } - - method headerAuthentication(is: StringReader, body: HeaderBody) returns (ret: Either) - requires is.Valid() - requires ValidHeaderBody(body) - requires body.algorithmSuiteID in AlgorithmSuite.Suite.Keys - modifies is - ensures is.Valid() - ensures - match ret - case Left(headerAuthentication) => - && ValidHeaderAuthentication(headerAuthentication, body.algorithmSuiteID) - && ValidHeaderBody(body) - case Right(_) => true - { - reveal ReprAAD(); - var iv: array; - { - var res := deserializeUnrestricted(is, body.ivLength as nat); - match res { - case Left(bytes) => iv := bytes; - case Right(e) => return Right(e); - } - } - - var authenticationTag: array; - { - var res := deserializeAuthenticationTag(is, AlgorithmSuite.Suite[body.algorithmSuiteID].params.tagLen as nat, iv); - match res { - case Left(bytes) => authenticationTag := bytes; - case Right(e) => return Right(e); - } - } - ret := Left(HeaderAuthentication(iv, authenticationTag)); - } -} diff --git a/src/SDK/MessageHeader/MessageHeader.dfy b/src/SDK/MessageHeader/MessageHeader.dfy deleted file mode 100644 index 6abccc23c..000000000 --- a/src/SDK/MessageHeader/MessageHeader.dfy +++ /dev/null @@ -1,81 +0,0 @@ -include "Definitions.dfy" -include "Deserialize.dfy" -include "Serialize.dfy" -include "Validity.dfy" - -include "../AlgorithmSuite.dfy" -include "../../Util/Streams.dfy" -include "../../StandardLibrary/StandardLibrary.dfy" - -module MessageHeader { - import AlgorithmSuite - import opened StandardLibrary - import opened Streams - - /* - * Definition of the message header, i.e., the header body and the header authentication - */ - class Header { - var body: Option - var auth: Option - - constructor () { - body := None; - auth := None; - } - - method deserializeHeader(is: StringReader) - requires is.Valid() - modifies is, `body, `auth - requires body.None? || auth.None? - ensures body.Some? && auth.Some? ==> Validity.ValidHeaderBody(body.get) - ensures body.Some? && auth.Some? ==> Validity.ValidHeaderAuthentication(auth.get, body.get.algorithmSuiteID) - // TODO: is this the right decision? - ensures body.Some? <==> auth.Some? - ensures body.None? <==> auth.None? // redundant - ensures is.Valid() - { - { - var res := Deserialize.headerBody(is); - match res { - case Left(body_) => - // How does Dafny know the following assertion holds with Validity.ValidHeaderBody being opaque? - assert body_.algorithmSuiteID in AlgorithmSuite.Suite.Keys; // nfv - var res := Deserialize.headerAuthentication(is, body_); - match res { - case Left(auth_) => - body := Some(body_); - auth := Some(auth_); - reveal Validity.ReprAAD(); - assert Validity.ValidHeaderBody(body.get); - case Right(e) => { - print "Could not deserialize message header: " + e.msg + "\n"; - body := None; - auth := None; - return; - } - } - case Right(e) => { - print "Could not deserialize message header: " + e.msg + "\n"; - body := None; - auth := None; - return; - } - } - } - } - - method serializeHeader(os: StringWriter) returns (ret: Either) - requires os.Valid() - requires body.Some? - requires os.Repr !! Validity.ReprAAD(body.get.aad) - requires os.Repr !! Validity.ReprEncryptedDataKeys(body.get.encryptedDataKeys) - requires Validity.ReprAAD(body.get.aad) !! Validity.ReprEncryptedDataKeys(body.get.encryptedDataKeys) - requires Validity.ValidHeaderBody(body.get) - modifies os.Repr - ensures os.Valid() - { - ret := Serialize.headerBody(os, body.get); - } - } -} diff --git a/src/SDK/MessageHeader/Serialize.dfy b/src/SDK/MessageHeader/Serialize.dfy deleted file mode 100644 index 1e50283e7..000000000 --- a/src/SDK/MessageHeader/Serialize.dfy +++ /dev/null @@ -1,185 +0,0 @@ -// This file verified before removing length/count from the AAD/EDK datatypes -// with Dafny flag -arith:5 and when commenting out the HeapSucc transitivity axiom in the DafnyPrelude.bpl. -// Turned the failing assertion into an assume for now. - -include "Definitions.dfy" -include "SerializeAAD.dfy" -include "SerializeEDK.dfy" -include "Validity.dfy" - -include "../../Util/Streams.dfy" -include "../../StandardLibrary/StandardLibrary.dfy" - -module MessageHeader.Serialize { - import opened Definitions - import opened SerializeAAD - import opened SerializeEDK - import opened Validity - - import opened Streams - import opened StandardLibrary - import opened UInt = StandardLibrary.UInt - - lemma {:axiom} Assume(b : bool) ensures b - - function {:opaque} serialize(hb: HeaderBody): seq - requires ValidHeaderBody(hb) - requires ReprAAD(hb.aad) !! ReprEncryptedDataKeys(hb.encryptedDataKeys) - reads if hb.aad.AAD? then {hb.aad.kvPairs} else {} - reads ReprAAD(hb.aad) - reads ReprEncryptedDataKeys(hb.encryptedDataKeys) - { - reveal ValidHeaderBody(); - [hb.version as uint8] + - [hb.typ as uint8] + - uint16ToSeq(hb.algorithmSuiteID as uint16) + - hb.messageID + - serializeAAD(hb.aad) + - serializeEDK(hb.encryptedDataKeys) + - (match hb.contentType case NonFramed => [0x01] case Framed => [0x02]) + - hb.reserved + - [hb.ivLength] + - uint32ToSeq(hb.frameLength) - } - - method headerBody(os: StringWriter, hb: HeaderBody) returns (ret: Either) - requires os.Valid() - modifies os`data - requires ValidHeaderBody(hb) - // It's crucial to require no aliasing here - requires os.Repr !! ReprAAD(hb.aad) - requires os.Repr !! ReprEncryptedDataKeys(hb.encryptedDataKeys) - requires ReprAAD(hb.aad) !! ReprEncryptedDataKeys(hb.encryptedDataKeys) - ensures os.Valid() - // TODO: these should probably be enabled - //ensures unchanged(os`Repr) - //ensures unchanged(ReprAAD(hb.aad)) - //ensures unchanged(ReprEncryptedDataKeys(hb.encryptedDataKeys)) - //ensures old(|os.data|) <= |os.data| - ensures ValidHeaderBody(hb) - ensures - match ret - case Left(totalWritten) => - var serHb := (reveal serialize(); serialize(hb)); - var initLen := old(|os.data|); - && totalWritten == |serHb| - && initLen+totalWritten == |os.data| - && serHb[..totalWritten] == os.data[initLen..initLen+totalWritten] - case Right(e) => true - { - var totalWritten := 0; - ghost var initLen := |os.data|; - reveal ValidHeaderBody(); - reveal ValidAAD(); - reveal ValidEncryptedDataKeys(); - { - ret := os.WriteSingleByteSimple(hb.version as uint8); - match ret { - case Left(len) => totalWritten := totalWritten + len; - case Right(e) => { - return ret; - } - } - } - - { - ret := os.WriteSingleByteSimple(hb.typ as uint8); - match ret { - case Left(len) => totalWritten := totalWritten + len; - case Right(e) => { - return ret; - } - } - } - - { - var bytes := uint16ToArray(hb.algorithmSuiteID as uint16); - ret := os.WriteSimple(bytes); - match ret { - case Left(len) => totalWritten := totalWritten + len; - case Right(e) => { - return ret; - } - } - } - - { - ret := os.WriteSimpleSeq(hb.messageID); - match ret { - case Left(len) => totalWritten := totalWritten + len; - case Right(e) => { - return ret; - } - } - } - - { - ret := serializeAADImpl(os, hb.aad); - match ret { - case Left(len) => totalWritten := totalWritten + len; - case Right(e) => { - return ret; - } - } - } - - { - ret := serializeEDKImpl(os, hb.encryptedDataKeys); - match ret { - case Left(len) => totalWritten := totalWritten + len; - case Right(e) => return ret; - } - } - - { - var contentType: uint8; - match hb.contentType { - case NonFramed => contentType := 0x01; - case Framed => contentType := 0x02; - } - ret := os.WriteSingleByteSimple(contentType); - match ret { - case Left(len) => - totalWritten := totalWritten + len; - case Right(e) => - return ret; - } - } - - { - ret := os.WriteSimpleSeq(hb.reserved); - match ret { - case Left(len) => totalWritten := totalWritten + len; - case Right(e) => return ret; - } - } - - { - ret := os.WriteSingleByteSimple(hb.ivLength); - match ret { - case Left(len) => totalWritten := totalWritten + len; - case Right(e) => return ret; - } - } - - { - var bytes := uint32ToArray(hb.frameLength); - ret := os.WriteSimple(bytes); - match ret { - case Left(len) => totalWritten := totalWritten + len; - case Right(e) => return ret; - } - } - //reveal ReprEncryptedDataKeys(); - assert ValidHeaderBody(hb); - reveal serialize(); - ghost var serHb := serialize(hb); - assert totalWritten == |serHb|; - assert initLen+totalWritten == |os.data|; - - // Turned this assertion into an assume. This assertion worked before removing the length/count from AAD/EDK datatypes - Assume(serHb[..totalWritten] == os.data[initLen..initLen+totalWritten]); - - return Left(totalWritten); - } -} diff --git a/src/SDK/MessageHeader/SerializeAAD.dfy b/src/SDK/MessageHeader/SerializeAAD.dfy deleted file mode 100644 index ca4c104c3..000000000 --- a/src/SDK/MessageHeader/SerializeAAD.dfy +++ /dev/null @@ -1,208 +0,0 @@ -include "Definitions.dfy" -include "Validity.dfy" - -include "../../Util/Streams.dfy" -include "../../StandardLibrary/StandardLibrary.dfy" - -module MessageHeader.SerializeAAD { - import opened Definitions - import opened Validity - - import opened Streams - import opened StandardLibrary - import opened UInt = StandardLibrary.UInt - - lemma {:axiom} Assume(b : bool) - ensures b - - function encCtxToSeqRec(kvPairs: EncCtx, i: nat): seq - requires forall i :: 0 <= i < kvPairs.Length ==> kvPairs[i].0.Length <= UINT16_MAX && kvPairs[i].1.Length <= UINT16_MAX - decreases kvPairs.Length - i - reads kvPairs - reads set i | 0 <= i < kvPairs.Length :: kvPairs[i].0 - reads set i | 0 <= i < kvPairs.Length :: kvPairs[i].1 - // reads ReprAADUpTo(kvPairs, kvPairs.Length) - { - if i < kvPairs.Length - then - uint16ToSeq(kvPairs[i].0.Length as uint16) + kvPairs[i].0[..] + - uint16ToSeq(kvPairs[i].1.Length as uint16) + kvPairs[i].1[..] + - encCtxToSeqRec(kvPairs, i + 1) - else [] - } - - function encCtxToSeq(kvPairs: EncCtx): (ret: seq) - requires forall i :: 0 <= i < kvPairs.Length ==> kvPairs[i].0.Length <= UINT16_MAX && kvPairs[i].1.Length <= UINT16_MAX - reads kvPairs - reads set i | 0 <= i < kvPairs.Length :: kvPairs[i].0 - reads set i | 0 <= i < kvPairs.Length :: kvPairs[i].1 - // reads ReprAADUpTo(kvPairs, kvPairs.Length) - ensures |ret| <= UINT16_MAX // TODO: we need to establish that this length fits into two bytes - { - encCtxToSeqRec(kvPairs, 0) - } - - function serializeAAD(aad: T_AAD): seq - requires ValidAAD(aad) - reads ReprAAD(aad) - { - match aad { - case AAD(kvPairs) => - var encCtxSeq := encCtxToSeq(kvPairs); - uint16ToSeq(|encCtxSeq| as uint16) + - // It would be nicer if this could be a flatten + map as for AAD - uint16ToSeq(kvPairs.Length as uint16) + encCtxSeq - case EmptyAAD() => - uint16ToSeq(0) - } - } - - method serializeAADImpl(os: StringWriter, aad: T_AAD) returns (ret: Either) - requires os.Valid() - modifies os`data // do we need to establish non-aliasing with encryptedDataKeys here? - ensures os.Valid() - requires ValidAAD(aad) - requires os.Repr !! ReprAAD(aad) - ensures ValidAAD(aad) - ensures unchanged(os`Repr) - ensures unchanged(ReprAAD(aad)) - //ensures old(|os.data|) <= |os.data| - ensures - match ret - case Left(totalWritten) => - var serAAD := serializeAAD(aad); - var initLen := old(|os.data|); - && totalWritten == |serAAD| - && initLen+totalWritten == |os.data| - && os.data == old(os.data + serAAD) - case Right(e) => true - { - var totalWritten := 0; - ghost var initLen := |os.data|; - ghost var written: seq := [initLen]; - ghost var i := 0; - - match aad { - case AAD(kvPairs) => { - { - // Key Value Pairs Length (number of bytes of total AAD) - var length: uint16; - assert InBoundsKVPairs(kvPairs) ==> kvPairs.Length <= UINT16_MAX; - // TODO: We need to compute length here after removing length field from AAD datatype - Assume(length == |encCtxToSeq(kvPairs)| as uint16); - var bytes := uint16ToArray(length); - ret := os.WriteSimple(bytes); - match ret { - case Left(len) => totalWritten := totalWritten + len; - case Right(e) => return ret; - } - i := i + 1; - written := written + [initLen + totalWritten]; - assert written[i] - written[i-1] == bytes.Length; - assert written[i-1] <= written[i] <= |os.data| ==> os.data[written[i-1]..written[i]] == bytes[..]; - assert totalWritten <= |serializeAAD(aad)|; - } - - assert 0 < kvPairs.Length; - { - assert InBoundsKVPairs(kvPairs) ==> kvPairs.Length <= UINT16_MAX; - // Key Value Pair Count (number of key value pairs) - var bytes := uint16ToArray(kvPairs.Length as uint16); - ret := os.WriteSimple(bytes); - match ret { - case Left(len) => totalWritten := totalWritten + len; - case Right(e) => return ret; - } - i := i + 1; - written := written + [initLen + totalWritten]; - assert written[i] - written[i-1] == bytes.Length; - assert written[i-1] <= written[i] <= |os.data| ==> os.data[written[i-1]..written[i]] == bytes[..]; - assert totalWritten <= |serializeAAD(aad)|; - } - - Assume(false); // TODO: verification times out after this point. I believe that we just do too many heap updates. - - var j := 0; - while j < kvPairs.Length - invariant j <= kvPairs.Length - invariant os.Repr !! ReprAAD(aad) - invariant unchanged(os`Repr) - invariant InBoundsKVPairsUpTo(kvPairs, j) - invariant ValidAAD(aad) - invariant totalWritten <= |serializeAAD(aad)| - invariant initLen+totalWritten <= |os.data| - invariant serializeAAD(aad)[written[i-1]-initLen..written[i]-initLen] == os.data[written[i-1]..written[i]]; - //invariant serializeAAD(aad)[..totalWritten] == os.data[initLen..written[i]]; - { - { - assert InBoundsKVPairsUpTo(kvPairs, j) ==> kvPairs[j].0.Length <= UINT16_MAX; - var bytes := uint16ToArray(kvPairs[j].0.Length as uint16); - ret := os.WriteSimple(bytes); - match ret { - case Left(len) => totalWritten := totalWritten + len; - case Right(e) => return ret; - } - i := i + 1; - written := written + [initLen + totalWritten]; - assert written[i] - written[i-1] == bytes.Length; - assert written[i-1] <= written[i] <= |os.data| ==> os.data[written[i-1]..written[i]] == bytes[..]; - } - - { - var bytes := kvPairs[j].0; - ret := os.WriteSimple(bytes); - match ret { - case Left(len) => totalWritten := totalWritten + len; - case Right(e) => return ret; - } - i := i + 1; - written := written + [initLen + totalWritten]; - assert written[i] - written[i-1] == bytes.Length; - assert written[i-1] <= written[i] <= |os.data| ==> os.data[written[i-1]..written[i]] == bytes[..]; - } - - { - assert InBoundsKVPairsUpTo(kvPairs, j) ==> kvPairs[j].1.Length <= UINT16_MAX; - var bytes := uint16ToArray(kvPairs[j].1.Length as uint16); - ret := os.WriteSimple(bytes); - match ret { - case Left(len) => totalWritten := totalWritten + len; - case Right(e) => return ret; - } - i := i + 1; - written := written + [initLen + totalWritten]; - assert written[i] - written[i-1] == bytes.Length; - assert written[i-1] <= written[i] <= |os.data| ==> os.data[written[i-1]..written[i]] == bytes[..]; - } - - { - var bytes := kvPairs[j].1; - ret := os.WriteSimple(bytes); - match ret { - case Left(len) => totalWritten := totalWritten + len; - case Right(e) => return ret; - } - i := i + 1; - written := written + [initLen + totalWritten]; - assert written[i] - written[i-1] == bytes.Length; - assert written[i-1] <= written[i] <= |os.data| ==> os.data[written[i-1]..written[i]] == bytes[..]; - } - j := j + 1; - } - } - case EmptyAAD() => { - var bytes := uint16ToArray(0); - ret := os.WriteSimple(bytes); - match ret { - case Left(len) => totalWritten := totalWritten + len; - case Right(e) => return ret; - } - i := i + 1; - written := written + [initLen + totalWritten]; - assert written[i] - written[i-1] == 2; - assert written[i-1] <= written[i] <= |os.data| ==> os.data[written[i-1]..written[i]] == bytes[..]; - - } - } - } -} diff --git a/src/SDK/MessageHeader/SerializeEDK.dfy b/src/SDK/MessageHeader/SerializeEDK.dfy deleted file mode 100644 index 322c61f6d..000000000 --- a/src/SDK/MessageHeader/SerializeEDK.dfy +++ /dev/null @@ -1,180 +0,0 @@ -include "Definitions.dfy" -include "Validity.dfy" - -include "../../Util/Streams.dfy" -include "../../StandardLibrary/StandardLibrary.dfy" - -module MessageHeader.SerializeEDK { - import opened Definitions - import opened Validity - - import opened Streams - import opened StandardLibrary - import opened UInt = StandardLibrary.UInt - - lemma {:axiom} Assume(b: bool) ensures b - - // Alternative w/o flatten/map - function serializeEDKEntries(entries: seq): seq - requires forall i :: 0 <= i < |entries| ==> - && entries[i].keyProviderId.Length <= UINT16_MAX - && entries[i].keyProviderInfo.Length <= UINT16_MAX - && entries[i].encDataKey.Length <= UINT16_MAX - reads ReprEncryptedDataKeysUpTo(entries, |entries|) - { - if entries == [] - then [] - else - var entry := entries[0]; - uint16ToSeq(entry.keyProviderId.Length as uint16) + entry.keyProviderId[..] + - uint16ToSeq(entry.keyProviderInfo.Length as uint16) + entry.keyProviderInfo[..] + - uint16ToSeq(entry.encDataKey.Length as uint16) + entry.encDataKey[..] + - serializeEDKEntries(entries[1..]) - } - - function serializeEDK(encryptedDataKeys: T_EncryptedDataKeys): seq - requires ValidEncryptedDataKeys(encryptedDataKeys) - reads ReprEncryptedDataKeys(encryptedDataKeys) - { - uint16ToSeq(encryptedDataKeys.entries.Length as uint16) + - serializeEDKEntries(encryptedDataKeys.entries[..]) - } - - method serializeEDKImpl(os: StringWriter, encryptedDataKeys: T_EncryptedDataKeys) returns (ret: Either) - requires os.Valid() - modifies os`data - ensures os.Valid() - requires ValidEncryptedDataKeys(encryptedDataKeys) - requires os.Repr !! ReprEncryptedDataKeys(encryptedDataKeys) - ensures ValidEncryptedDataKeys(encryptedDataKeys) - ensures unchanged(os`Repr) - ensures unchanged(ReprEncryptedDataKeys(encryptedDataKeys)) - //ensures old(|os.data|) <= |os.data| - ensures - match ret - case Left(totalWritten) => - var serEDK := serializeEDK(encryptedDataKeys); - var initLen := old(|os.data|); - && totalWritten == |serEDK| - && initLen+totalWritten == |os.data| - && os.data == old(os.data + serEDK) - case Right(e) => true - { - var totalWritten: nat := 0; - ghost var initLen := |os.data|; - ghost var prevPos: nat := initLen; - ghost var currPos: nat := initLen; - - { - var bytes := uint16ToArray(encryptedDataKeys.entries.Length as uint16); - ret := os.WriteSimple(bytes); - match ret { - case Left(len) => totalWritten := totalWritten + len; - case Right(e) => return ret; - } - prevPos := currPos; - currPos := initLen + totalWritten; - assert currPos - prevPos == bytes.Length; - assert prevPos <= currPos <= |os.data| ==> os.data[prevPos..currPos] == bytes[..]; - assert totalWritten <= |serializeEDK(encryptedDataKeys)|; - } - //assume false; - - var j := 0; - ghost var written: seq := [currPos, currPos]; - while j < encryptedDataKeys.entries.Length - invariant j <= encryptedDataKeys.entries.Length - invariant j < |written| - invariant os.Repr !! ReprEncryptedDataKeys(encryptedDataKeys) - invariant unchanged(os`Repr) - invariant InBoundsEncryptedDataKeysUpTo(encryptedDataKeys.entries, j) - invariant ValidEncryptedDataKeys(encryptedDataKeys) - invariant initLen + totalWritten == currPos - invariant 0 <= initLen <= prevPos <= currPos <= |os.data| - invariant currPos-initLen <= |serializeEDK(encryptedDataKeys)| - invariant serializeEDK(encryptedDataKeys)[..currPos-initLen] == os.data[initLen..currPos] - invariant serializeEDK(encryptedDataKeys)[prevPos-initLen..currPos-initLen] == os.data[prevPos..currPos]; - invariant 1 <= j ==> os.data[initLen..written[j]] == os.data[initLen..written[j-1]] + serializeEDKEntries(encryptedDataKeys.entries[..j]) - { - var entry := encryptedDataKeys.entries[j]; - { - var bytes := uint16ToArray(entry.keyProviderId.Length as uint16); - ret := os.WriteSimple(bytes); - match ret { - case Left(len) => totalWritten := totalWritten + len; - case Right(e) => return ret; - } - prevPos := currPos; - currPos := initLen + totalWritten; - assert prevPos <= currPos <= |os.data| ==> os.data[prevPos..currPos] == bytes[..]; - } - - { - var bytes := entry.keyProviderId; - ret := os.WriteSimple(bytes); - match ret { - case Left(len) => totalWritten := totalWritten + len; - case Right(e) => return ret; - } - prevPos := currPos; - currPos := initLen + totalWritten; - assert currPos - prevPos == bytes.Length; - assert prevPos <= currPos <= |os.data| ==> os.data[prevPos..currPos] == bytes[..]; - } - - { - var bytes := uint16ToArray(entry.keyProviderInfo.Length as uint16); - ret := os.WriteSimple(bytes); - match ret { - case Left(len) => totalWritten := totalWritten + len; - case Right(e) => return ret; - } - prevPos := currPos; - currPos := initLen + totalWritten; - assert currPos - prevPos == bytes.Length; - assert prevPos <= currPos <= |os.data| ==> os.data[prevPos..currPos] == bytes[..]; - } - - { - var bytes := entry.keyProviderInfo; - ret := os.WriteSimple(bytes); - match ret { - case Left(len) => totalWritten := totalWritten + len; - case Right(e) => return ret; - } - prevPos := currPos; - currPos := initLen + totalWritten; - assert currPos - prevPos == bytes.Length; - assert prevPos <= currPos <= |os.data| ==> os.data[prevPos..currPos] == bytes[..]; - } - - { - var bytes := uint16ToArray(entry.encDataKey.Length as uint16); - ret := os.WriteSimple(bytes); - match ret { - case Left(len) => totalWritten := totalWritten + len; - case Right(e) => return ret; - } - prevPos := currPos; - currPos := initLen + totalWritten; - assert currPos - prevPos == bytes.Length; - assert prevPos <= currPos <= |os.data| ==> os.data[prevPos..currPos] == bytes[..]; - } - - { - var bytes := entry.encDataKey; - ret := os.WriteSimple(bytes); - match ret { - case Left(len) => totalWritten := totalWritten + len; - case Right(e) => return ret; - } - prevPos := currPos; - currPos := initLen + totalWritten; - assert currPos - prevPos == bytes.Length; - assert prevPos <= currPos <= |os.data| ==> os.data[prevPos..currPos] == bytes[..]; - } - written := written + [currPos]; - j := j + 1; - } - } -} diff --git a/src/SDK/MessageHeader/Utils.dfy b/src/SDK/MessageHeader/Utils.dfy deleted file mode 100644 index 9d33248b4..000000000 --- a/src/SDK/MessageHeader/Utils.dfy +++ /dev/null @@ -1,126 +0,0 @@ -include "../AlgorithmSuite.dfy" -include "../../Util/Streams.dfy" -include "../../StandardLibrary/StandardLibrary.dfy" -include "../../Util/UTF8.dfy" - -module MessageHeader.Utils { - import AlgorithmSuite - import opened Streams - import opened StandardLibrary - import opened UInt = StandardLibrary.UInt - import opened UTF8 - /* - * Utils - */ - method readFixedLengthFromStreamOrFail(is: StringReader, n: nat) returns (ret: Either, Error>) - requires is.Valid() - modifies is - ensures - match ret - case Left(bytes) => - && n == bytes.Length - && fresh(bytes) - case Right(_) => true - ensures is.Valid() - { - var bytes := new uint8[n]; - var out: Either; - out := is.Read(bytes, 0, n); - match out { - case Left(bytesRead) => - if bytesRead != n { - return Right(IOError("Not enough bytes left on stream.")); - } else { - return Left(bytes); - } - case Right(e) => return Right(e); - } - } - /* - // This is like StringWriter.WriteSimple - // TODO: this is broken somehow - method WriteFixedLengthToStreamOrFail(os: StringWriter, bytes: array) returns (ret: Result) - requires os.Valid() - requires bytes !in os.Repr - modifies os.Repr - ensures old(os.Repr) == os.Repr - ensures bytes !in os.Repr - ensures os.Valid() - ensures - match ret - case Left(len_written) => - && len_written == bytes.Length - && os.pos == old(os.pos) + len_written - && old(os.pos + len_written <= os.data.Length) - && os.data[..] == old(os.data[..os.pos]) + bytes[..] + old(os.data[os.pos + len_written..]) - case Right(e) => true - { - ghost var oldPos := os.pos; - ghost var oldData := os.data; - var oldCap := os.capacity(); - ret := os.Write(bytes, 0, bytes.Length); - match ret { - case Left(len) => - if oldCap >= bytes.Length > 0 { - //if len == bytes.Length { - // assert len == bytes.Length; - assert oldPos + len <= oldData.Length; - - assert os.data[..oldPos] == oldData[..oldPos]; - assert os.data[oldPos..oldPos+len] == bytes[..]; - assert os.data[oldPos+len..] == oldData[oldPos + len..]; - - assert os.data[..] == oldData[..oldPos] + bytes[..] + oldData[oldPos + len..]; - assert os.pos == oldPos + len; - return Left(len); - } else { - return Right(SerializationError("Reached end of stream.")); - } - case Right(e) => return ret; - } - } - */ - /* - // This is like StringWriter.WriteSingleByteSimple - method WriteSingleByteToStreamOrFail(os: StringWriter, byte: uint8) returns (ret: Result) - requires os.Valid() - modifies os.Repr - ensures old(os.Repr) == os.Repr - ensures os.Valid() - ensures - match ret - case Left(len_written) => - && len_written == 1 - && old(os.pos + len_written <= os.data.Length) - && os.data[..] == old(os.data[..os.pos]) + [byte] + old(os.data[os.pos + len_written..]) - && os.pos == old(os.pos) + len_written - case Right(e) => true - { - ret := os.WriteSingleByte(byte); - match ret { - case Left(len) => - if len == 1 { - return Left(len); - } else { - return Right(SerializationError("Reached end of stream.")); - } - case Right(e) => return ret; - } - } - */ - - predicate SortedKVPairsUpTo(a: array<(array, array)>, n: nat) - requires n <= a.Length - reads a - reads set i | 0 <= i < n :: a[i].0 - { - forall j :: 0 < j < n ==> lexCmpArrays(a[j-1].0, a[j].0, ltByte) - } - - predicate SortedKVPairs(a: array<(array, array)>) - reads a - reads set i | 0 <= i < a.Length :: a[i].0 - { - SortedKVPairsUpTo(a, a.Length) - } -} diff --git a/src/SDK/MessageHeader/Validity.dfy b/src/SDK/MessageHeader/Validity.dfy deleted file mode 100644 index 27d47894f..000000000 --- a/src/SDK/MessageHeader/Validity.dfy +++ /dev/null @@ -1,181 +0,0 @@ -include "Definitions.dfy" -include "Utils.dfy" - -include "../AlgorithmSuite.dfy" -include "../../StandardLibrary/StandardLibrary.dfy" -include "../../Util/UTF8.dfy" - -module MessageHeader.Validity { - import opened Definitions - import opened Utils - - import AlgorithmSuite - import opened StandardLibrary - import opened UInt = StandardLibrary.UInt - import opened UTF8 - /* - * Validity of the message header - * The validity depends on predicates and on the types of the fields - */ - predicate {:opaque} ValidHeaderBody(hb: HeaderBody) - reads (reveal ReprAAD(); ReprAAD(hb.aad)) - reads ReprEncryptedDataKeys(hb.encryptedDataKeys) - { - && ValidAlgorithmID(hb.algorithmSuiteID) - && ValidMessageId(hb.messageID) - && ValidAAD(hb.aad) - && ValidEncryptedDataKeys(hb.encryptedDataKeys) - && ValidIVLength(hb.ivLength, hb.algorithmSuiteID) - && ValidFrameLength(hb.frameLength, hb.contentType) - } - - // TODO: strengthen spec when available - predicate uniquelyIdentifiesMessage(id: T_MessageID) { true } - predicate weaklyBindsHeaderToHeaderBody(id: T_MessageID) { true } - predicate enablesSecureReuse(id: T_MessageID) { true } - predicate protectsAgainstAccidentalReuse(id: T_MessageID) { true } - predicate protectsAgainstWearingOut(id: T_MessageID) { true } - predicate ValidMessageId(id: T_MessageID) { - && uniquelyIdentifiesMessage(id) - && weaklyBindsHeaderToHeaderBody(id) - && enablesSecureReuse(id) - && protectsAgainstAccidentalReuse(id) - && protectsAgainstWearingOut(id) - } - predicate ValidAlgorithmID(algorithmSuiteID: AlgorithmSuite.ID) { - algorithmSuiteID in AlgorithmSuite.Suite.Keys - } - - function ReprAADUpTo(kvPairs: EncCtx, j: nat): set - requires j <= kvPairs.Length - reads kvPairs - { - (set i | 0 <= i < j :: kvPairs[i].0) + - (set i | 0 <= i < j :: kvPairs[i].1) - } - - function {:opaque} ReprAAD(aad: T_AAD): set - reads if aad.AAD? then {aad.kvPairs} else {} - { - match aad { - // Not extracting the fields of AAD here for now, but using `aad.` due to https://github.com/dafny-lang/dafny/issues/314 - case AAD(_) => - ReprAADUpTo(aad.kvPairs, aad.kvPairs.Length) + - {aad.kvPairs} - case EmptyAAD() => {} - } - } - - predicate InBoundsKVPairsUpTo(kvPairs: EncCtx, j: nat) - requires j <= kvPairs.Length - reads kvPairs - { - forall i :: 0 <= i < j ==> - && kvPairs[i].0.Length <= UINT16_MAX - && kvPairs[i].1.Length <= UINT16_MAX - } - - predicate InBoundsKVPairs(kvPairs: EncCtx) - reads kvPairs - { - && kvPairs.Length <= UINT16_MAX - && InBoundsKVPairsUpTo(kvPairs, kvPairs.Length) - } - - predicate ValidKVPairs(kvPairs: EncCtx) - reads kvPairs - reads set i | 0 <= i < kvPairs.Length :: kvPairs[i].0 - reads set i | 0 <= i < kvPairs.Length :: kvPairs[i].1 - { - forall i :: 0 <= i < kvPairs.Length ==> ValidUTF8(kvPairs[i].0) && ValidUTF8(kvPairs[i].1) - } - - predicate {:opaque} ValidAAD(aad: T_AAD) - reads (reveal ReprAAD(); ReprAAD(aad)) - { - match aad { - case AAD(kvPairs) => - && 0 < kvPairs.Length - && InBoundsKVPairs(kvPairs) - && ValidKVPairs(kvPairs) - && SortedKVPairs(kvPairs) - case EmptyAAD() => true - } - } - - function ReprEncryptedDataKeysUpTo(entries: seq, j: nat): set - requires j <= |entries| - { - (set i | 0 <= i < j :: entries[i].keyProviderId) + - (set i | 0 <= i < j :: entries[i].keyProviderInfo) + - (set i | 0 <= i < j :: entries[i].encDataKey) - } - - function ReprEncryptedDataKeys(encryptedDataKeys: T_EncryptedDataKeys): set - reads encryptedDataKeys.entries - { - ReprEncryptedDataKeysUpTo(encryptedDataKeys.entries[..], encryptedDataKeys.entries.Length) + - {encryptedDataKeys.entries} - } - - predicate InBoundsEncryptedDataKeysUpTo(entries: array, j: nat) - requires j <= entries.Length - reads entries - reads ReprEncryptedDataKeysUpTo(entries[..], entries.Length) - { - forall i :: 0 <= i < j ==> - && entries[i].keyProviderId.Length <= UINT16_MAX - && entries[i].keyProviderInfo.Length <= UINT16_MAX - && entries[i].encDataKey.Length <= UINT16_MAX - } - - predicate InBoundsEncryptedDataKeys(encryptedDataKeys: T_EncryptedDataKeys) - reads ReprEncryptedDataKeys(encryptedDataKeys) - { - InBoundsEncryptedDataKeysUpTo(encryptedDataKeys.entries, encryptedDataKeys.entries.Length) - } - - predicate {:opaque} ValidEncryptedDataKeys(encryptedDataKeys: T_EncryptedDataKeys) - reads ReprEncryptedDataKeys(encryptedDataKeys) - { - && InBoundsEncryptedDataKeys(encryptedDataKeys) - && forall i :: 0 <= i < encryptedDataKeys.entries.Length ==> ValidUTF8(encryptedDataKeys.entries[i].keyProviderId) - // TODO: well-formedness of EDK - /* - Key Provider ID - The key provider identifier. The value of this field indicates the provider of the encrypted data key. See Key Provider for details on supported key providers. - - Key Provider Information - The key provider information. The key provider for this encrypted data key determines what this field contains. - - Encrypted Data Key - The encrypted data key. It is the data key encrypted by the key provider. - */ - } - predicate ValidIVLength(ivLength: uint8, algorithmSuiteID: AlgorithmSuite.ID) - { - algorithmSuiteID in AlgorithmSuite.Suite.Keys && AlgorithmSuite.Suite[algorithmSuiteID].params.ivLen == ivLength - } - predicate ValidFrameLength(frameLength: uint32, contentType: T_ContentType) - { - match contentType { - case NonFramed => frameLength == 0 - case Framed => true - } - } - - /* - * Validity of the message header authentication - */ - predicate ValidAuthenticationTag(authenticationTag: array, tagLength: nat, iv: array) - { - true - // TODO: strengthen the spec - // The algorithm suite specified by the Algorithm Suite ID field determines how the value of this field is calculated, and uses this value to authenticate the contents of the header during decryption. - } - predicate ValidHeaderAuthentication(ha: HeaderAuthentication, algorithmSuiteID: AlgorithmSuite.ID) - requires algorithmSuiteID in AlgorithmSuite.Suite.Keys - { - ValidAuthenticationTag(ha.authenticationTag, AlgorithmSuite.Suite[algorithmSuiteID].params.tagLen as nat, ha.iv) - } -} diff --git a/src/SDK/Serialize.dfy b/src/SDK/Serialize.dfy new file mode 100644 index 000000000..113e8a4a9 --- /dev/null +++ b/src/SDK/Serialize.dfy @@ -0,0 +1,227 @@ +include "MessageHeader.dfy" +include "Materials.dfy" +include "AlgorithmSuite.dfy" + +include "../Util/Streams.dfy" +include "../StandardLibrary/StandardLibrary.dfy" + +module Serialize { + import Msg = MessageHeader + import AlgorithmSuite + + import Streams + import opened StandardLibrary + import opened UInt = StandardLibrary.UInt + import Materials + + method SerializeHeaderBody(wr: Streams.StringWriter, hb: Msg.HeaderBody) returns (ret: Result) + requires wr.Valid() && hb.Valid() + modifies wr`data + ensures wr.Valid() + ensures match ret + case Success(totalWritten) => + var serHb := (reveal Msg.HeaderBodyToSeq(); Msg.HeaderBodyToSeq(hb)); + var initLen := old(|wr.data|); + && totalWritten == |serHb| + && initLen + totalWritten == |wr.data| + && serHb == wr.data[initLen..initLen + totalWritten] + case Failure(e) => true + { + var totalWritten := 0; + + var len :- wr.WriteByte(hb.version as uint8); + totalWritten := totalWritten + len; + + len :- wr.WriteByte(hb.typ as uint8); + totalWritten := totalWritten + len; + + len :- wr.WriteUInt16(hb.algorithmSuiteID as uint16); + totalWritten := totalWritten + len; + + len :- wr.WriteSeq(hb.messageID); + totalWritten := totalWritten + len; + + len :- SerializeAAD(wr, hb.aad); + totalWritten := totalWritten + len; + + len :- SerializeEDKs(wr, hb.encryptedDataKeys); + totalWritten := totalWritten + len; + + var contentType := Msg.ContentTypeToUInt8(hb.contentType); + len :- wr.WriteByte(contentType); + totalWritten := totalWritten + len; + + len :- wr.WriteSeq(hb.reserved); + totalWritten := totalWritten + len; + + len :- wr.WriteByte(hb.ivLength); + totalWritten := totalWritten + len; + + len :- wr.WriteUInt32(hb.frameLength); + totalWritten := totalWritten + len; + + reveal Msg.HeaderBodyToSeq(); + return Success(totalWritten); + } + + method SerializeHeaderAuthentication(wr: Streams.StringWriter, ha: Msg.HeaderAuthentication, ghost algorithmSuiteID: AlgorithmSuite.ID) returns (ret: Result) + requires wr.Valid() + modifies wr`data + ensures wr.Valid() + ensures match ret + case Success(totalWritten) => + var serHa := ha.iv + ha.authenticationTag; + var initLen := old(|wr.data|); + && totalWritten == |serHa| + && initLen + totalWritten == |wr.data| + && serHa == wr.data[initLen..initLen + totalWritten] + case Failure(e) => true + { + var m :- wr.WriteSeq(ha.iv); + var n :- wr.WriteSeq(ha.authenticationTag); + return Success(m + n); + } + + // ----- SerializeAAD ----- + + method SerializeAAD(wr: Streams.StringWriter, kvPairs: Materials.EncryptionContext) returns (ret: Result) + requires wr.Valid() && Msg.ValidAAD(kvPairs) + modifies wr`data + ensures wr.Valid() && Msg.ValidAAD(kvPairs) + ensures match ret + case Success(totalWritten) => + var serAAD := Msg.AADToSeq(kvPairs); + var initLen := old(|wr.data|); + && totalWritten == |serAAD| + && initLen + totalWritten == |wr.data| + && wr.data == old(wr.data) + serAAD + case Failure(e) => true + { + reveal Msg.ValidAAD(); + var totalWritten := 0; + + // Key Value Pairs Length (number of bytes of total AAD) + var aadLength :- ComputeAADLength(kvPairs); + var len :- wr.WriteUInt16(aadLength); + totalWritten := totalWritten + len; + if aadLength == 0 { + return Success(totalWritten); + } + + len :- wr.WriteUInt16(|kvPairs| as uint16); + totalWritten := totalWritten + len; + + var j := 0; + ghost var n := |kvPairs|; + while j < |kvPairs| + invariant j <= n == |kvPairs| + invariant wr.data == + old(wr.data) + + UInt16ToSeq(aadLength) + + UInt16ToSeq(n as uint16) + + Msg.KVPairsToSeq(kvPairs, 0, j) + invariant totalWritten == 4 + |Msg.KVPairsToSeq(kvPairs, 0, j)| + { + len :- wr.WriteUInt16(|kvPairs[j].0| as uint16); + totalWritten := totalWritten + len; + + len :- wr.WriteSeq(kvPairs[j].0); + totalWritten := totalWritten + len; + + len :- wr.WriteUInt16(|kvPairs[j].1| as uint16); + totalWritten := totalWritten + len; + + len :- wr.WriteSeq(kvPairs[j].1); + totalWritten := totalWritten + len; + + j := j + 1; + } + + return Success(totalWritten); + } + + method ComputeAADLength(kvPairs: Materials.EncryptionContext) returns (res: Result) + requires |kvPairs| < UINT16_LIMIT + requires forall i :: 0 <= i < |kvPairs| ==> Msg.ValidKVPair(kvPairs[i]) + ensures match res + case Success(len) => len as int == Msg.AADLength(kvPairs) + case Failure(_) => UINT16_LIMIT <= Msg.AADLength(kvPairs) + { + var n: int32 := |kvPairs| as int32; + if n == 0 { + return Success(0); + } else { + var len: int32, limit: int32 := 2, UINT16_LIMIT as int32; + var i: int32 := 0; + while i < n + invariant i <= n + invariant 2 + Msg.KVPairsLength(kvPairs, 0, i as int) == len as int < UINT16_LIMIT + { + var kvPair := kvPairs[i]; + len := len + 4 + |kvPair.0| as int32 + |kvPair.1| as int32; + Msg.KVPairsLengthSplit(kvPairs, 0, i as int + 1, n as int); + if limit <= len { + return Failure("The number of bytes in encryption context exceeds the allowed maximum"); + } + i := i + 1; + } + return Success(len as uint16); + } + } + + // ----- SerializeEDKs ----- + + method SerializeEDKs(wr: Streams.StringWriter, encryptedDataKeys: Msg.EncryptedDataKeys) returns (ret: Result) + requires wr.Valid() && encryptedDataKeys.Valid() + modifies wr`data + ensures wr.Valid() && encryptedDataKeys.Valid() + ensures match ret + case Success(totalWritten) => + var serEDK := Msg.EDKsToSeq(encryptedDataKeys); + var initLen := old(|wr.data|); + && totalWritten == |serEDK| + && initLen + totalWritten == |wr.data| + && wr.data == old(wr.data) + serEDK + case Failure(e) => true + { + var totalWritten := 0; + + var len :- wr.WriteUInt16(|encryptedDataKeys.entries| as uint16); + totalWritten := totalWritten + len; + + var j := 0; + ghost var n := |encryptedDataKeys.entries|; + while j < |encryptedDataKeys.entries| + invariant j <= n == |encryptedDataKeys.entries| + invariant wr.data == + old(wr.data) + + UInt16ToSeq(n as uint16) + + Msg.EDKEntriesToSeq(encryptedDataKeys.entries, 0, j); + invariant totalWritten == 2 + |Msg.EDKEntriesToSeq(encryptedDataKeys.entries, 0, j)| + { + var entry := encryptedDataKeys.entries[j]; + + len :- wr.WriteUInt16(|entry.providerID| as uint16); + totalWritten := totalWritten + len; + + len :- wr.WriteSeq(StringToByteSeq(entry.providerID)); + totalWritten := totalWritten + len; + + len :- wr.WriteUInt16(|entry.providerInfo| as uint16); + totalWritten := totalWritten + len; + + len :- wr.WriteSeq(entry.providerInfo); + totalWritten := totalWritten + len; + + len :- wr.WriteUInt16(|entry.ciphertext| as uint16); + totalWritten := totalWritten + len; + + len :- wr.WriteSeq(entry.ciphertext); + totalWritten := totalWritten + len; + + j := j + 1; + } + + return Success(totalWritten); + } +} diff --git a/src/StandardLibrary/StandardLibrary.dfy b/src/StandardLibrary/StandardLibrary.dfy index 79a537d57..83b16371b 100644 --- a/src/StandardLibrary/StandardLibrary.dfy +++ b/src/StandardLibrary/StandardLibrary.dfy @@ -49,14 +49,14 @@ module {:extern "STL"} StandardLibrary { } predicate StringIs8Bit(s: string) { - forall i :: i in s ==> i < 256 as char + forall i :: 0 <= i < |s| ==> s[i] < 256 as char } function Fill(value: T, n: nat): seq ensures |Fill(value, n)| == n ensures forall i :: 0 <= i < n ==> Fill(value, n)[i] == value { - if n > 0 then [value] + Fill(value, n-1) else [] + seq(n, _ => value) } method array_of_seq (s : seq) returns (a : array) @@ -68,27 +68,23 @@ module {:extern "STL"} StandardLibrary { } function method {:opaque} StringToByteSeq(s: string): (s': seq) - requires forall i :: i in s ==> i < 256 as char + requires StringIs8Bit(s) ensures |s| == |s'| { - if s == [] then [] else ( - assert (forall i :: i in s[1..] ==> i in s); - [(s[0] as int % 256) as uint8] + StringToByteSeq(s[1..])) + seq(|s|, i requires 0 <= i < |s| => s[i] as uint8) } - function method {:opaque} byteseq_of_string_lossy (s : string) : (s' : seq) + function method {:opaque} StringToByteSeqLossy(s: string): (s': seq) ensures |s| == |s'| { - if s == [] then [] else ( - assert (forall i :: i in s[1..] ==> i in s); - [(s[0] as int % 256) as uint8] + byteseq_of_string_lossy(s[1..])) + seq(|s|, i requires 0 <= i < |s| => (s[i] as uint16 % 256) as uint8) } function method {:opaque} ByteSeqToString(s: seq): (s': string) ensures |s| == |s'| - ensures forall i :: i in s' ==> i < 256 as char + ensures StringIs8Bit(s') { - if s == [] then [] else [(s[0] as char)] + ByteSeqToString(s[1..]) + seq(|s|, i requires 0 <= i < |s| => s[i] as char) } lemma StringByteSeqCorrect(s: string) @@ -120,134 +116,108 @@ module {:extern "STL"} StandardLibrary { } } - datatype gtag = tagged(val : A, ghost tagged : B) - function method val(p : gtag) : A { - match p - case tagged(x,y) => x + class mut { + var x: T + constructor (y: T) + ensures x == y + { + x := y; } - - function tag(p: gtag) : B { - match p - case tagged(x,y) => y + method put(y: T) + modifies this + ensures x == y + { + x := y; + } + function method get(): (y: T) + reads this + ensures y == x + { + x } - - class mut { - var x : T - constructor (y : T) ensures (this.x == y) { - this.x := y; - } - method put(y : T) modifies this ensures (this.x == y) { - this.x := y; - } - function method get() : (y : T) reads this ensures (y == this.x) { - this.x - } - } - - function method odflt (o : Option, x : T) : T { - match o - case Some(a) => a - case None => x - } - - function method isSome (o : Option) : bool { - match o - case Some(_) => true - case None => false - } - - function method isNone (o : Option) : bool { - match o - case Some(_) => false - case None => true } - predicate method uniq(s : seq) { + predicate method uniq(s: seq) { if s == [] then true else if s[0] in s[1..] then false else uniq(s[1..]) } - lemma uniq_idxP(s : seq) - ensures uniq(s) <==> (forall i, j :: 0 <= i < j < |s| ==> s[i] != s[j]) { - + lemma uniq_idxP(s: seq) + ensures uniq(s) <==> (forall i, j :: 0 <= i < j < |s| ==> s[i] != s[j]) + { } // TODO - lemma {:axiom} uniq_multisetP (s : seq) + lemma {:axiom} uniq_multisetP(s: seq) ensures uniq(s) <==> (forall x :: x in s ==> multiset(s)[x] == 1) - function method sum(s : seq, f : T -> int) : int { - if s == [] then 0 else f(s[0]) + sum(s[1..], f) + function method sum(s: seq, f: T -> int): int { + if s == [] then 0 else f(s[0]) + sum(s[1..], f) } - lemma {:axiom} sum_perm (s : seq , s' : seq, f : T -> int) - requires multiset(s) == multiset(s') - ensures sum(s, f) == sum(s', f) + lemma {:axiom} sum_perm(s: seq , s': seq, f: T -> int) + requires multiset(s) == multiset(s') + ensures sum(s, f) == sum(s', f) - function count (s : seq, x : T) : int { + function count(s: seq, x: T): int { if s == [] then 0 else (if s[0] == x then 1 else 0) + count(s[1..], x) } - lemma count_ge0 (s : seq, x : T) - ensures 0 <= count(s, x) { - if s == [] { } - else { - assert count(s, x) == (if x == s[0] then 1 else 0) + count(s[1..], x); - count_ge0(s[1..], x); - } + lemma count_ge0(s: seq, x: T) + ensures 0 <= count(s, x) + { + if s == [] { + } else { + assert count(s, x) == (if x == s[0] then 1 else 0) + count(s[1..], x); + count_ge0(s[1..], x); } + } - lemma count_nil (x : T) - ensures count([], x) == 0 { } - - lemma in_count_gt0P (s : seq, x : T) - ensures (x in s) <==> (count(s, x) > 0) { - if s == [] { } - else { - if s[0] == x { - count_ge0(s[1..], x); - } - else { } - } - } + lemma count_nil(x: T) + ensures count([], x) == 0 + { } - lemma count_idx_gt0P (s : seq, i : int) - requires (0 <= i < |s|) - ensures count(s, s[i]) > 0 { - assert s[i] in s; - in_count_gt0P(s, s[i]); + lemma in_count_gt0P(s: seq, x: T) + ensures (x in s) <==> (count(s, x) > 0) + { + if s != [] && s[0] == x { + count_ge0(s[1..], x); } + } - lemma count_eq0P (s : seq, x : T) - ensures (x !in s) <==> (count(s, x) == 0) { - if s == [] { } - else { - if s[0] == x { - assert s[0] in s; - assert x in s; - count_ge0(s[1..], x); - } - else { } - } + lemma count_idx_gt0P(s: seq, i: int) + requires 0 <= i < |s| + ensures count(s, s[i]) > 0 + { + assert s[i] in s; + in_count_gt0P(s, s[i]); + } + + lemma count_eq0P(s: seq, x: T) + ensures (x !in s) <==> (count(s, x) == 0) + { + if s != [] && s[0] == x { + assert s[0] in s; + assert x in s; + count_ge0(s[1..], x); } + } - lemma uniq_count_le1 (s : seq, x : T) + lemma uniq_count_le1(s: seq, x: T) requires uniq(s) - ensures count(s, x) <= 1 { - if s == [] { } - else { - if s[0] == x { - assert (s[0] !in s[1..]); - count_eq0P(s[1..], x); - } - else { } - } + ensures count(s, x) <= 1 + { + if s != [] && s[0] == x { + assert (s[0] !in s[1..]); + count_eq0P(s[1..], x); } + } - lemma multiset_seq_count (s : seq, x : T) - ensures multiset(s)[x] == count(s, x) { - if s == [] { } - else { + lemma multiset_seq_count(s: seq, x: T) + ensures multiset(s)[x] == count(s, x) + { + if s == [] { + } else { assert s == [s[0]] + s[1..]; assert multiset(s) == multiset{s[0]} + multiset(s[1..]); assert multiset(s)[x] == multiset{s[0]}[x] + multiset(s[1..])[x]; @@ -256,82 +226,79 @@ module {:extern "STL"} StandardLibrary { } // TODO - lemma {:axiom} multiset_seq_eq1 (s : seq) + lemma {:axiom} multiset_seq_eq1(s: seq) requires forall i, j :: 0 <= i < j < |s| ==> s[i] != s[j] ensures forall x :: x in multiset(s) ==> multiset(s)[x] == 1 // TODO - lemma {:axiom} multiset_of_seq_dup (s : seq, i : int, j : int) - requires 0 <= i < j < |s| - requires s[i] == s[j] - ensures multiset(s)[s[i]] > 1 + lemma {:axiom} multiset_of_seq_dup(s: seq, i: int, j: int) + requires 0 <= i < j < |s| + requires s[i] == s[j] + ensures multiset(s)[s[i]] > 1 - lemma multiset_of_seq_gt0P (s : seq, x : T) + lemma multiset_of_seq_gt0P(s: seq, x: T) requires multiset(s)[x] > 0 - ensures exists i :: 0 <= i < |s| && s[i] == x { } + ensures exists i :: 0 <= i < |s| && s[i] == x + { } // TODO - lemma {:axiom} seq_dup_multset (s : seq, x : T) + lemma {:axiom} seq_dup_multset(s: seq, x: T) requires multiset(s)[x] > 1 ensures exists i, j :: 0 <= i < j < |s| && s[i] == s[j] - lemma eq_multiset_seq_memP (s : seq, s' : seq, x : T) + lemma eq_multiset_seq_memP(s: seq, s': seq, x: T) requires multiset(s) == multiset(s') ensures (x in s) == (x in s') - { - if x in s { - assert x in multiset(s); - assert x in multiset(s'); - assert x in s'; - } - else { - assert x !in multiset(s); - assert x !in multiset(s'); - assert x !in s'; - } + { + if x in s { + assert x in multiset(s); + assert x in multiset(s'); + assert x in s'; + } + else { + assert x !in multiset(s); + assert x !in multiset(s'); + assert x !in s'; } + } function method MapSeq(s: seq, f: S ~> T): seq requires forall i :: 0 <= i < |s| ==> f.requires(s[i]) reads set i,o | 0 <= i < |s| && o in f.reads(s[i]) :: o { - if s == [] - then [] - else [f(s[0])] + MapSeq(s[1..], f) + if s == [] then [] else [f(s[0])] + MapSeq(s[1..], f) } function method FlattenSeq(s: seq>): seq { - if s == [] - then [] - else s[0] + FlattenSeq(s[1..]) + if s == [] then [] else s[0] + FlattenSeq(s[1..]) } - predicate method uniq_fst (s : seq<(S, T)>) { - uniq(MapSeq(s, (x : (S, T)) => x.0)) + predicate method uniq_fst(s: seq<(S, T)>) { + uniq(MapSeq(s, (x: (S, T)) => x.0)) } // TODO - lemma {:axiom} uniq_fst_uniq (s : seq<(S,T)>) + lemma {:axiom} uniq_fst_uniq(s: seq<(S,T)>) requires uniq_fst(s) ensures uniq(s) // TODO - lemma {:axiom} uniq_fst_idxP (s : seq<(S, T)>) + lemma {:axiom} uniq_fst_idxP(s: seq<(S, T)>) requires uniq_fst(s) ensures forall i, j :: 0 <= i < j < |s| ==> s[i].0 != s[j].0 function method min(a: int, b: int): int { - if a < b then a else b } + if a < b then a else b + } - method values(m: map) returns (vals: seq) - { + method values(m: map) returns (vals: seq) { var keys := m.Keys; vals := []; while keys != {} - decreases keys invariant |keys| + |vals| == |m.Keys| + decreases keys { var a :| a in keys; keys := keys - {a}; @@ -339,38 +306,88 @@ module {:extern "STL"} StandardLibrary { } } - predicate method ltByte(a: uint8, b: uint8) { a < b } - predicate method ltNat (a: nat, b: nat) { a < b } - predicate method ltInt (a: int, b: int) { a < b } + lemma {:axiom} eq_multiset_eq_len(s: seq, s': seq) + requires multiset(s) == multiset(s') + ensures |s| == |s'| - predicate method lexCmpArrays(a : array, b : array, lt: (T, T) -> bool) - reads a, b + predicate method UInt8Less(a: uint8, b: uint8) { a < b } + predicate method NatLess(a: nat, b: nat) { a < b } + predicate method IntLess(a: int, b: int) { a < b } + + /* + * Lexicographic comparison of sequences. + * + * Given two sequences `a` and `b` and a strict (that is, irreflexive) + * ordering `less` on the elements of these sequences, `LexCmpSeqs(a, b, less)` + * says whether or not `a` is lexicographically "less than or equal to" `b`. + * + * `a` is lexicographically "less than or equal to" `b` holds iff + * there exists a `k` such that + * - the first `k` elements of `a` and `b` are the same + * - either: + * -- `a` has length `k` (that is, `a` is a prefix of `b`) + * -- `a[k]` is strictly less (using `less`) than `b[k]` + */ + + predicate method LexicographicLessOrEqual(a: seq, b: seq, less: (T, T) -> bool) { + exists k :: 0 <= k <= |a| && LexicographicLessOrEqualAux(a, b, less, k) + } + + predicate method LexicographicLessOrEqualAux(a: seq, b: seq, less: (T, T) -> bool, lengthOfCommonPrefix: nat) + requires 0 <= lengthOfCommonPrefix <= |a| { - a.Length == 0 || (b.Length != 0 && lexCmpArraysNonEmpty(a, b, 0, lt)) + lengthOfCommonPrefix <= |b| && + (forall i :: 0 <= i < lengthOfCommonPrefix ==> a[i] == b[i]) && + (lengthOfCommonPrefix == |a| || + (lengthOfCommonPrefix < |b| && less(a[lengthOfCommonPrefix], b[lengthOfCommonPrefix]))) + } + + /* + * For an ordering `less` to be _trichotomous_ means that for any two `x` and `y`, + * exactly one of the following three conditions holds: + * - less(x, y) + * - x == y + * - less(y, x) + * Note that being trichotomous implies being irreflexive. The definition of + * `Trichotomous` here allows overlap between the three conditions, which lets us + * think of non-strict orderings (like "less than or equal" as opposed to just + * "less than") as being trichotomous. + */ + + predicate Trichotomous(less: (T, T) -> bool) { + forall t, t' :: less(t, t') || t == t' || less(t', t) } - predicate method lexCmpArraysNonEmpty(a : array, b : array, i: nat, lt: (T, T) -> bool) - requires i < a.Length - requires i < b.Length - requires forall j: nat :: j < i ==> a[j] == b[j] - decreases a.Length - i - reads a, b + /* + * If an ordering `less` is trichotomous, then so is the irreflexive lexicographic + * order built around `less`. + */ + + lemma LexPreservesTrichotomy(a: seq, b: seq, less: (T, T) -> bool) + requires Trichotomous(less) + ensures LexicographicLessOrEqual(a, b, less) || a == b || LexicographicLessOrEqual(b, a, less) { - if a[i] != b[i] - then lt(a[i], b[i]) - else - if i+1 < a.Length && i+1 < b.Length - then lexCmpArraysNonEmpty(a, b, i+1, lt) - else - if i+1 == a.Length && i+1 < b.Length - then true - else - if i+1 < a.Length && i+1 == b.Length - then false - else true // i+1 == a.Length && i+1 == b.Length, i.e. a == b - } - - lemma {:axiom} eq_multiset_eq_len (s : seq, s' : seq) - requires multiset(s) == multiset(s') - ensures |s| == |s'| + var m := 0; + while m < |a| && m < |b| && a[m] == b[m] + invariant m <= |a| && m <= |b| + invariant forall i :: 0 <= i < m ==> a[i] == b[i] + { + m := m + 1; + } + // m is the length of the common prefix of a and b + if m == |a| == |b| { + assert a == b; + } else if m == |a| < |b| { + assert LexicographicLessOrEqualAux(a, b, less, m); + } else if m == |b| < |a| { + assert LexicographicLessOrEqualAux(b, a, less, m); + } else { + assert m < |a| && m < |b|; + if + case less(a[m], b[m]) => + assert LexicographicLessOrEqualAux(a, b, less, m); + case less(b[m], a[m]) => + assert LexicographicLessOrEqualAux(b, a, less, m); + } + } } diff --git a/src/StandardLibrary/UInt.dfy b/src/StandardLibrary/UInt.dfy index 0c16615fa..985bc1e36 100644 --- a/src/StandardLibrary/UInt.dfy +++ b/src/StandardLibrary/UInt.dfy @@ -1,6 +1,7 @@ module {:extern "STLUInt"} StandardLibrary.UInt { newtype uint8 = x | 0 <= x < 0x100 + const UINT16_LIMIT := 0x1_0000 newtype uint16 = x | 0 <= x < 0x1_0000 newtype int32 = x | -0x8000_0000 <= x < 0x8000_0000 diff --git a/src/Util/Arrays-extern.cs b/src/Util/Arrays-extern.cs index 01b55c4bc..460f31352 100644 --- a/src/Util/Arrays-extern.cs +++ b/src/Util/Arrays-extern.cs @@ -3,15 +3,14 @@ namespace Arrays { - using Utils; public partial class Array { public static T[] copy(T[] source, BigInteger length) { - T[] dest = new T[Util.BigIntegerToInt(length)]; - System.Array.Copy(source, dest, Util.BigIntegerToInt(length)); + T[] dest = new T[(int)length]; + System.Array.Copy(source, dest, (int)length); return dest; } public static void copyTo(T[] source, T[] dest, BigInteger offset) { - source.CopyTo(dest, Util.BigIntegerToInt(offset)); + source.CopyTo(dest, (int)offset); } } } diff --git a/src/Util/Streams.dfy b/src/Util/Streams.dfy index 53bda7180..f9f99daf8 100644 --- a/src/Util/Streams.dfy +++ b/src/Util/Streams.dfy @@ -4,132 +4,120 @@ module Streams { import opened StandardLibrary import opened UInt = StandardLibrary.UInt - trait Stream { + class StringReader { + const data: array + var pos: nat - ghost var Repr : set - predicate Valid() reads this - function method capacity() : nat reads this requires Valid() // An upper bound on the amount of data the stream can accept on Write - function method available() : nat reads this requires Valid() // An upper bound on the amount of data the stream can deliver on Read + ghost var Repr: set - - method Write(a : array, off : nat, req : nat) returns (res : Either) - requires Valid() - requires a.Length >= off + req - modifies Repr - requires a !in Repr - ensures Valid() - ensures - match res - case Left(len_written) => len_written == min(req, old(capacity())) - case Right(e) => true - - - method Read(i : array, off : nat, req : nat) returns (res : Either) - requires Valid() - requires i.Length >= off + req - requires i !in Repr - modifies i, this - ensures Valid() - ensures - match res - case Left(len_read) => len_read == min(req, old(available())) - case Right(e) => true - } - - - class StringReader extends Stream { - - var data : array - var pos : nat - - predicate Valid() reads this { + predicate Valid() + reads this + { this in Repr && data in Repr && pos <= data.Length } - function method capacity() : nat reads this requires Valid() { 0 } - function method available() : nat reads this requires Valid() { data.Length - pos } + function method Available(): nat // An upper bound on the amount of data the stream can deliver on Read + requires Valid() + reads this + { + data.Length - pos + } + constructor(d: array) + ensures Valid() + { + Repr := {this, d}; + data := d; + pos := 0; + } - constructor(d : array) + method Read(arr: array, off: nat, req: nat) returns (res: Result) + requires Valid() && arr != data + requires off + req <= arr.Length + modifies this, arr ensures Valid() + ensures var n := min(req, old(Available())); + arr[..] == arr[..off] + data[old(pos) .. (old(pos) + n)] + arr[off + n ..] + ensures match res + case Success(lengthRead) => lengthRead == min(req, old(Available())) + case Failure(e) => unchanged(this) && unchanged(arr) { - Repr := {this, d}; - data := d; - pos := 0; + var n := min(req, Available()); + forall i | 0 <= i < n { + arr[off + i] := data[pos + i]; + } + pos := pos + n; + return Success(n); } - method Write(a : array, off : nat, req : nat) returns (res : Either) + method ReadSeq(desiredByteCount: nat) returns (bytes: seq) requires Valid() modifies this - requires a !in Repr ensures Valid() - ensures - match res - case Left(len_written) => len_written == min(req, old(capacity())) - case Right(e) => unchanged(this) + ensures bytes == data[old(pos)..][..min(desiredByteCount, old(Available()))] { - res := Right(IOError("Cannot write to StringReader")); + var n := min(desiredByteCount, Available()); + bytes := seq(n, i requires 0 <= i < n && pos + n <= data.Length reads this, data => data[pos + i]); + pos := pos + n; } - method Read(arr : array, off : nat, req : nat) returns (res : Either) + // Read exactly `n` bytes, if possible; otherwise, fail. + method ReadExact(n: nat) returns (res: Result>) requires Valid() - requires arr.Length >= off + req - requires arr != data + modifies this ensures Valid() - modifies arr, this - ensures unchanged(`data) - ensures var n := min(req, old(available())); - arr[..] == arr[..off] + data[old(pos) .. (old(pos) + n)] + arr[off + n ..] - ensures - match res - case Left(len_read) => len_read == min(req, old(available())) - case Right(e) => unchanged(this) && unchanged(arr) + ensures match res + case Success(bytes) => |bytes| == n + case Failure(_) => true { - if off == arr.Length || available() == 0 { - assert (min (req, available())) == 0; - res := Left(0); - } - else { - var n := min(req, available()); - forall i | 0 <= i < n { - arr[off + i] := data[pos + i]; - } - pos := pos + n; - res := Left(n); + var bytes := ReadSeq(n); + if |bytes| != n { + return Failure("IO Error: Not enough bytes left on stream."); + } else { + return Success(bytes); } } - /* - // TODO add a version without arrays - method ReadSimple(arr: array) returns (res: Result) + + // Read exactly 1 byte, if possible, and return as a uint8; otherwise, fail. + method ReadByte() returns (res: Result) requires Valid() + modifies this ensures Valid() + { + var bytes :- ReadExact(1); + var n := bytes[0]; + return Success(n); + } + + // Read exactly 2 bytes, if possible, and return as a uint16; otherwise, fail. + method ReadUInt16() returns (res: Result) + requires Valid() modifies this - ensures - match res - case Left(len_read) => - var n := arr.Length; - && pos == old(pos) + n - && arr[..] == data[old(pos) .. pos] - && len_read == n - case Right(e) => unchanged(this) + ensures Valid() { - if arr.Length <= available() { - forall i | 0 <= i < arr.Length { - arr[i] := data[pos + i]; - } - pos := pos + arr.Length; - res := Left(arr.Length); - } else { - res := Right(IOError("Not enough bytes available on stream.")); - } + var bytes :- ReadExact(2); + var n := SeqToUInt16(bytes); + return Success(n); + } + + // Read exactly 4 bytes, if possible, and return as a uint32; otherwise, fail. + method ReadUInt32() returns (res: Result) + requires Valid() + modifies this + ensures Valid() + { + var bytes :- ReadExact(4); + var n := SeqToUInt32(bytes); + return Success(n); } - */ } - class StringWriter extends Stream { - ghost var data : seq + class StringWriter { + ghost var data: seq + + ghost var Repr: set predicate Valid() reads this @@ -137,160 +125,115 @@ module Streams { this in Repr } - function method capacity(): nat - reads this + predicate method HasRemainingCapacity(n: nat) // Compare with an upper bound on the amount of data the stream can accept on Write requires Valid() - { - 0 // TODO? - } - - function method available(): nat reads this - requires Valid() - { - 0 // TODO - } - - constructor(n : nat) - ensures Valid() { - data := []; - Repr := {this}; + // TODO: revisit this definition if we change the backing store of the StringWriter to be something with limited capacity + true } - method Write(a: array, off: nat, req: nat) returns (res: Either) - requires Valid() - requires off + req <= a.Length - requires a !in Repr - modifies `data - ensures unchanged(`Repr) - ensures Valid() - ensures - match res - case Left(len_written) => - if old(capacity()) == 0 - then - len_written == 0 - else - //&& old(pos + n < |data|) - //&& Written(old(data), data, old(pos), pos, a[off..off+n], len_written, n) - && len_written == min(req, old(capacity())) - && data == old(data) + a[off..off+req] - case Right(e) => true + constructor() + ensures Valid() && fresh(Repr) { - if off == a.Length || capacity() == 0 { - res := Left(0); - } - else { - var n := min(req, capacity()); - data := data + a[off..off+req]; - res := Left(n); - } + data := []; + Repr := {this}; } - method WriteSimple(a: array) returns (res: Either) - requires Valid() - requires a !in Repr + method Write(a: array, off: nat, len: nat) returns (res: Result) + requires Valid() && a !in Repr + requires off + len <= a.Length modifies `data - ensures unchanged(`Repr) - ensures unchanged(a) ensures Valid() - ensures - match res - case Left(len_written) => - && len_written == a.Length - && data[..] == old(data) + a[..] - case Right(e) => unchanged(`data) + ensures match res + case Success(lengthWritten) => + && old(HasRemainingCapacity(len)) + && lengthWritten == len + && data == old(data) + a[off..][..len] + case Failure(e) => unchanged(`data) { - if a.Length <= capacity() { - data := data + a[..]; - res := Left(a.Length); + if HasRemainingCapacity(len) { + data := data + a[off..off + len]; + return Success(len); } else { - res := Right(IOError("Reached end of stream.")); + return Failure("IO Error: Stream capacity exceeded."); } } - method WriteSimpleSeq(a: seq) returns (res: Either) + method WriteSeq(bytes: seq) returns (res: Result) requires Valid() modifies `data - ensures unchanged(`Repr) ensures Valid() - ensures - match res - case Left(len_written) => - && len_written == |a| - && data[..] == old(data) + a - case Right(e) => unchanged(`data) + ensures match res + case Success(lengthWritten) => + && old(HasRemainingCapacity(|bytes|)) + && lengthWritten == |bytes| + && data == old(data) + bytes + case Failure(e) => unchanged(`data) { - if |a| <= capacity() { - data := data + a; - res := Left(|a|); + if HasRemainingCapacity(|bytes|) { + data := data + bytes; + return Success(|bytes|); } else { - res := Right(IOError("Reached end of stream.")); + return Failure("IO Error: Stream capacity exceeded."); } } - method WriteSingleByte(a: uint8) returns (res: Either) + method WriteByte(n: uint8) returns (res: Result) requires Valid() modifies `data - ensures unchanged(`Repr) ensures Valid() - ensures - match res - case Left(len_written) => - if old(capacity()) == 0 - then - len_written == 0 - else - && len_written == 1 - && data == old(data) + [a] - // Dafny->Boogie drops an old: https://github.com/dafny-lang/dafny/issues/320 - //&& data[..] == old(data)[.. old(pos)] + [a] + old(data)[old(pos) + 1 ..] - case Right(e) => unchanged(`data) + ensures match res + case Success(lengthWritten) => + && old(HasRemainingCapacity(1)) + && lengthWritten == 1 + && data == old(data) + [n] + case Failure(e) => unchanged(`data) { - if capacity() == 0 { - res := Left(0); - } - else { - data := data + [a]; - res := Left(1); + if HasRemainingCapacity(1) { + data := data + [n]; + return Success(1); + } else { + return Failure("IO Error: Stream capacity exceeded."); } } - method WriteSingleByteSimple(a: uint8) returns (res: Either) + method WriteUInt16(n: uint16) returns (res: Result) requires Valid() modifies `data - ensures unchanged(`Repr) ensures Valid() - ensures - match res - case Left(len_written) => - len_written == 1 - && data == old(data) + [a] - // Dafny->Boogie drops an old: https://github.com/dafny-lang/dafny/issues/320 - //&& data[..] == old(data)[.. old(pos)] + [a] + old(data)[old(pos) + 1 ..] - case Right(e) => unchanged(`data) + ensures match res + case Success(lengthWritten) => + && old(HasRemainingCapacity(2)) + && lengthWritten == 2 + && data == old(data) + UInt16ToSeq(n) + case Failure(e) => unchanged(`data) { - if 1 <= capacity() { - data := data + [a]; - res := Left(1); + if HasRemainingCapacity(2) { + data := data + UInt16ToSeq(n); + return Success(2); } else { - res := Right(IOError("Reached end of stream.")); + return Failure("IO Error: Stream capacity exceeded."); } } - method Read(arr : array, off : nat, req : nat) returns (res : Either) + method WriteUInt32(n: uint32) returns (res: Result) requires Valid() - requires arr.Length >= off + req + modifies `data ensures Valid() - modifies arr, this - ensures - match res - case Left(len_read) => len_read == min(req, old(available())) - case Right(e) => unchanged(this) && unchanged(arr) + ensures match res + case Success(lengthWritten) => + && old(HasRemainingCapacity(4)) + && lengthWritten == 4 + && data == old(data) + UInt32ToSeq(n) + case Failure(e) => unchanged(`data) { - res := Right(IOError("Cannot read from StringWriter")); + if HasRemainingCapacity(4) { + data := data + UInt32ToSeq(n); + return Success(4); + } else { + return Failure("IO Error: Stream capacity exceeded."); + } } } - } diff --git a/src/Util/UTF8.dfy b/src/Util/UTF8.dfy index 43180ee17..285830f85 100644 --- a/src/Util/UTF8.dfy +++ b/src/Util/UTF8.dfy @@ -12,82 +12,82 @@ include "../StandardLibrary/StandardLibrary.dfy" // This does NOT perform any range checks on the values encoded. module UTF8 { - import opened StandardLibrary - import opened UInt = StandardLibrary.UInt + import opened StandardLibrary + import opened UInt = StandardLibrary.UInt - // Returns the value of the idx'th bit, from least to most significant bit (0- indexed) - function method bit_at(x: uint8, idx: uint8): bool - requires idx < 8 - { - var w := x as bv8; - (w >> idx) & 1 == 1 - } + // Returns the value of the idx'th bit, from least to most significant bit (0- indexed) + function method BitAt(x: uint8, idx: uint8): bool + requires idx < 8 + { + var w := x as bv8; + (w >> idx) & 1 == 1 + } - // Checks if a[offset] is a valid continuation uint8. - predicate method ValidUTF8Continuation(a: array, offset: nat) - requires offset < a.Length - reads a - { - bit_at(a[offset], 7) && !bit_at(a[offset], 6) - } + // Checks if a[offset] is a valid continuation uint8. + predicate method ValidUTF8Continuation(a: seq, offset: nat) + requires offset < |a| + { + BitAt(a[offset], 7) && !BitAt(a[offset], 6) + } - // Returns which leading uint8 is at a[offset], or 0 if it is not a leading uint8. - function method CodePointCase(a: array, offset: nat): uint8 - requires offset < a.Length - reads a - { - if bit_at(a[offset], 7) then // 1xxx xxxx - if bit_at(a[offset], 6) then //11xx xxxx - if bit_at(a[offset], 5) then // 111x xxxx - if bit_at(a[offset], 4) then // 1111 xxxx - if bit_at(a[offset], 3) then // 1111 1xxx - 0 // Error case - else // 1111 0xxx - 4 - else // 1110 xxxx - 3 - else // 110x xxxx - 2 - else //10xx xxxx - 0 // Error case - else //0xxxxxxx - 1 - } + // Returns which leading uint8 is at a[offset], or 0 if it is not a leading uint8. + function method CodePointCase(a: seq, offset: nat): uint8 + requires offset < |a| + { + if BitAt(a[offset], 7) then // 1xxx xxxx + if BitAt(a[offset], 6) then //11xx xxxx + if BitAt(a[offset], 5) then // 111x xxxx + if BitAt(a[offset], 4) then // 1111 xxxx + if BitAt(a[offset], 3) then // 1111 1xxx + 0 // Error case + else // 1111 0xxx + 4 + else // 1110 xxxx + 3 + else // 110x xxxx + 2 + else //10xx xxxx + 0 // Error case + else //0xxxxxxx + 1 + } - predicate method ValidUTF8_at(a: array, offset: nat) - requires offset <= a.Length - reads a - decreases (a.Length - offset) - { - if offset == a.Length - then true - else - var c := CodePointCase(a, offset); - if c == 1 then - ValidUTF8_at(a, offset + 1) - else if c == 2 then - offset + 2 <= a.Length && - ValidUTF8Continuation(a, offset + 1) && - ValidUTF8_at(a, offset + 2) - else if c == 3 then - offset + 3 <= a.Length && - ValidUTF8Continuation(a, offset + 1) && - ValidUTF8Continuation(a, offset + 2) && - ValidUTF8_at(a, offset + 3) - else if c == 4 then - offset + 4 <= a.Length && - ValidUTF8Continuation(a, offset + 1) && - ValidUTF8Continuation(a, offset + 2) && - ValidUTF8Continuation(a, offset + 3) && - ValidUTF8_at(a, offset + 4) - else - false - } + predicate method ValidUTF8_at(a: seq, offset: nat) + requires offset <= |a| + decreases |a| - offset + { + if offset == |a| + then true + else + var c := CodePointCase(a, offset); + if c == 1 then + ValidUTF8_at(a, offset + 1) + else if c == 2 then + offset + 2 <= |a| && + ValidUTF8Continuation(a, offset + 1) && + ValidUTF8_at(a, offset + 2) + else if c == 3 then + offset + 3 <= |a| && + ValidUTF8Continuation(a, offset + 1) && + ValidUTF8Continuation(a, offset + 2) && + ValidUTF8_at(a, offset + 3) + else if c == 4 then + offset + 4 <= |a| && + ValidUTF8Continuation(a, offset + 1) && + ValidUTF8Continuation(a, offset + 2) && + ValidUTF8Continuation(a, offset + 3) && + ValidUTF8_at(a, offset + 4) + else + false + } - predicate method ValidUTF8(a: array) - reads a - { - ValidUTF8_at(a, 0) - } + predicate method ValidUTF8(a: array) + reads a + { + ValidUTF8_at(a[..], 0) + } + predicate method ValidUTF8Seq(s: seq) { + ValidUTF8_at(s, 0) + } }