Skip to content

Commit

Permalink
0.9.58 fix bug: remove_beta_effects
Browse files Browse the repository at this point in the history
  • Loading branch information
zengbin93 committed Aug 11, 2024
1 parent 8bc1cd6 commit fbc30f4
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions czsc/eda.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,17 @@ def remove_beta_effects(df, **kwargs):
- factor: str, 因子列名
- betas: list, beta 列名列表
- linear_model: str, 线性模型,可选 ridge、linear 或 lasso
- linear_model_params: dict, 线性模型参数, 默认为空, 需要传入字典,根据模型不同参数不同
:return: DataFrame
"""

linear_model = kwargs.get("linear_model", "ridge")
linear_model_params = kwargs.get("linear_model_params", {})
linear = {
"ridge": Ridge(),
"linear": LinearRegression(),
"lasso": Lasso(),
"ridge": Ridge,
"linear": LinearRegression,
"lasso": Lasso,
}
assert linear_model in linear.keys(), "linear_model 参数必须为 ridge、linear 或 lasso"
Model = linear[linear_model]
Expand All @@ -71,7 +73,7 @@ def remove_beta_effects(df, **kwargs):

x = dfg[betas].values
y = dfg[factor].values
model = Model().fit(x, y)
model = Model(**linear_model_params).fit(x, y)
dfg[factor] = y - model.predict(x)
rows.append(dfg)

Expand Down Expand Up @@ -113,7 +115,9 @@ def cross_sectional_strategy(df, factor, **kwargs):

dfa = dfg.sort_values(factor, ascending=False).head(long_num)
dfb = dfg.sort_values(factor, ascending=True).head(short_num)
df.loc[dfa.index, "weight"] = 1 / long_num
df.loc[dfb.index, "weight"] = -1 / short_num
if long_num > 0:
df.loc[dfa.index, "weight"] = 1 / long_num
if short_num > 0:
df.loc[dfb.index, "weight"] = -1 / short_num

return df

0 comments on commit fbc30f4

Please sign in to comment.