Skip to content

Commit

Permalink
fix: add timeout to node requests to prevent jobs from getting stuck
Browse files Browse the repository at this point in the history
  • Loading branch information
SaintShit committed Feb 29, 2024
1 parent c2a23fa commit 1818a2a
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 55 deletions.
4 changes: 2 additions & 2 deletions app/jobs/0_xray_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def core_health_check():
if node.connected:
try:
assert node.started
node.api.get_sys_stats()
except (ConnectionError, xray_exc.ConnectionError, xray_exc.UnknownError, AssertionError):
node.api.get_sys_stats(timeout=2)
except (ConnectionError, xray_exc.XrayError, AssertionError):
if not config:
config = xray.config.include_db_users()
xray.operations.restart_node(node_id, config)
Expand Down
8 changes: 4 additions & 4 deletions app/jobs/record_usages.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,20 +104,20 @@ def record_node_stats(params: dict, node_id: Union[int, None]):
def get_users_stats(api: XRayAPI):
try:
params = defaultdict(int)
for stat in filter(attrgetter('value'), api.get_users_stats(reset=True)):
for stat in filter(attrgetter('value'), api.get_users_stats(reset=True, timeout=30)):
params[stat.name.split('.', 1)[0]] += stat.value
params = list({"uid": uid, "value": value} for uid, value in params.items())
return params
except (xray_exc.ConnectionError, xray_exc.UnknownError):
except xray_exc.XrayError:
return []


def get_outbounds_stats(api: XRayAPI):
try:
params = [{"up": stat.value, "down": 0} if stat.link == "uplink" else {"up": 0, "down": stat.value}
for stat in filter(attrgetter('value'), api.get_outbounds_stats(reset=True))]
for stat in filter(attrgetter('value'), api.get_outbounds_stats(reset=True, timeout=10))]
return params
except (xray_exc.ConnectionError, xray_exc.UnknownError):
except xray_exc.XrayError:
return []


Expand Down
21 changes: 11 additions & 10 deletions app/xray/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,10 @@ def _prepare_config(self, config: XRayConfig):

return config

def make_request(self, path, **params):
def make_request(self, path: str, timeout: int, **params):
try:
res = self.session.post(self._rest_api_url + path, json={"session_id": self._session_id, **params})
res = self.session.post(self._rest_api_url + path, timeout=timeout,
json={"session_id": self._session_id, **params})
data = res.json()
except Exception as e:
exc = NodeAPIError(0, str(e))
Expand All @@ -118,14 +119,14 @@ def connected(self):
if not self._session_id:
return False
try:
self.make_request("/ping")
self.make_request("/ping", timeout=3)
return True
except NodeAPIError:
return False

@property
def started(self):
res = self.make_request("/")
res = self.make_request("/", timeout=3)
return res.get('started', False)

@property
Expand All @@ -151,15 +152,15 @@ def connect(self):
self._node_certfile = string_to_temp_file(self._node_cert)
self.session.verify = self._node_certfile.name

res = self.make_request("/connect")
res = self.make_request("/connect", timeout=3)
self._session_id = res['session_id']

def disconnect(self):
self.make_request("/disconnect")
self.make_request("/disconnect", timeout=3)
self._session_id = None

def get_version(self):
res = self.make_request("/")
res = self.make_request("/", timeout=3)
return res.get('core_version')

def start(self, config: XRayConfig):
Expand All @@ -170,7 +171,7 @@ def start(self, config: XRayConfig):
json_config = config.to_json()

try:
res = self.make_request("/start", config=json_config)
res = self.make_request("/start", timeout=10, config=json_config)
except NodeAPIError as exc:
if exc.detail == 'Xray is started already':
return self.restart(config)
Expand All @@ -197,7 +198,7 @@ def stop(self):
if not self.connected:
self.connect()

self.make_request('/stop')
self.make_request('/stop', timeout=5)
self._api = None
self._started = False

Expand All @@ -208,7 +209,7 @@ def restart(self, config: XRayConfig):
config = self._prepare_config(config)
json_config = config.to_json()

res = self.make_request("/restart", config=json_config)
res = self.make_request("/restart", timeout=10, config=json_config)

self._started = True

Expand Down
13 changes: 8 additions & 5 deletions app/xray/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,27 +31,27 @@ def get_tls():
@threaded_function
def _add_user_to_inbound(api: XRayAPI, inbound_tag: str, account: Account):
try:
api.add_inbound_user(tag=inbound_tag, user=account)
api.add_inbound_user(tag=inbound_tag, user=account, timeout=30)
except (xray.exc.EmailExistsError, xray.exc.ConnectionError):
pass


@threaded_function
def _remove_user_from_inbound(api: XRayAPI, inbound_tag: str, email: str):
try:
api.remove_inbound_user(tag=inbound_tag, email=email)
api.remove_inbound_user(tag=inbound_tag, email=email, timeout=30)
except (xray.exc.EmailNotFoundError, xray.exc.ConnectionError):
pass


@threaded_function
def _alter_inbound_user(api: XRayAPI, inbound_tag: str, account: Account):
try:
api.remove_inbound_user(tag=inbound_tag, email=account.email)
api.remove_inbound_user(tag=inbound_tag, email=account.email, timeout=30)
except (xray.exc.EmailNotFoundError, xray.exc.ConnectionError):
pass
try:
api.add_inbound_user(tag=inbound_tag, user=account)
api.add_inbound_user(tag=inbound_tag, user=account, timeout=30)
except (xray.exc.EmailExistsError, xray.exc.ConnectionError):
pass

Expand Down Expand Up @@ -152,7 +152,10 @@ def remove_node(node_id: int):
except Exception:
pass
finally:
del xray.nodes[node_id]
try:
del xray.nodes[node_id]
except KeyError:
pass


def add_node(dbnode: "DBNode"):
Expand Down
19 changes: 13 additions & 6 deletions xray_api/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,14 @@ def __init__(self, details, tag):
class ConnectionError(XrayError):
REGEXP = re.compile(r"Failed to connect to remote host|Socket closed|Broken pipe")

def __init__(self, details, tag):
self.tag = tag
def __init__(self, details):
super().__init__(details)


class TimeoutError(XrayError):
REGEXP = re.compile(r"Deadline Exceeded")

def __init__(self, details):
super().__init__(details)


Expand All @@ -50,11 +56,12 @@ def __init__(self, details=''):
class RelatedError(XrayError):
def __new__(cls, error: grpc.RpcError):
details = error.details()
for e in (EmailExistsError, EmailNotFoundError, TagNotFoundError, ConnectionError):
args = e.REGEXP.findall(details)
if not args:

for e in (EmailExistsError, EmailNotFoundError, TagNotFoundError, ConnectionError, TimeoutError):
m = e.REGEXP.search(details)
if not m:
continue

return e(details, *args)
return e(details, *m.groups())

return UnknownError(details)
24 changes: 12 additions & 12 deletions xray_api/proxyman.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,25 @@


class Proxyman(XRayBase):
def alter_inbound(self, tag: str, operation: TypedMessage) -> bool:
def alter_inbound(self, tag: str, operation: TypedMessage, timeout: int = None) -> bool:
stub = command_pb2_grpc.HandlerServiceStub(self._channel)
try:
stub.AlterInbound(command_pb2.AlterInboundRequest(tag=tag, operation=operation))
stub.AlterInbound(command_pb2.AlterInboundRequest(tag=tag, operation=operation), timeout=timeout)
return True

except grpc.RpcError as e:
raise RelatedError(e)

def alter_outbound(self, tag: str, operation: TypedMessage) -> bool:
def alter_outbound(self, tag: str, operation: TypedMessage, timeout: int = None) -> bool:
stub = command_pb2_grpc.HandlerServiceStub(self._channel)
try:
stub.AlterInbound(command_pb2.AlterOutboundRequest(tag=tag, operation=operation))
stub.AlterInbound(command_pb2.AlterOutboundRequest(tag=tag, operation=operation), timeout=timeout)
return True

except grpc.RpcError as e:
raise RelatedError(e)

def add_inbound_user(self, tag: str, user: Account) -> bool:
def add_inbound_user(self, tag: str, user: Account, timeout: int = None) -> bool:
return self.alter_inbound(
tag=tag,
operation=Message(
Expand All @@ -43,18 +43,18 @@ def add_inbound_user(self, tag: str, user: Account) -> bool:
account=user.message
)
)
))
), timeout=timeout)

def remove_inbound_user(self, tag: str, email: str) -> bool:
def remove_inbound_user(self, tag: str, email: str, timeout: int = None) -> bool:
return self.alter_inbound(
tag=tag,
operation=Message(
command_pb2.RemoveUserOperation(
email=email
)
))
), timeout=timeout)

def add_outbound_user(self, tag: str, user: Account) -> bool:
def add_outbound_user(self, tag: str, user: Account, timeout: int = None) -> bool:
return self.alter_outbound(
tag=tag,
operation=Message(
Expand All @@ -65,16 +65,16 @@ def add_outbound_user(self, tag: str, user: Account) -> bool:
account=user.message
)
)
))
), timeout=timeout)

def remove_outbound_user(self, tag: str, email: str) -> bool:
def remove_outbound_user(self, tag: str, email: str, timeout: int = None) -> bool:
return self.alter_outbound(
tag=tag,
operation=Message(
command_pb2.RemoveUserOperation(
email=email
)
))
), timeout=timeout)

def add_inbound(self):
raise NotImplementedError
Expand Down
32 changes: 16 additions & 16 deletions xray_api/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ class OutboundStatsResponse:


class Stats(XRayBase):
def get_sys_stats(self) -> SysStatsResponse:
def get_sys_stats(self, timeout: int = None) -> SysStatsResponse:
try:
stub = command_pb2_grpc.StatsServiceStub(self._channel)
r = stub.GetSysStats(command_pb2.SysStatsRequest())
r = stub.GetSysStats(command_pb2.SysStatsRequest(), timeout=timeout)

except grpc.RpcError as e:
raise RelatedError(e)
Expand All @@ -73,10 +73,10 @@ def get_sys_stats(self) -> SysStatsResponse:
uptime=r.Uptime
)

def query_stats(self, pattern: str, reset: bool = False) -> typing.Iterable[StatResponse]:
def query_stats(self, pattern: str, reset: bool = False, timeout: int = None) -> typing.Iterable[StatResponse]:
try:
stub = command_pb2_grpc.StatsServiceStub(self._channel)
r = stub.QueryStats(command_pb2.QueryStatsRequest(pattern=pattern, reset=reset))
r = stub.QueryStats(command_pb2.QueryStatsRequest(pattern=pattern, reset=reset), timeout=timeout)

except grpc.RpcError as e:
raise RelatedError(e)
Expand All @@ -85,37 +85,37 @@ def query_stats(self, pattern: str, reset: bool = False) -> typing.Iterable[Stat
type, name, _, link = stat.name.split('>>>')
yield StatResponse(name, type, link, stat.value)

def get_users_stats(self, reset: bool = False) -> typing.Iterable[StatResponse]:
return self.query_stats("user>>>", reset=reset)
def get_users_stats(self, reset: bool = False, timeout: int = None) -> typing.Iterable[StatResponse]:
return self.query_stats("user>>>", reset=reset, timeout=timeout)

def get_inbounds_stats(self, reset: bool = False) -> typing.Iterable[StatResponse]:
return self.query_stats("inbound>>>", reset=reset)
def get_inbounds_stats(self, reset: bool = False, timeout: int = None) -> typing.Iterable[StatResponse]:
return self.query_stats("inbound>>>", reset=reset, timeout=timeout)

def get_outbounds_stats(self, reset: bool = False) -> typing.Iterable[StatResponse]:
return self.query_stats("outbound>>>", reset=reset)
def get_outbounds_stats(self, reset: bool = False, timeout: int = None) -> typing.Iterable[StatResponse]:
return self.query_stats("outbound>>>", reset=reset, timeout=timeout)

def get_user_stats(self, email: str, reset: bool = False) -> typing.Iterable[StatResponse]:
def get_user_stats(self, email: str, reset: bool = False, timeout: int = None) -> typing.Iterable[StatResponse]:
uplink, downlink = 0, 0
for stat in self.query_stats(f"user>>>{email}>>>", reset=reset):
for stat in self.query_stats(f"user>>>{email}>>>", reset=reset, timeout=timeout):
if stat.link == 'uplink':
uplink = stat.value
if stat.link == 'downlink':
downlink = stat.value

return UserStatsResponse(email=email, uplink=uplink, downlink=downlink)

def get_inbound_stats(self, tag: str, reset: bool = False) -> typing.Iterable[StatResponse]:
def get_inbound_stats(self, tag: str, reset: bool = False, timeout: int = None) -> typing.Iterable[StatResponse]:
uplink, downlink = 0, 0
for stat in self.query_stats(f"inbound>>>{tag}>>>", reset=reset):
for stat in self.query_stats(f"inbound>>>{tag}>>>", reset=reset, timeout=timeout):
if stat.link == 'uplink':
uplink = stat.value
if stat.link == 'downlink':
downlink = stat.value
return InboundStatsResponse(tag=tag, uplink=uplink, downlink=downlink)

def get_outbound_stats(self, tag: str, reset: bool = False) -> typing.Iterable[StatResponse]:
def get_outbound_stats(self, tag: str, reset: bool = False, timeout: int = None) -> typing.Iterable[StatResponse]:
uplink, downlink = 0, 0
for stat in self.query_stats(f"outbound>>>{tag}>>>", reset=reset):
for stat in self.query_stats(f"outbound>>>{tag}>>>", reset=reset, timeout=timeout):
if stat.link == 'uplink':
uplink = stat.value
if stat.link == 'downlink':
Expand Down

0 comments on commit 1818a2a

Please sign in to comment.