Skip to content

Commit

Permalink
0.9.27 统一ensemble_method回调函数的输入
Browse files Browse the repository at this point in the history
  • Loading branch information
zengbin93 committed Aug 15, 2023
1 parent 28d229d commit a21d233
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
5 changes: 3 additions & 2 deletions czsc/traders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,8 @@ def __init__(self, bg: BarGenerator = None, positions: List[Position] = None,
vote - 投票表决,pos = 1
max - 取最大,pos = 1
对于传入回调函数的情况,输入是 self.positions
对于传入回调函数的情况,函数的输入为 dict,key 为 position.name,value 为 position.pos, 样例输入:
{'多头策略A': 1, '多头策略B': 1, '空头策略A': -1}
"""
self.positions = positions
if self.positions:
Expand Down Expand Up @@ -403,7 +404,7 @@ def get_ensemble_pos(self, method: Union[AnyStr, Callable] = None) -> float:
raise ValueError

else:
pos = method(self.positions)
pos = method({x.name: x.pos for x in self.positions})

return pos

Expand Down
4 changes: 2 additions & 2 deletions test/test_trader_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,8 @@ def __create_sma20_pos():
assert [x.pos for x in ct.positions] == [0, 0, 0]

# 测试自定义仓位集成
def _weighted_ensemble(positions: List[Position]):
return 0.5 * positions[0].pos + 0.5 * positions[1].pos
def _weighted_ensemble(poss):
return 0.5 * poss['测试A'] + 0.5 * poss['测试B']

assert ct.get_ensemble_pos(_weighted_ensemble) == 0
assert ct.get_ensemble_pos('vote') == 0
Expand Down

0 comments on commit a21d233

Please sign in to comment.