Skip to content

Commit

Permalink
Handle ayncio.TimeoutError in OktaService._retry (#64)
Browse files Browse the repository at this point in the history
  • Loading branch information
somethingnew2-0 authored May 18, 2024
1 parent 9397c09 commit 8138bed
Showing 1 changed file with 43 additions and 55 deletions.
98 changes: 43 additions & 55 deletions api/services/okta_service.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

import asyncio
import datetime
import logging
from datetime import UTC, datetime
from typing import Any, Callable, Optional

import dateutil.parser
Expand All @@ -17,7 +17,7 @@
REQUEST_MAX_RETRIES = 3
RETRIABLE_STATUS_CODES = [429, 500, 502, 503, 504]
HTTP_TOO_MANY_REQUESTS = 429
RATE_LIMIT_RESET_HEADER = 'X-Rate-Limit-Reset'
RATE_LIMIT_RESET_HEADER = "X-Rate-Limit-Reset"
RETRY_BACKOFF_FACTOR = 0.5
REQUEST_TIMEOUT = 30

Expand All @@ -28,13 +28,16 @@
class OktaService:
"""For interacting with the Okta API"""

def initialize(self, okta_domain: Optional[str], okta_api_token: Optional[str], use_group_owners_api: bool = False) -> None:
def initialize(
self, okta_domain: Optional[str], okta_api_token: Optional[str], use_group_owners_api: bool = False
) -> None:
# Ignore an okta domain and api token when testing
if okta_domain is None or okta_api_token is None:
return
self.okta_domain = okta_domain
self.okta_api_token = okta_api_token
self.okta_client = OktaClient({
self.okta_client = OktaClient(
{
"orgUrl": f"https://{okta_domain}",
"token": okta_api_token,
}
Expand All @@ -45,7 +48,11 @@ def initialize(self, okta_domain: Optional[str], okta_api_token: Optional[str],
async def _retry(func: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Any:
"""Retry Okta API requests with specific status codes using exponential backoff."""
for attempt in range(1 + REQUEST_MAX_RETRIES):
result = await asyncio.wait_for(func(*args, **kwargs), timeout=REQUEST_TIMEOUT)
try:
result = await asyncio.wait_for(func(*args, **kwargs), timeout=REQUEST_TIMEOUT)
except asyncio.TimeoutError as e:
logger.warning("Timeout on Okta request. Retrying...")
result = (None, e)

if len(result) == 2:
response, error = result
Expand All @@ -54,27 +61,27 @@ async def _retry(func: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Any:
else:
raise Exception("Unexpected result structure from Okta client.")

if (attempt == REQUEST_MAX_RETRIES or
error is None or
response is None or
(response is not None and response.get_status() not in RETRIABLE_STATUS_CODES)):
if (
attempt == REQUEST_MAX_RETRIES
or error is None
or ((response is not None) and (response.get_status() not in RETRIABLE_STATUS_CODES))
):
return result

if response is None:
logger.warning('Got None response from Okta resource. Retrying...')
else:
logger.warning(f'Got {response.get_status()} response from Okta resource {response._url}, with error:'
f' {error}. Retrying...'
if response is not None:
logger.warning(
f"Got {response.get_status()} response from Okta resource {response._url}, with error:"
f" {error}. Retrying..."
)

# If rate limit is hit, then wait until the "X-Rate-Limit-Reset" time, else backoff exponentially
if (response.get_status() == HTTP_TOO_MANY_REQUESTS):
logger.warning('Rate limit hit, waiting until reset...')
current_time = datetime.datetime.now(datetime.timezone.utc).timestamp()
if (response is not None) and (response.get_status() == HTTP_TOO_MANY_REQUESTS):
logger.warning("Rate limit hit, waiting until reset...")
current_time = datetime.now(UTC).timestamp()
rate_limit_reset = float(response.headers[RATE_LIMIT_RESET_HEADER])
wait_time = max(rate_limit_reset - current_time, 1) # Ensure wait_time is at least 1 second
else:
wait_time = RETRY_BACKOFF_FACTOR * (2 ** attempt)
wait_time = RETRY_BACKOFF_FACTOR * (2**attempt)
await asyncio.sleep(wait_time)

def get_user(self, userId: str) -> User:
Expand Down Expand Up @@ -121,8 +128,8 @@ async def _list_users() -> list[User]:
def create_group(self, name: str, description: str) -> Group:
group, _, error = asyncio.run(
OktaService._retry(
self.okta_client.create_group,
OktaGroupType({"profile": {"name": name, "description": description}}))
self.okta_client.create_group, OktaGroupType({"profile": {"name": name, "description": description}})
)
)
if error is not None:
raise Exception(error)
Expand Down Expand Up @@ -206,7 +213,8 @@ def get_group(self, groupId: str) -> Group:

return Group(group)

DEFAULT_QUERY_PARAMS = {'filter': 'type eq "BUILT_IN" or type eq "OKTA_GROUP"'}
DEFAULT_QUERY_PARAMS = {"filter": 'type eq "BUILT_IN" or type eq "OKTA_GROUP"'}

def list_groups(self, *, query_params: dict[str, str] = DEFAULT_QUERY_PARAMS) -> list[Group]:
async def _list_groups(query_params: dict[str, str]) -> list[Group]:
groups, resp, error = await OktaService._retry(self.okta_client.list_groups, query_params=query_params)
Expand All @@ -224,18 +232,17 @@ async def _list_groups(query_params: dict[str, str]) -> list[Group]:

def list_groups_with_active_rules(self) -> dict[str, list[OktaGroupRuleType]]:
group_rules = self.list_group_rules()
group_ids_with_group_rules = {} # type: dict[str, list[OktaGroupRuleType]]
group_ids_with_group_rules = {} # type: dict[str, list[OktaGroupRuleType]]
for group_rule in group_rules:
if group_rule.status == "ACTIVE":
for id in group_rule.actions.assign_user_to_groups.group_ids:
group_ids_with_group_rules.setdefault(id, []).append(group_rule)
return group_ids_with_group_rules

def list_group_rules(self, *, query_params: dict[str,str]={}) -> list[OktaGroupRuleType]:
def list_group_rules(self, *, query_params: dict[str, str] = {}) -> list[OktaGroupRuleType]:
async def _list_group_rules(query_params: dict[str, str]) -> list[OktaGroupRuleType]:
group_rules, resp, error = await OktaService._retry(
self.okta_client.list_group_rules,
query_params=query_params
self.okta_client.list_group_rules, query_params=query_params
)

if error is not None:
Expand Down Expand Up @@ -308,9 +315,7 @@ async def async_add_owner_to_group(self, groupId: str, userId: str) -> None:
_, error = await OktaService._retry(request_executor.execute, request)

# Ignore error if owner is already assigned to group
if error is not None and not error.message.endswith(
"Provided owner is already assigned to this group"
):
if error is not None and not error.message.endswith("Provided owner is already assigned to this group"):
raise Exception(error)

return
Expand All @@ -333,9 +338,7 @@ async def async_remove_owner_from_group(self, groupId: str, userId: str) -> None

request, error = await request_executor.create_request(
method="DELETE",
url="/api/v1/groups/{groupId}/owners/{userId}".format(
groupId=groupId, userId=userId
),
url="/api/v1/groups/{groupId}/owners/{userId}".format(groupId=groupId, userId=userId),
body={},
headers={},
oauth=False,
Expand Down Expand Up @@ -402,9 +405,7 @@ def update_okta_user(self, okta_user: OktaUser, user_attrs_to_titles: dict[str,
okta_user.created_at = dateutil.parser.isoparse(self.user.created)
if okta_user.updated_at is None:
okta_user.updated_at = (
dateutil.parser.isoparse(self.user.last_updated)
if self.user.last_updated is not None
else None
dateutil.parser.isoparse(self.user.last_updated) if self.user.last_updated is not None else None
)
okta_user.deleted_at = self.get_deleted_at()
okta_user.email = self.user.profile.login
Expand All @@ -416,14 +417,9 @@ def update_okta_user(self, okta_user: OktaUser, user_attrs_to_titles: dict[str,
return okta_user

def _convert_profile_keys_to_titles(self, user_attrs_to_titles: dict[str, str]) -> dict[str, str]:
return dict(
(
(user_attrs_to_titles.get(k, k), v)
for (k, v) in self.user.profile.__dict__.items()
)
)
return dict(((user_attrs_to_titles.get(k, k), v) for (k, v) in self.user.profile.__dict__.items()))

def get_deleted_at(self) -> Optional[datetime.datetime]:
def get_deleted_at(self) -> Optional[datetime]:
return (
dateutil.parser.isoparse(self.user.status_changed)
if self.user.status in ("SUSPENDED", "DEPROVISIONED")
Expand Down Expand Up @@ -472,35 +468,27 @@ def __getattr__(self, name: str) -> Any:
return getattr(self.group, name)

def update_okta_group(
self,
okta_group: OktaGroup,
group_ids_with_group_rules: dict[str, list[OktaGroupRuleType]]
self, okta_group: OktaGroup, group_ids_with_group_rules: dict[str, list[OktaGroupRuleType]]
) -> OktaGroup:
if okta_group.id is None:
okta_group.id = self.group.id
if okta_group.created_at is None:
okta_group.created_at = dateutil.parser.isoparse(self.group.created)
if okta_group.updated_at is None:
okta_group.updated_at = (
dateutil.parser.isoparse(self.group.last_updated)
if self.group.last_updated is not None
else None
dateutil.parser.isoparse(self.group.last_updated) if self.group.last_updated is not None else None
)

okta_group.name = self.group.profile.name
okta_group.description = (
self.group.profile.description
if self.group.profile.description is not None
else ""
)
okta_group.description = self.group.profile.description if self.group.profile.description is not None else ""

okta_group.is_managed = is_managed_group(self, group_ids_with_group_rules)

# Get externally managed group data
if self.group.id in group_ids_with_group_rules:
okta_group.externally_managed_data = {rule.name:rule.conditions.expression.value
for rule in group_ids_with_group_rules[self.group.id]}

okta_group.externally_managed_data = {
rule.name: rule.conditions.expression.value for rule in group_ids_with_group_rules[self.group.id]
}

return okta_group

Expand Down

0 comments on commit 8138bed

Please sign in to comment.