Skip to content

Commit

Permalink
move some imports into runtime functions (#894)
Browse files Browse the repository at this point in the history
gromacswrapper throws lots of warnings. boto3 may cause critical errors.
We move these imports into the runtime functions.

Fix #674.
  • Loading branch information
njzjz authored Aug 28, 2022
1 parent 4cb8550 commit 5bb5561
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 14 deletions.
17 changes: 8 additions & 9 deletions dpgen/dispatcher/AWS.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,18 @@
from dpgen.dispatcher.JobStatus import JobStatus
from dpgen import dlog

try:
import boto3
except ModuleNotFoundError:
pass
else:
batch_client = boto3.client('batch')

class AWS(Batch):
_query_time_interval = 30
_job_id_map_status = {}
_jobQueue = ""
_query_next_allow_time = datetime.now().timestamp()

def __init__(self, context, uuid_names=True):
import boto3
self.batch_client = boto3.client('batch')
super().__init__(context, uuid_names)

@staticmethod
def map_aws_status_to_dpgen_status(aws_status):
map_dict = {'SUBMITTED': JobStatus.waiting,
Expand Down Expand Up @@ -47,7 +46,7 @@ def AWS_check_status(cls, job_id=""):
for status in ['SUBMITTED', 'PENDING', 'RUNNABLE', 'STARTING', 'RUNNING','SUCCEEDED', 'FAILED']:
nextToken = ''
while nextToken is not None:
status_response = batch_client.list_jobs(jobQueue=cls._jobQueue, jobStatus=status, maxResults=100, nextToken=nextToken)
status_response = self.batch_client.list_jobs(jobQueue=cls._jobQueue, jobStatus=status, maxResults=100, nextToken=nextToken)
status_list=status_response.get('jobSummaryList')
nextToken = status_response.get('nextToken', None)
for job_dict in status_list:
Expand All @@ -66,7 +65,7 @@ def job_id(self):
except AttributeError:
if self.context.check_file_exists(self.job_id_name):
self._job_id = self.context.read_file(self.job_id_name)
response_list = batch_client.describe_jobs(jobs=[self._job_id]).get('jobs')
response_list = self.batch_client.describe_jobs(jobs=[self._job_id]).get('jobs')
try:
response = response_list[0]
jobQueue = response['jobQueue']
Expand Down Expand Up @@ -134,7 +133,7 @@ def do_submit(self,
"""
jobName = os.path.join(self.context.remote_root,job_dirs.pop())[1:].replace('/','-').replace('.','_')
jobName += ("_" + str(self.context.job_uuid))
response = batch_client.submit_job(jobName=jobName,
response = self.batch_client.submit_job(jobName=jobName,
jobQueue=res['jobQueue'],
jobDefinition=res['jobDefinition'],
parameters={'task_command':script_str},
Expand Down
10 changes: 5 additions & 5 deletions dpgen/generator/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,7 @@
from dpgen import ROOT_PATH
from pymatgen.io.vasp import Incar,Kpoints,Potcar
from dpgen.auto_test.lib.vasp import make_kspacing_kpoints
try:
from gromacs.fileformats.mdp import MDP
except ImportError:
dlog.info("GromacsWrapper>=0.8.0 is needed for DP-GEN + Gromacs.")
pass


template_name = 'template'
train_name = '00.train'
Expand Down Expand Up @@ -1209,6 +1205,10 @@ def _make_model_devi_native(iter_index, jdata, mdata, conf_systems):
sys_counter += 1

def _make_model_devi_native_gromacs(iter_index, jdata, mdata, conf_systems):
try:
from gromacs.fileformats.mdp import MDP
except ImportError as e:
raise RuntimeError("GromacsWrapper>=0.8.0 is needed for DP-GEN + Gromacs.") from e
# only support for deepmd v2.0
if LooseVersion(mdata['deepmd_version']) < LooseVersion('2.0'):
raise RuntimeError("Only support deepmd-kit 2.x for model_devi_engine='gromacs'")
Expand Down

0 comments on commit 5bb5561

Please sign in to comment.