diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d1de7029f..14f9311b8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,7 +12,11 @@ repos: fondant/.*| tests/.*| )$ - args: [--fix, --exit-non-zero-on-fix] + args: [ + "--target-version=py38", + "--fix", + "--exit-non-zero-on-fix", + ] - repo: https://github.com/PyCQA/bandit diff --git a/components/caption_images/src/main.py b/components/caption_images/src/main.py index ff9cec607..7404f0687 100644 --- a/components/caption_images/src/main.py +++ b/components/caption_images/src/main.py @@ -40,7 +40,7 @@ def caption_image_batch( *, model: BlipForConditionalGeneration, processor: BlipProcessor, - max_new_tokens: int + max_new_tokens: int, ) -> pd.Series: """Caption a batch of images.""" input_batch = torch.cat(image_batch.tolist()) @@ -67,7 +67,7 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: images = dataframe["images"]["data"].apply( process_image, processor=self.processor, - device=self.device + device=self.device, ) results: t.List[pd.Series] = [] @@ -78,8 +78,8 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: batch, model=self.model, processor=self.processor, - max_new_tokens=self.max_new_tokens - ).T + max_new_tokens=self.max_new_tokens, + ).T, ) return pd.concat(results).to_frame(name=("captions", "text")) diff --git a/components/download_images/src/main.py b/components/download_images/src/main.py index 36c92a546..9b222e3b0 100644 --- a/components/download_images/src/main.py +++ b/components/download_images/src/main.py @@ -49,7 +49,7 @@ def download_image(url, timeout, user_agent_token, disallowed_header_directives) f"+https://github.com/rom1504/img2dataset)" try: request = urllib.request.Request( - url, data=None, headers={"User-Agent": user_agent_string} + url, data=None, headers={"User-Agent": user_agent_string}, ) with urllib.request.urlopen(request, timeout=timeout) as r: if disallowed_header_directives and is_disallowed( @@ -77,7 +77,7 @@ def download_image_with_retry( ): for _ in range(retries + 1): img_stream = download_image( - url, timeout, user_agent_token, disallowed_header_directives + url, timeout, user_agent_token, disallowed_header_directives, ) if img_stream is not None: # resize the image @@ -114,7 +114,7 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: dataframe[[ ("images", "data"), ("images", "width"), - ("images", "height") + ("images", "height"), ]] = dataframe.apply( lambda example: download_image_with_retry( url=example["images"]["url"], diff --git a/components/embedding_based_laion_retrieval/src/main.py b/components/embedding_based_laion_retrieval/src/main.py index 54e21831d..1c8e1672c 100644 --- a/components/embedding_based_laion_retrieval/src/main.py +++ b/components/embedding_based_laion_retrieval/src/main.py @@ -21,7 +21,7 @@ def setup( *, num_images: int, aesthetic_score: int, - aesthetic_weight: float + aesthetic_weight: float, ) -> None: """ @@ -54,7 +54,7 @@ async def async_query(): futures = [ loop.run_in_executor( executor, - functools.partial(self.client.query, embedding_input=embedding.tolist()) + functools.partial(self.client.query, embedding_input=embedding.tolist()), ) for embedding in dataframe["embeddings"]["data"] ] @@ -64,7 +64,7 @@ async def async_query(): loop.run_until_complete(async_query()) results_df = pd.DataFrame(results)[["id", "url"]] - results_df.set_index("id", inplace=True) + results_df = results_df.set_index("id") results_df.columns = [["images"], ["url"]] return results_df diff --git a/components/filter_comments/src/main.py b/components/filter_comments/src/main.py index 3a5282866..aa2698657 100644 --- a/components/filter_comments/src/main.py +++ b/components/filter_comments/src/main.py @@ -20,7 +20,7 @@ def transform( *, dataframe: dd.DataFrame, min_comments_ratio: float, - max_comments_ratio: float + max_comments_ratio: float, ) -> dd.DataFrame: """ Args: @@ -31,16 +31,14 @@ def transform( Filtered dask dataframe. """ # Apply the function to the desired column and filter the DataFrame - filtered_df = dataframe[ + return dataframe[ dataframe["code_content"].map_partitions( lambda example: example.map(get_comments_to_code_ratio).between( - min_comments_ratio, max_comments_ratio - ) + min_comments_ratio, max_comments_ratio, + ), ) ] - return filtered_df - if __name__ == "__main__": component = FilterCommentsComponent.from_args() diff --git a/components/filter_comments/src/utils/text_extraction.py b/components/filter_comments/src/utils/text_extraction.py index 0221b7738..0462029e0 100644 --- a/components/filter_comments/src/utils/text_extraction.py +++ b/components/filter_comments/src/utils/text_extraction.py @@ -69,9 +69,7 @@ def get_comments(source: str) -> str: for toknum, tokval, _, _, _ in g: if toknum == tokenize.COMMENT: comments.append((toknum, tokval)) - result = tokenize.untokenize(comments).replace("#", "") - - return result + return tokenize.untokenize(comments).replace("#", "") def get_docstrings(source: str) -> t.List[str]: @@ -88,13 +86,13 @@ def get_docstrings(source: str) -> t.List[str]: source = source.read() docstrings = sorted( - parse_docstrings(source), key=lambda x: (NODE_TYPES.get(type(x[0])), x[1]) + parse_docstrings(source), key=lambda x: (NODE_TYPES.get(type(x[0])), x[1]), ) grouped = groupby(docstrings, key=lambda x: NODE_TYPES.get(type(x[0]))) results = [] for _, group in grouped: - for _, name, docstring in group: + for _, _name, docstring in group: if docstring: results.append(docstring) return results @@ -116,7 +114,7 @@ def get_text_python(source: str, extract_comments: bool = True) -> str: except Exception: docstrings = "" warnings.warn( - "code couldn't be parsed due to compilation failure, no docstring is extracted" + "code couldn't be parsed due to compilation failure, no docstring is extracted", ) if extract_comments: @@ -142,4 +140,4 @@ def get_comments_to_code_ratio(text: str) -> float: """ comments = get_text_python(text) - return len(comments) / len(text) \ No newline at end of file + return len(comments) / len(text) diff --git a/components/filter_line_length/src/main.py b/components/filter_line_length/src/main.py index ed04b5564..9401af8ac 100644 --- a/components/filter_line_length/src/main.py +++ b/components/filter_line_length/src/main.py @@ -20,7 +20,7 @@ def transform( dataframe: dd.DataFrame, avg_line_length_threshold: int, max_line_length_threshold: int, - alphanum_fraction_threshold: float + alphanum_fraction_threshold: float, ) -> dd.DataFrame: """ Args: @@ -31,14 +31,12 @@ def transform( Returns: Filtered dask dataframe. """ - filtered_df = dataframe[ + return dataframe[ (dataframe["code_avg_line_length"] > avg_line_length_threshold) & (dataframe["code_max_line_length"] > max_line_length_threshold) & (dataframe["code_alphanum_fraction"] > alphanum_fraction_threshold) ] - return filtered_df - if __name__ == "__main__": component = FilterLineLengthComponent.from_args() diff --git a/components/image_cropping/src/image_crop.py b/components/image_cropping/src/image_crop.py index b767c3eef..a6072192e 100644 --- a/components/image_cropping/src/image_crop.py +++ b/components/image_cropping/src/image_crop.py @@ -47,7 +47,7 @@ def get_image_borders(image: Image.Image) -> t.Tuple: def remove_borders( - image_bytes: bytes, cropping_threshold: int = -30, padding: int = 10 + image_bytes: bytes, cropping_threshold: int = -30, padding: int = 10, ) -> bytes: """This method removes borders by checking the overlap between a color and the original image. By subtracting these two @@ -89,12 +89,12 @@ def remove_borders( if image_crop.size[0] > image_crop.size[1]: padding = int((image_crop.size[0] - image_crop.size[1]) / 2) image_crop = ImageOps.expand( - image_crop, border=(0, padding), fill=color_common + image_crop, border=(0, padding), fill=color_common, ) else: padding = int((image_crop.size[1] - image_crop.size[0]) / 2) image_crop = ImageOps.expand( - image_crop, border=(padding, 0), fill=color_common + image_crop, border=(padding, 0), fill=color_common, ) # serialize image to JPEG diff --git a/components/image_cropping/src/main.py b/components/image_cropping/src/main.py index 664d6f19d..85445169f 100644 --- a/components/image_cropping/src/main.py +++ b/components/image_cropping/src/main.py @@ -36,7 +36,7 @@ def transform( *, dataframe: dd.DataFrame, cropping_threshold: int = -30, - padding: int = 10 + padding: int = 10, ) -> dd.DataFrame: """ Args: diff --git a/components/image_embedding/src/main.py b/components/image_embedding/src/main.py index 7fcc0d796..99758af1c 100644 --- a/components/image_embedding/src/main.py +++ b/components/image_embedding/src/main.py @@ -75,7 +75,7 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: images = dataframe["images"]["data"].apply( process_image, processor=self.processor, - device=self.device + device=self.device, ) results: t.List[pd.Series] = [] for batch in np.split(images, np.arange(self.batch_size, len(images), self.batch_size)): diff --git a/components/image_resolution_filtering/src/main.py b/components/image_resolution_filtering/src/main.py index 63b612b6c..f1cff0aba 100644 --- a/components/image_resolution_filtering/src/main.py +++ b/components/image_resolution_filtering/src/main.py @@ -12,7 +12,7 @@ class ImageFilterComponent(DaskTransformComponent): """Component that filters images based on height and width.""" def transform( - self, *, dataframe: dd.DataFrame, min_width: int, min_height: int + self, *, dataframe: dd.DataFrame, min_width: int, min_height: int, ) -> dd.DataFrame: """ Args: @@ -38,4 +38,4 @@ def transform( if __name__ == "__main__": component = ImageFilterComponent.from_args() - component.run() \ No newline at end of file + component.run() diff --git a/components/load_from_hf_hub/src/main.py b/components/load_from_hf_hub/src/main.py index 0f3d774db..cc0bb3a39 100644 --- a/components/load_from_hf_hub/src/main.py +++ b/components/load_from_hf_hub/src/main.py @@ -35,7 +35,7 @@ def load(self, if image_column_names is not None: for image_column_name in image_column_names: dask_df[image_column_name] = dask_df[image_column_name].map( - lambda x: x["bytes"], meta=("bytes", bytes) + lambda x: x["bytes"], meta=("bytes", bytes), ) # 3) Rename columns diff --git a/components/pii_redaction/src/main.py b/components/pii_redaction/src/main.py index 7404f7626..e05a6aaa0 100644 --- a/components/pii_redaction/src/main.py +++ b/components/pii_redaction/src/main.py @@ -40,7 +40,7 @@ def transform( # redact PII # we use random replacements by default - with open("replacements.json", "r") as f: + with open("replacements.json") as f: replacements = json.load(f) dataframe["code_content"] = dataframe.apply( @@ -54,7 +54,7 @@ def transform( meta=(None, "str"), ) dataframe = dataframe.drop( - ["code_secrets", "code_has_secrets", "code_number_secrets"], axis=1 + ["code_secrets", "code_has_secrets", "code_number_secrets"], axis=1, ) return dataframe diff --git a/components/pii_redaction/src/pii_detection.py b/components/pii_redaction/src/pii_detection.py index 7ac977f8c..e98c78665 100644 --- a/components/pii_redaction/src/pii_detection.py +++ b/components/pii_redaction/src/pii_detection.py @@ -26,12 +26,12 @@ def scan_pii(text, key_detector="other"): if key_detector == "regex": # use a regex to detect keys + emails + ips secrets = secrets + detect_email_addresses( - text, tag_types={"KEY", "EMAIL", "IP_ADDRESS"} + text, tag_types={"KEY", "EMAIL", "IP_ADDRESS"}, ) else: # detect emails and ip addresses with regexes secrets = secrets + detect_email_addresses( - text, tag_types={"EMAIL", "IP_ADDRESS"} + text, tag_types={"EMAIL", "IP_ADDRESS"}, ) # for keys use detect-secrets tool secrets = secrets + detect_keys(text) diff --git a/components/pii_redaction/src/pii_redaction.py b/components/pii_redaction/src/pii_redaction.py index d2b29e986..f880d458e 100644 --- a/components/pii_redaction/src/pii_redaction.py +++ b/components/pii_redaction/src/pii_redaction.py @@ -21,7 +21,6 @@ ], } -# providergs = ["google", "cloudfare", "alternate-dns", "quad9","open-dns", "comodo", "adguard"] POPULAR_DNS_SERVERS = [ "8.8.8.8", "8.8.4.4", @@ -113,7 +112,7 @@ def redact_pii_text(text, secrets, replacements, add_references=False): last_text = text for secret in secrets: # skip secret if it's an IP address for private networks or popular DNS servers - if secret["tag"] == "IP_ADDRESS": + if secret["tag"] == "IP_ADDRESS": # ruff: noqa: SIM102 # if secret value in popular DNS servers, skip it if is_private_ip(secret["value"]) or ( secret["value"] in POPULAR_DNS_SERVERS @@ -146,10 +145,9 @@ def redact_pii_text(text, secrets, replacements, add_references=False): else: new_text = text references = "" - result = ( + return ( (new_text, references, modified) if add_references else (new_text, modified) ) - return result def redact_pii(text, secrets, has_secrets, replacements): @@ -160,5 +158,5 @@ def redact_pii(text, secrets, has_secrets, replacements): if has_secrets: new_text, _ = redact_pii_text(text, secrets, replacements) return new_text - else: - return text + + return text diff --git a/components/pii_redaction/src/utils/emails_ip_addresses_detection.py b/components/pii_redaction/src/utils/emails_ip_addresses_detection.py index b27f759fa..b8e65c0ed 100644 --- a/components/pii_redaction/src/utils/emails_ip_addresses_detection.py +++ b/components/pii_redaction/src/utils/emails_ip_addresses_detection.py @@ -28,19 +28,19 @@ year_patterns = [ regex.compile( - r"(?:^|[\b\s@?,!;:\'\")(.\p{Han}])([1-2][0-9]{3}[\p{Pd}/][1-2][0-9]{3})(?:$|[\s@,?!;:\'\"(.\p{Han}])" + r"(?:^|[\b\s@?,!;:\'\")(.\p{Han}])([1-2][0-9]{3}[\p{Pd}/][1-2][0-9]{3})(?:$|[\s@,?!;:\'\"(.\p{Han}])", ), # yyyy-yyyy or yyyy/yyyy regex.compile( - r"(?:^|[\b\s@?,!;:\'\")(.\p{Han}])([1-2][0-9]{3}[\p{Pd}/.][0-3][0-9][\p{Pd}/.][0-3][0-9])(?:$|[\s@,?!;:\'\"(.\p{Han}])" + r"(?:^|[\b\s@?,!;:\'\")(.\p{Han}])([1-2][0-9]{3}[\p{Pd}/.][0-3][0-9][\p{Pd}/.][0-3][0-9])(?:$|[\s@,?!;:\'\"(.\p{Han}])", ), # yyyy-mm-dd or yyyy-dd-mm or yyyy/mm/dd or yyyy/dd/mm or yyyy.mm.dd or yyyy.dd.mm regex.compile( - r"(?:^|[\b\s@?,!;:\'\")(.\p{Han}])([0-3][0-9][\p{Pd}/.][0-3][0-9][\p{Pd}/.](?:[0-9]{2}|[1-2][0-9]{3}))(?:$|[\s@,?!;:\'\"(.\p{Han}])" + r"(?:^|[\b\s@?,!;:\'\")(.\p{Han}])([0-3][0-9][\p{Pd}/.][0-3][0-9][\p{Pd}/.](?:[0-9]{2}|[1-2][0-9]{3}))(?:$|[\s@,?!;:\'\"(.\p{Han}])", ), # mm-dd-yyyy or dd-mm-yyyy or mm/dd/yyyy or dd/mm/yyyy or mm.dd.yyyy or dd.mm.yyyy or the same but with yy instead of yyyy regex.compile( - r"(?:^|[\b\s@?,!;:\'\")(.\p{Han}])([0-3][0-9][\p{Pd}/](?:[0-9]{2}|[1-2][0-9]{3}))(?:$|[\s@,?!;:\'\"(.\p{Han}])" + r"(?:^|[\b\s@?,!;:\'\")(.\p{Han}])([0-3][0-9][\p{Pd}/](?:[0-9]{2}|[1-2][0-9]{3}))(?:$|[\s@,?!;:\'\"(.\p{Han}])", ), # mm-yyyy or mm/yyyy or the same but with yy regex.compile( - r"(?:^|[\b\s@?,!;:\'\")(.\p{Han}])([1-2][0-9]{3}-[0-3][0-9])(?:$|[\s@,?!;:\'\"(.\p{Han}])" + r"(?:^|[\b\s@?,!;:\'\")(.\p{Han}])([1-2][0-9]{3}-[0-3][0-9])(?:$|[\s@,?!;:\'\"(.\p{Han}])", ), # yyyy-mm or yyyy/mm ] @@ -51,7 +51,7 @@ ip_pattern = ( r"(?:^|[\b\s@?,!;:\'\")(.\p{Han}])(" + r"|".join([ipv4_pattern, ipv6_pattern]) - + ")(?:$|[\s@,?!;:'\"(.\p{Han}])" + + ")(?:$|[\\s@,?!;:'\"(.\\p{Han}])" ) # Note: to reduce false positives, a number of technically-valid-but-rarely-used @@ -116,10 +116,7 @@ def ip_has_digit(matched_str): def matches_date_pattern(matched_str): # Screen out date false positives - for year_regex in year_patterns: - if year_regex.match(matched_str): - return True - return False + return any(year_regex.match(matched_str) for year_regex in year_patterns) def filter_versions(matched_str, context): @@ -129,9 +126,8 @@ def filter_versions(matched_str, context): # count occurrence of dots dot_count = matched_str.count(".") exclude = dot_count == 3 and len(matched_str) == 7 # noqa: PLR2004 (magic value) - if exclude: - if "dns" in context.lower() or "server" in context.lower(): - return False + if exclude and ("dns" in context.lower() or "server" in context.lower()): + return False return exclude @@ -178,7 +174,7 @@ def detect_email_addresses(content, tag_types={"EMAIL", "IP_ADDRESS"}): if match.groups(): if len(match.groups()) > 1 and match.groups()[1]: sys.stderr.write( - "Warning: Found substring matches in the main match." + "Warning: Found substring matches in the main match.", ) # setup outputs value = match.group(1) @@ -191,23 +187,23 @@ def detect_email_addresses(content, tag_types={"EMAIL", "IP_ADDRESS"}): if matches_date_pattern(value): continue if filter_versions( - value, content[start - 100 : end + 100] + value, content[start - 100 : end + 100], ) or not_ip_address(value): continue # combine if conditions in one - if tag == "KEY": - # Filter out false positive keys - if not is_gibberish(value): - continue + # Filter out false positive keys + if tag == "KEY" and not is_gibberish(value): + continue matches.append( { "tag": tag, "value": value, "start": start, "end": end, - } + }, ) else: - raise ValueError("No match found inside groups") + msg = "No match found inside groups" + raise ValueError(msg) return matches diff --git a/components/pii_redaction/src/utils/keys_detection.py b/components/pii_redaction/src/utils/keys_detection.py index e542a292a..31d1c9e34 100644 --- a/components/pii_redaction/src/utils/keys_detection.py +++ b/components/pii_redaction/src/utils/keys_detection.py @@ -16,7 +16,7 @@ {"path": "detect_secrets.filters.heuristic.is_templated_secret"}, {"path": "detect_secrets.filters.heuristic.is_sequential_string"}, ] -plugins = [ +plugins = [ # ruff: noqa: ERA001 {"name": "ArtifactoryDetector"}, {"name": "AWSKeyDetector"}, # the entropy detectors esp Base64 need the gibberish detector on top @@ -116,7 +116,7 @@ def detect_keys(content, suffix=".txt"): fp.close() secrets = SecretsCollection() with transient_settings( - {"plugins_used": plugins, "filters_used": filters} + {"plugins_used": plugins, "filters_used": filters}, ): secrets.scan_file(fp.name) os.unlink(fp.name) @@ -136,6 +136,6 @@ def detect_keys(content, suffix=".txt"): "value": secret.secret_value, "start": start, "end": end, - } + }, ) return matches diff --git a/components/prompt_based_laion_retrieval/src/main.py b/components/prompt_based_laion_retrieval/src/main.py index d1ac482f2..5109e94e5 100644 --- a/components/prompt_based_laion_retrieval/src/main.py +++ b/components/prompt_based_laion_retrieval/src/main.py @@ -20,7 +20,7 @@ def setup( *, num_images: int, aesthetic_score: int, - aesthetic_weight: float + aesthetic_weight: float, ) -> None: """ @@ -53,7 +53,7 @@ async def async_query(): loop.run_in_executor( executor, self.client.query, - prompt + prompt, ) for prompt in dataframe["prompts"]["text"] ] @@ -63,7 +63,7 @@ async def async_query(): loop.run_until_complete(async_query()) results_df = pd.DataFrame(results)[["id", "url"]] - results_df.set_index("id", inplace=True) + results_df = results_df.set_index("id") results_df.columns = [["images"], ["url"]] return results_df diff --git a/components/segment_images/src/main.py b/components/segment_images/src/main.py index a509e8249..b666127b3 100644 --- a/components/segment_images/src/main.py +++ b/components/segment_images/src/main.py @@ -26,7 +26,7 @@ def convert_to_rgb(seg: np.array): color_seg: 3D segmentation map contain RGB values for each pixel. """ color_seg = np.zeros( - (seg.shape[0], seg.shape[1], 3), dtype=np.uint8 + (seg.shape[0], seg.shape[1], 3), dtype=np.uint8, ) # height, width, 3 for label, color in enumerate(palette): @@ -67,7 +67,7 @@ def segment_image_batch(image_batch: pd.DataFrame, *, model: AutoModelForSemanti input_batch = torch.cat(image_batch.tolist()) output_batch = model(input_batch) post_processed_batch = processor.post_process_semantic_segmentation( - output_batch + output_batch, ) segmentations_batch = [convert_to_rgb(seg.cpu().numpy()) for seg in post_processed_batch] return pd.Series(segmentations_batch, index=image_batch.index) @@ -94,7 +94,7 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: images = dataframe["images"]["data"].apply( process_image, processor=self.processor, - device=self.device + device=self.device, ) results: t.List[pd.Series] = [] @@ -105,7 +105,7 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: batch, model=self.model, processor=self.processor, - ).T + ).T, ) return pd.concat(results).to_frame(name=("segmentations", "data")) diff --git a/components/segment_images/src/palette.py b/components/segment_images/src/palette.py index 5a9a6005e..6be7711fa 100644 --- a/components/segment_images/src/palette.py +++ b/components/segment_images/src/palette.py @@ -161,5 +161,5 @@ [25, 194, 194], [102, 255, 0], [92, 0, 255], - ] + ], ) diff --git a/components/write_to_hf_hub/src/main.py b/components/write_to_hf_hub/src/main.py index 6592255e1..a81bcb5c9 100644 --- a/components/write_to_hf_hub/src/main.py +++ b/components/write_to_hf_hub/src/main.py @@ -39,7 +39,7 @@ def write( username: str, dataset_name: str, image_column_names: t.Optional[list], - column_name_mapping: t.Optional[dict] + column_name_mapping: t.Optional[dict], ): """ Args: @@ -83,7 +83,7 @@ def write( for image_column_name in image_column_names: dataframe[image_column_name] = dataframe[image_column_name].map( lambda x: convert_bytes_to_image(x, feature_encoder), - meta=(image_column_name, "object") + meta=(image_column_name, "object"), ) # Map column names to hf data format diff --git a/fondant/__init__.py b/fondant/__init__.py index 0db5752a5..795a81798 100644 --- a/fondant/__init__.py +++ b/fondant/__init__.py @@ -1,5 +1,6 @@ import logging logging.basicConfig( - format="[%(asctime)s | %(name)s | %(levelname)s] %(message)s", level=logging.INFO + format="[%(asctime)s | %(name)s | %(levelname)s] %(message)s", + level=logging.INFO, ) diff --git a/fondant/cli.py b/fondant/cli.py index 3f1ee7b01..34b1dec77 100644 --- a/fondant/cli.py +++ b/fondant/cli.py @@ -45,22 +45,24 @@ def argument(*name_or_flags, **kwargs): return (list(name_or_flags), kwargs) -def distill_arguments(args: argparse.Namespace, remove: t.List[str] = []): +def distill_arguments(args: argparse.Namespace, remove: t.Optional[t.List[str]] = None): """Helper function to distill arguments to be passed on to the function.""" args_dict = vars(args) args_dict.pop("func") - for arg in remove: - args_dict.pop(arg) + if remove is not None: + for arg in remove: + args_dict.pop(arg) return args_dict -def subcommand(name, parent_parser=subparsers, help=None, args=[]): +def subcommand(name, parent_parser=subparsers, help=None, args=None): """Decorator to add a subcommand to the CLI.""" def decorator(func): parser = parent_parser.add_parser(name, help=help) - for arg in args: - parser.add_argument(*arg[0], **arg[1]) + if args is not None: + for arg in args: + parser.add_argument(*arg[0], **arg[1]) parser.set_defaults(func=func) return decorator @@ -95,7 +97,7 @@ def decorator(func): argument( "--credentials", "-c", - help="""Path mapping of the source (local) and target (docker file system) + help="""Path mapping of the source (local) and target (docker file system) credential paths in the format of src:target \nExamples:\n Google Cloud: $HOME/.config/gcloud/application_default_credentials.json:/root/." @@ -115,7 +117,7 @@ def explore(args): logging.warning( "You have not provided a data directory." + "To access local files, provide a local data directory" - + " with the --data-directory flag." + + " with the --data-directory flag.", ) else: logging.info(f"Using data directory: {args.data_directory}") @@ -124,7 +126,7 @@ def explore(args): if not args.credentials: logging.warning( "You have not provided a credentials file. If you wish to access data " - "from a cloud provider, mount the credentials file with the --credentials flag." + "from a cloud provider, mount the credentials file with the --credentials flag.", ) if not shutil.which("docker"): @@ -143,24 +145,27 @@ def pipeline_from_string(import_string: str) -> Pipeline: if not attr_str or not module_str: raise ImportFromStringError( f"{import_string} is not a valid import string." - + "Please provide a valid import string in the format of module:attr" + + "Please provide a valid import string in the format of module:attr", ) try: module = importlib.import_module(module_str) except ImportError: + msg = f"{module_str} is not a valid module. Please provide a valid module." raise ImportFromStringError( - f"{module_str} is not a valid module. Please provide a valid module." + msg, ) try: for attr_str_element in attr_str.split("."): instance = getattr(module, attr_str_element) except AttributeError: - raise ImportFromStringError(f"{attr_str} is not found in {module}.") + msg = f"{attr_str} is not found in {module}." + raise ImportFromStringError(msg) if not isinstance(instance, Pipeline): - raise ImportFromStringError(f"{module}:{instance} is not a valid pipeline.") + msg = f"{module}:{instance} is not a valid pipeline." + raise ImportFromStringError(msg) return instance @@ -182,10 +187,15 @@ def pipeline_from_string(import_string: str) -> Pipeline: choices=["local", "kubeflow"], ), argument( - "--output-path", "-o", help="Output directory", default="docker-compose.yml" + "--output-path", + "-o", + help="Output directory", + default="docker-compose.yml", ), argument( - "--extra-volumes", help="Extra volumes to mount in containers", nargs="+" + "--extra-volumes", + help="Extra volumes to mount in containers", + nargs="+", ), ], ) @@ -195,4 +205,5 @@ def compile(args): function_args = distill_arguments(args, remove=["mode"]) compiler.compile(**function_args) else: - raise NotImplementedError("Kubeflow mode is not implemented yet.") + msg = "Kubeflow mode is not implemented yet." + raise NotImplementedError(msg) diff --git a/fondant/compiler.py b/fondant/compiler.py index dfb734848..4fa0a3613 100644 --- a/fondant/compiler.py +++ b/fondant/compiler.py @@ -57,7 +57,7 @@ def compile( self, pipeline: Pipeline, output_path: str = "docker-compose.yml", - extra_volumes: list = [], + extra_volumes: t.Optional[list] = None, ) -> None: """Compile a pipeline to docker-compose spec and save it to a specified output path. @@ -68,6 +68,9 @@ def compile( https://docs.docker.com/compose/compose-file/05-services/#short-syntax-5) to mount in the docker-compose spec. """ + if extra_volumes is None: + extra_volumes = [] + logger.info(f"Compiling {pipeline.name} to {output_path}") spec = self._generate_spec(pipeline=pipeline, extra_volumes=extra_volumes) with open(output_path, "w") as outfile: @@ -90,10 +93,12 @@ def _patch_path(self, base_path: str) -> t.Tuple[str, t.Optional[DockerVolume]]: # check if base path is an existing local folder if p_base_path.exists(): logger.info( - f"Base path found on local system, setting up {base_path} as mount volume" + f"Base path found on local system, setting up {base_path} as mount volume", ) volume = DockerVolume( - type="bind", source=str(p_base_path), target=f"/{p_base_path.stem}" + type="bind", + source=str(p_base_path), + target=f"/{p_base_path.stem}", ) path = f"/{p_base_path.stem}" else: @@ -126,7 +131,7 @@ def _generate_spec(self, pipeline: Pipeline, extra_volumes: list) -> dict: [ "--output_manifest_path", f"{path}/{safe_component_name}/manifest.json", - ] + ], ) # add arguments if any to command @@ -142,14 +147,14 @@ def _generate_spec(self, pipeline: Pipeline, extra_volumes: list) -> dict: for dependency in component["dependencies"]: safe_dependency = self._safe_component_name(dependency) depends_on[safe_dependency] = { - "condition": "service_completed_successfully" + "condition": "service_completed_successfully", } # there is only an input manifest if the component has dependencies command.extend( [ "--input_manifest_path", f"{path}/{safe_dependency}/manifest.json", - ] + ], ) volumes = [] diff --git a/fondant/component.py b/fondant/component.py index d10514ebb..8f679e9e0 100644 --- a/fondant/component.py +++ b/fondant/component.py @@ -42,7 +42,8 @@ def __init__( @classmethod def from_file( - cls, path: t.Union[str, Path] = "../fondant_component.yaml" + cls, + path: t.Union[str, Path] = "../fondant_component.yaml", ) -> "Component": """Create a component from a component spec file. @@ -60,7 +61,8 @@ def from_args(cls) -> "Component": args, _ = parser.parse_known_args() if "component_spec" not in args: - raise ValueError("Error: The --component_spec argument is required.") + msg = "Error: The --component_spec argument is required." + raise ValueError(msg) component_spec = ComponentSpec(args.component_spec) @@ -174,14 +176,12 @@ def optional_fondant_arguments() -> t.List[str]: def _load_or_create_manifest(self) -> Manifest: component_id = self.spec.name.lower().replace(" ", "_") - manifest = Manifest.create( + return Manifest.create( base_path=self.metadata["base_path"], run_id=self.metadata["run_id"], component_id=component_id, ) - return manifest - @abstractmethod def load(self, *args, **kwargs) -> dd.DataFrame: """Abstract method that loads the initial dataframe.""" @@ -193,9 +193,7 @@ def _process_dataset(self, manifest: Manifest) -> dd.DataFrame: A `dd.DataFrame` instance with initial data'. """ # Load the dataframe according to the custom function provided to the user - df = self.load(**self.user_arguments) - - return df + return self.load(**self.user_arguments) class TransformComponent(Component): @@ -244,9 +242,9 @@ def _process_dataset(self, manifest: Manifest) -> dd.DataFrame: A `dd.DataFrame` instance with updated data based on the applied data transformations. """ data_loader = DaskDataLoader(manifest=manifest, component_spec=self.spec) - df = data_loader.load_dataframe() - df = self.transform(dataframe=df, **self.user_arguments) - return df + dataframe = data_loader.load_dataframe() + dataframe = self.transform(dataframe, **self.user_arguments) + return dataframe class PandasTransformComponent(TransformComponent): @@ -281,7 +279,7 @@ def wrapped_transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: for (subset, field) in dataframe.columns if subset not in self.spec.produces or field not in self.spec.produces[subset].fields - ] + ], ) dataframe.columns = [ "_".join(column) for column in dataframe.columns.to_flat_index() @@ -297,7 +295,7 @@ def _process_dataset(self, manifest: Manifest) -> dd.DataFrame: A `dd.DataFrame` instance with updated data based on the applied data transformations. """ data_loader = DaskDataLoader(manifest=manifest, component_spec=self.spec) - df = data_loader.load_dataframe() + dataframe = data_loader.load_dataframe() # Call the component setup method with user provided argument self.setup(**self.user_arguments) @@ -307,21 +305,21 @@ def _process_dataset(self, manifest: Manifest) -> dd.DataFrame: for subset_name, subset in self.spec.produces.items(): for field_name, field in subset.fields.items(): meta_dict[f"{subset_name}_{field_name}"] = pd.Series( - dtype=pd.ArrowDtype(field.type.value) + dtype=pd.ArrowDtype(field.type.value), ) meta_df = pd.DataFrame(meta_dict).set_index("id") # Call the component transform method for each partition - df = df.map_partitions( + dataframe = dataframe.map_partitions( self.wrapped_transform, meta=meta_df, ) # Clear divisions if component spec indicates that the index is changed if self._infer_index_change(): - df.clear_divisions() + dataframe.clear_divisions() - return df + return dataframe def _infer_index_change(self) -> bool: """Infer if this component changes the index based on its component spec.""" @@ -332,10 +330,9 @@ def _infer_index_change(self) -> bool: for subset in self.spec.consumes.values(): if not subset.additional_fields: return True - for subset in self.spec.produces.values(): - if not subset.additional_fields: - return True - return False + return any( + not subset.additional_fields for subset in self.spec.produces.values() + ) class WriteComponent(Component): @@ -368,12 +365,11 @@ def _process_dataset(self, manifest: Manifest) -> None: A `dd.DataFrame` instance with updated data based on the applied data transformations. """ data_loader = DaskDataLoader(manifest=manifest, component_spec=self.spec) - df = data_loader.load_dataframe() - self.write(dataframe=df, **self.user_arguments) + dataframe = data_loader.load_dataframe() + self.write(dataframe, **self.user_arguments) def _write_data(self, dataframe: dd.DataFrame, *, manifest: Manifest): """Create a data writer given a manifest and writes out the index and subsets.""" - pass def upload_manifest(self, manifest: Manifest, save_path: str): pass diff --git a/fondant/component_spec.py b/fondant/component_spec.py index 4de75f0fd..9e7cd9ee3 100644 --- a/fondant/component_spec.py +++ b/fondant/component_spec.py @@ -82,7 +82,7 @@ def fields(self) -> t.Mapping[str, Field]: { name: Field(name=name, type=Type.from_json(field)) for name, field in self._specification["fields"].items() - } + }, ) @property @@ -110,10 +110,11 @@ def _validate_spec(self) -> None: spec_data = pkgutil.get_data("fondant", "schemas/component_spec.json") if spec_data is None: - raise FileNotFoundError("component_spec.json not found in fondant schema") - else: - spec_str = spec_data.decode("utf-8") - spec_schema = json.loads(spec_str) + msg = "component_spec.json not found in fondant schema" + raise FileNotFoundError(msg) + + spec_str = spec_data.decode("utf-8") + spec_schema = json.loads(spec_str) base_uri = (Path(__file__).parent / "schemas").as_uri() resolver = RefResolver(base_uri=f"{base_uri}/", referrer=spec_schema) @@ -160,7 +161,7 @@ def consumes(self) -> t.Mapping[str, ComponentSubset]: name: ComponentSubset(subset) for name, subset in self._specification.get("consumes", {}).items() if name != "additionalSubsets" - } + }, ) @property @@ -171,7 +172,7 @@ def produces(self) -> t.Mapping[str, ComponentSubset]: name: ComponentSubset(subset) for name, subset in self._specification.get("produces", {}).items() if name != "additionalSubsets" - } + }, ) @property @@ -224,7 +225,8 @@ def __init__(self, specification: t.Dict[str, t.Any]) -> None: @classmethod def from_fondant_component_spec( - cls, fondant_component: ComponentSpec + cls, + fondant_component: ComponentSpec, ) -> "KubeflowComponentSpec": """Create a Kubeflow component spec from a Fondant component spec.""" specification = { @@ -280,7 +282,7 @@ def from_fondant_component_spec( "--output_manifest_path", {"outputPath": "output_manifest_path"}, ], - } + }, }, } return cls(specification) @@ -325,7 +327,7 @@ def input_arguments(self) -> t.Mapping[str, Argument]: default=info["default"] if "default" in info else None, ) for info in self._specification["inputs"] - } + }, ) @property @@ -339,7 +341,7 @@ def output_arguments(self) -> t.Mapping[str, Argument]: type=info["type"], ) for info in self._specification["outputs"] - } + }, ) def __repr__(self) -> str: diff --git a/fondant/data_io.py b/fondant/data_io.py index 87c409a08..b3c519273 100644 --- a/fondant/data_io.py +++ b/fondant/data_io.py @@ -37,7 +37,7 @@ def _load_subset(self, subset_name: str, fields: t.List[str]) -> dd.DataFrame: # add subset prefix to columns subset_df = subset_df.rename( - columns={col: subset_name + "_" + col for col in subset_df.columns} + columns={col: subset_name + "_" + col for col in subset_df.columns}, ) return subset_df @@ -55,9 +55,7 @@ def _load_index(self) -> dd.DataFrame: remote_path = index.location # load index from parquet, expecting id and source columns - index_df = dd.read_parquet(remote_path) - - return index_df + return dd.read_parquet(remote_path) def load_dataframe(self) -> dd.DataFrame: """ @@ -69,22 +67,22 @@ def load_dataframe(self) -> dd.DataFrame: as well as the index columns. """ # load index into dataframe - df = self._load_index() + dataframe = self._load_index() for name, subset in self.component_spec.consumes.items(): fields = list(subset.fields.keys()) subset_df = self._load_subset(name, fields) # left joins -> filter on index - df = dd.merge( - df, + dataframe = dd.merge( + dataframe, subset_df, left_index=True, right_index=True, how="left", ) - logging.info(f"Columns of dataframe: {list(df.columns)}") + logging.info(f"Columns of dataframe: {list(dataframe.columns)}") - return df + return dataframe class DaskDataWriter(DataIO): @@ -96,16 +94,22 @@ def write_dataframe(self, dataframe: dd.DataFrame) -> None: # Turn index into an empty dataframe so we can write it index_df = dataframe.index.to_frame().drop(columns=["id"]) write_index_task = self._write_subset( - index_df, subset_name="index", subset_spec=self.component_spec.index + index_df, + subset_name="index", + subset_spec=self.component_spec.index, ) write_tasks.append(write_index_task) for subset_name, subset_spec in self.component_spec.produces.items(): subset_df = self._extract_subset_dataframe( - dataframe, subset_name=subset_name, subset_spec=subset_spec + dataframe, + subset_name=subset_name, + subset_spec=subset_spec, ) write_subset_task = self._write_subset( - subset_df, subset_name=subset_name, subset_spec=subset_spec + subset_df, + subset_name=subset_name, + subset_spec=subset_spec, ) write_tasks.append(write_subset_task) @@ -115,7 +119,10 @@ def write_dataframe(self, dataframe: dd.DataFrame) -> None: @staticmethod def _extract_subset_dataframe( - dataframe: dd.DataFrame, *, subset_name: str, subset_spec: ComponentSubset + dataframe: dd.DataFrame, + *, + subset_name: str, + subset_spec: ComponentSubset, ) -> dd.DataFrame: """Create subset dataframe to save with the original field name as the column name.""" # Create a new dataframe with only the columns needed for the output subset @@ -123,20 +130,27 @@ def _extract_subset_dataframe( try: subset_df = dataframe[subset_columns] except KeyError as e: - raise ValueError( + msg = ( f"Field {e.args[0]} defined in output subset {subset_name} " f"but not found in dataframe" ) + raise ValueError( + msg, + ) # Remove the subset prefix from the column names subset_df = subset_df.rename( - columns={col: col[(len(f"{subset_name}_")) :] for col in subset_columns} + columns={col: col[(len(f"{subset_name}_")) :] for col in subset_columns}, ) return subset_df def _write_subset( - self, dataframe: dd.DataFrame, *, subset_name: str, subset_spec: ComponentSubset + self, + dataframe: dd.DataFrame, + *, + subset_name: str, + subset_spec: ComponentSubset, ) -> dd.core.Scalar: if subset_name == "index": location = self.manifest.index.location @@ -149,7 +163,10 @@ def _write_subset( @staticmethod def _create_write_task( - dataframe: dd.DataFrame, *, location: str, schema: t.Dict[str, str] + dataframe: dd.DataFrame, + *, + location: str, + schema: t.Dict[str, str], ) -> dd.core.Scalar: """ Creates a delayed Dask task to upload the given DataFrame to the remote storage location @@ -165,7 +182,11 @@ def _create_write_task( executed. """ write_task = dd.to_parquet( - dataframe, location, schema=schema, overwrite=False, compute=False + dataframe, + location, + schema=schema, + overwrite=False, + compute=False, ) logging.info(f"Creating write task for: {location}") return write_task diff --git a/fondant/explorer.py b/fondant/explorer.py index 02f373e01..a701623ea 100644 --- a/fondant/explorer.py +++ b/fondant/explorer.py @@ -36,7 +36,7 @@ def run_explorer_app( [ "-v", credentials, - ] + ], ) # mount the local data directory to the container @@ -48,11 +48,11 @@ def run_explorer_app( cmd.extend( [ f"{shlex.quote(container)}:{shlex.quote(tag)}", - ] + ], ) logging.info( - f"Running image from registry: {container} with tag: {tag} on port: {port}" + f"Running image from registry: {container} with tag: {tag} on port: {port}", ) logging.info(f"Access the explorer at http://localhost:{port}") diff --git a/fondant/import_utils.py b/fondant/import_utils.py index 92868e11d..9b004b551 100644 --- a/fondant/import_utils.py +++ b/fondant/import_utils.py @@ -16,7 +16,7 @@ `{0}` requires the 🤗 Datasets library but it was not found in your environment. Please install fondant using the 'datasets' extra. Note that if you have a local folder named `datasets` or a local python file named - `datasets.py` in your current working directory, python may try to import this instead of the 🤗 + `datasets.py` in your current working directory, python may try to import this instead of the 🤗 Datasets library. You should rename this folder or that python file if that's the case. Please note that you may need to restart your runtime after installation. """ @@ -46,8 +46,8 @@ def is_package_available(package_name: str, import_error_msg: str) -> bool: if package_available: return package_available - else: - raise ModuleNotFoundError(import_error_msg.format(Path(sys.argv[0]).stem)) + + raise ModuleNotFoundError(import_error_msg.format(Path(sys.argv[0]).stem)) def is_datasets_available(): diff --git a/fondant/manifest.py b/fondant/manifest.py index ee384e4c4..25d03a826 100644 --- a/fondant/manifest.py +++ b/fondant/manifest.py @@ -41,12 +41,13 @@ def fields(self) -> t.Mapping[str, Field]: { name: Field(name=name, type=Type.from_json(field)) for name, field in self._specification["fields"].items() - } + }, ) def add_field(self, name: str, type_: Type, *, overwrite: bool = False) -> None: if not overwrite and name in self._specification["fields"]: - raise ValueError(f"A field with name {name} already exists") + msg = f"A field with name {name} already exists" + raise ValueError(msg) self._specification["fields"][name] = type_.to_json() @@ -88,10 +89,11 @@ def _validate_spec(self) -> None: spec_data = pkgutil.get_data("fondant", "schemas/manifest.json") if spec_data is None: - raise FileNotFoundError("schemas/manifest.json not found") - else: - spec_str = spec_data.decode("utf-8") - spec_schema = json.loads(spec_str) + msg = "schemas/manifest.json not found" + raise FileNotFoundError(msg) + + spec_str = spec_data.decode("utf-8") + spec_schema = json.loads(spec_str) base_uri = (Path(__file__).parent / "schemas").as_uri() resolver = RefResolver(base_uri=f"{base_uri}/", referrer=spec_schema) @@ -168,14 +170,17 @@ def subsets(self) -> t.Mapping[str, Subset]: { name: Subset(subset, base_path=self.base_path) for name, subset in self._specification["subsets"].items() - } + }, ) def add_subset( - self, name: str, fields: t.Iterable[t.Union[Field, t.Tuple[str, Type]]] + self, + name: str, + fields: t.Iterable[t.Union[Field, t.Tuple[str, Type]]], ) -> None: if name in self._specification["subsets"]: - raise ValueError(f"A subset with name {name} already exists") + msg = f"A subset with name {name} already exists" + raise ValueError(msg) self._specification["subsets"][name] = { "location": f"/{name}/{self.run_id}/{self.component_id}", @@ -184,12 +189,14 @@ def add_subset( def remove_subset(self, name: str) -> None: if name not in self._specification["subsets"]: - raise ValueError(f"Subset {name} not found in specification") + msg = f"Subset {name} not found in specification" + raise ValueError(msg) del self._specification["subsets"][name] def evolve( # noqa : PLR0912 (too many branches) - self, component_spec: ComponentSpec + self, + component_spec: ComponentSpec, ) -> "Manifest": """Evolve the manifest based on the component spec. The resulting manifest is the expected result if the current manifest is provided @@ -223,13 +230,12 @@ def evolve( # noqa : PLR0912 (too many branches) # If additionalFields is False for a consumed subset, # Remove all fields from that subset that are not listed for subset_name, subset in component_spec.consumes.items(): - if subset_name in evolved_manifest.subsets: - if not subset.additional_fields: - for field_name in evolved_manifest.subsets[subset_name].fields: - if field_name not in subset.fields: - evolved_manifest.subsets[subset_name].remove_field( - field_name - ) + if subset_name in evolved_manifest.subsets and not subset.additional_fields: + for field_name in evolved_manifest.subsets[subset_name].fields: + if field_name not in subset.fields: + evolved_manifest.subsets[subset_name].remove_field( + field_name, + ) # For each output subset defined in the component, add or update it for subset_name, subset in component_spec.produces.items(): @@ -241,14 +247,16 @@ def evolve( # noqa : PLR0912 (too many branches) for field_name in evolved_manifest.subsets[subset_name].fields: if field_name not in subset.fields: evolved_manifest.subsets[subset_name].remove_field( - field_name + field_name, ) # Add fields defined in the component spec produces section # Overwrite to persist changes to the field (eg. type of column) for field in subset.fields.values(): evolved_manifest.subsets[subset_name].add_field( - field.name, field.type, overwrite=True + field.name, + field.type, + overwrite=True, ) # Update subset location as this is currently always rewritten diff --git a/fondant/pipeline.py b/fondant/pipeline.py index c1b5a8899..96d1dd532 100644 --- a/fondant/pipeline.py +++ b/fondant/pipeline.py @@ -77,7 +77,7 @@ def extend_arguments(self): """Add the component specification to the arguments if not already present.""" if not self.arguments.get("component_spec"): self.arguments["component_spec"] = json.dumps( - self.component_spec.specification + self.component_spec.specification, ) @classmethod @@ -115,7 +115,8 @@ def from_registry( component_spec_path = t.cast(Path, component_spec_path) if not (component_spec_path.exists() and component_spec_path.is_file()): - raise ValueError(f"No reusable component with name {name} found.") + msg = f"No reusable component with name {name} found." + raise ValueError(msg) return ComponentOp( component_spec_path, @@ -165,8 +166,9 @@ def add_op( """ if dependencies is None: if self.task_without_dependencies_added: + msg = "At most one task can be defined without dependencies." raise InvalidPipelineDefinition( - "At most one task can be defined without " "dependencies." + msg, ) dependencies = [] self.task_without_dependencies_added = True @@ -174,12 +176,14 @@ def add_op( dependencies = [dependencies] if len(dependencies) > 1: + msg = ( + f"Multiple component dependencies provided for component " + f"`{task.component_spec.name}`. The current version of Fondant can only handle " + f"components with a single dependency. Please note that the behavior of the " + f"pipeline may be unpredictable or incorrect." + ) raise InvalidPipelineDefinition( - f"Multiple component dependencies provided for component" - f" `{task.component_spec.name}`. " - f"The current version of Fondant can only handle components with a single " - f"dependency. Please note that the behavior of the pipeline may be unpredictable" - f" or incorrect." + msg, ) dependencies_names = [ @@ -232,7 +236,9 @@ def _validate_pipeline_definition(self, run_id: str): # Create initial manifest manifest = Manifest.create( - base_path=self.base_path, run_id=run_id, component_id=load_component_name + base_path=self.base_path, + run_id=run_id, + component_id=load_component_name, ) for operation_specs in self._graph.values(): fondant_component_op = operation_specs["fondant_component_op"] @@ -244,10 +250,13 @@ def _validate_pipeline_definition(self, run_id: str): component_subset, ) in component_spec.consumes.items(): if component_subset_name not in manifest.subsets: + msg = ( + f"Component '{component_spec.name}' is trying to invoke the subset " + f"'{component_subset_name}', which has not been defined or created " + f"in the previous components." + ) raise InvalidPipelineDefinition( - f"Component '{component_spec.name}' " - f"is trying to invoke the subset '{component_subset_name}', " - f"which has not been defined or created in the previous components." + msg, ) # Get the corresponding manifest fields @@ -257,24 +266,29 @@ def _validate_pipeline_definition(self, run_id: str): for field_name, subset_field in component_subset.fields.items(): # Check if invoked field exists if field_name not in manifest_fields: - raise InvalidPipelineDefinition( - f"The invoked subset '{component_subset_name}' of the" - f" '{component_spec.name}' component does not match " - f"the previously created subset definition.\n The component is" - f" trying to invoke the field '{field_name}' which has not been" - f" previously defined. Current available fields are " + msg = ( + f"The invoked subset '{component_subset_name}' of the " + f"'{component_spec.name}' component does not match the " + f"previously created subset definition.\n The component is " + f"trying to invoke the field '{field_name}' which has not been " + f"previously defined. Current available fields are " f"{manifest_fields}\n" ) + raise InvalidPipelineDefinition( + msg, + ) # Check if the invoked field schema matches the current schema if subset_field != manifest_fields[field_name]: + msg = ( + f"The invoked subset '{component_subset_name}' of the " + f"'{component_spec.name}' component does not match the " + f"previously created subset definition.\n The '{field_name}' " + f"field is currently defined with the following schema:\n" + f"{manifest_fields[field_name]}\nThe current component to " + f"trying to invoke it with this schema:\n{subset_field}" + ) raise InvalidPipelineDefinition( - f"The invoked subset '{component_subset_name}' of the" - f" '{component_spec.name}' component does not match " - f" the previously created subset definition.\n The '{field_name}'" - f" field is currently defined with the following schema:\n" - f"{manifest_fields[field_name]}\n" - f"The current component to trying to invoke it with this schema:\n" - f"{subset_field}" + msg, ) manifest = manifest.evolve(component_spec) load_component = False @@ -303,7 +317,7 @@ def _get_component_function( Callable: The Kubeflow component. """ return kfp.components.load_component( - text=fondant_component_operation.component_spec.kubeflow_specification.to_string() + text=fondant_component_operation.component_spec.kubeflow_specification.to_string(), ) def _set_task_configuration(task, fondant_component_operation): @@ -361,7 +375,7 @@ def pipeline(): ) else: metadata = json.dumps( - {"base_path": self.base_path, "run_id": run_id} + {"base_path": self.base_path, "run_id": run_id}, ) # Add metadata to the first component component_task = kubeflow_component_op( @@ -372,7 +386,8 @@ def pipeline(): metadata = "" # Set optional configurations component_task = _set_task_configuration( - component_task, fondant_component_op + component_task, + fondant_component_op, ) # Set the execution order of the component task to be after the previous # component task. @@ -422,7 +437,7 @@ def get_pipeline_id(self, pipeline_name: str) -> str: def get_pipeline_version_ids(self, pipeline_id: str) -> t.List[str]: """Function that returns the versions of a pipeline given a pipeline id.""" pipeline_versions = self.client.list_pipeline_versions(pipeline_id).versions - return [getattr(version, "id") for version in pipeline_versions] + return [version.id for version in pipeline_versions] def delete_pipeline(self, pipeline_name: str): """ @@ -439,7 +454,7 @@ def delete_pipeline(self, pipeline_name: str): self.client.delete_pipeline(pipeline_id) logger.info( - f"Pipeline {pipeline_name} already exists. Deleting old pipeline..." + f"Pipeline {pipeline_name} already exists. Deleting old pipeline...", ) else: logger.info(f"No existing pipeline under `{pipeline_name}` name was found.") @@ -461,18 +476,18 @@ def compile_and_upload( Raises: Exception: If there was an error uploading the pipeline package. """ - # self.delete_pipeline(pipeline.name) - pipeline.compile() logger.info(f"Uploading pipeline: {pipeline.name}") try: self.client.upload_pipeline( - pipeline_package_path=pipeline.package_path, pipeline_name=pipeline.name + pipeline_package_path=pipeline.package_path, + pipeline_name=pipeline.name, ) except Exception as e: - raise Exception(f"Error uploading pipeline package: {str(e)}") + msg = f"Error uploading pipeline package: {str(e)}" + raise Exception(msg) # Delete the pipeline package file if specified. if delete_pipeline_package: @@ -501,7 +516,7 @@ def compile_and_run( except ValueError: logger.info( f"Defined experiment '{experiment_name}' not found. Creating new experiment" - f"under this name" + f"under this name", ) experiment = self.client.create_experiment(experiment_name) diff --git a/fondant/schema.py b/fondant/schema.py index aff68c0bd..ab4f116d0 100644 --- a/fondant/schema.py +++ b/fondant/schema.py @@ -67,9 +67,12 @@ def _validate_data_type(data_type: t.Union[str, pa.DataType]) -> pa.DataType: try: data_type = _TYPES[data_type] except KeyError: + msg = ( + f"Invalid schema provided {data_type} with type {type(data_type)}. " + f"Current available data types are: {_TYPES.keys()}" + ) raise InvalidTypeSchema( - f"Invalid schema provided {data_type} with type {type(data_type)}." - f" Current available data types are: {_TYPES.keys()}" + msg, ) return data_type @@ -87,7 +90,7 @@ def list(cls, data_type: t.Union[str, pa.DataType, "Type"]) -> "Type": """ data_type = cls._validate_data_type(data_type) return cls( - pa.list_(data_type.value if isinstance(data_type, Type) else data_type) + pa.list_(data_type.value if isinstance(data_type, Type) else data_type), ) @classmethod @@ -106,8 +109,9 @@ def from_json(cls, json_schema: dict): items = json_schema["items"] if isinstance(items, dict): return cls.list(cls.from_json(items)) - else: - return cls(json_schema["type"]) + return None + + return cls(json_schema["type"]) def to_json(self) -> dict: """ diff --git a/pyproject.toml b/pyproject.toml index aa824483d..595a54e04 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,7 +71,8 @@ requires = ["poetry-core>=1.2.0"] build-backend = "poetry.core.masonry.api" [tool.ruff] -select = ["D", "E", "F", "I", "PL", "RUF"] +select = ["C4", "COM", "D", "E", "EM", "ERA", "F", "I", "NPY", "PD", "PIE", "PL", "PT", "RET", + "RSE", "RUF", "SIM", "TCH", "TID", "UP", "W"] ignore = [ "D100", # Missing docstring in public module "D101", # Missing docstring in public class diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 0514dd35c..90b4c66a0 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -37,15 +37,16 @@ arguments={"storage_args": "a dummy string arg"}, ), ComponentOp.from_registry( - name="image_cropping", arguments={"cropping_threshold": 0, "padding": 0} + name="image_cropping", + arguments={"cropping_threshold": 0, "padding": 0}, ), ], ), ] -@pytest.fixture -def freeze_time(monkeypatch): +@pytest.fixture() +def _freeze_time(monkeypatch): class FrozenDatetime(datetime.datetime): @classmethod def now(cls): @@ -80,20 +81,22 @@ def setup_pipeline(request, tmp_path, monkeypatch): return (example_dir, pipeline) -def test_docker_compiler(setup_pipeline, freeze_time, tmp_path_factory): +@pytest.mark.usefixtures("_freeze_time") +def test_docker_compiler(setup_pipeline, tmp_path_factory): """Test compiling a pipeline to docker-compose.""" example_dir, pipeline = setup_pipeline compiler = DockerCompiler() with tmp_path_factory.mktemp("temp") as fn: output_path = str(fn / "docker-compose.yml") compiler.compile(pipeline=pipeline, output_path=output_path) - with open(output_path, "r") as src, open( - VALID_DOCKER_PIPELINE / example_dir / "docker-compose.yml", "r" + with open(output_path) as src, open( + VALID_DOCKER_PIPELINE / example_dir / "docker-compose.yml", ) as truth: assert yaml.safe_load(src) == yaml.safe_load(truth) -def test_docker_local_path(setup_pipeline, freeze_time, tmp_path_factory): +@pytest.mark.usefixtures("_freeze_time") +def test_docker_local_path(setup_pipeline, tmp_path_factory): """Test that a local path is applied correctly as a volume and in the arguments.""" # volumes are only created for local existing directories with tmp_path_factory.mktemp("temp") as fn: @@ -115,7 +118,7 @@ def test_docker_local_path(setup_pipeline, freeze_time, tmp_path_factory): "source": str(fn), "target": work_dir, "type": "bind", - } + }, ] # check if commands are patched to use the working dir commands_with_dir = [ @@ -126,7 +129,8 @@ def test_docker_local_path(setup_pipeline, freeze_time, tmp_path_factory): assert command in service["command"] -def test_docker_remote_path(setup_pipeline, freeze_time, tmp_path_factory): +@pytest.mark.usefixtures("_freeze_time") +def test_docker_remote_path(setup_pipeline, tmp_path_factory): """Test that a remote path is applied correctly in the arguments and no volume.""" _, pipeline = setup_pipeline remote_dir = "gs://somebucket/artifacts" @@ -151,7 +155,8 @@ def test_docker_remote_path(setup_pipeline, freeze_time, tmp_path_factory): assert command in service["command"] -def test_docker_extra_volumes(setup_pipeline, freeze_time, tmp_path_factory): +@pytest.mark.usefixtures("_freeze_time") +def test_docker_extra_volumes(setup_pipeline, tmp_path_factory): """Test that extra volumes are applied correctly.""" with tmp_path_factory.mktemp("temp") as fn: # this is the directory mounted in the container @@ -169,7 +174,7 @@ def test_docker_extra_volumes(setup_pipeline, freeze_time, tmp_path_factory): # read the generated docker-compose file with open(fn / "docker-compose.yml") as f_spec: spec = yaml.safe_load(f_spec) - for name, service in spec["services"].items(): + for _name, service in spec["services"].items(): assert all( extra_volume in service["volumes"] for extra_volume in extra_volumes ) diff --git a/tests/test_component.py b/tests/test_component.py index 8833bb2d8..3407f040c 100644 --- a/tests/test_component.py +++ b/tests/test_component.py @@ -24,14 +24,13 @@ def yaml_file_to_json_string(file_path): - with open(file_path, "r") as file: + with open(file_path) as file: data = yaml.safe_load(file) - json_string = json.dumps(data) - return json_string + return json.dumps(data) -@pytest.fixture -def patched_data_loading(monkeypatch): +@pytest.fixture() +def _patched_data_loading(monkeypatch): """Mock data loading so no actual data is loaded.""" def mocked_load_dataframe(self): @@ -40,8 +39,8 @@ def mocked_load_dataframe(self): monkeypatch.setattr(DaskDataLoader, "load_dataframe", mocked_load_dataframe) -@pytest.fixture -def patched_data_writing(monkeypatch): +@pytest.fixture() +def _patched_data_writing(monkeypatch): """Mock data loading so no actual data is written.""" def mocked_write_dataframe(self, dataframe): @@ -49,7 +48,9 @@ def mocked_write_dataframe(self, dataframe): monkeypatch.setattr(DaskDataWriter, "write_dataframe", mocked_write_dataframe) monkeypatch.setattr( - Component, "upload_manifest", lambda self, manifest, save_path: None + Component, + "upload_manifest", + lambda self, manifest, save_path: None, ) @@ -103,7 +104,8 @@ def _process_dataset(self, manifest: Manifest) -> t.Union[None, dd.DataFrame]: } -def test_load_component(patched_data_writing): +@pytest.mark.usefixtures("_patched_data_writing") +def test_load_component(): # Mock CLI argumentsload sys.argv = [ "", @@ -136,7 +138,8 @@ def load(self, *, flag, value): load.assert_called_once() -def test_dask_transform_component(patched_data_loading, patched_data_writing): +@pytest.mark.usefixtures("_patched_data_loading", "_patched_data_writing") +def test_dask_transform_component(): # Mock CLI arguments sys.argv = [ "", @@ -163,13 +166,16 @@ def transform(self, dataframe, *, flag, value): component = MyDaskComponent.from_args() with mock.patch.object( - MyDaskComponent, "transform", wraps=component.transform + MyDaskComponent, + "transform", + wraps=component.transform, ) as transform: component.run() transform.assert_called_once() -def test_pandas_transform_component(patched_data_loading, patched_data_writing): +@pytest.mark.usefixtures("_patched_data_loading", "_patched_data_writing") +def test_pandas_transform_component(): # Mock CLI arguments sys.argv = [ "", @@ -199,7 +205,9 @@ def transform(self, dataframe): component = MyPandasComponent.from_args() setup = mock.patch.object(MyPandasComponent, "setup", wraps=component.setup) transform = mock.patch.object( - MyPandasComponent, "transform", wraps=component.transform + MyPandasComponent, + "transform", + wraps=component.transform, ) with setup as setup, transform as transform: component.run() @@ -207,7 +215,8 @@ def transform(self, dataframe): assert transform.call_count == N_PARTITIONS -def test_write_component(patched_data_loading): +@pytest.mark.usefixtures("_patched_data_loading") +def test_write_component(): # Mock CLI arguments sys.argv = [ "", diff --git a/tests/test_component_specs.py b/tests/test_component_specs.py index 68eb7c1e8..698e7bd38 100644 --- a/tests/test_component_specs.py +++ b/tests/test_component_specs.py @@ -14,25 +14,25 @@ component_specs_path = Path(__file__).parent / "example_specs/component_specs" -@pytest.fixture +@pytest.fixture() def valid_fondant_schema() -> dict: with open(component_specs_path / "valid_component.yaml") as f: return yaml.safe_load(f) -@pytest.fixture +@pytest.fixture() def valid_fondant_schema_no_args() -> dict: with open(component_specs_path / "valid_component_no_args.yaml") as f: return yaml.safe_load(f) -@pytest.fixture +@pytest.fixture() def valid_kubeflow_schema() -> dict: with open(component_specs_path / "kubeflow_component.yaml") as f: return yaml.safe_load(f) -@pytest.fixture +@pytest.fixture() def invalid_fondant_schema() -> dict: with open(component_specs_path / "invalid_component.yaml") as f: return yaml.safe_load(f) @@ -64,7 +64,7 @@ def test_attribute_access(valid_fondant_schema): assert fondant_component.description == "This is an example component" assert fondant_component.consumes["images"].fields["data"].type == Type("binary") assert fondant_component.consumes["embeddings"].fields["data"].type == Type.list( - Type("float32") + Type("float32"), ) @@ -92,7 +92,7 @@ def test_component_spec_to_file(valid_fondant_schema): file_path = os.path.join(temp_dir, "component_spec.yaml") component_spec.to_file(file_path) - with open(file_path, "r") as f: + with open(file_path) as f: written_data = yaml.safe_load(f) # check if the written data is the same as the original data @@ -107,7 +107,7 @@ def test_kubeflow_component_spec_to_file(valid_kubeflow_schema): file_path = os.path.join(temp_dir, "kubeflow_component_spec.yaml") kubeflow_component_spec.to_file(file_path) - with open(file_path, "r") as f: + with open(file_path) as f: written_data = yaml.safe_load(f) # check if the written data is the same as the original data diff --git a/tests/test_data_io.py b/tests/test_data_io.py index 2bab0605f..650e73d97 100644 --- a/tests/test_data_io.py +++ b/tests/test_data_io.py @@ -13,17 +13,17 @@ NUMBER_OF_TEST_ROWS = 151 -@pytest.fixture +@pytest.fixture() def manifest(): return Manifest.from_file(manifest_path) -@pytest.fixture +@pytest.fixture() def component_spec(): return ComponentSpec.from_file(component_spec_path) -@pytest.fixture +@pytest.fixture() def dataframe(manifest, component_spec): data_loader = DaskDataLoader(manifest=manifest, component_spec=component_spec) return data_loader.load_dataframe() @@ -49,15 +49,15 @@ def test_load_subset(manifest, component_spec): def test_load_dataframe(manifest, component_spec): """Test merging of subsets in a dataframe based on a component_spec.""" dl = DaskDataLoader(manifest=manifest, component_spec=component_spec) - df = dl.load_dataframe() - assert len(df) == NUMBER_OF_TEST_ROWS - assert list(df.columns) == [ + dataframe = dl.load_dataframe() + assert len(dataframe) == NUMBER_OF_TEST_ROWS + assert list(dataframe.columns) == [ "properties_Name", "properties_HP", "types_Type 1", "types_Type 2", ] - assert df.index.name == "id" + assert dataframe.index.name == "id" def test_write_index(tmp_path_factory, dataframe, manifest, component_spec): @@ -69,9 +69,9 @@ def test_write_index(tmp_path_factory, dataframe, manifest, component_spec): # write out index to temp dir data_writer.write_dataframe(dataframe) # read written data and assert - df = dd.read_parquet(fn / "index") - assert len(df) == NUMBER_OF_TEST_ROWS - assert df.index.name == "id" + dataframe = dd.read_parquet(fn / "index") + assert len(dataframe) == NUMBER_OF_TEST_ROWS + assert dataframe.index.name == "id" def test_write_subsets(tmp_path_factory, dataframe, manifest, component_spec): @@ -90,10 +90,10 @@ def test_write_subsets(tmp_path_factory, dataframe, manifest, component_spec): data_writer.write_dataframe(dataframe) # read written data and assert for subset, subset_columns in subset_columns_dict.items(): - df = dd.read_parquet(fn / subset) - assert len(df) == NUMBER_OF_TEST_ROWS - assert list(df.columns) == subset_columns - assert df.index.name == "id" + dataframe = dd.read_parquet(fn / subset) + assert len(dataframe) == NUMBER_OF_TEST_ROWS + assert list(dataframe.columns) == subset_columns + assert dataframe.index.name == "id" def test_write_reset_index(tmp_path_factory, dataframe, manifest, component_spec): @@ -108,13 +108,17 @@ def test_write_reset_index(tmp_path_factory, dataframe, manifest, component_spec data_writer.write_dataframe(dataframe) for subset in ["properties", "types", "index"]: - df = dd.read_parquet(fn / subset) - assert df.index.name == "id" + dataframe = dd.read_parquet(fn / subset) + assert dataframe.index.name == "id" @pytest.mark.parametrize("partitions", list(range(1, 5))) def test_write_divisions( - tmp_path_factory, dataframe, manifest, component_spec, partitions + tmp_path_factory, + dataframe, + manifest, + component_spec, + partitions, ): """Test writing out index and subsets and asserting they have the divisions of the dataframe.""" # repartition the dataframe (default is 3 partitions) @@ -127,9 +131,9 @@ def test_write_divisions( data_writer.write_dataframe(dataframe) for target in ["properties", "types", "index"]: - df = dd.read_parquet(fn / target) - assert df.index.name == "id" - assert df.npartitions == partitions + dataframe = dd.read_parquet(fn / target) + assert dataframe.index.name == "id" + assert dataframe.npartitions == partitions def test_write_subsets_invalid(tmp_path_factory, dataframe, manifest, component_spec): @@ -140,5 +144,9 @@ def test_write_subsets_invalid(tmp_path_factory, dataframe, manifest, component_ # Drop one of the columns required in the output dataframe = dataframe.drop(["types_Type 2"], axis=1) data_writer = DaskDataWriter(manifest=manifest, component_spec=component_spec) - with pytest.raises(ValueError): + expected_error_msg = ( + r"Field \['types_Type 2'\] not in index defined in output subset " + r"types but not found in dataframe" + ) + with pytest.raises(ValueError, match=expected_error_msg): data_writer.write_dataframe(dataframe) diff --git a/tests/test_import_utils.py b/tests/test_import_utils.py index cc2496fd0..63120b05c 100644 --- a/tests/test_import_utils.py +++ b/tests/test_import_utils.py @@ -13,7 +13,7 @@ @pytest.mark.parametrize( - "package_name, import_error_msg, expected_result", + ("package_name", "import_error_msg", "expected_result"), [ ("jsonschema", "jsonschema package is not installed.", True), ( @@ -42,7 +42,8 @@ def test_available_packages(importlib_util_find_spec, importlib_metadata_version @mock.patch( - "importlib.metadata.version", side_effect=importlib.metadata.PackageNotFoundError + "importlib.metadata.version", + side_effect=importlib.metadata.PackageNotFoundError, ) def test_unavailable_packages(mock_importlib_metadata_version): """Test that is_datasets_available returns False when 'datasets' is not available.""" diff --git a/tests/test_manifest.py b/tests/test_manifest.py index b0f013af2..b39786195 100644 --- a/tests/test_manifest.py +++ b/tests/test_manifest.py @@ -10,13 +10,13 @@ manifest_path = Path(__file__).parent / "example_specs/manifests" -@pytest.fixture +@pytest.fixture() def valid_manifest(): with open(manifest_path / "valid_manifest.json") as f: return json.load(f) -@pytest.fixture +@pytest.fixture() def invalid_manifest(): with open(manifest_path / "invalid_manifest.json") as f: return json.load(f) @@ -64,7 +64,7 @@ def test_subset_fields(): assert "data2" in subset.fields # add a duplicate field - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="A field with name data2 already exists"): subset.add_field(name="data2", type_=Type("binary")) # add a duplicate field but overwrite @@ -121,7 +121,9 @@ def test_manifest_creation(): component_id = "component_id" manifest = Manifest.create( - base_path=base_path, run_id=run_id, component_id=component_id + base_path=base_path, + run_id=run_id, + component_id=component_id, ) manifest.add_subset("images", [("width", Type("int32")), ("height", Type("int32"))]) manifest.subsets["images"].add_field("data", Type("binary")) @@ -147,7 +149,7 @@ def test_manifest_creation(): "type": "binary", }, }, - } + }, }, } @@ -167,14 +169,16 @@ def test_manifest_alteration(valid_manifest): # test adding a subset manifest.add_subset( - "images2", [("width", Type("int32")), ("height", Type("int32"))] + "images2", + [("width", Type("int32")), ("height", Type("int32"))], ) assert "images2" in manifest.subsets # test adding a duplicate subset - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="A subset with name images2 already exists"): manifest.add_subset( - "images2", [("width", Type("int32")), ("height", Type("int32"))] + "images2", + [("width", Type("int32")), ("height", Type("int32"))], ) # test removing a subset @@ -182,7 +186,7 @@ def test_manifest_alteration(valid_manifest): assert "images2" not in manifest.subsets # test removing a nonexistant subset - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Subset pictures not found in specification"): manifest.remove_subset("pictures") diff --git a/tests/test_manifest_evolution.py b/tests/test_manifest_evolution.py index d6a5dbbe6..bec99c280 100644 --- a/tests/test_manifest_evolution.py +++ b/tests/test_manifest_evolution.py @@ -10,7 +10,7 @@ examples_path = Path(__file__).parent / "example_specs/evolution_examples" -@pytest.fixture +@pytest.fixture() def input_manifest(): with open(examples_path / "input_manifest.json") as f: return json.load(f) @@ -20,12 +20,12 @@ def examples(): """Returns examples as tuples of component and expected output_manifest.""" for directory in (f for f in examples_path.iterdir() if f.is_dir()): with open(directory / "component.yaml") as c, open( - directory / "output_manifest.json" + directory / "output_manifest.json", ) as o: yield yaml.safe_load(c), json.load(o) -@pytest.mark.parametrize("component_spec, output_manifest", list(examples())) +@pytest.mark.parametrize(("component_spec", "output_manifest"), list(examples())) def test_evolution(input_manifest, component_spec, output_manifest): manifest = Manifest(input_manifest) component_spec = ComponentSpec(component_spec) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 47df7645e..01e7b6541 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -15,12 +15,11 @@ def yaml_file_to_dict(file_path): - with open(file_path, "r") as file: - data = yaml.safe_load(file) - return data + with open(file_path) as file: + return yaml.safe_load(file) -@pytest.fixture +@pytest.fixture() def default_pipeline_args(): return { "pipeline_name": "pipeline", @@ -38,7 +37,10 @@ def default_pipeline_args(): ], ) def test_valid_pipeline( - default_pipeline_args, valid_pipeline_example, tmp_path, monkeypatch + default_pipeline_args, + valid_pipeline_example, + tmp_path, + monkeypatch, ): """Test that a valid pipeline definition can be compiled without errors.""" example_dir, component_names = valid_pipeline_example @@ -51,13 +53,16 @@ def test_valid_pipeline( monkeypatch.setattr(pipeline, "package_path", str(tmp_path / "test_pipeline.tgz")) first_component_op = ComponentOp( - Path(components_path / component_names[0]), arguments=component_args + Path(components_path / component_names[0]), + arguments=component_args, ) second_component_op = ComponentOp( - Path(components_path / component_names[1]), arguments=component_args + Path(components_path / component_names[1]), + arguments=component_args, ) third_component_op = ComponentOp( - Path(components_path / component_names[2]), arguments=component_args + Path(components_path / component_names[2]), + arguments=component_args, ) pipeline.add_op(third_component_op, dependencies=second_component_op) @@ -87,7 +92,9 @@ def test_valid_pipeline( ], ) def test_invalid_pipeline_dependencies( - default_pipeline_args, valid_pipeline_example, tmp_path + default_pipeline_args, + valid_pipeline_example, + tmp_path, ): """ Test that an InvalidPipelineDefinition exception is raised when attempting to create a pipeline @@ -100,13 +107,16 @@ def test_invalid_pipeline_dependencies( pipeline = Pipeline(**default_pipeline_args) first_component_op = ComponentOp( - Path(components_path / component_names[0]), arguments=component_args + Path(components_path / component_names[0]), + arguments=component_args, ) second_component_op = ComponentOp( - Path(components_path / component_names[1]), arguments=component_args + Path(components_path / component_names[1]), + arguments=component_args, ) third_component_op = ComponentOp( - Path(components_path / component_names[2]), arguments=component_args + Path(components_path / component_names[2]), + arguments=component_args, ) pipeline.add_op(third_component_op, dependencies=second_component_op) @@ -123,7 +133,9 @@ def test_invalid_pipeline_dependencies( ], ) def test_invalid_pipeline_compilation( - default_pipeline_args, invalid_pipeline_example, tmp_path + default_pipeline_args, + invalid_pipeline_example, + tmp_path, ): """ Test that an InvalidPipelineDefinition exception is raised when attempting to compile @@ -136,10 +148,12 @@ def test_invalid_pipeline_compilation( pipeline = Pipeline(**default_pipeline_args) first_component_op = ComponentOp( - Path(components_path / component_names[0]), arguments=component_args + Path(components_path / component_names[0]), + arguments=component_args, ) second_component_op = ComponentOp( - Path(components_path / component_names[1]), arguments=component_args + Path(components_path / component_names[1]), + arguments=component_args, ) pipeline.add_op(first_component_op) @@ -162,10 +176,11 @@ def test_invalid_argument(default_pipeline_args, invalid_component_args, tmp_pat component does not match the ones specified in the fondant specifications. """ components_spec_path = Path( - valid_pipeline_path / "example_1" / "first_component.yaml" + valid_pipeline_path / "example_1" / "first_component.yaml", ) component_operation = ComponentOp( - components_spec_path, arguments=invalid_component_args + components_spec_path, + arguments=invalid_component_args, ) pipeline = Pipeline(**default_pipeline_args) @@ -183,9 +198,13 @@ def test_reusable_component_op(): ) assert laion_retrieval_op.component_spec, "component_spec_path could not be loaded" - with pytest.raises(ValueError): + component_name = "this_component_does_not_exist" + with pytest.raises( + ValueError, + match=f"No reusable component with name {component_name} " "found.", + ): ComponentOp.from_registry( - name="this_component_does_not_exist", + name=component_name, ) @@ -200,7 +219,7 @@ def test_defining_reusable_component_op_with_custom_spec(): ) load_from_hub_op_default_spec = ComponentSpec( - yaml_file_to_dict(load_from_hub_op_default_op.component_spec_path) + yaml_file_to_dict(load_from_hub_op_default_op.component_spec_path), ) load_from_hub_op_custom_op = ComponentOp.from_registry( diff --git a/tests/test_schema.py b/tests/test_schema.py index 96f3d4c91..7d69331a7 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -22,26 +22,31 @@ def test_valid_json_schema(): """Test that Type class initialized with a json schema matches the expected pyarrow schema.""" assert Type.from_json({"type": "string"}).value == pa.string() assert Type.from_json( - {"type": "array", "items": {"type": "int8"}} + {"type": "array", "items": {"type": "int8"}}, ).value == pa.list_(pa.int8()) assert Type.from_json( - {"type": "array", "items": {"type": "array", "items": {"type": "int8"}}} + {"type": "array", "items": {"type": "array", "items": {"type": "int8"}}}, ).value == pa.list_(pa.list_(pa.int8())) -def test_invalid_json_schema(): +@pytest.mark.parametrize( + "statement", + [ + 'Type("invalid_type")', + 'Type("invalid_type").to_json()', + 'Type.list(Type("invalid_type"))', + 'Type.list(Type("invalid_type")).to_json()', + 'Type.from_json({"type": "invalid_value"})', + 'Type.from_json({"type": "invalid_value", "items": {"type": "int8"}})', + 'Type.from_json({"type": "array", "items": {"type": "invalid_type"}})', + ], +) +def test_invalid_json_schema(statement): """Test that an invalid type or schema specified with the Type class raise an invalid type schema error. """ with pytest.raises(InvalidTypeSchema): - Type("invalid_type") - Type("invalid_type").to_json() - Type.list(Type("invalid_type")) - Type.list(Type("invalid_type")).to_json() - Type.from_json({"invalid_key": "int8"}) - Type.from_json({"type": "invalid_value"}) - Type.from_json({"type": "invalid_value", "items": {"type": "int8"}}) - Type.from_json({"type": "array", "items": {"type": "invalid_type"}}) + eval(statement) def test_equality():