Skip to content

Commit

Permalink
0.9.61 优化 cross_sectional_strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
zengbin93 committed Dec 5, 2024
1 parent 4d58748 commit fc8be60
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 19 deletions.
45 changes: 30 additions & 15 deletions czsc/eda.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,46 +83,61 @@ def remove_beta_effects(df, **kwargs):
return dfr


def cross_sectional_strategy(df, factor, **kwargs):
def cross_sectional_strategy(df, factor, weight="weight", long=0.3, short=0.3, **kwargs):
"""根据截面因子值构建多空组合
:param df: pd.DataFrame, 包含因子列的数据, 必须包含 dt, symbol, factor 列
:param factor: str, 因子列名称
:param weight: str, 权重列名称,默认为 weight
:param long: float, 多头持仓比例/数量,默认为 0.3, 取值范围为 [0, n_symbols], 0~1 表示比例,大于等于1表示数量
:param short: float, 空头持仓比例/数量,默认为 0.3, 取值范围为 [0, n_symbols], 0~1 表示比例,大于等于1表示数量
:param kwargs:
- factor_direction: str, 因子方向,positive 或 negative
- long_num: int, 多头持仓数量
- short_num: int, 空头持仓数量
- logger: loguru.logger, 日志记录器
- norm: bool, 是否对 weight 进行截面持仓标准化,默认为 False
:return: pd.DataFrame, 包含 weight 列的数据
"""
factor_direction = kwargs.get("factor_direction", "positive")
long_num = kwargs.get("long_num", 5)
short_num = kwargs.get("short_num", 5)
logger = kwargs.get("logger", loguru.logger)
norm = kwargs.get("norm", True)

assert long >= 0 and short >= 0, "long 和 short 参数必须大于等于0"
assert factor in df.columns, f"{factor} 不在 df 中"
assert factor_direction in ["positive", "negative"], f"factor_direction 参数错误"

df = df.copy()
if factor_direction == "negative":
df[factor] = -df[factor]

df['weight'] = 0
df[weight] = 0
rows = []

for dt, dfg in df.groupby("dt"):
if len(dfg) < long_num + short_num:
logger.warning(f"{dt} 截面数据量过小,跳过;仅有 {len(dfg)} 条数据,需要 {long_num + short_num} 条数据")
long_num = long if long >= 1 else int(len(dfg) * long)
short_num = short if short >= 1 else int(len(dfg) * short)

if long_num == 0 and short_num == 0:
logger.warning(f"{dt} 多空目前持仓数量都为0; long: {long}, short: {short}")
rows.append(dfg)
continue

dfa = dfg.sort_values(factor, ascending=False).head(long_num)
dfb = dfg.sort_values(factor, ascending=True).head(short_num)
if long_num > 0:
df.loc[dfa.index, "weight"] = 1 / long_num
if short_num > 0:
df.loc[dfb.index, "weight"] = -1 / short_num
long_symbols = dfg.sort_values(factor, ascending=False).head(long_num)['symbol'].tolist()
short_symbols = dfg.sort_values(factor, ascending=True).head(short_num)['symbol'].tolist()

return df
union_symbols = set(long_symbols) & set(short_symbols)
if union_symbols:
logger.warning(f"{dt} 存在同时在多头和空头的品种:{union_symbols}")
long_symbols = list(set(long_symbols) - union_symbols)
short_symbols = list(set(short_symbols) - union_symbols)

dfg.loc[dfg['symbol'].isin(long_symbols), weight] = 1 / long_num if norm else 1
dfg.loc[dfg['symbol'].isin(short_symbols), weight] = -1 / short_num if norm else -1
rows.append(dfg)

dfx = pd.concat(rows, ignore_index=True)
return dfx


def judge_factor_direction(df: pd.DataFrame, factor, target='n1b', by='symbol', **kwargs):
Expand Down
13 changes: 9 additions & 4 deletions examples/develop/weight_backtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,16 @@ def test_ensemble_weight():

def test_rust_weight_backtest():
"""从持仓权重样例数据中回测"""
from rs_czsc import PyBacktest as WeightBacktest
from rs_czsc import WeightBacktest

# from rs_czsc import daily_performance
# from czsc import daily_performance

# stats = daily_performance([0.01, 0.02, -0.03, 0.04, 0.05])
dfw = pd.read_feather(r"C:\Users\zengb\Downloads\weight_example.feather")

wb = WeightBacktest(czsc.to_arrow(dfw), digits=2, fee_rate=0.0002, n_jobs=1)
# wb = WeightBacktest(czsc.to_arrow(dfw), digit=2, fee_rate=0.0002, n_jobs=1)
wb = WeightBacktest(dfw, digits=2, fee_rate=0.0002, n_jobs=1)

# ss = sorted(wb.stats.items())
# print(ss)
ss = sorted(wb.stats.items())
print(ss)
68 changes: 68 additions & 0 deletions test/test_cross_sectional_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# tests/test_cross_sectional_strategy.py
import pytest
import pandas as pd
from czsc.eda import cross_sectional_strategy


@pytest.fixture
def sample_data():
data = {
"dt": [
"2023-01-01",
"2023-01-02",
"2023-01-03",
"2023-01-04",
"2023-01-05",
"2023-01-06",
"2023-01-07",
"2023-01-08",
"2023-01-09",
"2023-01-10",
]
* 5,
"symbol": ["A"] * 10 + ["B"] * 10 + ["C"] * 10 + ["D"] * 10 + ["E"] * 10,
"factor": list(range(1, 51)),
}
return pd.DataFrame(data)


def test_cross_sectional_strategy_positive(sample_data):
result = cross_sectional_strategy(sample_data, factor="factor", long=0.5, short=0.5, factor_direction="positive")
assert "weight" in result.columns
assert result["weight"].sum() == 0 # Long and short positions should balance out


def test_cross_sectional_strategy_negative(sample_data):
result = cross_sectional_strategy(sample_data, factor="factor", long=0.5, short=0.5, factor_direction="negative")
assert "weight" in result.columns
assert result["weight"].sum() == 0 # Long and short positions should balance out
print(result)


def test_cross_sectional_strategy_negative_norm(sample_data):
result = cross_sectional_strategy(
sample_data, factor="factor", long=0.5, short=0.5, factor_direction="negative", norm=False
)
assert "weight" in result.columns
assert result["weight"].sum() == 0 # Long and short positions should balance out
print(result)


def test_cross_sectional_strategy_no_positions(sample_data):
result = cross_sectional_strategy(sample_data, factor="factor", long=0, short=0)
assert "weight" in result.columns
assert result["weight"].sum() == 0 # No positions should be taken


def test_cross_sectional_strategy_invalid_factor(sample_data):
with pytest.raises(AssertionError):
cross_sectional_strategy(sample_data, factor="invalid_factor", long=0.5, short=0.5)


def test_cross_sectional_strategy_invalid_factor_direction(sample_data):
with pytest.raises(AssertionError):
cross_sectional_strategy(sample_data, factor="factor", long=0.5, short=0.5, factor_direction="invalid")


if __name__ == "__main__":
pytest.main()

0 comments on commit fc8be60

Please sign in to comment.