@@ -450,15 +450,16 @@ def next(self):
450450
451451
452452class TestStrategy (TestCase ):
453- def _Backtest (self , strategy_coroutine , ** kwargs ):
453+ @staticmethod
454+ def _Backtest (strategy_coroutine , data = SHORT_DATA , ** kwargs ):
454455 class S (Strategy ):
455456 def init (self ):
456457 self .step = strategy_coroutine (self )
457458
458459 def next (self ):
459460 try_ (self .step .__next__ , None , StopIteration )
460461
461- return Backtest (SHORT_DATA , S , ** kwargs )
462+ return Backtest (data , S , ** kwargs )
462463
463464 def test_position (self ):
464465 def coroutine (self ):
@@ -1032,12 +1033,8 @@ def next(self):
10321033 if self .data .Close [- 1 ] == 100 :
10331034 self .buy (size = 1 , sl = 90 )
10341035
1035- df = pd .DataFrame ({
1036- 'Open' : [100 , 100 , 100 , 50 , 50 ],
1037- 'High' : [100 , 100 , 100 , 50 , 50 ],
1038- 'Low' : [100 , 100 , 100 , 50 , 50 ],
1039- 'Close' : [100 , 100 , 100 , 50 , 50 ],
1040- })
1036+ arr = np .r_ [100 , 100 , 100 , 50 , 50 ]
1037+ df = pd .DataFrame ({'Open' : arr , 'High' : arr , 'Low' : arr , 'Close' : arr })
10411038 with self .assertWarnsRegex (UserWarning , 'index is not datetime' ):
10421039 bt = Backtest (df , S , cash = 100 , trade_on_close = True )
10431040 self .assertEqual (bt .run ()._trades ['ExitPrice' ][0 ], 50 )
@@ -1059,3 +1056,44 @@ def next(self):
10591056 order .cancel ()
10601057
10611058 Backtest (SHORT_DATA , S ).run ()
1059+
1060+ def test_trade_on_close_closes_trades_on_close (self ):
1061+ def coro (strat ):
1062+ yield strat .buy (size = 1 , sl = 90 ) and strat .buy (size = 1 , sl = 80 )
1063+ assert len (strat .trades ) == 2
1064+ yield strat .trades [0 ].close ()
1065+ yield
1066+
1067+ arr = np .r_ [100 , 101 , 102 , 50 , 51 ]
1068+ df = pd .DataFrame ({
1069+ 'Open' : arr - 10 ,
1070+ 'Close' : arr , 'High' : arr , 'Low' : arr })
1071+ with self .assertWarnsRegex (UserWarning , 'index is not datetime' ):
1072+ trades = TestStrategy ._Backtest (coro , df , cash = 250 , trade_on_close = True ).run ()._trades
1073+ # trades = Backtest(df, S, cash=250, trade_on_close=True).run()._trades
1074+ self .assertEqual (trades ['EntryBar' ][0 ], 1 )
1075+ self .assertEqual (trades ['ExitBar' ][0 ], 2 )
1076+ self .assertEqual (trades ['EntryPrice' ][0 ], 101 )
1077+ self .assertEqual (trades ['ExitPrice' ][0 ], 102 )
1078+ self .assertEqual (trades ['EntryBar' ][1 ], 1 )
1079+ self .assertEqual (trades ['ExitBar' ][1 ], 3 )
1080+ self .assertEqual (trades ['EntryPrice' ][1 ], 101 )
1081+ self .assertEqual (trades ['ExitPrice' ][1 ], 40 )
1082+
1083+ with self .assertWarnsRegex (UserWarning , 'index is not datetime' ):
1084+ trades = TestStrategy ._Backtest (coro , df , cash = 250 , trade_on_close = False ).run ()._trades
1085+ # trades = Backtest(df, S, cash=250, trade_on_close=False).run()._trades
1086+ self .assertEqual (trades ['EntryBar' ][0 ], 2 )
1087+ self .assertEqual (trades ['ExitBar' ][0 ], 3 )
1088+ self .assertEqual (trades ['EntryPrice' ][0 ], 92 )
1089+ self .assertEqual (trades ['ExitPrice' ][0 ], 40 )
1090+ self .assertEqual (trades ['EntryBar' ][1 ], 2 )
1091+ self .assertEqual (trades ['ExitBar' ][1 ], 3 )
1092+ self .assertEqual (trades ['EntryPrice' ][1 ], 92 )
1093+ self .assertEqual (trades ['ExitPrice' ][1 ], 40 )
1094+
1095+ def test_trades_dates_match_prices (self ):
1096+ bt = Backtest (EURUSD , SmaCross , trade_on_close = True )
1097+ trades = bt .run ()._trades
1098+ self .assertEqual (EURUSD .Close [trades ['ExitTime' ]].tolist (),
1099+ trades ['ExitPrice' ].tolist ())
0 commit comments