diff --git a/openhexa/toolbox/dhis2/api.py b/openhexa/toolbox/dhis2/api.py index 7fcbbc4..fe51405 100644 --- a/openhexa/toolbox/dhis2/api.py +++ b/openhexa/toolbox/dhis2/api.py @@ -25,9 +25,7 @@ class DHIS2Connection(Protocol): class Api: - def __init__(self, connection: DHIS2Connection, cache_dir: Optional[Union[Path, str]] = None): - self.url = self.parse_api_url(connection.url) - + def __init__(self, connection: DHIS2Connection = None, cache_dir: Optional[Union[Path, str]] = None, **kwargs): self.session = requests.Session() adapter = HTTPAdapter( max_retries=Retry( @@ -40,7 +38,15 @@ def __init__(self, connection: DHIS2Connection, cache_dir: Optional[Union[Path, self.session.mount("https://", adapter) self.session.mount("http://", adapter) - self.session = self.authenticate(connection.username, connection.password) + if connection is None and ("url" not in kwargs or "username" not in kwargs or "password" not in kwargs): + raise DHIS2Error("Connection or url, username and password must be provided") + + if connection: + self.url = self.parse_api_url(connection.url) + self.session = self.authenticate(connection.username, connection.password) + else: + self.url = self.parse_api_url(kwargs["url"]) + self.session = self.authenticate(kwargs["username"], kwargs["password"]) self.cache = None if cache_dir: diff --git a/openhexa/toolbox/dhis2/dhis2.py b/openhexa/toolbox/dhis2/dhis2.py index 2bd2a1c..7846b27 100644 --- a/openhexa/toolbox/dhis2/dhis2.py +++ b/openhexa/toolbox/dhis2/dhis2.py @@ -14,20 +14,22 @@ class DHIS2: - def __init__(self, connection: DHIS2Connection, cache_dir: Union[str, Path] = None): + def __init__(self, connection: DHIS2Connection = None, cache_dir: Union[str, Path] = None, **kwargs): """Initialize a new DHIS2 instance. Parameters ---------- - connection : openhexa DHIS2Connection + connection : openhexa DHIS2Connection, optional An initialized openhexa dhis2 connection + kwargs: + Additional arguments to pass to initialize openhexa dhis2 connection, such as `url`, `username`, `password` cache_dir : str, optional Cache directory. Actual cache data will be stored under a sub-directory named after the DHIS2 instance domain. """ if isinstance(cache_dir, str): cache_dir = Path(cache_dir) - self.api = Api(connection, cache_dir) + self.api = Api(connection, cache_dir, **kwargs) self.meta = Metadata(self) self.version = self.meta.system_info().get("version") self.data_value_sets = DataValueSets(self) @@ -69,11 +71,13 @@ def organisation_unit_levels(self) -> List[dict]: ) return levels - def organisation_units(self, filter: str = None) -> List[dict]: + def organisation_units(self, fields: str = "id,name,level,path,geometry", filter: str = None) -> List[dict]: """Get organisation units metadata. Parameters ---------- + fields: str, optional + DHIS2 fields to include in the response, where default value is "id,name,level,path,geometry" filter: str, optional DHIS2 query filter @@ -82,7 +86,7 @@ def organisation_units(self, filter: str = None) -> List[dict]: list of dict Id, name, level, path and geometry of all org units. """ - params = {"fields": "id,name,level,path,geometry"} + params = {"fields": fields} if filter: params["filter"] = filter org_units = [] @@ -93,18 +97,22 @@ def organisation_units(self, filter: str = None) -> List[dict]: for ou in page["organisationUnits"]: org_units.append( { - "id": ou.get("id"), - "name": ou.get("name"), - "level": ou.get("level"), - "path": ou.get("path"), - "geometry": json.dumps(ou.get("geometry")) if ou.get("geometry") else None, + key: ou.get(key) + if key != "geometry" + else json.dumps(ou.get("geometry")) + if ou.get("geometry") + else None + for key in fields.split(",") } ) return org_units - def organisation_unit_groups(self) -> List[dict]: + def organisation_unit_groups(self, fields: str = "id,name,organisationUnits") -> List[dict]: """Get organisation unit groups metadata. - + Parameters + ---------- + fields: str, optional + DHIS2 fields to include in the response, where default value is "id,name,organisationUnits" Return ------ list of dict @@ -113,23 +121,29 @@ def organisation_unit_groups(self) -> List[dict]: org_unit_groups = [] for page in self.client.api.get_paged( "organisationUnitGroups", - params={"fields": "id,name,organisationUnits"}, + params={"fields": fields}, ): groups = [] for group in page.get("organisationUnitGroups"): groups.append( { - "id": group.get("id"), - "name": group.get("name"), - "organisation_units": [ou.get("id") for ou in group["organisationUnits"]], + key: group.get(key) + if key != "organisationUnits" + else [ou.get("id") for ou in group["organisationUnits"]] + for key in fields.split(",") } ) org_unit_groups += groups return groups - def datasets(self) -> List[dict]: + def datasets(self, fields: str = "id,name,dataSetElements,indicators,organisationUnits") -> List[dict]: """Get datasets metadata. + Parameters + ---------- + fields: str, optional + DHIS2 fields to include in the response, where default value is + "id,name,dataSetElements,indicators,organisationUnits" Return ------ list of dict @@ -139,23 +153,35 @@ def datasets(self) -> List[dict]: for page in self.client.api.get_paged( "dataSets", params={ - "fields": "id,name,dataSetElements,indicators,organisationUnits", + "fields": fields, "pageSize": 10, }, ): for ds in page["dataSets"]: - row = {"id": ds.get("id"), "name": ds.get("name")} - row["data_elements"] = [dx["dataElement"]["id"] for dx in ds["dataSetElements"]] - row["indicators"] = [indicator["id"] for indicator in ds["indicators"]] - row["organisation_units"] = [ou["id"] for ou in ds["organisationUnits"]] + fields_list = fields.split(",") + row = {} + if "data_elements" in fields_list: + row["data_elements"] = [dx["dataElement"]["id"] for dx in ds["dataSetElements"]] + fields_list.remove("data_elements") + if "indicators" in fields_list: + row["indicators"] = [indicator["id"] for indicator in ds["indicators"]] + fields_list.remove("indicators") + if "organisation_units" in fields_list: + row["organisation_units"] = [ou["id"] for ou in ds["organisationUnits"]] + fields_list.remove("organisation_units") + row.update({key: ds.get(key) for key in fields_list}) datasets.append(row) return datasets - def data_elements(self, filter: str = None) -> List[dict]: + def data_elements( + self, fields: str = "id,name,aggregationType,zeroIsSignificant", filter: str = None + ) -> List[dict]: """Get data elements metadata. Parameters ---------- + fields: str, optional + DHIS2 fields to include in the response, where default value is "id,name,aggregationType,zeroIsSignificant" filter: str, optional DHIS2 query filter @@ -164,7 +190,7 @@ def data_elements(self, filter: str = None) -> List[dict]: list of dict Id, name, and aggregation type of all data elements. """ - params = {"fields": "id,name,aggregationType,zeroIsSignificant"} + params = {"fields": fields} if filter: params["filter"] = filter elements = [] @@ -172,12 +198,16 @@ def data_elements(self, filter: str = None) -> List[dict]: "dataElements", params=params, ): - elements += page["dataElements"] + for element in page["dataElements"]: + elements.append({key: element.get(key) for key in params["fields"].split(",")}) return elements - def data_element_groups(self) -> List[dict]: + def data_element_groups(self, fields: str = "id,name,dataElements") -> List[dict]: """Get data element groups metadata. - + Parameters + ---------- + fields: str, optional + DHIS2 fields to include in the response, where default value is "id,name,dataElements" Return ------ list of dict @@ -186,15 +216,14 @@ def data_element_groups(self) -> List[dict]: de_groups = [] for page in self.client.api.get_paged( "dataElementGroups", - params={"fields": "id,name,dataElements"}, + params={"fields": fields}, ): groups = [] for group in page.get("dataElementGroups"): groups.append( { - "id": group.get("id"), - "name": group.get("name"), - "data_elements": [ou.get("id") for ou in group["dataElements"]], + key: group.get(key) if key != "dataElements" else [de.get("id") for de in group["dataElements"]] + for key in fields.split(",") } ) de_groups += groups @@ -213,7 +242,7 @@ def category_option_combos(self) -> List[dict]: combos += page.get("categoryOptionCombos") return combos - def indicators(self, filter: str = None) -> List[dict]: + def indicators(self, fields: str = "id,name,numerator,denominator", filter: str = None) -> List[dict]: """Get indicators metadata. Parameters @@ -226,7 +255,7 @@ def indicators(self, filter: str = None) -> List[dict]: list of dict Id, name, numerator and denominator of all indicators. """ - params = {"fields": "id,name,numerator,denominator"} + params = {"fields": fields} if filter: params["filter"] = filter indicators = [] @@ -234,10 +263,12 @@ def indicators(self, filter: str = None) -> List[dict]: "indicators", params=params, ): + for indicator in page["indicators"]: + indicators.append({key: indicator.get(key) for key in fields.split(",")}) indicators += page["indicators"] return indicators - def indicator_groups(self) -> List[dict]: + def indicator_groups(self, fields: str = "id,name,indicators") -> List[dict]: """Get indicator groups metadata. Return @@ -248,15 +279,16 @@ def indicator_groups(self) -> List[dict]: ind_groups = [] for page in self.client.api.get_paged( "indicatorGroups", - params={"fields": "id,name,indicators"}, + params={"fields": fields}, ): groups = [] for group in page.get("indicatorGroups"): groups.append( { - "id": group.get("id"), - "name": group.get("name"), - "indicators": [ou.get("id") for ou in group["indicators"]], + key: group.get(key) + if key != "indicators" + else [indicator.get("id") for indicator in group["indicators"]] + for key in fields.split(",") } ) ind_groups += groups diff --git a/tests/dhis2/test_dhis2.py b/tests/dhis2/test_dhis2.py index 165abbb..50e63ce 100644 --- a/tests/dhis2/test_dhis2.py +++ b/tests/dhis2/test_dhis2.py @@ -17,6 +17,33 @@ def con(): return Connection("http://localhost:8080", "admin", "district") +@responses.activate +@pytest.mark.parametrize("version", VERSIONS) +def test_connection_from_object(version, con): + responses_dir = Path("tests", "dhis2", "responses", version) + responses._add_from_file(Path(responses_dir, "dhis2_init.yaml")) + api = DHIS2(con, cache_dir=None) + assert api is not None + + +@responses.activate +@pytest.mark.parametrize("version", VERSIONS) +def test_connection_from_kwargs(version): + responses_dir = Path("tests", "dhis2", "responses", version) + responses._add_from_file(Path(responses_dir, "dhis2_init.yaml")) + api = DHIS2(url="http://localhost:8080", username="admin", password="district", cache_dir=None) + assert api is not None + + +@responses.activate +@pytest.mark.parametrize("version", VERSIONS) +def test_connection_from_kwargs_fails(version): + responses_dir = Path("tests", "dhis2", "responses", version) + responses._add_from_file(Path(responses_dir, "dhis2_init.yaml")) + with pytest.raises(DHIS2Error): + DHIS2(url="http://localhost:8080", cache_dir=None) + + @responses.activate @pytest.mark.parametrize("version", VERSIONS) def test_data_elements(version, con):