Skip to content

Commit

Permalink
discover
Browse files Browse the repository at this point in the history
  • Loading branch information
sehnem committed Sep 20, 2023
1 parent 3e4d28d commit fde2114
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 267 deletions.
99 changes: 56 additions & 43 deletions tap_shopify/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,44 +2,49 @@

from __future__ import annotations

import simplejson
from datetime import datetime
from functools import cached_property
from http import HTTPStatus
from inspect import stack
from time import sleep
from typing import Any, Dict, Iterable, Optional, cast

import requests

import simplejson
from singer_sdk import typing as th
from singer_sdk.exceptions import FatalAPIError, RetriableAPIError
from singer_sdk.helpers.jsonpath import extract_jsonpath

from http import HTTPStatus

from singer_sdk import typing as th
from singer_sdk.pagination import SinglePagePaginator
from singer_sdk.streams import GraphQLStream

from tap_shopify.auth import ShopifyAuthenticator
from tap_shopify.gql_queries import schema_query
from tap_shopify.paginator import ShopifyPaginator
from tap_shopify.gql_queries import query_incremental

from datetime import datetime
from time import sleep
from singer_sdk.pagination import SinglePagePaginator

from tap_shopify.exceptions import InvalidOperation, OperationFailed
from tap_shopify.gql_queries import bulk_query, bulk_query_status
from tap_shopify.gql_queries import (
bulk_query,
bulk_query_status,
query_incremental,
)
from tap_shopify.paginator import ShopifyPaginator


def verify_recursion(func):
"""Verify if the stream is recursive."""
objs = []
connections = dict(num=0, in_conn=False)

def wrapper(*args, **kwargs):
if not [f for f in stack() if f.function == func.__name__]:
connections["in_conn"] = False
objs.clear()

field_name = args[1]["name"]
field_kind = args[1]["kind"]

if field_kind == "INTERFACE":
if connections["in_conn"] or connections["num"] >= 5:
return
connections["in_conn"] = True
connections["num"] += 1

if field_name not in objs:
if field_kind == "OBJECT":
objs.append(args[1]["name"])
Expand All @@ -54,9 +59,8 @@ class ShopifyStream(GraphQLStream):

query_name = None
single_object_params = None
ignore_objs = []
ignore_objs = ["image", "metafield", "metafields", "metafieldconnection", "privateMetafield", "privateMetafields"]
_requests_session = None
nested_connections = []
denied_fields = []

@property
Expand Down Expand Up @@ -90,15 +94,18 @@ def http_headers(self) -> dict:
def schema_gql(self) -> dict:
"""Return the schema for the stream."""
return self._tap.schema_gql

@cached_property
def additional_arguments(self) -> dict:
"""Return the schema for the stream."""
gql_query = next(q for q in self._tap.queries_gql if q["name"]==self.query_name)
gql_query = next(
q for q in self._tap.queries_gql if q["name"] == self.query_name
)
if "includeClosed" in [a["name"] for a in gql_query["args"]]:
return ["includeClosed: true"]
return []

# @verify_connections
@verify_recursion
def extract_field_type(self, field) -> str:
"""Extract the field type from the schema."""
Expand All @@ -120,9 +127,16 @@ def extract_field_type(self, field) -> str:
return th.ObjectType(*properties)
elif kind == "LIST":
obj_type = field["ofType"]["ofType"]
if not obj_type:
return None
list_field_type = self.extract_field_type(obj_type)
if list_field_type:
if obj_type["name"].endswith("Edge") and not "node" in list_field_type.type_dict["properties"].keys():
return None
return th.ArrayType(list_field_type)
elif kind == "INTERFACE" and self.config.get("bulk"):
obj_schema = self.extract_gql_schema(name)
properties = self.get_fields_schema(obj_schema["fields"])
elif kind == "ENUM":
return th.StringType
elif kind == "NON_NULL":
Expand All @@ -133,25 +147,27 @@ def extract_field_type(self, field) -> str:

def get_fields_schema(self, fields) -> dict:
"""Build the schema for the stream."""
# Filtering the fields that are not needed
field_names = [f["name"] for f in fields]
if "edges" in field_names:
fields = [f for f in fields if f["name"]=="edges"]
elif "node" in field_names:
fields = [f for f in fields if f["name"]=="node"]

properties = []
for field in fields:
field_name = field["name"]
type_def = field.get("type", field)
type_def = type_def["ofType"] or type_def
# Ignore all the fields that need arguments
if field.get("isDeprecated") and self.config.get("ignore_deprecated"):
continue
if field.get("args"):
if field["args"][0]["name"] == "first":
self.nested_connections.append(field_name)
continue
if field_name in self.ignore_objs:
continue
if field["type"]["kind"] == "INTERFACE":
continue

required = field["type"].get("kind") == "NON_NULL"
type_def = field.get("type", field)
type_def = type_def["ofType"] or type_def
field_type = self.extract_field_type(type_def)

if field_type:
property = th.Property(field_name, field_type, required=required)
properties.append(property)
Expand Down Expand Up @@ -180,7 +196,7 @@ def schema(self) -> dict:
stream_catalog = next(stream, None)
if stream_catalog:
return stream_catalog["schema"]

stream_type = self.extract_gql_schema(self.gql_type)
properties = self.get_fields_schema(stream_type["fields"])
return th.PropertiesList(*properties).to_dict()
Expand Down Expand Up @@ -223,7 +239,6 @@ def denest_schema(schema):

return denest_schema(catalog)


def validate_response(self, response: requests.Response) -> None:
"""Validate HTTP response."""

Expand All @@ -235,11 +250,11 @@ def validate_response(self, response: requests.Response) -> None:
):
msg = self.response_error_message(response)
raise RetriableAPIError(msg, response)

json_resp = response.json()

if errors:=json_resp.get("errors"):
if len(errors)==1:
if errors := json_resp.get("errors"):
if len(errors) == 1:
error = errors[0]
code = error.get("extensions", {}).get("code")
if code in ["THROTTLED", "MAX_COST_EXCEEDED"]:
Expand All @@ -260,7 +275,7 @@ def convert_id_fields(self, row: dict) -> dict:
if not isinstance(row, dict):
return row
for key, value in row.items():
if key=="id" and isinstance(value, str):
if key == "id" and isinstance(value, str):
row["id"] = row["id"].split("/")[-1].split("?")[0]
elif isinstance(value, dict):
row[key] = self.convert_id_fields(value)
Expand All @@ -280,7 +295,6 @@ def post_process(

return row


def query_gql(self) -> str:
"""Set or return the GraphQL query string."""
base_query = query_incremental
Expand Down Expand Up @@ -310,7 +324,7 @@ def get_url_params(
if self.single_object_params:
params = self.single_object_params
return params

def prepare_request_payload(
self, context: Optional[dict], next_page_token: Optional[Any]
) -> Optional[dict]:
Expand All @@ -334,7 +348,6 @@ def parse_response_gql(self, response: requests.Response) -> Iterable[dict]:

yield from extract_jsonpath(json_path, json_resp)


def query_bulk(self) -> str:
"""Set or return the GraphQL query string."""
base_query = bulk_query
Expand Down Expand Up @@ -411,9 +424,7 @@ def parse_response_bulk(self, response: requests.Response) -> Iterable[dict]:
errors = next(extract_jsonpath(error_jsonpath, json_resp), None)
if errors:
raise InvalidOperation(simplejson.dumps(errors))
operation_id = next(
extract_jsonpath(operation_id_jsonpath, json_resp)
)
operation_id = next(extract_jsonpath(operation_id_jsonpath, json_resp))

url = self.check_status(operation_id)

Expand All @@ -431,12 +442,12 @@ def parse_response(self, response: requests.Response) -> Iterable[dict]:
@cached_property
def query(self) -> str:
"""Set or return the GraphQL query string."""
self.evaluate_query()
# TODO: figure out how to handle interfaces
# self.evaluate_query()
if self.config.get("bulk"):
return self.query_bulk()
return self.query_gql()


def evaluate_query(self) -> dict:
query = self.query_gql().lstrip()
params = self.get_url_params(None, None)
Expand Down Expand Up @@ -464,4 +475,6 @@ def evaluate_query(self) -> dict:
self.denied_fields.append(message.split(" ")[3])
else:
raise FatalAPIError(error.get("message", ""), response)
self.evaluate_query()
self.evaluate_query()

# TODO: get query cost from here
107 changes: 0 additions & 107 deletions tap_shopify/client_bulk.py

This file was deleted.

Loading

0 comments on commit fde2114

Please sign in to comment.