Skip to content

Commit

Permalink
feat HEXA-1149 dhis2 client improvements (#75)
Browse files Browse the repository at this point in the history
* feat: update client, orgunits, datasets

* feat: extends all calls with feilds

* fixes test, adds exception

* refactors api

* add test for connection initialisation
  • Loading branch information
nazarfil authored Dec 26, 2024
1 parent cf6b793 commit 4be0046
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 42 deletions.
14 changes: 10 additions & 4 deletions openhexa/toolbox/dhis2/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ class DHIS2Connection(Protocol):


class Api:
def __init__(self, connection: DHIS2Connection, cache_dir: Optional[Union[Path, str]] = None):
self.url = self.parse_api_url(connection.url)

def __init__(self, connection: DHIS2Connection = None, cache_dir: Optional[Union[Path, str]] = None, **kwargs):
self.session = requests.Session()
adapter = HTTPAdapter(
max_retries=Retry(
Expand All @@ -40,7 +38,15 @@ def __init__(self, connection: DHIS2Connection, cache_dir: Optional[Union[Path,
self.session.mount("https://", adapter)
self.session.mount("http://", adapter)

self.session = self.authenticate(connection.username, connection.password)
if connection is None and ("url" not in kwargs or "username" not in kwargs or "password" not in kwargs):
raise DHIS2Error("Connection or url, username and password must be provided")

if connection:
self.url = self.parse_api_url(connection.url)
self.session = self.authenticate(connection.username, connection.password)
else:
self.url = self.parse_api_url(kwargs["url"])
self.session = self.authenticate(kwargs["username"], kwargs["password"])

self.cache = None
if cache_dir:
Expand Down
108 changes: 70 additions & 38 deletions openhexa/toolbox/dhis2/dhis2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,22 @@


class DHIS2:
def __init__(self, connection: DHIS2Connection, cache_dir: Union[str, Path] = None):
def __init__(self, connection: DHIS2Connection = None, cache_dir: Union[str, Path] = None, **kwargs):
"""Initialize a new DHIS2 instance.
Parameters
----------
connection : openhexa DHIS2Connection
connection : openhexa DHIS2Connection, optional
An initialized openhexa dhis2 connection
kwargs:
Additional arguments to pass to initialize openhexa dhis2 connection, such as `url`, `username`, `password`
cache_dir : str, optional
Cache directory. Actual cache data will be stored under a sub-directory
named after the DHIS2 instance domain.
"""
if isinstance(cache_dir, str):
cache_dir = Path(cache_dir)
self.api = Api(connection, cache_dir)
self.api = Api(connection, cache_dir, **kwargs)
self.meta = Metadata(self)
self.version = self.meta.system_info().get("version")
self.data_value_sets = DataValueSets(self)
Expand Down Expand Up @@ -69,11 +71,13 @@ def organisation_unit_levels(self) -> List[dict]:
)
return levels

def organisation_units(self, filter: str = None) -> List[dict]:
def organisation_units(self, fields: str = "id,name,level,path,geometry", filter: str = None) -> List[dict]:
"""Get organisation units metadata.
Parameters
----------
fields: str, optional
DHIS2 fields to include in the response, where default value is "id,name,level,path,geometry"
filter: str, optional
DHIS2 query filter
Expand All @@ -82,7 +86,7 @@ def organisation_units(self, filter: str = None) -> List[dict]:
list of dict
Id, name, level, path and geometry of all org units.
"""
params = {"fields": "id,name,level,path,geometry"}
params = {"fields": fields}
if filter:
params["filter"] = filter
org_units = []
Expand All @@ -93,18 +97,22 @@ def organisation_units(self, filter: str = None) -> List[dict]:
for ou in page["organisationUnits"]:
org_units.append(
{
"id": ou.get("id"),
"name": ou.get("name"),
"level": ou.get("level"),
"path": ou.get("path"),
"geometry": json.dumps(ou.get("geometry")) if ou.get("geometry") else None,
key: ou.get(key)
if key != "geometry"
else json.dumps(ou.get("geometry"))
if ou.get("geometry")
else None
for key in fields.split(",")
}
)
return org_units

def organisation_unit_groups(self) -> List[dict]:
def organisation_unit_groups(self, fields: str = "id,name,organisationUnits") -> List[dict]:
"""Get organisation unit groups metadata.
Parameters
----------
fields: str, optional
DHIS2 fields to include in the response, where default value is "id,name,organisationUnits"
Return
------
list of dict
Expand All @@ -113,23 +121,29 @@ def organisation_unit_groups(self) -> List[dict]:
org_unit_groups = []
for page in self.client.api.get_paged(
"organisationUnitGroups",
params={"fields": "id,name,organisationUnits"},
params={"fields": fields},
):
groups = []
for group in page.get("organisationUnitGroups"):
groups.append(
{
"id": group.get("id"),
"name": group.get("name"),
"organisation_units": [ou.get("id") for ou in group["organisationUnits"]],
key: group.get(key)
if key != "organisationUnits"
else [ou.get("id") for ou in group["organisationUnits"]]
for key in fields.split(",")
}
)
org_unit_groups += groups
return groups

def datasets(self) -> List[dict]:
def datasets(self, fields: str = "id,name,dataSetElements,indicators,organisationUnits") -> List[dict]:
"""Get datasets metadata.
Parameters
----------
fields: str, optional
DHIS2 fields to include in the response, where default value is
"id,name,dataSetElements,indicators,organisationUnits"
Return
------
list of dict
Expand All @@ -139,23 +153,35 @@ def datasets(self) -> List[dict]:
for page in self.client.api.get_paged(
"dataSets",
params={
"fields": "id,name,dataSetElements,indicators,organisationUnits",
"fields": fields,
"pageSize": 10,
},
):
for ds in page["dataSets"]:
row = {"id": ds.get("id"), "name": ds.get("name")}
row["data_elements"] = [dx["dataElement"]["id"] for dx in ds["dataSetElements"]]
row["indicators"] = [indicator["id"] for indicator in ds["indicators"]]
row["organisation_units"] = [ou["id"] for ou in ds["organisationUnits"]]
fields_list = fields.split(",")
row = {}
if "data_elements" in fields_list:
row["data_elements"] = [dx["dataElement"]["id"] for dx in ds["dataSetElements"]]
fields_list.remove("data_elements")
if "indicators" in fields_list:
row["indicators"] = [indicator["id"] for indicator in ds["indicators"]]
fields_list.remove("indicators")
if "organisation_units" in fields_list:
row["organisation_units"] = [ou["id"] for ou in ds["organisationUnits"]]
fields_list.remove("organisation_units")
row.update({key: ds.get(key) for key in fields_list})
datasets.append(row)
return datasets

def data_elements(self, filter: str = None) -> List[dict]:
def data_elements(
self, fields: str = "id,name,aggregationType,zeroIsSignificant", filter: str = None
) -> List[dict]:
"""Get data elements metadata.
Parameters
----------
fields: str, optional
DHIS2 fields to include in the response, where default value is "id,name,aggregationType,zeroIsSignificant"
filter: str, optional
DHIS2 query filter
Expand All @@ -164,20 +190,24 @@ def data_elements(self, filter: str = None) -> List[dict]:
list of dict
Id, name, and aggregation type of all data elements.
"""
params = {"fields": "id,name,aggregationType,zeroIsSignificant"}
params = {"fields": fields}
if filter:
params["filter"] = filter
elements = []
for page in self.client.api.get_paged(
"dataElements",
params=params,
):
elements += page["dataElements"]
for element in page["dataElements"]:
elements.append({key: element.get(key) for key in params["fields"].split(",")})
return elements

def data_element_groups(self) -> List[dict]:
def data_element_groups(self, fields: str = "id,name,dataElements") -> List[dict]:
"""Get data element groups metadata.
Parameters
----------
fields: str, optional
DHIS2 fields to include in the response, where default value is "id,name,dataElements"
Return
------
list of dict
Expand All @@ -186,15 +216,14 @@ def data_element_groups(self) -> List[dict]:
de_groups = []
for page in self.client.api.get_paged(
"dataElementGroups",
params={"fields": "id,name,dataElements"},
params={"fields": fields},
):
groups = []
for group in page.get("dataElementGroups"):
groups.append(
{
"id": group.get("id"),
"name": group.get("name"),
"data_elements": [ou.get("id") for ou in group["dataElements"]],
key: group.get(key) if key != "dataElements" else [de.get("id") for de in group["dataElements"]]
for key in fields.split(",")
}
)
de_groups += groups
Expand All @@ -213,7 +242,7 @@ def category_option_combos(self) -> List[dict]:
combos += page.get("categoryOptionCombos")
return combos

def indicators(self, filter: str = None) -> List[dict]:
def indicators(self, fields: str = "id,name,numerator,denominator", filter: str = None) -> List[dict]:
"""Get indicators metadata.
Parameters
Expand All @@ -226,18 +255,20 @@ def indicators(self, filter: str = None) -> List[dict]:
list of dict
Id, name, numerator and denominator of all indicators.
"""
params = {"fields": "id,name,numerator,denominator"}
params = {"fields": fields}
if filter:
params["filter"] = filter
indicators = []
for page in self.client.api.get_paged(
"indicators",
params=params,
):
for indicator in page["indicators"]:
indicators.append({key: indicator.get(key) for key in fields.split(",")})
indicators += page["indicators"]
return indicators

def indicator_groups(self) -> List[dict]:
def indicator_groups(self, fields: str = "id,name,indicators") -> List[dict]:
"""Get indicator groups metadata.
Return
Expand All @@ -248,15 +279,16 @@ def indicator_groups(self) -> List[dict]:
ind_groups = []
for page in self.client.api.get_paged(
"indicatorGroups",
params={"fields": "id,name,indicators"},
params={"fields": fields},
):
groups = []
for group in page.get("indicatorGroups"):
groups.append(
{
"id": group.get("id"),
"name": group.get("name"),
"indicators": [ou.get("id") for ou in group["indicators"]],
key: group.get(key)
if key != "indicators"
else [indicator.get("id") for indicator in group["indicators"]]
for key in fields.split(",")
}
)
ind_groups += groups
Expand Down
27 changes: 27 additions & 0 deletions tests/dhis2/test_dhis2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,33 @@ def con():
return Connection("http://localhost:8080", "admin", "district")


@responses.activate
@pytest.mark.parametrize("version", VERSIONS)
def test_connection_from_object(version, con):
responses_dir = Path("tests", "dhis2", "responses", version)
responses._add_from_file(Path(responses_dir, "dhis2_init.yaml"))
api = DHIS2(con, cache_dir=None)
assert api is not None


@responses.activate
@pytest.mark.parametrize("version", VERSIONS)
def test_connection_from_kwargs(version):
responses_dir = Path("tests", "dhis2", "responses", version)
responses._add_from_file(Path(responses_dir, "dhis2_init.yaml"))
api = DHIS2(url="http://localhost:8080", username="admin", password="district", cache_dir=None)
assert api is not None


@responses.activate
@pytest.mark.parametrize("version", VERSIONS)
def test_connection_from_kwargs_fails(version):
responses_dir = Path("tests", "dhis2", "responses", version)
responses._add_from_file(Path(responses_dir, "dhis2_init.yaml"))
with pytest.raises(DHIS2Error):
DHIS2(url="http://localhost:8080", cache_dir=None)


@responses.activate
@pytest.mark.parametrize("version", VERSIONS)
def test_data_elements(version, con):
Expand Down

0 comments on commit 4be0046

Please sign in to comment.