diff --git a/df_trade_simulator/__init__.py b/df_trade_simulator/__init__.py index b0657c8..2dda09d 100644 --- a/df_trade_simulator/__init__.py +++ b/df_trade_simulator/__init__.py @@ -6,6 +6,5 @@ MarketDFTradeSimulator, StopLimitDFTradeSimulator, ) -from df_trade_simulator.df_trade_simulator import DFTradeConfig, StopLimitDFTradeConfig from df_trade_simulator.df_trade_simulator import Side from df_trade_simulator.df_trade_simulator import BUY, SELL, HODL diff --git a/df_trade_simulator/df_trade_simulator.py b/df_trade_simulator/df_trade_simulator.py index 1dfa6ab..c25c36c 100644 --- a/df_trade_simulator/df_trade_simulator.py +++ b/df_trade_simulator/df_trade_simulator.py @@ -35,25 +35,15 @@ class DFTrade: wallet_fee: float = INIT_WALLET -@dataclass -class DFTradeConfig(ABC): - """Basic representation of a DFTradeSimulator configuration object.""" - - ... - - class DFTradeSimulator(ABC): df: DataFrame - config: DFTradeConfig trades: list[DFTrade] current_row: Series = None last_row: Series = None x_fee: float columns: tuple[str] - def __init__(self, df: DataFrame, config: DFTradeConfig = None, **kwargs) -> None: - # store config - self.config = config + def __init__(self, df: DataFrame, **kwargs) -> None: # read kwargs self.fee = kwargs.get("fee", 0.1) self.price_col = kwargs.get("price_col", PRICE_COL) @@ -72,8 +62,6 @@ def __init__(self, df: DataFrame, config: DFTradeConfig = None, **kwargs) -> Non self.set_df(df) # set fee self.x_fee = (100 - self.fee) / 100 - # validate - self.validate() def set_df(self, df: DataFrame): """Sets the current data frame and cleans up existing trades.""" @@ -88,9 +76,6 @@ def set_df(self, df: DataFrame): # clean up self.trades = [] - def validate(self) -> None: - """Validates the trade simulator object.""" - def add_signals(self, buy: Series, sell: Series) -> None: """Adds signal by a Series selection (boolean Series).""" self.df.loc[buy, self.signal_col] = BUY @@ -297,21 +282,14 @@ def should_trade(self): return self.side != self.signal and self.signal is not None -@dataclass -class StopLimitDFTradeConfig(DFTradeConfig): - """StopLimitDFTradeSimulator configuration object.""" - - treshold: float - - class StopLimitDFTradeSimulator(MarketDFTradeSimulator): """Stop-limit trade simulator which activates a stop loss trade if a given treshold reached.""" - def validate(self): - if not isinstance(self.config, StopLimitDFTradeConfig): - raise ValueError( - "StopLimitDFTradeSimulator configuration object is required." - ) + treshold: float + + def __init__(self, df: DataFrame, treshold: float, **kwargs): + self.treshold = treshold + super().__init__(df, **kwargs) def should_trade(self): # get prices @@ -322,7 +300,7 @@ def should_trade(self): self.current_row["stop-limit"] = "[ ]" # cause a stop limit if above the treshold - if change > self.config.treshold and self.signal is None: + if change > self.treshold and self.signal is None: self.current_row[self.signal_col] = SELL if self.side == BUY else BUY self.current_row["stop-limit"] = "[x]" return True @@ -335,6 +313,6 @@ def should_trade(self): df = pd.read_csv(TEST_CSV, index_col="time", parse_dates=["time"]) - trade_sim = StopLimitDFTradeSimulator(df, StopLimitDFTradeConfig(treshold=0.1)) + trade_sim = StopLimitDFTradeSimulator(df, treshold=0.1) trade_sim.simulate() print(trade_sim.df) diff --git a/tests/test_df_trade_simulator.py b/tests/test_df_trade_simulator.py index c269736..c793c43 100644 --- a/tests/test_df_trade_simulator.py +++ b/tests/test_df_trade_simulator.py @@ -13,7 +13,6 @@ DFTradeSimulator, MarketDFTradeSimulator, StopLimitDFTradeSimulator, - StopLimitDFTradeConfig, ) from df_trade_simulator import Side from df_trade_simulator import BUY, SELL, HODL @@ -329,26 +328,23 @@ def setUp(self) -> None: def test_market_df_trade_simulator(self) -> None: # NOTE: just for coverage right now # TODO: make more proper tests - trade_sim = MarketDFTradeSimulator(self.df, None) + trade_sim = MarketDFTradeSimulator(self.df) trade_sim.simulate() + self.assertAlmostEqual(trade_sim.wallet, 0.666667, 6) + class StopLimitDFTradeSimulator_TestCase(unittest.TestCase): def setUp(self) -> None: self.csv = prepare_test_env.TEST_SMALL_RESULTS_CSV self.df = prepare_test_env.READ_TEST_CSV(self.csv) - def test_validate(self) -> None: - with self.assertRaises(ValueError): - ts = StopLimitDFTradeSimulator(self.df, None) - def test_stop_limit_df_trade_simulator(self) -> None: - # NOTE: just for coverage right now - # TODO: make more proper tests - config = StopLimitDFTradeConfig(treshold=0.1) - trade_sim = StopLimitDFTradeSimulator(self.df, config) + trade_sim = StopLimitDFTradeSimulator(self.df, treshold=0.1) trade_sim.simulate() + self.assertAlmostEqual(trade_sim.wallet, 1.851852, 6) + if __name__ == "__main__": unittest.main()