diff --git a/requirements.txt b/requirements.txt index 5e65d19..7eb25cc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -keras>=2.0.0 -tensorflow>=1.4.0 \ No newline at end of file +tensorflow>=2.3.0 +pytest \ No newline at end of file diff --git a/tests/wrapper_test.py b/tests/wrapper_test.py index 42debba..23b2631 100644 --- a/tests/wrapper_test.py +++ b/tests/wrapper_test.py @@ -1,11 +1,11 @@ import sys import numpy as np +import pytest -from keras.models import Model -from keras.layers import Input -from keras.layers import Lambda +from tensorflow.keras.models import Model +from tensorflow.keras.layers import Input +from tensorflow.keras.layers import Lambda -sys.path.append('..') from tta_wrapper import tta_classification from tta_wrapper import tta_segmentation @@ -16,7 +16,27 @@ def identity_model(input_shape): identity_model = Model(inp, x) return identity_model - +# inputs +input_sample = np.arange(9).reshape((1, 3, 3, 1)) + +# outputs +tta_segmentation_output = input_sample +tta_classification_output = np.ones((1, 3, 3, 1)) * 4 + +# model +seg_identity_model = identity_model(input_sample.shape[1:]) +cls_identity_model = identity_model(input_sample.shape[1:]) + +@pytest.mark.parametrize("wrapper, base_model, inputs, outputs", + [(tta_segmentation, + seg_identity_model, + input_sample, + tta_segmentation_output), + (tta_classification, + cls_identity_model, + input_sample, + tta_classification_output) + ]) def test_wrapper(wrapper, base_model, inputs, outputs): print('[TEST] wrapping model with {} ... '.format(wrapper.__name__)) @@ -35,21 +55,4 @@ def test_wrapper(wrapper, base_model, inputs, outputs): prediction = model.predict(inputs) assert np.allclose(prediction, outputs), f"\nprediction: \n{prediction}\n\nground_truth: \n{outputs}" - print('[TEST] {} - test passed. '.format(wrapper.__name__)) - - -if __name__ == '__main__': - - # inputs - input_sample = np.arange(9).reshape((1, 3, 3, 1)) - - # outputs - tta_segmentation_output = input_sample - tta_classification_output = np.ones((1, 3, 3, 1)) * 4 - - # model - seg_identity_model = identity_model(input_sample.shape[1:]) - cls_identity_model = identity_model(input_sample.shape[1:]) - - test_wrapper(tta_segmentation, seg_identity_model, input_sample, tta_segmentation_output) - test_wrapper(tta_classification, cls_identity_model, input_sample, tta_classification_output) \ No newline at end of file + print('[TEST] {} - test passed. '.format(wrapper.__name__)) \ No newline at end of file diff --git a/tta_wrapper/__version__.py b/tta_wrapper/__version__.py index 1b095da..fd88f61 100644 --- a/tta_wrapper/__version__.py +++ b/tta_wrapper/__version__.py @@ -1,3 +1,3 @@ -VERSION = (0, 0, 1) +VERSION = (0, 0, 2) __version__ = '.'.join(map(str, VERSION)) \ No newline at end of file diff --git a/tta_wrapper/functional.py b/tta_wrapper/functional.py index 77014eb..e681694 100644 --- a/tta_wrapper/functional.py +++ b/tta_wrapper/functional.py @@ -1,4 +1,5 @@ import tensorflow as tf +from tensorflow.python.ops import manip_ops class DualTransform: @@ -81,10 +82,10 @@ class HShift(DualTransform): identity_param = 0 def forward(self, image, param): - return tf.manip.roll(image, param, axis=0) + return manip_ops.roll(image, param, axis=0) def backward(self, image, param): - return tf.manip.roll(image, -param, axis=0) + return manip_ops.roll(image, -param, axis=0) class VShift(DualTransform): @@ -92,10 +93,10 @@ class VShift(DualTransform): identity_param = 0 def forward(self, image, param): - return tf.manip.roll(image, param, axis=1) + return manip_ops.roll(image, param, axis=1) def backward(self, image, param): - return tf.manip.roll(image, -param, axis=1) + return manip_ops.roll(image, -param, axis=1) class Contrast(SingleTransform): diff --git a/tta_wrapper/layers.py b/tta_wrapper/layers.py index 3167df2..eec04ad 100644 --- a/tta_wrapper/layers.py +++ b/tta_wrapper/layers.py @@ -1,5 +1,5 @@ import tensorflow as tf -from keras.layers import Layer +from tensorflow.keras.layers import Layer from . import functional as F diff --git a/tta_wrapper/wrappers.py b/tta_wrapper/wrappers.py index 30bfe39..b752041 100644 --- a/tta_wrapper/wrappers.py +++ b/tta_wrapper/wrappers.py @@ -1,5 +1,5 @@ -from keras.models import Model -from keras.layers import Input +from tensorflow.keras.models import Model +from tensorflow.keras.layers import Input from .layers import Repeat, TTA, Merge from .augmentation import Augmentation