diff --git a/zipline/__main__.py b/zipline/__main__.py index 1db6d9f71c..d9c7b538f3 100644 --- a/zipline/__main__.py +++ b/zipline/__main__.py @@ -1,5 +1,7 @@ import errno import os + +from importlib import import_module from functools import wraps import click @@ -185,17 +187,16 @@ def _(*args, **kwargs): help='Should the algorithm methods be resolved in the local namespace.' )) @click.option( - '--live-trading', - is_flag=True, - default=False, - help='Live trading using IB TWS' + '--broker', + default=None, + help='Broker' ) @click.option( - '--tws-uri', + '--broker-uri', default=None, - metavar='TWS-URI', + metavar='BROKER-URI', show_default=True, - help='Connection to TWS: host:port:client-id.', + help='Connection to broker', ) @click.pass_context def run(ctx, @@ -211,29 +212,46 @@ def run(ctx, output, print_algo, local_namespace, - live_trading, - tws_uri): + broker, + broker_uri): """Run a backtest for the given algorithm. """ # check that the start and end dates are passed correctly - if not live_trading and start is None and end is None: + if not broker and start is None and end is None: # check both at the same time to avoid the case where a user # does not pass either of these and then passes the first only # to be told they need to pass the second argument also ctx.fail( "must specify dates with '-s' / '--start' and '-e' / '--end'", ) - if not live_trading and start is None: + + if not broker and start is None: ctx.fail("must specify a start date with '-s' / '--start'") - if not live_trading and end is None: + if not broker and end is None: ctx.fail("must specify an end date with '-e' / '--end'") - if live_trading and tws_uri is None: - ctx.fail("must specify tws-uri if live-trading is specified") + if broker and broker_uri is None: + ctx.fail("must specify broker-uri if broker is specified") - if live_trading and data_frequency != 'minute': + if broker and data_frequency != 'minute': ctx.fail("must use '--data-frequency minute' with live trading") + brokerobj = None + if broker: + mod_name = 'zipline.gens.brokers.%s_broker' % broker.lower() + try: + bmod = import_module(mod_name) + except ImportError: + ctx.fail("unsupported broker: can't import module %s" % mod_name) + + cl_name = '%sBroker' % broker.upper() + try: + bclass = getattr(bmod, cl_name) + except AttributeError: + ctx.fail("unsupported broker: can't import class %s from %s" % + (cl_name, mod_name)) + brokerobj = bclass(broker_uri) + if (algotext is not None) == (algofile is not None): ctx.fail( "must specify exactly one of '-f' / '--algofile' or" @@ -259,8 +277,7 @@ def run(ctx, print_algo=print_algo, local_namespace=local_namespace, environ=os.environ, - live_trading=live_trading, - tws_uri=tws_uri + broker=brokerobj, ) if output == '-': diff --git a/zipline/algorithm_live.py b/zipline/algorithm_live.py index fc2260a8f3..af61ec26fa 100644 --- a/zipline/algorithm_live.py +++ b/zipline/algorithm_live.py @@ -33,7 +33,6 @@ def __init__(self, *args, **kwargs): class LiveTradingAlgorithm(TradingAlgorithm): def __init__(self, *args, **kwargs): - self.live_trading = kwargs.pop('live_trading', False) self.broker = kwargs.pop('broker', None) super(self.__class__, self).__init__(*args, **kwargs) diff --git a/zipline/utils/run_algo.py b/zipline/utils/run_algo.py index 2e954fc1df..60a14217d1 100644 --- a/zipline/utils/run_algo.py +++ b/zipline/utils/run_algo.py @@ -29,7 +29,6 @@ from zipline.pipeline.loaders import USEquityPricingLoader from zipline.utils.calendars import get_calendar from zipline.utils.factory import create_simulation_parameters -from zipline.gens.brokers import IBBroker import zipline.utils.paths as pth @@ -72,8 +71,7 @@ def _run(handle_data, print_algo, local_namespace, environ, - live_trading, - tws_uri): + broker): """Run a backtest for the given algorithm. This is shared between the cli and :func:`zipline.run_algo`. @@ -122,8 +120,6 @@ def _run(handle_data, else: click.echo(algotext) - broker = IBBroker(tws_uri) if live_trading else None - if bundle is not None: bundle_data = load( bundle, @@ -146,7 +142,7 @@ def _run(handle_data, bundle_data.equity_minute_bar_reader.first_trading_day DataPortalClass = (partial(DataPortalLive, broker) - if live_trading + if broker else DataPortal) data = DataPortalClass( env.asset_finder, get_calendar("NYSE"), @@ -171,13 +167,12 @@ def choose_loader(column): env = None choose_loader = None - if live_trading: + if broker: start = pd.Timestamp.utcnow() end = start + pd.Timedelta('1', 'D') - TradingAlgorithmClass = (partial(LiveTradingAlgorithm, - live_trading=live_trading, - broker=broker) - if live_trading else TradingAlgorithm) + + TradingAlgorithmClass = (partial(LiveTradingAlgorithm, broker=broker) + if broker else TradingAlgorithm) perf = TradingAlgorithmClass( namespace=namespace, @@ -384,6 +379,5 @@ def run_algorithm(start, print_algo=False, local_namespace=False, environ=environ, - live_trading=live_trading, - tws_uri=tws_uri + broker=None, )