From dc4a27d69431197ea2e15c59d5c70d08373428e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mikael=20G=C3=B6ransson?= Date: Mon, 23 May 2022 21:48:44 +0200 Subject: [PATCH 1/3] added types for additional code --- locust/env.py | 42 ++--- locust/runners.py | 166 +++++++++++--------- locust/stats.py | 307 +++++++++++++++++++++---------------- locust/user/inspectuser.py | 8 +- locust/user/task.py | 15 +- locust/web.py | 134 ++++++++++------ 6 files changed, 394 insertions(+), 278 deletions(-) diff --git a/locust/env.py b/locust/env.py index 1212d01fba..8ecb9afbef 100644 --- a/locust/env.py +++ b/locust/env.py @@ -5,19 +5,19 @@ List, Type, TypeVar, - Union, Optional, + Union, ) from configargparse import Namespace from .event import Events from .exception import RunnerAlreadyExistsError -from .stats import RequestStats +from .stats import RequestStats, StatsCSV from .runners import Runner, LocalRunner, MasterRunner, WorkerRunner from .web import WebUI from .user import User -from .user.task import TaskSet, filter_tasks_by_tags +from .user.task import filter_tasks_by_tags, TaskSet from .shape import LoadTestShape @@ -28,17 +28,17 @@ class Environment: def __init__( self, *, - user_classes: Union[List[Type[User]], None] = None, - shape_class: Union[LoadTestShape, None] = None, - tags: Union[List[str], None] = None, + user_classes: Optional[List[Type[User]]] = None, + shape_class: Optional[LoadTestShape] = None, + tags: Optional[List[str]] = None, locustfile: str = None, - exclude_tags=None, + exclude_tags: Optional[List[str]] = None, events: Events = None, host: str = None, reset_stats=False, - stop_timeout: Union[float, None] = None, + stop_timeout: Optional[float] = None, catch_exceptions=True, - parsed_options: Namespace = None, + parsed_options: Optional[Namespace] = None, ): self.runner: Optional[Runner] = None @@ -141,7 +141,7 @@ def create_master_runner(self, master_bind_host="*", master_bind_port=5557) -> M master_bind_port=master_bind_port, ) - def create_worker_runner(self, master_host, master_port) -> WorkerRunner: + def create_worker_runner(self, master_host: str, master_port: int) -> WorkerRunner: """ Create a :class:`WorkerRunner ` instance for this Environment @@ -161,12 +161,12 @@ def create_web_ui( self, host="", port=8089, - auth_credentials=None, - tls_cert=None, - tls_key=None, - stats_csv_writer=None, + auth_credentials: Optional[str] = None, + tls_cert: Optional[str] = None, + tls_key: Optional[str] = None, + stats_csv_writer: Optional[StatsCSV] = None, delayed_start=False, - ): + ) -> WebUI: """ Creates a :class:`WebUI ` instance for this Environment and start running the web server @@ -194,7 +194,7 @@ def create_web_ui( ) return self.web_ui - def _filter_tasks_by_tags(self): + def _filter_tasks_by_tags(self) -> None: """ Filter the tasks on all the user_classes recursively, according to the tags and exclude_tags attributes @@ -220,7 +220,7 @@ def _filter_tasks_by_tags(self): for user_class in self.user_classes: filter_tasks_by_tags(user_class, tags, exclude_tags) - def _remove_user_classes_with_weight_zero(self): + def _remove_user_classes_with_weight_zero(self) -> None: """ Remove user classes having a weight of zero. """ @@ -235,20 +235,20 @@ def _remove_user_classes_with_weight_zero(self): raise ValueError("There are no users with weight > 0.") self.user_classes[:] = filtered_user_classes - def assign_equal_weights(self): + def assign_equal_weights(self) -> None: """ Update the user classes such that each user runs their specified tasks with equal probability. """ for u in self.user_classes: u.weight = 1 - user_tasks = [] + user_tasks: List[Union[TaskSet, Callable]] = [] tasks_frontier = u.tasks while len(tasks_frontier) != 0: t = tasks_frontier.pop() - if hasattr(t, "tasks") and t.tasks: + if not callable(t) and hasattr(t, "tasks") and t.tasks: tasks_frontier.extend(t.tasks) - elif isinstance(t, Callable): + elif callable(t): if t not in user_tasks: user_tasks.append(t) else: diff --git a/locust/runners.py b/locust/runners.py index abb6db4db9..4a52deec30 100644 --- a/locust/runners.py +++ b/locust/runners.py @@ -13,12 +13,23 @@ itemgetter, methodcaller, ) +from types import TracebackType from typing import ( + TYPE_CHECKING, Dict, Iterator, List, + NoReturn, Union, ValuesView, + TypedDict, + Set, + Callable, + Optional, + Tuple, + Type, + Any, + cast, ) from uuid import uuid4 @@ -38,10 +49,14 @@ ) from .stats import ( RequestStats, + StatsError, setup_distributed_stats_event_listeners, ) from . import argument_parser +if TYPE_CHECKING: + from .env import Environment + logger = logging.getLogger(__name__) @@ -66,6 +81,13 @@ greenlet_exception_handler = greenlet_exception_logger(logger) +class ExceptionDict(TypedDict): + count: int + msg: str + traceback: str + nodes: Set[str] + + class Runner: """ Orchestrates the load test by starting and stopping the users. @@ -77,33 +99,33 @@ class Runner: desired type. """ - def __init__(self, environment): + def __init__(self, environment: "Environment") -> None: self.environment = environment self.user_greenlets = Group() self.greenlet = Group() self.state = STATE_INIT - self.spawning_greenlet = None - self.shape_greenlet = None - self.shape_last_state = None - self.current_cpu_usage = 0 - self.cpu_warning_emitted = False - self.worker_cpu_warning_emitted = False - self.current_memory_usage = 0 + self.spawning_greenlet: Optional[gevent.Greenlet] = None + self.shape_greenlet: Optional[gevent.Greenlet] = None + self.shape_last_state: Optional[Tuple[int, float]] = None + self.current_cpu_usage: int = 0 + self.cpu_warning_emitted: bool = False + self.worker_cpu_warning_emitted: bool = False + self.current_memory_usage: int = 0 self.greenlet.spawn(self.monitor_cpu_and_memory).link_exception(greenlet_exception_handler) - self.exceptions = {} + self.exceptions: Dict[int, ExceptionDict] = {} # Because of the way the ramp-up/ramp-down is implemented, target_user_classes_count # is only updated at the end of the ramp-up/ramp-down. # See https://github.com/locustio/locust/issues/1883#issuecomment-919239824 for context. self.target_user_classes_count: Dict[str, int] = {} # target_user_count is set before the ramp-up/ramp-down occurs. self.target_user_count: int = 0 - self.custom_messages = {} + self.custom_messages: Dict[str, Callable[["Environment", Message], None]] = {} # Only when running in standalone mode (non-distributed) self._local_worker_node = WorkerNode(id="local") self._local_worker_node.user_classes_count = self.user_classes_count - self._users_dispatcher = None + self._users_dispatcher: Optional[UsersDispatcher] = None # set up event listeners for recording requests def on_request_success(request_type, name, response_time, response_length, **_kwargs): @@ -121,7 +143,7 @@ def on_request_failure(request_type, name, response_time, response_length, excep logging.getLogger().setLevel(loglevel) self.connection_broken = False - self.final_user_classes_count = {} # just for the ratio report, fills before runner stops + self.final_user_classes_count: Dict[str, int] = {} # just for the ratio report, fills before runner stops # register listener that resets stats when spawning is complete def on_spawning_complete(user_count): @@ -138,11 +160,11 @@ def __del__(self): self.greenlet.kill(block=False) @property - def user_classes(self): + def user_classes(self) -> List[Type[User]]: return self.environment.user_classes @property - def user_classes_by_name(self): + def user_classes_by_name(self) -> Dict[str, Type[User]]: return self.environment.user_classes_by_name @property @@ -150,11 +172,11 @@ def stats(self) -> RequestStats: return self.environment.stats @property - def errors(self): + def errors(self) -> Dict[str, StatsError]: return self.stats.errors @property - def user_count(self): + def user_count(self) -> int: """ :returns: Number of currently running users """ @@ -184,7 +206,7 @@ def user_classes_count(self) -> Dict[str, int]: user_classes_count[user.__class__.__name__] += 1 return user_classes_count - def update_state(self, new_state): + def update_state(self, new_state: str) -> None: """ Updates the current state """ @@ -210,9 +232,9 @@ def spawn_users(self, user_classes_spawn_count: Dict[str, int], wait: bool = Fal % (json.dumps(user_classes_spawn_count), json.dumps(self.user_classes_count)) ) - def spawn(user_class: str, spawn_count: int): + def spawn(user_class: str, spawn_count: int) -> List[User]: n = 0 - new_users = [] + new_users: List[User] = [] while n < spawn_count: new_user = self.user_classes_by_name[user_class](self.environment) new_user.start(self.user_greenlets) @@ -223,7 +245,7 @@ def spawn(user_class: str, spawn_count: int): logger.debug(f"All users of class {user_class} spawned") return new_users - new_users = [] + new_users: List[User] = [] for user_class, spawn_count in user_classes_spawn_count.items(): new_users += spawn(user_class, spawn_count) @@ -232,7 +254,7 @@ def spawn(user_class: str, spawn_count: int): logger.info("All users stopped\n") return new_users - def stop_users(self, user_classes_stop_count: Dict[str, int]): + def stop_users(self, user_classes_stop_count: Dict[str, int]) -> None: async_calls_to_stop = Group() stop_group = Group() @@ -284,7 +306,7 @@ def stop_users(self, user_classes_stop_count: Dict[str, int]): "%g users have been stopped, %g still running", sum(user_classes_stop_count.values()), self.user_count ) - def monitor_cpu_and_memory(self): + def monitor_cpu_and_memory(self) -> NoReturn: process = psutil.Process() while True: self.current_cpu_usage = process.cpu_percent() @@ -298,7 +320,7 @@ def monitor_cpu_and_memory(self): self.cpu_warning_emitted = True gevent.sleep(CPU_MONITOR_INTERVAL) - def start(self, user_count: int, spawn_rate: float, wait: bool = False): + def start(self, user_count: int, spawn_rate: float, wait: bool = False) -> None: """ Start running a load test @@ -373,7 +395,7 @@ def start(self, user_count: int, spawn_rate: float, wait: bool = False): self.environment.events.spawning_complete.fire(user_count=sum(self.target_user_classes_count.values())) - def start_shape(self): + def start_shape(self) -> None: """ Start running a load test with a custom LoadTestShape specified in the :meth:`Environment.shape_class ` parameter. """ @@ -387,7 +409,7 @@ def start_shape(self): self.shape_greenlet.link_exception(greenlet_exception_handler) self.environment.shape_class.reset_time() - def shape_worker(self): + def shape_worker(self) -> None: logger.info("Shape worker starting") while self.state == STATE_INIT or self.state == STATE_SPAWNING or self.state == STATE_RUNNING: new_state = self.environment.shape_class.tick() @@ -421,7 +443,7 @@ def shape_worker(self): self.start(user_count=user_count, spawn_rate=spawn_rate) self.shape_last_state = new_state - def stop(self): + def stop(self) -> None: """ Stop a running load test by stopping all running users """ @@ -451,21 +473,21 @@ def stop(self): self.cpu_log_warning() self.environment.events.test_stop.fire(environment=self.environment) - def quit(self): + def quit(self) -> None: """ Stop any running load test and kill all greenlets for the runner """ self.stop() self.greenlet.kill(block=True) - def log_exception(self, node_id, msg, formatted_tb): + def log_exception(self, node_id: str, msg: str, formatted_tb: str) -> None: key = hash(formatted_tb) row = self.exceptions.setdefault(key, {"count": 0, "msg": msg, "traceback": formatted_tb, "nodes": set()}) row["count"] += 1 row["nodes"].add(node_id) self.exceptions[key] = row - def register_message(self, msg_type, listener): + def register_message(self, msg_type: str, listener: Callable[["Environment", Message], None]) -> None: """ Register a listener for a custom message from another node @@ -480,7 +502,7 @@ class LocalRunner(Runner): Runner for running single process load test """ - def __init__(self, environment): + def __init__(self, environment) -> None: """ :param environment: Environment instance """ @@ -493,7 +515,7 @@ def on_user_error(user_instance, exception, tb): self.environment.events.user_error.add_listener(on_user_error) - def start(self, user_count: int, spawn_rate: float, wait: bool = False): + def start(self, user_count: int, spawn_rate: float, wait: bool = False) -> None: if spawn_rate > 100: logger.warning( "Your selected spawn rate is very high (>100), and this is known to sometimes cause issues. Do you really need to ramp up that fast?" @@ -507,12 +529,12 @@ def start(self, user_count: int, spawn_rate: float, wait: bool = False): ) self.spawning_greenlet.link_exception(greenlet_exception_handler) - def stop(self): + def stop(self) -> None: if self.state == STATE_STOPPED: return super().stop() - def send_message(self, msg_type, data=None): + def send_message(self, msg_type: str, data: Optional[Any] = None) -> None: """ Emulates internodal messaging by calling registered listeners @@ -529,20 +551,20 @@ def send_message(self, msg_type, data=None): class DistributedRunner(Runner): - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self._local_worker_node = None setup_distributed_stats_event_listeners(self.environment.events, self.stats) class WorkerNode: - def __init__(self, id: str, state=STATE_INIT, heartbeat_liveness=HEARTBEAT_LIVENESS): + def __init__(self, id: str, state=STATE_INIT, heartbeat_liveness=HEARTBEAT_LIVENESS) -> None: self.id: str = id self.state = state self.heartbeat = heartbeat_liveness - self.cpu_usage = 0 + self.cpu_usage: int = 0 self.cpu_warning_emitted = False - self.memory_usage = 0 + self.memory_usage: int = 0 # The reported users running on the worker self.user_classes_count: Dict[str, int] = {} @@ -553,7 +575,7 @@ def user_count(self) -> int: class WorkerNodes(MutableMapping): def __init__(self): - self._worker_nodes = {} + self._worker_nodes: Dict[str, WorkerNode] = {} def get_by_state(self, state) -> List[WorkerNode]: return [c for c in self.values() if c.state == state] @@ -614,7 +636,7 @@ def __init__(self, environment, master_bind_host, master_bind_port): self.worker_cpu_warning_emitted = False self.master_bind_host = master_bind_host self.master_bind_port = master_bind_port - self.spawn_rate: float = 0 + self.spawn_rate: float = 0.0 self.spawning_completed = False self.clients = WorkerNodes() @@ -632,13 +654,13 @@ def __init__(self, environment, master_bind_host, master_bind_port): else: raise - self._users_dispatcher: Union[UsersDispatcher, None] = None + self._users_dispatcher: Optional[UsersDispatcher] = None self.greenlet.spawn(self.heartbeat_worker).link_exception(greenlet_exception_handler) self.greenlet.spawn(self.client_listener).link_exception(greenlet_exception_handler) # listener that gathers info on how many users the worker has spawned - def on_worker_report(client_id, data): + def on_worker_report(client_id: str, data: Dict[str, Any]) -> None: if client_id not in self.clients: logger.info("Discarded report from unrecognized worker %s", client_id) return @@ -647,19 +669,21 @@ def on_worker_report(client_id, data): self.environment.events.worker_report.add_listener(on_worker_report) # register listener that sends quit message to worker nodes - def on_quitting(environment, **kw): + def on_quitting(environment: "Environment", **kw): self.quit() self.environment.events.quitting.add_listener(on_quitting) def rebalancing_enabled(self) -> bool: - return self.environment.parsed_options and self.environment.parsed_options.enable_rebalancing + return self.environment.parsed_options is not None and cast( + bool, self.environment.parsed_options.enable_rebalancing + ) @property def user_count(self) -> int: return sum(c.user_count for c in self.clients.values()) - def cpu_log_warning(self): + def cpu_log_warning(self) -> bool: warning_emitted = Runner.cpu_log_warning(self) if self.worker_cpu_warning_emitted: logger.warning("CPU usage threshold was exceeded on workers during the test!") @@ -795,14 +819,18 @@ def _wait_for_workers_report_after_ramp_up(self) -> float: else: return float(match.group("coeff")) * WORKER_REPORT_INTERVAL - def stop(self, send_stop_to_client: bool = True): + def stop(self, send_stop_to_client: bool = True) -> None: if self.state not in [STATE_INIT, STATE_STOPPED, STATE_STOPPING]: logger.debug("Stopping...") self.environment.events.test_stopping.fire(environment=self.environment) self.final_user_classes_count = {**self.reported_user_classes_count} self.update_state(STATE_STOPPING) - if self.environment.shape_class is not None and self.shape_greenlet is not greenlet.getcurrent(): + if ( + self.environment.shape_class is not None + and self.shape_greenlet is not None + and self.shape_greenlet is not greenlet.getcurrent() + ): self.shape_greenlet.kill(block=True) self.shape_greenlet = None self.shape_last_state = None @@ -826,7 +854,7 @@ def stop(self, send_stop_to_client: bool = True): timeout.cancel() self.environment.events.test_stop.fire(environment=self.environment) - def quit(self): + def quit(self) -> None: self.stop(send_stop_to_client=False) logger.debug("Quitting...") for client in self.clients.all: @@ -835,7 +863,7 @@ def quit(self): gevent.sleep(0.5) # wait for final stats report from all workers self.greenlet.kill(block=True) - def check_stopped(self): + def check_stopped(self) -> None: if ( not self.state == STATE_INIT and not self.state == STATE_STOPPED @@ -843,7 +871,7 @@ def check_stopped(self): ): self.update_state(STATE_STOPPED) - def heartbeat_worker(self): + def heartbeat_worker(self) -> NoReturn: while True: gevent.sleep(HEARTBEAT_INTERVAL) if self.connection_broken: @@ -882,7 +910,7 @@ def heartbeat_worker(self): # trigger redistribution after missing cclient removal self.start(user_count=self.target_user_count, spawn_rate=self.spawn_rate) - def reset_connection(self): + def reset_connection(self) -> None: logger.info("Reset connection to worker") try: self.server.close() @@ -891,7 +919,7 @@ def reset_connection(self): except RPCError as e: logger.error(f"Temporary failure when resetting connection: {e}, will retry later.") - def client_listener(self): + def client_listener(self) -> NoReturn: while True: try: client_id, msg = self.server.recv_from_client() @@ -1003,7 +1031,7 @@ def client_listener(self): self.check_stopped() @property - def worker_count(self): + def worker_count(self) -> int: return len(self.clients.ready) + len(self.clients.spawning) + len(self.clients.running) @property @@ -1014,7 +1042,7 @@ def reported_user_classes_count(self) -> Dict[str, int]: reported_user_classes_count[name] += count return reported_user_classes_count - def send_message(self, msg_type, data=None, client_id=None): + def send_message(self, msg_type: str, data: Optional[Dict[str, Any]] = None, client_id: Optional[str] = None): """ Sends a message to attached worker node(s) @@ -1041,7 +1069,7 @@ class WorkerRunner(DistributedRunner): take the stats generated by the running users and send back to the :class:`MasterRunner`. """ - def __init__(self, environment, master_host, master_port): + def __init__(self, environment: "Environment", master_host: str, master_port: int) -> None: """ :param environment: Environment instance :param master_host: Host/IP to use for connection to the master @@ -1053,7 +1081,7 @@ def __init__(self, environment, master_host, master_port): self.master_host = master_host self.master_port = master_port self.worker_cpu_warning_emitted = False - self._users_dispatcher = None + self._users_dispatcher: Optional[UsersDispatcher] = None self.client = rpc.Client(master_host, master_port, self.client_id) self.greenlet.spawn(self.heartbeat).link_exception(greenlet_exception_handler) self.greenlet.spawn(self.worker).link_exception(greenlet_exception_handler) @@ -1061,7 +1089,7 @@ def __init__(self, environment, master_host, master_port): self.greenlet.spawn(self.stats_reporter).link_exception(greenlet_exception_handler) # register listener for when all users have spawned, and report it to the master node - def on_spawning_complete(user_count): + def on_spawning_complete(user_count: int) -> None: assert user_count == sum(self.user_classes_count.values()) self.client.send( Message( @@ -1075,29 +1103,29 @@ def on_spawning_complete(user_count): self.environment.events.spawning_complete.add_listener(on_spawning_complete) # register listener that adds the current number of spawned users to the report that is sent to the master node - def on_report_to_master(client_id, data): + def on_report_to_master(client_id: str, data: Dict[str, Any]): data["user_classes_count"] = self.user_classes_count data["user_count"] = self.user_count self.environment.events.report_to_master.add_listener(on_report_to_master) # register listener that sends quit message to master - def on_quitting(environment, **kw): + def on_quitting(environment: "Environment", **kw) -> None: self.client.send(Message("quit", None, self.client_id)) self.environment.events.quitting.add_listener(on_quitting) # register listener that's sends user exceptions to master - def on_user_error(user_instance, exception, tb): + def on_user_error(user_instance: User, exception: Exception, tb: TracebackType) -> None: formatted_tb = "".join(traceback.format_tb(tb)) self.client.send(Message("exception", {"msg": str(exception), "traceback": formatted_tb}, self.client_id)) self.environment.events.user_error.add_listener(on_user_error) - def start(self, user_count, spawn_rate, wait=False): + def start(self, user_count: int, spawn_rate: float, wait: bool = False) -> None: raise NotImplementedError("use start_worker") - def start_worker(self, user_classes_count: Dict[str, int], **kwargs): + def start_worker(self, user_classes_count: Dict[str, int], **kwargs) -> None: """ Start running a load test as a worker @@ -1110,8 +1138,8 @@ def start_worker(self, user_classes_count: Dict[str, int], **kwargs): if self.environment.host: user_class.host = self.environment.host - user_classes_spawn_count = {} - user_classes_stop_count = {} + user_classes_spawn_count: Dict[Type[User], int] = {} + user_classes_stop_count: Dict[Type[User], int] = {} for user_class, user_class_count in user_classes_count.items(): if self.user_classes_count[user_class] > user_class_count: @@ -1126,7 +1154,7 @@ def start_worker(self, user_classes_count: Dict[str, int], **kwargs): self.environment.events.spawning_complete.fire(user_count=sum(self.user_classes_count.values())) - def heartbeat(self): + def heartbeat(self) -> NoReturn: while True: try: self.client.send( @@ -1145,7 +1173,7 @@ def heartbeat(self): self.reset_connection() gevent.sleep(HEARTBEAT_INTERVAL) - def reset_connection(self): + def reset_connection(self) -> None: logger.info("Reset connection to master") try: self.client.close() @@ -1153,7 +1181,7 @@ def reset_connection(self): except RPCError as e: logger.error(f"Temporary failure when resetting connection: {e}, will retry later.") - def worker(self): + def worker(self) -> NoReturn: last_received_spawn_timestamp = 0 while True: try: @@ -1222,7 +1250,7 @@ def worker(self): else: logger.warning(f"Unknown message type received: {msg.type}") - def stats_reporter(self): + def stats_reporter(self) -> NoReturn: while True: try: self._send_stats() @@ -1230,7 +1258,7 @@ def stats_reporter(self): logger.error(f"Temporary connection lost to master server: {e}, will retry later.") gevent.sleep(WORKER_REPORT_INTERVAL) - def send_message(self, msg_type, data=None): + def send_message(self, msg_type: str, data: Optional[Dict[str, Any]] = None) -> None: """ Sends a message to master node @@ -1240,7 +1268,7 @@ def send_message(self, msg_type, data=None): logger.debug(f"Sending {msg_type} message to master") self.client.send(Message(msg_type, data, self.client_id)) - def _send_stats(self): + def _send_stats(self) -> None: data = {} self.environment.events.report_to_master.fire(client_id=self.client_id, data=data) self.client.send(Message("stats", data, self.client_id)) diff --git a/locust/stats.py b/locust/stats.py index 9d2390eddc..e865475799 100644 --- a/locust/stats.py +++ b/locust/stats.py @@ -1,5 +1,6 @@ import datetime import hashlib +from tempfile import NamedTemporaryFile import time from collections import namedtuple, OrderedDict from copy import copy @@ -8,14 +9,40 @@ import csv import signal import gevent -from typing import Dict, Tuple + +from typing import ( + TYPE_CHECKING, + Any, + Dict, + NoReturn, + Tuple, + List, + Union, + TypedDict, + Optional, + OrderedDict as OrderedDictType, + Callable, + TypeVar, + cast, +) +from types import FrameType from .exception import CatchResponseError +from .event import Events import logging +with NamedTemporaryFile(mode="w") as t: + CSVWriter = type(csv.writer(t)) + +if TYPE_CHECKING: + from .runners import Runner + from .env import Environment + console_logger = logging.getLogger("locust.stats_logger") +S = TypeVar("S", bound="StatsBase") + """Space in table for request name. Auto shrink it if terminal is small (<160 characters)""" try: STATS_NAME_WIDTH = max(min(os.get_terminal_size()[0] - 80, 80), 0) @@ -25,7 +52,38 @@ STATS_AUTORESIZE = True # overwrite this if you dont want auto resize while running -def resize_handler(signum, frame): +class StatsBaseDict(TypedDict): + name: str + method: str + + +class StatsEntryDict(StatsBaseDict): + last_request_timestamp: Optional[float] + start_time: float + num_requests: int + num_none_requests: int + num_failures: int + total_response_time: int + max_response_time: int + min_response_time: Optional[int] + total_content_length: int + response_times: Dict[int, int] + num_reqs_per_sec: Dict[int, int] + num_fail_per_sec: Dict[int, int] + + +class StatsErrorDict(StatsBaseDict): + error: str + occurrences: int + + +class StatsBase: + def __init__(self, name: str, method: str) -> None: + self.name = name + self.method = method + + +def resize_handler(signum: int, frame: Optional[FrameType]): global STATS_NAME_WIDTH if STATS_AUTORESIZE: try: @@ -68,7 +126,7 @@ class RequestStatsAdditionError(Exception): pass -def get_readable_percentiles(percentile_list): +def get_readable_percentiles(percentile_list: List[float]) -> List[str]: """ Converts a list of percentiles from 0-1 fraction to 0%-100% view for using in console & csv reporting :param percentile_list: The list of percentiles in range 0-1 @@ -80,7 +138,7 @@ def get_readable_percentiles(percentile_list): ] -def calculate_response_time_percentile(response_times, num_requests, percent): +def calculate_response_time_percentile(response_times: Dict[int, int], num_requests: int, percent: float) -> int: """ Get the response time that a certain number of percent of the requests finished within. Arguments: @@ -101,7 +159,7 @@ def calculate_response_time_percentile(response_times, num_requests, percent): return 0 -def diff_response_time_dicts(latest, old): +def diff_response_time_dicts(latest: Dict[int, int], old: Dict[int, int]) -> Dict[int, int]: """ Returns the delta between two {response_times:request_count} dicts. @@ -155,11 +213,11 @@ def last_request_timestamp(self): def start_time(self): return self.total.start_time - def log_request(self, method, name, response_time, content_length): + def log_request(self, method: str, name: str, response_time: int, content_length: int) -> None: self.total.log(response_time, content_length) self.get(name, method).log(response_time, content_length) - def log_error(self, method, name, error): + def log_error(self, method: str, name: str, error: Optional[Union[Exception, str]]) -> None: self.total.log_error(error) self.get(name, method).log_error(error) @@ -171,7 +229,7 @@ def log_error(self, method, name, error): self.errors[key] = entry entry.occurred() - def get(self, name, method): + def get(self, name: str, method: str) -> "StatsEntry": """ Retrieve a StatsEntry instance by name and method """ @@ -181,7 +239,7 @@ def get(self, name, method): self.entries[(name, method)] = entry return entry - def reset_all(self): + def reset_all(self) -> None: """ Go through all stats entries and reset them to zero """ @@ -191,32 +249,33 @@ def reset_all(self): r.reset() self.history = [] - def clear_all(self): + def clear_all(self) -> None: """ Remove all stats entries and errors """ - self.total = StatsEntry(self, "Aggregated", None, use_response_times_cache=self.use_response_times_cache) + self.total = StatsEntry(self, "Aggregated", "", use_response_times_cache=self.use_response_times_cache) self.entries = {} self.errors = {} self.history = [] - def serialize_stats(self): + def serialize_stats(self) -> List["StatsEntryDict"]: return [ self.entries[key].get_stripped_report() for key in self.entries.keys() if not (self.entries[key].num_requests == 0 and self.entries[key].num_failures == 0) ] - def serialize_errors(self): + def serialize_errors(self) -> Dict[str, "StatsErrorDict"]: return {k: e.to_dict() for k, e in self.errors.items()} -class StatsEntry: +class StatsEntry(StatsBase): """ Represents a single stats entry (name and method) """ - def __init__(self, stats: RequestStats, name: str, method: str, use_response_times_cache=False): + def __init__(self, stats: Optional[RequestStats], name: str, method: str, use_response_times_cache: bool = False): + super().__init__(name, method) self.stats = stats self.name = name """ Name (URL) of this stats entry """ @@ -229,17 +288,17 @@ def __init__(self, stats: RequestStats, name: str, method: str, use_response_tim We can use this dict to calculate the *current* median response time, as well as other response time percentiles. """ - self.num_requests = 0 + self.num_requests: int = 0 """ The number of requests made """ - self.num_none_requests = 0 + self.num_none_requests: int = 0 """ The number of requests made with a None response time (typically async requests) """ - self.num_failures = 0 + self.num_failures: int = 0 """ Number of failed request """ - self.total_response_time = 0 + self.total_response_time: int = 0 """ Total sum of the response times """ - self.min_response_time = None + self.min_response_time: Optional[int] = None """ Minimum response time """ - self.max_response_time = 0 + self.max_response_time: int = 0 """ Maximum response time """ self.num_reqs_per_sec: Dict[int, int] = {} """ A {second => request_count} dict that holds the number of requests made per second """ @@ -255,16 +314,16 @@ def __init__(self, stats: RequestStats, name: str, method: str, use_response_tim This dict is used to calculate the median and percentile response times. """ - self.response_times_cache = None + self.response_times_cache: OrderedDictType[int, CachedResponseTimes] """ If use_response_times_cache is set to True, this will be a {timestamp => CachedResponseTimes()} OrderedDict that holds a copy of the response_times dict for each of the last 20 seconds. """ - self.total_content_length = 0 + self.total_content_length: int = 0 """ The sum of the content length of all the requests for this entry """ - self.start_time = 0.0 + self.start_time: float = 0.0 """ Time of the first request for this entry """ - self.last_request_timestamp = None + self.last_request_timestamp: Optional[float] = None """ Time of the last request for this entry """ self.reset() @@ -285,7 +344,7 @@ def reset(self): self.response_times_cache = OrderedDict() self._cache_response_times(int(time.time())) - def log(self, response_time, content_length): + def log(self, response_time: int, content_length: int) -> None: # get the time current_time = time.time() t = int(current_time) @@ -301,12 +360,12 @@ def log(self, response_time, content_length): # increase total content-length self.total_content_length += content_length - def _log_time_of_request(self, current_time): + def _log_time_of_request(self, current_time: float) -> None: t = int(current_time) self.num_reqs_per_sec[t] = self.num_reqs_per_sec.setdefault(t, 0) + 1 self.last_request_timestamp = current_time - def _log_response_time(self, response_time): + def _log_response_time(self, response_time: int) -> None: if response_time is None: self.num_none_requests += 1 return @@ -335,13 +394,13 @@ def _log_response_time(self, response_time): self.response_times.setdefault(rounded_response_time, 0) self.response_times[rounded_response_time] += 1 - def log_error(self, error): + def log_error(self, error: Optional[Union[Exception, str]]) -> None: self.num_failures += 1 t = int(time.time()) self.num_fail_per_sec[t] = self.num_fail_per_sec.setdefault(t, 0) + 1 @property - def fail_ratio(self): + def fail_ratio(self) -> float: try: return float(self.num_failures) / self.num_requests except ZeroDivisionError: @@ -351,14 +410,14 @@ def fail_ratio(self): return 0.0 @property - def avg_response_time(self): + def avg_response_time(self) -> float: try: return float(self.total_response_time) / (self.num_requests - self.num_none_requests) except ZeroDivisionError: - return 0 + return 0.0 @property - def median_response_time(self): + def median_response_time(self) -> int: if not self.response_times: return 0 median = median_from_dict(self.num_requests - self.num_none_requests, self.response_times) or 0 @@ -369,18 +428,18 @@ def median_response_time(self): # have one (or very few) really slow requests if median > self.max_response_time: median = self.max_response_time - elif median < self.min_response_time: + elif self.min_response_time is not None and median < self.min_response_time: median = self.min_response_time return median @property - def current_rps(self): - if self.stats.last_request_timestamp is None: + def current_rps(self) -> float: + if self.stats is None or self.stats.last_request_timestamp is None: return 0 slice_start_time = max(int(self.stats.last_request_timestamp) - 12, int(self.stats.start_time or 0)) - reqs = [ + reqs: List[Union[int, float]] = [ self.num_reqs_per_sec.get(t, 0) for t in range(slice_start_time, int(self.stats.last_request_timestamp) - 2) ] return avg(reqs) @@ -421,7 +480,7 @@ def avg_content_length(self): except ZeroDivisionError: return 0 - def extend(self, other): + def extend(self, other: "StatsEntry") -> None: """ Extend the data from the current StatsEntry with the stats from another StatsEntry instance. @@ -461,49 +520,27 @@ def extend(self, other): # lag behind a second or two, but since StatsEntry.current_response_time_percentile() # (which is what the response times cache is used for) uses an approximation of the # last 10 seconds anyway, it should be fine to ignore this. - last_time = self.last_request_timestamp and int(self.last_request_timestamp) or None + last_time = int(self.last_request_timestamp) if self.last_request_timestamp else None if last_time and last_time > (old_last_request_timestamp and int(old_last_request_timestamp) or 0): self._cache_response_times(last_time) - def serialize(self): - return { - "name": self.name, - "method": self.method, - "last_request_timestamp": self.last_request_timestamp, - "start_time": self.start_time, - "num_requests": self.num_requests, - "num_none_requests": self.num_none_requests, - "num_failures": self.num_failures, - "total_response_time": self.total_response_time, - "max_response_time": self.max_response_time, - "min_response_time": self.min_response_time, - "total_content_length": self.total_content_length, - "response_times": self.response_times, - "num_reqs_per_sec": self.num_reqs_per_sec, - "num_fail_per_sec": self.num_fail_per_sec, - } + def serialize(self) -> StatsEntryDict: + return cast(StatsEntryDict, {key: getattr(self, key, None) for key in StatsEntryDict.__annotations__.keys()}) @classmethod - def unserialize(cls, data): + def unserialize(cls, data: StatsEntryDict) -> "StatsEntry": + """Return the unserialzed version of the specified dict""" obj = cls(None, data["name"], data["method"]) - for key in [ - "last_request_timestamp", - "start_time", - "num_requests", - "num_none_requests", - "num_failures", - "total_response_time", - "max_response_time", - "min_response_time", - "total_content_length", - "response_times", - "num_reqs_per_sec", - "num_fail_per_sec", - ]: - setattr(obj, key, data[key]) + valid_keys = StatsEntryDict.__annotations__.keys() + + for key, value in data.items(): + if key in ["name", "method"] or key not in valid_keys: + continue + + setattr(obj, key, value) return obj - def get_stripped_report(self): + def get_stripped_report(self) -> StatsEntryDict: """ Return the serialized version of this StatsEntry, and then clear the current stats. """ @@ -511,7 +548,7 @@ def get_stripped_report(self): self.reset() return report - def to_string(self, current=True): + def to_string(self, current=True) -> str: """ Return the stats as a string suitable for console output. If current is True, it'll show the RPS and failure rate for the last 10 seconds. If it's false, it'll show the total stats @@ -542,10 +579,10 @@ def to_string(self, current=True): fail_per_sec or 0, ) - def __str__(self): + def __str__(self) -> str: return self.to_string(current=True) - def get_response_time_percentile(self, percent): + def get_response_time_percentile(self, percent: float) -> int: """ Get the response time that a certain number of percent of the requests finished within. @@ -554,7 +591,7 @@ def get_response_time_percentile(self, percent): """ return calculate_response_time_percentile(self.response_times, self.num_requests, percent) - def get_current_response_time_percentile(self, percent): + def get_current_response_time_percentile(self, percent: float) -> Optional[int]: """ Calculate the *current* response time for a certain percentile. We use a sliding window of (approximately) the last 10 seconds (specified by CURRENT_RESPONSE_TIME_PERCENTILE_WINDOW) @@ -572,13 +609,13 @@ def get_current_response_time_percentile(self, percent): # when trying to fetch the cached response_times. We construct this list in such a way # that it's ordered by preference by starting to add t-10, then t-11, t-9, t-12, t-8, # and so on - acceptable_timestamps = [] + acceptable_timestamps: List[int] = [] acceptable_timestamps.append(t - CURRENT_RESPONSE_TIME_PERCENTILE_WINDOW) for i in range(1, 9): acceptable_timestamps.append(t - CURRENT_RESPONSE_TIME_PERCENTILE_WINDOW - i) acceptable_timestamps.append(t - CURRENT_RESPONSE_TIME_PERCENTILE_WINDOW + i) - cached = None + cached: Optional[CachedResponseTimes] = None for ts in acceptable_timestamps: if ts in self.response_times_cache: cached = self.response_times_cache[ts] @@ -597,7 +634,7 @@ def get_current_response_time_percentile(self, percent): # if time was not in response times cache window return None - def percentile(self): + def percentile(self) -> str: if not self.num_requests: raise ValueError("Can't calculate percentile on url with no successful requests") @@ -609,7 +646,7 @@ def percentile(self): + (self.num_requests,) ) - def _cache_response_times(self, t): + def _cache_response_times(self, t: int) -> None: self.response_times_cache[t] = CachedResponseTimes( response_times=copy(self.response_times), num_requests=self.num_requests, @@ -623,19 +660,20 @@ def _cache_response_times(self, t): if len(self.response_times_cache) > cache_size: # only keep the latest 20 response_times dicts - for i in range(len(self.response_times_cache) - cache_size): + for _ in range(len(self.response_times_cache) - cache_size): self.response_times_cache.popitem(last=False) -class StatsError: - def __init__(self, method, name, error, occurrences=0): +class StatsError(StatsBase): + def __init__(self, method: str, name: str, error: Optional[Union[Exception, str]], occurrences: int = 0): + super().__init__(name, method) self.method = method self.name = name self.error = error self.occurrences = occurrences @classmethod - def parse_error(cls, error): + def parse_error(cls, error: Optional[Union[Exception, str]]) -> str: string_error = repr(error) target = "object at 0x" target_index = string_error.find(target) @@ -649,14 +687,14 @@ def parse_error(cls, error): return string_error.replace(hex_address, "0x....") @classmethod - def create_key(cls, method, name, error): + def create_key(cls, method: str, name: str, error: Optional[Union[Exception, str]]) -> str: key = f"{method}.{name}.{StatsError.parse_error(error)!r}" return hashlib.md5(key.encode("utf-8")).hexdigest() - def occurred(self): + def occurred(self) -> None: self.occurrences += 1 - def to_name(self): + def to_name(self) -> str: error = self.error if isinstance(error, CatchResponseError): # standalone @@ -671,24 +709,19 @@ def to_name(self): return f"{self.method} {self.name}: {unwrapped_error}" - def to_dict(self): - return { - "method": self.method, - "name": self.name, - "error": StatsError.parse_error(self.error), - "occurrences": self.occurrences, - } + def to_dict(self) -> StatsErrorDict: + return cast(StatsErrorDict, {key: getattr(self, key, None) for key in StatsErrorDict.__annotations__.keys()}) @classmethod - def from_dict(cls, data): + def from_dict(cls, data: StatsErrorDict) -> "StatsError": return cls(data["method"], data["name"], data["error"], data["occurrences"]) -def avg(values): +def avg(values: List[Union[float, int]]) -> float: return sum(values, 0.0) / max(len(values), 1) -def median_from_dict(total, count): +def median_from_dict(total: int, count: Dict[int, int]) -> int: """ total is the number of requests made count is a dict {response_time: count} @@ -699,15 +732,17 @@ def median_from_dict(total, count): return k pos -= count[k] + return k + -def setup_distributed_stats_event_listeners(events, stats): - def on_report_to_master(client_id, data): +def setup_distributed_stats_event_listeners(events: Events, stats: RequestStats) -> None: + def on_report_to_master(client_id: str, data: Dict[str, Any]) -> None: data["stats"] = stats.serialize_stats() data["stats_total"] = stats.total.get_stripped_report() data["errors"] = stats.serialize_errors() stats.errors = {} - def on_worker_report(client_id, data): + def on_worker_report(client_id: str, data: Dict[str, Any]) -> None: for stats_data in data["stats"]: entry = StatsEntry.unserialize(stats_data) request_key = (entry.name, entry.method) @@ -727,7 +762,7 @@ def on_worker_report(client_id, data): events.worker_report.add_listener(on_worker_report) -def print_stats(stats, current=True): +def print_stats(stats: RequestStats, current=True) -> None: name_column_width = (STATS_NAME_WIDTH - STATS_TYPE_WIDTH) + 4 # saved characters by compacting other columns console_logger.info( ("%-" + str(STATS_TYPE_WIDTH) + "s %-" + str(name_column_width) + "s %7s %12s |%7s %7s %7s%7s | %7s %11s") @@ -743,7 +778,7 @@ def print_stats(stats, current=True): console_logger.info("") -def print_percentile_stats(stats): +def print_percentile_stats(stats: RequestStats) -> None: console_logger.info("Response time percentiles (approximated)") headers = ("Type", "Name") + tuple(get_readable_percentiles(PERCENTILES_TO_REPORT)) + ("# reqs",) console_logger.info( @@ -768,7 +803,7 @@ def print_percentile_stats(stats): console_logger.info("") -def print_error_report(stats): +def print_error_report(stats: RequestStats) -> None: if not len(stats.errors): return console_logger.info("Error report") @@ -781,8 +816,8 @@ def print_error_report(stats): console_logger.info("") -def stats_printer(stats): - def stats_printer_func(): +def stats_printer(stats: RequestStats) -> Callable[[], None]: + def stats_printer_func() -> None: while True: print_stats(stats) gevent.sleep(CONSOLE_STATS_INTERVAL_SEC) @@ -790,11 +825,11 @@ def stats_printer_func(): return stats_printer_func -def sort_stats(stats): +def sort_stats(stats: Dict[Any, S]) -> List[S]: return [stats[key] for key in sorted(stats.keys())] -def stats_history(runner): +def stats_history(runner: "Runner") -> None: """Save current stats info to history for charts of report.""" while True: stats = runner.stats @@ -816,8 +851,7 @@ def stats_history(runner): class StatsCSV: """Write statistics to csv_writer stream.""" - def __init__(self, environment, percentiles_to_report): - super().__init__() + def __init__(self, environment: "Environment", percentiles_to_report: List[float]) -> None: self.environment = environment self.percentiles_to_report = percentiles_to_report @@ -851,7 +885,7 @@ def __init__(self, environment, percentiles_to_report): "Nodes", ] - def _percentile_fields(self, stats_entry, use_current=False): + def _percentile_fields(self, stats_entry: StatsEntry, use_current: bool = False) -> Union[List[str], List[int]]: if not stats_entry.num_requests: return self.percentiles_na elif use_current: @@ -859,12 +893,12 @@ def _percentile_fields(self, stats_entry, use_current=False): else: return [int(stats_entry.get_response_time_percentile(x) or 0) for x in self.percentiles_to_report] - def requests_csv(self, csv_writer): + def requests_csv(self, csv_writer: CSVWriter) -> None: """Write requests csv with header and data rows.""" csv_writer.writerow(self.requests_csv_columns) self._requests_data_rows(csv_writer) - def _requests_data_rows(self, csv_writer): + def _requests_data_rows(self, csv_writer: CSVWriter) -> None: """Write requests csv data row, excluding header.""" stats = self.environment.stats for stats_entry in chain(sort_stats(stats.entries), [stats.total]): @@ -887,11 +921,11 @@ def _requests_data_rows(self, csv_writer): ) ) - def failures_csv(self, csv_writer): + def failures_csv(self, csv_writer: CSVWriter) -> None: csv_writer.writerow(self.failures_columns) self._failures_data_rows(csv_writer) - def _failures_data_rows(self, csv_writer): + def _failures_data_rows(self, csv_writer: CSVWriter) -> None: for stats_error in sort_stats(self.environment.stats.errors): csv_writer.writerow( [ @@ -902,11 +936,14 @@ def _failures_data_rows(self, csv_writer): ] ) - def exceptions_csv(self, csv_writer): + def exceptions_csv(self, csv_writer: CSVWriter) -> None: csv_writer.writerow(self.exceptions_columns) self._exceptions_data_rows(csv_writer) - def _exceptions_data_rows(self, csv_writer): + def _exceptions_data_rows(self, csv_writer: CSVWriter) -> None: + if self.environment.runner is None: + return + for exc in self.environment.runner.exceptions.values(): csv_writer.writerow([exc["count"], exc["msg"], exc["traceback"], ", ".join(exc["nodes"])]) @@ -914,7 +951,13 @@ def _exceptions_data_rows(self, csv_writer): class StatsCSVFileWriter(StatsCSV): """Write statistics to to CSV files""" - def __init__(self, environment, percentiles_to_report, base_filepath, full_history=False): + def __init__( + self, + environment: "Environment", + percentiles_to_report: List[float], + base_filepath: str, + full_history: bool = False, + ): super().__init__(environment, percentiles_to_report) self.base_filepath = base_filepath self.full_history = full_history @@ -927,11 +970,11 @@ def __init__(self, environment, percentiles_to_report, base_filepath, full_histo self.failures_csv_filehandle = open(self.base_filepath + "_failures.csv", "w") self.failures_csv_writer = csv.writer(self.failures_csv_filehandle) - self.failures_csv_data_start = 0 + self.failures_csv_data_start: int = 0 self.exceptions_csv_filehandle = open(self.base_filepath + "_exceptions.csv", "w") self.exceptions_csv_writer = csv.writer(self.exceptions_csv_filehandle) - self.exceptions_csv_data_start = 0 + self.exceptions_csv_data_start: int = 0 self.stats_history_csv_columns = [ "Timestamp", @@ -950,10 +993,10 @@ def __init__(self, environment, percentiles_to_report, base_filepath, full_histo "Total Average Content Size", ] - def __call__(self): + def __call__(self) -> None: self.stats_writer() - def stats_writer(self): + def stats_writer(self) -> NoReturn: """Writes all the csv files for the locust run.""" # Write header row for all files and save position for non-append files @@ -969,7 +1012,7 @@ def stats_writer(self): self.exceptions_csv_data_start = self.exceptions_csv_filehandle.tell() # Continuously write date rows for all files - last_flush_time = 0 + last_flush_time: float = 0.0 while True: now = time.time() @@ -996,7 +1039,7 @@ def stats_writer(self): gevent.sleep(CSV_STATS_INTERVAL_SEC) - def _stats_history_data_rows(self, csv_writer, now): + def _stats_history_data_rows(self, csv_writer: CSVWriter, now: float) -> None: """ Write CSV rows with the *current* stats. By default only includes the Aggregated stats entry, but if self.full_history is set to True, a row for each entry will @@ -1007,7 +1050,7 @@ def _stats_history_data_rows(self, csv_writer, now): stats = self.environment.stats timestamp = int(now) - stats_entries = [] + stats_entries: List[StatsEntry] = [] if self.full_history: stats_entries = sort_stats(stats.entries) @@ -1016,7 +1059,7 @@ def _stats_history_data_rows(self, csv_writer, now): chain( ( timestamp, - self.environment.runner.user_count, + self.environment.runner.user_count if self.environment.runner is not None else 0, stats_entry.method or "", stats_entry.name, f"{stats_entry.current_rps:2f}", @@ -1035,23 +1078,23 @@ def _stats_history_data_rows(self, csv_writer, now): ) ) - def requests_flush(self): + def requests_flush(self) -> None: self.requests_csv_filehandle.flush() - def stats_history_flush(self): + def stats_history_flush(self) -> None: self.stats_history_csv_filehandle.flush() - def failures_flush(self): + def failures_flush(self) -> None: self.failures_csv_filehandle.flush() - def exceptions_flush(self): + def exceptions_flush(self) -> None: self.exceptions_csv_filehandle.flush() - def close_files(self): + def close_files(self) -> None: self.requests_csv_filehandle.close() self.stats_history_csv_filehandle.close() self.failures_csv_filehandle.close() self.exceptions_csv_filehandle.close() - def stats_history_file_name(self): + def stats_history_file_name(self) -> str: return self.base_filepath + "_stats_history.csv" diff --git a/locust/user/inspectuser.py b/locust/user/inspectuser.py index e1a5cdb279..c88b83950b 100644 --- a/locust/user/inspectuser.py +++ b/locust/user/inspectuser.py @@ -1,8 +1,10 @@ from collections import defaultdict import inspect from json import dumps +from typing import List, Type, Dict from .task import TaskSet +from .users import User def print_task_ratio(user_classes, num_users, total): @@ -47,11 +49,11 @@ def _print_task_ratio(x, level=0): _print_task_ratio(v["tasks"], level + 1) -def get_ratio(user_classes, user_spawned, total): +def get_ratio(user_classes: List[Type[User]], user_spawned: Dict[str, int], total: bool) -> Dict[str, Dict[str, float]]: user_count = sum(user_spawned.values()) or 1 - ratio_percent = {u: user_spawned.get(u.__name__, 0) / user_count for u in user_classes} + ratio_percent: Dict[Type[User], float] = {u: user_spawned.get(u.__name__, 0) / user_count for u in user_classes} - task_dict = {} + task_dict: Dict[str, Dict[str, float]] = {} for u, r in ratio_percent.items(): d = {"ratio": r} d["tasks"] = _get_task_ratio(u.tasks, total, r) diff --git a/locust/user/task.py b/locust/user/task.py index c3fcd331f2..e9cd9511b3 100644 --- a/locust/user/task.py +++ b/locust/user/task.py @@ -2,7 +2,7 @@ import random import traceback from time import time -from typing import TYPE_CHECKING, Callable, List, Union, TypeVar, Optional, Type, overload +from typing import TYPE_CHECKING, Callable, List, Union, TypeVar, Optional, Type, overload, Protocol, Dict, Set from typing_extensions import final import gevent @@ -15,11 +15,15 @@ logger = logging.getLogger(__name__) -TaskT = TypeVar("TaskT", bound=Union[Callable[..., None], Type["TaskSet"]]) +TaskT = TypeVar("TaskT", Callable[..., None], Type["TaskSet"]) LOCUST_STATE_RUNNING, LOCUST_STATE_WAITING, LOCUST_STATE_STOPPING = ["running", "waiting", "stopping"] +class TaskHolder(Protocol[TaskT]): + tasks: List[TaskT] + + @overload def task(weight: TaskT) -> TaskT: ... @@ -152,7 +156,12 @@ def get_tasks_from_base_classes(bases, class_dict): return new_tasks -def filter_tasks_by_tags(task_holder, tags=None, exclude_tags=None, checked=None): +def filter_tasks_by_tags( + task_holder: Type[TaskHolder], + tags: Optional[Set[str]] = None, + exclude_tags: Optional[Set[str]] = None, + checked: Optional[Dict[TaskT, bool]] = None, +): """ Function used by Environment to recursively remove any tasks/TaskSets from a TaskSet/User that shouldn't be executed according to the tag options diff --git a/locust/web.py b/locust/web.py index 7b774cbdfd..c13b83de49 100644 --- a/locust/web.py +++ b/locust/web.py @@ -1,4 +1,5 @@ import csv +import json import logging import os.path from functools import wraps @@ -6,27 +7,29 @@ from io import StringIO from itertools import chain from time import time +from typing import TYPE_CHECKING, Optional, Union, Any, Dict, List import gevent -from flask import Flask, make_response, jsonify, render_template, request, send_file +from flask import Flask, make_response, jsonify, render_template, request, send_file, Response from flask_basicauth import BasicAuth from gevent import pywsgi -from typing import Any, Dict from .exception import AuthCredentialsError -from .runners import MasterRunner +from .runners import MasterRunner, STATE_MISSING from .log import greenlet_exception_logger -from .stats import sort_stats +from .stats import StatsCSVFileWriter, sort_stats from . import stats as stats_module, __version__ as version, argument_parser from .stats import StatsCSV from .user.inspectuser import get_ratio from .util.cache import memoize from .util.rounding import proper_round -from .util.timespan import parse_timespan from .html import get_html_report from flask_cors import CORS from json import dumps +if TYPE_CHECKING: + from .env import Environment + logger = logging.getLogger(__name__) greenlet_exception_handler = greenlet_exception_logger(logger) @@ -41,7 +44,7 @@ class WebUI: in :attr:`environment.stats ` """ - app = None + app: Optional[Flask] = None """ Reference to the :class:`flask.Flask` app. Can be used to add additional web routes and customize the Flask app in other various ways. Example:: @@ -53,12 +56,12 @@ def my_custom_route(): return "your IP is: %s" % request.remote_addr """ - greenlet = None + greenlet: Optional[gevent.Greenlet] = None """ Greenlet of the running web server """ - server = None + server: Optional[pywsgi.WSGIServer] = None """Reference to the :class:`pyqsgi.WSGIServer` instance""" template_args: Dict[str, Any] @@ -67,13 +70,13 @@ def my_custom_route(): def __init__( self, - environment, - host, - port, - auth_credentials=None, - tls_cert=None, - tls_key=None, - stats_csv_writer=None, + environment: "Environment", + host: str, + port: int, + auth_credentials: Optional[str] = None, + tls_cert: Optional[str] = None, + tls_key: Optional[str] = None, + stats_csv_writer: Optional[StatsCSV] = None, delayed_start=False, ): """ @@ -104,9 +107,9 @@ def __init__( app.debug = True app.root_path = os.path.dirname(os.path.abspath(__file__)) self.app.config["BASIC_AUTH_ENABLED"] = False - self.auth = None - self.greenlet = None - self._swarm_greenlet = None + self.auth: Optional[BasicAuth] = None + self.greenlet: Optional[gevent.Greenlet] = None + self._swarm_greenlet: Optional[gevent.Greenlet] = None self.template_args = {} if auth_credentials is not None: @@ -128,7 +131,7 @@ def __init__( @app.route("/") @self.auth_required_if_enabled - def index(): + def index() -> Union[str, Response]: if not environment.runner: return make_response("Error: Locust Environment does not have any runner", 500) self.update_template_args() @@ -136,7 +139,7 @@ def index(): @app.route("/swarm", methods=["POST"]) @self.auth_required_if_enabled - def swarm(): + def swarm() -> Response: assert request.method == "POST" parsed_options_dict = vars(environment.parsed_options) if environment.parsed_options else {} @@ -153,7 +156,7 @@ def swarm(): # This won't work for parameters that are None parsed_options_dict[key] = type(parsed_options_dict[key])(value) - if environment.shape_class: + if environment.shape_class and environment.runner is not None: environment.runner.start_shape() return jsonify( {"success": True, "message": "Swarming started using shape class", "host": environment.host} @@ -162,37 +165,43 @@ def swarm(): if self._swarm_greenlet is not None: self._swarm_greenlet.kill(block=True) self._swarm_greenlet = None - self._swarm_greenlet = gevent.spawn(environment.runner.start, user_count, spawn_rate) - self._swarm_greenlet.link_exception(greenlet_exception_handler) - return jsonify({"success": True, "message": "Swarming started", "host": environment.host}) + + if environment.runner is not None: + self._swarm_greenlet = gevent.spawn(environment.runner.start, user_count, spawn_rate) + self._swarm_greenlet.link_exception(greenlet_exception_handler) + return jsonify({"success": True, "message": "Swarming started", "host": environment.host}) + else: + return jsonify({"success": False, "message": "No runner", "host": environment.host}) @app.route("/stop") @self.auth_required_if_enabled - def stop(): + def stop() -> Response: if self._swarm_greenlet is not None: self._swarm_greenlet.kill(block=True) self._swarm_greenlet = None - environment.runner.stop() + if environment.runner is not None: + environment.runner.stop() return jsonify({"success": True, "message": "Test stopped"}) @app.route("/stats/reset") @self.auth_required_if_enabled - def reset_stats(): + def reset_stats() -> str: environment.events.reset_stats.fire() - environment.runner.stats.reset_all() - environment.runner.exceptions = {} + if environment.runner is not None: + environment.runner.stats.reset_all() + environment.runner.exceptions = {} return "ok" @app.route("/stats/report") @self.auth_required_if_enabled - def stats_report(): + def stats_report() -> Response: res = get_html_report(self.environment, show_download_link=not request.args.get("download")) if request.args.get("download"): res = app.make_response(res) res.headers["Content-Disposition"] = f"attachment;filename=report_{time()}.html" return res - def _download_csv_suggest_file_name(suggest_filename_prefix): + def _download_csv_suggest_file_name(suggest_filename_prefix: str) -> str: """Generate csv file download attachment filename suggestion. Arguments: @@ -201,7 +210,7 @@ def _download_csv_suggest_file_name(suggest_filename_prefix): return f"{suggest_filename_prefix}_{time()}.csv" - def _download_csv_response(csv_data, filename_prefix): + def _download_csv_response(csv_data: str, filename_prefix: str) -> Response: """Generate csv file download response with 'csv_data'. Arguments: @@ -218,7 +227,7 @@ def _download_csv_response(csv_data, filename_prefix): @app.route("/stats/requests/csv") @self.auth_required_if_enabled - def request_stats_csv(): + def request_stats_csv() -> Response: data = StringIO() writer = csv.writer(data) self.stats_csv_writer.requests_csv(writer) @@ -226,9 +235,9 @@ def request_stats_csv(): @app.route("/stats/requests_full_history/csv") @self.auth_required_if_enabled - def request_stats_full_history_csv(): + def request_stats_full_history_csv() -> Response: options = self.environment.parsed_options - if options and options.stats_history_enabled: + if options and options.stats_history_enabled and isinstance(self.stats_csv_writer, StatsCSVFileWriter): return send_file( os.path.abspath(self.stats_csv_writer.stats_history_file_name()), mimetype="text/csv", @@ -244,7 +253,7 @@ def request_stats_full_history_csv(): @app.route("/stats/failures/csv") @self.auth_required_if_enabled - def failures_stats_csv(): + def failures_stats_csv() -> Response: data = StringIO() writer = csv.writer(data) self.stats_csv_writer.failures_csv(writer) @@ -253,10 +262,28 @@ def failures_stats_csv(): @app.route("/stats/requests") @self.auth_required_if_enabled @memoize(timeout=DEFAULT_CACHE_TIME, dynamic_timeout=True) - def request_stats(): - stats = [] + def request_stats() -> Response: + stats: List[Dict[str, Any]] = [] + errors: List[Dict[str, str]] = [] + + if environment.runner is None: + report = { + "stats": stats, + "errors": errors, + "total_rps": 0.0, + "fail_ratio": 0.0, + "current_response_time_percentile_95": None, + "current_response_time_percentile_50": None, + "state": STATE_MISSING, + "user_count": 0, + } + + if isinstance(environment.runner, MasterRunner): + report.update({"workers": []}) - for s in chain(sort_stats(self.environment.runner.stats.entries), [environment.runner.stats.total]): + return jsonify(report) + + for s in chain(sort_stats(environment.runner.stats.entries), [environment.runner.stats.total]): stats.append( { "method": s.method, @@ -276,7 +303,6 @@ def request_stats(): } ) - errors = [] for e in environment.runner.errors.values(): err_dict = e.to_dict() err_dict["name"] = escape(err_dict["name"]) @@ -285,9 +311,11 @@ def request_stats(): # Truncate the total number of stats and errors displayed since a large number of rows will cause the app # to render extremely slowly. Aggregate stats should be preserved. - report = {"stats": stats[:500], "errors": errors[:500]} + truncated_stats = stats[:500] if len(stats) > 500: - report["stats"] += [stats[-1]] + truncated_stats += [stats[-1]] + + report = {"stats": truncated_stats, "errors": errors[:500]} if stats: report["total_rps"] = stats[len(stats) - 1]["current_rps"] @@ -299,8 +327,7 @@ def request_stats(): "current_response_time_percentile_50" ] = environment.runner.stats.total.get_current_response_time_percentile(0.5) - is_distributed = isinstance(environment.runner, MasterRunner) - if is_distributed: + if isinstance(environment.runner, MasterRunner): workers = [] for worker in environment.runner.clients.values(): workers.append( @@ -322,7 +349,7 @@ def request_stats(): @app.route("/exceptions") @self.auth_required_if_enabled - def exceptions(): + def exceptions() -> Response: return jsonify( { "exceptions": [ @@ -332,14 +359,14 @@ def exceptions(): "traceback": row["traceback"], "nodes": ", ".join(row["nodes"]), } - for row in environment.runner.exceptions.values() + for row in (environment.runner.exceptions.values() if environment.runner is not None else []) ] } ) @app.route("/exceptions/csv") @self.auth_required_if_enabled - def exceptions_csv(): + def exceptions_csv() -> Response: data = StringIO() writer = csv.writer(data) self.stats_csv_writer.exceptions_csv(writer) @@ -347,10 +374,17 @@ def exceptions_csv(): @app.route("/tasks") @self.auth_required_if_enabled - def tasks(): - is_distributed = isinstance(self.environment.runner, MasterRunner) + def tasks() -> Dict[str, Dict[str, Dict[str, float]]]: runner = self.environment.runner - user_spawned = runner.reported_user_classes_count if is_distributed else runner.user_classes_count + user_spawned: Dict[str, int] + if runner is None: + user_spawned = {} + else: + user_spawned = ( + runner.reported_user_classes_count + if isinstance(runner, MasterRunner) + else runner.user_classes_count + ) task_data = { "per_class": get_ratio(self.environment.user_classes, user_spawned, False), From 6f6090b264e5b2de09c872aab3be2636b3bd17c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mikael=20G=C3=B6ransson?= Date: Tue, 24 May 2022 14:55:25 +0200 Subject: [PATCH 2/3] fixed additional mypy errors --- locust/dispatch.py | 4 +- locust/env.py | 4 +- locust/runners.py | 212 +++++++++++++++++++++++--------------------- locust/stats.py | 64 ++++++++----- locust/user/task.py | 16 +++- locust/web.py | 6 +- 6 files changed, 174 insertions(+), 132 deletions(-) diff --git a/locust/dispatch.py b/locust/dispatch.py index 6e49645ebb..f4abb7b566 100644 --- a/locust/dispatch.py +++ b/locust/dispatch.py @@ -49,7 +49,7 @@ class UsersDispatcher(Iterator): from 10 to 100. """ - def __init__(self, worker_nodes: "List[WorkerNode]", user_classes: List[Type[User]]): + def __init__(self, worker_nodes: List["WorkerNode"], user_classes: List[Type[User]]): """ :param worker_nodes: List of worker nodes :param user_classes: The user classes @@ -397,7 +397,7 @@ def infinite_cycle_gen(users: List[Tuple[Type[User], int]]) -> itertools.cycle: current_fixed_users_count = {u: self._get_user_current_count(u) for u in fixed_users} spawned_classes: Set[str] = set() while len(spawned_classes) != len(fixed_users): - user_name = next(cycle_fixed_gen) + user_name: Optional[str] = next(cycle_fixed_gen) if not user_name: break diff --git a/locust/env.py b/locust/env.py index 8ecb9afbef..c7a12dd4a0 100644 --- a/locust/env.py +++ b/locust/env.py @@ -17,7 +17,7 @@ from .runners import Runner, LocalRunner, MasterRunner, WorkerRunner from .web import WebUI from .user import User -from .user.task import filter_tasks_by_tags, TaskSet +from .user.task import filter_tasks_by_tags, TaskSet, TaskHolder from .shape import LoadTestShape @@ -246,7 +246,7 @@ def assign_equal_weights(self) -> None: tasks_frontier = u.tasks while len(tasks_frontier) != 0: t = tasks_frontier.pop() - if not callable(t) and hasattr(t, "tasks") and t.tasks: + if isinstance(t, TaskHolder): tasks_frontier.extend(t.tasks) elif callable(t): if t not in user_tasks: diff --git a/locust/runners.py b/locust/runners.py index 4a52deec30..29fa24853a 100644 --- a/locust/runners.py +++ b/locust/runners.py @@ -7,6 +7,7 @@ import sys import time import traceback +from abc import abstractmethod from collections import defaultdict from collections.abc import MutableMapping from operator import ( @@ -20,15 +21,14 @@ Iterator, List, NoReturn, - Union, ValuesView, TypedDict, Set, - Callable, Optional, Tuple, Type, Any, + Protocol, cast, ) from uuid import uuid4 @@ -88,6 +88,12 @@ class ExceptionDict(TypedDict): nodes: Set[str] +class CustomMessageListener(Protocol): + @abstractmethod + def __call__(self, environment: "Environment", msg: Message) -> None: + ... + + class Runner: """ Orchestrates the load test by starting and stopping the users. @@ -119,11 +125,7 @@ def __init__(self, environment: "Environment") -> None: self.target_user_classes_count: Dict[str, int] = {} # target_user_count is set before the ramp-up/ramp-down occurs. self.target_user_count: int = 0 - self.custom_messages: Dict[str, Callable[["Environment", Message], None]] = {} - - # Only when running in standalone mode (non-distributed) - self._local_worker_node = WorkerNode(id="local") - self._local_worker_node.user_classes_count = self.user_classes_count + self.custom_messages: Dict[str, CustomMessageListener] = {} self._users_dispatcher: Optional[UsersDispatcher] = None @@ -146,7 +148,7 @@ def on_request_failure(request_type, name, response_time, response_length, excep self.final_user_classes_count: Dict[str, int] = {} # just for the ratio report, fills before runner stops # register listener that resets stats when spawning is complete - def on_spawning_complete(user_count): + def on_spawning_complete(user_count: int) -> None: self.update_state(STATE_RUNNING) if environment.reset_stats: logger.info("Resetting stats\n") @@ -154,7 +156,7 @@ def on_spawning_complete(user_count): self.environment.events.spawning_complete.add_listener(on_spawning_complete) - def __del__(self): + def __del__(self) -> None: # don't leave any stray greenlets if runner is removed if self.greenlet and len(self.greenlet) > 0: self.greenlet.kill(block=False) @@ -320,80 +322,9 @@ def monitor_cpu_and_memory(self) -> NoReturn: self.cpu_warning_emitted = True gevent.sleep(CPU_MONITOR_INTERVAL) + @abstractmethod def start(self, user_count: int, spawn_rate: float, wait: bool = False) -> None: - """ - Start running a load test - - :param user_count: Total number of users to start - :param spawn_rate: Number of users to spawn per second - :param wait: If True calls to this method will block until all users are spawned. - If False (the default), a greenlet that spawns the users will be - started and the call to this method will return immediately. - """ - self.target_user_count = user_count - - if self.state != STATE_RUNNING and self.state != STATE_SPAWNING: - self.stats.clear_all() - self.exceptions = {} - self.cpu_warning_emitted = False - self.worker_cpu_warning_emitted = False - self.environment._filter_tasks_by_tags() - self.environment.events.test_start.fire(environment=self.environment) - - if wait and user_count - self.user_count > spawn_rate: - raise ValueError("wait is True but the amount of users to add is greater than the spawn rate") - - for user_class in self.user_classes: - if self.environment.host: - user_class.host = self.environment.host - - if self.state != STATE_INIT and self.state != STATE_STOPPED: - self.update_state(STATE_SPAWNING) - - if self._users_dispatcher is None: - self._users_dispatcher = UsersDispatcher( - worker_nodes=[self._local_worker_node], user_classes=self.user_classes - ) - - logger.info("Ramping to %d users at a rate of %.2f per second" % (user_count, spawn_rate)) - - self._users_dispatcher.new_dispatch(user_count, spawn_rate) - - try: - for dispatched_users in self._users_dispatcher: - user_classes_spawn_count = {} - user_classes_stop_count = {} - user_classes_count = dispatched_users[self._local_worker_node.id] - logger.debug(f"Ramping to {_format_user_classes_count_for_log(user_classes_count)}") - for user_class, user_class_count in user_classes_count.items(): - if self.user_classes_count[user_class] > user_class_count: - user_classes_stop_count[user_class] = self.user_classes_count[user_class] - user_class_count - elif self.user_classes_count[user_class] < user_class_count: - user_classes_spawn_count[user_class] = user_class_count - self.user_classes_count[user_class] - - if wait: - # spawn_users will block, so we need to call stop_users first - self.stop_users(user_classes_stop_count) - self.spawn_users(user_classes_spawn_count, wait) - else: - # call spawn_users before stopping the users since stop_users - # can be blocking because of the stop_timeout - self.spawn_users(user_classes_spawn_count, wait) - self.stop_users(user_classes_stop_count) - - self._local_worker_node.user_classes_count = next(iter(dispatched_users.values())) - - except KeyboardInterrupt: - # TODO: Find a cleaner way to handle that - # We need to catch keyboard interrupt. Otherwise, if KeyboardInterrupt is received while in - # a gevent.sleep inside the dispatch_users function, locust won't gracefully shutdown. - self.quit() - - logger.info(f"All users spawned: {_format_user_classes_count_for_log(self.user_classes_count)}") - - self.target_user_classes_count = self.user_classes_count - - self.environment.events.spawning_complete.fire(user_count=sum(self.target_user_classes_count.values())) + ... def start_shape(self) -> None: """ @@ -407,12 +338,13 @@ def start_shape(self) -> None: self.update_state(STATE_INIT) self.shape_greenlet = self.greenlet.spawn(self.shape_worker) self.shape_greenlet.link_exception(greenlet_exception_handler) - self.environment.shape_class.reset_time() + if self.environment.shape_class is not None: + self.environment.shape_class.reset_time() def shape_worker(self) -> None: logger.info("Shape worker starting") while self.state == STATE_INIT or self.state == STATE_SPAWNING or self.state == STATE_RUNNING: - new_state = self.environment.shape_class.tick() + new_state = self.environment.shape_class.tick() if self.environment.shape_class is not None else None if new_state is None: logger.info("Shape test stopping") if self.environment.parsed_options and self.environment.parsed_options.headless: @@ -487,7 +419,7 @@ def log_exception(self, node_id: str, msg: str, formatted_tb: str) -> None: row["nodes"].add(node_id) self.exceptions[key] = row - def register_message(self, msg_type: str, listener: Callable[["Environment", Message], None]) -> None: + def register_message(self, msg_type: str, listener: CustomMessageListener) -> None: """ Register a listener for a custom message from another node @@ -508,6 +440,10 @@ def __init__(self, environment) -> None: """ super().__init__(environment) + # Only when running in standalone mode (non-distributed) + self._local_worker_node = WorkerNode(id="local") + self._local_worker_node.user_classes_count = self.user_classes_count + # register listener that's logs the exception for the local runner def on_user_error(user_instance, exception, tb): formatted_tb = "".join(traceback.format_tb(tb)) @@ -515,6 +451,85 @@ def on_user_error(user_instance, exception, tb): self.environment.events.user_error.add_listener(on_user_error) + def _start(self, user_count: int, spawn_rate: float, wait: bool = False) -> None: + """ + Start running a load test + + :param user_count: Total number of users to start + :param spawn_rate: Number of users to spawn per second + :param wait: If True calls to this method will block until all users are spawned. + If False (the default), a greenlet that spawns the users will be + started and the call to this method will return immediately. + """ + self.target_user_count = user_count + + if self.state != STATE_RUNNING and self.state != STATE_SPAWNING: + self.stats.clear_all() + self.exceptions = {} + self.cpu_warning_emitted = False + self.worker_cpu_warning_emitted = False + self.environment._filter_tasks_by_tags() + self.environment.events.test_start.fire(environment=self.environment) + + if wait and user_count - self.user_count > spawn_rate: + raise ValueError("wait is True but the amount of users to add is greater than the spawn rate") + + for user_class in self.user_classes: + if self.environment.host: + user_class.host = self.environment.host + + if self.state != STATE_INIT and self.state != STATE_STOPPED: + self.update_state(STATE_SPAWNING) + + if self._users_dispatcher is None: + self._users_dispatcher = UsersDispatcher( + worker_nodes=[self._local_worker_node], user_classes=self.user_classes + ) + + logger.info("Ramping to %d users at a rate of %.2f per second" % (user_count, spawn_rate)) + + cast(UsersDispatcher, self._users_dispatcher).new_dispatch(user_count, spawn_rate) + + try: + for dispatched_users in self._users_dispatcher: + user_classes_spawn_count: Dict[str, int] = {} + user_classes_stop_count: Dict[str, int] = {} + user_classes_count = dispatched_users[self._local_worker_node.id] + logger.debug(f"Ramping to {_format_user_classes_count_for_log(user_classes_count)}") + for user_class_name, user_class_count in user_classes_count.items(): + if self.user_classes_count[user_class_name] > user_class_count: + user_classes_stop_count[user_class_name] = ( + self.user_classes_count[user_class_name] - user_class_count + ) + elif self.user_classes_count[user_class_name] < user_class_count: + user_classes_spawn_count[user_class_name] = ( + user_class_count - self.user_classes_count[user_class_name] + ) + + if wait: + # spawn_users will block, so we need to call stop_users first + self.stop_users(user_classes_stop_count) + self.spawn_users(user_classes_spawn_count, wait) + else: + # call spawn_users before stopping the users since stop_users + # can be blocking because of the stop_timeout + self.spawn_users(user_classes_spawn_count, wait) + self.stop_users(user_classes_stop_count) + + self._local_worker_node.user_classes_count = next(iter(dispatched_users.values())) + + except KeyboardInterrupt: + # TODO: Find a cleaner way to handle that + # We need to catch keyboard interrupt. Otherwise, if KeyboardInterrupt is received while in + # a gevent.sleep inside the dispatch_users function, locust won't gracefully shutdown. + self.quit() + + logger.info(f"All users spawned: {_format_user_classes_count_for_log(self.user_classes_count)}") + + self.target_user_classes_count = self.user_classes_count + + self.environment.events.spawning_complete.fire(user_count=sum(self.target_user_classes_count.values())) + def start(self, user_count: int, spawn_rate: float, wait: bool = False) -> None: if spawn_rate > 100: logger.warning( @@ -524,9 +539,7 @@ def start(self, user_count: int, spawn_rate: float, wait: bool = False) -> None: if self.spawning_greenlet: # kill existing spawning_greenlet before we start a new one self.spawning_greenlet.kill(block=True) - self.spawning_greenlet = self.greenlet.spawn( - lambda: super(LocalRunner, self).start(user_count, spawn_rate, wait=wait) - ) + self.spawning_greenlet = self.greenlet.spawn(lambda: self._start(user_count, spawn_rate, wait=wait)) self.spawning_greenlet.link_exception(greenlet_exception_handler) def stop(self) -> None: @@ -553,7 +566,6 @@ def send_message(self, msg_type: str, data: Optional[Any] = None) -> None: class DistributedRunner(Runner): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - self._local_worker_node = None setup_distributed_stats_event_listeners(self.environment.events, self.stats) @@ -612,8 +624,8 @@ def __getitem__(self, k: str) -> WorkerNode: def __len__(self) -> int: return len(self._worker_nodes) - def __iter__(self) -> Iterator[WorkerNode]: - return iter(self._worker_nodes) + def __iter__(self) -> Iterator[str]: + return iter(list(self._worker_nodes.keys())) class MasterRunner(DistributedRunner): @@ -681,7 +693,7 @@ def rebalancing_enabled(self) -> bool: @property def user_count(self) -> int: - return sum(c.user_count for c in self.clients.values()) + return sum([c.user_count for c in self.clients.values()]) def cpu_log_warning(self) -> bool: warning_emitted = Runner.cpu_log_warning(self) @@ -1138,14 +1150,14 @@ def start_worker(self, user_classes_count: Dict[str, int], **kwargs) -> None: if self.environment.host: user_class.host = self.environment.host - user_classes_spawn_count: Dict[Type[User], int] = {} - user_classes_stop_count: Dict[Type[User], int] = {} + user_classes_spawn_count: Dict[str, int] = {} + user_classes_stop_count: Dict[str, int] = {} - for user_class, user_class_count in user_classes_count.items(): - if self.user_classes_count[user_class] > user_class_count: - user_classes_stop_count[user_class] = self.user_classes_count[user_class] - user_class_count - elif self.user_classes_count[user_class] < user_class_count: - user_classes_spawn_count[user_class] = user_class_count - self.user_classes_count[user_class] + for user_class_name, user_class_count in user_classes_count.items(): + if self.user_classes_count[user_class_name] > user_class_count: + user_classes_stop_count[user_class_name] = self.user_classes_count[user_class_name] - user_class_count + elif self.user_classes_count[user_class_name] < user_class_count: + user_classes_spawn_count[user_class_name] = user_class_count - self.user_classes_count[user_class_name] # call spawn_users before stopping the users since stop_users # can be blocking because of the stop_timeout @@ -1269,7 +1281,7 @@ def send_message(self, msg_type: str, data: Optional[Dict[str, Any]] = None) -> self.client.send(Message(msg_type, data, self.client_id)) def _send_stats(self) -> None: - data = {} + data: Dict[str, Any] = {} self.environment.events.report_to_master.fire(client_id=self.client_id, data=data) self.client.send(Message("stats", data, self.client_id)) diff --git a/locust/stats.py b/locust/stats.py index e865475799..3649e6c235 100644 --- a/locust/stats.py +++ b/locust/stats.py @@ -1,3 +1,4 @@ +from abc import abstractmethod import datetime import hashlib from tempfile import NamedTemporaryFile @@ -14,6 +15,7 @@ TYPE_CHECKING, Any, Dict, + Iterable, NoReturn, Tuple, List, @@ -23,6 +25,7 @@ OrderedDict as OrderedDictType, Callable, TypeVar, + Protocol, cast, ) from types import FrameType @@ -32,17 +35,12 @@ import logging -with NamedTemporaryFile(mode="w") as t: - CSVWriter = type(csv.writer(t)) - if TYPE_CHECKING: from .runners import Runner from .env import Environment console_logger = logging.getLogger("locust.stats_logger") -S = TypeVar("S", bound="StatsBase") - """Space in table for request name. Auto shrink it if terminal is small (<160 characters)""" try: STATS_NAME_WIDTH = max(min(os.get_terminal_size()[0] - 80, 80), 0) @@ -52,6 +50,12 @@ STATS_AUTORESIZE = True # overwrite this if you dont want auto resize while running +class CSVWriter(Protocol): + @abstractmethod + def writerow(self, columns: Iterable[Union[str, int, float]]) -> None: + ... + + class StatsBaseDict(TypedDict): name: str method: str @@ -77,10 +81,12 @@ class StatsErrorDict(StatsBaseDict): occurrences: int -class StatsBase: - def __init__(self, name: str, method: str) -> None: - self.name = name - self.method = method +class StatsHolder(Protocol): + name: str + method: str + + +S = TypeVar("S", bound=StatsHolder) def resize_handler(signum: int, frame: Optional[FrameType]): @@ -266,16 +272,15 @@ def serialize_stats(self) -> List["StatsEntryDict"]: ] def serialize_errors(self) -> Dict[str, "StatsErrorDict"]: - return {k: e.to_dict() for k, e in self.errors.items()} + return {k: e.serialize() for k, e in self.errors.items()} -class StatsEntry(StatsBase): +class StatsEntry: """ Represents a single stats entry (name and method) """ def __init__(self, stats: Optional[RequestStats], name: str, method: str, use_response_times_cache: bool = False): - super().__init__(name, method) self.stats = stats self.name = name """ Name (URL) of this stats entry """ @@ -314,7 +319,7 @@ def __init__(self, stats: Optional[RequestStats], name: str, method: str, use_re This dict is used to calculate the median and percentile response times. """ - self.response_times_cache: OrderedDictType[int, CachedResponseTimes] + self.response_times_cache: Optional[OrderedDictType[int, CachedResponseTimes]] = None """ If use_response_times_cache is set to True, this will be a {timestamp => CachedResponseTimes()} OrderedDict that holds a copy of the response_times dict for each of the last 20 seconds. @@ -616,10 +621,11 @@ def get_current_response_time_percentile(self, percent: float) -> Optional[int]: acceptable_timestamps.append(t - CURRENT_RESPONSE_TIME_PERCENTILE_WINDOW + i) cached: Optional[CachedResponseTimes] = None - for ts in acceptable_timestamps: - if ts in self.response_times_cache: - cached = self.response_times_cache[ts] - break + if self.response_times_cache is not None: + for ts in acceptable_timestamps: + if ts in self.response_times_cache: + cached = self.response_times_cache[ts] + break if cached: # If we found an acceptable cached response times, we'll calculate a new response @@ -647,6 +653,9 @@ def percentile(self) -> str: ) def _cache_response_times(self, t: int) -> None: + if self.response_times_cache is None: + self.response_times_cache = OrderedDict() + self.response_times_cache[t] = CachedResponseTimes( response_times=copy(self.response_times), num_requests=self.num_requests, @@ -664,9 +673,8 @@ def _cache_response_times(self, t: int) -> None: self.response_times_cache.popitem(last=False) -class StatsError(StatsBase): +class StatsError: def __init__(self, method: str, name: str, error: Optional[Union[Exception, str]], occurrences: int = 0): - super().__init__(name, method) self.method = method self.name = name self.error = error @@ -709,11 +717,19 @@ def to_name(self) -> str: return f"{self.method} {self.name}: {unwrapped_error}" - def to_dict(self) -> StatsErrorDict: - return cast(StatsErrorDict, {key: getattr(self, key, None) for key in StatsErrorDict.__annotations__.keys()}) + def serialize(self) -> StatsErrorDict: + def _getattr(obj: "StatsError", key: str, default: Optional[Any]) -> Optional[Any]: + value = getattr(obj, key, default) + + if key in ["error"]: + value = StatsError.parse_error(value) + + return value + + return cast(StatsErrorDict, {key: _getattr(self, key, None) for key in StatsErrorDict.__annotations__.keys()}) @classmethod - def from_dict(cls, data: StatsErrorDict) -> "StatsError": + def unserialize(cls, data: StatsErrorDict) -> "StatsError": return cls(data["method"], data["name"], data["error"], data["occurrences"]) @@ -752,7 +768,7 @@ def on_worker_report(client_id: str, data: Dict[str, Any]) -> None: for error_key, error in data["errors"].items(): if error_key not in stats.errors: - stats.errors[error_key] = StatsError.from_dict(error) + stats.errors[error_key] = StatsError.unserialize(error) else: stats.errors[error_key].occurrences += error["occurrences"] @@ -931,7 +947,7 @@ def _failures_data_rows(self, csv_writer: CSVWriter) -> None: [ stats_error.method, stats_error.name, - stats_error.error, + StatsError.parse_error(stats_error.error), stats_error.occurrences, ] ) diff --git a/locust/user/task.py b/locust/user/task.py index e9cd9511b3..32aecf82c8 100644 --- a/locust/user/task.py +++ b/locust/user/task.py @@ -2,7 +2,20 @@ import random import traceback from time import time -from typing import TYPE_CHECKING, Callable, List, Union, TypeVar, Optional, Type, overload, Protocol, Dict, Set +from typing import ( + TYPE_CHECKING, + Callable, + List, + Union, + TypeVar, + Optional, + Type, + overload, + Protocol, + Dict, + Set, + runtime_checkable, +) from typing_extensions import final import gevent @@ -20,6 +33,7 @@ LOCUST_STATE_RUNNING, LOCUST_STATE_WAITING, LOCUST_STATE_STOPPING = ["running", "waiting", "stopping"] +@runtime_checkable class TaskHolder(Protocol[TaskT]): tasks: List[TaskT] diff --git a/locust/web.py b/locust/web.py index c13b83de49..7790115984 100644 --- a/locust/web.py +++ b/locust/web.py @@ -17,7 +17,7 @@ from .exception import AuthCredentialsError from .runners import MasterRunner, STATE_MISSING from .log import greenlet_exception_logger -from .stats import StatsCSVFileWriter, sort_stats +from .stats import StatsCSVFileWriter, StatsErrorDict, sort_stats from . import stats as stats_module, __version__ as version, argument_parser from .stats import StatsCSV from .user.inspectuser import get_ratio @@ -264,7 +264,7 @@ def failures_stats_csv() -> Response: @memoize(timeout=DEFAULT_CACHE_TIME, dynamic_timeout=True) def request_stats() -> Response: stats: List[Dict[str, Any]] = [] - errors: List[Dict[str, str]] = [] + errors: List[StatsErrorDict] = [] if environment.runner is None: report = { @@ -304,7 +304,7 @@ def request_stats() -> Response: ) for e in environment.runner.errors.values(): - err_dict = e.to_dict() + err_dict = e.serialize() err_dict["name"] = escape(err_dict["name"]) err_dict["error"] = escape(err_dict["error"]) errors.append(err_dict) From 1db3834d81c387b384643889253b324def30b407 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mikael=20G=C3=B6ransson?= Date: Tue, 24 May 2022 15:15:29 +0200 Subject: [PATCH 3/3] python 3.7 compatible --- locust/runners.py | 8 ++++++-- locust/stats.py | 9 +++++++-- locust/user/task.py | 9 ++++++--- setup.cfg | 2 +- 4 files changed, 20 insertions(+), 8 deletions(-) diff --git a/locust/runners.py b/locust/runners.py index 29fa24853a..8c3837734d 100644 --- a/locust/runners.py +++ b/locust/runners.py @@ -22,17 +22,21 @@ List, NoReturn, ValuesView, - TypedDict, Set, Optional, Tuple, Type, Any, - Protocol, cast, ) from uuid import uuid4 +# @TODO: typing.Protocol is in python >= 3.8 +try: + from typing import Protocol, TypedDict +except ImportError: + from typing_extensions import Protocol, TypedDict # type: ignore + import gevent import greenlet import psutil diff --git a/locust/stats.py b/locust/stats.py index 3649e6c235..fdc4f3bc1c 100644 --- a/locust/stats.py +++ b/locust/stats.py @@ -20,14 +20,19 @@ Tuple, List, Union, - TypedDict, Optional, OrderedDict as OrderedDictType, Callable, TypeVar, - Protocol, cast, ) + +# @TODO: typing.Protocol is in python >= 3.8 +try: + from typing import Protocol, TypedDict +except ImportError: + from typing_extensions import Protocol, TypedDict # type: ignore + from types import FrameType from .exception import CatchResponseError diff --git a/locust/user/task.py b/locust/user/task.py index 32aecf82c8..a3da8d98d2 100644 --- a/locust/user/task.py +++ b/locust/user/task.py @@ -11,12 +11,15 @@ Optional, Type, overload, - Protocol, Dict, Set, - runtime_checkable, ) -from typing_extensions import final + +# @TODO: typing.Protocol and typing.final is in python >= 3.8 +try: + from typing import Protocol, final, runtime_checkable +except ImportError: + from typing_extensions import Protocol, final, runtime_checkable # type: ignore import gevent from gevent import GreenletExit diff --git a/setup.cfg b/setup.cfg index 0205a30881..f44b4565c2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -47,7 +47,7 @@ install_requires = Flask-BasicAuth >=0.2.0 Flask-Cors >=3.0.10 roundrobin >=0.0.2 - typing-extensions >=3.7.4.3 # This provides support for @final, and can probably be removed once we drop 3.7 support + typing-extensions >=3.7.4.3 # This provides support for @final, @runtime_checkable, Protocol and TypedDict, and can probably be removed once we drop 3.7 support pywin32;platform_system=='Windows' [options.packages.find]