Skip to content

Commit f000631

Browse files
anjali411facebook-github-bot
authored andcommitted
Add support for complex valued keys for dict in TS (pytorch#51472)
Summary: Pull Request resolved: pytorch#51472 Test Plan: Imported from OSS Reviewed By: ezyang Differential Revision: D26177963 Pulled By: anjali411 fbshipit-source-id: 5841159c36b07290b1d88d4df27a0bf8c17d9df8
1 parent 9c474c9 commit f000631

File tree

6 files changed

+30
-2
lines changed

6 files changed

+30
-2
lines changed

Diff for: aten/src/ATen/core/Dict.h

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ using valid_dict_key_types = guts::typelist::typelist<
2121
int64_t,
2222
std::string,
2323
double,
24+
c10::complex<double>,
2425
bool,
2526
at::Tensor
2627
>;

Diff for: aten/src/ATen/core/Dict_inl.h

+3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include <ATen/core/ivalue.h>
4+
#include <c10/util/hash.h>
45

56
namespace c10 {
67
namespace detail {
@@ -43,6 +44,8 @@ inline size_t DictKeyHash::operator()(const IValue& ivalue) const {
4344
return std::hash<std::string>()(ivalue.toStringRef());
4445
} else if (ivalue.isDouble()) {
4546
return std::hash<double>()(ivalue.toDouble());
47+
} else if (ivalue.isComplexDouble()) {
48+
return c10::hash<c10::complex<double>>()(ivalue.toComplexDouble());
4649
} else if (ivalue.isBool()) {
4750
return std::hash<bool>()(ivalue.toBool());
4851
} else if (ivalue.isTensor()) {

Diff for: aten/src/ATen/core/jit_type.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -738,14 +738,15 @@ struct TORCH_API DictType : public Type {
738738
case TypeKind::IntType:
739739
case TypeKind::BoolType:
740740
case TypeKind::FloatType:
741+
case TypeKind::ComplexDoubleType:
741742
case TypeKind::StringType:
742743
case TypeKind::TensorType:
743744
return DictTypePtr(new DictType(key, value));
744745
default:
745746
AT_ERROR(
746747
"Cannot create dict for key type '",
747748
key->str(),
748-
"', only int, float, Tensor and string keys are supported");
749+
"', only int, float, complex, Tensor and string keys are supported");
749750
}
750751
}
751752

Diff for: aten/src/ATen/test/ivalue_test.cpp

+13
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <gtest/gtest.h>
33
#include <torch/torch.h>
44
#include <c10/util/intrusive_ptr.h>
5+
#include <ATen/core/Dict.h>
56

67
namespace c10 {
78

@@ -77,6 +78,18 @@ TEST(IValueTest, Basic) {
7778
ASSERT_EQ(complex_tuple.toTuple()->elements()[1], foo1);
7879
}
7980

81+
TEST(IValueTest, ComplexDict) {
82+
typedef c10::complex<double> c_type;
83+
c10::Dict<c_type, c_type> m;
84+
auto num1 = c_type(2.3, -3.5);
85+
auto num2 = c_type(0, 5);
86+
m.insert(num1, 2 * num1);
87+
m.insert(num2, 2 * num2);
88+
IValue dict(std::move(m));
89+
auto m_ = dict.toGenericDict();
90+
ASSERT_EQ(m_.at(num1), 2 * num1);
91+
ASSERT_EQ(m_.at(num2), 2 * num2);
92+
}
8093
static std::array<IValue, 5> makeSampleIValues() {
8194
return { at::rand({3, 4}), "hello", 42, true, 1.5 };
8295
}

Diff for: test/jit/test_complex.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import sys
44
from torch.testing._internal.jit_utils import JitTestCase
5-
from typing import List
5+
from typing import List, Dict
66

77
# Make the helper files in test/ importable
88
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
@@ -22,16 +22,25 @@ def fn(a: List[complex], idx: int):
2222
input = [1j, 2, 3 + 4j, -5, -7j]
2323
self.checkScript(fn, (input, 2))
2424

25+
def test_complexdict(self):
26+
def fn(a: Dict[complex, complex], key: complex) -> complex:
27+
return a[key]
28+
29+
input = {2 + 3j : 2 - 3j, -4.3 - 2j: 3j}
30+
self.checkScript(fn, (input, -4.3 - 2j))
31+
2532
def test_pickle(self):
2633
class ComplexModule(torch.jit.ScriptModule):
2734
def __init__(self):
2835
super().__init__()
2936
self.a = 3 + 5j
3037
self.b = [2 + 3j, 3 + 4j, 0 - 3j, -4 + 0j]
38+
self.c = {2 + 3j : 2 - 3j, -4.3 - 2j: 3j}
3139

3240
def forward(self, b: int):
3341
return b
3442

3543
loaded = self.getExportImportCopy(ComplexModule())
3644
self.assertEqual(loaded.a, 3 + 5j)
3745
self.assertEqual(loaded.b, [2 + 3j, 3 + 4j, -3j, -4])
46+
self.assertEqual(loaded.c, {2 + 3j : 2 - 3j, -4.3 - 2j: 3j})

Diff for: torch/csrc/jit/runtime/register_prim_ops.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -1356,6 +1356,7 @@ RegisterOperators reg_dict_ops({
13561356
CREATE_DICT_OPS("int"),
13571357
CREATE_DICT_OPS("bool"),
13581358
CREATE_DICT_OPS("float"),
1359+
CREATE_DICT_OPS("complex"),
13591360
CREATE_DICT_OPS("Tensor"),
13601361
});
13611362

0 commit comments

Comments
 (0)