Skip to content

Commit

Permalink
full taptree psbt[view] signing
Browse files Browse the repository at this point in the history
  • Loading branch information
stepansnigirev committed Aug 4, 2023
1 parent a9d3983 commit 08b1cd9
Show file tree
Hide file tree
Showing 10 changed files with 428 additions and 218 deletions.
8 changes: 6 additions & 2 deletions src/embit/descriptor/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,13 @@ def __str__(self):
class AllowedDerivation(DescriptorBase):
# xpub/<0;1>/* - <0;1> is a set of allowed branches, wildcard * is stored as None
def __init__(self, indexes=[[0, 1], None]):
# check only one wildcard and only one set is in the derivation
if len([i for i in indexes if i is None]) > 1:
# check only one wildcard
if len([i
for i in indexes
if i is None or (isinstance(i, list) and None in i)
]) > 1:
raise ArgumentError("Only one wildcard is allowed")
# check only one set is in the derivation
if len([i for i in indexes if isinstance(i, list)]) > 1:
raise ArgumentError("Only one set of branches is allowed")
self.indexes = indexes
Expand Down
5 changes: 3 additions & 2 deletions src/embit/liquid/descriptor.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from .. import ec
from ..descriptor.descriptor import *
from .networks import NETWORKS
from .addresses import address
from . import slip77
from ..hashes import tagged_hash, sha256
from ..ec import PrivateKey, PublicKey, secp256k1
from ..ec import secp256k1

class LDescriptor(Descriptor):
"""Liquid descriptor that supports blinded() wrapper"""
Expand Down Expand Up @@ -191,7 +192,7 @@ def sec(self):
if self._pubkey is None:
pubs = [secp256k1.ec_pubkey_parse(k.sec()) for k in self.keys]
pub = musig_combine_pubs(pubs)
self._pubkey = PublicKey(pub)
self._pubkey = ec.PublicKey(pub)
return self._pubkey.sec()

def musig_combine_privs(privs, sort=True):
Expand Down
1 change: 1 addition & 0 deletions src/embit/liquid/transaction.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
import io
from .. import compact
from ..script import Script, Witness
Expand Down
232 changes: 136 additions & 96 deletions src/embit/psbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from . import bip32
from . import ec
from . import hashes
from .script import Script, Witness
from . import script
from .script import Script, Witness
from .base import EmbitBase, EmbitError

from binascii import b2a_base64, a2b_base64, hexlify, unhexlify
Expand Down Expand Up @@ -788,7 +788,13 @@ def sighash(self, i, sighash=SIGHASH.ALL, **kwargs):
if inp.is_taproot:
values = [inp.utxo.value for inp in self.inputs]
scripts = [inp.utxo.script_pubkey for inp in self.inputs]
return self.sighash_taproot(i, script_pubkeys=scripts, values=values, sighash=sighash, **kwargs)
return self.sighash_taproot(
i,
script_pubkeys=scripts,
values=values,
sighash=sighash,
**kwargs,
)

value = inp.utxo.value
sc = inp.witness_script or inp.redeem_script or inp.utxo.script_pubkey
Expand All @@ -812,17 +818,74 @@ def sighash(self, i, sighash=SIGHASH.ALL, **kwargs):
h = self.sighash_legacy(i, sc, sighash=sighash)
return h

def sign_input_with_tapkey(
self,
key: ec.PrivateKey,
input_index: int,
inp = None,
sighash = SIGHASH.DEFAULT,
) -> int:
"""Sign taproot input with key. Signs with internal or leaf key."""
# get input ourselves if not provided
inp = inp or self.inputs[input_index]
if not inp.is_taproot:
return 0
# check if key is internal key
pk = key.taproot_tweak(inp.taproot_merkle_root or b"")
if pk.xonly() in inp.utxo.script_pubkey.data:
h = self.sighash(
input_index,
sighash=sighash,
)
sig = pk.schnorr_sign(h)
wit = sig.serialize()
if sighash != SIGHASH.DEFAULT:
wit += bytes([sighash])
# TODO: maybe better to put into internal key sig field
inp.final_scriptwitness = Witness([wit])
# no need to sign anything else
return 1
counter = 0
# negate if necessary
pub = ec.PublicKey.from_xonly(key.xonly())
# iterate over leafs and sign
for ctrl, sc in inp.taproot_scripts.items():
if pub.xonly() not in sc:
continue
leaf_version = sc[-1]
script = Script(sc[:-1])
h = self.sighash(
input_index,
sighash=sighash,
ext_flag=1,
script=script,
leaf_version=leaf_version,
)
sig = key.schnorr_sign(h)
leaf = hashes.tagged_hash(
"TapLeaf",
bytes([leaf_version])+script.serialize()
)
sigdata = sig.serialize()
# append sighash if necessary
if sighash != SIGHASH.DEFAULT:
sigdata += bytes([sighash])
inp.taproot_sigs[(pub,leaf)] = sigdata
counter += 1
return counter

def sign_with(self, root, sighash=SIGHASH.DEFAULT) -> int:
"""
Signs psbt with root key (HDKey or similar).
Returns number of signatures added to PSBT.
Sighash kwarg is set to SIGHASH.DEFAULT, for segwit and legacy it's replaced to SIGHASH.ALL
Sighash kwarg is set to SIGHASH.DEFAULT,
for segwit and legacy it's replaced to SIGHASH.ALL
so if PSBT is asking to sign with a different sighash this function won't sign.
If you want to sign with sighashes provided in the PSBT - set sighash=None.
"""
# check if it's a descriptor
counter = 0 # sigs counter
# check if it's a descriptor, and sign with all private keys in this descriptor
if hasattr(root, "keys"):
counter = 0
for k in root.keys:
if hasattr(k, "is_private") and k.is_private:
counter += self.sign_with(k, sighash)
Expand All @@ -841,119 +904,96 @@ def sign_with(self, root, sighash=SIGHASH.DEFAULT) -> int:
# if HDKey
if not fingerprint and hasattr(root, "my_fingerprint"):
fingerprint = root.my_fingerprint
if not fingerprint:
pub = root.get_public_key()
sec = pub.sec()
pkh = hashes.hash160(sec)

rootpub = root.get_public_key()
sec = rootpub.sec()
pkh = hashes.hash160(sec)

counter = 0
for i, inp in enumerate(self.inputs):
# SIGHASH.DEFAULT is only for taproot, fallback to SIGHASH.ALL for other inputs
# SIGHASH.DEFAULT is only for taproot, fallback
# to SIGHASH.ALL for other inputs
required_sighash = sighash
if not inp.is_taproot and required_sighash == SIGHASH.DEFAULT:
required_sighash = SIGHASH.ALL

# check which sighash to use
inp_sighash = inp.sighash_type or required_sighash or SIGHASH.DEFAULT
inp_sighash = inp.sighash_type
if inp_sighash is None:
inp_sighash = required_sighash or SIGHASH.DEFAULT
if not inp.is_taproot and inp_sighash == SIGHASH.DEFAULT:
inp_sighash = SIGHASH.ALL

# if input sighash is set and is different from kwarg - don't sign this input
# if input sighash is set and is different from required sighash
# we don't sign this input
# except DEFAULT is functionally the same as ALL
if required_sighash is not None and inp_sighash != required_sighash:
continue

h = self.sighash(i, sighash=inp_sighash)

sc = inp.witness_script or inp.redeem_script or inp.utxo.script_pubkey

# taproot is special
# currently works only for single key
if inp.is_taproot:
# individual private key
# TODO: what if it's hdkey but root key was used for signing?
if not fingerprint:
# check if key is internal key
# TODO: process hashes
pk = root.taproot_tweak(inp.taproot_merkle_root or b"")
if pk.xonly() not in sc.data:
continue
sig = pk.schnorr_sign(h)
wit = sig.serialize()
if inp_sighash != SIGHASH.DEFAULT:
wit += bytes([inp_sighash])
inp.final_scriptwitness = Witness([wit])
counter += 1
if (inp_sighash not in {SIGHASH.DEFAULT, SIGHASH.ALL}
or required_sighash not in {SIGHASH.DEFAULT, SIGHASH.ALL}):
continue
# if we use HDKey
bip32_derivations = set()

# get all possible derivations with matching fingerprint
bip32_derivations = set()
if fingerprint:
# if taproot derivations are present add them
for pub in inp.taproot_bip32_derivations:
leaf_hashes, derivation = inp.taproot_bip32_derivations[pub]
# TODO: also sign with leaf_hashes
if derivation.fingerprint == fingerprint and len(leaf_hashes) == 0:
(_leafs, derivation) = inp.taproot_bip32_derivations[pub]
if derivation.fingerprint == fingerprint:
bip32_derivations.add((pub, derivation))

# "Legacy" support for workaround when BIP-371 Taproot psbt fields aren't available.
# TODO: Remove this (and refactor above) when workaround has been phased out.
# segwit and legacy derivations
for pub in inp.bip32_derivations:
derivation = inp.bip32_derivations[pub]
if derivation.fingerprint == fingerprint:
bip32_derivations.add((pub, derivation))

for pub, derivation in bip32_derivations:
der = derivation.derivation
if hasattr(root, "origin"):
# for descriptor key remove origin part
if root.origin:
if root.origin.derivation != der[:len(root.origin.derivation)]:
continue
der = der[len(root.origin.derivation):]
hdkey = root.key.derive(der)
else:
hdkey = root.derive(der)

if hdkey.xonly() != pub.xonly():
raise PSBTError("Derivation path doesn't look right")

pk = hdkey.taproot_tweak(inp.taproot_merkle_root or b"")
if pk.xonly() in sc.data:
sig = pk.schnorr_sign(h)
# sig plus sighash flag
wit = sig.serialize()
if inp_sighash != SIGHASH.DEFAULT:
wit += bytes([inp_sighash])
inp.final_scriptwitness = Witness([wit])
counter += 1
continue
# get derived keys for signing
derived_keypairs = set() # (prv, pub)
for pub, derivation in bip32_derivations:
der = derivation.derivation
# descriptor key has origin derivation that we take into account
if hasattr(root, "origin"):
if root.origin:
if root.origin.derivation != der[:len(root.origin.derivation)]:
# derivation doesn't match - go to next input
continue
der = der[len(root.origin.derivation):]
hdkey = root.key.derive(der)
else:
hdkey = root.derive(der)

# if we have individual private key
if not fingerprint:
# check if we are included in the script
if sec in sc.data or pkh in sc.data:
sig = root.sign(h)
# sig plus sighash flag
inp.partial_sigs[pub] = sig.serialize() + bytes([inp_sighash])
counter += 1
if hdkey.xonly() != pub.xonly():
raise PSBTError("Derivation path doesn't look right")
derived_keypairs.add((hdkey.key, pub))

# sign with taproot key
if inp.is_taproot:
# try to sign with individual private key (WIF)
# or with root without derivations
counter += self.sign_input_with_tapkey(
root, i, inp, sighash=inp_sighash,
)
# sign with all derived keys
for prv, pub in derived_keypairs:
counter += self.sign_input_with_tapkey(
prv, i, inp, sighash=inp_sighash,
)
continue

# if we use HDKey
for pub in inp.bip32_derivations:
# check if it is root key
if inp.bip32_derivations[pub].fingerprint == fingerprint:
der = inp.bip32_derivations[pub].derivation
if hasattr(root, "origin"):
# for descriptor key remove origin part
if root.origin:
if root.origin.derivation != der[:len(root.origin.derivation)]:
continue
der = der[len(root.origin.derivation):]
hdkey = root.key.derive(der)
else:
hdkey = root.derive(der)
mypub = hdkey.key.get_public_key()
if mypub != pub:
raise PSBTError("Derivation path doesn't look right")
sig = hdkey.key.sign(h)
# sig plus sighash flag
inp.partial_sigs[mypub] = sig.serialize() + bytes([inp_sighash])
counter += 1
# hash can be reused
h = self.sighash(i, sighash=inp_sighash)
sc = inp.witness_script or inp.redeem_script or inp.utxo.script_pubkey

# check if root itself is included in the script
if sec in sc.data or pkh in sc.data:
sig = root.sign(h)
# sig plus sighash flag
inp.partial_sigs[rootpub] = sig.serialize() + bytes([inp_sighash])
counter += 1

for prv, pub in derived_keypairs:
sig = prv.sign(h)
# sig plus sighash flag
inp.partial_sigs[pub] = sig.serialize() + bytes([inp_sighash])
counter += 1
return counter
Loading

0 comments on commit 08b1cd9

Please sign in to comment.