Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added customer config #4

Merged
merged 9 commits into from
Jul 29, 2024
4 changes: 3 additions & 1 deletion databricks/sdk/chaosgenius/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"Library for Pulling Client Data."
from .cg_config import CGConfig
from .data_puller import DataPuller
from .handler import initiate_data_pull
from .logger import LogSparkDBHandler

__all__ = ["DataPuller", "LogSparkDBHandler"]
__all__ = ["CGConfig", "DataPuller", "LogSparkDBHandler", "initiate_data_pull"]
111 changes: 111 additions & 0 deletions databricks/sdk/chaosgenius/cg_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import json
from logging import Logger
from typing import Optional

import pandas as pd
from pyspark.sql.session import SparkSession


class CGConfig:
"""
CG Config class.

Entity type: workspace, cluster, warehouse, job etc
Entity ID: ID of above
Include Entity: "yes"/"no"
Entity Config: JSON {"something": "else"}
"""

def __init__(self, sparkSession: SparkSession, logger: Logger):
self.logger = logger
self.logger.info("Creating customer config table if not exists.")
self.sparkSession = sparkSession

try:
sparkSession.sql(
"""
CREATE TABLE IF NOT EXISTS chaosgenius.default.chaosgenius_config (
entity_type string,
entity_id string,
include_entity string,
entity_config string
)
"""
)
except Exception:
self.logger.error("Unable to create config table.", exc_info=True)

def get(
self,
entity_type: Optional[str] = None,
entity_ids: Optional[list[str]] = None,
include_entity: Optional[str] = None,
entity_config_filter: Optional[dict] = None,
) -> pd.DataFrame:
try:
where_query = ""
if entity_type is not None:
where_query += f"where entity_type = '{entity_type}'"

if entity_ids is not None:
if where_query == "":
where_query += "where "
else:
where_query += " and "
entity_ids_string = ",".join(map(lambda x: f"'{x}'", entity_ids))
where_query += f"entity_id in ({entity_ids_string})"

if include_entity is not None:
if where_query == "":
where_query += "where "
else:
where_query += " and "
where_query += f"include_entity = '{include_entity}'"

df = self.sparkSession.sql(
f"select * from chaosgenius.default.chaosgenius_config {where_query}"
).toPandas()

if df.empty:
return pd.DataFrame(
columns=[
"entity_type",
"entity_id",
"include_entity",
"entity_config",
]
)

df["entity_config"] = (
df["entity_config"].replace("", "{}").apply(lambda x: json.loads(x))
)
if entity_config_filter is not None:
df = df[
df["entity_config"].apply(
lambda x: all(
item in x.items() for item in entity_config_filter.items()
)
)
]
return df
except Exception:
self.logger.error("Unable to get config.", exc_info=True)
return pd.DataFrame(
columns=["entity_type", "entity_id", "include_entity", "entity_config"]
)

def get_ids(
self,
entity_type: Optional[str] = None,
entity_ids: Optional[list[str]] = None,
include_entity: Optional[str] = None,
entity_config_filter: Optional[dict] = None,
) -> set[str]:
return set(
self.get(
entity_type=entity_type,
entity_ids=entity_ids,
include_entity=include_entity,
entity_config_filter=entity_config_filter,
)["entity_id"].values.tolist()
)
179 changes: 156 additions & 23 deletions databricks/sdk/chaosgenius/data_puller.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
"""Utilities for pulling data."""

import datetime as dt
import logging
import json
from typing import Optional, Union
from typing import Callable, Optional, Union

import pandas as pd
from pyspark.sql.session import SparkSession
from pyspark.sql import DataFrame as SparkDataFrame
from pyspark.sql.functions import explode
from databricks.sdk import WorkspaceClient
from databricks.sdk.chaosgenius.cg_config import CGConfig
from databricks.sdk.service import sql as databricks_sql
from databricks.sdk.service.compute import ClusterDetails, InstancePoolAndStats
from databricks.sdk.service.iam import User
from databricks.sdk.service.sql import EndpointInfo
from databricks.sdk.service.jobs import BaseJob


PANDAS_CHUNK_SIZE = 10000

Expand All @@ -20,12 +28,14 @@ def __init__(
self,
workspace_id: str,
workspace_client: WorkspaceClient,
customer_config: CGConfig,
spark_session: Optional[SparkSession],
save_to_csv: bool = False,
logger: Optional[logging.Logger] = None,
) -> None:
self._workspace_id = workspace_id
self._workspace_client = workspace_client
self._customer_config = customer_config
self._spark_session = spark_session
self._logger = logger if logger else logging.getLogger("data_puller")
self._pull_time = dt.datetime.now()
Expand All @@ -34,7 +44,7 @@ def __init__(
self._start_time, self._end_time = self._get_start_end_time()
self._save_to_csv = save_to_csv

logger.info(
self._logger.info(
f"Initializing data puller with workspace id: {workspace_id}, "
f"pull time: {self._pull_time}, start_time: {self._start_time}, "
f"end_time: {self._end_time}, save_to_csv: {self._save_to_csv}"
Expand All @@ -51,25 +61,28 @@ def __init__(
},
)

logger.info("Getting cluster list")
self._cluster_list = [i for i in self._workspace_client.clusters.list()]
logger.info(f"Total clusters: {len(self._cluster_list)}")
logger.info("Getting instance pools list")
self._ip_list = [i for i in self._workspace_client.instance_pools.list()]
logger.info(f"Total pools: {len(self._ip_list)}")
logger.info("Getting warehouses list")
self._wh_list = [i for i in self._workspace_client.warehouses.list()]
logger.info(f"Total warehouses: {len(self._wh_list)}")
logger.info("Getting jobs list")
self._job_list = [
i for i in self._workspace_client.jobs.list(expand_tasks=True)
]
logger.info(f"Total jobs: {len(self._job_list)}")
logger.info("Getting users list")
self._user_list = [i for i in self._workspace_client.users.list()]
logger.info(f"Total users: {len(self._user_list)}")
# TODO(KB): refactor into multiple files by clusters, wh, etc
self._logger.info("Getting cluster list")
self._cluster_list = self._get_full_cluster_list()
self._logger.info(f"Total clusters: {len(self._cluster_list)}")

self._logger.info("Getting instance pools list")
self._ip_list = self._get_full_instance_pool_info()
self._logger.info(f"Total pools: {len(self._ip_list)}")

self._logger.info("Getting warehouses list")
self._wh_list = self._get_full_warehouse_info()
self._logger.info(f"Total warehouses: {len(self._wh_list)}")

logger.info("Starting data pull")
self._logger.info("Getting jobs list")
self._job_list = self._get_full_jobs_info()
self._logger.info(f"Total jobs: {len(self._job_list)}")

self._logger.info("Getting users list")
self._user_list = self._get_full_user_info()
self._logger.info(f"Total users: {len(self._user_list)}")

self._logger.info("Starting data pull")
results = self.get_all()
success = True
for res in results:
Expand All @@ -78,7 +91,127 @@ def __init__(
break
status = "success" if success else "failed"
self._add_status_entry("overall", status=status, data={"results": results})
logger.info("Completed data pull.")
self._logger.info("Completed data pull.")

def _get_job_cluster_ids(self) -> set[str]:
job_compute_id_list = self._spark_session.sql(
"select compute_ids from system.workflow.job_task_run_timeline "
f"where period_start_time < from_unixtime({self._start_time//1000}) "
f"and period_start_time >= from_unixtime({self._end_time//1000}) "
)
job_compute_id_list = (
job_compute_id_list.select(
explode(job_compute_id_list.compute_ids).alias("compute_id")
)
.distinct()
.toPandas()["compute_id"]
.values.tolist()
)
return set(i for i in job_compute_id_list if len(i.split("-")) == 3)

def _generic_get_full_list(
self,
name: str,
root_list_getter: Callable,
root_item_getter: Callable,
id_attribute_name: str,
additional_ids: Optional[set] = None,
) -> list:
self._logger.info(f"Getting {name}.")
root_list = [i for i in root_list_getter()]
root_list_ids = set(getattr(i, id_attribute_name) for i in root_list)
self._logger.info(f"Current count: {len(root_list_ids)}")

self._logger.info(f"Adding {name} from customer config.")
config_ids = self._customer_config.get_ids(
entity_type=name,
include_entity="yes",
entity_config_filter={"workspace_id": self._workspace_id},
)
self._logger.info(f"Found {len(config_ids)} items from customer config.")

if additional_ids is not None:
self._logger.info("Adding additional IDs.")
config_ids = config_ids.union(additional_ids)
self._logger.info(f"Total count after additional items: {len(config_ids)}")

new_ids = root_list_ids.union(config_ids) - root_list_ids

for item_id in new_ids:
self._logger.info(f"New {name} ID {item_id} not in list. Getting info.")
try:
root_list.append(root_item_getter(item_id))
except Exception:
self._logger.exception(f"Failed to get {name} ID {item_id}.")

self._logger.info("Removing items from customer config.")
ids_to_remove = self._customer_config.get_ids(
entity_type="cluster",
include_entity="no",
entity_config_filter={"workspace_id": self._workspace_id},
)
self._logger.info(f"Items to be removed: {len(ids_to_remove)}.")

ids_to_remove = root_list_ids.union(config_ids).intersection(ids_to_remove)
self._logger.info(f"Actual items to be removed: {len(ids_to_remove)}.")

root_list = [
i for i in root_list if getattr(i, id_attribute_name) not in ids_to_remove
]
self._logger.info(f"Current count: {len(root_list)}")

return root_list

def _get_full_cluster_list(self) -> list[ClusterDetails]:
self._logger.info("Getting workspace clusters.")

self._logger.info("Getting job cluster IDs.")
job_cluster_ids = self._get_job_cluster_ids()
self._logger.info(f"Total job cluster IDs: {len(job_cluster_ids)}")

return self._generic_get_full_list(
"cluster",
self._workspace_client.clusters.list,
self._workspace_client.clusters.get,
"cluster_id",
additional_ids=job_cluster_ids,
)

def _get_full_instance_pool_info(self) -> list[InstancePoolAndStats]:
self._logger.info("Getting workspace instance pools.")
return self._generic_get_full_list(
"instance_pool",
self._workspace_client.instance_pools.list,
self._workspace_client.instance_pools.get,
"instance_pool_id",
)

def _get_full_warehouse_info(self) -> list[EndpointInfo]:
self._logger.info("Getting workspace warehouses.")
return self._generic_get_full_list(
"warehouse",
self._workspace_client.warehouses.list,
self._workspace_client.warehouses.get,
"id",
)

def _get_full_jobs_info(self) -> list[BaseJob]:
self._logger.info("Getting workspace jobs.")
return self._generic_get_full_list(
"job",
self._workspace_client.jobs.list,
self._workspace_client.jobs.get,
"job_id",
)

def _get_full_user_info(self) -> list[User]:
self._logger.info("Getting workspace users.")
return self._generic_get_full_list(
"user",
self._workspace_client.users.list,
self._workspace_client.users.get,
"id",
)

def _get_start_end_time(self) -> tuple[int, int]:
try:
Expand Down Expand Up @@ -216,7 +349,7 @@ def get_clusters_events(
self._save_iterator_in_chunks(
iterator=cluster_events,
metadata={"cluster_id": cluster.cluster_id},
table_name="clusters_events"
table_name="clusters_events",
)
results.append((cluster.cluster_id, True))
except Exception:
Expand Down Expand Up @@ -365,7 +498,7 @@ def get_job_runs_list(self) -> bool:
self._save_iterator_in_chunks(
iterator=job_runs,
metadata={"job_id": job.job_id},
table_name="jobs_runs_list"
table_name="jobs_runs_list",
)
job_results.append((job.job_id, True))
except Exception:
Expand Down
Loading
Loading