From 392899ba9027292d78725782aeb19b7837fe12ea Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Mon, 11 Mar 2024 14:11:28 -0400 Subject: [PATCH 1/2] Add linting with black --- ...oliynyk_test_atom_mixing_formatted_log.csv | 2 +- filter/occupancy.py | 46 ++-- postprocess/bond.py | 11 +- postprocess/excel.py | 53 +++-- postprocess/histogram.py | 32 +-- postprocess/writer.py | 29 ++- preprocess/cif_parser.py | 67 +++--- preprocess/cif_parser_handler.py | 30 ++- preprocess/supercell.py | 201 +++++++++++------- tests/conftest.py | 7 +- tests/filter/test_occupancy.py | 92 ++++---- tests/postprocess/test_bond.py | 7 +- tests/postprocess/test_pair_order.py | 6 +- tests/preprocess/test_cif_parser.py | 12 +- tests/test_single_cif.py | 7 +- util/folder.py | 22 +- util/prompt.py | 47 ++-- util/string_parser.py | 12 +- util/unit.py | 4 +- 19 files changed, 399 insertions(+), 288 deletions(-) diff --git a/20240229_oliynyk_test_atom_mixing_formatted/csv/20240229_oliynyk_test_atom_mixing_formatted_log.csv b/20240229_oliynyk_test_atom_mixing_formatted/csv/20240229_oliynyk_test_atom_mixing_formatted_log.csv index 3ec4892..4cb0089 100644 --- a/20240229_oliynyk_test_atom_mixing_formatted/csv/20240229_oliynyk_test_atom_mixing_formatted_log.csv +++ b/20240229_oliynyk_test_atom_mixing_formatted/csv/20240229_oliynyk_test_atom_mixing_formatted_log.csv @@ -1,2 +1,2 @@ File,Number of atoms in supercell,Processing time (s) -539016,334,3.585 +539016,334,3.945 diff --git a/filter/occupancy.py b/filter/occupancy.py index dc5a5a9..6f6c90c 100644 --- a/filter/occupancy.py +++ b/filter/occupancy.py @@ -17,9 +17,7 @@ def get_coord_occupancy_sum(cif_loop_values): coord_occupancy_sum = {} for i in range(num_atom_labels): - _, occupancy, coordinates = cif_parser.get_atom_info( - cif_loop_values, i - ) + _, occupancy, coordinates = cif_parser.get_atom_info(cif_loop_values, i) occupancy_num = coord_occupancy_sum.get(coordinates, 0) + occupancy coord_occupancy_sum[coordinates] = occupancy_num @@ -69,7 +67,7 @@ def get_all_possible_ordered_label_pairs(cif_loop_values): # Get a list of unique pairs from atomic labels label_list = cif_parser.get_atom_label_list(cif_loop_values) all_possible_label_pairs = list(product(label_list, repeat=2)) - + # Step 1: Sort each pair to standardize order sorted_pairs = pair_order.sort_tuple_in_list(all_possible_label_pairs) @@ -77,17 +75,16 @@ def get_all_possible_ordered_label_pairs(cif_loop_values): unique_sorted_pairs = list(set(sorted_pairs)) # Step 3. Order pairs based on Mendeleev ordering - unique_sorted_pairs_ordered = ( - [tuple(pair_order.order_pair_by_mendeleev(pair)) - for pair in unique_sorted_pairs] - ) + unique_sorted_pairs_ordered = [ + tuple(pair_order.order_pair_by_mendeleev(pair)) + for pair in unique_sorted_pairs + ] return unique_sorted_pairs_ordered # Get atom site mixing label for all pairs possible -def get_atom_site_mixing_dict( - atom_site_mixing_file_info, cif_loop_values): +def get_atom_site_mixing_dict(atom_site_mixing_file_info, cif_loop_values): """ Gets atomic site mixing dictionary for all possible label pairs using cif loop values. """ @@ -122,7 +119,7 @@ def get_atom_site_mixing_dict( if first_label_occ == 1 and second_label_occ == 1: atom_site_pair_dict[pair] = "4" continue - + # Step 4. Check deficiecny at the pair level # Check whehter one of the sites is deficient is_first_label_site_deficient = None @@ -136,7 +133,7 @@ def get_atom_site_mixing_dict( if occupancy_sum[second_label_coord] < 1: is_second_label_deficient = True - + else: is_second_label_deficient = False @@ -162,28 +159,29 @@ def get_atom_site_mixing_dict( # Assign "3" for "deficiency_no_atomic_mixing" # Check 1. One of the labels is deficient # Check 2. Both labels are not atomic mixed - if ((is_first_label_site_deficient or - is_second_label_deficient) and - (not is_first_label_atomic_mixed and - not is_second_label_atomic_mixed)): + if ( + is_first_label_site_deficient or is_second_label_deficient + ) and ( + not is_first_label_atomic_mixed + and not is_second_label_atomic_mixed + ): atom_site_pair_dict[pair] = "3" # Assign "2" for "full_occupancy_atomic_mixing" # Check 1. Both labels are not deficient # Check 2. At least one label is atomic mixed - if ((not is_first_label_site_deficient - and not is_second_label_deficient) and - (is_first_label_atomic_mixed or - is_second_label_atomic_mixed)): + if ( + not is_first_label_site_deficient + and not is_second_label_deficient + ) and (is_first_label_atomic_mixed or is_second_label_atomic_mixed): atom_site_pair_dict[pair] = "2" # Assign "1" for "deficiency" # Check 1. At least one label is deficient # Check 2. At least one label mixed - if ((is_first_label_site_deficient or - is_second_label_deficient) and - (is_first_label_atomic_mixed or - is_second_label_atomic_mixed)): + if ( + is_first_label_site_deficient or is_second_label_deficient + ) and (is_first_label_atomic_mixed or is_second_label_atomic_mixed): atom_site_pair_dict[pair] = "1" return atom_site_pair_dict diff --git a/postprocess/bond.py b/postprocess/bond.py index 13e1e59..83438e4 100644 --- a/postprocess/bond.py +++ b/postprocess/bond.py @@ -86,7 +86,9 @@ def get_sorted_missing_pairs(pair_dict): ) # Sort the pairs in the data as well before comparison - missing_label_pairs = [pair for pair in all_pairs if pair not in pairs_found] + missing_label_pairs = [ + pair for pair in all_pairs if pair not in pairs_found + ] return missing_label_pairs @@ -118,7 +120,8 @@ def get_unique_pairs_dict(ordered_pairs, filename): # if this pair is shorter than the previous pair if ( label_tuple not in unique_pairs_dict[filename] - or pair["distance"] < unique_pairs_dict[filename][label_tuple]["distance"] + or pair["distance"] + < unique_pairs_dict[filename][label_tuple]["distance"] ): # Add this pair to the dictionary unique_pairs_dict[filename][label_tuple] = pair @@ -126,7 +129,9 @@ def get_unique_pairs_dict(ordered_pairs, filename): return unique_pairs_dict -def get_dist_mix_pair_dict(dist_pair_dict, unique_pairs_dict, label_pair_mixing_dict): +def get_dist_mix_pair_dict( + dist_pair_dict, unique_pairs_dict, label_pair_mixing_dict +): """ Returns dict containing files and dist per pair. """ diff --git a/postprocess/excel.py b/postprocess/excel.py index 732cf71..a1748fa 100644 --- a/postprocess/excel.py +++ b/postprocess/excel.py @@ -17,10 +17,12 @@ def write_label_pair_dict_to_excel_json(input_dict, pair_tpye, dir_path): os.makedirs(output_dir) folder_name = os.path.basename(os.path.normpath(dir_path)) - excel_file_path = os.path.join(output_dir, - f"{folder_name}_{pair_tpye}_pairs.xlsx") - json_file_path = os.path.join(output_dir, - f"{folder_name}_{pair_tpye}_pairs.json") + excel_file_path = os.path.join( + output_dir, f"{folder_name}_{pair_tpye}_pairs.xlsx" + ) + json_file_path = os.path.join( + output_dir, f"{folder_name}_{pair_tpye}_pairs.json" + ) category_mapping = { 1: "deficiency", @@ -43,14 +45,17 @@ def write_label_pair_dict_to_excel_json(input_dict, pair_tpye, dir_path): inplace=True, ) - df["Distance"] = pd.to_numeric(df["Distance"], - errors="coerce").astype(float) + df["Distance"] = pd.to_numeric( + df["Distance"], errors="coerce" + ).astype(float) # Convert 'Atomic Mixing' column to numeric, coerce errors - df["Atomic Mixing"] = df["Atomic Mixing"].apply(pd.to_numeric, - errors="coerce") + df["Atomic Mixing"] = df["Atomic Mixing"].apply( + pd.to_numeric, errors="coerce" + ) df["Atomic Mixing"] = ( - df["Atomic Mixing"].map(category_mapping).fillna("Unknown")) + df["Atomic Mixing"].map(category_mapping).fillna("Unknown") + ) df["File"] = df["File"].apply(lambda x: f"{x}.cif") df.sort_values(by="Distance", inplace=True) @@ -76,10 +81,12 @@ def write_element_pair_dict_to_excel_json(input_dict, pair_type, dir_path): os.makedirs(output_dir, exist_ok=True) folder_name = os.path.basename(os.path.normpath(dir_path)) - excel_file_path = os.path.join(output_dir, - f"{folder_name}_{pair_type}_pairs.xlsx") - json_file_path = os.path.join(output_dir, - f"{folder_name}_{pair_type}_pairs.json") + excel_file_path = os.path.join( + output_dir, f"{folder_name}_{pair_type}_pairs.xlsx" + ) + json_file_path = os.path.join( + output_dir, f"{folder_name}_{pair_type}_pairs.json" + ) category_mapping = { "1": "deficiency", @@ -96,23 +103,27 @@ def write_element_pair_dict_to_excel_json(input_dict, pair_type, dir_path): for info in infos: # Here infos is a list of dictionaries info_copy = info.copy() info_copy[ - "File"] = f"{file_id}.cif" # Add the file ID as 'File' + "File" + ] = f"{file_id}.cif" # Add the file ID as 'File' aggregated_info.append(info_copy) # Create a DataFrame from the aggregated information df = pd.DataFrame(aggregated_info) # Rename columns to match the expected format - df.rename(columns={ - "dist": "Distance", - "mixing": "Atomic Mixing" - }, - inplace=True) + df.rename( + columns={"dist": "Distance", "mixing": "Atomic Mixing"}, + inplace=True, + ) # Apply numeric transformation and category mapping df["Distance"] = pd.to_numeric(df["Distance"], errors="coerce") - df["Atomic Mixing"] = (df["Atomic Mixing"].astype(str).map( - category_mapping).fillna("Unknown")) + df["Atomic Mixing"] = ( + df["Atomic Mixing"] + .astype(str) + .map(category_mapping) + .fillna("Unknown") + ) df.sort_values(by="Distance", inplace=True) # Specify the desired column order diff --git a/postprocess/histogram.py b/postprocess/histogram.py index 15f1463..2b883da 100644 --- a/postprocess/histogram.py +++ b/postprocess/histogram.py @@ -54,8 +54,9 @@ def plot_histograms_from_label_dict(data, directory_path): labels = [] for cat, _ in categories_colors.items(): category_distances = [ - float(pair_info["dist"]) for sub_key, - pair_info in pair_info.items() if pair_info["mixing"] == cat + float(pair_info["dist"]) + for sub_key, pair_info in pair_info.items() + if pair_info["mixing"] == cat ] if category_distances: stacked_data.append(category_distances) @@ -69,18 +70,19 @@ def plot_histograms_from_label_dict(data, directory_path): color=[categories_colors[cat] for cat in labels], label=[categories_mapping[cat] for cat in labels], stacked=True, - edgecolor='black' + edgecolor="black", ) ax.set_title(atomic_pair) ax.set_xlabel("Distance (Å)") ax.set_ylabel("Count") - ax.legend(loc='upper right') + ax.legend(loc="upper right") plt.tight_layout() - plt.savefig(os.path.join( - directory_path, - "output", "histograms_label_pair.png"), dpi=150) + plt.savefig( + os.path.join(directory_path, "output", "histograms_label_pair.png"), + dpi=150, + ) plt.close() @@ -114,7 +116,7 @@ def plot_histograms_from_element_dict(data, directory_path): for atomic_pair, records in data.items(): for infos in records.values(): for info in infos: - distances.append(float(info['dist'])) + distances.append(float(info["dist"])) all_distances = sorted(distances) bins = np.linspace(min(all_distances), max(all_distances), 21) @@ -129,8 +131,10 @@ def plot_histograms_from_element_dict(data, directory_path): labels = [] for cat, color in categories_colors.items(): category_distances = [ - float(info['dist']) for infos in records.values() - for info in infos if info['mixing'] == cat + float(info["dist"]) + for infos in records.values() + for info in infos + if info["mixing"] == cat ] if category_distances: stacked_data.append(category_distances) @@ -144,15 +148,17 @@ def plot_histograms_from_element_dict(data, directory_path): color=[categories_colors[cat] for cat in labels], label=[categories_mapping[cat] for cat in labels], stacked=True, - edgecolor='black' + edgecolor="black", ) ax.set_xlabel("Distance (Å)") ax.set_ylabel("Count") - ax.legend(loc='upper right') + ax.legend(loc="upper right") plt.tight_layout() output_dir = os.path.join(directory_path, "output") os.makedirs(output_dir, exist_ok=True) # Ensure output directory exists - plt.savefig(os.path.join(output_dir, "histograms_element_pair.png"), dpi=150) + plt.savefig( + os.path.join(output_dir, "histograms_element_pair.png"), dpi=150 + ) plt.close() diff --git a/postprocess/writer.py b/postprocess/writer.py index 6cca398..8a62c8e 100644 --- a/postprocess/writer.py +++ b/postprocess/writer.py @@ -3,10 +3,8 @@ def write_summary_and_missing_pairs( - dist_mix_pair_dict, - missing_pairs, - text_filename, - dir_path): + dist_mix_pair_dict, missing_pairs, text_filename, dir_path +): """ Writes a summary of unique atomic pairs, including counts and distances, and a list of missing pairs to a file. @@ -22,16 +20,16 @@ def write_summary_and_missing_pairs( # Step 1: Collect data data = [] for pair, files in dist_mix_pair_dict.items(): - distances = sorted(float(info['dist']) for info in files.values()) + distances = sorted(float(info["dist"]) for info in files.values()) count = len(distances) - dists = ', '.join(f"{distance:.3f}" for distance in distances) + dists = ", ".join(f"{distance:.3f}" for distance in distances) data.append((pair, count, dists)) # Step 2: Sort the data first by count (descending) then by pair name sorted_data = sorted(data, key=lambda x: (-x[1], x[0])) # Step 3: Write sorted data to file - with open(file_path, 'w', encoding="utf-8") as file: + with open(file_path, "w", encoding="utf-8") as file: file.write("Summary:\n") for pair, count, dists in sorted_data: file.write(f"Pair: {pair}, Count: {count} Distances: {dists}\n") @@ -54,10 +52,8 @@ def write_summary_and_missing_pairs( def write_summary_and_missing_pairs_with_element_dict( - dist_mix_pair_dict, - missing_pairs, - text_filename, - dir_path): + dist_mix_pair_dict, missing_pairs, text_filename, dir_path +): """ Writes a summary of unique atomic pairs, including counts and distances, and a list of missing pairs to a file. @@ -76,25 +72,26 @@ def write_summary_and_missing_pairs_with_element_dict( distances = [] for file_infos in files.values(): for info in file_infos: # Access each list in the dictionary - distances.append(float(info['dist'])) + distances.append(float(info["dist"])) distances = sorted(distances) count = len(distances) - dists = ', '.join(f"{distance:.3f}" for distance in distances) + dists = ", ".join(f"{distance:.3f}" for distance in distances) data.append((pair, count, dists)) # Step 2: Sort the data first by count (descending) then by pair name sorted_data = sorted(data, key=lambda x: (-x[1], x[0])) # Step 3: Write sorted data to file - with open(file_path, 'w', encoding="utf-8") as file: - + with open(file_path, "w", encoding="utf-8") as file: print("\nMissing pairs:") file.write("Summary:\n") for pair, count, dists in sorted_data: file.write(f"Pair: {pair}, Count: {count}, Distances: {dists}\n") file.write("\nMissing pairs:\n") - missing_pairs_sorted = sorted(missing_pairs, key=lambda x: (x[0][0], x[0], x[1])) + missing_pairs_sorted = sorted( + missing_pairs, key=lambda x: (x[0][0], x[0], x[1]) + ) for pair in missing_pairs_sorted: atom_1, atom_2 = pair file.write(f"{atom_1}-{atom_2}\n") diff --git a/preprocess/cif_parser.py b/preprocess/cif_parser.py index 5e6dc50..07662f5 100755 --- a/preprocess/cif_parser.py +++ b/preprocess/cif_parser.py @@ -8,10 +8,10 @@ def get_atom_type(label): """ Returns the element from the given label """ - parts = re.split(r'[()]', label) + parts = re.split(r"[()]", label) for part in parts: # Attempt to extract the atom type - match = re.search(r'([A-Z][a-z]*)', part) + match = re.search(r"([A-Z][a-z]*)", part) if match: return match.group(1) return None @@ -21,14 +21,16 @@ def get_loop_tags(): """ Returns tags commonly used for atomic description. """ - loop_tags = ["_atom_site_label", - "_atom_site_type_symbol", - "_atom_site_symmetry_multiplicity", - "_atom_site_Wyckoff_symbol", - "_atom_site_fract_x", - "_atom_site_fract_y", - "_atom_site_fract_z", - "_atom_site_occupancy"] + loop_tags = [ + "_atom_site_label", + "_atom_site_type_symbol", + "_atom_site_symmetry_multiplicity", + "_atom_site_Wyckoff_symbol", + "_atom_site_fract_x", + "_atom_site_fract_y", + "_atom_site_fract_z", + "_atom_site_occupancy", + ] return loop_tags @@ -37,22 +39,16 @@ def get_unit_cell_lengths_angles(block): """ Returns the unit cell lengths and angles from a given block. """ - keys_lengths = [ - '_cell_length_a', - '_cell_length_b', - '_cell_length_c' + keys_lengths = ["_cell_length_a", "_cell_length_b", "_cell_length_c"] + keys_angles = ["_cell_angle_alpha", "_cell_angle_beta", "_cell_angle_gamma"] + + lengths = [ + remove_string_braket(block.find_value(key)) for key in keys_lengths ] - keys_angles = [ - '_cell_angle_alpha', - '_cell_angle_beta', - '_cell_angle_gamma' + angles = [ + remove_string_braket(block.find_value(key)) for key in keys_angles ] - lengths = [remove_string_braket(block.find_value(key)) - for key in keys_lengths] - angles = [remove_string_braket(block.find_value(key)) - for key in keys_angles] - return tuple(lengths + angles) @@ -74,9 +70,11 @@ def get_loop_values(block, loop_tags): loop_values = [block.find_loop(tag) for tag in loop_tags] # Check for zero or missing coordinates - if (len(loop_values[4]) == 0 or - len(loop_values[5]) == 0 or - len(loop_values[6]) == 0): + if ( + len(loop_values[4]) == 0 + or len(loop_values[5]) == 0 + or len(loop_values[6]) == 0 + ): raise RuntimeError("Missing atomic coordinates") return loop_values @@ -88,7 +86,14 @@ def get_cell_lenghts_angles_rad(CIF_block): """ # Extract cell dimensions and angles from CIF block cell_lengths_angles = get_unit_cell_lengths_angles(CIF_block) - cell_len_a, cell_len_b, cell_len_c, alpha_deg, beta_deg, gamma_deg = cell_lengths_angles + ( + cell_len_a, + cell_len_b, + cell_len_c, + alpha_deg, + beta_deg, + gamma_deg, + ) = cell_lengths_angles # Convert angles from degrees to radians alpha_rad, beta_rad, gamma_rad = get_radians_from_degrees( @@ -140,9 +145,11 @@ def get_atom_info(cif_loop_values, i): """ label = cif_loop_values[0][i] occupancy = float(remove_string_braket(cif_loop_values[7][i])) - coordinates = (remove_string_braket(cif_loop_values[4][i]), - remove_string_braket(cif_loop_values[5][i]), - remove_string_braket(cif_loop_values[6][i])) + coordinates = ( + remove_string_braket(cif_loop_values[4][i]), + remove_string_braket(cif_loop_values[5][i]), + remove_string_braket(cif_loop_values[6][i]), + ) return label, occupancy, coordinates diff --git a/preprocess/cif_parser_handler.py b/preprocess/cif_parser_handler.py index af0f3ee..34f0d15 100644 --- a/preprocess/cif_parser_handler.py +++ b/preprocess/cif_parser_handler.py @@ -9,17 +9,24 @@ def get_cif_info(file_path, loop_tags, supercell_generation_method=3): Parse CIF data from file path. """ cif_block = cif_parser.get_cif_block(file_path) - cell_lengths, cell_angles_rad = cif_parser.get_cell_lenghts_angles_rad(cif_block) + cell_lengths, cell_angles_rad = cif_parser.get_cell_lenghts_angles_rad( + cif_block + ) cif_loop_values = cif_parser.get_loop_values(cif_block, loop_tags) all_coords_list = supercell.get_coords_list(cif_block, cif_loop_values) all_points, unique_labels, atom_site_list = supercell.get_points_and_labels( - all_coords_list, - cif_loop_values, - supercell_generation_method + all_coords_list, cif_loop_values, supercell_generation_method ) - return cif_block, cell_lengths, cell_angles_rad, \ - all_coords_list, all_points, unique_labels, atom_site_list + return ( + cif_block, + cell_lengths, + cell_angles_rad, + all_coords_list, + all_points, + unique_labels, + atom_site_list, + ) def get_cif_loop_values(file_path: str) -> list: @@ -33,9 +40,7 @@ def get_cif_loop_values(file_path: str) -> list: return cif_loop_values -def get_folder_and_files_info( - script_directory: str, - is_interactive_mode: bool): +def get_folder_and_files_info(script_directory: str, is_interactive_mode: bool): """ Get info about folders and files. """ @@ -51,8 +56,11 @@ def get_folder_and_files_info( filtered_folder_name = f"{folder_name}_filter_dist_min" filtered_folder = os.path.join(folder_info, filtered_folder_name) - files_lst = [os.path.join(folder_info, file) - for file in os.listdir(folder_info) if file.endswith('.cif')] + files_lst = [ + os.path.join(folder_info, file) + for file in os.listdir(folder_info) + if file.endswith(".cif") + ] num_of_files = len(files_lst) loop_tags = cif_parser.get_loop_tags() diff --git a/preprocess/supercell.py b/preprocess/supercell.py index 7a14254..65e1c46 100755 --- a/preprocess/supercell.py +++ b/preprocess/supercell.py @@ -22,9 +22,30 @@ def calculate_distance(point1, point2, cell_lengths, angles): dz_sq = (cell_lengths[2] * delta_z) ** 2 # Calculate cross terms - cross_x = 2 * cell_lengths[1] * cell_lengths[2] * np.cos(angles[0]) * delta_y * delta_z - cross_y = 2 * cell_lengths[2] * cell_lengths[0] * np.cos(angles[1]) * delta_z * delta_x - cross_z = 2 * cell_lengths[0] * cell_lengths[1] * np.cos(angles[2]) * delta_x * delta_y + cross_x = ( + 2 + * cell_lengths[1] + * cell_lengths[2] + * np.cos(angles[0]) + * delta_y + * delta_z + ) + cross_y = ( + 2 + * cell_lengths[2] + * cell_lengths[0] + * np.cos(angles[1]) + * delta_z + * delta_x + ) + cross_z = ( + 2 + * cell_lengths[0] + * cell_lengths[1] + * np.cos(angles[2]) + * delta_x + * delta_y + ) # Calculate squared distance result = dx_sq + dy_sq + dz_sq + cross_x + cross_y + cross_z @@ -35,10 +56,9 @@ def calculate_distance(point1, point2, cell_lengths, angles): return distance, label1, label2 - -def shift_and_append_points(points, atom_site_label, - num_unitcell_atom, - supercell_generation_method): +def shift_and_append_points( + points, atom_site_label, num_unitcell_atom, supercell_generation_method +): """ Shift and duplicate points to create supercell. """ @@ -55,52 +75,83 @@ def shift_and_append_points(points, atom_site_label, all_points = [] for point_group in shifted_points: for point in point_group: - new_point = (*np.round(point,5), atom_site_label) + new_point = (*np.round(point, 5), atom_site_label) all_points.append(new_point) return all_points - if supercell_generation_method == 2: - shifts = np.array([ - [0, 0, 0], [1, 0, 0], [0, 1, 0], - [1, 1, 0], [0, 0, 1], [1, 0, 1], - [0, 1, 1], [1, 1, 1] - ]) + if supercell_generation_method == 2: + shifts = np.array( + [ + [0, 0, 0], + [1, 0, 0], + [0, 1, 0], + [1, 1, 0], + [0, 0, 1], + [1, 0, 1], + [0, 1, 1], + [1, 1, 1], + ] + ) shifted_points = points[:, None, :] + shifts[None, :, :] all_points = [] for point_group in shifted_points: for point in point_group: - new_point = (*np.round(point,5), atom_site_label) + new_point = (*np.round(point, 5), atom_site_label) all_points.append(new_point) return all_points - if supercell_generation_method == 3: - shifts = np.array([ - [0, 0, 0],[1, 0, 0], [0, 1, 0], - [1, 1, 0], [0, 0, 1], [1, 0, 1], - [0, 1, 1], [1, 1, 1], [-1, 0, 0], - [0, -1, 0], [-1, -1, 0], [0, 0, -1], - [1, 0, -1], [0, -1, -1], [-1, -1, -1] - ]) - + if supercell_generation_method == 3: + shifts = np.array( + [ + [0, 0, 0], + [1, 0, 0], + [0, 1, 0], + [1, 1, 0], + [0, 0, 1], + [1, 0, 1], + [0, 1, 1], + [1, 1, 1], + [-1, 0, 0], + [0, -1, 0], + [-1, -1, 0], + [0, 0, -1], + [1, 0, -1], + [0, -1, -1], + [-1, -1, -1], + ] + ) + shifted_points = points[:, None, :] + shifts[None, :, :] all_points = [] for point_group in shifted_points: for point in point_group: - new_point = (*np.round(point,5), atom_site_label) + new_point = (*np.round(point, 5), atom_site_label) all_points.append(new_point) return all_points # General method for files below 200 atoms in the unit cell - shifts = np.array([ - [0, 0, 0], [1, 0, 0], [0, 1, 0], - [1, 1, 0], [0, 0, 1], [1, 0, 1], - [0, 1, 1], [1, 1, 1], [-1, 0, 0], - [0, -1, 0], [-1, -1, 0], [0, 0, -1], - [1, 0, -1], [0, -1, -1], [-1, -1, -1] - ]) + shifts = np.array( + [ + [0, 0, 0], + [1, 0, 0], + [0, 1, 0], + [1, 1, 0], + [0, 0, 1], + [1, 0, 1], + [0, 1, 1], + [1, 1, 1], + [-1, 0, 0], + [0, -1, 0], + [-1, -1, 0], + [0, 0, -1], + [1, 0, -1], + [0, -1, -1], + [-1, -1, -1], + ] + ) shifted_points = points[:, None, :] + shifts[None, :, :] all_points = [] @@ -114,7 +165,7 @@ def shift_and_append_points(points, atom_site_label, def get_coords_list(block, loop_values): """ - Computes the new coordinates after applying + Computes the new coordinates after applying symmetry operations to the initial coordinates. """ @@ -127,22 +178,24 @@ def get_coords_list(block, loop_values): atom_site_label = loop_values[0][i] coords_after_symmetry_operations = get_coords_after_sym_operations( - block, + block, float(atom_site_x), float(atom_site_y), float(atom_site_z), - atom_site_label + atom_site_label, ) coords_list.append(coords_after_symmetry_operations) return coords_list -def get_coords_after_sym_operations(block, - atom_site_fract_x, - atom_site_fract_y, - atom_site_fract_z, - atom_site_label): +def get_coords_after_sym_operations( + block, + atom_site_fract_x, + atom_site_fract_y, + atom_site_fract_z, + atom_site_label, +): """ Generates a list of coordinates for each atom site """ @@ -151,11 +204,9 @@ def get_coords_after_sym_operations(block, operation = operation.replace("'", "") try: op = gemmi.Op(operation) - new_x, new_y, new_z = op.apply_to_xyz([ - atom_site_fract_x, - atom_site_fract_y, - atom_site_fract_z - ]) + new_x, new_y, new_z = op.apply_to_xyz( + [atom_site_fract_x, atom_site_fract_y, atom_site_fract_z] + ) new_x = round(new_x, 5) new_y = round(new_y, 5) new_z = round(new_z, 5) @@ -164,14 +215,16 @@ def get_coords_after_sym_operations(block, except RuntimeError as e: print(f"Skipping operation '{operation}': {str(e)}") - raise RuntimeError("An error occurred while processing symmetry operation") from e + raise RuntimeError( + "An error occurred while processing symmetry operation" + ) from e return list(all_coords) -def get_points_and_labels(all_coords_list, - loop_values, - supercell_generation_method): +def get_points_and_labels( + all_coords_list, loop_values, supercell_generation_method +): """ Process coordinates and loop values to extract points, labels, and atom types. """ @@ -183,30 +236,37 @@ def get_points_and_labels(all_coords_list, num_unitcell_atom = 0 for i, all_coords in enumerate(all_coords_list): - points = np.array([list(map(float, coord[:-1])) for coord in all_coords]) + points = np.array( + [list(map(float, coord[:-1])) for coord in all_coords] + ) num_unitcell_atom += len(points) - for i, all_coords in enumerate(all_coords_list): - points = np.array([list(map(float, coord[:-1])) for coord in all_coords]) + points = np.array( + [list(map(float, coord[:-1])) for coord in all_coords] + ) atom_site_label = loop_values[0][i] atom_site_type = loop_values[1][i] unique_labels.append(atom_site_label) unique_atoms_tuple.append(atom_site_type) - all_points.extend(shift_and_append_points( - points, - atom_site_label, - num_unitcell_atom, - supercell_generation_method - )) + all_points.extend( + shift_and_append_points( + points, + atom_site_label, + num_unitcell_atom, + supercell_generation_method, + ) + ) if atom_site_type in atom_site_label: continue if cif_parser.get_atom_type(atom_site_label) != atom_site_type: - raise RuntimeError("Different elements found in atom site and label") + raise RuntimeError( + "Different elements found in atom site and label" + ) return list(set(all_points)), unique_labels, unique_atoms_tuple @@ -224,23 +284,22 @@ def get_atomic_pair_list(flattened_points, cell_lengths, angles): for j, point2 in enumerate(flattened_points): if i != j: pair = tuple(sorted([i, j])) - if pair not in pairs_set: + if pair not in pairs_set: distance, atom_label1, atom_label2 = calculate_distance( - point1, - point2, - cell_lengths, - angles + point1, point2, cell_lengths, angles ) if abs(distance) > 1e-3: - distances_from_point_i.append({ - 'point_pair': (i + 1, j + 1), - 'labels': (atom_label1, atom_label2), - 'coordinates': (point1[:3], point2[:3]), - 'distance': np.round(distance, 5) - }) + distances_from_point_i.append( + { + "point_pair": (i + 1, j + 1), + "labels": (atom_label1, atom_label2), + "coordinates": (point1[:3], point2[:3]), + "distance": np.round(distance, 5), + } + ) pairs_set.add(pair) - distances_from_point_i.sort(key=lambda x: x['distance']) + distances_from_point_i.sort(key=lambda x: x["distance"]) atomic_info_list.extend(distances_from_point_i) return atomic_info_list diff --git a/tests/conftest.py b/tests/conftest.py index cbe41d2..f1f620f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,8 @@ -#conftest.py +# conftest.py import pytest import preprocess.cif_parser_handler as cif_parser_handler + @pytest.fixture def get_cif_527000_loop_values(): CIF_loop_values = cif_parser_handler.get_cif_loop_values( @@ -26,7 +27,7 @@ def get_cif_300160_loop_values(): return CIF_loop_values -@pytest.fixture +@pytest.fixture def get_cif_1831432_loop_values(): CIF_loop_values = cif_parser_handler.get_cif_loop_values( "tests/filter/cifs/1831432.cif" @@ -56,5 +57,3 @@ def get_cif_URhIn_loop_values(): "tests/filter/cifs/URhIn.cif" ) return CIF_loop_values - - diff --git a/tests/filter/test_occupancy.py b/tests/filter/test_occupancy.py index 94aa840..be67be4 100644 --- a/tests/filter/test_occupancy.py +++ b/tests/filter/test_occupancy.py @@ -10,6 +10,7 @@ def test_get_atom_site_mixing_info(get_cif_527000_loop_values): ) assert atom_mixing_type == "3" + @pytest.mark.fast def test_get_atom_site_mixing_info(get_cif_1803318_loop_values): atom_mixing_type = occupancy.get_atom_site_mixing_info( @@ -20,25 +21,28 @@ def test_get_atom_site_mixing_info(get_cif_1803318_loop_values): @pytest.mark.fast def test_get_all_possible_ordered_label_pair_tuples_300160( - get_cif_300160_loop_values): - + get_cif_300160_loop_values, +): ordered_label_pairs = occupancy.get_all_possible_ordered_label_pairs( get_cif_300160_loop_values ) assert len(ordered_label_pairs) == 6 - assert sorted(ordered_label_pairs) == sorted([ - ("Rh1", "Rh1"), - ("Ge1", "Ge1"), - ("Sm1", "Sm1"), - ("Sm1", "Rh1"), - ("Sm1", "Ge1"), - ("Rh1", "Ge1") - ]) + assert sorted(ordered_label_pairs) == sorted( + [ + ("Rh1", "Rh1"), + ("Ge1", "Ge1"), + ("Sm1", "Sm1"), + ("Sm1", "Rh1"), + ("Sm1", "Ge1"), + ("Rh1", "Ge1"), + ] + ) + @pytest.mark.fast def test_get_all_possible_ordered_label_pair_tuples_URhIn( - get_cif_URhIn_loop_values): - + get_cif_URhIn_loop_values, +): ordered_label_pairs = occupancy.get_all_possible_ordered_label_pairs( get_cif_URhIn_loop_values ) @@ -47,18 +51,21 @@ def test_get_all_possible_ordered_label_pair_tuples_URhIn( # In1, U1, Rh1, Rh2 assert len(ordered_label_pairs) == 10 - assert sorted(ordered_label_pairs) == sorted([ - ("In1", "In1"), - ("U1", "U1"), - ("Rh2", "Rh2"), - ("Rh1", "Rh1"), # 4 same pairs - ("U1", "In1"), - ("Rh1", "In1"), - ("Rh2", "In1"), # 3 pairs below In1 - ("Rh1", "Rh2"), - ("U1", "Rh2"), # 2 pairs below Rh2 - ("U1", "Rh1") # 1 pair below Rh1 - ]) + assert sorted(ordered_label_pairs) == sorted( + [ + ("In1", "In1"), + ("U1", "U1"), + ("Rh2", "Rh2"), + ("Rh1", "Rh1"), # 4 same pairs + ("U1", "In1"), + ("Rh1", "In1"), + ("Rh2", "In1"), # 3 pairs below In1 + ("Rh1", "Rh2"), + ("U1", "Rh2"), # 2 pairs below Rh2 + ("U1", "Rh1"), # 1 pair below Rh1 + ] + ) + @pytest.mark.fast def test_get_atom_site_mixing_dict_1(get_cif_300160_loop_values): @@ -67,8 +74,7 @@ def test_get_atom_site_mixing_dict_1(get_cif_300160_loop_values): ) atom_site_pair_dict = occupancy.get_atom_site_mixing_dict( - atom_site_mixing_file_info, - get_cif_300160_loop_values + atom_site_mixing_file_info, get_cif_300160_loop_values ) # Mendeleev # - Ge 79, Rh 59, Sm 23 @@ -83,17 +89,16 @@ def test_get_atom_site_mixing_dict_1(get_cif_300160_loop_values): @pytest.mark.fast def test_get_atom_site_mixing_dict_2(get_cif_527000_loop_values): - ''' + """ Pair: Rh2-Si 2.28 Å - deficiency_no_atomic_mixing Pair: Rh1-Rh1 2.524 Å - full_occupancy - ''' + """ atom_site_mixing_file_info = occupancy.get_atom_site_mixing_info( get_cif_527000_loop_values ) atom_site_pair_dict = occupancy.get_atom_site_mixing_dict( - atom_site_mixing_file_info, - get_cif_527000_loop_values + atom_site_mixing_file_info, get_cif_527000_loop_values ) # Mendeleev # - Rh 59, Si 78 @@ -108,32 +113,31 @@ def test_get_atom_site_mixing_dict_2(get_cif_527000_loop_values): @pytest.mark.fast def test_get_atom_site_mixing_dict_3(get_cif_1831432_loop_values): - ''' + """ Mendeleev # - Fe 55, Ge 79 1831432.cif Fe Fe 8 b 0.375 0.375 0.375 0.01 Ge1 Ge 8 a 0.125 0.125 0.125 0.944 Fe2 Fe 8 a 0.125 0.125 0.125 0.056 - + Result: Fe-Fe 2.448 deficiency, Fe-Ge 2.448 mixing-deficiency, Fe-Fe 2.448 mixing-deficiency - ''' + """ atom_site_mixing_file_info = occupancy.get_atom_site_mixing_info( get_cif_1831432_loop_values ) atom_site_pair_dict = occupancy.get_atom_site_mixing_dict( - atom_site_mixing_file_info, - get_cif_1831432_loop_values + atom_site_mixing_file_info, get_cif_1831432_loop_values ) assert len(atom_site_pair_dict) == 6 assert atom_site_pair_dict[("Fe", "Fe")] == "3" assert atom_site_pair_dict[("Fe", "Fe2")] == "1" - assert atom_site_pair_dict[("Fe", "Ge1")] == "1" + assert atom_site_pair_dict[("Fe", "Ge1")] == "1" assert atom_site_pair_dict[("Fe2", "Ge1")] == "2" assert atom_site_pair_dict[("Fe2", "Fe2")] == "2" assert atom_site_pair_dict[("Ge1", "Ge1")] == "2" @@ -141,7 +145,7 @@ def test_get_atom_site_mixing_dict_3(get_cif_1831432_loop_values): @pytest.mark.fast def test_get_atom_site_mixing_dict_4(get_cif_529848_loop_values): - ''' + """ Mendeleev # - Ni 61, Sb 85 529848.cif Ni1 Ni 4 a 0 0 0 0.92 @@ -149,14 +153,13 @@ def test_get_atom_site_mixing_dict_4(get_cif_529848_loop_values): Result: 529848: Ni-Sb 2.531 mixing - ''' + """ atom_site_mixing_file_info = occupancy.get_atom_site_mixing_info( get_cif_529848_loop_values ) atom_site_pair_dict = occupancy.get_atom_site_mixing_dict( - atom_site_mixing_file_info, - get_cif_529848_loop_values + atom_site_mixing_file_info, get_cif_529848_loop_values ) assert len(atom_site_pair_dict) == 3 @@ -167,23 +170,22 @@ def test_get_atom_site_mixing_dict_4(get_cif_529848_loop_values): @pytest.mark.fast def test_get_atom_site_mixing_dict_5(get_cif_1617211_loop_values): - ''' + """ Mendeleev # - Fe 55, Si 78 1617211.cif Si1 Si 2 h 0.5 0.5 0.2700 1 Fe1A Fe 1 a 0 0 0 0.85008 Si1B Si 1 a 0 0 0 0.06992 - + Result: 529848: Ni-Sb 2.531 mixing - ''' + """ atom_site_mixing_file_info = occupancy.get_atom_site_mixing_info( get_cif_1617211_loop_values ) atom_site_pair_dict = occupancy.get_atom_site_mixing_dict( - atom_site_mixing_file_info, - get_cif_1617211_loop_values + atom_site_mixing_file_info, get_cif_1617211_loop_values ) assert len(atom_site_pair_dict) == 6 diff --git a/tests/postprocess/test_bond.py b/tests/postprocess/test_bond.py index fb625ec..0d94aa0 100644 --- a/tests/postprocess/test_bond.py +++ b/tests/postprocess/test_bond.py @@ -3,7 +3,7 @@ @pytest.mark.fast def test_remove_duplicate_pairs(): - ''' + """ unique_pairs_distances_test = { ('Ga1A', 'Ga1'): ['2.601'], ('Ga1', 'La1'): ['3.291'], @@ -11,7 +11,7 @@ def test_remove_duplicate_pairs(): ('Ga1', 'Ga1A'): ['2.601'], ('Ga1', 'Ga1'): ['2.358']} - to + to adjusted_pairs_test == { ('Ga1', 'Ga1A'): ['2.601'], @@ -19,8 +19,7 @@ def test_remove_duplicate_pairs(): ('Co1B', 'Ga1'): ['2.601'], ('Ga1', 'Ga1A'): ['2.601'], ('Ga1', 'Ga1'): ['2.358']} - ''' - + """ # # 560709.cif diff --git a/tests/postprocess/test_pair_order.py b/tests/postprocess/test_pair_order.py index a5ef643..3edf707 100644 --- a/tests/postprocess/test_pair_order.py +++ b/tests/postprocess/test_pair_order.py @@ -27,7 +27,7 @@ def test_order_pair_by_mendeleev_and_label(): expected = pair_order.order_pair_by_mendeleev(("In", "Rh")) assert expected == ("Rh", "In") - + expected = pair_order.order_pair_by_mendeleev(("Rh4", "Rh2")) assert expected == ("Rh2", "Rh4") @@ -51,7 +51,3 @@ def test_sort_tuple_in_list(): tuple_pairs = [("Co2A", "Co1A")] sorted_tuple_pairs = pair_order.sort_tuple_in_list(tuple_pairs) assert sorted_tuple_pairs == [("Co1A", "Co2A")] - - - - diff --git a/tests/preprocess/test_cif_parser.py b/tests/preprocess/test_cif_parser.py index a5cf5e8..69c01c3 100644 --- a/tests/preprocess/test_cif_parser.py +++ b/tests/preprocess/test_cif_parser.py @@ -6,9 +6,7 @@ def test_get_unique_element_list(get_cif_527000_loop_values): CIF_loop_values = get_cif_527000_loop_values - unique_element_list = cif_parser.get_unique_element_list( - CIF_loop_values - ) + unique_element_list = cif_parser.get_unique_element_list(CIF_loop_values) assert set(unique_element_list) == set(["Rh", "Si"]) @@ -16,9 +14,7 @@ def test_get_unique_element_list(get_cif_527000_loop_values): def test_get_atom_label_list(get_cif_527000_loop_values): CIF_loop_values = get_cif_527000_loop_values - label_list = cif_parser.get_atom_label_list( - CIF_loop_values - ) + label_list = cif_parser.get_atom_label_list(CIF_loop_values) assert label_list == ["Rh2", "Si", "Rh1"] @@ -26,7 +22,5 @@ def test_get_atom_label_list(get_cif_527000_loop_values): def test_get_num_of_atom_labels(get_cif_527000_loop_values): CIF_loop_values = get_cif_527000_loop_values - num_of_atom_labels = cif_parser.get_num_of_atom_labels( - CIF_loop_values - ) + num_of_atom_labels = cif_parser.get_num_of_atom_labels(CIF_loop_values) assert num_of_atom_labels == 3 diff --git a/tests/test_single_cif.py b/tests/test_single_cif.py index 9b9f546..f58207a 100644 --- a/tests/test_single_cif.py +++ b/tests/test_single_cif.py @@ -1,4 +1,6 @@ -from main import main # Assuming main.py and this test file are in the same directory +from main import ( + main, +) # Assuming main.py and this test file are in the same directory import os import json import util.folder as folder @@ -10,6 +12,7 @@ def cleanup(dir_path): output_dir_path = os.path.join(dir_path, "output") folder.remove_directories([csv_dir_path, output_dir_path]) + # def run_test(dir_path, expected_json): # cif_folder_name = os.path.basename(dir_path) # output_dir_path = os.path.join(dir_path, "output") @@ -17,7 +20,7 @@ def cleanup(dir_path): # # Run # main(False, dir_path) - + # # Load output # with open(json_output_path, 'r') as file: # actual_output = json.load(file) diff --git a/util/folder.py b/util/folder.py index db56401..4963945 100644 --- a/util/folder.py +++ b/util/folder.py @@ -8,10 +8,16 @@ def choose_CIF_directory(script_directory): """ Allows the user to select a directory from the given path. """ - directories = [d for d in os.listdir(script_directory) - if os.path.isdir(join(script_directory, d)) - and any(file.endswith('.cif') for file in os.listdir(join(script_directory, d)))] - + directories = [ + d + for d in os.listdir(script_directory) + if os.path.isdir(join(script_directory, d)) + and any( + file.endswith(".cif") + for file in os.listdir(join(script_directory, d)) + ) + ] + if not directories: print("No directories found in the current path containing .cif files!") return None @@ -23,9 +29,11 @@ def choose_CIF_directory(script_directory): try: choice = int(input("\nEnter folder # having .cif files: ")) if 1 <= choice <= len(directories): - return join(script_directory, directories[choice-1]) + return join(script_directory, directories[choice - 1]) else: - print(f"Please enter a number between 1 and {len(directories)}.") + print( + f"Please enter a number between 1 and {len(directories)}." + ) except ValueError: print("Invalid input. Please enter a number.") @@ -35,7 +43,7 @@ def save_to_csv_directory(folder_info, df, base_filename): Saves the dataframe as a CSV inside a 'csv' sub-directory. """ # Create the sub-directory for CSVs if it doesn't exist - + csv_directory = join(folder_info, "csv") if not os.path.exists(csv_directory): os.mkdir(csv_directory) diff --git a/util/prompt.py b/util/prompt.py index 0d2b708..ab9f5a1 100644 --- a/util/prompt.py +++ b/util/prompt.py @@ -3,9 +3,11 @@ from click import style, echo import json + def print_intro_prompt(): """Filters and moves CIF files based on the shortest atomic distance.""" - intro_prompt = textwrap.dedent("""\ + intro_prompt = textwrap.dedent( + """\ === Welcome to the CIF Bond Analyzer! @@ -19,28 +21,41 @@ def print_intro_prompt(): Let's get started! === - """) + """ + ) print(intro_prompt) def get_user_input_on_supercell_method(): - click.echo("\nDo you want to modify the supercell generation method for CIF files with more than 100 atoms in the unit cell?") - is_supercell_generation_method_modified = click.confirm('(Default: N)', default=False) + click.echo( + "\nDo you want to modify the supercell generation method for CIF files with more than 100 atoms in the unit cell?" + ) + is_supercell_generation_method_modified = click.confirm( + "(Default: N)", default=False + ) if is_supercell_generation_method_modified: click.echo("\nChoose a supercell generation method:") click.echo("1. No shift (fastest)") click.echo("2. +1 +1 +1 shifts in x, y, z directions") - click.echo("3. +-1, +-1, +-1 shifts (2x2x2 supercell generation, requires heavy computation, slowest)") - - method = click.prompt("Choose your option by entering a number", type=int) - + click.echo( + "3. +-1, +-1, +-1 shifts (2x2x2 supercell generation, requires heavy computation, slowest)" + ) + + method = click.prompt( + "Choose your option by entering a number", type=int + ) + if method == 1: click.echo("You've selected: No shift (fastest)\n") elif method == 2: - click.echo("You've selected: +1 +1 +1 shifts in x, y, z directions\n") + click.echo( + "You've selected: +1 +1 +1 shifts in x, y, z directions\n" + ) elif method == 3: - click.echo("You've selected: +-1, +-1, +-1 shifts (2x2x2 supercell generation, slowest)\n") + click.echo( + "You've selected: +-1, +-1, +-1 shifts (2x2x2 supercell generation, slowest)\n" + ) else: click.echo("Invalid option. Defaulting to No shift (fastest)\n") method = 1 @@ -52,11 +67,13 @@ def get_user_input_on_supercell_method(): def print_progress(filename_with_ext, num_of_atoms, elapsed_time, is_finished): if is_finished: - echo(style( - f"Processed {filename_with_ext} with {num_of_atoms} atoms in " - f"{round(elapsed_time, 2)} s\n", - fg="blue" - )) + echo( + style( + f"Processed {filename_with_ext} with {num_of_atoms} atoms in " + f"{round(elapsed_time, 2)} s\n", + fg="blue", + ) + ) def print_dict_in_json(data): diff --git a/util/string_parser.py b/util/string_parser.py index 36e4e66..fe1b988 100755 --- a/util/string_parser.py +++ b/util/string_parser.py @@ -2,7 +2,11 @@ def remove_string_braket(value_string): - """ - Removes parentheses from a value string and convert to float if possible. - """ - return float(value_string.split('(')[0]) if '(' in value_string else float(value_string) + """ + Removes parentheses from a value string and convert to float if possible. + """ + return ( + float(value_string.split("(")[0]) + if "(" in value_string + else float(value_string) + ) diff --git a/util/unit.py b/util/unit.py index e26adfb..f26eba4 100755 --- a/util/unit.py +++ b/util/unit.py @@ -14,7 +14,5 @@ def rounded_distance(distance, precision=2): """ Round a distance value to a specified precision. """ - - return round(distance, precision) - + return round(distance, precision) From 54ed34684365dc074b867ecbd0e3aaea791465a4 Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Mon, 11 Mar 2024 14:16:25 -0400 Subject: [PATCH 2/2] Added linting --- main.py | 86 +++++++++++++++---------------------------- util/prompt.py | 4 +- util/string_parser.py | 3 -- 3 files changed, 31 insertions(+), 62 deletions(-) diff --git a/main.py b/main.py index 6970fcf..3a032e3 100644 --- a/main.py +++ b/main.py @@ -41,8 +41,10 @@ def main(is_iteractive_mode=True, dir_path=None): # If the user chooses no option, then it's simply 3 if not supercell_method: - print("\nYour default option is generating a 2-2-2 supercell for", - "files more than 100 atoms in the unit cell.") + print( + "\nYour default option is generating a 2-2-2 supercell for", + "files more than 100 atoms in the unit cell.", + ) supercell_method = 1 if not is_iteractive_mode: @@ -52,7 +54,7 @@ def main(is_iteractive_mode=True, dir_path=None): file_path_list = folder.get_cif_file_path_list(dir_path) # PART 2: PREPROCESS - + dist_mix_pair_dict = {} overall_start_time = time.perf_counter() @@ -62,17 +64,13 @@ def main(is_iteractive_mode=True, dir_path=None): filename_with_ext = os.path.basename(file_path) filename, ext = os.path.splitext(filename_with_ext) num_of_atoms = None - + # Process CIF files and return a list of coordinates result = cif_parser_handler.get_cif_info( - file_path, - cif_parser.get_loop_tags(), - supercell_method + file_path, cif_parser.get_loop_tags(), supercell_method ) - CIF_loop_values = cif_parser_handler.get_cif_loop_values( - file_path - ) + CIF_loop_values = cif_parser_handler.get_cif_loop_values(file_path) _, lenghts, angles_rad, _, all_points, _, atom_site_list = result @@ -82,14 +80,13 @@ def main(is_iteractive_mode=True, dir_path=None): echo( style( f"Processing {filename_with_ext} with " - f"{num_of_atoms} atoms {index}", fg="yellow" + f"{num_of_atoms} atoms {index}", + fg="yellow", ) ) atomic_pair_list = supercell.get_atomic_pair_list( - all_points, - lenghts, - angles_rad + all_points, lenghts, angles_rad ) # Get atomic site mixing info -> String @@ -104,35 +101,26 @@ def main(is_iteractive_mode=True, dir_path=None): # Find the shortest pair from each reference atom ordered_pairs = bond.process_and_order_pairs( - all_points, - atomic_pair_list + all_points, atomic_pair_list ) # Determine unique pairs and get the shortest dist for each pair - unique_pairs_dict = bond.get_unique_pairs_dict( - ordered_pairs, - filename - ) + unique_pairs_dict = bond.get_unique_pairs_dict(ordered_pairs, filename) dist_mix_pair_dict = bond.get_dist_mix_pair_dict( - dist_mix_pair_dict, - unique_pairs_dict, - label_pair_mixing_dict + dist_mix_pair_dict, unique_pairs_dict, label_pair_mixing_dict ) elapsed_time = time.perf_counter() - start_time prompt.print_progress( - filename_with_ext, - num_of_atoms, - elapsed_time, - is_finished=True + filename_with_ext, num_of_atoms, elapsed_time, is_finished=True ) data = { - 'File': filename, + "File": filename, "Number of atoms in supercell": num_of_atoms, - "Processing time (s)": round(elapsed_time, 3) + "Processing time (s)": round(elapsed_time, 3), } log_list.append(data) @@ -147,10 +135,8 @@ def main(is_iteractive_mode=True, dir_path=None): prompt.print_dict_in_json(dist_mix_element_pair_dict) - missing_label_pairs = bond.get_sorted_missing_pairs( - dist_mix_pair_dict - ) - + missing_label_pairs = bond.get_sorted_missing_pairs(dist_mix_pair_dict) + missing_element_pairs = bond.get_sorted_missing_pairs( dist_mix_element_pair_dict ) @@ -170,54 +156,40 @@ def main(is_iteractive_mode=True, dir_path=None): dist_mix_pair_dict, missing_label_pairs, "summary_label.txt", - dir_path + dir_path, ) - + # Save Excel file with label pair excel.write_label_pair_dict_to_excel_json( - dist_mix_pair_dict, - "label", - dir_path + dist_mix_pair_dict, "label", dir_path ) # Draw histograms with label pari - histogram.plot_histograms_from_label_dict( - dist_mix_pair_dict, - dir_path - ) - + histogram.plot_histograms_from_label_dict(dist_mix_pair_dict, dir_path) + # Write elesummary-element.txt writer.write_summary_and_missing_pairs_with_element_dict( dist_mix_element_pair_dict, missing_element_pairs, "summary_element.txt", - dir_path - ) + dir_path, + ) # Save Excel file with element pair excel.write_element_pair_dict_to_excel_json( - dist_mix_element_pair_dict, - "element", - dir_path + dist_mix_element_pair_dict, "element", dir_path ) # Draw histograms with element pair histogram.plot_histograms_from_element_dict( - dist_mix_element_pair_dict, - dir_path + dist_mix_element_pair_dict, dir_path ) - - total_elapsed_time = time.perf_counter() - overall_start_time print(f"Total processing time: {total_elapsed_time:.2f}s") # Save log csv - folder.save_to_csv_directory( - dir_path, - pd.DataFrame(log_list), - "log" - ) + folder.save_to_csv_directory(dir_path, pd.DataFrame(log_list), "log") # print("\nAll files successfully processed.") diff --git a/util/prompt.py b/util/prompt.py index ab9f5a1..727b350 100644 --- a/util/prompt.py +++ b/util/prompt.py @@ -39,7 +39,7 @@ def get_user_input_on_supercell_method(): click.echo("1. No shift (fastest)") click.echo("2. +1 +1 +1 shifts in x, y, z directions") click.echo( - "3. +-1, +-1, +-1 shifts (2x2x2 supercell generation, requires heavy computation, slowest)" + "3. +-1, +-1, +-1 shifts (2x2x2 supercell generation, slowest)" ) method = click.prompt( @@ -54,7 +54,7 @@ def get_user_input_on_supercell_method(): ) elif method == 3: click.echo( - "You've selected: +-1, +-1, +-1 shifts (2x2x2 supercell generation, slowest)\n" + "You've selected: +-1, +-1, +-1 shifts (2x2x2 supercell, slowest)\n" ) else: click.echo("Invalid option. Defaulting to No shift (fastest)\n") diff --git a/util/string_parser.py b/util/string_parser.py index fe1b988..64064dc 100755 --- a/util/string_parser.py +++ b/util/string_parser.py @@ -1,6 +1,3 @@ -import re - - def remove_string_braket(value_string): """ Removes parentheses from a value string and convert to float if possible.