Skip to content

Commit

Permalink
feat: add set_labels method (#21)
Browse files Browse the repository at this point in the history
* feat: add set_labels method

* feat: lazy loading queries

* feat: finally impl set_page_labels_by_labels

* feat: Clean up labels before running tests

* fix: Refactor test_omnivoreql.py for better readability and maintainability
  • Loading branch information
yazdipour authored Jun 18, 2024
1 parent 074129d commit a9a5cec
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 58 deletions.
110 changes: 81 additions & 29 deletions omnivoreql/omnivoreql.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import uuid
import os
import glob
from typing import List, Optional
from gql.transport.requests import RequestsHTTPTransport
from gql import gql, Client
Expand All @@ -26,38 +25,35 @@ def __init__(
use_json=True,
)
self.client = Client(transport=transport, fetch_schema_from_transport=False)
self.queries = self._load_queries("queries")

def _load_queries(self, queries_path: str) -> dict:
current_dir = os.path.dirname(os.path.abspath(__file__))
queries_path = os.path.join(current_dir, queries_path)
queries = {}
for file in glob.glob(f"{queries_path}/*.graphql"):
with open(file, "r") as f:
queries[os.path.basename(file).replace(".graphql", "")] = (
f.read().replace("\n", " ")
)
return queries
self.queries = {}

def _get_query(self, query_name: str) -> str:
if query_name not in self.queries:
current_dir = os.path.dirname(os.path.abspath(__file__))
query_file_path = os.path.join(current_dir, f"queries/{query_name}.graphql")
with open(query_file_path, "r") as file:
self.queries[query_name] = gql(file.read())
return self.queries[query_name]

def save_url(
self,
url: str,
labels: Optional[List[str]] = None,
clientRequestId: str = str(uuid.uuid4()),
client_request_id: str = str(uuid.uuid4()),
):
"""
Save a URL to Omnivore.
:param url: The URL to save.
:param labels: The labels to assign to the item.
:param clientRequestId: The client request ID.
:param client_request_id: The client request ID.
"""
labels = [] if labels is None else [{"name": x} for x in labels]
return self.client.execute(
gql(self.queries["SaveUrl"]),
self._get_query("SaveUrl"),
variable_values={
"input": {
"clientRequestId": clientRequestId,
"clientRequestId": client_request_id,
"source": "api",
"url": url,
"labels": labels,
Expand All @@ -75,7 +71,7 @@ def save_page(self, url: str, original_content: str, labels: List[str] = None):
"""
labels = [] if labels is None else [{"name": x} for x in labels]
return self.client.execute(
gql(self.queries["SavePage"]),
self._get_query("SavePage"),
variable_values={
"input": {
"clientRequestId": str(uuid.uuid4()),
Expand All @@ -91,19 +87,19 @@ def get_profile(self):
"""
Get the profile of the current user.
"""
return self.client.execute(gql(self.queries["Viewer"]))
return self.client.execute(self._get_query("Viewer"))

def get_labels(self):
"""
Get the labels of the current user.
"""
return self.client.execute(gql(self.queries["Labels"]))
return self.client.execute(self._get_query("Labels"))

def get_subscriptions(self):
"""
Get the subscriptions of the current user.
"""
return self.client.execute(gql(self.queries["GetSubscriptions"]))
return self.client.execute(self._get_query("GetSubscriptions"))

def get_articles(
self,
Expand All @@ -123,7 +119,7 @@ def get_articles(
:param include_content: Whether to include the content of the articles.
"""
return self.client.execute(
gql(self.queries["Search"]),
self._get_query("Search"),
variable_values={
"first": limit,
"after": cursor,
Expand All @@ -142,7 +138,7 @@ def get_article(self, username: str, slug: str, format: str = None):
:param format: The format of the article to return.
"""
return self.client.execute(
gql(self.queries["ArticleContent"]),
self._get_query("ArticleContent"),
variable_values={
"username": username,
"slug": slug,
Expand All @@ -158,7 +154,7 @@ def archive_article(self, article_id: str, to_archive: bool = True):
:param to_archive: Whether to archive or unarchive the article.
"""
return self.client.execute(
gql(self.queries["ArchiveSavedItem"]),
self._get_query("ArchiveSavedItem"),
variable_values={"input": {"linkId": article_id, "archived": to_archive}},
)

Expand All @@ -176,9 +172,8 @@ def delete_article(self, article_id: str):
:param article_id: The ID of the article to delete.
"""
q = self.queries["DeleteSavedItem"]
return self.client.execute(
gql(q),
self._get_query("DeleteSavedItem"),
variable_values={"input": {"articleID": article_id, "bookmark": False}},
)

Expand All @@ -189,7 +184,7 @@ def create_label(self, label: CreateLabelInput):
:param label: An instance of LabelInput with the label data.
"""
return self.client.execute(
gql(self.queries["CreateLabel"]),
self._get_query("CreateLabel"),
variable_values={"input": asdict(label)},
)

Expand All @@ -205,7 +200,7 @@ def update_label(
:param description: The description of the label.
"""
return self.client.execute(
gql(self.queries["UpdateLabel"]),
self._get_query("UpdateLabel"),
variable_values={
"input": {
"labelId": label_id,
Expand All @@ -223,6 +218,63 @@ def delete_label(self, label_id: str):
:param label_id: The ID of the label to delete.
"""
return self.client.execute(
gql(self.queries["DeleteLabel"]),
self._get_query("DeleteLabel"),
variable_values={"id": label_id},
)

def set_page_labels_by_create_label_inputs(
self, page_id: str, labels: List[CreateLabelInput]
) -> dict:
"""
Set labels for a page.
:param page_id: The ID of the page to set labels for.
:param labels: The labels to set.
"""
return self.set_page_labels_by_labels(page_id, parsed_labels)

def set_page_labels_by_labels(self, page_id: str, labels: List[dict]) -> dict:
"""
Set labels for a page.
:param page_id: The ID of the page to set labels for.
:param labels: The labels to set.
"""
parsed_labels = []
for label in labels:
if isinstance(label, CreateLabelInput):
label = asdict(label)
parsed_labels.append(
{
"name": label["name"],
"color": label["color"],
"description": label["description"],
}
)

return self.client.execute(
self._get_query("ApplyLabels"),
variable_values={
"input": {
"pageId": page_id,
"labels": parsed_labels,
}
},
)

def set_page_labels_by_label_ids(self, page_id: str, label_ids: List[str]) -> dict:
"""
Set labels for a page.
:param page_id: The ID of the page to set labels for.
:param label_ids: The IDs of the labels to set.
"""
return self.client.execute(
self._get_query("ApplyLabels"),
variable_values={
"input": {
"pageId": page_id,
"labelIds": label_ids,
}
},
)
89 changes: 60 additions & 29 deletions tests/test_omnivoreql.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
To run the tests, execute the following command:
python -m unittest discover -s tests
"""


class TestOmnivoreQL(unittest.TestCase):
client = None

@classmethod
def setUpClass(cls):
Expand All @@ -33,6 +34,21 @@ def setUpClass(cls):
raise ValueError("OMNIVORE_API_TOKEN is not set")
print(f"OMNIVORE_API_TOKEN: {api_token[:4]}")
cls.client = OmnivoreQL(api_token)
cls.sample_label = None
# clean_up_created_labels from previous tests
try:
labels = cls.client.get_labels()["labels"]
for label in labels["labels"]:
if not cls.sample_label:
cls.sample_label = label
continue
cls.client.delete_label(label["id"])
except Exception as e:
print(f"Error cleaning up labels: {e}")
if not cls.sample_label:
cls.sample_label = cls.client.create_label(
CreateLabelInput(str(hash("test_update_label")), "#FF0000")
)["createLabel"]["label"]

@staticmethod
def getEnvVariable(variable_name):
Expand All @@ -51,26 +67,19 @@ def test_get_profile(self):

def test_save_url(self):
# When
result = self.client.save_url("https://github.com/yazdipour/OmnivoreQL")
result = self.client.save_url("https://github.com/yazdipour/OmnivoreQL", ["testLabel"])
# Then
self.assertIsNotNone(result)
self.assertNotIn("errorCodes", result["saveUrl"])
self.assertTrue(result["saveUrl"]["url"].startswith("http"))

def test_save_page(self):
# When
result = self.client.save_page("http://example.com", "Example")
result = self.client.save_page("http://example.com", "Example", ["label1"])
# Then
self.assertIsNotNone(result)
self.assertNotIn("errorCodes", result["savePage"])

def test_save_url_with_labels(self):
# When
result = self.client.save_url("https://www.google.com", ["test", "google"])
# Then
self.assertIsNotNone(result)
self.assertFalse("errorCodes" in result["saveUrl"])

def test_get_articles(self):
# When
articles = self.client.get_articles()
Expand Down Expand Up @@ -128,9 +137,8 @@ def test_delete_article(self):

def test_create_label(self):
# Given
label_name = hash("TestLabel") # create random label name to avoid conflicts
label_input = CreateLabelInput(
name=str(label_name), color="#FF0000", description="label description"
name=str(hash("test_create_label")), color="#FF0000"
)
# When
result = self.client.create_label(label_input)
Expand All @@ -145,12 +153,11 @@ def test_create_label(self):

def test_update_label(self):
# Given
label_input = CreateLabelInput(name=hash("TestLabel"), color="#FF0000")
created_label = self.client.create_label(label_input)
label_sample = self.sample_label
# When
new_label_name = f"UpdatedLabel-{label_input.name}"
new_label_name = f"UpdatedLabel-{hash(label_sample['name'])}"
result = self.client.update_label(
label_id=created_label["createLabel"]["label"]["id"],
label_id=label_sample["id"],
name=new_label_name,
color="#0000FF",
description="An updated TestLabel",
Expand All @@ -166,27 +173,51 @@ def test_update_label(self):

def test_delete_label(self):
# Given
label_input = CreateLabelInput(name=hash("TestLabel"), color="#FF0000")
created_label = self.client.create_label(label_input)
label_sample = self.client.create_label(
CreateLabelInput(str(hash("test_update_label")), "#FF0000")
)["createLabel"]["label"]
# When
result = self.client.delete_label(
created_label["createLabel"]["label"]["id"]
)
result = self.client.delete_label(label_sample["id"])
# Then
self.assertIsNotNone(result)
self.assertNotIn("errorCodes", result["deleteLabel"])
self.assertEqual(
result["deleteLabel"]["label"]["id"],
created_label["createLabel"]["label"]["id"],
label_sample["id"],
)

def test_clean_up_created_labels(self):
try:
labels = self.client.get_labels()["labels"]
for label in labels["labels"]:
self.client.delete_label(label["id"])
except Exception as e:
print(f"Error cleaning up labels: {e}")
def test_set_page_labels_by_labels(self):
# Given
page = self.client.get_articles(limit=1)["search"]["edges"][0]["node"]
label_sample = self.sample_label
created_label_input = CreateLabelInput(
label_sample["name"],
label_sample["color"],
label_sample["description"],
)
# When
result = self.client.set_page_labels_by_labels(page["id"], [created_label_input])
# Then
self.assertIsNotNone(result)
self.assertNotIn("errorCodes", result["setLabels"])
self.assertEqual(
result["setLabels"]["labels"][0]["id"], label_sample["id"]
)

def test_set_page_labels_by_label_ids(self):
# Given
page = self.client.get_articles(limit=1)["search"]["edges"][0]["node"]
label_sample = self.sample_label
# When
result = self.client.set_page_labels_by_label_ids(
page["id"], label_ids=[label_sample["id"]]
)
# Then
self.assertIsNotNone(result)
self.assertNotIn("errorCodes", result["setLabels"])
self.assertEqual(
result["setLabels"]["labels"][0]["id"], label_sample["id"]
)


if __name__ == "__main__":
Expand Down

0 comments on commit a9a5cec

Please sign in to comment.