From 32e469e042ed0a961100f44f6a3b8472522e538b Mon Sep 17 00:00:00 2001 From: zengbin93 Date: Mon, 14 Oct 2024 23:12:29 +0800 Subject: [PATCH] =?UTF-8?q?0.9.60=20=E6=96=B0=E5=A2=9E=20timeout=5Fdecorat?= =?UTF-8?q?or=20=E8=A3=85=E9=A5=B0=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- czsc/__init__.py | 1 + czsc/utils/__init__.py | 30 ++++++++++++++++++++++++++++++ test/test_utils.py | 21 +++++++++++++++++++++ 3 files changed, 52 insertions(+) diff --git a/czsc/__init__.py b/czsc/__init__.py index 73e0691b5..edc1a4ded 100644 --- a/czsc/__init__.py +++ b/czsc/__init__.py @@ -47,6 +47,7 @@ ExitsOptimize, ) from czsc.utils import ( + timeout_decorator, mac_address, overlap, to_arrow, diff --git a/czsc/utils/__init__.py b/czsc/utils/__init__.py index 4d93a3000..c91105b16 100644 --- a/czsc/utils/__init__.py +++ b/czsc/utils/__init__.py @@ -1,7 +1,9 @@ # coding: utf-8 import os +import functools import pandas as pd from typing import List, Union +from loguru import logger from . import qywx from . import ta @@ -226,3 +228,31 @@ def to_arrow(df: pd.DataFrame): with pa.ipc.new_file(sink, table.schema) as writer: writer.write_table(table) return sink.getvalue() + + +def timeout_decorator(timeout): + """超时装饰器 + + :param timeout: int, 超时时间,单位秒 + """ + + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + from concurrent.futures import ThreadPoolExecutor, TimeoutError + + with ThreadPoolExecutor() as executor: + future = executor.submit(func, *args, **kwargs) + try: + result = future.result(timeout=timeout) + return result + except TimeoutError: + # print(f"{func.__name__} timed out after {timeout} seconds") + logger.warning( + f"{func.__name__} timed out after {timeout} seconds;" f"args: {args}; kwargs: {kwargs}" + ) + raise ValueError(f"{func.__name__} timed out after {timeout} seconds") + + return wrapper + + return decorator diff --git a/test/test_utils.py b/test/test_utils.py index 9c984dd1e..05bc49024 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -7,9 +7,11 @@ """ import sys import pytest +import time import pandas as pd import numpy as np from czsc import utils +from czsc.utils import timeout_decorator def test_x_round(): @@ -344,3 +346,22 @@ def test_overlap(): # 验证结果 assert result["col_overlap"].tolist() == [1, 2, 1, 2, 1] + + +def test_timeout_decorator_success(): + @timeout_decorator(2) + def fast_function(): + time.sleep(1) + return "Completed" + + assert fast_function() == "Completed" + + +def test_timeout_decorator_timeout(): + @timeout_decorator(1) + def slow_function(): + time.sleep(2) + return "Completed" + + with pytest.raises(ValueError, match="timed out after 1 seconds"): + slow_function()