Skip to content

Commit

Permalink
Handle networks without input in run_profile
Browse files Browse the repository at this point in the history
  • Loading branch information
drasmuss committed Jan 28, 2020
1 parent 8c21849 commit 06f24e6
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
8 changes: 5 additions & 3 deletions nengo_dl/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ def run_profile(
sim.minibatch_size * n_batches, n_steps, net.inp.size_out
)
}
else:
elif hasattr(net, "inp_a"):
x = {
net.inp_a: np.random.randn(
sim.minibatch_size * n_batches, n_steps, net.inp_a.size_out
Expand All @@ -658,6 +658,8 @@ def run_profile(
sim.minibatch_size * n_batches, n_steps, net.inp_b.size_out
),
}
else:
x = None

if train:
y = {
Expand All @@ -670,14 +672,14 @@ def run_profile(

# run once to eliminate startup overhead
start = timeit.default_timer()
sim.fit(x, y, epochs=1)
sim.fit(x, y, epochs=1, n_steps=n_steps)
print("Warmup time:", timeit.default_timer() - start)

for _ in range(reps):
if do_profile:
profiler.start()
start = timeit.default_timer()
sim.fit(x, y, epochs=1)
sim.fit(x, y, epochs=1, n_steps=n_steps)
exec_time = min(timeit.default_timer() - start, exec_time)
if do_profile:
profiler.save("profile", profiler.stop())
Expand Down
8 changes: 7 additions & 1 deletion nengo_dl/tests/test_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,20 @@ def _test_random(
assert all(net.inp in x for x in post_conns.values())


@pytest.mark.parametrize("network, train", [("integrator", True), ("cconv", False)])
@pytest.mark.parametrize(
"network, train", [("integrator", True), ("cconv", False), ("test", True)]
)
def test_run_profile(network, train, pytestconfig, monkeypatch, tmpdir):
monkeypatch.chdir(tmpdir)

if network == "integrator":
net = benchmarks.integrator(3, 2, nengo.SpikingRectifiedLinear())
elif network == "cconv":
net = benchmarks.cconv(3, 10, nengo.LIF())
elif network == "test":
with nengo.Network() as net:
ens = nengo.Ensemble(10, 1)
net.p = nengo.Probe(ens)

benchmarks.run_profile(
net,
Expand Down

0 comments on commit 06f24e6

Please sign in to comment.