@@ -73,7 +73,8 @@ def main(args):
73
73
if args .modified_base_model is not None or args .modified_bases is not None :
74
74
sys .stderr .write ("> loading modified base model\n " )
75
75
mods_model = load_mods_model (
76
- args .modified_bases , args .model_directory , args .modified_base_model
76
+ args .modified_bases , args .model_directory , args .modified_base_model ,
77
+ device = args .modified_device ,
77
78
)
78
79
sys .stderr .write (f"> { mods_model [1 ]['alphabet_str' ]} \n " )
79
80
@@ -131,9 +132,12 @@ def main(args):
131
132
)
132
133
133
134
if mods_model is not None :
134
- results = process_itemmap (
135
- partial (call_mods , mods_model ), results , n_proc = args .modified_procs
136
- )
135
+ if args .modified_device :
136
+ results = ((k , call_mods (mods_model , k , v )) for k , v in results )
137
+ else :
138
+ results = process_itemmap (
139
+ partial (call_mods , mods_model ), results , n_proc = args .modified_procs
140
+ )
137
141
if aligner :
138
142
results = align_map (aligner , results , n_thread = args .alignment_threads )
139
143
@@ -167,6 +171,7 @@ def argparser():
167
171
parser .add_argument ("--modified-bases" , nargs = "+" )
168
172
parser .add_argument ("--modified-base-model" )
169
173
parser .add_argument ("--modified-procs" , default = 8 , type = int )
174
+ parser .add_argument ("--modified-device" , default = None )
170
175
parser .add_argument ("--read-ids" )
171
176
parser .add_argument ("--device" , default = "cuda" )
172
177
parser .add_argument ("--seed" , default = 25 , type = int )
0 commit comments