Skip to content

Commit

Permalink
Examples: Adjust imports, update some code
Browse files Browse the repository at this point in the history
Apparently MNIST is inaccessible from lecun.com, so we try a mirror.
  • Loading branch information
nicholasjng committed Dec 3, 2024
1 parent ad1e3c5 commit afb2763
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions examples/huggingface/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@


def main() -> None:
console_reporter = nnbench.ConsoleReporter()
runner = nnbench.BenchmarkRunner()
reporter = nnbench.ConsoleReporter()
result = runner.run("benchmark.py", tags=("per-class",))
console_reporter.display(result)
reporter.display(result)


if __name__ == "__main__":
Expand Down
6 changes: 3 additions & 3 deletions examples/mnist/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class MNISTTestParameters(nnbench.Parameters):


class ConvNet(nn.Module):
@nn.to_json
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
Expand Down Expand Up @@ -101,7 +101,7 @@ def load_mnist() -> ArrayMapping:

mnist: ArrayMapping = {}

baseurl = "http://yann.lecun.com/exdb/mnist/"
baseurl = "https://storage.googleapis.com/cvdf-datasets/mnist/"

for key, file in [
("x_train", "train-images-idx3-ubyte.gz"),
Expand Down Expand Up @@ -217,7 +217,7 @@ def mnist_jax():

# the nnbench portion.
runner = nnbench.BenchmarkRunner()
reporter = nnbench.reporter.FileReporter()
reporter = nnbench.FileReporter()
params = MNISTTestParameters(params=state.params, data=data)
result = runner.run(HERE, params=params)
reporter.write(result, "result.json")
Expand Down

0 comments on commit afb2763

Please sign in to comment.