Skip to content

Commit

Permalink
[TorchGen] Use std::optional in generated code (pytorch#121454)
Browse files Browse the repository at this point in the history
This PR changes TorchGen to generate std::optional.

Pull Request resolved: pytorch#121454
Approved by: https://github.com/ezyang
  • Loading branch information
cyyever authored and pytorchmergebot committed Mar 29, 2024
1 parent 375a804 commit fb90b4d
Show file tree
Hide file tree
Showing 21 changed files with 202 additions and 192 deletions.
12 changes: 6 additions & 6 deletions aten/src/ATen/FunctionalTensorWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -484,8 +484,8 @@ c10::optional<Tensor> to_functional_tensor(const c10::optional<Tensor>& tensor)
}
return c10::nullopt;
}
c10::List<c10::optional<Tensor>> to_functional_tensor(const c10::List<c10::optional<Tensor>>& t_list) {
c10::List<c10::optional<Tensor>> outputs;
c10::List<::std::optional<Tensor>> to_functional_tensor(const c10::List<::std::optional<Tensor>>& t_list) {
c10::List<::std::optional<Tensor>> outputs;
outputs.reserve(t_list.size());
for (const auto i : c10::irange(t_list.size())) {
outputs.push_back(to_functional_tensor(t_list[i]));
Expand Down Expand Up @@ -536,8 +536,8 @@ std::vector<Tensor> from_functional_tensor(ITensorListRef t_list) {
}
return outputs;
}
c10::List<c10::optional<Tensor>> from_functional_tensor(const c10::List<c10::optional<Tensor>>& t_list) {
c10::List<c10::optional<Tensor>> outputs;
c10::List<::std::optional<Tensor>> from_functional_tensor(const c10::List<::std::optional<Tensor>>& t_list) {
c10::List<::std::optional<Tensor>> outputs;
outputs.reserve(t_list.size());
for (const auto i : c10::irange(t_list.size())) {
outputs.push_back(from_functional_tensor(t_list[i], /*assert_functional=*/false));
Expand Down Expand Up @@ -572,7 +572,7 @@ void sync(ITensorListRef t_list) {
sync(t);
}
}
void sync(const c10::List<c10::optional<Tensor>>& t_list) {
void sync(const c10::List<::std::optional<Tensor>>& t_list) {
for (const auto i : c10::irange(t_list.size())) {
sync(t_list[i]);
}
Expand Down Expand Up @@ -652,7 +652,7 @@ bool isFunctionalTensor(const c10::optional<Tensor>& t) {
}
}

bool isFunctionalTensor(const c10::List<c10::optional<Tensor>>& t_list) {
bool isFunctionalTensor(const c10::List<::std::optional<Tensor>>& t_list) {
if (t_list.empty()) return false;
auto functional_count = 0;
for (const auto i : c10::irange(t_list.size())) {
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/TensorIndexing.h
Original file line number Diff line number Diff line change
Expand Up @@ -317,10 +317,10 @@ static inline void recordTensorIndex(
(*dim_ptr)++;
};

static inline c10::List<c10::optional<Tensor>> typeConvertIndices(
static inline c10::List<::std::optional<Tensor>> typeConvertIndices(
const Tensor& /*self*/,
std::vector<Tensor>&& indices) {
c10::List<c10::optional<Tensor>> converted_inds;
c10::List<::std::optional<Tensor>> converted_inds;
converted_inds.reserve(indices.size());
for (auto&& i : std::move(indices)) {
converted_inds.push_back(std::move(i));
Expand Down
20 changes: 10 additions & 10 deletions aten/src/ATen/core/op_registration/op_registration_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1154,15 +1154,15 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) {
"(int[]? a) -> int[]?");

// Test list of optional (with empty list)
testArgTypes<c10::List<c10::optional<int64_t>>>::test(
c10::List<c10::optional<int64_t>>(c10::List<c10::optional<int64_t>>({})), [] (const c10::List<c10::optional<int64_t>>& v) {EXPECT_EQ(0, v.size());},
c10::List<c10::optional<int64_t>>(c10::List<c10::optional<int64_t>>({})), [] (const IValue& v) {EXPECT_EQ(0, v.to<c10::List<c10::optional<int64_t>>>().size());},
testArgTypes<c10::List<::std::optional<int64_t>>>::test(
c10::List<::std::optional<int64_t>>(c10::List<::std::optional<int64_t>>({})), [] (const c10::List<::std::optional<int64_t>>& v) {EXPECT_EQ(0, v.size());},
c10::List<::std::optional<int64_t>>(c10::List<::std::optional<int64_t>>({})), [] (const IValue& v) {EXPECT_EQ(0, v.to<c10::List<::std::optional<int64_t>>>().size());},
"(int?[] a) -> int?[]");

// Test list of optional (with values)
testArgTypes<c10::List<c10::optional<int64_t>>>::test(
c10::List<c10::optional<int64_t>>(c10::List<c10::optional<int64_t>>({3, c10::nullopt, 2})), [] (const c10::List<c10::optional<int64_t>>& v) {expectListEquals<c10::optional<int64_t>>({3, c10::nullopt, 2}, v);},
c10::List<c10::optional<int64_t>>(c10::List<c10::optional<int64_t>>({3, c10::nullopt, 2})), [] (const IValue& v) {expectListEquals<c10::optional<int64_t>>({3, c10::nullopt, 2}, v.to<c10::List<c10::optional<int64_t>>>());},
testArgTypes<c10::List<::std::optional<int64_t>>>::test(
c10::List<::std::optional<int64_t>>(c10::List<::std::optional<int64_t>>({3, c10::nullopt, 2})), [] (const c10::List<::std::optional<int64_t>>& v) {expectListEquals<c10::optional<int64_t>>({3, c10::nullopt, 2}, v);},
c10::List<::std::optional<int64_t>>(c10::List<::std::optional<int64_t>>({3, c10::nullopt, 2})), [] (const IValue& v) {expectListEquals<c10::optional<int64_t>>({3, c10::nullopt, 2}, v.to<c10::List<::std::optional<int64_t>>>());},
"(int?[] a) -> int?[]");

// dict types
Expand Down Expand Up @@ -1234,15 +1234,15 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) {
"(Dict(int, Tensor) a) -> Dict(int, Tensor)");

// weird deeply nested type
using DeeplyNestedType = c10::List<c10::Dict<std::string, c10::List<c10::optional<c10::Dict<int64_t, std::string>>>>>;
using DeeplyNestedType = c10::List<c10::Dict<std::string, c10::List<::std::optional<c10::Dict<int64_t, std::string>>>>>;
auto makeDeeplyNestedObject = [] () -> DeeplyNestedType {
c10::Dict<int64_t, std::string> inner3;
inner3.insert(1, "1");
c10::List<c10::optional<c10::Dict<int64_t, std::string>>> inner2;
c10::List<::std::optional<c10::Dict<int64_t, std::string>>> inner2;
inner2.push_back(std::move(inner3));
c10::Dict<std::string, c10::List<c10::optional<c10::Dict<int64_t, std::string>>>> inner1;
c10::Dict<std::string, c10::List<::std::optional<c10::Dict<int64_t, std::string>>>> inner1;
inner1.insert("key", std::move(inner2));
c10::List<c10::Dict<std::string, c10::List<c10::optional<c10::Dict<int64_t, std::string>>>>> result;
c10::List<c10::Dict<std::string, c10::List<::std::optional<c10::Dict<int64_t, std::string>>>>> result;
result.push_back(inner1);
return result;
};
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/templates/RegisterFunctionalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ inline c10::List<Tensor> to_meta(const c10::List<Tensor>& t_list) {
return outputs;
}

inline c10::List<c10::optional<Tensor>> to_meta(const c10::List<c10::optional<Tensor>>& t_list) {
c10::List<c10::optional<Tensor>> outputs;
inline c10::List<::std::optional<Tensor>> to_meta(const c10::List<::std::optional<Tensor>>& t_list) {
c10::List<::std::optional<Tensor>> outputs;
outputs.reserve(t_list.size());
for (const auto i : c10::irange(t_list.size())) {
outputs.push_back(to_meta(t_list[i]));
Expand Down
4 changes: 2 additions & 2 deletions caffe2/contrib/aten/aten_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ namespace caffe2 {
namespace internal {
at::Tensor index_with_uint8_handling(
const at::Tensor& self,
const torch::List<c10::optional<at::Tensor>>& indices) {
const torch::List<std::optional<at::Tensor>>& indices) {
// Support BC only for the simplest case of mask indexing
if (indices.size() == 1) {
c10::optional<at::Tensor> first = indices[0];
std::optional<at::Tensor> first = indices[0];
if (first.has_value()
&& first->scalar_type() == at::kByte) {
TORCH_WARN(
Expand Down
6 changes: 3 additions & 3 deletions caffe2/contrib/aten/aten_op_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ using at::Half; // for AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, ...)
namespace internal {
TORCH_API at::Tensor index_with_uint8_handling(
const at::Tensor& self,
const torch::List<c10::optional<at::Tensor>>& indices);
const torch::List<std::optional<at::Tensor>>& indices);
}

template <class Context>
Expand Down Expand Up @@ -94,8 +94,8 @@ class ATenOp : public Operator<Context> {
return results;
}

torch::List<c10::optional<at::Tensor>> peekSliceOptionals(size_t i, size_t len, size_t N) {
torch::List<c10::optional<at::Tensor>> results;
torch::List<std::optional<at::Tensor>> peekSliceOptionals(size_t i, size_t len, size_t N) {
torch::List<std::optional<at::Tensor>> results;
results.reserve(len);
for (size_t ii = i; ii < i + len; ++ii) {
results.push_back(peek(ii, N));
Expand Down
4 changes: 2 additions & 2 deletions caffe2/contrib/aten/gen_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def value_is_tensor_type(v):
TENSORLIST_TYPE = [
'at::TensorList',
'const at::ITensorListRef &',
'const c10::List<c10::optional<at::Tensor>> &',
'const c10::List<::std::optional<at::Tensor>> &',
]

# for each aten type, how do we handle a return value of that type?
Expand Down Expand Up @@ -298,7 +298,7 @@ def emit_assignments(o, env):
env['statements'].append(
'auto {} = peekSlice({}, InputSize() - {}, InputSize());'
.format(arg['name'], real_inputs, static_tensor_inputs))
elif arg['type'] == 'const c10::List<c10::optional<at::Tensor>> &':
elif arg['type'] == 'const c10::List<::std::optional<at::Tensor>> &':
# NOTE: do not advance real_inputs here. After this we will
# switch to indexing the "stack" from the end
env['statements'].append(
Expand Down
2 changes: 1 addition & 1 deletion test/test_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ def _simple_type_parser(func, arg_name, arg_type):
return instance_gen()
elif arg_type == "TensorList" or arg_type == "ITensorListRef":
return [instance_gen(), instance_gen()]
elif arg_type == "c10::List<c10::optional<Tensor>>":
elif arg_type == "c10::List<::std::optional<Tensor>>":
return [instance_gen(), instance_gen()]
elif arg_type == "IntArrayRef" or arg_type == "SymIntArrayRef":
size = arg.get("size", 2)
Expand Down
Loading

0 comments on commit fb90b4d

Please sign in to comment.