Skip to content

Commit

Permalink
python/trezorctl: implement common client and exception handling (fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
matejcik committed Mar 25, 2020
1 parent 6213201 commit 61aef25
Show file tree
Hide file tree
Showing 22 changed files with 255 additions and 250 deletions.
2 changes: 1 addition & 1 deletion python/src/trezorlib/btc.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def verify_message(client, coin_name, address, signature, message):
coin_name=coin_name,
)
)
except exceptions.TrezorFailure as e:
except exceptions.TrezorFailure:
return False
return isinstance(resp, messages.Success)

Expand Down
62 changes: 60 additions & 2 deletions python/src/trezorlib/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,72 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.

import functools
import sys

import click

from .. import exceptions
from ..client import TrezorClient
from ..transport import get_transport
from ..ui import ClickUI


class ChoiceType(click.Choice):
def __init__(self, typemap):
super(ChoiceType, self).__init__(typemap.keys())
super().__init__(typemap.keys())
self.typemap = typemap

def convert(self, value, param, ctx):
value = super(ChoiceType, self).convert(value, param, ctx)
value = super().convert(value, param, ctx)
return self.typemap[value]


class TrezorConnection:
def __init__(self, path, session_id, passphrase_on_host):
self.path = path
self.session_id = session_id
self.passphrase_on_host = passphrase_on_host

def get_transport(self):
try:
# look for transport without prefix search
return get_transport(self.path, prefix_search=False)
except Exception:
# most likely not found. try again below.
pass

# look for transport with prefix search
# if this fails, we want the exception to bubble up to the caller
return get_transport(self.path, prefix_search=True)

def get_ui(self):
return ClickUI(passphrase_on_host=self.passphrase_on_host)

def get_client(self):
transport = self.get_transport()
ui = self.get_ui()
return TrezorClient(transport, ui=ui, session_id=self.session_id)


def with_client(func):
@click.pass_obj
@functools.wraps(func)
def trezorctl_command_with_client(obj, *args, **kwargs):
try:
client = obj.get_client()
except Exception:
click.echo("Failed to find a Trezor device.")
if obj.path is not None:
click.echo("Using path: {}".format(obj.path))
sys.exit(1)

try:
return func(client, *args, **kwargs)
except exceptions.Cancelled:
click.echo("Action was cancelled.")
sys.exit(1)
except exceptions.TrezorException as e:
raise click.ClickException(str(e)) from e

return trezorctl_command_with_client
19 changes: 7 additions & 12 deletions python/src/trezorlib/cli/binance.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import click

from .. import binance, tools
from . import with_client

PATH_HELP = "BIP-32 path to key, e.g. m/44'/714'/0'/0/0"

Expand All @@ -31,24 +32,20 @@ def cli():
@cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@click.pass_obj
def get_address(connect, address, show_display):
@with_client
def get_address(client, address, show_display):
"""Get Binance address for specified path."""
client = connect()
address_n = tools.parse_path(address)

return binance.get_address(client, address_n, show_display)


@cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@click.pass_obj
def get_public_key(connect, address, show_display):
@with_client
def get_public_key(client, address, show_display):
"""Get Binance public key."""
client = connect()
address_n = tools.parse_path(address)

return binance.get_public_key(client, address_n, show_display).hex()


Expand All @@ -61,10 +58,8 @@ def get_public_key(connect, address, show_display):
required=True,
help="Transaction in JSON format",
)
@click.pass_obj
def sign_tx(connect, address, file):
@with_client
def sign_tx(client, address, file):
"""Sign Binance transaction"""
client = connect()
address_n = tools.parse_path(address)

return binance.sign_tx(client, address_n, json.load(file))
32 changes: 15 additions & 17 deletions python/src/trezorlib/cli/btc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import click

from .. import btc, messages, protobuf, tools
from . import ChoiceType
from . import ChoiceType, with_client

INPUT_SCRIPTS = {
"address": messages.InputScriptType.SPENDADDRESS,
Expand Down Expand Up @@ -52,13 +52,13 @@ def cli():
@click.option("-n", "--address", required=True, help="BIP-32 path")
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS), default="address")
@click.option("-d", "--show-display", is_flag=True)
@click.pass_obj
def get_address(connect, coin, address, script_type, show_display):
@with_client
def get_address(client, coin, address, script_type, show_display):
"""Get address for specified path."""
coin = coin or DEFAULT_COIN
address_n = tools.parse_path(address)
return btc.get_address(
connect(), coin, address_n, show_display, script_type=script_type
client, coin, address_n, show_display, script_type=script_type
)


Expand All @@ -68,13 +68,13 @@ def get_address(connect, coin, address, script_type, show_display):
@click.option("-e", "--curve")
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS), default="address")
@click.option("-d", "--show-display", is_flag=True)
@click.pass_obj
def get_public_node(connect, coin, address, curve, script_type, show_display):
@with_client
def get_public_node(client, coin, address, curve, script_type, show_display):
"""Get public node of given path."""
coin = coin or DEFAULT_COIN
address_n = tools.parse_path(address)
result = btc.get_public_node(
connect(),
client,
address_n,
ecdsa_curve_name=curve,
show_display=show_display,
Expand All @@ -101,8 +101,8 @@ def get_public_node(connect, coin, address, curve, script_type, show_display):
@cli.command()
@click.option("-c", "--coin", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.argument("json_file", type=click.File())
@click.pass_obj
def sign_tx(connect, json_file):
@with_client
def sign_tx(client, json_file):
"""Sign transaction.
Transaction data must be provided in a JSON file. See `transaction-format.md` for
Expand All @@ -111,8 +111,6 @@ def sign_tx(connect, json_file):
$ python3 tools/build_tx.py | trezorctl btc sign-tx -
"""
client = connect()

data = json.load(json_file)
coin = data.get("coin_name", DEFAULT_COIN)
details = protobuf.dict_to_proto(messages.SignTx, data.get("details", {}))
Expand Down Expand Up @@ -145,12 +143,12 @@ def sign_tx(connect, json_file):
@click.option("-n", "--address", required=True, help="BIP-32 path")
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS), default="address")
@click.argument("message")
@click.pass_obj
def sign_message(connect, coin, address, message, script_type):
@with_client
def sign_message(client, coin, address, message, script_type):
"""Sign message using address of given path."""
coin = coin or DEFAULT_COIN
address_n = tools.parse_path(address)
res = btc.sign_message(connect(), coin, address_n, message, script_type)
res = btc.sign_message(client, coin, address_n, message, script_type)
return {
"message": message,
"address": res.address,
Expand All @@ -163,12 +161,12 @@ def sign_message(connect, coin, address, message, script_type):
@click.argument("address")
@click.argument("signature")
@click.argument("message")
@click.pass_obj
def verify_message(connect, coin, address, signature, message):
@with_client
def verify_message(client, coin, address, signature, message):
"""Verify message."""
signature = base64.b64decode(signature)
coin = coin or DEFAULT_COIN
return btc.verify_message(connect(), coin, address, signature, message)
return btc.verify_message(client, coin, address, signature, message)


#
Expand Down
19 changes: 7 additions & 12 deletions python/src/trezorlib/cli/cardano.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import click

from .. import cardano, tools
from . import with_client

PATH_HELP = "BIP-32 path to key, e.g. m/44'/1815'/0'/0/0"

Expand All @@ -37,11 +38,9 @@ def cli():
help="Transaction in JSON format",
)
@click.option("-N", "--network", type=int, default=1)
@click.pass_obj
def sign_tx(connect, file, network):
@with_client
def sign_tx(client, file, network):
"""Sign Cardano transaction."""
client = connect()

transaction = json.load(file)

inputs = [cardano.create_input(input) for input in transaction["inputs"]]
Expand All @@ -59,21 +58,17 @@ def sign_tx(connect, file, network):
@cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@click.pass_obj
def get_address(connect, address, show_display):
@with_client
def get_address(client, address, show_display):
"""Get Cardano address."""
client = connect()
address_n = tools.parse_path(address)

return cardano.get_address(client, address_n, show_display)


@cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.pass_obj
def get_public_key(connect, address):
@with_client
def get_public_key(client, address):
"""Get Cardano public key."""
client = connect()
address_n = tools.parse_path(address)

return cardano.get_public_key(client, address_n)
11 changes: 5 additions & 6 deletions python/src/trezorlib/cli/cosi.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import click

from .. import cosi, tools
from . import with_client

PATH_HELP = "BIP-32 path, e.g. m/44'/0'/0'/0/0"

Expand All @@ -29,10 +30,9 @@ def cli():
@cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.argument("data")
@click.pass_obj
def commit(connect, address, data):
@with_client
def commit(client, address, data):
"""Ask device to commit to CoSi signing."""
client = connect()
address_n = tools.parse_path(address)
return cosi.commit(client, address_n, bytes.fromhex(data))

Expand All @@ -42,10 +42,9 @@ def commit(connect, address, data):
@click.argument("data")
@click.argument("global_commitment")
@click.argument("global_pubkey")
@click.pass_obj
def sign(connect, address, data, global_commitment, global_pubkey):
@with_client
def sign(client, address, data, global_commitment, global_pubkey):
"""Ask device to sign using CoSi."""
client = connect()
address_n = tools.parse_path(address)
return cosi.sign(
client,
Expand Down
20 changes: 9 additions & 11 deletions python/src/trezorlib/cli/crypto.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import click

from .. import misc, tools
from . import with_client


@click.group(name="crypto")
Expand All @@ -26,32 +27,29 @@ def cli():

@cli.command()
@click.argument("size", type=int)
@click.pass_obj
def get_entropy(connect, size):
@with_client
def get_entropy(client, size):
"""Get random bytes from device."""
return misc.get_entropy(connect(), size).hex()
return misc.get_entropy(client, size).hex()


@cli.command()
@click.option("-n", "--address", required=True, help="BIP-32 path, e.g. m/10016'/0")
@click.argument("key")
@click.argument("value")
@click.pass_obj
def encrypt_keyvalue(connect, address, key, value):
@with_client
def encrypt_keyvalue(client, address, key, value):
"""Encrypt value by given key and path."""
client = connect()
address_n = tools.parse_path(address)
res = misc.encrypt_keyvalue(client, address_n, key, value.encode())
return res.hex()
return misc.encrypt_keyvalue(client, address_n, key, value.encode()).hex()


@cli.command()
@click.option("-n", "--address", required=True, help="BIP-32 path, e.g. m/10016'/0")
@click.argument("key")
@click.argument("value")
@click.pass_obj
def decrypt_keyvalue(connect, address, key, value):
@with_client
def decrypt_keyvalue(client, address, key, value):
"""Decrypt value by given key and path."""
client = connect()
address_n = tools.parse_path(address)
return misc.decrypt_keyvalue(client, address_n, key, bytes.fromhex(value))
Loading

0 comments on commit 61aef25

Please sign in to comment.