From e963cc8f26a211649d53573fe7fa0e03d0420d92 Mon Sep 17 00:00:00 2001 From: zjgemi Date: Mon, 19 Aug 2024 11:36:27 +0800 Subject: [PATCH] add user arguments to DistanceConfFilter Signed-off-by: zjgemi --- dpgen2/entrypoint/args.py | 26 +++++- .../selector/distance_conf_filter.py | 84 +++++++++++++++---- 2 files changed, 92 insertions(+), 18 deletions(-) diff --git a/dpgen2/entrypoint/args.py b/dpgen2/entrypoint/args.py index a7cab63c..faa1f5c3 100644 --- a/dpgen2/entrypoint/args.py +++ b/dpgen2/entrypoint/args.py @@ -19,6 +19,9 @@ from dpgen2.exploration.report import ( conv_styles, ) +from dpgen2.exploration.selector import ( + conf_filter_styles, +) from dpgen2.fp import ( fp_styles, ) @@ -174,6 +177,25 @@ def variant_conf(): ) +def variant_filter(): + doc = "the type of the configuration filter." + var_list = [] + for kk in conf_filter_styles.keys(): + var_list.append( + Argument( + kk, + dict, + conf_filter_styles[kk].args(), + doc="Configuration filter of type %s" % kk, + ) + ) + return Variant( + "type", + var_list, + doc=doc, + ) + + def lmp_args(): doc_config = "Configuration of lmp exploration" doc_max_numb_iter = "Maximum number of iterations per stage" @@ -228,7 +250,7 @@ def lmp_args(): alias=["configuration"], ), Argument("stages", List[List[dict]], optional=False, doc=doc_stages), - Argument("filters", List[dict], optional=True, default=[], doc=doc_filters), + Argument("filters", list, [], [variant_filter()], optional=True, default=[], doc=doc_filters), ] @@ -313,7 +335,7 @@ def caly_args(): alias=["configuration"], ), Argument("stages", List[List[dict]], optional=False, doc=doc_stages), - Argument("filters", List[dict], optional=True, default=[], doc=doc_filters), + Argument("filters", list, [], [variant_filter()], optional=True, default=[], doc=doc_filters), ] diff --git a/dpgen2/exploration/selector/distance_conf_filter.py b/dpgen2/exploration/selector/distance_conf_filter.py index 69268197..c61fd89a 100644 --- a/dpgen2/exploration/selector/distance_conf_filter.py +++ b/dpgen2/exploration/selector/distance_conf_filter.py @@ -1,7 +1,7 @@ -import os -import shutil - -import dpdata +from copy import deepcopy +import dargs +from dargs import Argument +from typing import List import numpy as np from . import ( @@ -107,8 +107,6 @@ "Cm": 2.8, "Cf": 2.3, } -for k in safe_dist_dict: - safe_dist_dict[k] *= 0.441 def check_multiples(a, b, c, multiple): @@ -126,8 +124,11 @@ def check_multiples(a, b, c, multiple): class DistanceConfFilter(ConfFilter): - def __init__(self): - pass + def __init__(self, custom_safe_dist=None, safe_dist_ratio=1.0, theta=60.0, length_ratio=5.0): + self.custom_safe_dist = custom_safe_dist if custom_safe_dist is not None else {} + self.safe_dist_ratio = safe_dist_ratio + self.theta = theta + self.length_ratio = length_ratio def check( self, @@ -143,7 +144,13 @@ def check( make_supercell, ) - atom_names = list(safe_dist_dict) + safe_dist = deepcopy(safe_dist_dict) + safe_dist.update(self.custom_safe_dist) + for k in safe_dist: + # bohr -> ang and multiply by a relaxation ratio + safe_dist[k] *= 0.529/1.2*self.safe_dist_ratio + + atom_names = list(safe_dist) structure = Atoms( positions=coords, numbers=[atom_names.index(n) + 1 for n in atom_types], @@ -151,12 +158,12 @@ def check( pbc=(not nopbc), ) - cell = structure.get_cell() + cell, _ = structure.get_cell().standard_form() # type: ignore if ( - cell[1][0] > 1.732 * cell[1][1] - or cell[2][0] > 1.732 * cell[2][2] - or cell[2][1] > 1.732 * cell[2][2] + cell[1][0] > np.tan(self.theta/180.*np.pi) * cell[1][1] + or cell[2][0] > np.tan(self.theta/180.*np.pi) * cell[2][2] + or cell[2][1] > np.tan(self.theta/180.*np.pi) * cell[2][2] ): print("Inclined box") return False @@ -165,8 +172,8 @@ def check( b = cell[1][1] c = cell[2][2] - if check_multiples(a, b, c, 5): - print("One side is 5 larger than another") + if check_multiples(a, b, c, self.length_ratio): + print("One side is %s larger than another" % self.length_ratio) return False P = [[2, 0, 0], [0, 2, 0], [0, 0, 2]] @@ -181,7 +188,7 @@ def check( dist = extended_structure.get_distance(i, j, mic=True) type_i = symbols[i] type_j = symbols[j] - dr = safe_dist_dict[type_i] + safe_dist_dict[type_j] + dr = safe_dist[type_i] + safe_dist[type_j] if dist < dr: print( @@ -191,3 +198,48 @@ def check( print("Valid structure") return True + + @staticmethod + def args() -> List[dargs.Argument]: + r"""The argument definition of the `ConfFilter`. + + Returns + ------- + arguments: List[dargs.Argument] + List of dargs.Argument defines the arguments of the `ConfFilter`. + """ + + doc_custom_safe_dist = "Custom safe distance for each element" + doc_safe_dist_ratio = "The ratio multiplied to the safe distance" + doc_theta = "The threshold for the angle of the box" + doc_length_ratio = "The threshold for the length ratio of the box" + return [ + Argument( + "custom_safe_dist", + dict, + optional=True, + default={}, + doc=doc_custom_safe_dist, + ), + Argument( + "safe_dist_ratio", + float, + optional=True, + default=1.0, + doc=doc_safe_dist_ratio, + ), + Argument( + "theta", + float, + optional=True, + default=60.0, + doc=doc_theta, + ), + Argument( + "length_ratio", + float, + optional=True, + default=5.0, + doc=doc_length_ratio, + ), + ]