diff --git a/deepmd/descriptor/se_a.py b/deepmd/descriptor/se_a.py index bdc2b37d7b..67bf896fbe 100644 --- a/deepmd/descriptor/se_a.py +++ b/deepmd/descriptor/se_a.py @@ -720,12 +720,12 @@ def _filter_lower( raise RuntimeError('compression of type embedded descriptor is not supported at the moment') # natom x 4 x outputs_size if self.compress and (not is_exclude): - info = [self.lower, self.upper, self.upper * self.table_config[0], self.table_config[1], self.table_config[2], self.table_config[3]] - if self.type_one_side: - net = 'filter_-1_net_' + str(type_i) - else: - net = 'filter_' + str(type_input) + '_net_' + str(type_i) - return op_module.tabulate_fusion_se_a(tf.cast(self.table.data[net], self.filter_precision), info, xyz_scatter, tf.reshape(inputs_i, [natom, shape_i[1]//4, 4]), last_layer_size = outputs_size[-1]) + if self.type_one_side: + net = 'filter_-1_net_' + str(type_i) + else: + net = 'filter_' + str(type_input) + '_net_' + str(type_i) + info = [self.lower[net], self.upper[net], self.upper[net] * self.table_config[0], self.table_config[1], self.table_config[2], self.table_config[3]] + return op_module.tabulate_fusion_se_a(tf.cast(self.table.data[net], self.filter_precision), info, xyz_scatter, tf.reshape(inputs_i, [natom, shape_i[1]//4, 4]), last_layer_size = outputs_size[-1]) else: if (not is_exclude): # with (natom x nei_type_i) x out_size diff --git a/deepmd/descriptor/se_r.py b/deepmd/descriptor/se_r.py index b9e5d5aabd..b85e9ecc49 100644 --- a/deepmd/descriptor/se_r.py +++ b/deepmd/descriptor/se_r.py @@ -540,8 +540,8 @@ def _filter_r(self, # with (natom x nei_type_i) x 1 xyz_scatter = tf.reshape(inputs_i, [-1, 1]) if self.compress and ((type_input, type_i) not in self.exclude_types): - info = [self.lower, self.upper, self.upper * self.table_config[0], self.table_config[1], self.table_config[2], self.table_config[3]] net = 'filter_' + str(type_input) + '_net_' + str(type_i) + info = [self.lower[net], self.upper[net], self.upper[net] * self.table_config[0], self.table_config[1], self.table_config[2], self.table_config[3]] xyz_scatter = op_module.tabulate_fusion_se_r(tf.cast(self.table.data[net], self.filter_precision), info, inputs_i, last_layer_size = outputs_size[-1]) elif (type_input, type_i) not in self.exclude_types: xyz_scatter = embedding_net(xyz_scatter, diff --git a/deepmd/descriptor/se_t.py b/deepmd/descriptor/se_t.py index 1735757dcb..b52883023b 100644 --- a/deepmd/descriptor/se_t.py +++ b/deepmd/descriptor/se_t.py @@ -559,8 +559,8 @@ def _filter(self, # with (natom x nei_type_i x nei_type_j) ebd_env_ij = tf.reshape(env_ij, [-1, 1]) if self.compress: - info = [self.lower, self.upper, self.upper * self.table_config[0], self.table_config[1], self.table_config[2], self.table_config[3]] net = 'filter_' + str(type_i) + '_net_' + str(type_j) + info = [self.lower[net], self.upper[net], self.upper[net] * self.table_config[0], self.table_config[1], self.table_config[2], self.table_config[3]] res_ij = op_module.tabulate_fusion_se_t(tf.cast(self.table.data[net], self.filter_precision), info, ebd_env_ij, env_ij, last_layer_size = outputs_size[-1]) else: # with (natom x nei_type_i x nei_type_j) x out_size diff --git a/deepmd/utils/tabulate.py b/deepmd/utils/tabulate.py index ee1088bd3c..a6d3f40f2c 100644 --- a/deepmd/utils/tabulate.py +++ b/deepmd/utils/tabulate.py @@ -1,9 +1,8 @@ -import math import logging import numpy as np import deepmd from typing import Callable -from typing import Tuple, List +from typing import Tuple, List, Dict from functools import lru_cache from scipy.special import comb from deepmd.env import tf @@ -137,12 +136,15 @@ def __init__(self, self.data = {} + self.upper = {} + self.lower = {} + def build(self, min_nbor_dist : float, extrapolate : float, stride0 : float, - stride1 : float) -> Tuple[int, int]: + stride1 : float) -> Tuple[Dict[str, int], Dict[str, int]]: """ Build the tables for model compression @@ -161,81 +163,100 @@ def build(self, Returns ---------- - lower - The lower boundary of environment matrix - upper - The upper boundary of environment matrix + lower : dict[str, int] + The lower boundary of environment matrix by net + upper : dict[str, int] + The upper boundary of environment matrix by net """ # tabulate range [lower, upper] with stride0 'stride0' lower, upper = self._get_env_mat_range(min_nbor_dist) - if isinstance(self.descrpt, deepmd.descriptor.DescrptSeA): - xx = np.arange(lower, upper, stride0, dtype = self.data_type) - xx = np.append(xx, np.arange(upper, extrapolate * upper, stride1, dtype = self.data_type)) - xx = np.append(xx, np.array([extrapolate * upper], dtype = self.data_type)) - self.nspline = int((upper - lower) / stride0 + (extrapolate * upper - upper) / stride1) for ii in range(self.table_size): if (self.type_one_side and not self._all_excluded(ii)) or (not self.type_one_side and (ii // self.ntypes, ii % self.ntypes) not in self.exclude_types): if self.type_one_side: net = "filter_-1_net_" + str(ii) + # upper and lower should consider all types which are not excluded and sel>0 + idx = [(type_i, ii) not in self.exclude_types and self.sel_a[type_i] > 0 for type_i in range(self.ntypes)] + uu = np.max(upper[idx]) + ll = np.min(lower[idx]) else: - net = "filter_" + str(ii // self.ntypes) + "_net_" + str(ii % self.ntypes) - self._build_lower(net, xx, ii, upper, lower, stride0, stride1, extrapolate) + ielement = ii // self.ntypes + net = "filter_" + str(ielement) + "_net_" + str(ii % self.ntypes) + uu = upper[ielement] + ll = lower[ielement] + xx = np.arange(ll, uu, stride0, dtype = self.data_type) + xx = np.append(xx, np.arange(uu, extrapolate * uu, stride1, dtype = self.data_type)) + xx = np.append(xx, np.array([extrapolate * uu], dtype = self.data_type)) + nspline = ((uu - ll) / stride0 + (extrapolate * uu - uu) / stride1).astype(int) + self._build_lower(net, xx, ii, uu, ll, stride0, stride1, extrapolate, nspline) elif isinstance(self.descrpt, deepmd.descriptor.DescrptSeT): - xx = np.arange(extrapolate * lower, lower, stride1, dtype = self.data_type) - xx = np.append(xx, np.arange(lower, upper, stride0, dtype = self.data_type)) - xx = np.append(xx, np.arange(upper, extrapolate * upper, stride1, dtype = self.data_type)) - xx = np.append(xx, np.array([extrapolate * upper], dtype = self.data_type)) - self.nspline = int((upper - lower) / stride0 + 2 * ((extrapolate * upper - upper) / stride1)) + xx_all = [] + for ii in range(self.ntypes): + xx = np.arange(extrapolate * lower[ii], lower[ii], stride1, dtype = self.data_type) + xx = np.append(xx, np.arange(lower[ii], upper[ii], stride0, dtype = self.data_type)) + xx = np.append(xx, np.arange(upper[ii], extrapolate * upper[ii], stride1, dtype = self.data_type)) + xx = np.append(xx, np.array([extrapolate * upper[ii]], dtype = self.data_type)) + xx_all.append(xx) + nspline = ((upper - lower) / stride0 + 2 * ((extrapolate * upper - upper) / stride1)).astype(int) idx = 0 for ii in range(self.ntypes): for jj in range(ii, self.ntypes): net = "filter_" + str(ii) + "_net_" + str(jj) - self._build_lower(net, xx, idx, upper, lower, stride0, stride1, extrapolate) + self._build_lower(net, xx_all[ii], idx, upper[ii], lower[ii], stride0, stride1, extrapolate, nspline[ii]) idx += 1 elif isinstance(self.descrpt, deepmd.descriptor.DescrptSeR): - xx = np.arange(lower, upper, stride0, dtype = self.data_type) - xx = np.append(xx, np.arange(upper, extrapolate * upper, stride1, dtype = self.data_type)) - xx = np.append(xx, np.array([extrapolate * upper], dtype = self.data_type)) - self.nspline = int((upper - lower) / stride0 + (extrapolate * upper - upper) / stride1) for ii in range(self.table_size): if (self.type_one_side and not self._all_excluded(ii)) or (not self.type_one_side and (ii // self.ntypes, ii % self.ntypes) not in self.exclude_types): if self.type_one_side: net = "filter_-1_net_" + str(ii) + # upper and lower should consider all types which are not excluded and sel>0 + idx = [(type_i, ii) not in self.exclude_types and self.sel_a[type_i] > 0 for type_i in range(self.ntypes)] + uu = np.max(upper[idx]) + ll = np.min(lower[idx]) else: - net = "filter_" + str(ii // self.ntypes) + "_net_" + str(ii % self.ntypes) - self._build_lower(net, xx, ii, upper, lower, stride0, stride1, extrapolate) + ielement = ii // self.ntypes + net = "filter_" + str(ielement) + "_net_" + str(ii % self.ntypes) + uu = upper[ielement] + ll = lower[ielement] + xx = np.arange(ll, uu, stride0, dtype = self.data_type) + xx = np.append(xx, np.arange(uu, extrapolate * uu, stride1, dtype = self.data_type)) + xx = np.append(xx, np.array([extrapolate * uu], dtype = self.data_type)) + nspline = ((uu - ll) / stride0 + (extrapolate * uu - uu) / stride1).astype(int) + self._build_lower(net, xx, ii, uu, ll, stride0, stride1, extrapolate, nspline) else: raise RuntimeError("Unsupported descriptor") - return lower, upper + return self.lower, self.upper - def _build_lower(self, net, xx, idx, upper, lower, stride0, stride1, extrapolate): + def _build_lower(self, net, xx, idx, upper, lower, stride0, stride1, extrapolate, nspline): vv, dd, d2 = self._make_data(xx, idx) - self.data[net] = np.zeros([self.nspline, 6 * self.last_layer_size], dtype = self.data_type) + self.data[net] = np.zeros([nspline, 6 * self.last_layer_size], dtype = self.data_type) - # tt.shape: [self.nspline, self.last_layer_size] + # tt.shape: [nspline, self.last_layer_size] if isinstance(self.descrpt, deepmd.descriptor.DescrptSeA): - tt = np.full((self.nspline, self.last_layer_size), stride1) + tt = np.full((nspline, self.last_layer_size), stride1) tt[:int((upper - lower) / stride0), :] = stride0 elif isinstance(self.descrpt, deepmd.descriptor.DescrptSeT): - tt = np.full((self.nspline, self.last_layer_size), stride1) + tt = np.full((nspline, self.last_layer_size), stride1) tt[int((lower - extrapolate * lower) / stride1) + 1:(int((lower - extrapolate * lower) / stride1) + int((upper - lower) / stride0)), :] = stride0 elif isinstance(self.descrpt, deepmd.descriptor.DescrptSeR): - tt = np.full((self.nspline, self.last_layer_size), stride1) + tt = np.full((nspline, self.last_layer_size), stride1) tt[:int((upper - lower) / stride0), :] = stride0 else: raise RuntimeError("Unsupported descriptor") - # hh.shape: [self.nspline, self.last_layer_size] - hh = vv[1:self.nspline+1, :self.last_layer_size] - vv[:self.nspline, :self.last_layer_size] + # hh.shape: [nspline, self.last_layer_size] + hh = vv[1:nspline+1, :self.last_layer_size] - vv[:nspline, :self.last_layer_size] + + self.data[net][:, :6 * self.last_layer_size:6] = vv[:nspline, :self.last_layer_size] + self.data[net][:, 1:6 * self.last_layer_size:6] = dd[:nspline, :self.last_layer_size] + self.data[net][:, 2:6 * self.last_layer_size:6] = 0.5 * d2[:nspline, :self.last_layer_size] + self.data[net][:, 3:6 * self.last_layer_size:6] = (1 / (2 * tt * tt * tt)) * (20 * hh - (8 * dd[1:nspline+1, :self.last_layer_size] + 12 * dd[:nspline, :self.last_layer_size]) * tt - (3 * d2[:nspline, :self.last_layer_size] - d2[1:nspline+1, :self.last_layer_size]) * tt * tt) + self.data[net][:, 4:6 * self.last_layer_size:6] = (1 / (2 * tt * tt * tt * tt)) * (-30 * hh + (14 * dd[1:nspline+1, :self.last_layer_size] + 16 * dd[:nspline, :self.last_layer_size]) * tt + (3 * d2[:nspline, :self.last_layer_size] - 2 * d2[1:nspline+1, :self.last_layer_size]) * tt * tt) + self.data[net][:, 5:6 * self.last_layer_size:6] = (1 / (2 * tt * tt * tt * tt * tt)) * (12 * hh - 6 * (dd[1:nspline+1, :self.last_layer_size] + dd[:nspline, :self.last_layer_size]) * tt + (d2[1:nspline+1, :self.last_layer_size] - d2[:nspline, :self.last_layer_size]) * tt * tt) - self.data[net][:, :6 * self.last_layer_size:6] = vv[:self.nspline, :self.last_layer_size] - self.data[net][:, 1:6 * self.last_layer_size:6] = dd[:self.nspline, :self.last_layer_size] - self.data[net][:, 2:6 * self.last_layer_size:6] = 0.5 * d2[:self.nspline, :self.last_layer_size] - self.data[net][:, 3:6 * self.last_layer_size:6] = (1 / (2 * tt * tt * tt)) * (20 * hh - (8 * dd[1:self.nspline+1, :self.last_layer_size] + 12 * dd[:self.nspline, :self.last_layer_size]) * tt - (3 * d2[:self.nspline, :self.last_layer_size] - d2[1:self.nspline+1, :self.last_layer_size]) * tt * tt) - self.data[net][:, 4:6 * self.last_layer_size:6] = (1 / (2 * tt * tt * tt * tt)) * (-30 * hh + (14 * dd[1:self.nspline+1, :self.last_layer_size] + 16 * dd[:self.nspline, :self.last_layer_size]) * tt + (3 * d2[:self.nspline, :self.last_layer_size] - 2 * d2[1:self.nspline+1, :self.last_layer_size]) * tt * tt) - self.data[net][:, 5:6 * self.last_layer_size:6] = (1 / (2 * tt * tt * tt * tt * tt)) * (12 * hh - 6 * (dd[1:self.nspline+1, :self.last_layer_size] + dd[:self.nspline, :self.last_layer_size]) * tt + (d2[1:self.nspline+1, :self.last_layer_size] - d2[:self.nspline, :self.last_layer_size]) * tt * tt) + self.upper[net] = upper + self.lower[net] = lower def _load_sub_graph(self): sub_graph_def = tf.GraphDef() @@ -387,24 +408,23 @@ def _layer_1(self, x, w, b): # Change the embedding net range to sw / min_nbor_dist def _get_env_mat_range(self, min_nbor_dist): - lower = +100.0 - upper = -100.0 sw = self._spline5_switch(min_nbor_dist, self.rcut_smth, self.rcut) if isinstance(self.descrpt, deepmd.descriptor.DescrptSeA): - lower = np.min(-self.davg[:, 0] / self.dstd[:, 0]) - upper = np.max(((1 / min_nbor_dist) * sw - self.davg[:, 0]) / self.dstd[:, 0]) + lower = -self.davg[:, 0] / self.dstd[:, 0] + upper = ((1 / min_nbor_dist) * sw - self.davg[:, 0]) / self.dstd[:, 0] elif isinstance(self.descrpt, deepmd.descriptor.DescrptSeT): var = np.square(sw / (min_nbor_dist * self.dstd[:, 1:4])) - lower = np.min(-var) - upper = np.max(var) + lower = np.min(-var, axis=1) + upper = np.max(var, axis=1) elif isinstance(self.descrpt, deepmd.descriptor.DescrptSeR): - lower = np.min(-self.davg[:, 0] / self.dstd[:, 0]) - upper = np.max(((1 / min_nbor_dist) * sw - self.davg[:, 0]) / self.dstd[:, 0]) + lower = -self.davg[:, 0] / self.dstd[:, 0] + upper = ((1 / min_nbor_dist) * sw - self.davg[:, 0]) / self.dstd[:, 0] else: raise RuntimeError("Unsupported descriptor") log.info('training data with lower boundary: ' + str(lower)) log.info('training data with upper boundary: ' + str(upper)) - return math.floor(lower), math.ceil(upper) + # returns element-wise lower and upper + return np.floor(lower), np.ceil(upper) def _spline5_switch(self, xx,