Skip to content

Commit

Permalink
Replace parameter class with init args
Browse files Browse the repository at this point in the history
  • Loading branch information
martincpt committed Oct 29, 2022
1 parent e6d2953 commit 09f843f
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 41 deletions.
1 change: 0 additions & 1 deletion df_trade_simulator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,5 @@
MarketDFTradeSimulator,
StopLimitDFTradeSimulator,
)
from df_trade_simulator.df_trade_simulator import DFTradeConfig, StopLimitDFTradeConfig
from df_trade_simulator.df_trade_simulator import Side
from df_trade_simulator.df_trade_simulator import BUY, SELL, HODL
38 changes: 8 additions & 30 deletions df_trade_simulator/df_trade_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,25 +35,15 @@ class DFTrade:
wallet_fee: float = INIT_WALLET


@dataclass
class DFTradeConfig(ABC):
"""Basic representation of a DFTradeSimulator configuration object."""

...


class DFTradeSimulator(ABC):
df: DataFrame
config: DFTradeConfig
trades: list[DFTrade]
current_row: Series = None
last_row: Series = None
x_fee: float
columns: tuple[str]

def __init__(self, df: DataFrame, config: DFTradeConfig = None, **kwargs) -> None:
# store config
self.config = config
def __init__(self, df: DataFrame, **kwargs) -> None:
# read kwargs
self.fee = kwargs.get("fee", 0.1)
self.price_col = kwargs.get("price_col", PRICE_COL)
Expand All @@ -72,8 +62,6 @@ def __init__(self, df: DataFrame, config: DFTradeConfig = None, **kwargs) -> Non
self.set_df(df)
# set fee
self.x_fee = (100 - self.fee) / 100
# validate
self.validate()

def set_df(self, df: DataFrame):
"""Sets the current data frame and cleans up existing trades."""
Expand All @@ -88,9 +76,6 @@ def set_df(self, df: DataFrame):
# clean up
self.trades = []

def validate(self) -> None:
"""Validates the trade simulator object."""

def add_signals(self, buy: Series, sell: Series) -> None:
"""Adds signal by a Series selection (boolean Series)."""
self.df.loc[buy, self.signal_col] = BUY
Expand Down Expand Up @@ -297,21 +282,14 @@ def should_trade(self):
return self.side != self.signal and self.signal is not None


@dataclass
class StopLimitDFTradeConfig(DFTradeConfig):
"""StopLimitDFTradeSimulator configuration object."""

treshold: float


class StopLimitDFTradeSimulator(MarketDFTradeSimulator):
"""Stop-limit trade simulator which activates a stop loss trade if a given treshold reached."""

def validate(self):
if not isinstance(self.config, StopLimitDFTradeConfig):
raise ValueError(
"StopLimitDFTradeSimulator configuration object is required."
)
treshold: float

def __init__(self, df: DataFrame, treshold: float, **kwargs):
self.treshold = treshold
super().__init__(df, **kwargs)

def should_trade(self):
# get prices
Expand All @@ -322,7 +300,7 @@ def should_trade(self):
self.current_row["stop-limit"] = "[ ]"

# cause a stop limit if above the treshold
if change > self.config.treshold and self.signal is None:
if change > self.treshold and self.signal is None:
self.current_row[self.signal_col] = SELL if self.side == BUY else BUY
self.current_row["stop-limit"] = "[x]"
return True
Expand All @@ -335,6 +313,6 @@ def should_trade(self):

df = pd.read_csv(TEST_CSV, index_col="time", parse_dates=["time"])

trade_sim = StopLimitDFTradeSimulator(df, StopLimitDFTradeConfig(treshold=0.1))
trade_sim = StopLimitDFTradeSimulator(df, treshold=0.1)
trade_sim.simulate()
print(trade_sim.df)
16 changes: 6 additions & 10 deletions tests/test_df_trade_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
DFTradeSimulator,
MarketDFTradeSimulator,
StopLimitDFTradeSimulator,
StopLimitDFTradeConfig,
)
from df_trade_simulator import Side
from df_trade_simulator import BUY, SELL, HODL
Expand Down Expand Up @@ -329,26 +328,23 @@ def setUp(self) -> None:
def test_market_df_trade_simulator(self) -> None:
# NOTE: just for coverage right now
# TODO: make more proper tests
trade_sim = MarketDFTradeSimulator(self.df, None)
trade_sim = MarketDFTradeSimulator(self.df)
trade_sim.simulate()

self.assertAlmostEqual(trade_sim.wallet, 0.666667, 6)


class StopLimitDFTradeSimulator_TestCase(unittest.TestCase):
def setUp(self) -> None:
self.csv = prepare_test_env.TEST_SMALL_RESULTS_CSV
self.df = prepare_test_env.READ_TEST_CSV(self.csv)

def test_validate(self) -> None:
with self.assertRaises(ValueError):
ts = StopLimitDFTradeSimulator(self.df, None)

def test_stop_limit_df_trade_simulator(self) -> None:
# NOTE: just for coverage right now
# TODO: make more proper tests
config = StopLimitDFTradeConfig(treshold=0.1)
trade_sim = StopLimitDFTradeSimulator(self.df, config)
trade_sim = StopLimitDFTradeSimulator(self.df, treshold=0.1)
trade_sim.simulate()

self.assertAlmostEqual(trade_sim.wallet, 1.851852, 6)


if __name__ == "__main__":
unittest.main()

0 comments on commit 09f843f

Please sign in to comment.