Skip to content

Commit

Permalink
Work-in-progress to fix #94
Browse files Browse the repository at this point in the history
This allows for distributed tests with single load of data in session
scope

	modified:   conftest.py
	modified:   reg/test_inference.py
  • Loading branch information
tallamjr committed Jun 12, 2022
1 parent bc54ca7 commit 5483eac
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 39 deletions.
75 changes: 38 additions & 37 deletions astronet/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import inspect
import json
import subprocess

import numpy as np
import pandas as pd
import pytest
import tensorflow as tf
from filelock import FileLock

from astronet.constants import ASTRONET_WORKING_DIRECTORY as asnwd
from astronet.constants import LOCAL_DEBUG
from astronet.utils import astronet_logger

log = astronet_logger(__file__)

ISA = subprocess.run(
"uname -m",
Expand All @@ -22,6 +27,25 @@
BATCH_SIZE = 64


class NumpyEncoder(json.JSONEncoder):
"""Special json encoder for numpy types"""

def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
return json.JSONEncoder.default(self, obj)


def pandas_encoder(obj):
# TODO: Reshape required to fix ValueError: Must pass 2-d input. shape=(869864, 100, 6)
log.critical(f"{inspect.stack()[0].function} -- Not Fully Implemented Yet")
return pd.DataFrame(obj).to_json(orient="values")


@pytest.fixture(scope="session")
def get_fixt_UGRIZY_wZ(tmp_path_factory, worker_id, name="fixt_UGRIZY_wZ"):
if not worker_id:
Expand All @@ -35,22 +59,20 @@ def get_fixt_UGRIZY_wZ(tmp_path_factory, worker_id, name="fixt_UGRIZY_wZ"):
fn = root_tmp_dir / "data.json"
with FileLock(str(fn) + ".lock"):
if fn.is_file():
tdata = json.loads(fn.read_text())
ldata = list(tdata)
for item, index in enumerate(ldata):
ldata[item] = tf.convert_to_tensor(index)
data = tuple(ldata)
data = json.loads(fn.read_text())
X_test = np.asarray(data["X_test"])
y_test = np.asarray(data["y_test"])
Z_test = np.asarray(data["Z_test"])
else:
tdata = fixt_UGRIZY_wZ()
import pdb

pdb.set_trace()
ldata = list(tdata)
for item, index in enumerate(ldata):
ldata[item] = item.numpy()
data = tuple(ldata)
fn.write_text(json.dumps(data))
return data
X_test, y_test, Z_test = fixt_UGRIZY_wZ()
fn.write_text(
json.dumps(
{"X_test": X_test, "y_test": y_test, "Z_test": Z_test},
cls=NumpyEncoder,
# default=pandas_encoder,
)
)
return X_test, y_test, Z_test


def fixt_UGRIZY_wZ():
Expand All @@ -65,28 +87,7 @@ def fixt_UGRIZY_wZ():
f"{asnwd}/data/plasticc/test_set/infer/Z_test.npy",
)

test_input = [X_test, Z_test]

test_ds = (
tf.data.Dataset.from_tensor_slices(
({"input_1": test_input[0], "input_2": test_input[1]}, y_test)
)
.batch(BATCH_SIZE, drop_remainder=False)
.prefetch(tf.data.AUTOTUNE)
)

y_test_ds = (
tf.data.Dataset.from_tensor_slices(y_test)
.batch(BATCH_SIZE, drop_remainder=False)
.prefetch(tf.data.AUTOTUNE)
)

if LOCAL_DEBUG is not None:
print("LOCAL_DEBUG set, reducing dataset size...")
test_ds = test_ds.take(300)
y_test_ds = y_test_ds.take(300)

return test_ds, y_test_ds, test_input
return X_test, y_test, Z_test


@pytest.fixture(scope="session")
Expand Down
34 changes: 32 additions & 2 deletions astronet/tests/reg/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from tensorflow import keras

from astronet.constants import ASTRONET_WORKING_DIRECTORY as asnwd
from astronet.constants import LOCAL_DEBUG
from astronet.metrics import WeightedLogLoss
from astronet.tests.conftest import BATCH_SIZE
from astronet.tinho.lite import LiteModel
from astronet.utils import astronet_logger

Expand Down Expand Up @@ -56,7 +58,35 @@ def test_inference_UGRIZY_wZ(
# Previous models were trained using numpy data as the inputs, newer models leverage
# tf.data.Dataset instead for faster inference. This is a legacy requirment.
# Fix ValueError of shape mismatch.
test_ds, y_test_ds, test_inputs = get_fixt_UGRIZY_wZ
X_test, y_test, Z_test = get_fixt_UGRIZY_wZ

test_input = [X_test, Z_test]

test_ds = (
tf.data.Dataset.from_tensor_slices(
({"input_1": test_input[0], "input_2": test_input[1]}, y_test)
)
.batch(BATCH_SIZE, drop_remainder=False)
.prefetch(tf.data.AUTOTUNE)
)

y_test_ds = (
tf.data.Dataset.from_tensor_slices(y_test)
.batch(BATCH_SIZE, drop_remainder=False)
.prefetch(tf.data.AUTOTUNE)
)

if LOCAL_DEBUG is not None:
log.info("LOCAL_DEBUG set, reducing dataset size...")
test_ds = test_ds.take(300)
y_test_ds = y_test_ds.take(300)

worker_id = (
os.environ.get("PYTEST_XDIST_WORKER")
if "PYTEST_CURRENT_TEST" in os.environ
else 0
)
log.info(f"Data loaded successfully on worker: {worker_id}")

model = keras.models.load_model(
f"{asnwd}/astronet/{architecture}/models/{dataset}/model-{model_name}",
Expand All @@ -65,7 +95,7 @@ def test_inference_UGRIZY_wZ(
)

wloss = WeightedLogLoss()
y_preds = model.predict(test_inputs)
y_preds = model.predict(test_input)

y_test = np.concatenate([y for y in y_test_ds], axis=0)

Expand Down

0 comments on commit 5483eac

Please sign in to comment.