Skip to content
This repository has been archived by the owner on Nov 29, 2021. It is now read-only.

Fix scan progress in which all hosts are dead or excluded. #295

Merged
merged 4 commits into from
Jul 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
- Fix stop scan. Wait for the scan process to be stopped before delete it from the process table. [#204](https://github.com/greenbone/ospd/pull/204)
- Fix get_scanner_details(). [#210](https://github.com/greenbone/ospd/pull/210)
- Fix thread lib leak using daemon mode for python 3.7. [#272](https://github.com/greenbone/ospd/pull/272)
- Fix scan progress in which all hosts are dead or excluded. [#295](https://github.com/greenbone/ospd/pull/295)

### Removed
- Remove support for resume task. [#266](https://github.com/greenbone/ospd/pull/266)
Expand Down
2 changes: 1 addition & 1 deletion ospd/datapickler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from hashlib import sha256
from pathlib import Path
from typing import Dict, BinaryIO, Any
from typing import BinaryIO, Any

from ospd.errors import OspdCommandError

Expand Down
11 changes: 8 additions & 3 deletions ospd/ospd.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,7 @@ def start_scan(self, scan_id: str) -> None:
logger.info("%s: Host scan finished.", scan_id)

is_stopped = self.get_scan_status(scan_id) == ScanStatus.STOPPED
self.set_scan_progress(scan_id)
progress = self.get_scan_progress(scan_id)

if not is_stopped and progress == ScanProgress.FINISHED:
Expand Down Expand Up @@ -619,13 +620,17 @@ def sort_host_finished(
scan_id, finished_hosts
)

def set_scan_progress(self, scan_id: str):
""" Calculate the target progress with the current host states
and stores in the scan table. """
scan_progress = self.scan_collection.calculate_target_progress(scan_id)
self.scan_collection.set_progress(scan_id, scan_progress)

def set_scan_progress_batch(
self, scan_id: str, host_progress: Dict[str, int]
):
self.scan_collection.set_host_progress(scan_id, host_progress)

scan_progress = self.scan_collection.calculate_target_progress(scan_id)
self.scan_collection.set_progress(scan_id, scan_progress)
self.set_scan_progress(scan_id)

def set_scan_host_progress(
self, scan_id: str, host: str = None, progress: int = None,
Expand Down
12 changes: 6 additions & 6 deletions ospd/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,11 +398,8 @@ def calculate_target_progress(self, scan_id: str) -> int:
/ (total_hosts - exc_hosts - count_dead)
)
except ZeroDivisionError:
LOGGER.error(
"Zero division error in %s",
self.calculate_target_progress.__name__,
)
raise
# Consider the case in which all hosts are dead or excluded
t_prog = ScanProgress.FINISHED.value

return t_prog

Expand All @@ -424,7 +421,10 @@ def get_host_list(self, scan_id: str) -> Dict:
def get_host_count(self, scan_id: str) -> int:
""" Get total host count in the target. """
host = self.get_host_list(scan_id)
total_hosts = len(target_str_to_list(host))
total_hosts = 0

if host:
total_hosts = len(target_str_to_list(host))

return total_hosts

Expand Down
25 changes: 25 additions & 0 deletions tests/test_scan_and_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,31 @@ def test_progress(self):
self.daemon.scan_collection.calculate_target_progress(scan_id), 50
)

def test_progress_all_host_dead(self):

fs = FakeStream()
self.daemon.handle_command(
'<start_scan parallel="2">'
'<scanner_params />'
'<targets><target>'
'<hosts>localhost1, localhost2</hosts>'
'<ports>22</ports>'
'</target></targets>'
'</start_scan>',
fs,
)
self.daemon.start_queued_scans()
response = fs.get_response()

scan_id = response.findtext('id')
self.daemon.set_scan_host_progress(scan_id, 'localhost1', -1)
self.daemon.set_scan_host_progress(scan_id, 'localhost2', -1)

self.daemon.sort_host_finished(scan_id, ['localhost1', 'localhost2'])
self.assertEqual(
self.daemon.scan_collection.calculate_target_progress(scan_id), 100
)

@patch('ospd.ospd.os')
def test_interrupted_scan(self, mock_os):
mock_os.setsid.return_value = None
Expand Down