From 08b1cd9a985ab0b18ccc47ad2cf1f2c42c1d0c45 Mon Sep 17 00:00:00 2001 From: Stepan Snigirev Date: Fri, 4 Aug 2023 03:30:08 +0200 Subject: [PATCH] full taptree psbt[view] signing --- src/embit/descriptor/arguments.py | 8 +- src/embit/liquid/descriptor.py | 5 +- src/embit/liquid/transaction.py | 1 + src/embit/psbt.py | 232 ++++++++++++-------- src/embit/psbtview.py | 353 +++++++++++++++++++++--------- src/embit/transaction.py | 25 ++- tests/tests/test_descriptor.py | 5 + tests/tests/test_liquid.py | 10 +- tests/tests/test_psbt.py | 1 - tests/tests/test_psbtview.py | 6 +- 10 files changed, 428 insertions(+), 218 deletions(-) diff --git a/src/embit/descriptor/arguments.py b/src/embit/descriptor/arguments.py index 4aa3913..018dc6c 100644 --- a/src/embit/descriptor/arguments.py +++ b/src/embit/descriptor/arguments.py @@ -27,9 +27,13 @@ def __str__(self): class AllowedDerivation(DescriptorBase): # xpub/<0;1>/* - <0;1> is a set of allowed branches, wildcard * is stored as None def __init__(self, indexes=[[0, 1], None]): - # check only one wildcard and only one set is in the derivation - if len([i for i in indexes if i is None]) > 1: + # check only one wildcard + if len([i + for i in indexes + if i is None or (isinstance(i, list) and None in i) + ]) > 1: raise ArgumentError("Only one wildcard is allowed") + # check only one set is in the derivation if len([i for i in indexes if isinstance(i, list)]) > 1: raise ArgumentError("Only one set of branches is allowed") self.indexes = indexes diff --git a/src/embit/liquid/descriptor.py b/src/embit/liquid/descriptor.py index a3e8427..a6e4251 100644 --- a/src/embit/liquid/descriptor.py +++ b/src/embit/liquid/descriptor.py @@ -1,9 +1,10 @@ +from .. import ec from ..descriptor.descriptor import * from .networks import NETWORKS from .addresses import address from . import slip77 from ..hashes import tagged_hash, sha256 -from ..ec import PrivateKey, PublicKey, secp256k1 +from ..ec import secp256k1 class LDescriptor(Descriptor): """Liquid descriptor that supports blinded() wrapper""" @@ -191,7 +192,7 @@ def sec(self): if self._pubkey is None: pubs = [secp256k1.ec_pubkey_parse(k.sec()) for k in self.keys] pub = musig_combine_pubs(pubs) - self._pubkey = PublicKey(pub) + self._pubkey = ec.PublicKey(pub) return self._pubkey.sec() def musig_combine_privs(privs, sort=True): diff --git a/src/embit/liquid/transaction.py b/src/embit/liquid/transaction.py index 3f737ce..7bd3442 100644 --- a/src/embit/liquid/transaction.py +++ b/src/embit/liquid/transaction.py @@ -1,3 +1,4 @@ +import sys import io from .. import compact from ..script import Script, Witness diff --git a/src/embit/psbt.py b/src/embit/psbt.py index b9c1bd8..8146225 100644 --- a/src/embit/psbt.py +++ b/src/embit/psbt.py @@ -4,8 +4,8 @@ from . import bip32 from . import ec from . import hashes -from .script import Script, Witness from . import script +from .script import Script, Witness from .base import EmbitBase, EmbitError from binascii import b2a_base64, a2b_base64, hexlify, unhexlify @@ -788,7 +788,13 @@ def sighash(self, i, sighash=SIGHASH.ALL, **kwargs): if inp.is_taproot: values = [inp.utxo.value for inp in self.inputs] scripts = [inp.utxo.script_pubkey for inp in self.inputs] - return self.sighash_taproot(i, script_pubkeys=scripts, values=values, sighash=sighash, **kwargs) + return self.sighash_taproot( + i, + script_pubkeys=scripts, + values=values, + sighash=sighash, + **kwargs, + ) value = inp.utxo.value sc = inp.witness_script or inp.redeem_script or inp.utxo.script_pubkey @@ -812,17 +818,74 @@ def sighash(self, i, sighash=SIGHASH.ALL, **kwargs): h = self.sighash_legacy(i, sc, sighash=sighash) return h + def sign_input_with_tapkey( + self, + key: ec.PrivateKey, + input_index: int, + inp = None, + sighash = SIGHASH.DEFAULT, + ) -> int: + """Sign taproot input with key. Signs with internal or leaf key.""" + # get input ourselves if not provided + inp = inp or self.inputs[input_index] + if not inp.is_taproot: + return 0 + # check if key is internal key + pk = key.taproot_tweak(inp.taproot_merkle_root or b"") + if pk.xonly() in inp.utxo.script_pubkey.data: + h = self.sighash( + input_index, + sighash=sighash, + ) + sig = pk.schnorr_sign(h) + wit = sig.serialize() + if sighash != SIGHASH.DEFAULT: + wit += bytes([sighash]) + # TODO: maybe better to put into internal key sig field + inp.final_scriptwitness = Witness([wit]) + # no need to sign anything else + return 1 + counter = 0 + # negate if necessary + pub = ec.PublicKey.from_xonly(key.xonly()) + # iterate over leafs and sign + for ctrl, sc in inp.taproot_scripts.items(): + if pub.xonly() not in sc: + continue + leaf_version = sc[-1] + script = Script(sc[:-1]) + h = self.sighash( + input_index, + sighash=sighash, + ext_flag=1, + script=script, + leaf_version=leaf_version, + ) + sig = key.schnorr_sign(h) + leaf = hashes.tagged_hash( + "TapLeaf", + bytes([leaf_version])+script.serialize() + ) + sigdata = sig.serialize() + # append sighash if necessary + if sighash != SIGHASH.DEFAULT: + sigdata += bytes([sighash]) + inp.taproot_sigs[(pub,leaf)] = sigdata + counter += 1 + return counter + def sign_with(self, root, sighash=SIGHASH.DEFAULT) -> int: """ Signs psbt with root key (HDKey or similar). Returns number of signatures added to PSBT. - Sighash kwarg is set to SIGHASH.DEFAULT, for segwit and legacy it's replaced to SIGHASH.ALL + Sighash kwarg is set to SIGHASH.DEFAULT, + for segwit and legacy it's replaced to SIGHASH.ALL so if PSBT is asking to sign with a different sighash this function won't sign. If you want to sign with sighashes provided in the PSBT - set sighash=None. """ - # check if it's a descriptor + counter = 0 # sigs counter + # check if it's a descriptor, and sign with all private keys in this descriptor if hasattr(root, "keys"): - counter = 0 for k in root.keys: if hasattr(k, "is_private") and k.is_private: counter += self.sign_with(k, sighash) @@ -841,119 +904,96 @@ def sign_with(self, root, sighash=SIGHASH.DEFAULT) -> int: # if HDKey if not fingerprint and hasattr(root, "my_fingerprint"): fingerprint = root.my_fingerprint - if not fingerprint: - pub = root.get_public_key() - sec = pub.sec() - pkh = hashes.hash160(sec) + + rootpub = root.get_public_key() + sec = rootpub.sec() + pkh = hashes.hash160(sec) counter = 0 for i, inp in enumerate(self.inputs): - # SIGHASH.DEFAULT is only for taproot, fallback to SIGHASH.ALL for other inputs + # SIGHASH.DEFAULT is only for taproot, fallback + # to SIGHASH.ALL for other inputs required_sighash = sighash if not inp.is_taproot and required_sighash == SIGHASH.DEFAULT: required_sighash = SIGHASH.ALL # check which sighash to use - inp_sighash = inp.sighash_type or required_sighash or SIGHASH.DEFAULT + inp_sighash = inp.sighash_type + if inp_sighash is None: + inp_sighash = required_sighash or SIGHASH.DEFAULT if not inp.is_taproot and inp_sighash == SIGHASH.DEFAULT: inp_sighash = SIGHASH.ALL - # if input sighash is set and is different from kwarg - don't sign this input + # if input sighash is set and is different from required sighash + # we don't sign this input + # except DEFAULT is functionally the same as ALL if required_sighash is not None and inp_sighash != required_sighash: - continue - - h = self.sighash(i, sighash=inp_sighash) - - sc = inp.witness_script or inp.redeem_script or inp.utxo.script_pubkey - - # taproot is special - # currently works only for single key - if inp.is_taproot: - # individual private key - # TODO: what if it's hdkey but root key was used for signing? - if not fingerprint: - # check if key is internal key - # TODO: process hashes - pk = root.taproot_tweak(inp.taproot_merkle_root or b"") - if pk.xonly() not in sc.data: - continue - sig = pk.schnorr_sign(h) - wit = sig.serialize() - if inp_sighash != SIGHASH.DEFAULT: - wit += bytes([inp_sighash]) - inp.final_scriptwitness = Witness([wit]) - counter += 1 + if (inp_sighash not in {SIGHASH.DEFAULT, SIGHASH.ALL} + or required_sighash not in {SIGHASH.DEFAULT, SIGHASH.ALL}): continue - # if we use HDKey - bip32_derivations = set() + + # get all possible derivations with matching fingerprint + bip32_derivations = set() + if fingerprint: + # if taproot derivations are present add them for pub in inp.taproot_bip32_derivations: - leaf_hashes, derivation = inp.taproot_bip32_derivations[pub] - # TODO: also sign with leaf_hashes - if derivation.fingerprint == fingerprint and len(leaf_hashes) == 0: + (_leafs, derivation) = inp.taproot_bip32_derivations[pub] + if derivation.fingerprint == fingerprint: bip32_derivations.add((pub, derivation)) - # "Legacy" support for workaround when BIP-371 Taproot psbt fields aren't available. - # TODO: Remove this (and refactor above) when workaround has been phased out. + # segwit and legacy derivations for pub in inp.bip32_derivations: derivation = inp.bip32_derivations[pub] if derivation.fingerprint == fingerprint: bip32_derivations.add((pub, derivation)) - for pub, derivation in bip32_derivations: - der = derivation.derivation - if hasattr(root, "origin"): - # for descriptor key remove origin part - if root.origin: - if root.origin.derivation != der[:len(root.origin.derivation)]: - continue - der = der[len(root.origin.derivation):] - hdkey = root.key.derive(der) - else: - hdkey = root.derive(der) - - if hdkey.xonly() != pub.xonly(): - raise PSBTError("Derivation path doesn't look right") - - pk = hdkey.taproot_tweak(inp.taproot_merkle_root or b"") - if pk.xonly() in sc.data: - sig = pk.schnorr_sign(h) - # sig plus sighash flag - wit = sig.serialize() - if inp_sighash != SIGHASH.DEFAULT: - wit += bytes([inp_sighash]) - inp.final_scriptwitness = Witness([wit]) - counter += 1 - continue + # get derived keys for signing + derived_keypairs = set() # (prv, pub) + for pub, derivation in bip32_derivations: + der = derivation.derivation + # descriptor key has origin derivation that we take into account + if hasattr(root, "origin"): + if root.origin: + if root.origin.derivation != der[:len(root.origin.derivation)]: + # derivation doesn't match - go to next input + continue + der = der[len(root.origin.derivation):] + hdkey = root.key.derive(der) + else: + hdkey = root.derive(der) - # if we have individual private key - if not fingerprint: - # check if we are included in the script - if sec in sc.data or pkh in sc.data: - sig = root.sign(h) - # sig plus sighash flag - inp.partial_sigs[pub] = sig.serialize() + bytes([inp_sighash]) - counter += 1 + if hdkey.xonly() != pub.xonly(): + raise PSBTError("Derivation path doesn't look right") + derived_keypairs.add((hdkey.key, pub)) + + # sign with taproot key + if inp.is_taproot: + # try to sign with individual private key (WIF) + # or with root without derivations + counter += self.sign_input_with_tapkey( + root, i, inp, sighash=inp_sighash, + ) + # sign with all derived keys + for prv, pub in derived_keypairs: + counter += self.sign_input_with_tapkey( + prv, i, inp, sighash=inp_sighash, + ) continue - # if we use HDKey - for pub in inp.bip32_derivations: - # check if it is root key - if inp.bip32_derivations[pub].fingerprint == fingerprint: - der = inp.bip32_derivations[pub].derivation - if hasattr(root, "origin"): - # for descriptor key remove origin part - if root.origin: - if root.origin.derivation != der[:len(root.origin.derivation)]: - continue - der = der[len(root.origin.derivation):] - hdkey = root.key.derive(der) - else: - hdkey = root.derive(der) - mypub = hdkey.key.get_public_key() - if mypub != pub: - raise PSBTError("Derivation path doesn't look right") - sig = hdkey.key.sign(h) - # sig plus sighash flag - inp.partial_sigs[mypub] = sig.serialize() + bytes([inp_sighash]) - counter += 1 + # hash can be reused + h = self.sighash(i, sighash=inp_sighash) + sc = inp.witness_script or inp.redeem_script or inp.utxo.script_pubkey + + # check if root itself is included in the script + if sec in sc.data or pkh in sc.data: + sig = root.sign(h) + # sig plus sighash flag + inp.partial_sigs[rootpub] = sig.serialize() + bytes([inp_sighash]) + counter += 1 + + for prv, pub in derived_keypairs: + sig = prv.sign(h) + # sig plus sighash flag + inp.partial_sigs[pub] = sig.serialize() + bytes([inp_sighash]) + counter += 1 return counter diff --git a/src/embit/psbtview.py b/src/embit/psbtview.py index 5bd8b38..10aee1a 100644 --- a/src/embit/psbtview.py +++ b/src/embit/psbtview.py @@ -1,15 +1,41 @@ """ -PSBTView class is RAM-friendly implementation of PSBT that reads required data from a stream on request. -The PSBT transaction itself passed to the class is a readable stream - it can be a file stream or a BytesIO object. -When using files make sure they are in trusted storage - when using SD card or other untrusted source make sure -to copy the file to a trusted media (flash, QSPI or SPIRAM for example). -Otherwise you expose yourself to time-of-check-time-of-use style of attacks where SD card MCU can trick you -to sign a wrong transactions. +PSBTView class is RAM-friendly implementation of PSBT +that reads required data from a stream on request. + +The PSBT transaction itself passed to the class +is a readable stream - it can be a file stream or a BytesIO object. +When using files make sure they are in trusted storage - when using SD card +or other untrusted source make sure to copy the file to a trusted media +(flash, QSPI or SPIRAM for example). + +Otherwise you expose yourself to time-of-check-time-of-use style of attacks +where SD card MCU can trick you to sign a wrong transactions. + Makes sense to run gc.collect() after processing of each scope to free memory. """ -from .psbt import * -from .transaction import hash_amounts, hash_script_pubkeys +# TODO: refactor, a lot of code is duplicated here from transaction.py import hashlib +from . import compact +from . import ec +from . import script +from .script import Script, Witness +from . import hashes +from .psbt import ( + PSBTError, + CompressMode, + InputScope, + OutputScope, + read_string, + ser_string, + skip_string, +) +from .transaction import ( + TransactionOutput, + TransactionInput, + SIGHASH, + hash_amounts, + hash_script_pubkeys, +) def read_write(sin, sout, l=None, chunk_size=32) -> int: """Reads l or all bytes from sin and writes to sout""" @@ -79,13 +105,21 @@ def num_vout(self): @property def vin0_offset(self): if self._vin0_offset is None: - self._vin0_offset = self.offset + self.NUM_VIN_OFFSET + len(compact.to_bytes(self.num_vin)) + self._vin0_offset = ( + self.offset + + self.NUM_VIN_OFFSET + + len(compact.to_bytes(self.num_vin)) + ) return self._vin0_offset @property def vout0_offset(self): if self._vout0_offset is None: - self._vout0_offset = self.vin0_offset + self.LEN_VIN * self.num_vin + len(compact.to_bytes(self.num_vout)) + self._vout0_offset = ( + self.vin0_offset + + self.LEN_VIN * self.num_vin + + len(compact.to_bytes(self.num_vout)) + ) return self._vout0_offset @property @@ -139,7 +173,9 @@ def __init__(self, stream, compress=CompressMode.KEEP_ALL, ): if version != 2 and tx_offset is None: - raise PSBTError("Global tx is not found, but PSBT version is %d" % version) + raise PSBTError( + "Global tx is not found, but PSBT version is %d" % version + ) self.version = version self.stream = stream # by default we use provided offset, tell() or 0 as default value @@ -390,8 +426,19 @@ def hash_script_pubkeys(self, script_pubkeys): self._hash_script_pubkeys = hash_script_pubkeys(script_pubkeys) return self._hash_script_pubkeys - def sighash_taproot(self, input_index, script_pubkeys, values, sighash=SIGHASH.DEFAULT): + def sighash_taproot(self, + input_index, + script_pubkeys, + values, + sighash=SIGHASH.DEFAULT, + ext_flag=0, + annex=None, + script=None, + leaf_version=0xc0, + codeseparator_pos=None, + ): """check out bip-341""" + # TODO: refactor, it's almost a complete copy of tx.sighash_taproot if input_index < 0 or input_index >= self.num_inputs: raise PSBTError("Invalid input index") if len(values) != self.num_inputs: @@ -409,7 +456,7 @@ def sighash_taproot(self, input_index, script_pubkeys, values, sighash=SIGHASH.D if sh not in [SIGHASH.SINGLE, SIGHASH.NONE]: h.update(self.hash_outputs()) # data about this input - h.update(b"\x00") # ext_flags and annex are not supported + h.update(bytes([2*ext_flag+int(annex is not None)])) if anyonecanpay: vin = self.vin(input_index) h.update(vin.serialize()) @@ -418,9 +465,20 @@ def sighash_taproot(self, input_index, script_pubkeys, values, sighash=SIGHASH.D h.update(vin.sequence.to_bytes(4, "little")) else: h.update(input_index.to_bytes(4, "little")) - # annex is not supported + if annex is not None: + h.update(hashes.sha256(compact.to_bytes(len(annex))+annex)) if sh == SIGHASH.SINGLE: h.update(self.vout(input_index).serialize()) + if script is not None: + h.update( + hashes.tagged_hash("TapLeaf", bytes([leaf_version])+script.serialize()) + ) + h.update(b"\x00") + h.update( + b"\xff\xff\xff\xff" + if codeseparator_pos is None + else codeseparator_pos.to_bytes(4,'little') + ) return h.digest() def sighash_segwit(self, input_index, script_pubkey, value, sighash=SIGHASH.ALL): @@ -448,7 +506,7 @@ def sighash_segwit(self, input_index, script_pubkey, value, sighash=SIGHASH.ALL) h.update(script_pubkey.serialize()) h.update(int(value).to_bytes(8, "little")) h.update(inp.sequence.to_bytes(4, "little")) - if not (sh in [SIGHASH.NONE, SIGHASH.SINGLE]): + if sh not in {SIGHASH.NONE, SIGHASH.SINGLE}: h.update(hashlib.sha256(self.hash_outputs()).digest()) elif sh == SIGHASH.SINGLE and input_index < self.num_outputs: h.update(hashlib.sha256( @@ -509,18 +567,24 @@ def sighash_legacy(self, input_index, script_pubkey, sighash=SIGHASH.ALL): h.update(sighash.to_bytes(4, "little")) return hashlib.sha256(h.digest()).digest() - def sighash(self, i, sighash=SIGHASH.ALL, input_scope=None): + def sighash(self, i, sighash=SIGHASH.ALL, input_scope=None, **kwargs): inp = self.input(i) if input_scope is None else input_scope if inp.is_taproot: - # TODO: not very optimal. Maybe use build_cache() or something? values = [] scripts = [] + # TODO: not very optimal. Maybe use build_cache() or something? for idx in range(self.num_inputs): inp = self.input(idx) values.append(inp.utxo.value) scripts.append(inp.utxo.script_pubkey) - return self.sighash_taproot(i, script_pubkeys=scripts, values=values, sighash=sighash) + return self.sighash_taproot( + i, + script_pubkeys=scripts, + values=values, + sighash=sighash, + **kwargs, + ) value = inp.utxo.value sc = inp.witness_script or inp.redeem_script or inp.utxo.script_pubkey @@ -544,21 +608,97 @@ def sighash(self, i, sighash=SIGHASH.ALL, input_scope=None): h = self.sighash_legacy(i, sc, sighash=sighash) return h - def sign_input(self, i, root, sig_stream, sighash=SIGHASH.DEFAULT, extra_scope_data=None) -> int: + def sign_input_with_tapkey( + self, + key: ec.PrivateKey, + input_index: int, + inp = None, + sighash = SIGHASH.DEFAULT, + ) -> int: + """Sign taproot input with key. Signs with internal or leaf key.""" + # get input ourselves if not provided + inp = inp or self.input(input_index) + if not inp.is_taproot: + return 0 + # check if key is internal key + pk = key.taproot_tweak(inp.taproot_merkle_root or b"") + if pk.xonly() in inp.utxo.script_pubkey.data: + h = self.sighash( + input_index, + sighash=sighash, + ) + sig = pk.schnorr_sign(h) + wit = sig.serialize() + if sighash != SIGHASH.DEFAULT: + wit += bytes([sighash]) + # TODO: maybe better to put into internal key sig field + inp.final_scriptwitness = Witness([wit]) + # no need to sign anything else + return 1 + counter = 0 + # negate if necessary + pub = ec.PublicKey.from_xonly(key.xonly()) + # iterate over leafs and sign + for ctrl, sc in inp.taproot_scripts.items(): + if pub.xonly() not in sc: + continue + leaf_version = sc[-1] + script = Script(sc[:-1]) + h = self.sighash( + input_index, + sighash=sighash, + ext_flag=1, + script=script, + leaf_version=leaf_version, + ) + sig = key.schnorr_sign(h) + leaf = hashes.tagged_hash( + "TapLeaf", + bytes([leaf_version])+script.serialize() + ) + sigdata = sig.serialize() + # append sighash if necessary + if sighash != SIGHASH.DEFAULT: + sigdata += bytes([sighash]) + inp.taproot_sigs[(pub,leaf)] = sigdata + counter += 1 + return counter + + def sign_input(self, + i, + root, + sig_stream, + sighash=SIGHASH.DEFAULT, + extra_scope_data=None + ) -> int: """ - Signs input taking into account additional derivation information for this input. + Signs input taking into account additional + derivation information for this input. + It's helpful if your wallet knows more than provided in PSBT. - As PSBTView is read-only it can't change anything in PSBT, that's why you may need extra_scope_data + As PSBTView is read-only it can't change anything in PSBT, + that's why you may need extra_scope_data. """ if i < 0 or i >= self.num_inputs: raise PSBTError("Invalid input number") # if WIF - fingerprint is None - fingerprint = None if not hasattr(root, "my_fingerprint") else root.my_fingerprint - if not fingerprint: - pub = root.get_public_key() - sec = pub.sec() - pkh = hashes.hash160(sec) + fingerprint = None + # if descriptor key + if hasattr(root, "origin"): + if not root.is_private: # pubkey can't sign + return 0 + if root.is_extended: # use fingerprint only for HDKey + fingerprint = root.fingerprint + else: + root = root.key # WIF key + # if HDKey + if not fingerprint and hasattr(root, "my_fingerprint"): + fingerprint = root.my_fingerprint + + rootpub = root.get_public_key() + sec = rootpub.sec() + pkh = hashes.hash160(sec) inp = self.input(i) if extra_scope_data is not None: @@ -570,96 +710,93 @@ def sign_input(self, i, root, sig_stream, sighash=SIGHASH.DEFAULT, extra_scope_d required_sighash = SIGHASH.ALL # check which sighash to use - inp_sighash = inp.sighash_type or required_sighash or SIGHASH.DEFAULT + inp_sighash = inp.sighash_type + if inp_sighash is None: + inp_sighash = required_sighash or SIGHASH.DEFAULT if not inp.is_taproot and inp_sighash == SIGHASH.DEFAULT: inp_sighash = SIGHASH.ALL - # if input sighash is set and is different from kwarg - don't sign this input + # if input sighash is set and is different from required sighash + # we don't sign this input + # except DEFAULT is functionally the same as ALL if required_sighash is not None and inp_sighash != required_sighash: - return 0 - - h = self.sighash(i, sighash=inp_sighash, input_scope=inp) + if (inp_sighash not in {SIGHASH.DEFAULT, SIGHASH.ALL} + or required_sighash not in {SIGHASH.DEFAULT, SIGHASH.ALL}): + return 0 + + # get all possible derivations with matching fingerprint + bip32_derivations = set() + if fingerprint: + # if taproot derivations are present add them + for pub in inp.taproot_bip32_derivations: + (_leafs, derivation) = inp.taproot_bip32_derivations[pub] + if derivation.fingerprint == fingerprint: + bip32_derivations.add((pub, derivation)) + + # segwit and legacy derivations + for pub in inp.bip32_derivations: + derivation = inp.bip32_derivations[pub] + if derivation.fingerprint == fingerprint: + bip32_derivations.add((pub, derivation)) + + # get derived keys for signing + derived_keypairs = set() # (prv, pub) + for pub, derivation in bip32_derivations: + der = derivation.derivation + # descriptor key has origin derivation that we take into account + if hasattr(root, "origin"): + if root.origin: + if root.origin.derivation != der[:len(root.origin.derivation)]: + # derivation doesn't match - go to next input + continue + der = der[len(root.origin.derivation):] + hdkey = root.key.derive(der) + else: + hdkey = root.derive(der) - sc = inp.witness_script or inp.redeem_script or inp.utxo.script_pubkey + if hdkey.xonly() != pub.xonly(): + raise PSBTError("Derivation path doesn't look right") + derived_keypairs.add((hdkey.key, pub)) counter = 0 - partial_sigs = OrderedDict() - - # taproot is special - # currently works only for single key + # sign with taproot key if inp.is_taproot: - # individual private key - if not fingerprint: - # TODO: tweak using taproot psbt fields - pk = root.taproot_tweak(b"") - if pk.xonly() in sc.data: - sig = pk.schnorr_sign(h) - wit = sig.serialize() - if inp_sighash != SIGHASH.DEFAULT: - wit += bytes([inp_sighash]) - inp.final_scriptwitness = Witness([wit]) - counter = 1 - # if we use HDKey - else: - bip32_derivations = [] - for pub in inp.taproot_bip32_derivations: - leaf_hashes, derivation = inp.taproot_bip32_derivations[pub] - if derivation.fingerprint == fingerprint: - bip32_derivations.append((pub, derivation)) - - # "Legacy" support for workaround when BIP-371 Taproot psbt fields aren't available - for pub in inp.bip32_derivations: - derivation = inp.bip32_derivations[pub] - if derivation.fingerprint == fingerprint: - bip32_derivations.append((pub, derivation)) - - for pub, derivation in bip32_derivations: - hdkey = root.derive(derivation.derivation) - - # Taproot BIP32 derivations use X-only pubkeys - xonly_pub = hdkey.key.xonly() - mypub = ec.PublicKey.from_xonly(xonly_pub) - - if mypub != pub: - raise PSBTError("Derivation path doesn't look right") - pk = hdkey.taproot_tweak(b"") - if pk.xonly() in sc.data: - sig = pk.schnorr_sign(h) - # sig plus sighash flag - wit = sig.serialize() - if inp_sighash != SIGHASH.DEFAULT: - wit += bytes([inp_sighash]) - inp.final_scriptwitness = Witness([wit]) - counter = 1 - if counter: + # try to sign with individual private key (WIF) + # or with root without derivations + counter += self.sign_input_with_tapkey( + root, i, inp, sighash=inp_sighash, + ) + # sign with all derived keys + for prv, pub in derived_keypairs: + counter += self.sign_input_with_tapkey( + prv, i, inp, sighash=inp_sighash, + ) + if inp.final_scriptwitness: ser_string(sig_stream, b"\x08") ser_string(sig_stream, inp.final_scriptwitness.serialize()) + + for (pub, leaf) in inp.taproot_sigs: + ser_string(sig_stream, b"\x14" + pub.xonly() + leaf) + ser_string(sig_stream, inp.taproot_sigs[(pub, leaf)]) return counter - # if we have individual private key - if not fingerprint: - # check if we are included in the script - if sec in sc.data or pkh in sc.data: - sig = root.sign(h) - # sig plus sighash flag - partial_sigs[pub] = sig.serialize() + bytes([inp_sighash]) - counter += 1 - # if we use HDKey - else: - for pub in inp.bip32_derivations: - # check if it is root key - if inp.bip32_derivations[pub].fingerprint == fingerprint: - hdkey = root.derive(inp.bip32_derivations[pub].derivation) - mypub = hdkey.key.get_public_key() - if mypub != pub: - raise PSBTError("Derivation path doesn't look right") - sig = hdkey.key.sign(h) - # sig plus sighash flag - partial_sigs[mypub] = sig.serialize() + bytes([inp_sighash]) - counter += 1 - for pub in partial_sigs: + h = self.sighash(i, sighash=inp_sighash, input_scope=inp) + sc = inp.witness_script or inp.redeem_script or inp.utxo.script_pubkey + + # check if root is included in the script + if sec in sc.data or pkh in sc.data: + sig = root.sign(h) + # sig plus sighash flag + inp.partial_sigs[rootpub] = sig.serialize() + bytes([inp_sighash]) + counter += 1 + for prv, pub in derived_keypairs: + sig = prv.sign(h) + # sig plus sighash flag + inp.partial_sigs[pub] = sig.serialize() + bytes([inp_sighash]) + counter += 1 + for pub in inp.partial_sigs: ser_string(sig_stream, b"\x02" + pub.serialize()) - ser_string(sig_stream, partial_sigs[pub]) + ser_string(sig_stream, inp.partial_sigs[pub]) return counter def sign_with(self, root, sig_stream, sighash=SIGHASH.DEFAULT) -> int: @@ -673,7 +810,15 @@ def sign_with(self, root, sig_stream, sighash=SIGHASH.DEFAULT) -> int: """ counter = 0 for i in range(self.num_inputs): - counter += self.sign_input(i, root, sig_stream, sighash=sighash) + # check if it's a descriptor, and sign with + # all private keys in this descriptor + if hasattr(root, "keys"): + for k in root.keys: + if hasattr(k, "is_private") and k.is_private: + counter += self.sign_input(i, k, sig_stream, sighash=sighash) + else: + # just sign with the key + counter += self.sign_input(i, root, sig_stream, sighash=sighash) # add separator sig_stream.write(b"\x00") return counter diff --git a/src/embit/transaction.py b/src/embit/transaction.py index d858100..de5f6da 100644 --- a/src/embit/transaction.py +++ b/src/embit/transaction.py @@ -1,11 +1,10 @@ -import sys -import io import hashlib from . import compact from .script import Script, Witness from . import hashes from .base import EmbitBase, EmbitError + class TransactionError(EmbitError): pass @@ -200,8 +199,10 @@ def sighash_taproot(self, values, sighash=SIGHASH.DEFAULT, ext_flag=0, - annex_present=False, - extra=b"", + annex=None, + script=None, + leaf_version=0xc0, + codeseparator_pos=None, ): """check out bip-341""" if input_index < 0 or input_index >= len(self.vin): @@ -221,7 +222,7 @@ def sighash_taproot(self, if sh not in [SIGHASH.SINGLE, SIGHASH.NONE]: h.update(self.hash_outputs()) # data about this input - h.update(bytes([2*ext_flag+int(annex_present)])) + h.update(bytes([2*ext_flag+int(annex is not None)])) if anyonecanpay: h.update(self.vin[input_index].serialize()) h.update(values[input_index].to_bytes(8, "little")) @@ -229,10 +230,20 @@ def sighash_taproot(self, h.update(self.vin[input_index].sequence.to_bytes(4, "little")) else: h.update(input_index.to_bytes(4, "little")) - # annex is not supported + if annex is not None: + h.update(hashes.sha256(compact.to_bytes(len(annex))+annex)) if sh == SIGHASH.SINGLE: h.update(self.vout[input_index].serialize()) - h.update(extra) + if script is not None: + h.update( + hashes.tagged_hash("TapLeaf", bytes([leaf_version])+script.serialize()) + ) + h.update(b"\x00") + h.update( + b"\xff\xff\xff\xff" + if codeseparator_pos is None + else codeseparator_pos.to_bytes(4,'little') + ) return h.digest() def sighash_segwit(self, input_index, script_pubkey, value, sighash=SIGHASH.ALL): diff --git a/tests/tests/test_descriptor.py b/tests/tests/test_descriptor.py index bc22ddd..5250737 100644 --- a/tests/tests/test_descriptor.py +++ b/tests/tests/test_descriptor.py @@ -150,6 +150,11 @@ def test_branch_mixing(self): "[f45912ab/44h/12/32h]xprvA1BtcqnJTKdjRQJ4K2874WTDyPCvgT7bCte7cXi4XrZ5csfoVqgWAL61U9dSf3xE9GUDrFL6RnxPRGvHMn85MHbuKSHDp4vqmJ7PK1Eewug/<*;1>/34h/*", ] for k in keys: + try: + Key.from_string(k) + print(k) + except: + pass self.assertRaises( Exception, Key.from_string, k diff --git a/tests/tests/test_liquid.py b/tests/tests/test_liquid.py index 3691c1d..c398353 100644 --- a/tests/tests/test_liquid.py +++ b/tests/tests/test_liquid.py @@ -83,14 +83,14 @@ def test_rangeproof(self): self.assertEqual(extra, b"") def test_descriptors(self): - multi = "wsh(sortedmulti(1,[12345678/44h/12]xpub6BwcvdstHTJtLpp1WxUiQCYERWSB66XY5JrCpw71GAJxcJ6s2AiUoEK4Nzt6UDaTmanUiSe6TY2RoFturKNLXeWBhwBF6WBNghr8cr7qnjk/{0,1}/*,[abcdef12/84h/22h]xpub6F6wWxm8F64iBHNhyaoh3QKCuuMUY5pfPPr1H1WuZXUXeXtZ21qjFN5ykaqnLL1jtPEFB9d94CyZrcYWKVdSiJKQ6mLGEB5sfrGFBpg6wgA/{0,1}/*))" + multi = "wsh(sortedmulti(1,[12345678/44h/12]xpub6BwcvdstHTJtLpp1WxUiQCYERWSB66XY5JrCpw71GAJxcJ6s2AiUoEK4Nzt6UDaTmanUiSe6TY2RoFturKNLXeWBhwBF6WBNghr8cr7qnjk/<0;1>/*,[abcdef12/84h/22h]xpub6F6wWxm8F64iBHNhyaoh3QKCuuMUY5pfPPr1H1WuZXUXeXtZ21qjFN5ykaqnLL1jtPEFB9d94CyZrcYWKVdSiJKQ6mLGEB5sfrGFBpg6wgA/<0;1>/*))" descs = [ - "wpkh([abcdef12/84h/22h]xpub6F6wWxm8F64iBHNhyaoh3QKCuuMUY5pfPPr1H1WuZXUXeXtZ21qjFN5ykaqnLL1jtPEFB9d94CyZrcYWKVdSiJKQ6mLGEB5sfrGFBpg6wgA/{0,1}/*)", + "wpkh([abcdef12/84h/22h]xpub6F6wWxm8F64iBHNhyaoh3QKCuuMUY5pfPPr1H1WuZXUXeXtZ21qjFN5ykaqnLL1jtPEFB9d94CyZrcYWKVdSiJKQ6mLGEB5sfrGFBpg6wgA/<0;1>/*)", multi, "blinded(slip77(L2t59TFgKmc83tPJD1rTy2KxJt44CMMQYsECXdz75xSqVv1X9Tvr),%s)" % multi, - "blinded(xprvA18YC5Aog5LxHgMrSv5t9QaHyfh5DU8Pr8zFTP5QhJSTjdg3mSpEyxLZfNQaEc8sALUtsHeDJYsp8YnobhjJT9D7JADoEV4wXiMuNMYDLZ2/{0,1}/*,%s)" % multi, - "blinded(musig(xprvA18YC5Aog5LxHgMrSv5t9QaHyfh5DU8Pr8zFTP5QhJSTjdg3mSpEyxLZfNQaEc8sALUtsHeDJYsp8YnobhjJT9D7JADoEV4wXiMuNMYDLZ2/{0,1}/*,xprv9ybbsYg8NKhDxDrSdmWPWih2AVjyDYxvTYvjaqNLmSpQcaLhmXeXUcHDEK99MiPDJwteBF2EzZkhfwwQDycrTgdxWGAgyWVpVJxrgZF5eCT/{0,1}/*),%s)" % multi, - "blinded(musig(xpub6E7tbahhWSuFWASKYwctWYX2XhXZcvrFDMurFmV2FdyScS1CJz8VXkf3WchmYnBmC8uMVgENPLYd8uWjXYjxFFwFXD6unhFXs6VBjHTAb9e/{0,1}/*,xprv9ybbsYg8NKhDxDrSdmWPWih2AVjyDYxvTYvjaqNLmSpQcaLhmXeXUcHDEK99MiPDJwteBF2EzZkhfwwQDycrTgdxWGAgyWVpVJxrgZF5eCT/{0,1}/*),%s)" % multi, + "blinded(xprvA18YC5Aog5LxHgMrSv5t9QaHyfh5DU8Pr8zFTP5QhJSTjdg3mSpEyxLZfNQaEc8sALUtsHeDJYsp8YnobhjJT9D7JADoEV4wXiMuNMYDLZ2/<0;1>/*,%s)" % multi, + "blinded(musig(xprvA18YC5Aog5LxHgMrSv5t9QaHyfh5DU8Pr8zFTP5QhJSTjdg3mSpEyxLZfNQaEc8sALUtsHeDJYsp8YnobhjJT9D7JADoEV4wXiMuNMYDLZ2/<0;1>/*,xprv9ybbsYg8NKhDxDrSdmWPWih2AVjyDYxvTYvjaqNLmSpQcaLhmXeXUcHDEK99MiPDJwteBF2EzZkhfwwQDycrTgdxWGAgyWVpVJxrgZF5eCT/<0;1>/*),%s)" % multi, + "blinded(musig(xpub6E7tbahhWSuFWASKYwctWYX2XhXZcvrFDMurFmV2FdyScS1CJz8VXkf3WchmYnBmC8uMVgENPLYd8uWjXYjxFFwFXD6unhFXs6VBjHTAb9e/<0;1>/*,xprv9ybbsYg8NKhDxDrSdmWPWih2AVjyDYxvTYvjaqNLmSpQcaLhmXeXUcHDEK99MiPDJwteBF2EzZkhfwwQDycrTgdxWGAgyWVpVJxrgZF5eCT/<0;1>/*),%s)" % multi, ] for d in descs: desc = LDescriptor.from_string(d) diff --git a/tests/tests/test_psbt.py b/tests/tests/test_psbt.py index 077a92e..6192bf7 100644 --- a/tests/tests/test_psbt.py +++ b/tests/tests/test_psbt.py @@ -3,7 +3,6 @@ from binascii import hexlify, unhexlify from embit.bip32 import HDKey -from embit.ec import PublicKey from embit.psbt import PSBT from unittest import TestCase diff --git a/tests/tests/test_psbtview.py b/tests/tests/test_psbtview.py index bdff18f..73da2be 100644 --- a/tests/tests/test_psbtview.py +++ b/tests/tests/test_psbtview.py @@ -64,7 +64,7 @@ def test_scopes(self): def test_sign(self): """Test if we can sign psbtview and get the same as from signing psbt""" for compress in [CompressMode.KEEP_ALL, CompressMode.CLEAR_ALL, CompressMode.PARTIAL]: - for b64 in PSBTS: + for i, b64 in enumerate(PSBTS): psbt = PSBT.from_string(b64, compress=compress) stream = BytesIO(a2b_base64(b64)) psbtv = PSBTView.view(stream, compress=compress) @@ -100,6 +100,10 @@ def test_sign(self): self.assertEqual(len(signed_inputs), len(psbt.inputs)) for i, inp in enumerate(signed_inputs): inp2 = psbt.inputs[i] + if inp.partial_sigs != inp2.partial_sigs: + print(compress) + print(i) + print(inp.partial_sigs, inp2.partial_sigs) self.assertEqual(inp.partial_sigs, inp2.partial_sigs) # check serialization with signatures sigs_stream.seek(0)