Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move some imports into runtime functions #894

Merged
merged 1 commit into from
Aug 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions dpgen/dispatcher/AWS.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,18 @@
from dpgen.dispatcher.JobStatus import JobStatus
from dpgen import dlog

try:
import boto3
except ModuleNotFoundError:
pass
else:
batch_client = boto3.client('batch')

class AWS(Batch):
_query_time_interval = 30
_job_id_map_status = {}
_jobQueue = ""
_query_next_allow_time = datetime.now().timestamp()

def __init__(self, context, uuid_names=True):
import boto3
self.batch_client = boto3.client('batch')
super().__init__(context, uuid_names)

@staticmethod
def map_aws_status_to_dpgen_status(aws_status):
map_dict = {'SUBMITTED': JobStatus.waiting,
Expand Down Expand Up @@ -47,7 +46,7 @@ def AWS_check_status(cls, job_id=""):
for status in ['SUBMITTED', 'PENDING', 'RUNNABLE', 'STARTING', 'RUNNING','SUCCEEDED', 'FAILED']:
nextToken = ''
while nextToken is not None:
status_response = batch_client.list_jobs(jobQueue=cls._jobQueue, jobStatus=status, maxResults=100, nextToken=nextToken)
status_response = self.batch_client.list_jobs(jobQueue=cls._jobQueue, jobStatus=status, maxResults=100, nextToken=nextToken)
status_list=status_response.get('jobSummaryList')
nextToken = status_response.get('nextToken', None)
for job_dict in status_list:
Expand All @@ -66,7 +65,7 @@ def job_id(self):
except AttributeError:
if self.context.check_file_exists(self.job_id_name):
self._job_id = self.context.read_file(self.job_id_name)
response_list = batch_client.describe_jobs(jobs=[self._job_id]).get('jobs')
response_list = self.batch_client.describe_jobs(jobs=[self._job_id]).get('jobs')
try:
response = response_list[0]
jobQueue = response['jobQueue']
Expand Down Expand Up @@ -134,7 +133,7 @@ def do_submit(self,
"""
jobName = os.path.join(self.context.remote_root,job_dirs.pop())[1:].replace('/','-').replace('.','_')
jobName += ("_" + str(self.context.job_uuid))
response = batch_client.submit_job(jobName=jobName,
response = self.batch_client.submit_job(jobName=jobName,
jobQueue=res['jobQueue'],
jobDefinition=res['jobDefinition'],
parameters={'task_command':script_str},
Expand Down
10 changes: 5 additions & 5 deletions dpgen/generator/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,7 @@
from dpgen import ROOT_PATH
from pymatgen.io.vasp import Incar,Kpoints,Potcar
from dpgen.auto_test.lib.vasp import make_kspacing_kpoints
try:
from gromacs.fileformats.mdp import MDP
except ImportError:
dlog.info("GromacsWrapper>=0.8.0 is needed for DP-GEN + Gromacs.")
pass


template_name = 'template'
train_name = '00.train'
Expand Down Expand Up @@ -1209,6 +1205,10 @@ def _make_model_devi_native(iter_index, jdata, mdata, conf_systems):
sys_counter += 1

def _make_model_devi_native_gromacs(iter_index, jdata, mdata, conf_systems):
try:
from gromacs.fileformats.mdp import MDP
except ImportError as e:
raise RuntimeError("GromacsWrapper>=0.8.0 is needed for DP-GEN + Gromacs.") from e
# only support for deepmd v2.0
if LooseVersion(mdata['deepmd_version']) < LooseVersion('2.0'):
raise RuntimeError("Only support deepmd-kit 2.x for model_devi_engine='gromacs'")
Expand Down