Skip to content

Commit

Permalink
fix subconnection
Browse files Browse the repository at this point in the history
  • Loading branch information
sehnem committed Sep 21, 2023
1 parent 4ff0bec commit 1317e62
Showing 1 changed file with 48 additions and 24 deletions.
72 changes: 48 additions & 24 deletions tap_shopify/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from singer_sdk.helpers.jsonpath import extract_jsonpath
from singer_sdk.pagination import SinglePagePaginator
from singer_sdk.streams import GraphQLStream
from memoization import cached

from tap_shopify.exceptions import InvalidOperation, OperationFailed
from tap_shopify.gql_queries import (
Expand All @@ -35,7 +36,7 @@ def wrapper(*args, **kwargs):
if not [f for f in stack() if f.function == func.__name__]:
connections["in_conn"] = False
objs.clear()

field_name = args[1]["name"]
field_kind = args[1]["kind"]

Expand All @@ -59,7 +60,14 @@ class ShopifyStream(GraphQLStream):

query_name = None
single_object_params = None
ignore_objs = ["image", "metafield", "metafields", "metafieldconnection", "privateMetafield", "privateMetafields"]
ignore_objs = [
"image",
"metafield",
"metafields",
"metafieldconnection",
"privateMetafield",
"privateMetafields",
]
_requests_session = None
denied_fields = []
stream_connections = []
Expand Down Expand Up @@ -106,7 +114,7 @@ def additional_arguments(self) -> dict:
return ["includeClosed: true"]
return []

# @verify_connections
@cached
@verify_recursion
def extract_field_type(self, field) -> str:
"""Extract the field type from the schema."""
Expand All @@ -132,7 +140,10 @@ def extract_field_type(self, field) -> str:
return None
list_field_type = self.extract_field_type(obj_type)
if list_field_type:
if obj_type["name"].endswith("Edge") and not "node" in list_field_type.type_dict["properties"].keys():
if (
obj_type["name"].endswith("Edge")
and not "node" in list_field_type.type_dict["properties"].keys()
):
return None
return th.ArrayType(list_field_type)
elif kind == "INTERFACE" and self.config.get("bulk"):
Expand All @@ -146,14 +157,20 @@ def extract_field_type(self, field) -> str:
elif kind == "SCALAR":
return type_mapping.get(name, th.StringType)

def extract_gql_schema(self, gql_type):
"""Extract the schema for the stream."""
gql_type_lw = gql_type.lower()
schema_gen = (s for s in self.schema_gql if s["name"].lower() == gql_type_lw)
return next(schema_gen, None)

def get_fields_schema(self, fields) -> dict:
"""Build the schema for the stream."""
# Filtering the fields that are not needed
field_names = [f["name"] for f in fields]
if "edges" in field_names:
fields = [f for f in fields if f["name"]=="edges"]
fields = [f for f in fields if f["name"] == "edges"]
elif "node" in field_names:
fields = [f for f in fields if f["name"]=="node"]
fields = [f for f in fields if f["name"] == "node"]

properties = []
for field in fields:
Expand All @@ -167,10 +184,9 @@ def get_fields_schema(self, fields) -> dict:
continue

if type_def.get("name") and type_def["name"].endswith("Connection"):
self.stream_connections.append(dict(
name=field_name,
of_type=type_def["name"][:-10]
))
self.stream_connections.append(
dict(name=field_name, of_type=type_def["name"][:-10])
)

required = field["type"].get("kind") == "NON_NULL"
field_type = self.extract_field_type(type_def)
Expand All @@ -180,12 +196,6 @@ def get_fields_schema(self, fields) -> dict:
properties.append(property)
return properties

def extract_gql_schema(self, gql_type):
"""Extract the schema for the stream."""
gql_type_lw = gql_type.lower()
schema_gen = (s for s in self.schema_gql if s["name"].lower() == gql_type_lw)
return next(schema_gen, None)

@cached_property
def catalog_dict(self):
"""Return the catalog for the stream."""
Expand All @@ -202,7 +212,9 @@ def schema(self) -> dict:
stream = (s for s in streams if s["tap_stream_id"] == self.name)
stream_catalog = next(stream, None)
if stream_catalog:
metadata = next(f for f in stream_catalog["metadata"] if not f["breadcrumb"])
metadata = next(
f for f in stream_catalog["metadata"] if not f["breadcrumb"]
)
if not metadata["metadata"].get("selected"):
return stream_catalog["schema"]

Expand All @@ -224,7 +236,7 @@ def selected_properties(self):
or field_name == self.replication_key
):
selected_properties.append(field_name)
return selected_properties
return selected_properties

@property
def gql_selected_fields(self):
Expand All @@ -248,6 +260,13 @@ def denest_schema(schema):

return denest_schema(catalog)

@cached_property
def selected_connections(self):
"""Return the selected connections for the stream."""
return [
c for c in self.stream_connections if c["name"] in self.selected_properties
]

def validate_response(self, response: requests.Response) -> None:
"""Validate HTTP response."""

Expand Down Expand Up @@ -313,6 +332,10 @@ def query_gql(self) -> str:
additional_args = ", " + ", ".join(self.additional_arguments)
query = query.replace("__additional_args__", additional_args)

if self.selected_connections:
conn_names = list(set([c["name"] for c in self.selected_connections]))
for conn_name in conn_names:
query = query.replace(conn_name, f"{conn_name}(first:1)")
return query

def get_url_params(
Expand Down Expand Up @@ -404,7 +427,7 @@ def get_operation_status(self):

return response

def check_status(self, operation_id, sleep_time=10, timeout=1800):
def check_status(self, operation_id, sleep_time=10, timeout=3600):
status_jsonpath = "$.data.currentBulkOperation"
start = datetime.now().timestamp()

Expand Down Expand Up @@ -441,20 +464,21 @@ def parse_response_bulk(self, response: requests.Response) -> Iterable[dict]:
main_item = None
for line in output.iter_lines():
line = simplejson.loads(line)
selected_connections = [c for c in self.stream_connections if c["name"] in self.selected_properties]
if "__parentId" not in line.keys():
if main_item:
yield main_item
main_item = line
for sc in selected_connections:
for sc in self.selected_connections:
main_item[sc["name"]] = {}
main_item[sc["name"]]["edges"] = []
main_item["variants"] = {}
main_item["variants"]["edges"] = []
elif main_item["id"]==line["__parentId"]:
elif main_item["id"] == line["__parentId"]:
del line["__parentId"]
line_type = line["id"].split("/")[-2]
field_name = next(c["name"] for c in selected_connections if c["of_type"]==line_type)
field_name = next(
c["name"] for c in self.selected_connections if c["of_type"] == line_type
)
main_item[field_name]["edges"].append(dict(node=line))
else:
pass
Expand All @@ -469,7 +493,7 @@ def parse_response(self, response: requests.Response) -> Iterable[dict]:
def query(self) -> str:
"""Set or return the GraphQL query string."""
# TODO: figure out how to handle interfaces
# self.evaluate_query()
self.evaluate_query()
if self.config.get("bulk"):
return self.query_bulk()
return self.query_gql()
Expand Down

0 comments on commit 1317e62

Please sign in to comment.