diff --git a/.github/workflows/reusable-precommit.yml b/.github/workflows/reusable-precommit.yml index b2115ef4..262a18a6 100644 --- a/.github/workflows/reusable-precommit.yml +++ b/.github/workflows/reusable-precommit.yml @@ -101,6 +101,8 @@ jobs: - "3.8" - "3.9" - "3.10" + - "3.11" + - "3.12" integration_tests: runs-on: ubuntu-latest needs: @@ -148,6 +150,8 @@ jobs: - "3.8" - "3.9" - "3.10" + - "3.11" + - "3.12" doc_tests: runs-on: ubuntu-latest needs: diff --git a/pybatfish/client/asserts.py b/pybatfish/client/asserts.py index 60efbfab..00efd00e 100644 --- a/pybatfish/client/asserts.py +++ b/pybatfish/client/asserts.py @@ -134,20 +134,11 @@ def _get_duplicate_router_ids( .frame() ) if ignore_same_node: - # Maps Router_ID to whether multiple nodes have that Router_ID - router_id_on_duplicate_nodes = ( - df.drop_duplicates(["Node", "Router_ID"]) - .value_counts(["Router_ID"]) - .map(lambda x: x > 1) + return df.groupby("Router_ID").filter( + lambda x: x["Node"].nunique() > 1 and x["Node"].nunique() != len(x) ) - df_duplicate = df[ - df.apply(lambda x: router_id_on_duplicate_nodes[x["Router_ID"]], axis=1) - ].sort_values(["Router_ID"]) else: - df_duplicate = df[df.duplicated(["Router_ID"], keep=False)].sort_values( - ["Router_ID"] - ) - return df_duplicate + return df[df.duplicated(["Router_ID"], keep=False)].sort_values(["Router_ID"]) def _is_dict_match(actual: Dict[str, Any], expected: Dict[str, Any]) -> bool: @@ -782,7 +773,7 @@ def assert_no_duplicate_router_ids( supported_protocols = {"bgp", "ospf"} protocols_to_fetch = ( - supported_protocols if protocols is None else set(map(str.lower, protocols)) + supported_protocols if protocols is None else set(p.lower() for p in protocols) ) if not protocols_to_fetch.issubset(supported_protocols): raise ValueError( diff --git a/pybatfish/client/restv2helper.py b/pybatfish/client/restv2helper.py index 702736fc..1dcf7ae2 100644 --- a/pybatfish/client/restv2helper.py +++ b/pybatfish/client/restv2helper.py @@ -630,7 +630,7 @@ def auto_complete( if session.snapshot else "", CoordConstsV2.RSC_AUTOCOMPLETE, - completion_type, + completion_type.value, ) params = {} # type: Dict[str, Any] if query: diff --git a/pybatfish/question/question.py b/pybatfish/question/question.py index c79ffad4..96619e4c 100644 --- a/pybatfish/question/question.py +++ b/pybatfish/question/question.py @@ -712,7 +712,7 @@ def _validate(questionJson): else: for i in range(0, len(value)): valueElement = value[i] - typeValid = _validateType(valueElement, variableType) + typeValid = _validate_type(valueElement, variableType) if not typeValid: valid = False errorMessage += ( @@ -750,7 +750,7 @@ def _validate(questionJson): ) else: - typeValid, typeValidErrorMessage = _validateType( + typeValid, typeValidErrorMessage = _validate_type( value, variableType ) if not typeValid: @@ -796,7 +796,9 @@ def _validate(questionJson): return True -def _validateType(value, expectedType): +def _validate_type( + value: Any, expected_type: Union[str, VariableType] +) -> Tuple[bool, Optional[str]]: """ Check if the input `value` have contents that matches the requirements specified by `expectedType`. @@ -805,28 +807,31 @@ def _validateType(value, expectedType): :raises QuestionValidationException """ - if expectedType == VariableType.BOOLEAN: + if not isinstance(expected_type, VariableType): + expected_type = VariableType(expected_type) + + if expected_type == VariableType.BOOLEAN: return isinstance(value, bool), None - elif expectedType == VariableType.COMPARATOR: - validComparators = ["<", "<=", "==", ">=", ">", "!="] - if value not in validComparators: + elif expected_type == VariableType.COMPARATOR: + valid_comparators = ["<", "<=", "==", ">=", ">", "!="] + if value not in valid_comparators: return ( False, "'{}' is not a known comparator. Valid options are: '{}'".format( - value, ", ".join(validComparators) + value, ", ".join(valid_comparators) ), ) return True, None - elif expectedType == VariableType.INTEGER: + elif expected_type == VariableType.INTEGER: INT32_MIN = -(2**32) INT32_MAX = 2**32 - 1 valid = isinstance(value, int) and INT32_MIN <= value <= INT32_MAX return valid, None - elif expectedType == VariableType.FLOAT: + elif expected_type == VariableType.FLOAT: return isinstance(value, float), None - elif expectedType == VariableType.DOUBLE: + elif expected_type == VariableType.DOUBLE: return isinstance(value, float), None - elif expectedType in [ + elif expected_type in [ VariableType.ADDRESS_GROUP_NAME, VariableType.APPLICATION_SPEC, VariableType.BGP_PEER_PROPERTY_SPEC, @@ -868,46 +873,46 @@ def _validateType(value, expectedType): VariableType.ZONE, ]: if not isinstance(value, str): - return False, f"A Batfish {expectedType} must be a string" + return False, f"A Batfish {expected_type.value} must be a string" return True, None - elif expectedType == VariableType.IP: + elif expected_type == VariableType.IP: if not isinstance(value, str): - return False, f"A Batfish {expectedType} must be a string" + return False, f"A Batfish {expected_type.value} must be a string" else: return _isIp(value) - elif expectedType == VariableType.IP_WILDCARD: + elif expected_type == VariableType.IP_WILDCARD: if not isinstance(value, str): - return False, f"A Batfish {expectedType} must be a string" + return False, f"A Batfish {expected_type.value} must be a string" else: return _isIpWildcard(value) - elif expectedType == VariableType.JSON_PATH: + elif expected_type == VariableType.JSON_PATH: return _isJsonPath(value) - elif expectedType == VariableType.LONG: + elif expected_type == VariableType.LONG: INT64_MIN = -(2**64) INT64_MAX = 2**64 - 1 valid = isinstance(value, int) and INT64_MIN <= value <= INT64_MAX return valid, None - elif expectedType == VariableType.PREFIX: + elif expected_type == VariableType.PREFIX: if not isinstance(value, str): - return False, f"A Batfish {expectedType} must be a string" + return False, f"A Batfish {expected_type.value} must be a string" else: return _isPrefix(value) - elif expectedType == VariableType.PREFIX_RANGE: + elif expected_type == VariableType.PREFIX_RANGE: if not isinstance(value, str): - return False, f"A Batfish {expectedType} must be a string" + return False, f"A Batfish {expected_type.value} must be a string" else: return _isPrefixRange(value) - elif expectedType == VariableType.QUESTION: + elif expected_type == VariableType.QUESTION: return isinstance(value, QuestionBase), None - elif expectedType == VariableType.BGP_ROUTES: + elif expected_type == VariableType.BGP_ROUTES: if not isinstance(value, list) or not all( isinstance(r, BgpRoute) for r in value ): - return False, f"A Batfish {expectedType} must be a list of BgpRoute" + return False, f"A Batfish {expected_type.value} must be a list of BgpRoute" return True, None - elif expectedType == VariableType.STRING: + elif expected_type == VariableType.STRING: return isinstance(value, str), None - elif expectedType == VariableType.SUBRANGE: + elif expected_type == VariableType.SUBRANGE: if isinstance(value, int): return True, None elif isinstance(value, str): @@ -916,12 +921,12 @@ def _validateType(value, expectedType): return ( False, "A Batfish {} must either be a string or an integer".format( - expectedType + expected_type.value ), ) - elif expectedType == VariableType.PROTOCOL: + elif expected_type == VariableType.PROTOCOL: if not isinstance(value, str): - return False, f"A Batfish {expectedType} must be a string" + return False, f"A Batfish {expected_type.value} must be a string" else: validProtocols = ["dns", "ssh", "tcp", "udp"] if not value.lower() in validProtocols: @@ -932,9 +937,9 @@ def _validateType(value, expectedType): ), ) return True, None - elif expectedType == VariableType.IP_PROTOCOL: + elif expected_type == VariableType.IP_PROTOCOL: if not isinstance(value, str): - return False, f"A Batfish {expectedType} must be a string" + return False, f"A Batfish {expected_type.value} must be a string" else: try: intValue = int(value) @@ -947,7 +952,7 @@ def _validateType(value, expectedType): except ValueError: # TODO: Should be validated at server side return True, None - elif expectedType in [ + elif expected_type in [ VariableType.ANSWER_ELEMENT, VariableType.BGP_ROUTE_CONSTRAINTS, VariableType.HEADER_CONSTRAINT, @@ -957,13 +962,13 @@ def _validateType(value, expectedType): else: logging.getLogger(__name__).warning( "WARNING: skipping validation for unknown argument type {}".format( - expectedType + expected_type.value ) ) return True, None -def _isJsonPath(value): +def _isJsonPath(value: Any) -> Tuple[bool, Optional[str]]: """ Check if the input string represents a valid jsonPath. @@ -991,7 +996,7 @@ def _isJsonPath(value): return True, None -def _isIp(value): +def _isIp(value: str) -> Tuple[bool, Optional[str]]: """ Check if the input string represents a valid IP address. @@ -1040,7 +1045,7 @@ def _isIp(value): return True, None -def _isSubRange(value): +def _isSubRange(value: str) -> Tuple[bool, Optional[str]]: """ Check if the input string represents a valid subRange. @@ -1061,7 +1066,7 @@ def _isSubRange(value): return True, None -def _isPrefix(value): +def _isPrefix(value: str) -> Tuple[bool, Optional[str]]: """ Check if the input string represents a valid prefix. @@ -1081,7 +1086,7 @@ def _isPrefix(value): return _isIp(contents[0]) -def _isPrefixRange(value): +def _isPrefixRange(value: str) -> Tuple[bool, Optional[str]]: """ Check if the input string represents a valid prefix range. @@ -1105,7 +1110,7 @@ def _isPrefixRange(value): return True, None -def _isIpWildcard(value): +def _isIpWildcard(value: str) -> Tuple[bool, Optional[str]]: """ Check if the input string represents a valid ipWildCard. diff --git a/setup.py b/setup.py index 64267e26..456b789a 100644 --- a/setup.py +++ b/setup.py @@ -74,6 +74,8 @@ "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", ], python_requires=">=3.8", # What does your project relate to? diff --git a/tests/question/test_question_additional.py b/tests/question/test_question_additional.py index ddb6f534..3991e71f 100644 --- a/tests/question/test_question_additional.py +++ b/tests/question/test_question_additional.py @@ -257,27 +257,27 @@ def testValidIpAddressIpWildcard(): # Tests for validateType def testInvalidBooleanValidateType(): - result = question._validateType(1.5, "boolean") + result = question._validate_type(1.5, "boolean") assert not result[0] def testValidBooleanValidateType(): - result = question._validateType(True, "boolean") + result = question._validate_type(True, "boolean") assert result[0] def testInvalidIntegerValidateType(): - result = question._validateType(1.5, "integer") + result = question._validate_type(1.5, "integer") assert not result[0] def testValidIntegerValidateType(): - result = question._validateType(10, "integer") + result = question._validate_type(10, "integer") assert result[0] def testInvalidComparatorValidateType(): - result = question._validateType("<==", "comparator") + result = question._validate_type("<==", "comparator") expectMessage = ( "'<==' is not a known comparator. Valid options are: '<, <=, ==, >=, >, !='" ) @@ -286,107 +286,107 @@ def testInvalidComparatorValidateType(): def testValidComparatorValidateType(): - result = question._validateType("<=", "comparator") + result = question._validate_type("<=", "comparator") assert result[0] def testInvalidFloatValidateType(): - result = question._validateType(10, "float") + result = question._validate_type(10, "float") assert not result[0] def testValidFloatValidateType(): - result = question._validateType(10.0, "float") + result = question._validate_type(10.0, "float") assert result[0] def testInvalidDoubleValidateType(): - result = question._validateType(10, "double") + result = question._validate_type(10, "double") assert not result[0] def testValidDoubleValidateType(): - result = question._validateType(10.0, "double") + result = question._validate_type(10.0, "double") assert result[0] def testInvalidLongValidateType(): - result = question._validateType(5.3, "long") + result = question._validate_type(5.3, "long") assert not result[0] - result = question._validateType(2**64, "long") + result = question._validate_type(2**64, "long") assert not result[0] def testValidLongValidateType(): - result = question._validateType(10, "long") + result = question._validate_type(10, "long") assert result[0] - result = question._validateType(2**40, "long") + result = question._validate_type(2**40, "long") assert result[0] def testInvalidJavaRegexValidateType(): - result = question._validateType(10, "javaRegex") + result = question._validate_type(10, "javaRegex") expectMessage = "A Batfish javaRegex must be a string" assert not result[0] assert expectMessage == result[1] def testInvalidNonDictionaryJsonPathValidateType(): - result = question._validateType(10, "jsonPath") + result = question._validate_type(10, "jsonPath") expectMessage = "Expected a jsonPath dictionary with elements 'path' (string) and optional 'suffix' (boolean)" assert not result[0] assert expectMessage == result[1] def testInvalidDictionaryJsonPathValidateType(): - result = question._validateType({"value": 10}, "jsonPath") + result = question._validate_type({"value": 10}, "jsonPath") expectMessage = "Missing 'path' element of jsonPath" assert not result[0] assert expectMessage == result[1] def testPathNonStringJsonPathValidateType(): - result = question._validateType({"path": 10}, "jsonPath") + result = question._validate_type({"path": 10}, "jsonPath") expectMessage = "'path' element of jsonPath dictionary should be a string" assert not result[0] assert expectMessage == result[1] def testSuffixNonBooleanJsonPathValidateType(): - result = question._validateType({"path": "I am path", "suffix": "hi"}, "jsonPath") + result = question._validate_type({"path": "I am path", "suffix": "hi"}, "jsonPath") expectMessage = "'suffix' element of jsonPath dictionary should be a boolean" assert not result[0] assert expectMessage == result[1] def testValidJsonPathValidateType(): - result = question._validateType({"path": "I am path", "suffix": True}, "jsonPath") + result = question._validate_type({"path": "I am path", "suffix": True}, "jsonPath") assert result[0] assert result[1] is None def testInvalidTypeSubRangeValidateType(): - result = question._validateType(10.0, "subrange") + result = question._validate_type(10.0, "subrange") expectMessage = "A Batfish subrange must either be a string or an integer" assert not result[0] assert expectMessage == result[1] def testValidIntegerSubRangeValidateType(): - result = question._validateType(10, "subrange") + result = question._validate_type(10, "subrange") assert result[0] assert result[1] is None def testNonStringProtocolValidateType(): - result = question._validateType(10.0, "protocol") + result = question._validate_type(10.0, "protocol") expectMessage = "A Batfish protocol must be a string" assert not result[0] assert expectMessage == result[1] def testInvalidProtocolValidateType(): - result = question._validateType("TCPP", "protocol") + result = question._validate_type("TCPP", "protocol") expectMessage = ( "'TCPP' is not a valid protocols. Valid options are: 'dns, ssh, tcp, udp'" ) @@ -395,27 +395,27 @@ def testInvalidProtocolValidateType(): def testValidProtocolValidateType(): - result = question._validateType("TCP", "protocol") + result = question._validate_type("TCP", "protocol") assert result[0] assert result[1] is None def testNonStringIpProtocolValidateType(): - result = question._validateType(10.0, "ipProtocol") + result = question._validate_type(10.0, "ipProtocol") expectMessage = "A Batfish ipProtocol must be a string" assert not result[0] assert expectMessage == result[1] def testInvalidIntegerIpProtocolValidateType(): - result = question._validateType("1000", "ipProtocol") + result = question._validate_type("1000", "ipProtocol") expectMessage = "'1000' is not in valid ipProtocol range: 0-255" assert not result[0] assert expectMessage == result[1] def testValidIntegerIpProtocolValidateType(): - result = question._validateType("10", "ipProtocol") + result = question._validate_type("10", "ipProtocol") assert result[0] assert result[1] is None @@ -423,8 +423,8 @@ def testValidIntegerIpProtocolValidateType(): def testInvalidCompletionTypes(): # TODO: simplify to COMPLETION_TYPES after VariableType.BGP_ROUTE_STATUS_SPEC is moved for completion_type in set(COMPLETION_TYPES + [VariableType.BGP_ROUTE_STATUS_SPEC]): - result = question._validateType(5, completion_type) - expectMessage = "A Batfish " + completion_type + " must be a string" + result = question._validate_type(5, completion_type) + expectMessage = f"A Batfish {completion_type.value} must be a string" assert not result[0] assert result[1] == expectMessage @@ -437,7 +437,7 @@ def testValidCompletionTypes(): } # TODO: simplify to COMPLETION_TYPES after VariableType.BGP_ROUTE_STATUS_SPEC is moved for completion_type in set(COMPLETION_TYPES + [VariableType.BGP_ROUTE_STATUS_SPEC]): - result = question._validateType( + result = question._validate_type( values.get(completion_type, ".*"), completion_type ) assert result[0]