@@ -298,6 +298,7 @@ def __init__(self, version=None, header_size=IMAGE_HEADER_SIZE,
298298 self .enctlv_len = 0
299299 self .max_align = max (DEFAULT_MAX_ALIGN , align ) if max_align is None else int (max_align )
300300 self .non_bootable = non_bootable
301+ self .key_ids = None
301302
302303 if self .max_align == DEFAULT_MAX_ALIGN :
303304 self .boot_magic = bytes ([
@@ -471,32 +472,40 @@ def ecies_hkdf(self, enckey, plainkey):
471472 format = PublicFormat .Raw )
472473 return cipherkey , ciphermac , pubk
473474
474- def create (self , key , public_key_format , enckey , dependencies = None ,
475+ def create (self , keys , public_key_format , enckey , dependencies = None ,
475476 sw_type = None , custom_tlvs = None , compression_tlvs = None ,
476477 compression_type = None , encrypt_keylen = 128 , clear = False ,
477478 fixed_sig = None , pub_key = None , vector_to_sign = None ,
478479 user_sha = 'auto' , is_pure = False , keep_comp_size = False , dont_encrypt = False ):
479480 self .enckey = enckey
480481
481- # key decides on sha, then pub_key; of both are none default is used
482- check_key = key if key is not None else pub_key
482+ # key decides on sha, then pub_key; if both are none default is used
483+ check_key = keys [ 0 ] if keys [ 0 ] is not None else pub_key
483484 hash_algorithm , hash_tlv = key_and_user_sha_to_alg_and_tlv (check_key , user_sha , is_pure )
484485
485486 # Calculate the hash of the public key
486- if key is not None :
487- pub = key .get_public_bytes ()
488- sha = hash_algorithm ()
489- sha .update (pub )
490- pubbytes = sha .digest ()
491- elif pub_key is not None :
492- if hasattr (pub_key , 'sign' ):
493- print (os .path .basename (__file__ ) + ": sign the payload" )
494- pub = pub_key .get_public_bytes ()
495- sha = hash_algorithm ()
496- sha .update (pub )
497- pubbytes = sha .digest ()
487+ pub_digests = []
488+ pub_list = []
489+
490+ if keys is None :
491+ if pub_key is not None :
492+ if hasattr (pub_key , 'sign' ):
493+ print (os .path .basename (__file__ ) + ": sign the payload" )
494+ pub = pub_key .get_public_bytes ()
495+ sha = hash_algorithm ()
496+ sha .update (pub )
497+ pubbytes = sha .digest ()
498+ else :
499+ pubbytes = bytes (hashlib .sha256 ().digest_size )
498500 else :
499- pubbytes = bytes (hashlib .sha256 ().digest_size )
501+ for key in keys or []:
502+ pub = key .get_public_bytes ()
503+ sha = hash_algorithm ()
504+ sha .update (pub )
505+ pubbytes = sha .digest ()
506+ pub_digests .append (pubbytes )
507+ pub_list .append (pub )
508+
500509
501510 protected_tlv_size = 0
502511
@@ -524,10 +533,14 @@ def create(self, key, public_key_format, enckey, dependencies=None,
524533 # value later.
525534 digest = bytes (hash_algorithm ().digest_size )
526535
536+ if pub_digests :
537+ boot_pub_digest = pub_digests [0 ]
538+ else :
539+ boot_pub_digest = pubbytes
527540 # Create CBOR encoded boot record
528541 boot_record = create_sw_component_data (sw_type , image_version ,
529542 hash_tlv , digest ,
530- pubbytes )
543+ boot_pub_digest )
531544
532545 protected_tlv_size += TLV_SIZE + len (boot_record )
533546
@@ -646,33 +659,39 @@ def create(self, key, public_key_format, enckey, dependencies=None,
646659 print (os .path .basename (__file__ ) + ': export digest' )
647660 return
648661
649- if self . key_ids is not None :
650- self . _add_key_id_tlv_to_unprotected ( tlv , self . key_ids [ 0 ] )
662+ if fixed_sig is not None and keys is not None :
663+ raise click . UsageError ( "Can not sign using key and provide fixed-signature at the same time" )
651664
652- if key is not None or fixed_sig is not None :
653- if public_key_format == 'hash' :
654- tlv .add ('KEYHASH' , pubbytes )
655- else :
656- tlv .add ('PUBKEY' , pub )
665+ if fixed_sig is not None :
666+ tlv .add (pub_key .sig_tlv (), fixed_sig ['value' ])
667+ self .signatures [0 ] = fixed_sig ['value' ]
668+ else :
669+ # Multi-signature handling: iterate through each provided key and sign.
670+ self .signatures = []
671+ for i , key in enumerate (keys ):
672+ # If key IDs are provided, and we have enough for this key, add it first.
673+ if self .key_ids is not None and len (self .key_ids ) > i :
674+ # Convert key id (an integer) to 4-byte defined endian bytes.
675+ kid_bytes = self .key_ids [i ].to_bytes (4 , self .endian )
676+ tlv .add ('KEYID' , kid_bytes ) # Using the TLV tag that corresponds to key IDs.
677+
678+ if public_key_format == 'hash' :
679+ tlv .add ('KEYHASH' , pub_digests [i ])
680+ else :
681+ tlv .add ('PUBKEY' , pub_list [i ])
657682
658- if key is not None and fixed_sig is None :
659683 # `sign` expects the full image payload (hashing done
660684 # internally), while `sign_digest` expects only the digest
661685 # of the payload
662-
663686 if hasattr (key , 'sign' ):
664687 print (os .path .basename (__file__ ) + ": sign the payload" )
665688 sig = key .sign (bytes (self .payload ))
666689 else :
667690 print (os .path .basename (__file__ ) + ": sign the digest" )
668691 sig = key .sign_digest (message )
669692 tlv .add (key .sig_tlv (), sig )
670- self .signature = sig
671- elif fixed_sig is not None and key is None :
672- tlv .add (pub_key .sig_tlv (), fixed_sig ['value' ])
673- self .signature = fixed_sig ['value' ]
674- else :
675- raise click .UsageError ("Can not sign using key and provide fixed-signature at the same time" )
693+ self .signatures .append (sig )
694+
676695
677696 # At this point the image was hashed + signed, we can remove the
678697 # protected TLVs from the payload (will be re-added later)
@@ -721,7 +740,7 @@ def get_struct_endian(self):
721740 return STRUCT_ENDIAN_DICT [self .endian ]
722741
723742 def get_signature (self ):
724- return self .signature
743+ return self .signatures
725744
726745 def get_infile_data (self ):
727746 return self .infile_data
@@ -831,75 +850,99 @@ def verify(imgfile, key):
831850 if magic != IMAGE_MAGIC :
832851 return VerifyResult .INVALID_MAGIC , None , None , None
833852
853+ # Locate the first TLV info header
834854 tlv_off = header_size + img_size
835855 tlv_info = b [tlv_off :tlv_off + TLV_INFO_SIZE ]
836856 magic , tlv_tot = struct .unpack ('HH' , tlv_info )
857+
858+ # If it's the protected-TLV block, skip it
837859 if magic == TLV_PROT_INFO_MAGIC :
838- tlv_off += tlv_tot
860+ tlv_off += TLV_INFO_SIZE + tlv_tot
839861 tlv_info = b [tlv_off :tlv_off + TLV_INFO_SIZE ]
840862 magic , tlv_tot = struct .unpack ('HH' , tlv_info )
841863
842864 if magic != TLV_INFO_MAGIC :
843865 return VerifyResult .INVALID_TLV_INFO_MAGIC , None , None , None
844866
845- # This is set by existence of TLV SIG_PURE
846- is_pure = False
867+ # Define the unprotected-TLV window
868+ unprot_off = tlv_off + TLV_INFO_SIZE
869+ unprot_end = unprot_off + tlv_tot
847870
848- prot_tlv_size = tlv_off
849- hash_region = b [:prot_tlv_size ]
850- tlv_end = tlv_off + tlv_tot
851- tlv_off += TLV_INFO_SIZE # skip tlv info
871+ # Region up to the start of unprotected TLVs is hashed
872+ prot_tlv_end = unprot_off - TLV_INFO_SIZE
873+ hash_region = b [:prot_tlv_end ]
852874
853- # First scan all TLVs in search of SIG_PURE
854- while tlv_off < tlv_end :
855- tlv = b [tlv_off :tlv_off + TLV_SIZE ]
875+ # This is set by existence of TLV SIG_PURE
876+ is_pure = False
877+ scan_off = unprot_off
878+ while scan_off < unprot_end :
879+ tlv = b [scan_off :scan_off + TLV_SIZE ]
856880 tlv_type , _ , tlv_len = struct .unpack ('BBH' , tlv )
857881 if tlv_type == TLV_VALUES ['SIG_PURE' ]:
858882 is_pure = True
859883 break
860- tlv_off += TLV_SIZE + tlv_len
884+ scan_off += TLV_SIZE + tlv_len
861885
886+ if key is not None and not isinstance (key , list ):
887+ key = [key ]
888+
889+ verify_results = []
890+ scan_off = unprot_off
862891 digest = None
863- tlv_off = prot_tlv_size
864- tlv_end = tlv_off + tlv_tot
865- tlv_off += TLV_INFO_SIZE # skip tlv info
866- while tlv_off < tlv_end :
867- tlv = b [tlv_off : tlv_off + TLV_SIZE ]
892+ prot_tlv_size = unprot_off - TLV_INFO_SIZE
893+
894+ # Verify hash and signatures
895+ while scan_off < unprot_end :
896+ tlv = b [scan_off : scan_off + TLV_SIZE ]
868897 tlv_type , _ , tlv_len = struct .unpack ('BBH' , tlv )
869898 if is_sha_tlv (tlv_type ):
870- if not tlv_matches_key_type (tlv_type , key ):
899+ if not tlv_matches_key_type (tlv_type , key [ 0 ] ):
871900 return VerifyResult .KEY_MISMATCH , None , None , None
872- off = tlv_off + TLV_SIZE
901+ off = scan_off + TLV_SIZE
873902 digest = get_digest (tlv_type , hash_region )
874- if digest == b [off :off + tlv_len ]:
875- if key is None :
876- return VerifyResult .OK , version , digest , None
877- else :
878- return VerifyResult .INVALID_HASH , None , None , None
879- elif not is_pure and key is not None and tlv_type == TLV_VALUES [key .sig_tlv ()]:
880- off = tlv_off + TLV_SIZE
881- tlv_sig = b [off :off + tlv_len ]
882- payload = b [:prot_tlv_size ]
883- try :
884- if hasattr (key , 'verify' ):
885- key .verify (tlv_sig , payload )
886- else :
887- key .verify_digest (tlv_sig , digest )
888- return VerifyResult .OK , version , digest , None
889- except InvalidSignature :
890- # continue to next TLV
891- pass
903+ if digest != b [off :off + tlv_len ]:
904+ verify_results .append (("Digest" , "INVALID_HASH" ))
905+
906+ elif not is_pure and key is not None and tlv_type == TLV_VALUES [key [0 ].sig_tlv ()]:
907+ for idx , k in enumerate (key ):
908+ if tlv_type == TLV_VALUES [k .sig_tlv ()]:
909+ off = scan_off + TLV_SIZE
910+ tlv_sig = b [off :off + tlv_len ]
911+ payload = b [:prot_tlv_size ]
912+ try :
913+ if hasattr (k , 'verify' ):
914+ k .verify (tlv_sig , payload )
915+ else :
916+ k .verify_digest (tlv_sig , digest )
917+ verify_results .append ((f"Key { idx } " , "OK" ))
918+ break
919+ except InvalidSignature :
920+ # continue to next TLV
921+ verify_results .append ((f"Key { idx } " , "INVALID_SIGNATURE" ))
922+ continue
923+
892924 elif is_pure and key is not None and tlv_type in ALLOWED_PURE_SIG_TLVS :
893- off = tlv_off + TLV_SIZE
925+ # pure signature verification
926+ off = scan_off + TLV_SIZE
894927 tlv_sig = b [off :off + tlv_len ]
928+ k = key [0 ]
895929 try :
896- key .verify_digest (tlv_sig , hash_region )
930+ k .verify_digest (tlv_sig , hash_region )
897931 return VerifyResult .OK , version , None , tlv_sig
898932 except InvalidSignature :
899- # continue to next TLV
900- pass
901- tlv_off += TLV_SIZE + tlv_len
902- return VerifyResult .INVALID_SIGNATURE , None , None , None
933+ return VerifyResult .INVALID_SIGNATURE , None , None , None
934+
935+ scan_off += TLV_SIZE + tlv_len
936+ # Now print out the verification results:
937+ for k , result in verify_results :
938+ print (f"{ k } : { result } " )
939+
940+ # Decide on a final return (for example, OK only if at least one signature is valid)
941+ if any (result == "OK" for _ , result in verify_results ):
942+ return VerifyResult .OK , version , digest , None
943+ else :
944+ return VerifyResult .INVALID_SIGNATURE , None , None , None
945+
903946
904947 def set_key_ids (self , key_ids ):
905948 """Set list of key IDs (integers) to be inserted before each signature."""
0 commit comments