From 06f24e637e2a341093232ff5dcae5e369f51bddb Mon Sep 17 00:00:00 2001 From: Daniel Rasmussen Date: Tue, 28 Jan 2020 12:34:19 -0400 Subject: [PATCH] Handle networks without input in run_profile --- nengo_dl/benchmarks.py | 8 +++++--- nengo_dl/tests/test_benchmarks.py | 8 +++++++- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/nengo_dl/benchmarks.py b/nengo_dl/benchmarks.py index deff7ef82..401763670 100644 --- a/nengo_dl/benchmarks.py +++ b/nengo_dl/benchmarks.py @@ -649,7 +649,7 @@ def run_profile( sim.minibatch_size * n_batches, n_steps, net.inp.size_out ) } - else: + elif hasattr(net, "inp_a"): x = { net.inp_a: np.random.randn( sim.minibatch_size * n_batches, n_steps, net.inp_a.size_out @@ -658,6 +658,8 @@ def run_profile( sim.minibatch_size * n_batches, n_steps, net.inp_b.size_out ), } + else: + x = None if train: y = { @@ -670,14 +672,14 @@ def run_profile( # run once to eliminate startup overhead start = timeit.default_timer() - sim.fit(x, y, epochs=1) + sim.fit(x, y, epochs=1, n_steps=n_steps) print("Warmup time:", timeit.default_timer() - start) for _ in range(reps): if do_profile: profiler.start() start = timeit.default_timer() - sim.fit(x, y, epochs=1) + sim.fit(x, y, epochs=1, n_steps=n_steps) exec_time = min(timeit.default_timer() - start, exec_time) if do_profile: profiler.save("profile", profiler.stop()) diff --git a/nengo_dl/tests/test_benchmarks.py b/nengo_dl/tests/test_benchmarks.py index e388c3e13..79e4c2c6b 100644 --- a/nengo_dl/tests/test_benchmarks.py +++ b/nengo_dl/tests/test_benchmarks.py @@ -101,7 +101,9 @@ def _test_random( assert all(net.inp in x for x in post_conns.values()) -@pytest.mark.parametrize("network, train", [("integrator", True), ("cconv", False)]) +@pytest.mark.parametrize( + "network, train", [("integrator", True), ("cconv", False), ("test", True)] +) def test_run_profile(network, train, pytestconfig, monkeypatch, tmpdir): monkeypatch.chdir(tmpdir) @@ -109,6 +111,10 @@ def test_run_profile(network, train, pytestconfig, monkeypatch, tmpdir): net = benchmarks.integrator(3, 2, nengo.SpikingRectifiedLinear()) elif network == "cconv": net = benchmarks.cconv(3, 10, nengo.LIF()) + elif network == "test": + with nengo.Network() as net: + ens = nengo.Ensemble(10, 1) + net.p = nengo.Probe(ens) benchmarks.run_profile( net,