Skip to content

Commit b1ca705

Browse files
committed
Merge branch 'basemod_gpu' into 'master'
GPU mod base inference enabled See merge request machine-learning/bonito!83
2 parents 4a0d7cc + 37044f3 commit b1ca705

File tree

3 files changed

+14
-6
lines changed

3 files changed

+14
-6
lines changed

README.md

+2
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ Modified base calling is handled by [Remora](https://github.com/nanoporetech/rem
4343
$ bonito basecaller dna_r10.4_e8.1_sup@v3.4 /data/reads --modified-bases 5mC --reference ref.mmi > basecalls_with_mods.bam
4444
```
4545

46+
To use the GPU-powered modified bases inference the `onnxruntime-gpu` package is required.
47+
4648
See available modified base models with the ``remora model list_pretrained`` command.
4749

4850
## Training your own model

bonito/cli/basecaller.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ def main(args):
7373
if args.modified_base_model is not None or args.modified_bases is not None:
7474
sys.stderr.write("> loading modified base model\n")
7575
mods_model = load_mods_model(
76-
args.modified_bases, args.model_directory, args.modified_base_model
76+
args.modified_bases, args.model_directory, args.modified_base_model,
77+
device=args.modified_device,
7778
)
7879
sys.stderr.write(f"> {mods_model[1]['alphabet_str']}\n")
7980

@@ -131,9 +132,12 @@ def main(args):
131132
)
132133

133134
if mods_model is not None:
134-
results = process_itemmap(
135-
partial(call_mods, mods_model), results, n_proc=args.modified_procs
136-
)
135+
if args.modified_device:
136+
results = ((k, call_mods(mods_model, k, v)) for k, v in results)
137+
else:
138+
results = process_itemmap(
139+
partial(call_mods, mods_model), results, n_proc=args.modified_procs
140+
)
137141
if aligner:
138142
results = align_map(aligner, results, n_thread=args.alignment_threads)
139143

@@ -167,6 +171,7 @@ def argparser():
167171
parser.add_argument("--modified-bases", nargs="+")
168172
parser.add_argument("--modified-base-model")
169173
parser.add_argument("--modified-procs", default=8, type=int)
174+
parser.add_argument("--modified-device", default=None)
170175
parser.add_argument("--read-ids")
171176
parser.add_argument("--device", default="cuda")
172177
parser.add_argument("--seed", default=25, type=int)

bonito/mod_util.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def format(self, record):
3131
log.CONSOLE.setFormatter(CustomFormatter())
3232

3333

34-
def load_mods_model(mod_bases, bc_model_str, model_path):
34+
def load_mods_model(mod_bases, bc_model_str, model_path, device=None):
3535
if mod_bases is not None:
3636
try:
3737
bc_model_type, model_version = bc_model_str.split('@')
@@ -50,8 +50,9 @@ def load_mods_model(mod_bases, bc_model_str, model_path):
5050
basecall_model_version=model_version,
5151
modified_bases=mod_bases,
5252
quiet=True,
53+
device=device,
5354
)
54-
return load_model(model_path, quiet=True)
55+
return load_model(model_path, quiet=True, device=device)
5556

5657

5758
def mods_tags_to_str(mods_tags):

0 commit comments

Comments
 (0)