Skip to content

Commit

Permalink
refactor: put comment_allowed into LnurlPayResponse
Browse files Browse the repository at this point in the history
instead of having an extra class for it
  • Loading branch information
dni committed Jul 22, 2024
1 parent e3e589a commit 635f365
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 23 deletions.
2 changes: 0 additions & 2 deletions lnurl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
LnurlHostedChannelResponse,
LnurlPayActionResponse,
LnurlPayResponse,
LnurlPayResponseComment,
LnurlResponse,
LnurlSuccessResponse,
LnurlWithdrawResponse,
Expand All @@ -29,7 +28,6 @@
"LnurlHostedChannelResponse",
"LnurlPayActionResponse",
"LnurlPayResponse",
"LnurlPayResponseComment",
"LnurlResponse",
"LnurlSuccessResponse",
"LnurlWithdrawResponse",
Expand Down
31 changes: 14 additions & 17 deletions lnurl/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import math
from typing import List, Literal, Optional, Union

Expand Down Expand Up @@ -51,11 +53,11 @@ class Config:

def dict(self, **kwargs):
kwargs.setdefault("by_alias", True)
return super().dict(**kwargs)
return super().dict(**kwargs, exclude_none=True)

def json(self, **kwargs):
kwargs.setdefault("by_alias", True)
return super().json(**kwargs)
return super().json(**kwargs, exclude_none=True)

@property
def ok(self) -> bool:
Expand Down Expand Up @@ -106,8 +108,16 @@ class LnurlPayResponse(LnurlResponseModel):
max_sendable: MilliSatoshi = Field(..., alias="maxSendable", gt=0)
metadata: LnurlPayMetadata

# Adds the optional comment_allowed field to the LnurlPayResponse
# ref LUD-12: Comments in payRequest.
comment_allowed: Optional[int] = Field(
None,
description="Length of comment which can be sent",
alias="commentAllowed",
)

@validator("max_sendable")
def max_less_than_min(cls, value, values, **kwargs): # noqa
def max_less_than_min(cls, value, values): # noqa
if "min_sendable" in values and value < values["min_sendable"]:
raise ValueError("`max_sendable` cannot be less than `min_sendable`.")
return value
Expand All @@ -121,19 +131,6 @@ def max_sats(self) -> int:
return int(math.floor(self.max_sendable / 1000))


class LnurlPayResponseComment(LnurlPayResponse):
"""
Adds the optional comment_allowed field to the LnurlPayResponse
ref LUD-12: Comments in payRequest.
"""

comment_allowed: int = Field(
1000,
description="Length of comment which can be sent",
alias="commentAllowed",
)


class LnurlPayActionResponse(LnurlResponseModel):
pr: LightningInvoice
success_action: Optional[Union[MessageAction, UrlAction, AesAction]] = Field(None, alias="successAction")
Expand All @@ -150,7 +147,7 @@ class LnurlWithdrawResponse(LnurlResponseModel):
default_description: str = Field("", alias="defaultDescription")

@validator("max_withdrawable")
def max_less_than_min(cls, value, values, **kwargs): # noqa
def max_less_than_min(cls, value, values): # noqa
if "min_withdrawable" in values and value < values["min_withdrawable"]:
raise ValueError("`max_withdrawable` cannot be less than `min_withdrawable`.")
return value
Expand Down
7 changes: 3 additions & 4 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
LnurlErrorResponse,
LnurlHostedChannelResponse,
LnurlPayResponse,
LnurlPayResponseComment,
LnurlSuccessResponse,
LnurlWithdrawResponse,
)
Expand Down Expand Up @@ -62,7 +61,7 @@ class TestLnurlHostedChannelResponse:
def test_channel_response(self, d):
res = LnurlHostedChannelResponse(**d)
assert res.ok
assert res.dict() == {**{"tag": "hostedChannelRequest", "alias": None}, **d}
assert res.dict() == {**{"tag": "hostedChannelRequest"}, **d}

@pytest.mark.parametrize(
"d", [{"uri": "invalid", "k1": "c3RyaW5n"}, {"uri": "node_key@ip_address:port_number", "k1": None}]
Expand Down Expand Up @@ -145,7 +144,7 @@ class TestLnurlPayResponseComment:
],
)
def test_success_response(self, d):
res = LnurlPayResponseComment(**d)
res = LnurlPayResponse(**d)
assert res.ok
assert (
res.json() == res.json(by_alias=True) == '{"tag": "payRequest", "callback": "https://service.io/pay", '
Expand Down Expand Up @@ -192,7 +191,7 @@ def test_success_response(self, d):
)
def test_invalid_data(self, d):
with pytest.raises(ValidationError):
LnurlPayResponseComment(**d)
LnurlPayResponse(**d)


class TestLnurlWithdrawResponse:
Expand Down

0 comments on commit 635f365

Please sign in to comment.