Skip to content

Commit

Permalink
fmt: formatting code
Browse files Browse the repository at this point in the history
Signed-off-by: Vincenzo Palazzo <vincenzopalazzodev@gmail.com>
  • Loading branch information
vincenzopalazzo committed Mar 3, 2022
1 parent e0cce0a commit 4b4823c
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 75 deletions.
81 changes: 42 additions & 39 deletions lnprototest/clightning/clightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class Runner(lnprototest.Runner):
def __init__(self, config: Any):
super().__init__(config)
self.running = False
self.rpc = None
self.cleanup_callbacks: List[Callable[[], None]] = []
self.fundchannel_future: Optional[Any] = None
self.is_fundchannel_kill = False
Expand All @@ -81,8 +82,8 @@ def __init__(self, config: Any):
stdout=subprocess.PIPE,
check=True,
)
.stdout.decode("utf-8")
.splitlines()
.stdout.decode("utf-8")
.splitlines()
)
self.options: Dict[str, str] = {}
for o in opts:
Expand All @@ -91,6 +92,7 @@ def __init__(self, config: Any):
else:
k, v = o.split("/")
self.options[k] = v
self.start()

def get_keyset(self) -> KeySet:
return KeySet(
Expand All @@ -111,7 +113,7 @@ def is_running(self) -> bool:
return self.running

def start(self) -> None:
if self.running:
if self.is_running():
return
self.proc = subprocess.Popen(
[
Expand Down Expand Up @@ -139,6 +141,7 @@ def start(self) -> None:
self.rpc = pyln.client.LightningRpc(
os.path.join(self.lightning_dir, "regtest", "lightning-rpc")
)

def node_ready(rpc: pyln.client.LightningRpc) -> bool:
try:
rpc.getinfo()
Expand All @@ -152,22 +155,12 @@ def node_ready(rpc: pyln.client.LightningRpc) -> bool:
for i in range(5):
self.rpc.newaddr()

def kill_fundchannel(self) -> None:
fut = self.fundchannel_future
self.fundchannel_future = None
self.is_fundchannel_kill = True
if fut:
try:
fut.result(0)
except (SpecFileError, futures.TimeoutError):
pass

def shutdown(self) -> None:
for cb in self.cleanup_callbacks:
cb()

def stop(self) -> None:
if self.running is False:
if not self.running:
return
self.shutdown()
self.rpc.stop()
Expand All @@ -176,6 +169,16 @@ def stop(self) -> None:
for c in self.conns.values():
cast(CLightningConn, c).connection.connection.close()

def kill_fundchannel(self) -> None:
fut = self.fundchannel_future
self.fundchannel_future = None
self.is_fundchannel_kill = True
if fut:
try:
fut.result(0)
except (SpecFileError, futures.TimeoutError):
pass

def connect(self, event: Event, connprivkey: str) -> None:
self.add_conn(CLightningConn(connprivkey, self.lightning_port))

Expand Down Expand Up @@ -223,12 +226,12 @@ def recv(self, event: Event, conn: Conn, outbuf: bytes) -> None:
raise EventError(event, "Connection closed")

def fundchannel(
self,
event: Event,
conn: Conn,
amount: int,
feerate: int = 253,
expect_fail: bool = False,
self,
event: Event,
conn: Conn,
amount: int,
feerate: int = 253,
expect_fail: bool = False,
) -> None:
"""
event - the event which cause this, for error logging
Expand All @@ -248,11 +251,11 @@ def fundchannel(
self.fundchannel_future = None

def _fundchannel(
runner: Runner,
conn: Conn,
amount: int,
feerate: int,
expect_fail: bool = False,
runner: Runner,
conn: Conn,
amount: int,
feerate: int,
expect_fail: bool = False,
) -> str:
peer_id = conn.pubkey.format().hex()
# Need to supply feerate here, since regtest cannot estimate fees
Expand All @@ -276,14 +279,14 @@ def _done(fut: Any) -> None:
self.cleanup_callbacks.append(self.kill_fundchannel)

def init_rbf(
self,
event: Event,
conn: Conn,
channel_id: str,
amount: int,
utxo_txid: str,
utxo_outnum: int,
feerate: int,
self,
event: Event,
conn: Conn,
channel_id: str,
amount: int,
utxo_txid: str,
utxo_outnum: int,
feerate: int,
) -> None:

if self.fundchannel_future:
Expand Down Expand Up @@ -356,7 +359,7 @@ def addhtlc(self, event: Event, conn: Conn, amount: int, preimage: str) -> None:
self.rpc.sendpay([routestep], payhash)

def get_output_message(
self, conn: Conn, event: Event, timeout: int = TIMEOUT
self, conn: Conn, event: Event, timeout: int = TIMEOUT
) -> Optional[bytes]:
fut = self.executor.submit(cast(CLightningConn, conn).connection.read_message)
try:
Expand All @@ -373,11 +376,11 @@ def check_error(self, event: Event, conn: Conn) -> Optional[str]:
return msg.hex()

def check_final_error(
self,
event: Event,
conn: Conn,
expected: bool,
must_not_events: List[MustNotMsg],
self,
event: Event,
conn: Conn,
expected: bool,
must_not_events: List[MustNotMsg],
) -> None:
if not expected:
# Inject raw packet to ensure it hangs up *after* processing all previous ones.
Expand Down
65 changes: 32 additions & 33 deletions lnprototest/dummyrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@


class DummyRunner(Runner):

def __init__(self, config: Any):
super().__init__(config)

Expand Down Expand Up @@ -87,12 +86,12 @@ def recv(self, event: Event, conn: Conn, outbuf: bytes) -> None:
print("[RECV {} {}]".format(event, outbuf.hex()))

def fundchannel(
self,
event: Event,
conn: Conn,
amount: int,
feerate: int = 253,
expect_fail: bool = False,
self,
event: Event,
conn: Conn,
amount: int,
feerate: int = 253,
expect_fail: bool = False,
) -> None:
if self.config.getoption("verbose"):
print(
Expand All @@ -102,14 +101,14 @@ def fundchannel(
)

def init_rbf(
self,
event: Event,
conn: Conn,
channel_id: str,
amount: int,
utxo_txid: str,
utxo_outnum: int,
feerate: int,
self,
event: Event,
conn: Conn,
channel_id: str,
amount: int,
utxo_txid: str,
utxo_outnum: int,
feerate: int,
) -> None:
if self.config.getoption("verbose"):
print(
Expand Down Expand Up @@ -144,21 +143,21 @@ def fake_field(ftype: FieldType) -> str:
if ftype.elemtype.name == "byte":
return "00" * ftype.arraysize
return (
"["
+ ",".join([DummyRunner.fake_field(ftype.elemtype)] * ftype.arraysize)
+ "]"
"["
+ ",".join([DummyRunner.fake_field(ftype.elemtype)] * ftype.arraysize)
+ "]"
)
elif ftype.name in (
"byte",
"u8",
"u16",
"u32",
"u64",
"tu16",
"tu32",
"tu64",
"bigsize",
"varint",
"byte",
"u8",
"u16",
"u32",
"u64",
"tu16",
"tu32",
"tu64",
"bigsize",
"varint",
):
return "0"
elif ftype.name in ("chain_hash", "channel_id", "sha256"):
Expand Down Expand Up @@ -201,11 +200,11 @@ def check_error(self, event: Event, conn: Conn) -> Optional[str]:
return "Dummy error"

def check_final_error(
self,
event: Event,
conn: Conn,
expected: bool,
must_not_events: List[MustNotMsg],
self,
event: Event,
conn: Conn,
expected: bool,
must_not_events: List[MustNotMsg],
) -> None:
pass

Expand Down
4 changes: 3 additions & 1 deletion lnprototest/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def __init__(self, config: Any):
self.stash: Dict[str, Dict[str, Any]] = {}

def __enter__(self) -> "Runner":
"""Call the method when enter inside the class the first time"""
"""Call the method when enter inside the class the first time.
doc: https://docs.python.org/3/reference/datamodel.html#with-statement-context-managers"""
self.start()
return self

Expand Down Expand Up @@ -233,6 +234,7 @@ def close_channel(self, channel_id: str) -> bool:
a boolean value if it succeeded with success"""
pass


def remote_revocation_basepoint() -> Callable[[Runner, Event, str], str]:
"""Get the remote revocation basepoint"""

Expand Down
3 changes: 1 addition & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ def pytest_addoption(parser: Any) -> None:
@pytest.fixture() # type: ignore
def runner(pytestconfig: Any) -> Any:
parts = pytestconfig.getoption("runner").rpartition(".")
runner = importlib.import_module(parts[0]).__dict__[parts[2]](pytestconfig)
yield runner
yield importlib.import_module(parts[0]).__dict__[parts[2]](pytestconfig)


@pytest.fixture()
Expand Down

0 comments on commit 4b4823c

Please sign in to comment.