diff --git a/src/Asn1Decode.sol b/src/Asn1Decode.sol index 40bf8ce..b8e8feb 100644 --- a/src/Asn1Decode.sol +++ b/src/Asn1Decode.sol @@ -24,21 +24,49 @@ pragma solidity ^0.8.15; // adapted from https://github.com/JonahGroendal/asn1-decode/tree/master -import {NodePtr, LibNodePtr} from "./NodePtr.sol"; import {LibBytes} from "./LibBytes.sol"; +type Asn1Ptr is uint256; + +library LibAsn1Ptr { + using LibAsn1Ptr for Asn1Ptr; + + // First byte index of the header + function header(Asn1Ptr self) internal pure returns (uint256) { + return uint80(Asn1Ptr.unwrap(self)); + } + + // First byte index of the content + function content(Asn1Ptr self) internal pure returns (uint256) { + return uint80(Asn1Ptr.unwrap(self) >> 80); + } + + // Content length + function length(Asn1Ptr self) internal pure returns (uint256) { + return uint80(Asn1Ptr.unwrap(self) >> 160); + } + + // Total length (header length + content length) + function totalLength(Asn1Ptr self) internal pure returns (uint256) { + return self.length() + self.content() - self.header(); + } + + // Pack 3 uint80s into a uint256 + function toAsn1Ptr(uint256 _header, uint256 _content, uint256 _length) internal pure returns (Asn1Ptr) { + return Asn1Ptr.wrap(_header | _content << 80 | _length << 160); + } +} + library Asn1Decode { - using LibNodePtr for NodePtr; + using LibAsn1Ptr for Asn1Ptr; using LibBytes for bytes; - bytes1 public constant NULL_VALUE = 0xF6; - /* * @dev Get the root node. First step in traversing an ASN1 structure * @param der The DER-encoded ASN1 structure * @return A pointer to the outermost node */ - function root(bytes memory der) internal pure returns (NodePtr) { + function root(bytes memory der) internal pure returns (Asn1Ptr) { return readNodeLength(der, 0); } @@ -48,7 +76,7 @@ library Asn1Decode { * @param ptr Pointer to the current node * @return A pointer to the child root node */ - function rootOf(bytes memory der, NodePtr ptr) internal pure returns (NodePtr) { + function rootOf(bytes memory der, Asn1Ptr ptr) internal pure returns (Asn1Ptr) { return readNodeLength(der, ptr.content()); } @@ -58,7 +86,7 @@ library Asn1Decode { * @param ptr Points to the indices of the current node * @return A pointer to the next sibling node */ - function nextSiblingOf(bytes memory der, NodePtr ptr) internal pure returns (NodePtr) { + function nextSiblingOf(bytes memory der, Asn1Ptr ptr) internal pure returns (Asn1Ptr) { return readNodeLength(der, ptr.content() + ptr.length()); } @@ -68,7 +96,7 @@ library Asn1Decode { * @param ptr Points to the indices of the current node * @return A pointer to the first child node */ - function firstChildOf(bytes memory der, NodePtr ptr) internal pure returns (NodePtr) { + function firstChildOf(bytes memory der, Asn1Ptr ptr) internal pure returns (Asn1Ptr) { require(der[ptr.header()] & 0x20 == 0x20, "Not a constructed type"); return readNodeLength(der, ptr.content()); } @@ -79,11 +107,11 @@ library Asn1Decode { * @param ptr Points to the indices of the current node * @return A pointer to a bitstring */ - function bitstring(bytes memory der, NodePtr ptr) internal pure returns (NodePtr) { + function bitstring(bytes memory der, Asn1Ptr ptr) internal pure returns (Asn1Ptr) { require(der[ptr.header()] == 0x03, "Not type BIT STRING"); // Only 00 padded bitstr can be converted to bytestr! require(der[ptr.content()] == 0x00, "Non-0-padded BIT STRING"); - return LibNodePtr.toNodePtr(ptr.header(), ptr.content() + 1, ptr.length() - 1); + return LibAsn1Ptr.toAsn1Ptr(ptr.header(), ptr.content() + 1, ptr.length() - 1); } /* @@ -92,7 +120,7 @@ library Asn1Decode { * @param ptr Points to the indices of the current node * @return A bitstring encoded in a uint256 */ - function bitstringUintAt(bytes memory der, NodePtr ptr) internal pure returns (uint256) { + function bitstringUintAt(bytes memory der, Asn1Ptr ptr) internal pure returns (uint256) { require(der[ptr.header()] == 0x03, "Not type BIT STRING"); uint256 len = ptr.length() - 1; return uint256(readBytesN(der, ptr.content() + 1, len) >> ((32 - len) * 8)); @@ -104,7 +132,7 @@ library Asn1Decode { * @param ptr Points to the indices of the current node * @return A pointer to a octet string */ - function octetString(bytes memory der, NodePtr ptr) internal pure returns (NodePtr) { + function octetString(bytes memory der, Asn1Ptr ptr) internal pure returns (Asn1Ptr) { require(der[ptr.header()] == 0x04, "Not type OCTET STRING"); return readNodeLength(der, ptr.content()); } @@ -115,7 +143,7 @@ library Asn1Decode { * @param ptr Points to the indices of the current node * @return Uint value of node */ - function uintAt(bytes memory der, NodePtr ptr) internal pure returns (uint256) { + function uintAt(bytes memory der, Asn1Ptr ptr) internal pure returns (uint256) { require(der[ptr.header()] == 0x02, "Not type INTEGER"); require(der[ptr.content()] & 0x80 == 0, "Not positive"); uint256 len = ptr.length(); @@ -128,7 +156,7 @@ library Asn1Decode { * @param ptr Points to the indices of the current node * @return 384-bit uint encoded in uint128 and uint256 */ - function uint384At(bytes memory der, NodePtr ptr) internal pure returns (uint128, uint256) { + function uint384At(bytes memory der, Asn1Ptr ptr) internal pure returns (uint128, uint256) { require(der[ptr.header()] == 0x02, "Not type INTEGER"); require(der[ptr.content()] & 0x80 == 0, "Not positive"); uint256 valueLength = ptr.length(); @@ -149,7 +177,7 @@ library Asn1Decode { * @param ptr Points to the indices of the current node * @return UNIX timestamp (seconds since 1970/01/01) */ - function timestampAt(bytes memory der, NodePtr ptr) internal pure returns (uint256) { + function timestampAt(bytes memory der, Asn1Ptr ptr) internal pure returns (uint256) { uint16 _years; uint256 offset = ptr.content(); uint256 length = ptr.length(); @@ -170,7 +198,7 @@ library Asn1Decode { return timestampFromDateTime(_years, _months, _days, _hours, _mins, _secs); } - function readNodeLength(bytes memory der, uint256 ix) private pure returns (NodePtr) { + function readNodeLength(bytes memory der, uint256 ix) private pure returns (Asn1Ptr) { uint256 length; uint80 ixFirstContentByte; if ((der[ix + 1] & 0x80) == 0) { @@ -187,7 +215,7 @@ library Asn1Decode { } ixFirstContentByte = uint80(ix + 2 + lengthbytesLength); } - return LibNodePtr.toNodePtr(ix, ixFirstContentByte, uint80(length)); + return LibAsn1Ptr.toAsn1Ptr(ix, ixFirstContentByte, uint80(length)); } function readBytesN(bytes memory self, uint256 idx, uint256 len) private pure returns (bytes32 ret) { diff --git a/src/CborDecode.sol b/src/CborDecode.sol new file mode 100644 index 0000000..027966c --- /dev/null +++ b/src/CborDecode.sol @@ -0,0 +1,116 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.15; + +import {LibBytes} from "./LibBytes.sol"; + +type CborElement is uint256; + +library LibCborElement { + // Cbor element type + function cborType(CborElement self) internal pure returns (uint8) { + return uint8(CborElement.unwrap(self)); + } + + // First byte index of the content + function start(CborElement self) internal pure returns (uint256) { + return uint80(CborElement.unwrap(self) >> 80); + } + + // First byte index of the next element (exclusive end of content) + function end(CborElement self) internal pure returns (uint256) { + return start(self) + length(self); + } + + // Content length (0 for non-string types) + function length(CborElement self) internal pure returns (uint256) { + uint8 _type = cborType(self); + if (_type == 0x40 || _type == 0x60) { + // length is non-zero only for byte strings and text strings + return value(self); + } + return 0; + } + + // Value of the element (length for string/map/array types, value for others) + function value(CborElement self) internal pure returns (uint64) { + return uint64(CborElement.unwrap(self) >> 160); + } + + // Returns true if the element is null + function isNull(CborElement self) internal pure returns (bool) { + return cborType(self) == 0xF6; + } + + // Pack 3 uint80s into a uint256 + function toCborElement(uint256 _type, uint256 _start, uint256 _length) internal pure returns (CborElement) { + return CborElement.wrap(_type | _start << 80 | _length << 160); + } +} + +library CborDecode { + using LibBytes for bytes; + using LibCborElement for CborElement; + + // Calculate the keccak256 hash of the given cbor element + function keccak(bytes memory cbor, CborElement ptr) internal pure returns (bytes32) { + return cbor.keccak(ptr.start(), ptr.length()); + } + + // Take a slice of the given cbor element + function slice(bytes memory cbor, CborElement ptr) internal pure returns (bytes memory) { + return cbor.slice(ptr.start(), ptr.length()); + } + + function byteStringAt(bytes memory cbor, uint256 ix) internal pure returns (CborElement) { + return elementAt(cbor, ix, 0x40, true); + } + + function nextByteString(bytes memory cbor, CborElement ptr) internal pure returns (CborElement) { + return elementAt(cbor, ptr.end(), 0x40, true); + } + + function nextByteStringOrNull(bytes memory cbor, CborElement ptr) internal pure returns (CborElement) { + return elementAt(cbor, ptr.end(), 0x40, false); + } + + function nextTextString(bytes memory cbor, CborElement ptr) internal pure returns (CborElement) { + return elementAt(cbor, ptr.end(), 0x60, true); + } + + function nextPositiveInt(bytes memory cbor, CborElement ptr) internal pure returns (CborElement) { + return elementAt(cbor, ptr.end(), 0x00, true); + } + + function mapAt(bytes memory cbor, uint256 ix) internal pure returns (CborElement) { + return elementAt(cbor, ix, 0xa0, true); + } + + function nextMap(bytes memory cbor, CborElement ptr) internal pure returns (CborElement) { + return mapAt(cbor, ptr.end()); + } + + function nextArray(bytes memory cbor, CborElement ptr) internal pure returns (CborElement) { + return elementAt(cbor, ptr.end(), 0x80, true); + } + + function elementAt(bytes memory cbor, uint256 ix, uint8 expectedType, bool required) internal pure returns (CborElement) { + uint8 _type = uint8(cbor[ix] & 0xe0); + uint8 ai = uint8(cbor[ix] & 0x1f); + if (_type == 0xe0) { + require(!required || ai != 22, "null value for required element"); + // primitive type, retain the additional information + return LibCborElement.toCborElement(_type | ai, ix + 1, 0); + } + require(_type == expectedType, "unexpected type"); + if (ai == 24) { + return LibCborElement.toCborElement(_type, ix + 2, uint8(cbor[ix + 1])); + } else if (ai == 25) { + return LibCborElement.toCborElement(_type, ix + 3, cbor.readUint16(ix + 1)); + } else if (ai == 26) { + return LibCborElement.toCborElement(_type, ix + 5, cbor.readUint32(ix + 1)); + } else if (ai == 27) { + return LibCborElement.toCborElement(_type, ix + 9, cbor.readUint64(ix + 1)); + } + return LibCborElement.toCborElement(_type, ix + 1, ai); + } +} diff --git a/src/CertManager.sol b/src/CertManager.sol index adb5887..ea98732 100644 --- a/src/CertManager.sol +++ b/src/CertManager.sol @@ -2,17 +2,16 @@ pragma solidity ^0.8.15; import {Sha2Ext} from "./Sha2Ext.sol"; -import {Asn1Decode} from "./Asn1Decode.sol"; +import {Asn1Decode, Asn1Ptr, LibAsn1Ptr} from "./Asn1Decode.sol"; import {ECDSA384} from "./ECDSA384.sol"; import {LibBytes} from "./LibBytes.sol"; -import {NodePtr, LibNodePtr} from "./NodePtr.sol"; // adapted from https://github.com/marlinprotocol/NitroProver/blob/f1d368d1f172ad3a55cd2aaaa98ad6a6e7dcde9d/src/CertManager.sol contract CertManager { using Asn1Decode for bytes; + using LibAsn1Ptr for Asn1Ptr; using LibBytes for bytes; - using LibNodePtr for NodePtr; // @dev download the root CA cert for AWS nitro enclaves from https://aws-nitro-enclaves.amazonaws.com/AWS_NitroEnclaves_Root-G1.zip // @dev convert the base64 encoded pub key into hex to get the cert below @@ -104,8 +103,8 @@ contract CertManager { return cache; } - NodePtr root = certificate.root(); - NodePtr tbsCertPtr = certificate.firstChildOf(root); + Asn1Ptr root = certificate.root(); + Asn1Ptr tbsCertPtr = certificate.firstChildOf(root); (uint256 notAfter, int256 maxPathLen, bytes memory pubKey) = _parseTbs(certificate, tbsCertPtr, clientCert); if (parentCache.pubKey.length != 0 || certHash != ROOT_CA_CERT_HASH) { @@ -121,15 +120,15 @@ contract CertManager { return cache; } - function _parseTbs(bytes memory certificate, NodePtr ptr, bool clientCert) + function _parseTbs(bytes memory certificate, Asn1Ptr ptr, bool clientCert) internal view returns (uint256 notAfter, int256 maxPathLen, bytes memory pubKey) { - NodePtr versionPtr = certificate.firstChildOf(ptr); - NodePtr vPtr = certificate.firstChildOf(versionPtr); - NodePtr serialPtr = certificate.nextSiblingOf(versionPtr); - NodePtr sigAlgoPtr = certificate.nextSiblingOf(serialPtr); + Asn1Ptr versionPtr = certificate.firstChildOf(ptr); + Asn1Ptr vPtr = certificate.firstChildOf(versionPtr); + Asn1Ptr serialPtr = certificate.nextSiblingOf(versionPtr); + Asn1Ptr sigAlgoPtr = certificate.nextSiblingOf(serialPtr); require(certificate.keccak(sigAlgoPtr.content(), sigAlgoPtr.length()) == CERT_ALGO_OID, "invalid cert sig algo"); uint256 version = certificate.uintAt(vPtr); @@ -139,32 +138,32 @@ contract CertManager { (notAfter, maxPathLen, pubKey) = _parseTbsInner(certificate, sigAlgoPtr, clientCert); } - function _parseTbsInner(bytes memory certificate, NodePtr sigAlgoPtr, bool clientCert) + function _parseTbsInner(bytes memory certificate, Asn1Ptr sigAlgoPtr, bool clientCert) internal view returns (uint256 notAfter, int256 maxPathLen, bytes memory pubKey) { - NodePtr issuerPtr = certificate.nextSiblingOf(sigAlgoPtr); - NodePtr validityPtr = certificate.nextSiblingOf(issuerPtr); - NodePtr subjectPtr = certificate.nextSiblingOf(validityPtr); - NodePtr subjectPublicKeyInfoPtr = certificate.nextSiblingOf(subjectPtr); - NodePtr extensionsPtr = certificate.nextSiblingOf(subjectPublicKeyInfoPtr); + Asn1Ptr issuerPtr = certificate.nextSiblingOf(sigAlgoPtr); + Asn1Ptr validityPtr = certificate.nextSiblingOf(issuerPtr); + Asn1Ptr subjectPtr = certificate.nextSiblingOf(validityPtr); + Asn1Ptr subjectPublicKeyInfoPtr = certificate.nextSiblingOf(subjectPtr); + Asn1Ptr extensionsPtr = certificate.nextSiblingOf(subjectPublicKeyInfoPtr); notAfter = _verifyValidity(certificate, validityPtr); maxPathLen = _verifyExtensions(certificate, extensionsPtr, clientCert); pubKey = _parsePubKey(certificate, subjectPublicKeyInfoPtr); } - function _parsePubKey(bytes memory certificate, NodePtr subjectPublicKeyInfoPtr) + function _parsePubKey(bytes memory certificate, Asn1Ptr subjectPublicKeyInfoPtr) internal pure returns (bytes memory subjectPubKey) { - NodePtr pubKeyAlgoPtr = certificate.firstChildOf(subjectPublicKeyInfoPtr); - NodePtr pubKeyAlgoIdPtr = certificate.firstChildOf(pubKeyAlgoPtr); - NodePtr algoParamsPtr = certificate.nextSiblingOf(pubKeyAlgoIdPtr); - NodePtr subjectPublicKeyPtr = certificate.nextSiblingOf(pubKeyAlgoPtr); - NodePtr subjectPubKeyPtr = certificate.bitstring(subjectPublicKeyPtr); + Asn1Ptr pubKeyAlgoPtr = certificate.firstChildOf(subjectPublicKeyInfoPtr); + Asn1Ptr pubKeyAlgoIdPtr = certificate.firstChildOf(pubKeyAlgoPtr); + Asn1Ptr algoParamsPtr = certificate.nextSiblingOf(pubKeyAlgoIdPtr); + Asn1Ptr subjectPublicKeyPtr = certificate.nextSiblingOf(pubKeyAlgoPtr); + Asn1Ptr subjectPubKeyPtr = certificate.bitstring(subjectPublicKeyPtr); require( certificate.keccak(pubKeyAlgoIdPtr.content(), pubKeyAlgoIdPtr.length()) == EC_PUB_KEY_OID, @@ -179,9 +178,9 @@ contract CertManager { subjectPubKey = certificate.slice(end - 96, 96); } - function _verifyValidity(bytes memory certificate, NodePtr validityPtr) internal view returns (uint256 notAfter) { - NodePtr notBeforePtr = certificate.firstChildOf(validityPtr); - NodePtr notAfterPtr = certificate.nextSiblingOf(notBeforePtr); + function _verifyValidity(bytes memory certificate, Asn1Ptr validityPtr) internal view returns (uint256 notAfter) { + Asn1Ptr notBeforePtr = certificate.firstChildOf(validityPtr); + Asn1Ptr notAfterPtr = certificate.nextSiblingOf(notBeforePtr); uint256 notBefore = certificate.timestampAt(notBeforePtr); notAfter = certificate.timestampAt(notAfterPtr); @@ -190,25 +189,25 @@ contract CertManager { require(notAfter >= block.timestamp, "certificate not valid anymore"); } - function _verifyExtensions(bytes memory certificate, NodePtr extensionsPtr, bool clientCert) + function _verifyExtensions(bytes memory certificate, Asn1Ptr extensionsPtr, bool clientCert) internal pure returns (int256 maxPathLen) { require(certificate[extensionsPtr.header()] == 0xa3, "invalid extensions"); extensionsPtr = certificate.firstChildOf(extensionsPtr); - NodePtr extensionPtr = certificate.firstChildOf(extensionsPtr); + Asn1Ptr extensionPtr = certificate.firstChildOf(extensionsPtr); uint256 end = extensionsPtr.content() + extensionsPtr.length(); bool basicConstraintsFound = false; bool keyUsageFound = false; maxPathLen = -1; while (true) { - NodePtr oidPtr = certificate.firstChildOf(extensionPtr); + Asn1Ptr oidPtr = certificate.firstChildOf(extensionPtr); bytes32 oid = certificate.keccak(oidPtr.content(), oidPtr.length()); if (oid == BASIC_CONSTRAINTS_OID || oid == KEY_USAGE_OID) { - NodePtr valuePtr = certificate.nextSiblingOf(oidPtr); + Asn1Ptr valuePtr = certificate.nextSiblingOf(oidPtr); if (certificate[valuePtr.header()] == 0x01) { // skip optional critical bool @@ -238,13 +237,13 @@ contract CertManager { require(!clientCert || maxPathLen == -1, "maxPathLen must be undefined for client cert"); } - function _verifyBasicConstraintsExtension(bytes memory certificate, NodePtr valuePtr) + function _verifyBasicConstraintsExtension(bytes memory certificate, Asn1Ptr valuePtr) internal pure returns (int256 maxPathLen) { maxPathLen = -1; - NodePtr basicConstraintsPtr = certificate.firstChildOf(valuePtr); + Asn1Ptr basicConstraintsPtr = certificate.firstChildOf(valuePtr); if (certificate[basicConstraintsPtr.header()] == 0x01) { // skip optional isCA bool require(basicConstraintsPtr.length() == 1, "invalid isCA bool value"); @@ -255,7 +254,7 @@ contract CertManager { } } - function _verifyKeyUsageExtension(bytes memory certificate, NodePtr valuePtr, bool clientCert) internal pure { + function _verifyKeyUsageExtension(bytes memory certificate, Asn1Ptr valuePtr, bool clientCert) internal pure { uint256 value = certificate.bitstringUintAt(valuePtr); // bits are reversed (DigitalSignature 0x01 => 0x80, CertSign 0x32 => 0x04) if (clientCert) { @@ -265,17 +264,17 @@ contract CertManager { } } - function _verifyCertSignature(bytes memory certificate, NodePtr ptr, bytes memory pubKey) internal view { - NodePtr sigAlgoPtr = certificate.nextSiblingOf(ptr); + function _verifyCertSignature(bytes memory certificate, Asn1Ptr ptr, bytes memory pubKey) internal view { + Asn1Ptr sigAlgoPtr = certificate.nextSiblingOf(ptr); require(certificate.keccak(sigAlgoPtr.content(), sigAlgoPtr.length()) == CERT_ALGO_OID, "invalid cert sig algo"); bytes memory hash = Sha2Ext.sha384(certificate, ptr.header(), ptr.totalLength()); - NodePtr sigPtr = certificate.nextSiblingOf(sigAlgoPtr); - NodePtr sigBPtr = certificate.bitstring(sigPtr); - NodePtr sigRoot = certificate.rootOf(sigBPtr); - NodePtr sigRPtr = certificate.firstChildOf(sigRoot); - NodePtr sigSPtr = certificate.nextSiblingOf(sigRPtr); + Asn1Ptr sigPtr = certificate.nextSiblingOf(sigAlgoPtr); + Asn1Ptr sigBPtr = certificate.bitstring(sigPtr); + Asn1Ptr sigRoot = certificate.rootOf(sigBPtr); + Asn1Ptr sigRPtr = certificate.firstChildOf(sigRoot); + Asn1Ptr sigSPtr = certificate.nextSiblingOf(sigRPtr); (uint128 rhi, uint256 rlo) = certificate.uint384At(sigRPtr); (uint128 shi, uint256 slo) = certificate.uint384At(sigSPtr); bytes memory sigPacked = abi.encodePacked(rhi, rlo, shi, slo); diff --git a/src/NitroValidator.sol b/src/NitroValidator.sol index 9e67346..b4c2fec 100644 --- a/src/NitroValidator.sol +++ b/src/NitroValidator.sol @@ -3,16 +3,19 @@ pragma solidity ^0.8.15; import {CertManager} from "./CertManager.sol"; import {Sha2Ext} from "./Sha2Ext.sol"; +import {CborDecode, CborElement, LibCborElement} from "./CborDecode.sol"; import {Asn1Decode} from "./Asn1Decode.sol"; import {ECDSA384} from "./ECDSA384.sol"; import {LibBytes} from "./LibBytes.sol"; -import {NodePtr, LibNodePtr} from "./NodePtr.sol"; + +import {console} from "forge-std/console.sol"; // adapted from https://github.com/marlinprotocol/NitroProver/blob/f1d368d1f172ad3a55cd2aaaa98ad6a6e7dcde9d/src/NitroProver.sol contract NitroValidator { using LibBytes for bytes; - using LibNodePtr for NodePtr; + using CborDecode for bytes; + using LibCborElement for CborElement; bytes32 public constant ATTESTATION_TBS_PREFIX = keccak256(hex"846a5369676e61747572653144a101382240"); bytes32 public constant ATTESTATION_DIGEST = keccak256("SHA384"); @@ -44,15 +47,15 @@ contract NitroValidator { hex"7fffffffffffffffffffffffffffffffffffffffffffffffe3b1a6c0fa1b96efac0d06d9245853bd76760cb5666294b9"; struct Ptrs { - NodePtr moduleID; + CborElement moduleID; uint64 timestamp; - NodePtr digest; - NodePtr[] pcrs; - NodePtr cert; - NodePtr[] cabundle; - NodePtr publicKey; - NodePtr userData; - NodePtr nonce; + CborElement digest; + CborElement[] pcrs; + CborElement cert; + CborElement[] cabundle; + CborElement publicKey; + CborElement userData; + CborElement nonce; } CertManager public immutable certManager; @@ -71,20 +74,46 @@ contract NitroValidator { offset = 2; } - NodePtr protectedPtr = _readNextElement(attestation, offset); - NodePtr unprotectedPtr = _readNextElement(attestation, protectedPtr.content() + protectedPtr.length()); - NodePtr payloadPtr = _readNextElement(attestation, unprotectedPtr.content() + unprotectedPtr.length()); - NodePtr signaturePtr = _readNextElement(attestation, payloadPtr.content() + payloadPtr.length()); + CborElement protectedPtr = attestation.byteStringAt(offset); + CborElement unprotectedPtr = attestation.nextMap(protectedPtr); + CborElement payloadPtr = attestation.nextByteString(unprotectedPtr); + CborElement signaturePtr = attestation.nextByteString(payloadPtr); - uint256 rawProtectedLength = protectedPtr.content() + protectedPtr.length() - offset; - uint256 rawPayloadLength = - payloadPtr.content() + payloadPtr.length() - unprotectedPtr.content() - unprotectedPtr.length(); + uint256 rawProtectedLength = protectedPtr.end() - offset; + uint256 rawPayloadLength = payloadPtr.end() - unprotectedPtr.end(); bytes memory rawProtectedBytes = attestation.slice(offset, rawProtectedLength); - bytes memory rawPayloadBytes = - attestation.slice(unprotectedPtr.content() + unprotectedPtr.length(), rawPayloadLength); - signature = attestation.slice(signaturePtr.content(), signaturePtr.length()); + bytes memory rawPayloadBytes = attestation.slice(unprotectedPtr.end(), rawPayloadLength); attestationTbs = _constructAttestationTbs(rawProtectedBytes, rawProtectedLength, rawPayloadBytes, rawPayloadLength); + signature = attestation.slice(signaturePtr.start(), signaturePtr.length()); + } + + function validateAttestation(bytes memory attestationTbs, bytes memory signature) public returns (Ptrs memory) { + Ptrs memory ptrs = _parseAttestation(attestationTbs); + + require(ptrs.moduleID.length() > 0, "no module id"); + require(ptrs.timestamp > 0, "no timestamp"); + require(ptrs.cabundle.length > 0, "no cabundle"); + require(attestationTbs.keccak(ptrs.digest) == ATTESTATION_DIGEST, "invalid digest"); + require(1 <= ptrs.pcrs.length && ptrs.pcrs.length <= 32, "invalid pcrs"); + require( + ptrs.publicKey.isNull() || (1 <= ptrs.publicKey.length() && ptrs.publicKey.length() <= 1024), + "invalid pub key" + ); + require(ptrs.userData.isNull() || (ptrs.userData.length() <= 512), "invalid user data"); + require(ptrs.nonce.isNull() || (ptrs.nonce.length() <= 512), "invalid nonce"); + + bytes memory cert = attestationTbs.slice(ptrs.cert); + bytes[] memory cabundle = new bytes[](ptrs.cabundle.length); + for (uint256 i = 0; i < ptrs.cabundle.length; i++) { + cabundle[i] = attestationTbs.slice(ptrs.cabundle[i]); + } + + CertManager.CachedCert memory parent = certManager.verifyCertBundle(cert, cabundle); + bytes memory hash = Sha2Ext.sha384(attestationTbs, 0, attestationTbs.length); + _verifySignature(parent.pubKey, hash, signature); + + return ptrs; } function _constructAttestationTbs( @@ -115,107 +144,55 @@ contract NitroValidator { LibBytes.memcpy(dest + 13 + rawProtectedLength, payloadSrc, rawPayloadLength); } - function validateAttestation(bytes memory attestationTbs, bytes memory signature) public returns (Ptrs memory) { - Ptrs memory ptrs = _parseAttestation(attestationTbs); - - require(ptrs.moduleID.length() > 0, "no module id"); - require(ptrs.timestamp > 0, "no timestamp"); - require(ptrs.cabundle.length > 0, "no cabundle"); - require( - attestationTbs.keccak(ptrs.digest.content(), ptrs.digest.length()) == ATTESTATION_DIGEST, "invalid digest" - ); - require(1 <= ptrs.pcrs.length && ptrs.pcrs.length <= 32, "invalid pcrs"); - require( - attestationTbs[ptrs.publicKey.header()] == Asn1Decode.NULL_VALUE - || (1 <= ptrs.publicKey.length() && ptrs.publicKey.length() <= 1024), - "invalid pub key" - ); - require( - attestationTbs[ptrs.userData.header()] == Asn1Decode.NULL_VALUE || (ptrs.userData.length() <= 512), - "invalid user data" - ); - require( - attestationTbs[ptrs.nonce.header()] == Asn1Decode.NULL_VALUE || (ptrs.nonce.length() <= 512), - "invalid nonce" - ); - - bytes memory cert = attestationTbs.slice(ptrs.cert.content(), ptrs.cert.length()); - bytes[] memory cabundle = new bytes[](ptrs.cabundle.length); - for (uint256 i = 0; i < ptrs.cabundle.length; i++) { - cabundle[i] = attestationTbs.slice(ptrs.cabundle[i].content(), ptrs.cabundle[i].length()); - } - - CertManager.CachedCert memory parent = certManager.verifyCertBundle(cert, cabundle); - bytes memory hash = Sha2Ext.sha384(attestationTbs, 0, attestationTbs.length); - _verifySignature(parent.pubKey, hash, signature); - - return ptrs; - } - function _parseAttestation(bytes memory attestationTbs) internal pure returns (Ptrs memory) { require(attestationTbs.keccak(0, 18) == ATTESTATION_TBS_PREFIX, "invalid attestation prefix"); - NodePtr payload = _readNextElement(attestationTbs, 18); - require(payload.header() == 0x40, "invalid attestation payload type"); - NodePtr payloadMap = _readNextElement(attestationTbs, payload.content()); - require(payloadMap.header() == 0xa0, "invalid attestation payload map type"); + CborElement payload = attestationTbs.byteStringAt(18); + CborElement current = attestationTbs.mapAt(payload.start()); Ptrs memory ptrs; - uint256 offset = payloadMap.content(); - uint256 end = payload.content() + payload.length(); - while (offset < end) { - NodePtr key = _readNextElement(attestationTbs, offset); - require(key.header() == 0x60, "invalid attestation key type"); - bytes32 keyHash = attestationTbs.keccak(key.content(), key.length()); - NodePtr value = _readNextElement(attestationTbs, key.content() + key.length()); + uint256 end = payload.end(); + while (current.end() < end) { + current = attestationTbs.nextTextString(current); + bytes32 keyHash = attestationTbs.keccak(current); if (keyHash == MODULE_ID_KEY) { - require(value.header() == 0x60, "invalid module_id type"); - ptrs.moduleID = value; - offset = value.content() + value.length(); + current = attestationTbs.nextTextString(current); + ptrs.moduleID = current; } else if (keyHash == DIGEST_KEY) { - require(value.header() == 0x60, "invalid digest type"); - ptrs.digest = value; - offset = value.content() + value.length(); + current = attestationTbs.nextTextString(current); + ptrs.digest = current; } else if (keyHash == CERTIFICATE_KEY) { - require(value.header() == 0x40, "invalid cert type"); - ptrs.cert = value; - offset = value.content() + value.length(); + current = attestationTbs.nextByteString(current); + ptrs.cert = current; } else if (keyHash == PUBLIC_KEY_KEY) { - ptrs.publicKey = value; - offset = value.content() + value.length(); + current = attestationTbs.nextByteStringOrNull(current); + ptrs.publicKey = current; } else if (keyHash == USER_DATA_KEY) { - ptrs.userData = value; - offset = value.content() + value.length(); + current = attestationTbs.nextByteStringOrNull(current); + ptrs.userData = current; } else if (keyHash == NONCE_KEY) { - ptrs.nonce = value; - offset = value.content() + value.length(); + current = attestationTbs.nextByteStringOrNull(current); + ptrs.nonce = current; } else if (keyHash == TIMESTAMP_KEY) { - require(value.header() == 0x00, "invalid timestamp type"); - ptrs.timestamp = uint64(value.length()); - offset = value.content(); + current = attestationTbs.nextPositiveInt(current); + ptrs.timestamp = uint64(current.value()); } else if (keyHash == CABUNDLE_KEY) { - require(value.header() == 0x80, "invalid cabundle type"); - offset = value.content(); - ptrs.cabundle = new NodePtr[](value.length()); - for (uint256 i = 0; i < value.length(); i++) { - NodePtr cert = _readNextElement(attestationTbs, offset); - require(cert.header() == 0x40, "invalid cert type"); - ptrs.cabundle[i] = cert; - offset = cert.content() + cert.length(); + current = attestationTbs.nextArray(current); + ptrs.cabundle = new CborElement[](current.value()); + for (uint256 i = 0; i < ptrs.cabundle.length; i++) { + current = attestationTbs.nextByteString(current); + ptrs.cabundle[i] = current; } } else if (keyHash == PCRS_KEY) { - require(value.header() == 0xa0, "invalid pcrs type"); - offset = value.content(); - ptrs.pcrs = new NodePtr[](value.length()); - for (uint256 i = 0; i < value.length(); i++) { - key = _readNextElement(attestationTbs, offset); - require(key.header() == 0x00, "invalid pcr key type"); - require(key.length() < value.length(), "invalid pcr key value"); - require(NodePtr.unwrap(ptrs.pcrs[key.length()]) == 0, "duplicate pcr key"); - NodePtr pcr = _readNextElement(attestationTbs, key.content()); - require(pcr.header() == 0x40, "invalid pcr type"); - ptrs.pcrs[key.length()] = pcr; - offset = pcr.content() + pcr.length(); + current = attestationTbs.nextMap(current); + ptrs.pcrs = new CborElement[](current.value()); + for (uint256 i = 0; i < ptrs.pcrs.length; i++) { + current = attestationTbs.nextPositiveInt(current); + uint256 key = current.value(); + require(key < ptrs.pcrs.length, "invalid pcr key value"); + require(CborElement.unwrap(ptrs.pcrs[key]) == 0, "duplicate pcr key"); + current = attestationTbs.nextByteString(current); + ptrs.pcrs[key] = current; } } else { revert("invalid attestation key"); @@ -225,26 +202,6 @@ contract NitroValidator { return ptrs; } - function _readNextElement(bytes memory cbor, uint256 ix) internal pure returns (NodePtr) { - uint256 _type = uint256(uint8(cbor[ix] & 0xe0)); - uint256 length = uint256(uint8(cbor[ix] & 0x1f)); - uint256 header = 1; - if (length == 24) { - length = uint8(cbor[ix + 1]); - header = 2; - } else if (length == 25) { - length = cbor.readUint16(ix + 1); - header = 3; - } else if (length == 26) { - length = cbor.readUint32(ix + 1); - header = 5; - } else if (length == 27) { - length = cbor.readUint64(ix + 1); - header = 9; - } - return LibNodePtr.toNodePtr(_type, ix + header, length); - } - function _verifySignature(bytes memory pubKey, bytes memory hash, bytes memory sig) internal view { ECDSA384.Parameters memory CURVE_PARAMETERS = ECDSA384.Parameters({ a: CURVE_A, diff --git a/src/NodePtr.sol b/src/NodePtr.sol deleted file mode 100644 index c9659e3..0000000 --- a/src/NodePtr.sol +++ /dev/null @@ -1,33 +0,0 @@ -// SPDX-License-Identifier: MIT -pragma solidity ^0.8.15; - -type NodePtr is uint256; - -library LibNodePtr { - using LibNodePtr for NodePtr; - - // First byte index of the header - function header(NodePtr self) internal pure returns (uint256) { - return uint80(NodePtr.unwrap(self)); - } - - // First byte index of the content - function content(NodePtr self) internal pure returns (uint256) { - return uint80(NodePtr.unwrap(self) >> 80); - } - - // Content length - function length(NodePtr self) internal pure returns (uint256) { - return uint80(NodePtr.unwrap(self) >> 160); - } - - // Total length (header length + content length) - function totalLength(NodePtr self) internal pure returns (uint256) { - return self.length() + self.content() - self.header(); - } - - // Pack 3 uint80s into a uint256 - function toNodePtr(uint256 _header, uint256 _content, uint256 _length) internal pure returns (NodePtr) { - return NodePtr.wrap(_header | _content << 80 | _length << 160); - } -}