Skip to content

Commit

Permalink
add user arguments to DistanceConfFilter
Browse files Browse the repository at this point in the history
Signed-off-by: zjgemi <liuxin_zijian@163.com>
  • Loading branch information
zjgemi committed Aug 19, 2024
1 parent dfbe2e7 commit e963cc8
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 18 deletions.
26 changes: 24 additions & 2 deletions dpgen2/entrypoint/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from dpgen2.exploration.report import (
conv_styles,
)
from dpgen2.exploration.selector import (
conf_filter_styles,
)
from dpgen2.fp import (
fp_styles,
)
Expand Down Expand Up @@ -174,6 +177,25 @@ def variant_conf():
)


def variant_filter():
doc = "the type of the configuration filter."
var_list = []
for kk in conf_filter_styles.keys():
var_list.append(
Argument(
kk,
dict,
conf_filter_styles[kk].args(),
doc="Configuration filter of type %s" % kk,
)
)
return Variant(
"type",
var_list,
doc=doc,
)


def lmp_args():
doc_config = "Configuration of lmp exploration"
doc_max_numb_iter = "Maximum number of iterations per stage"
Expand Down Expand Up @@ -228,7 +250,7 @@ def lmp_args():
alias=["configuration"],
),
Argument("stages", List[List[dict]], optional=False, doc=doc_stages),
Argument("filters", List[dict], optional=True, default=[], doc=doc_filters),
Argument("filters", list, [], [variant_filter()], optional=True, default=[], doc=doc_filters),
]


Expand Down Expand Up @@ -313,7 +335,7 @@ def caly_args():
alias=["configuration"],
),
Argument("stages", List[List[dict]], optional=False, doc=doc_stages),
Argument("filters", List[dict], optional=True, default=[], doc=doc_filters),
Argument("filters", list, [], [variant_filter()], optional=True, default=[], doc=doc_filters),
]


Expand Down
84 changes: 68 additions & 16 deletions dpgen2/exploration/selector/distance_conf_filter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import shutil

import dpdata
from copy import deepcopy
import dargs
from dargs import Argument
from typing import List
import numpy as np

from . import (
Expand Down Expand Up @@ -107,8 +107,6 @@
"Cm": 2.8,
"Cf": 2.3,
}
for k in safe_dist_dict:
safe_dist_dict[k] *= 0.441


def check_multiples(a, b, c, multiple):
Expand All @@ -126,8 +124,11 @@ def check_multiples(a, b, c, multiple):


class DistanceConfFilter(ConfFilter):
def __init__(self):
pass
def __init__(self, custom_safe_dist=None, safe_dist_ratio=1.0, theta=60.0, length_ratio=5.0):
self.custom_safe_dist = custom_safe_dist if custom_safe_dist is not None else {}
self.safe_dist_ratio = safe_dist_ratio
self.theta = theta
self.length_ratio = length_ratio

def check(
self,
Expand All @@ -143,20 +144,26 @@ def check(
make_supercell,
)

atom_names = list(safe_dist_dict)
safe_dist = deepcopy(safe_dist_dict)
safe_dist.update(self.custom_safe_dist)
for k in safe_dist:
# bohr -> ang and multiply by a relaxation ratio
safe_dist[k] *= 0.529/1.2*self.safe_dist_ratio

atom_names = list(safe_dist)
structure = Atoms(
positions=coords,
numbers=[atom_names.index(n) + 1 for n in atom_types],
cell=cell,
pbc=(not nopbc),
)

cell = structure.get_cell()
cell, _ = structure.get_cell().standard_form() # type: ignore

if (
cell[1][0] > 1.732 * cell[1][1]
or cell[2][0] > 1.732 * cell[2][2]
or cell[2][1] > 1.732 * cell[2][2]
cell[1][0] > np.tan(self.theta/180.*np.pi) * cell[1][1]
or cell[2][0] > np.tan(self.theta/180.*np.pi) * cell[2][2]
or cell[2][1] > np.tan(self.theta/180.*np.pi) * cell[2][2]
):
print("Inclined box")
return False
Expand All @@ -165,8 +172,8 @@ def check(
b = cell[1][1]
c = cell[2][2]

if check_multiples(a, b, c, 5):
print("One side is 5 larger than another")
if check_multiples(a, b, c, self.length_ratio):
print("One side is %s larger than another" % self.length_ratio)
return False

P = [[2, 0, 0], [0, 2, 0], [0, 0, 2]]
Expand All @@ -181,7 +188,7 @@ def check(
dist = extended_structure.get_distance(i, j, mic=True)
type_i = symbols[i]
type_j = symbols[j]
dr = safe_dist_dict[type_i] + safe_dist_dict[type_j]
dr = safe_dist[type_i] + safe_dist[type_j]

if dist < dr:
print(
Expand All @@ -191,3 +198,48 @@ def check(

print("Valid structure")
return True

@staticmethod
def args() -> List[dargs.Argument]:
r"""The argument definition of the `ConfFilter`.
Returns
-------
arguments: List[dargs.Argument]
List of dargs.Argument defines the arguments of the `ConfFilter`.
"""

doc_custom_safe_dist = "Custom safe distance for each element"
doc_safe_dist_ratio = "The ratio multiplied to the safe distance"
doc_theta = "The threshold for the angle of the box"
doc_length_ratio = "The threshold for the length ratio of the box"
return [
Argument(
"custom_safe_dist",
dict,
optional=True,
default={},
doc=doc_custom_safe_dist,
),
Argument(
"safe_dist_ratio",
float,
optional=True,
default=1.0,
doc=doc_safe_dist_ratio,
),
Argument(
"theta",
float,
optional=True,
default=60.0,
doc=doc_theta,
),
Argument(
"length_ratio",
float,
optional=True,
default=5.0,
doc=doc_length_ratio,
),
]

0 comments on commit e963cc8

Please sign in to comment.