Skip to content

Commit

Permalink
don't save local toml files; add tensorflow tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
pbmanis committed Jan 29, 2025
1 parent 6451913 commit 51a9c16
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 1 deletion.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
ephys/ephysanalysis/c_deriv.c
ephys/mini_analyses/clembek.c
archived/
# file history:
data/files.toml
mini_viewer_recent_files.toml

# excel files
testbridge.xlsx
untitled*.*

Expand Down
20 changes: 20 additions & 0 deletions ephys/tools/tests/test_tensorflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# test for tensorflow install.

try:
import tensorflow as tf
except:
raise ImportError("tensorflow is not installed.")
print(f"Testing tensorflow version: {tf.__version__}")
cifar = tf.keras.datasets.cifar100
(x_train, y_train), (x_test, y_test) = cifar.load_data()
model = tf.keras.applications.ResNet50(
include_top=True,
weights=None,
input_shape=(32, 32, 3),
classes=100,)

loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
model.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"])
print(" Tensorflow - starting fit")
model.fit(x_train, y_train, epochs=5, batch_size=64)
print(" Tensorflow test complete.")
7 changes: 6 additions & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@ def main():

# Allow user to audit tests with --audit flag
import ephys.ephys_analysis
import ephys.tools
if '--audit' in sys.argv:
sys.argv.remove('--audit')
sys.argv.append('-s') # needed for cli-based user interaction
ephys.mini_analyses.AUDIT_TESTS = True

if '--tensor_flow_test' in sys.argv:
sys.argv.remove('--tensor_flow_test')
sys.argc.append('-s')
flags.append('ephys/tools/tests/test_tensorflow.py')
# generate test flags
flags = sys.argv[1:]
flags.append('-v')
Expand All @@ -39,6 +43,7 @@ def main():
flags.append('ephys/mini_analyses')
flags.append('ephys/ephys_analysis')
flags.append('ephys/psc_analysis')

print("flags: ", flags)
# ignore the an cache
# flags.append('--ignore=minis/somedir')
Expand Down

0 comments on commit 51a9c16

Please sign in to comment.