Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MAINT: Pytest refactoring #15

Merged
merged 5 commits into from
May 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/zipline/examples/dual_ema_talib.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
momentum).

"""

import os
from zipline.api import order, record, symbol
from zipline.finance import commission, slippage

Expand Down Expand Up @@ -127,6 +127,9 @@ def analyze(context=None, results=None):

plt.show()

if "PYTEST_CURRENT_TEST" in os.environ:
plt.close("all")


def _test_args():
"""Extra arguments to use when zipline's automated tests run this example."""
Expand Down
5 changes: 4 additions & 1 deletion src/zipline/examples/dual_moving_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
its shares once the averages cross again (indicating downwards
momentum).
"""

import os
from zipline.api import order_target, record, symbol
from zipline.finance import commission, slippage

Expand Down Expand Up @@ -114,6 +114,9 @@ def analyze(context=None, results=None):

plt.show()

if "PYTEST_CURRENT_TEST" in os.environ:
plt.close("all")


def _test_args():
"""Extra arguments to use when zipline's automated tests run this example."""
Expand Down
88 changes: 68 additions & 20 deletions src/zipline/testing/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,9 @@ def assert_single_position(test, zipline):
test.zipline_test_config["expected_transactions"], transaction_count
)
else:
test.assertEqual(test.zipline_test_config["order_count"], transaction_count)
test.assertEqual(
test.zipline_test_config["order_count"], transaction_count
)

# the final message is the risk report, the second to
# last is the final day's results. Positions is a list of
Expand All @@ -149,7 +151,9 @@ def assert_single_position(test, zipline):
for order in orders_by_id.value():
test.assertEqual(order["status"], ORDER_STATUS.FILLED, "")

test.assertEqual(len(closing_positions), 1, "Portfolio should have one position.")
test.assertEqual(
len(closing_positions), 1, "Portfolio should have one position."
)

sid = test.zipline_test_config["sid"]
test.assertEqual(
Expand Down Expand Up @@ -181,7 +185,8 @@ def security_list_copy():
def add_security_data(adds, deletes):
if not hasattr(security_list, "using_copy"):
raise Exception(
"add_security_data must be used within " "security_list_copy context"
"add_security_data must be used within "
"security_list_copy context"
)
directory = os.path.join(
security_list.SECURITY_LISTS_DIR, "leveraged_etf_list/20150127/20150125"
Expand Down Expand Up @@ -270,7 +275,9 @@ def make_trade_data_for_asset_info(
sids = asset_info.index

price_sid_deltas = np.arange(len(sids), dtype=float64) * price_step_by_sid
price_date_deltas = np.arange(len(dates), dtype=float64) * price_step_by_date
price_date_deltas = (
np.arange(len(dates), dtype=float64) * price_step_by_date
)
prices = (price_sid_deltas + as_column(price_date_deltas)) + price_start

volume_sid_deltas = np.arange(len(sids)) * volume_step_by_sid
Expand Down Expand Up @@ -301,7 +308,9 @@ def make_trade_data_for_asset_info(
return trade_data


def check_allclose(actual, desired, rtol=1e-07, atol=0, err_msg="", verbose=True):
def check_allclose(
actual, desired, rtol=1e-07, atol=0, err_msg="", verbose=True
):
"""
Wrapper around np.testing.assert_allclose that also verifies that inputs
are ndarrays.
Expand Down Expand Up @@ -439,10 +448,17 @@ def write_daily_data(tempdir, sim_params, sids, trading_calendar):


def create_data_portal(
asset_finder, tempdir, sim_params, sids, trading_calendar, adjustment_reader=None
asset_finder,
tempdir,
sim_params,
sids,
trading_calendar,
adjustment_reader=None,
):
if sim_params.data_frequency == "daily":
daily_path = write_daily_data(tempdir, sim_params, sids, trading_calendar)
daily_path = write_daily_data(
tempdir, sim_params, sids, trading_calendar
)

equity_daily_reader = BcolzDailyBarReader(daily_path)

Expand All @@ -458,7 +474,9 @@ def create_data_portal(
sim_params.first_open, sim_params.last_close
)

minute_path = write_minute_data(trading_calendar, tempdir, minutes, sids)
minute_path = write_minute_data(
trading_calendar, tempdir, minutes, sids
)

equity_minute_reader = BcolzMinuteBarReader(minute_path)

Expand All @@ -478,9 +496,16 @@ def write_bcolz_minute_data(trading_calendar, days, path, data):


def create_minute_df_for_asset(
trading_calendar, start_dt, end_dt, interval=1, start_val=1, minute_blacklist=None
trading_calendar,
start_dt,
end_dt,
interval=1,
start_val=1,
minute_blacklist=None,
):
asset_minutes = trading_calendar.minutes_for_sessions_in_range(start_dt, end_dt)
asset_minutes = trading_calendar.minutes_for_sessions_in_range(
start_dt, end_dt
)
minutes_count = len(asset_minutes)

if interval > 1:
Expand Down Expand Up @@ -579,7 +604,10 @@ def create_data_portal_from_trade_history(
if sim_params.data_frequency == "daily":
path = os.path.join(tempdir.path, "testdaily.bcolz")
writer = BcolzDailyBarWriter(
path, trading_calendar, sim_params.start_session, sim_params.end_session
path,
trading_calendar,
sim_params.start_session,
sim_params.end_session,
)
writer.write(
trades_by_sid_to_dfs(trades_by_sid, sim_params.sessions),
Expand Down Expand Up @@ -644,7 +672,9 @@ def create_data_portal_from_trade_history(


class FakeDataPortal(DataPortal):
def __init__(self, asset_finder, trading_calendar=None, first_trading_day=None):
def __init__(
self, asset_finder, trading_calendar=None, first_trading_day=None
):
if trading_calendar is None:
trading_calendar = get_calendar("NYSE")

Expand All @@ -665,7 +695,14 @@ def get_scalar_asset_spot_value(self, asset, field, dt, data_frequency):
return 1.0

def get_history_window(
self, assets, end_dt, bar_count, frequency, field, data_frequency, ffill=True
self,
assets,
end_dt,
bar_count,
frequency,
field,
data_frequency,
ffill=True,
):
end_idx = self.trading_calendar.all_sessions.searchsorted(end_dt)
days = self.trading_calendar.all_sessions[
Expand Down Expand Up @@ -712,7 +749,9 @@ def get_spot_value(self, asset, field, dt, data_frequency):
# XXX: These aren't actually the methods that are used by the superclasses,
# so these don't do anything, and this class will likely produce unexpected
# results for history().
def _get_daily_window_for_sid(self, asset, field, days_in_window, extra_slot=True):
def _get_daily_window_for_sid(
self, asset, field, days_in_window, extra_slot=True
):
return np.arange(days_in_window, dtype=np.float64)

def _get_minute_window_for_asset(self, asset, field, minutes_for_window):
Expand Down Expand Up @@ -740,7 +779,9 @@ class tmp_assets_db(object):

_default_equities = sentinel("_default_equities")

def __init__(self, url="sqlite:///:memory:", equities=_default_equities, **frames):
def __init__(
self, url="sqlite:///:memory:", equities=_default_equities, **frames
):
self._url = url
self._eng = None
if equities is self._default_equities:
Expand Down Expand Up @@ -1326,14 +1367,16 @@ def read_compressed(path):
return f.read()


zipline_git_root = abspath(
join(realpath(dirname(__file__)), "..", ".."),
zipline_reloaded_git_root = abspath(
join(realpath(dirname(__file__)), "..", "..", ".."),
)


# @nottest
def test_resource_path(*path_parts):
return os.path.join(zipline_git_root, "tests", "resources", *path_parts)
return os.path.join(
zipline_reloaded_git_root, "tests", "resources", *path_parts
)


@contextmanager
Expand Down Expand Up @@ -1672,7 +1715,9 @@ def simulate_minutes_for_day(

min_ = min(close, open_)
where = values < min_
values[where] = (values[where] - min_) * (low - min_) / (values.min() - min_) + min_
values[where] = (values[where] - min_) * (low - min_) / (
values.min() - min_
) + min_

if not (np.allclose(values.max(), high) and np.allclose(values.min(), low)):
return simulate_minutes_for_day(
Expand Down Expand Up @@ -1757,7 +1802,10 @@ def exchange_info_for_domains(domains):
"""
return pd.DataFrame.from_records(
[
{"exchange": domain.calendar.name, "country_code": domain.country_code}
{
"exchange": domain.calendar.name,
"country_code": domain.country_code,
}
for domain in domains
]
)
35 changes: 18 additions & 17 deletions tests/pipeline/test_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
)
from zipline.pipeline.factors import CustomFactor
import zipline.testing.fixtures as zf
from zipline.testing.core import parameter_space, powerset
from zipline.testing.core import powerset
from zipline.testing.predicates import assert_equal, assert_messages_equal
from zipline.utils.pandas_utils import days_at_time
import pytest
Expand Down Expand Up @@ -117,8 +117,8 @@ def run(ts):
assert_equal(result, expected)


class SpecializeTestCase(zf.ZiplineTestCase):
@parameter_space(domain=BUILT_IN_DOMAINS)
class TestSpecialize:
@pytest.mark.parametrize("domain", BUILT_IN_DOMAINS)
def test_specialize(self, domain):
class MyData(DataSet):
col1 = Column(dtype=float)
Expand Down Expand Up @@ -166,7 +166,7 @@ def do_checks(cls, colnames):
do_checks(MyData, ["col1", "col2", "col3"])
do_checks(MyDataSubclass, ["col1", "col2", "col3", "col4"])

@parameter_space(domain=BUILT_IN_DOMAINS)
@pytest.mark.parametrize("domain", BUILT_IN_DOMAINS)
def test_unspecialize(self, domain):
class MyData(DataSet):
col1 = Column(dtype=float)
Expand Down Expand Up @@ -197,7 +197,7 @@ def do_checks(cls, colnames):
do_checks(MyData, ["col1", "col2", "col3"])
do_checks(MyDataSubclass, ["col1", "col2", "col3", "col4"])

@parameter_space(domain_param=[BE_EQUITIES, CA_EQUITIES, CH_EQUITIES])
@pytest.mark.parametrize("domain_param", [BE_EQUITIES, CA_EQUITIES, CH_EQUITIES])
def test_specialized_root(self, domain_param):
different_domain = GB_EQUITIES

Expand Down Expand Up @@ -253,7 +253,7 @@ class D(DataSet):
c3 = Column(object)


class InferDomainTestCase(zf.ZiplineTestCase):
class TestInferDomain:
def check(self, inputs, expected):
result = infer_domain(inputs)
assert result is expected
Expand All @@ -274,15 +274,15 @@ def test_all_generic(self):
self.check([D.c1, D.c2, D.c3], GENERIC)
self.check([D.c1.latest, D.c2.latest, D.c3.latest], GENERIC)

@parameter_space(domain=[US_EQUITIES, GB_EQUITIES])
@pytest.mark.parametrize("domain", [US_EQUITIES, GB_EQUITIES])
def test_all_non_generic(self, domain):
D_s = D.specialize(domain)
self.check([D_s.c1], domain)
self.check([D_s.c1, D_s.c2], domain)
self.check([D_s.c1, D_s.c2, D_s.c3], domain)
self.check([D_s.c1, D_s.c2, D_s.c3.latest], domain)

@parameter_space(domain=[US_EQUITIES, GB_EQUITIES])
@pytest.mark.parametrize("domain", [US_EQUITIES, GB_EQUITIES])
def test_mix_generic_and_specialized(self, domain):
D_s = D.specialize(domain)
self.check([D.c1, D_s.c3], domain)
Expand Down Expand Up @@ -332,7 +332,7 @@ def test_ambiguous_domain_repr(self):
assert_messages_equal(result, expected)


class DataQueryCutoffForSessionTestCase(zf.ZiplineTestCase):
class TestDataQueryCutoffForSession:
def test_generic(self):
sessions = pd.date_range("2014-01-01", "2014-06-01")
with pytest.raises(NotImplementedError):
Expand Down Expand Up @@ -441,7 +441,7 @@ def test_equity_calendar_domain(self):
expected_cutoff_date_offset=-7,
)

@parameter_space(domain=BUILT_IN_DOMAINS)
@pytest.mark.parametrize("domain", BUILT_IN_DOMAINS)
def test_equity_calendar_not_aligned(self, domain):
valid_sessions = domain.all_sessions()[:50]
sessions = pd.date_range(valid_sessions[0], valid_sessions[-1])
Expand All @@ -457,8 +457,9 @@ def test_equity_calendar_not_aligned(self, domain):

Case = namedtuple("Case", "time date_offset expected_timedelta")

@parameter_space(
parameters=(
@pytest.mark.parametrize(
"parameters",
(
Case(
time=datetime.time(8, 45, tzinfo=pytz.utc),
date_offset=0,
Expand Down Expand Up @@ -496,7 +497,7 @@ def test_equity_calendar_not_aligned(self, domain):
["4 hours 30 minutes"] * 93 + ["3 hours 30 minutes"] * 60,
),
),
)
),
)
def test_equity_session_domain(self, parameters):
time, date_offset, expected_timedelta = parameters
Expand All @@ -519,7 +520,7 @@ def test_equity_session_domain(self, parameters):
assert_equal(expected, actual)


class RollForwardTestCase(zf.ZiplineTestCase):
class TestRollForward:
def test_roll_forward(self):
# January 2017
# Su Mo Tu We Th Fr Sa
Expand Down Expand Up @@ -557,8 +558,8 @@ def test_roll_forward(self):

expected_msg = (
f"Date {after_last_session.date()} was past the last session "
"for domain EquityCalendarDomain('JP', 'XTKS'). "
f"The last session for this domain is {JP_EQUITIES.calendar.last_session.date()}."
"for domain EquityCalendarDomain('JP', 'XTKS'). The last session for "
f"this domain is {JP_EQUITIES.calendar.last_session.date()}."
)
with pytest.raises(ValueError, match=re.escape(expected_msg)):
JP_EQUITIES.roll_forward(after_last_session)
Expand All @@ -580,6 +581,6 @@ def test_roll_forward(self):
)


class ReprTestCase(zf.ZiplineTestCase):
class TestRepr:
def test_generic_domain_repr(self):
assert repr(GENERIC) == "GENERIC"
2 changes: 1 addition & 1 deletion tests/pipeline/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1588,7 +1588,7 @@ def test_no_groupby_maximum(self):
assert_equal(groupby_max.to_numpy(), pipeline_max.to_numpy())


class ResolveDomainTestCase(zf.ZiplineTestCase):
class TestResolveDomain:
def test_resolve_domain(self):
# we need to pass a get_loader and an asset_finder to construct
# SimplePipelineEngine, but do not expect to use them
Expand Down
Binary file modified tests/resources/example_data.tar.gz
Binary file not shown.
Loading