Skip to content

Commit

Permalink
ENH: Update validators adding type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
jcfr committed Aug 9, 2023
1 parent ea55a96 commit b99a31b
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 25 deletions.
33 changes: 19 additions & 14 deletions scripts/python/PyAutoscoper/validators.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,39 @@
from __future__ import annotations

import os
from abc import ABC, abstractmethod
from typing import Any


class Validator(ABC):
"""Abstract base class for validators."""

def __set_name__(self, owner, name):
def __set_name__(self, owner, name: str) -> None:
self.private_name = f"_{name}"

def __get__(self, obj, objtype=None):
return getattr(obj, self.private_name)

def __set__(self, obj, value) -> None:
def __set__(self, obj, value: Any) -> None:
self.validate(value)
setattr(obj, self.private_name, value)

@abstractmethod
def validate(self, value):
def validate(self, value: Any):
"""Validate the value."""
pass


class Path(Validator):
"""Validate that the value is a path."""

def __init__(self, directory=False, file=False) -> None:
def __init__(self, directory: bool = False, file: bool = False) -> None:
self.directory = directory
self.file = file
if self.directory and self.file:
raise ValueError("Cannot set both directory and file to True.")

def validate(self, value):
def validate(self, value: str) -> None:
if not isinstance(value, str):
raise TypeError(f"Expected {value!r} to be a string.")
if not os.path.exists(value):
Expand All @@ -44,20 +47,22 @@ def validate(self, value):
class Boolean(Validator):
"""Validate that the value is a bool."""

def validate(self, value):
def validate(self, value: bool) -> None:
if not isinstance(value, bool):
raise TypeError(f"Expected {value!r} to be a bool.")


class Number(Validator):
"""Validate that the value is a number."""

def __init__(self, min=None, max=None, types=None) -> None:
def __init__(
self, min: float | int | None = None, max: float | int | None = None, types: list[type] | None = None
) -> None:
self.min = min
self.max = max
self.types = (int, float) if types is None else types

def validate(self, value):
def validate(self, value: float | int) -> None:
if not isinstance(value, self.types):
raise TypeError(f'Expected {value!r} to be {",".join(self.types)}.')
if self.min is not None and value < self.min:
Expand All @@ -69,25 +74,25 @@ def validate(self, value):
class Integer(Number):
"""Validate that the value is an integer."""

def __init__(self, min=None, max=None) -> None:
def __init__(self, min: int | None = None, max: int | None = None) -> None:
super().__init__(min, max, types=(int,))


class Float(Number):
"""Validate that the value is a float."""

def __init__(self, min=None, max=None) -> None:
def __init__(self, min: float | None = None, max: float | None = None) -> None:
super().__init__(min, max, types=(float,))


class List(Validator):
"""Validate that the value is a list."""

def __init__(self, size=None, types=None) -> None:
def __init__(self, size: int | None = None, types: list[type] | None = None) -> None:
self.size = size
self.types = types

def validate(self, value):
def validate(self, value: list[Any]) -> None:
if not isinstance(value, list):
raise TypeError(f"Expected {value!r} to be a list.")
if self.size is not None and len(value) != self.size:
Expand All @@ -101,12 +106,12 @@ def validate(self, value):
class FloatList(List):
"""Validate that the value is a list of floats."""

def __init__(self, size=None) -> None:
def __init__(self, size: int | None = None) -> None:
super().__init__(size, types=(float,))


class IntegerList(List):
"""Validate that the value is a list of integers."""

def __init__(self, size=None) -> None:
def __init__(self, size: int | None = None) -> None:
super().__init__(size, types=(int,))
1 change: 1 addition & 0 deletions scripts/python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ select = [
]
extend-ignore = [
"ANN101", # missing-type-self
"ANN102", # missing-type-cls
"G004", # logging-f-string
"PIE790", # unnecessary-pass
"PLR0913", # too many arguments to function call
Expand Down
6 changes: 3 additions & 3 deletions scripts/python/tests/test_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@


class TestEnums(unittest.TestCase):
def test_cost_function(self):
def test_cost_function(self) -> None:
assert CostFunction.NORMALIZED_CROSS_CORRELATION.value == 0
assert CostFunction.SUM_OF_ABSOLUTE_DIFFERENCES.value == 1
assert CostFunction(0).name == "NORMALIZED_CROSS_CORRELATION"
assert CostFunction(1).name == "SUM_OF_ABSOLUTE_DIFFERENCES"
self.assertRaises(ValueError, CostFunction, 2)
self.assertRaises(ValueError, CostFunction, "0")

def test_optimization_initialization_heuristic(self):
def test_optimization_initialization_heuristic(self) -> None:
assert OptimizationInitializationHeuristic.CURRENT_FRAME.value == 0
assert OptimizationInitializationHeuristic.PREVIOUS_FRAME.value == 1
assert OptimizationInitializationHeuristic.LINEAR_EXTRAPOLATION.value == 2
Expand All @@ -24,7 +24,7 @@ def test_optimization_initialization_heuristic(self):
self.assertRaises(ValueError, OptimizationInitializationHeuristic, 4)
self.assertRaises(ValueError, OptimizationInitializationHeuristic, "0")

def test_optimization_method(self):
def test_optimization_method(self) -> None:
assert OptimizationMethod.PARTICLE_SWARM_OPTIMIZATION.value == 0
assert OptimizationMethod.DOWNHILL_SIMPLEX.value == 1
assert OptimizationMethod(0).name == "PARTICLE_SWARM_OPTIMIZATION"
Expand Down
16 changes: 8 additions & 8 deletions scripts/python/tests/test_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,50 +13,50 @@

class TestValidtors(unittest.TestCase):
@classmethod
def setUpClass(cls):
def setUpClass(cls) -> None:
cls.test_file = os.path.join(os.path.dirname(__file__), "test_file.txt")
cls.test_dir = os.path.join(os.path.dirname(__file__), "test_dir")
os.makedirs(cls.test_dir)
with open(cls.test_file, "w") as f:
f.write("Autogenerated test file.")

@classmethod
def tearDownClass(cls):
def tearDownClass(cls) -> None:
os.remove(cls.test_file)
os.rmdir(cls.test_dir)

def test_float(self):
def test_float(self) -> None:
Float().validate(1.0)
self.assertRaises(TypeError, Float().validate, 1)
self.assertRaises(TypeError, Float().validate, "1.0")
self.assertRaises(ValueError, Float(min=0.0).validate, -1.0)
self.assertRaises(ValueError, Float(max=0.0).validate, 1.0)

def test_integer(self):
def test_integer(self) -> None:
Integer().validate(1)
self.assertRaises(TypeError, Integer().validate, 1.0)
self.assertRaises(TypeError, Integer().validate, "1")
self.assertRaises(ValueError, Integer(min=0).validate, -1)
self.assertRaises(ValueError, Integer(max=0).validate, 1)

def test_float_list(self):
def test_float_list(self) -> None:
FloatList().validate([1.0, 2.0])
self.assertRaises(TypeError, FloatList().validate, [1, 2])
self.assertRaises(TypeError, FloatList().validate, ["1.0", "2.0"])
self.assertRaises(ValueError, FloatList(size=2).validate, [1.0])

def test_integer_list(self):
def test_integer_list(self) -> None:
IntegerList().validate([1, 2])
self.assertRaises(TypeError, IntegerList().validate, [1.0, 2.0])
self.assertRaises(TypeError, IntegerList().validate, ["1", "2"])
self.assertRaises(ValueError, IntegerList(size=2).validate, [1])

def test_boolean(self):
def test_boolean(self) -> None:
Boolean().validate(True)
self.assertRaises(TypeError, Boolean().validate, 1)
self.assertRaises(TypeError, Boolean().validate, "True")

def test_path(self):
def test_path(self) -> None:
Path().validate(self.test_file)
Path().validate(self.test_dir)
Path(file=True).validate(self.test_file)
Expand Down

0 comments on commit b99a31b

Please sign in to comment.