Skip to content

Commit

Permalink
refactor: switch to pathlib.Path in cvedb.py (#1751)
Browse files Browse the repository at this point in the history
* refactor: switch to pathlib.Path in cvedb.py

* fix: windows tests
  • Loading branch information
rhythmrx9 authored Jul 6, 2022
1 parent 6c13a9d commit 9bcdc1d
Showing 1 changed file with 24 additions and 24 deletions.
48 changes: 24 additions & 24 deletions cve_bin_tool/cvedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
import asyncio
import datetime
import logging
import os
import shutil
import sqlite3
from pathlib import Path
from typing import Any

import requests
Expand All @@ -26,12 +26,10 @@
logging.basicConfig(level=logging.DEBUG)

# database defaults
DISK_LOCATION_DEFAULT = os.path.join(os.path.expanduser("~"), ".cache", "cve-bin-tool")
DISK_LOCATION_BACKUP = os.path.join(
os.path.expanduser("~"), ".cache", "cve-bin-tool-backup"
)
DISK_LOCATION_DEFAULT = Path("~").expanduser() / ".cache" / "cve-bin-tool"
DISK_LOCATION_BACKUP = Path("~").expanduser() / ".cache" / "cve-bin-tool-backup"
DBNAME = "cve.db"
OLD_CACHE_DIR = os.path.join(os.path.expanduser("~"), ".cache", "cvedb")
OLD_CACHE_DIR = Path("~") / ".cache" / "cvedb"


class CVEDB:
Expand All @@ -58,9 +56,11 @@ def __init__(
if sources is not None
else [x(error_mode=error_mode) for x in self.SOURCES]
)
self.cachedir = cachedir if cachedir is not None else self.CACHEDIR
self.cachedir = Path(cachedir) if cachedir is not None else self.CACHEDIR
self.backup_cachedir = (
backup_cachedir if backup_cachedir is not None else self.BACKUPCACHEDIR
Path(backup_cachedir)
if backup_cachedir is not None
else self.BACKUPCACHEDIR
)
self.error_mode = error_mode

Expand All @@ -71,7 +71,7 @@ def __init__(
self.version_check = version_check

# set up the db if needed
self.dbpath = os.path.join(self.cachedir, DBNAME)
self.dbpath = self.cachedir / DBNAME
self.connection: sqlite3.Connection | None = None

self.data = []
Expand All @@ -81,7 +81,7 @@ def __init__(
self.exploits_list = []
self.exploit_count = 0

if not os.path.exists(self.dbpath):
if not self.dbpath.exists():
self.rollback_cache_backup()

def get_cve_count(self) -> int:
Expand All @@ -91,20 +91,20 @@ def get_cve_count(self) -> int:
return self.cve_count

def check_db_exists(self) -> bool:
return os.path.isfile(self.dbpath)
return self.dbpath.is_file()

def get_db_update_date(self) -> float:
# last time when CVE data was updated
self.time_of_last_update = datetime.datetime.fromtimestamp(
os.path.getmtime(self.dbpath)
self.dbpath.stat().st_mtime
)
return os.path.getmtime(self.dbpath)
return self.dbpath.stat().st_mtime

async def refresh(self) -> None:
"""Refresh the cve database and check for new version."""
# refresh the database
if not os.path.isdir(self.cachedir):
os.makedirs(self.cachedir)
if not self.cachedir.is_dir():
self.cachedir.mkdir(parents=True)

# check for the latest version
if self.version_check:
Expand All @@ -125,15 +125,15 @@ def get_cvelist_if_stale(self) -> None:
"""Update if the local db is more than one day old.
This avoids the full slow update with every execution.
"""
if not os.path.isfile(self.dbpath) or (
if not self.dbpath.is_file() or (
datetime.datetime.today()
- datetime.datetime.fromtimestamp(os.path.getmtime(self.dbpath))
- datetime.datetime.fromtimestamp(self.dbpath.stat().st_mtime)
) > datetime.timedelta(hours=24):
self.refresh_cache_and_update_db()
self.time_of_last_update = datetime.datetime.today()
else:
self.time_of_last_update = datetime.datetime.fromtimestamp(
os.path.getmtime(self.dbpath)
self.dbpath.stat().st_mtime
)
self.LOGGER.info(
"Using cached CVE data (<24h old). Use -u now to update immediately."
Expand Down Expand Up @@ -315,11 +315,11 @@ def populate_affected(self, affected_data, cursor):

def clear_cached_data(self) -> None:
self.create_cache_backup()
if os.path.exists(self.cachedir):
if self.cachedir.exists():
self.LOGGER.warning(f"Updating cachedir {self.cachedir}")
shutil.rmtree(self.cachedir)
# Remove files associated with pre-1.0 development tree
if os.path.exists(OLD_CACHE_DIR):
if OLD_CACHE_DIR.exists():
self.LOGGER.warning(f"Deleting old cachedir {OLD_CACHE_DIR}")
shutil.rmtree(OLD_CACHE_DIR)

Expand Down Expand Up @@ -381,7 +381,7 @@ def db_close(self) -> None:

def create_cache_backup(self) -> None:
"""Creates a backup of the cachedir in case anything fails"""
if os.path.exists(self.cachedir):
if self.cachedir.exists():
self.LOGGER.debug(
f"Creating backup of cachedir {self.cachedir} at {self.backup_cachedir}"
)
Expand All @@ -397,15 +397,15 @@ def copy_db(self, filename, export=True):

def remove_cache_backup(self) -> None:
"""Removes the backup if database was successfully loaded"""
if os.path.exists(self.backup_cachedir):
if self.backup_cachedir.exists():
self.LOGGER.debug(f"Removing backup cache from {self.backup_cachedir}")
shutil.rmtree(self.backup_cachedir)

def rollback_cache_backup(self) -> None:
"""Rollback the cachedir backup in case anything fails"""
if os.path.exists(os.path.join(self.backup_cachedir, DBNAME)):
if (self.backup_cachedir / DBNAME).exists():
self.LOGGER.info("Rolling back the cache to its previous state")
if os.path.exists(self.cachedir):
if self.cachedir.exists():
shutil.rmtree(self.cachedir)
shutil.move(self.backup_cachedir, self.cachedir)

Expand Down

0 comments on commit 9bcdc1d

Please sign in to comment.