diff --git a/tests/create_jit_models.py b/tests/create_jit_models.py index 06fc4f48..376a76ae 100644 --- a/tests/create_jit_models.py +++ b/tests/create_jit_models.py @@ -1,4 +1,5 @@ import torch +from typing import Dict, Tuple class Foo(torch.jit.ScriptModule): def __init__(self, v): @@ -107,3 +108,12 @@ def make_input_object(self, foo, bar): foo_7 = TorchScriptExample() foo_7.save("foo7.pt") + +# https://github.com/LaurentMazare/tch-rs/issues/597 +class DictExample(torch.jit.ScriptModule): + @torch.jit.script_method + def generate(self, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + return batch["foo"], batch["bar"] + +foo_8 = DictExample() +foo_8.save("foo8.pt") diff --git a/tests/foo.pt b/tests/foo.pt index 4759d17f..ae7b2940 100644 Binary files a/tests/foo.pt and b/tests/foo.pt differ diff --git a/tests/foo1.pt b/tests/foo1.pt index 8831620b..0a6cc41e 100644 Binary files a/tests/foo1.pt and b/tests/foo1.pt differ diff --git a/tests/foo2.pt b/tests/foo2.pt index 3fcc2e2b..5c611c16 100644 Binary files a/tests/foo2.pt and b/tests/foo2.pt differ diff --git a/tests/foo3.pt b/tests/foo3.pt index e5e9a1fc..b82ffbf3 100644 Binary files a/tests/foo3.pt and b/tests/foo3.pt differ diff --git a/tests/foo4.pt b/tests/foo4.pt index ea3607e3..420a3d2d 100644 Binary files a/tests/foo4.pt and b/tests/foo4.pt differ diff --git a/tests/foo5.pt b/tests/foo5.pt index a3259085..83787825 100644 Binary files a/tests/foo5.pt and b/tests/foo5.pt differ diff --git a/tests/foo6.pt b/tests/foo6.pt index 4e7084ae..a8798663 100644 Binary files a/tests/foo6.pt and b/tests/foo6.pt differ diff --git a/tests/foo7.pt b/tests/foo7.pt index 878e9817..af51dece 100644 Binary files a/tests/foo7.pt and b/tests/foo7.pt differ diff --git a/tests/foo8.pt b/tests/foo8.pt new file mode 100644 index 00000000..e6bee10d Binary files /dev/null and b/tests/foo8.pt differ diff --git a/tests/jit_tests.rs b/tests/jit_tests.rs index f3381d5c..5118595b 100644 --- a/tests/jit_tests.rs +++ b/tests/jit_tests.rs @@ -1,4 +1,4 @@ -use std::convert::TryFrom; +use std::convert::{TryFrom, TryInto}; use tch::{IValue, Kind, Tensor}; #[test] @@ -160,3 +160,20 @@ fn jit_double_free() { }; assert_eq!(Vec::::from(&result), [5.0, 7.0, 9.0]) } + +// https://github.com/LaurentMazare/tch-rs/issues/597 +#[test] +fn specialized_dict() { + let mod_ = tch::CModule::load("tests/foo8.pt").unwrap(); + let input = IValue::GenericDict(vec![ + (IValue::String("bar".to_owned()), IValue::Tensor(Tensor::of_slice(&[1_f32, 7_f32]))), + ( + IValue::String("foo".to_owned()), + IValue::Tensor(Tensor::of_slice(&[1_f32, 2_f32, 3_f32])), + ), + ]); + let result = mod_.method_is("generate", &[input]).unwrap(); + let result: (Tensor, Tensor) = result.try_into().unwrap(); + assert_eq!(Vec::::from(&result.0), [1.0, 2.0, 3.0]); + assert_eq!(Vec::::from(&result.1), [1.0, 7.0]) +} diff --git a/torch-sys/libtch/torch_api.cpp b/torch-sys/libtch/torch_api.cpp index 091834f4..bc2f3ef3 100644 --- a/torch-sys/libtch/torch_api.cpp +++ b/torch-sys/libtch/torch_api.cpp @@ -1212,11 +1212,27 @@ ivalue ati_generic_list(ivalue *is, int nvalues) { return nullptr; } +using generic_dict = c10::Dict; + ivalue ati_generic_dict(ivalue *is, int nvalues) { - c10::Dict dict(c10::AnyType::get(), c10::AnyType::get()); PROTECT( - for (int i = 0; i < nvalues; ++i) dict.insert(*(is[2*i]), *(is[2*i+1])); - return new torch::jit::IValue(dict); + bool all_keys_are_str = true; + for (int i = 0; i < nvalues; ++i) { + if (!is[2*i]->isString()) all_keys_are_str = false; + } + bool all_values_are_tensor = true; + for (int i = 0; i < nvalues; ++i) { + if (!is[2*i+1]->isTensor()) all_values_are_tensor = false; + } + if (all_keys_are_str && all_values_are_tensor) { + generic_dict dict(c10::StringType::get(), c10::TensorType::get()); + for (int i = 0; i < nvalues; ++i) dict.insert(is[2*i]->toString(), is[2*i+1]->toTensor()); + return new torch::jit::IValue(dict); + } else { + generic_dict dict(c10::AnyType::get(), c10::AnyType::get()); + for (int i = 0; i < nvalues; ++i) dict.insert(*(is[2*i]), *(is[2*i+1])); + return new torch::jit::IValue(dict); + } ) return nullptr; }