diff --git a/wgkex/broker/app.py b/wgkex/broker/app.py index e8122cc..84366d1 100644 --- a/wgkex/broker/app.py +++ b/wgkex/broker/app.py @@ -52,6 +52,31 @@ def from_dict(cls, msg: dict) -> "KeyExchange": return cls(public_key=public_key, domain=domain) +@dataclasses.dataclass +class Gateway: + """A best Gateway message. + + Attributes: + domain: The domain for the best Gateway. + """ + + domain: str + + @classmethod + def from_dict(cls, msg: dict) -> "Gateway": + """Creates a new Gateway message from dict. + + Arguments: + msg: The message to convert. + Returns: + A Gateway object. + """ + domain = str(msg.get("domain")) + if not is_valid_domain(domain): + raise ValueError(f"Domain {domain} not in configured domains.") + return cls(domain=domain) + + def _fetch_app_config() -> Flask_app: """Creates the Flask app from configuration. @@ -160,6 +185,52 @@ def wg_api_v2_key_exchange() -> Tuple[Response | Dict, int]: return {"Endpoint": endpoint}, 200 +@app.route("/api/v2/wg/gateway/best", methods=["POST"]) +def wg_api_v2_gateway_best() -> Tuple[Response | Dict, int]: + """Retrieves a site, validates it and responds with a worker/gateway the client should connect to. + + Returns: + Status message, Endpoint with address/domain, port. + """ + try: + data = Gateway.from_dict(request.get_json(force=True)) + except Exception as ex: + return {"error": {"message": str(ex)}}, 400 + + domain = data.domain + logger.info(f"wg_api_v2_gateway_best: Domain: {domain}") + + best_worker, diff, current_peers = worker_metrics.get_best_worker(domain) + if best_worker is None: + logger.warning(f"No worker online for domain {domain}") + return { + "error": { + "message": "no gateway online for this domain, please check the domain value and try again later" + } + }, 400 + + logger.debug( + f"Should Chose worker {best_worker} with {current_peers} connected clients ({diff})" + ) + + w_data = worker_data.get((best_worker, domain), None) + if w_data is None: + logger.error(f"Couldn't get worker endpoint data for {best_worker}/{domain}") + return {"error": {"message": "could not get gateway data"}}, 500 + + # Add Code to check if we should Switch (check if Gateways are Loadbalaced in a specefiy trashhold) + shouldSwitch = True + endpoint = { + "Address": w_data.get("ExternalAddress"), + "Port": str(w_data.get("Port")), + "AllowedIPs": [w_data.get("LinkAddress")], + "PublicKey": w_data.get("PublicKey"), + "Switch": shouldSwitch, + } + + return {"Endpoint": endpoint}, 200 + + @mqtt.on_connect() def handle_mqtt_connect( client: mqtt_client.Client, userdata: bytes, flags: Any, rc: Any