Skip to content

Commit

Permalink
Merge pull request #73 from kannkyo/fix-flag-interface
Browse files Browse the repository at this point in the history
change: remove FLAG for userbility
  • Loading branch information
kannkyo authored Feb 12, 2023
2 parents 53f021b + 2a87089 commit b3d798c
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 54 deletions.
1 change: 0 additions & 1 deletion src/nvd_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
__version__ = "1.0.0"

# import NvdApiClient
from nvd_api.client import FLAG
from nvd_api.client import VERSION_TYPE
from nvd_api.client import CVSS_V2_SEVERITY
from nvd_api.client import CVSS_V3_SEVERITY
Expand Down
78 changes: 43 additions & 35 deletions src/nvd_api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,6 @@
logger = logging.getLogger()


class FLAG(Enum):
"""Flag class"""

TRUE = ""
FALSE = None


class VERSION_TYPE(Enum):
INCLUDING = "including"
EXCLUDING = "excluding"
Expand Down Expand Up @@ -55,6 +48,11 @@ class NvdApiClient(object):
"""NVD API Client class
"""

MAX_PAGE_LIMIT_CVE_API = 2000
MAX_PAGE_LIMIT_CVE_HISTORY_API = 5000
MAX_PAGE_LIMIT_CPE_API = 10000
MAX_PAGE_LIMIT_CPE_MATCH_API = 5000

def __init__(self, wait_time: int = 6000, api_key=None):
"""Constructor # noqa: E501
Expand All @@ -78,16 +76,16 @@ def get_cves(self,
cvss_v3_metrics: str = None,
cvss_v3_severity: CVSS_V3_SEVERITY = None,
cwe_id: str = None,
has_cert_alerts: FLAG = FLAG.FALSE,
has_cert_notes: FLAG = FLAG.FALSE,
has_kev: FLAG = FLAG.FALSE,
has_oval: FLAG = FLAG.FALSE,
is_vulnerable: FLAG = FLAG.FALSE,
keyword_exact_match: FLAG = FLAG.FALSE,
has_cert_alerts: bool = False,
has_cert_notes: bool = False,
has_kev: bool = False,
has_oval: bool = False,
is_vulnerable: bool = False,
keyword_exact_match: bool = False,
keyword_search: str = None,
last_mod_start_date: datetime = None,
last_mod_end_date: datetime = None,
no_rejected: FLAG = FLAG.FALSE,
no_rejected: bool = False,
pub_start_date: datetime = None,
pub_end_date: datetime = None,
results_per_page: int = 2000,
Expand All @@ -108,16 +106,16 @@ def get_cves(self,
cvss_v3_metrics (str, optional): CVSSv3 vector string. Defaults to None.
cvss_v3_severity (CVSS_V3_SEVERITY, optional): CVSSv3 qualitative severity rating. Defaults to None.
cwe_id (str, optional): CWE ID. Defaults to None.
has_cert_alerts (FLAG, optional): contain a Technical Alert from US-CERT. Defaults to None.
has_cert_notes (FLAG, optional): contain a Vulnerability Note from CERT/CC. Defaults to None.
has_kev (FLAG, optional): appear in CISA's Known Exploited Vulnerabilities (KEV) Catalog. Defaults to None.
has_oval (FLAG, optional): contain information from MITRE's Open Vulnerability and Assessment Language (OVAL). Defaults to None.
is_vulnerable (FLAG, optional): returns only CVE associated with a specific CPE. Defaults to None.
keyword_exact_match (FLAG, optional): returns any CVE where a word or phrase. Defaults to None.
has_cert_alerts (bool, optional): contain a Technical Alert from US-CERT. Defaults to False.
has_cert_notes (bool, optional): contain a Vulnerability Note from CERT/CC. Defaults to False.
has_kev (bool, optional): appear in CISA's Known Exploited Vulnerabilities (KEV) Catalog. Defaults to False.
has_oval (bool, optional): contain information from MITRE's Open Vulnerability and Assessment Language (OVAL). Defaults to False.
is_vulnerable (bool, optional): returns only CVE associated with a specific CPE. Defaults to False.
keyword_exact_match (bool, optional): returns any CVE where a word or phrase. Defaults to False.
keyword_search (str, optional): a word or phrase is found in the current description. Defaults to None.
last_mod_start_date (datetime, optional): search by modified date. Defaults to None.
last_mod_end_date (datetime, optional): search by modified date. Defaults to None.
no_rejected (str, optional): return the CVE API includes CVE records with the REJECT or Rejected status. Defaults to None.
no_rejected (bool, optional): return the CVE API includes CVE records with the REJECT or Rejected status. Defaults to False.
pub_start_date (datetime, optional): search by published date. Defaults to None.
pub_end_date (datetime, optional): search by published date. Defaults to None.
results_per_page (int, optional): max number of records (default is 2000). Defaults to None.
Expand Down Expand Up @@ -159,23 +157,31 @@ def get_cves(self,
self._verify_version_start(version_start, version_start_type)
self._verify_version_end(version_end, version_end_type)

has_cert_alerts = "" if has_cert_alerts else None
has_cert_notes = "" if has_cert_notes else None
has_kev = "" if has_kev else None
has_oval = "" if has_oval else None
is_vulnerable = "" if is_vulnerable else None
keyword_exact_match = "" if keyword_exact_match else None
no_rejected = "" if no_rejected else None

kwargs = dict(cpe_name=cpe_name,
cve_id=cve_id,
cvss_v2_metrics=cvss_v2_metrics,
cvss_v2_severity=cvss_v2_severity,
cvss_v3_metrics=cvss_v3_metrics,
cvss_v3_severity=cvss_v3_severity,
cwe_id=cwe_id,
has_cert_alerts=has_cert_alerts.value,
has_cert_notes=has_cert_notes.value,
has_kev=has_kev.value,
has_oval=has_oval.value,
is_vulnerable=is_vulnerable.value,
keyword_exact_match=keyword_exact_match.value,
has_cert_alerts=has_cert_alerts,
has_cert_notes=has_cert_notes,
has_kev=has_kev,
has_oval=has_oval,
is_vulnerable=is_vulnerable,
keyword_exact_match=keyword_exact_match,
keyword_search=keyword_search,
last_mod_start_date=last_mod_start_date,
last_mod_end_date=last_mod_end_date,
no_rejected=no_rejected.value,
no_rejected=no_rejected,
pub_start_date=pub_start_date,
pub_end_date=pub_end_date,
results_per_page=results_per_page,
Expand Down Expand Up @@ -242,7 +248,7 @@ def get_cve_history(self,
def get_cpes(self,
cpe_name_id: str = None,
cpe_match_string: str = None,
keyword_exact_match: FLAG = FLAG.FALSE,
keyword_exact_match: bool = False,
keyword_search: str = None,
last_mod_start_date: datetime = None,
last_mod_end_date: datetime = None,
Expand All @@ -254,7 +260,7 @@ def get_cpes(self,
Args:
cpe_name_id (str, optional): specific CPE record UUID. Defaults to None.
cpe_match_string (str, optional): CPE Name. Defaults to None.
keyword_exact_match (FLAG, optional): if CPE exactly match or not. Defaults to None. Defaults to None.
keyword_exact_match (bool, optional): if CPE exactly match or not. Defaults to None. Defaults to False.
keyword_search (str, optional): a word or phrase is found in the metadata title or reference links. Defaults to None.
last_mod_start_date (datetime, optional): search CPE by modified date. Defaults to None.
last_mod_end_date (datetime, optional): search CPE by modified date. Defaults to None.
Expand All @@ -272,9 +278,11 @@ def get_cpes(self,
self._verify_last_mod_dates(last_mod_start_date, last_mod_end_date)
self._verify_keyword(keyword_exact_match, keyword_search)

keyword_exact_match = "" if keyword_exact_match else None

kwargs = dict(cpe_name_id=cpe_name_id,
cpe_match_string=cpe_match_string,
keyword_exact_match=keyword_exact_match.value,
keyword_exact_match=keyword_exact_match,
keyword_search=keyword_search,
last_mod_start_date=last_mod_start_date,
last_mod_end_date=last_mod_end_date,
Expand Down Expand Up @@ -431,15 +439,15 @@ def _verify_cvss_severity(self,
"can not use cvss_v2_severity with cvss_v3_severity")

def _verify_vulnerable(self,
is_vulnerable: FLAG,
is_vulnerable: bool,
cpe_name: str,):
if is_vulnerable is FLAG.TRUE and cpe_name is None:
if is_vulnerable is True and cpe_name is None:
raise ApiValueError("must use is_vulnerable with cpe_name")

def _verify_keyword(self,
keyword_exact_match: FLAG,
keyword_exact_match: bool,
keyword_search: str,):
if keyword_exact_match is FLAG.TRUE and keyword_search is None:
if keyword_exact_match is True and keyword_search is None:
raise ApiValueError(
"must use keyword_exact_match with keyword_search")

Expand Down
9 changes: 9 additions & 0 deletions tests/test_get_cpe_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,15 @@ def test_invalid_api_key(self):
)
pprint(response)

def test_max_page_limit(self):
max_limit = NvdApiClient.MAX_PAGE_LIMIT_CPE_MATCH_API
response = self.client.get_cpe_match(results_per_page=max_limit)
assert (len(response.match_strings) > 0)

with self.assertRaises(ApiValueError):
response = self.client.get_cpe_match(
results_per_page=max_limit + 1)


if __name__ == '__main__':
unittest.main()
15 changes: 12 additions & 3 deletions tests/test_get_cpes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import unittest
from pprint import pprint

from nvd_api.client import FLAG, NvdApiClient
from nvd_api.client import NvdApiClient
from nvd_api.low_api.exceptions import ApiValueError, NotFoundException


Expand Down Expand Up @@ -48,7 +48,7 @@ def test_get_by_cpe_match_string(self):

def test_get_by_keywords(self):
response = self.client.get_cpes(
keyword_exact_match=FLAG.TRUE,
keyword_exact_match=True,
keyword_search="Microsoft Windows",
results_per_page=1,
start_index=0
Expand Down Expand Up @@ -92,7 +92,7 @@ def test_use_mod_start_date_without_mod_end_date(self):
def test_use_keyword_exact_match_without_keyword_search(self):
with self.assertRaises(ApiValueError):
response = self.client.get_cpes(
keyword_exact_match=FLAG.TRUE
keyword_exact_match=True
)
pprint(response)

Expand Down Expand Up @@ -189,6 +189,15 @@ def test_invalid_api_key(self):
)
pprint(response)

def test_max_page_limit(self):
max_limit = NvdApiClient.MAX_PAGE_LIMIT_CPE_API
response = self.client.get_cpes(results_per_page=max_limit)
assert (len(response.products) > 0)

with self.assertRaises(ApiValueError):
response = self.client.get_cpes(
results_per_page=max_limit + 1)


if __name__ == '__main__':
unittest.main()
9 changes: 9 additions & 0 deletions tests/test_get_cve_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,15 @@ def test_invalid_api_key(self):
)
pprint(response)

def test_max_page_limit(self):
max_limit = NvdApiClient.MAX_PAGE_LIMIT_CVE_HISTORY_API
response = self.client.get_cve_history(results_per_page=max_limit)
assert (len(response.cve_changes) > 0)

with self.assertRaises(ApiValueError):
response = self.client.get_cve_history(
results_per_page=max_limit + 1)


if __name__ == '__main__':
unittest.main()
39 changes: 24 additions & 15 deletions tests/test_get_cves.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
import unittest
from pprint import pprint

from nvd_api.client import (CVSS_V2_SEVERITY, CVSS_V3_SEVERITY, FLAG,
VERSION_TYPE, NvdApiClient)
from nvd_api.client import (CVSS_V2_SEVERITY, CVSS_V3_SEVERITY, VERSION_TYPE,
NvdApiClient)
from nvd_api.low_api.exceptions import ApiValueError, NotFoundException


Expand All @@ -34,7 +34,7 @@ def test_get_by_cve(self):
cpe_name="cpe:2.3:a:ibm:mq:9.0.0.0:*:*:*:lts:*:*:*",
cve_id="CVE-2019-4227",
cwe_id="CWE-384",
is_vulnerable=FLAG.TRUE,
is_vulnerable=True,
source_identifier="nvd@nist.gov"
)
pprint(response)
Expand Down Expand Up @@ -70,11 +70,11 @@ def test_get_by_cvss_v3(self):

def test_get_by_has_flags(self):
response = self.client.get_cves(
has_cert_alerts=FLAG.TRUE,
has_cert_notes=FLAG.TRUE,
has_kev=FLAG.TRUE,
has_oval=FLAG.TRUE,
no_rejected=FLAG.TRUE,
has_cert_alerts=True,
has_cert_notes=True,
has_kev=True,
has_oval=True,
no_rejected=True,
results_per_page=1,
start_index=1
)
Expand All @@ -83,7 +83,7 @@ def test_get_by_has_flags(self):

def test_get_by_keywords(self):
response = self.client.get_cves(
keyword_exact_match=FLAG.TRUE,
keyword_exact_match=True,
keyword_search="CentOS",
results_per_page=1,
start_index=1
Expand Down Expand Up @@ -186,7 +186,7 @@ def test_use_cvss_v2_and_v3_severity(self):
def test_use_keyword_exact_match_without_keyword_search(self):
with self.assertRaises(ApiValueError):
response = self.client.get_cves(
keyword_exact_match=FLAG.TRUE,
keyword_exact_match=True,
results_per_page=1,
start_index=1
)
Expand All @@ -197,7 +197,7 @@ def test_use_is_vulnerable_without_cpe_name(self):
response = self.client.get_cves(
cve_id="CVE-2019-4227",
cwe_id="CWE-384",
is_vulnerable=FLAG.TRUE,
is_vulnerable=True,
source_identifier="nvd@nist.gov"
)
pprint(response)
Expand Down Expand Up @@ -266,7 +266,7 @@ def test_invalid_cpe_name(self):
with self.assertRaises(NotFoundException):
response = self.client.get_cves(
cpe_name="INVALID:2.3:a:ibm:mq:9.0.0.0:*:*:*:lts:*:*:*",
is_vulnerable=FLAG.TRUE
is_vulnerable=True
)
pprint(response)

Expand All @@ -276,7 +276,7 @@ def test_invalid_cve_id(self):
cpe_name="cpe:2.3:a:ibm:mq:9.0.0.0:*:*:*:lts:*:*:*",
cve_id="INVALID-2019-4227",
cwe_id="CWE-384",
is_vulnerable=FLAG.TRUE,
is_vulnerable=True,
source_identifier="nvd@nist.gov"
)
pprint(response)
Expand Down Expand Up @@ -331,7 +331,7 @@ def test_invalid_cwe_id(self):
cpe_name="cpe:2.3:a:ibm:mq:9.0.0.0:*:*:*:lts:*:*:*",
cve_id="CVE-2019-4227",
cwe_id="INVALID-384",
is_vulnerable=FLAG.TRUE,
is_vulnerable=True,
source_identifier="nvd@nist.gov"
)
pprint(response)
Expand Down Expand Up @@ -444,11 +444,20 @@ def test_invalid_api_key(self):
cpe_name="cpe:2.3:a:ibm:mq:9.0.0.0:*:*:*:lts:*:*:*",
cve_id="CVE-2019-4227",
cwe_id="CWE-384",
is_vulnerable=FLAG.TRUE,
is_vulnerable=True,
source_identifier="nvd@nist.gov"
)
pprint(response)

def test_max_page_limit(self):
max_limit = NvdApiClient.MAX_PAGE_LIMIT_CVE_API
response = self.client.get_cves(results_per_page=max_limit)
assert (len(response.vulnerabilities) > 0)

with self.assertRaises(ApiValueError):
response = self.client.get_cves(
results_per_page=max_limit + 1)


if __name__ == '__main__':
unittest.main()

0 comments on commit b3d798c

Please sign in to comment.