Skip to content

Commit

Permalink
mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
lloesche committed Sep 19, 2024
1 parent 1dd5a2e commit 73f64d7
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 59 deletions.
1 change: 0 additions & 1 deletion fixattiosync/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from .fixdata import FixData, add_args as fixdata_add_args
from .attiodata import AttioData, add_args as attio_add_args
from .sync import sync_fix_to_attio
from pprint import pprint


def main() -> None:
Expand Down
37 changes: 21 additions & 16 deletions fixattiosync/attiodata.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,32 @@
import os
import requests
from uuid import UUID
from typing import Union
from typing import Union, Any, Optional
from argparse import ArgumentParser
from .logger import log
from .attioresources import AttioWorkspace, AttioPerson, AttioUser


class AttioData:
def __init__(self, api_key, default_limit=500):
def __init__(self, api_key: str, default_limit: int = 500):
self.api_key = api_key
self.base_url = "https://api.attio.com/v2/"
self.default_limit = default_limit
self.hydrated = False
self.__workspaces = {}
self.__people = {}
self.__users = {}
self.__workspaces: dict[UUID, AttioWorkspace] = {}
self.__people: dict[UUID, AttioPerson] = {}
self.__users: dict[UUID, AttioUser] = {}

def _headers(self):
def _headers(self) -> dict[str, str]:
return {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"Accept": "application/json",
}

def _post_data(self, endpoint, json=None, params=None):
def _post_data(
self, endpoint: str, json: Optional[dict[str, Any]] = None, params: Optional[dict[str, str]] = None
) -> dict[str, Any]:
log.debug(f"Fetching data from {endpoint}")
url = self.base_url + endpoint
response = requests.post(url, headers=self._headers(), json=json, params=params)
Expand All @@ -33,7 +35,9 @@ def _post_data(self, endpoint, json=None, params=None):
else:
raise Exception(f"Error fetching data from {url}: {response.status_code} {response.text}")

def _put_data(self, endpoint: str, json: dict = None, params: dict = None):
def _put_data(
self, endpoint: str, json: Optional[dict[str, Any]] = None, params: Optional[dict[str, str]] = None
) -> dict[str, Any]:
log.debug(f"Putting data to {endpoint}")
url = self.base_url + endpoint
response = requests.put(url, headers=self._headers(), json=json, params=params)
Expand All @@ -43,10 +47,11 @@ def _put_data(self, endpoint: str, json: dict = None, params: dict = None):
raise Exception(f"Error putting data to {url}: {response.status_code} {response.text}")

def assert_record(
self, object_id: str, matching_attribute: str, data: dict
self, object_id: str, matching_attribute: str, data: dict[str, Any]
) -> Union[AttioPerson, AttioUser, AttioWorkspace]:
endpoint = f"objects/{object_id}/records"
params = {"matching_attribute": matching_attribute}
attio_cls: Union[type[AttioPerson], type[AttioUser], type[AttioWorkspace]]
match object_id:
case "users":
attio_cls = AttioUser
Expand All @@ -70,7 +75,7 @@ def assert_record(
else:
raise RuntimeError(f"Error asserting {object_id} in Attio: {response}")

def _records(self, object_id: str):
def _records(self, object_id: str) -> list[dict[str, Any]]:
log.debug(f"Fetching {object_id}")
endpoint = f"objects/{object_id}/records/query"
all_data = []
Expand All @@ -90,32 +95,32 @@ def _records(self, object_id: str):
return all_data

@property
def workspaces(self):
def workspaces(self) -> list[AttioWorkspace]:
if not self.hydrated:
self.hydrate()
return list(self.__workspaces.values())

@property
def people(self):
def people(self) -> list[AttioPerson]:
if not self.hydrated:
self.hydrate()
return list(self.__people.values())

@property
def users(self):
def users(self) -> list[AttioUser]:
if not self.hydrated:
self.hydrate()
return list(self.__users.values())

def hydrate(self):
def hydrate(self) -> None:
log.debug("Hydrating Attio data")
self.__workspaces = self.__marshal(self._records("workspaces"), AttioWorkspace)
self.__people = self.__marshal(self._records("people"), AttioPerson)
self.__users = self.__marshal(self._records("users"), AttioUser)
self.__connect()
self.hydrated = True

def __connect(self):
def __connect(self) -> None:
for user in self.__users.values():
if user.person_id in self.__people:
person = self.__people[user.person_id]
Expand All @@ -129,7 +134,7 @@ def __connect(self):
user.workspaces.append(workspace)

def __marshal(
self, data: dict, cls: Union[AttioWorkspace, AttioPerson, AttioUser]
self, data: list[dict[str, Any]], cls: Union[type[AttioWorkspace], type[AttioPerson], type[AttioUser]]
) -> dict[UUID, Union[AttioWorkspace, AttioPerson, AttioUser]]:
ret = {}
for item in data:
Expand Down
34 changes: 19 additions & 15 deletions fixattiosync/attioresources.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from dataclasses import dataclass, field
from datetime import datetime
from uuid import UUID
from typing import Optional, Self, Type, ClassVar
from typing import Optional, Self, Type, ClassVar, Any
from .logger import log


def get_latest_value(value: list[dict]) -> dict:
def get_latest_value(value: list[dict[str, Any]]) -> dict[str, Any]:
if value and len(value) > 0:
return value[0]
return {}
Expand All @@ -33,7 +33,7 @@ class AttioResource(ABC):

@classmethod
@abstractmethod
def make(cls: Type[Self], data: dict) -> Self:
def make(cls: Type[Self], data: dict[str, Any]) -> Self:
pass


Expand All @@ -48,11 +48,13 @@ class AttioWorkspace(AttioResource):
fix_workspace_id: Optional[UUID]
users: list[AttioUser] = field(default_factory=list)

def __eq__(self, other):
return self.id == other.id and self.tier == other.tier
def __eq__(self: Self, other: Any) -> bool:
if not hasattr(other, "id") or not hasattr(other, "tier"):
return False
return bool(self.id == other.id and self.tier == other.tier)

@classmethod
def make(cls: Type[Self], data: dict) -> Self:
def make(cls: Type[Self], data: dict[str, Any]) -> Self:
object_id = UUID(data["id"]["object_id"])
record_id = UUID(data["id"]["record_id"])
workspace_id = UUID(data["id"]["workspace_id"])
Expand All @@ -70,7 +72,7 @@ def make(cls: Type[Self], data: dict) -> Self:
status = status_info.get("status", {}).get("title")

fix_workspace_id_info = get_latest_value(values.get("workspace_id", [{}]))
fix_workspace_id = optional_uuid(fix_workspace_id_info.get("value"))
fix_workspace_id = optional_uuid(str(fix_workspace_id_info.get("value")))
if fix_workspace_id is None:
log.error(f"Fix workspace ID not found for {record_id}: {data}")

Expand Down Expand Up @@ -102,7 +104,7 @@ class AttioPerson(AttioResource):
users: list[AttioUser] = field(default_factory=list)

@classmethod
def make(cls: Type[Self], data: dict) -> Self:
def make(cls: Type[Self], data: dict[str, Any]) -> Self:
object_id = UUID(data["id"]["object_id"])
record_id = UUID(data["id"]["record_id"])
workspace_id = UUID(data["id"]["workspace_id"])
Expand Down Expand Up @@ -155,11 +157,13 @@ class AttioUser(AttioResource):
person: Optional[AttioPerson] = None
workspaces: list[AttioWorkspace] = field(default_factory=list)

def __eq__(self, other):
return self.id == other.id and self.email.lower() == other.email.lower()
def __eq__(self: Self, other: Any) -> bool:
if not hasattr(other, "id") or not hasattr(other, "email"):
return False
return bool(self.id == other.id and str(self.email).lower() == str(other.email).lower())

@classmethod
def make(cls: Type[Self], data: dict) -> Self:
def make(cls: Type[Self], data: dict[str, Any]) -> Self:
object_id = UUID(data["id"]["object_id"])
record_id = UUID(data["id"]["record_id"])
workspace_id = UUID(data["id"]["workspace_id"])
Expand All @@ -179,21 +183,21 @@ def make(cls: Type[Self], data: dict) -> Self:
log.error(f"Fix user ID not found for {record_id}: {data}")

person_info = get_latest_value(values.get("person", [{}]))
person_id = optional_uuid(person_info.get("target_record_id"))
person_id = optional_uuid(str(person_info.get("target_record_id")))

workspace_refs = None
workspace_info = values.get("workspace", [])
for workspace in workspace_info:
workspace_ref = optional_uuid(workspace.get("target_record_id"))
workspace_ref = optional_uuid(str(workspace.get("target_record_id")))
if workspace_refs is None:
workspace_refs = []
workspace_refs.append(workspace_ref)

cls_data = {
"id": user_id,
"object_id": object_id,
"record_id": record_id,
"workspace_id": workspace_id,
"id": user_id,
"created_at": created_at,
"demo_workspace_viewed": None,
"email": primary_email_address,
Expand All @@ -206,7 +210,7 @@ def make(cls: Type[Self], data: dict) -> Self:

return cls(**cls_data)

def create_or_update(self) -> tuple[str, dict]:
def create_or_update(self) -> tuple[str, dict[str, Any]]:
data = {
"values": {
"primary_email_address": [{"email_address": self.email}],
Expand Down
16 changes: 9 additions & 7 deletions fixattiosync/fixdata.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
import os
import psycopg
from psycopg.rows import dict_row
from uuid import UUID
from argparse import ArgumentParser
from .logger import log
from .fixresources import FixUser, FixWorkspace
from typing import Optional


class FixData:
def __init__(self, db, user, password, host="localhost", port=5432):
def __init__(self, db: str, user: str, password: str, host: str = "localhost", port: int = 5432) -> None:
self.db = db
self.user = user
self.password = password
self.host = host
self.port = port
self.conn = None
self.conn: Optional[psycopg.Connection] = None
self.hydrated = False
self.__workspaces = {}
self.__users = {}
self.__workspaces: dict[UUID, FixWorkspace] = {}
self.__users: dict[UUID, FixUser] = {}

@property
def users(self) -> list[FixUser]:
Expand All @@ -30,7 +32,7 @@ def workspaces(self) -> list[FixWorkspace]:
self.hydrate()
return list(self.__workspaces.values())

def connect(self):
def connect(self) -> None:
log.debug("Connecting to the database")
if self.conn is None:
try:
Expand All @@ -42,7 +44,7 @@ def connect(self):
log.error(f"Error connecting to the database: {e}")
self.conn = None

def hydrate(self):
def hydrate(self) -> None:
if self.conn is None:
self.connect()

Expand Down Expand Up @@ -81,7 +83,7 @@ def hydrate(self):
log.debug(f"Found {len(self.__users)} users in database")
self.hydrated = True

def close(self):
def close(self) -> None:
if self.conn is not None:
log.debug("Closing database connection")
self.conn.close()
Expand Down
26 changes: 15 additions & 11 deletions fixattiosync/fixresources.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass, field
from datetime import datetime
from uuid import UUID
from typing import Optional
from typing import Optional, Self, Any
from .attioresources import AttioPerson, AttioWorkspace


Expand All @@ -20,15 +20,17 @@ class FixUser:
updated_at: datetime
workspaces: list[FixWorkspace] = field(default_factory=list)

def __eq__(self, other):
return self.id == other.id and self.email.lower() == other.email.lower()
def __eq__(self: Self, other: Any) -> bool:
if not hasattr(other, "id") or not hasattr(other, "email"):
return False
return bool(self.id == other.id and str(self.email).lower() == str(other.email).lower())

def attio_data(
self, person: Optional[AttioPerson] = None, workspaces: Optional[list[AttioWorkspace]] = None
) -> dict:
) -> dict[str, Any]:
object_id = "users"
matching_attribute = "user_id"
data = {
data: dict[str, Any] = {
"data": {
"values": {
"user_id": str(self.id),
Expand Down Expand Up @@ -58,10 +60,10 @@ def attio_data(
"data": data,
}

def attio_person(self) -> dict:
def attio_person(self) -> dict[str, Any]:
object_id = "people"
matching_attribute = "email_addresses"
data = {"data": {"values": {"email_addresses": [{"email_address": self.email}]}}}
data: dict[str, Any] = {"data": {"values": {"email_addresses": [{"email_address": self.email}]}}}
return {
"object_id": object_id,
"matching_attribute": matching_attribute,
Expand All @@ -87,13 +89,15 @@ class FixWorkspace:
owner: Optional[FixUser] = None
users: list[FixUser] = field(default_factory=list)

def __eq__(self, other):
return self.id == other.id and self.name == other.name and self.tier == other.tier
def __eq__(self: Self, other: Any) -> bool:
if not hasattr(other, "id") or not hasattr(other, "name") or not hasattr(other, "tier"):
return False
return bool(self.id == other.id and self.name == other.name and self.tier == other.tier)

def attio_data(self) -> dict:
def attio_data(self) -> dict[str, Any]:
object_id = "workspaces"
matching_attribute = "workspace_id"
data = {
data: dict[str, Any] = {
"data": {
"values": {
"workspace_id": str(self.id),
Expand Down
Loading

0 comments on commit 73f64d7

Please sign in to comment.