Skip to content

Commit

Permalink
sign with internal key when taptree is present
Browse files Browse the repository at this point in the history
  • Loading branch information
stepansnigirev committed Aug 2, 2023
1 parent cd0b247 commit a9d3983
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 53 deletions.
98 changes: 48 additions & 50 deletions src/embit/psbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def update(self, other):
self.bip32_derivations.update(other.bip32_derivations)
self.taproot_bip32_derivations.update(other.taproot_bip32_derivations)
self.taproot_internal_key = other.taproot_internal_key
self.taproot_merkle_root = other.taproot_merkle.root or self.taproot_merkle_root
self.taproot_merkle_root = other.taproot_merkle_root or self.taproot_merkle_root
self.taproot_sigs.update(other.taproot_sigs)
self.taproot_scripts.update(other.taproot_scripts)
self.final_scriptsig = other.final_scriptsig or self.final_scriptsig
Expand Down Expand Up @@ -782,13 +782,13 @@ def parse_unknowns(self):
raise PSBTError("Outputs already initialized")
self.outputs = [self.PSBTOUT_CLS() for _ in range(compact.from_bytes(self.unknown.pop(k)))]

def sighash(self, i, sighash=SIGHASH.ALL):
def sighash(self, i, sighash=SIGHASH.ALL, **kwargs):
inp = self.inputs[i]

if inp.is_taproot:
values = [inp.utxo.value for inp in self.inputs]
scripts = [inp.utxo.script_pubkey for inp in self.inputs]
return self.sighash_taproot(i, script_pubkeys=scripts, values=values, sighash=sighash)
return self.sighash_taproot(i, script_pubkeys=scripts, values=values, sighash=sighash, **kwargs)

value = inp.utxo.value
sc = inp.witness_script or inp.redeem_script or inp.utxo.script_pubkey
Expand Down Expand Up @@ -870,61 +870,59 @@ def sign_with(self, root, sighash=SIGHASH.DEFAULT) -> int:
# currently works only for single key
if inp.is_taproot:
# individual private key
# TODO: what if it's hdkey but root key was used for signing?
if not fingerprint:
# TODO: tweak using taproot psbt fields
pk = root.taproot_tweak(b"")
# check if key is internal key
# TODO: process hashes
pk = root.taproot_tweak(inp.taproot_merkle_root or b"")
if pk.xonly() not in sc.data:
continue
sig = pk.schnorr_sign(h)
wit = sig.serialize()
if inp_sighash != SIGHASH.DEFAULT:
wit += bytes([inp_sighash])
inp.final_scriptwitness = Witness([wit])
counter += 1
continue
# if we use HDKey
bip32_derivations = set()
for pub in inp.taproot_bip32_derivations:
leaf_hashes, derivation = inp.taproot_bip32_derivations[pub]
# TODO: also sign with leaf_hashes
if derivation.fingerprint == fingerprint and len(leaf_hashes) == 0:
bip32_derivations.add((pub, derivation))

# "Legacy" support for workaround when BIP-371 Taproot psbt fields aren't available.
# TODO: Remove this (and refactor above) when workaround has been phased out.
for pub in inp.bip32_derivations:
derivation = inp.bip32_derivations[pub]
if derivation.fingerprint == fingerprint:
bip32_derivations.add((pub, derivation))

for pub, derivation in bip32_derivations:
der = derivation.derivation
if hasattr(root, "origin"):
# for descriptor key remove origin part
if root.origin:
if root.origin.derivation != der[:len(root.origin.derivation)]:
continue
der = der[len(root.origin.derivation):]
hdkey = root.key.derive(der)
else:
hdkey = root.derive(der)

if hdkey.xonly() != pub.xonly():
raise PSBTError("Derivation path doesn't look right")

pk = hdkey.taproot_tweak(inp.taproot_merkle_root or b"")
if pk.xonly() in sc.data:
sig = pk.schnorr_sign(h)
# sig plus sighash flag
wit = sig.serialize()
if inp_sighash != SIGHASH.DEFAULT:
wit += bytes([inp_sighash])
inp.final_scriptwitness = Witness([wit])
counter += 1
# if we use HDKey
else:
bip32_derivations = []
for pub in inp.taproot_bip32_derivations:
leaf_hashes, derivation = inp.taproot_bip32_derivations[pub]
# TODO: also sign with leaf_hashes
if derivation.fingerprint == fingerprint and len(leaf_hashes) == 0:
bip32_derivations.append((pub, derivation))

# "Legacy" support for workaround when BIP-371 Taproot psbt fields aren't available.
# TODO: Remove this (and refactor above) when workaround has been phased out.
for pub in inp.bip32_derivations:
derivation = inp.bip32_derivations[pub]
if derivation.fingerprint == fingerprint:
bip32_derivations.append((pub, derivation))

for pub, derivation in bip32_derivations:
der = derivation.derivation
if hasattr(root, "origin"):
# for descriptor key remove origin part
if root.origin:
if root.origin.derivation != der[:len(root.origin.derivation)]:
continue
der = der[len(root.origin.derivation):]
hdkey = root.key.derive(der)
else:
hdkey = root.derive(der)

# Taproot BIP32 derivations use X-only pubkeys
xonly_pub = hdkey.key.xonly()
mypub = ec.PublicKey.from_xonly(xonly_pub)

if mypub != pub:
raise PSBTError("Derivation path doesn't look right")

# TODO: Support signing for keys within leaves
pk = hdkey.taproot_tweak(b"")
if pk.xonly() in sc.data:
sig = pk.schnorr_sign(h)
# sig plus sighash flag
wit = sig.serialize()
if inp_sighash != SIGHASH.DEFAULT:
wit += bytes([inp_sighash])
inp.final_scriptwitness = Witness([wit])
counter += 1
continue

# if we have individual private key
Expand Down
2 changes: 2 additions & 0 deletions src/embit/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def sighash_taproot(self,
sighash=SIGHASH.DEFAULT,
ext_flag=0,
annex_present=False,
extra=b"",
):
"""check out bip-341"""
if input_index < 0 or input_index >= len(self.vin):
Expand Down Expand Up @@ -231,6 +232,7 @@ def sighash_taproot(self,
# annex is not supported
if sh == SIGHASH.SINGLE:
h.update(self.vout[input_index].serialize())
h.update(extra)
return h.digest()

def sighash_segwit(self, input_index, script_pubkey, value, sighash=SIGHASH.ALL):
Expand Down
4 changes: 1 addition & 3 deletions tests/tests/test_taproot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from unittest import TestCase
from embit import bip32
from embit.bip32 import HDKey
from embit.networks import NETWORKS
from embit.script import p2tr, address_to_scriptpubkey
Expand All @@ -8,9 +7,8 @@
from embit.psbtview import PSBTView
from embit.ec import SchnorrSig, PublicKey
from embit.transaction import SIGHASH
from embit.psbtview import PSBTView
from io import BytesIO
from binascii import unhexlify, hexlify
from binascii import unhexlify

KEY = "tprv8ZgxMBicQKsPf27gmh4DbQqN2K6xnXA7m7AeceqQVGkRYny3X49sgcufzbJcq4k5eaGZDMijccdDzvQga2Saqd78dKqN52QwLyqgY8apX3j"
ROOT = HDKey.from_string(KEY)
Expand Down

0 comments on commit a9d3983

Please sign in to comment.