Skip to content

Commit

Permalink
Rename DynamicCastToGenerated/DownCastToGenerated to
Browse files Browse the repository at this point in the history
`DynamicCastMessage`/`DownCastMessage`.
The target does not necessarily need to be a generated type. For example, it
also supports `Message` itself. This makes the API friendlier to generic code and less verbose.

Replace all uses of dynamic_cast/down_cast/**ToGenerated with the new names.
Also, remove checks for RTTI in tests where we only need the casts to work. They don't need RTTI anymore.

PiperOrigin-RevId: 638278948
  • Loading branch information
protobuf-github-bot authored and copybara-github committed May 29, 2024
1 parent 01ec3fa commit 18da465
Show file tree
Hide file tree
Showing 28 changed files with 223 additions and 191 deletions.
3 changes: 2 additions & 1 deletion python/google/protobuf/pyext/descriptor_database.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "google/protobuf/descriptor.pb.h"
#include "absl/log/absl_log.h"
#include "google/protobuf/message.h"
#include "google/protobuf/message_lite.h"
#include "google/protobuf/pyext/message.h"
#include "google/protobuf/pyext/scoped_pyobject_ptr.h"

Expand Down Expand Up @@ -55,7 +56,7 @@ static bool GetFileDescriptorProto(PyObject* py_descriptor,
message->message->GetDescriptor() == filedescriptor_descriptor) {
// Fast path: Just use the pointer.
FileDescriptorProto* file_proto =
google::protobuf::DownCastToGenerated<FileDescriptorProto>(message->message);
google::protobuf::DownCastMessage<FileDescriptorProto>(message->message);
*output = *file_proto;
return true;
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/google/protobuf/arena.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ using type_info = ::type_info;
#endif

#include "absl/base/attributes.h"
#include "google/protobuf/stubs/common.h"
#include "absl/base/macros.h"
#include "absl/log/absl_check.h"
#include "absl/utility/internal/if_constexpr.h"
#include "google/protobuf/arena_align.h"
Expand Down
2 changes: 1 addition & 1 deletion src/google/protobuf/arena_align.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
#include <cstddef>
#include <cstdint>

#include "google/protobuf/stubs/common.h"
#include "absl/base/macros.h"
#include "absl/log/absl_check.h"
#include "absl/numeric/bits.h"

Expand Down
6 changes: 3 additions & 3 deletions src/google/protobuf/arena_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -879,7 +879,7 @@ TEST(ArenaTest, ReleaseFromArenaMessageUsingReflectionMakesCopy) {
const Reflection* r = arena_message->GetReflection();
const FieldDescriptor* f = arena_message->GetDescriptor()->FindFieldByName(
"optional_nested_message");
nested_msg = DownCastToGenerated<TestAllTypes::NestedMessage>(
nested_msg = DownCastMessage<TestAllTypes::NestedMessage>(
r->ReleaseMessage(arena_message, f));
}
EXPECT_EQ(42, nested_msg->bb());
Expand Down Expand Up @@ -1491,7 +1491,7 @@ TEST(ArenaTest, MutableMessageReflection) {
const Descriptor* d = message->GetDescriptor();
const FieldDescriptor* field = d->FindFieldByName("optional_nested_message");
TestAllTypes::NestedMessage* submessage =
DownCastToGenerated<TestAllTypes::NestedMessage>(
DownCastMessage<TestAllTypes::NestedMessage>(
r->MutableMessage(message, field));
TestAllTypes::NestedMessage* submessage_expected =
message->mutable_optional_nested_message();
Expand All @@ -1501,7 +1501,7 @@ TEST(ArenaTest, MutableMessageReflection) {

const FieldDescriptor* oneof_field =
d->FindFieldByName("oneof_nested_message");
submessage = DownCastToGenerated<TestAllTypes::NestedMessage>(
submessage = DownCastMessage<TestAllTypes::NestedMessage>(
r->MutableMessage(message, oneof_field));
submessage_expected = message->mutable_oneof_nested_message();

Expand Down
5 changes: 2 additions & 3 deletions src/google/protobuf/compiler/cpp/service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,8 @@ void ServiceGenerator::GenerateCallMethodCases(io::Printer* printer) {
},
R"cc(
case $index$:
$name$(controller,
::$proto_ns$::DownCastToGenerated<$input$>(request),
::$proto_ns$::DownCastToGenerated<$output$>(response), done);
$name$(controller, ::$proto_ns$::DownCastMessage<$input$>(request),
::$proto_ns$::DownCastMessage<$output$>(response), done);
break;
)cc");
}
Expand Down
14 changes: 7 additions & 7 deletions src/google/protobuf/compiler/cpp/unittest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -1310,17 +1310,17 @@ TEST_F(GENERATED_SERVICE_TEST_NAME, CallMethod) {
TEST_F(GENERATED_SERVICE_TEST_NAME, CallMethodTypeFailure) {
// Verify death if we call Foo() with Bar's message types.

#if PROTOBUF_RTTI && GTEST_HAS_DEATH_TEST // death tests do not work on Windows yet
#if GTEST_HAS_DEATH_TEST // death tests do not work on Windows yet
EXPECT_DEBUG_DEATH(
mock_service_.CallMethod(foo_, &mock_controller_,
&foo_request_, &bar_response_, done_.get()),
"DynamicCastToGenerated");
mock_service_.CallMethod(foo_, &mock_controller_, &foo_request_,
&bar_response_, done_.get()),
"DynamicCastMessage");

mock_service_.Reset();
EXPECT_DEBUG_DEATH(
mock_service_.CallMethod(foo_, &mock_controller_,
&bar_request_, &foo_response_, done_.get()),
"DynamicCastToGenerated");
mock_service_.CallMethod(foo_, &mock_controller_, &bar_request_,
&foo_response_, done_.get()),
"DynamicCastMessage");
#endif // GTEST_HAS_DEATH_TEST
}

Expand Down
1 change: 1 addition & 0 deletions src/google/protobuf/compiler/objectivec/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ cc_test(
":line_consumer",
"//src/google/protobuf/io",
"//src/google/protobuf/stubs",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest",
"@com_google_googletest//:gtest_main",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include <vector>

#include <gtest/gtest.h>
#include "google/protobuf/stubs/common.h"
#include "absl/base/macros.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
Expand Down
5 changes: 2 additions & 3 deletions src/google/protobuf/compiler/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1565,9 +1565,8 @@ bool Parser::ParseOption(Message* options,
}

UninterpretedOption* uninterpreted_option =
DownCastToGenerated<UninterpretedOption>(
options->GetReflection()->AddMessage(options,
uninterpreted_option_field));
DownCastMessage<UninterpretedOption>(options->GetReflection()->AddMessage(
options, uninterpreted_option_field));

// Parse dot-separated name.
{
Expand Down
2 changes: 1 addition & 1 deletion src/google/protobuf/descriptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8584,7 +8584,7 @@ bool DescriptorBuilder::OptionInterpreter::InterpretOptionsImpl(
*original_options, original_uninterpreted_options_field);
for (int i = 0; i < num_uninterpreted_options; ++i) {
src_path.push_back(i);
uninterpreted_option_ = DownCastToGenerated<UninterpretedOption>(
uninterpreted_option_ = DownCastMessage<UninterpretedOption>(
&original_options->GetReflection()->GetRepeatedMessage(
*original_options, original_uninterpreted_options_field, i));
if (!InterpretSingleOption(options, src_path,
Expand Down
3 changes: 2 additions & 1 deletion src/google/protobuf/extension_set_heavy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,8 @@ size_t ExtensionSet::Extension::SpaceUsedExcludingSelfLong() const {
if (is_lazy) {
total_size += lazymessage_value->SpaceUsedLong();
} else {
total_size += DownCastToMessage(message_value)->SpaceUsedLong();
total_size +=
DownCastMessage<Message>(message_value)->SpaceUsedLong();
}
break;
default:
Expand Down
6 changes: 1 addition & 5 deletions src/google/protobuf/extension_set_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1323,11 +1323,7 @@ TEST(ExtensionSetTest, DynamicExtensions) {
const Message& sub_message =
message.GetReflection()->GetMessage(message, message_extension);
const unittest::ForeignMessage* typed_sub_message =
#if PROTOBUF_RTTI
dynamic_cast<const unittest::ForeignMessage*>(&sub_message);
#else
static_cast<const unittest::ForeignMessage*>(&sub_message);
#endif
google::protobuf::DynamicCastMessage<unittest::ForeignMessage>(&sub_message);
ASSERT_TRUE(typed_sub_message != nullptr);
EXPECT_EQ(456, typed_sub_message->c());
}
Expand Down
4 changes: 2 additions & 2 deletions src/google/protobuf/generated_message_tctable_full.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ const char* TcParser::ReflectionFallback(PROTOBUF_TC_PARAM_DECL) {
return ptr;
}

auto* full_msg = DownCastToMessage(msg);
auto* full_msg = DownCastMessage<Message>(msg);
auto* descriptor = full_msg->GetDescriptor();
auto* reflection = full_msg->GetReflection();
int field_number = WireFormatLite::GetTagFieldNumber(tag);
Expand All @@ -87,7 +87,7 @@ const char* TcParser::ReflectionParseLoop(PROTOBUF_TC_PARAM_DECL) {
(void)table;
(void)hasbits;
// Call into the wire format reflective parse loop.
return WireFormat::_InternalParse(DownCastToMessage(msg), ptr, ctx);
return WireFormat::_InternalParse(DownCastMessage<Message>(msg), ptr, ctx);
}

const char* TcParser::MessageSetWireFormatParseLoop(
Expand Down
1 change: 1 addition & 0 deletions src/google/protobuf/io/zero_copy_stream_impl_lite.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "google/protobuf/stubs/callback.h"
#include "google/protobuf/stubs/common.h"
#include "absl/base/attributes.h"
#include "absl/base/macros.h"
#include "absl/strings/cord.h"
#include "absl/strings/cord_buffer.h"
#include "google/protobuf/io/zero_copy_stream.h"
Expand Down
43 changes: 20 additions & 23 deletions src/google/protobuf/lite_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1330,83 +1330,80 @@ TEST(LiteBasicTest, CodedInputStreamRollback) {
using CastType1 = protobuf_unittest::TestAllTypesLite;
using CastType2 = protobuf_unittest::TestPackedTypesLite;

TEST(LiteTest, DynamicCastToGenerated) {
TEST(LiteTest, DynamicCastMessage) {
CastType1 test_type_1;

MessageLite* test_type_1_pointer = &test_type_1;
EXPECT_EQ(&test_type_1,
DynamicCastToGenerated<CastType1>(test_type_1_pointer));
EXPECT_EQ(nullptr, DynamicCastToGenerated<CastType2>(test_type_1_pointer));
EXPECT_EQ(&test_type_1, DynamicCastMessage<CastType1>(test_type_1_pointer));
EXPECT_EQ(nullptr, DynamicCastMessage<CastType2>(test_type_1_pointer));

const MessageLite* test_type_1_pointer_const = &test_type_1;
EXPECT_EQ(&test_type_1,
DynamicCastToGenerated<const CastType1>(test_type_1_pointer_const));
DynamicCastMessage<const CastType1>(test_type_1_pointer_const));
EXPECT_EQ(nullptr,
DynamicCastToGenerated<const CastType2>(test_type_1_pointer_const));
DynamicCastMessage<const CastType2>(test_type_1_pointer_const));

MessageLite* test_type_1_pointer_nullptr = nullptr;
EXPECT_EQ(nullptr,
DynamicCastToGenerated<CastType1>(test_type_1_pointer_nullptr));
DynamicCastMessage<CastType1>(test_type_1_pointer_nullptr));

MessageLite& test_type_1_pointer_ref = test_type_1;
EXPECT_EQ(&test_type_1,
&DynamicCastToGenerated<CastType1>(test_type_1_pointer_ref));
&DynamicCastMessage<CastType1>(test_type_1_pointer_ref));

const MessageLite& test_type_1_pointer_const_ref = test_type_1;
EXPECT_EQ(&test_type_1,
&DynamicCastToGenerated<CastType1>(test_type_1_pointer_const_ref));
&DynamicCastMessage<CastType1>(test_type_1_pointer_const_ref));
}

#if GTEST_HAS_DEATH_TEST
TEST(LiteTest, DynamicCastToGeneratedInvalidReferenceType) {
TEST(LiteTest, DynamicCastMessageInvalidReferenceType) {
CastType1 test_type_1;
const MessageLite& test_type_1_pointer_const_ref = test_type_1;
ASSERT_DEATH(DynamicCastToGenerated<CastType2>(test_type_1_pointer_const_ref),
ASSERT_DEATH(DynamicCastMessage<CastType2>(test_type_1_pointer_const_ref),
"Cannot downcast " + test_type_1.GetTypeName() + " to " +
CastType2::default_instance().GetTypeName());
}
#endif // GTEST_HAS_DEATH_TEST

TEST(LiteTest, DownCastToGeneratedValidType) {
TEST(LiteTest, DownCastMessageValidType) {
CastType1 test_type_1;

MessageLite* test_type_1_pointer = &test_type_1;
EXPECT_EQ(&test_type_1, DownCastToGenerated<CastType1>(test_type_1_pointer));
EXPECT_EQ(&test_type_1, DownCastMessage<CastType1>(test_type_1_pointer));

const MessageLite* test_type_1_pointer_const = &test_type_1;
EXPECT_EQ(&test_type_1,
DownCastToGenerated<const CastType1>(test_type_1_pointer_const));
DownCastMessage<const CastType1>(test_type_1_pointer_const));

MessageLite* test_type_1_pointer_nullptr = nullptr;
EXPECT_EQ(nullptr,
DownCastToGenerated<CastType1>(test_type_1_pointer_nullptr));
EXPECT_EQ(nullptr, DownCastMessage<CastType1>(test_type_1_pointer_nullptr));

MessageLite& test_type_1_pointer_ref = test_type_1;
EXPECT_EQ(&test_type_1,
&DownCastToGenerated<CastType1>(test_type_1_pointer_ref));
EXPECT_EQ(&test_type_1, &DownCastMessage<CastType1>(test_type_1_pointer_ref));

const MessageLite& test_type_1_pointer_const_ref = test_type_1;
EXPECT_EQ(&test_type_1,
&DownCastToGenerated<CastType1>(test_type_1_pointer_const_ref));
&DownCastMessage<CastType1>(test_type_1_pointer_const_ref));
}

#if GTEST_HAS_DEATH_TEST
TEST(LiteTest, DownCastToGeneratedInvalidPointerType) {
TEST(LiteTest, DownCastMessageInvalidPointerType) {
CastType1 test_type_1;

MessageLite* test_type_1_pointer = &test_type_1;

ASSERT_DEBUG_DEATH(DownCastToGenerated<CastType2>(test_type_1_pointer),
ASSERT_DEBUG_DEATH(DownCastMessage<CastType2>(test_type_1_pointer),
"Cannot downcast " + test_type_1.GetTypeName() + " to " +
CastType2::default_instance().GetTypeName());
}

TEST(LiteTest, DownCastToGeneratedInvalidReferenceType) {
TEST(LiteTest, DownCastMessageInvalidReferenceType) {
CastType1 test_type_1;

MessageLite& test_type_1_pointer = test_type_1;

ASSERT_DEBUG_DEATH(DownCastToGenerated<CastType2>(test_type_1_pointer),
ASSERT_DEBUG_DEATH(DownCastMessage<CastType2>(test_type_1_pointer),
"Cannot downcast " + test_type_1.GetTypeName() + " to " +
CastType2::default_instance().GetTypeName());
}
Expand Down
30 changes: 15 additions & 15 deletions src/google/protobuf/map_test.inc
Original file line number Diff line number Diff line change
Expand Up @@ -1578,7 +1578,7 @@ TEST_F(MapFieldReflectionTest, RegularFields) {
message_int32_message.GetReflection()->GetInt32(
message_int32_message, fd_map_int32_foreign_message_key);
const ForeignMessage& value_int32_message =
DownCastToGenerated<ForeignMessage>(
DownCastMessage<ForeignMessage>(
message_int32_message.GetReflection()->GetMessage(
message_int32_message, fd_map_int32_foreign_message_value));
EXPECT_EQ(value_int32_message.c(), Func(key_int32_message, 6));
Expand Down Expand Up @@ -1615,7 +1615,7 @@ TEST_F(MapFieldReflectionTest, RegularFields) {
message_int32_message.GetReflection()->GetInt32(
message_int32_message, fd_map_int32_foreign_message_key);
const ForeignMessage& value_int32_message =
DownCastToGenerated<ForeignMessage>(
DownCastMessage<ForeignMessage>(
message_int32_message.GetReflection()->GetMessage(
message_int32_message, fd_map_int32_foreign_message_value));
EXPECT_EQ(value_int32_message.c(), Func(key_int32_message, 6));
Expand Down Expand Up @@ -1652,7 +1652,7 @@ TEST_F(MapFieldReflectionTest, RegularFields) {
int32_t key_int32_message =
message_int32_message->GetReflection()->GetInt32(
*message_int32_message, fd_map_int32_foreign_message_key);
ForeignMessage* value_int32_message = DownCastToGenerated<ForeignMessage>(
ForeignMessage* value_int32_message = DownCastMessage<ForeignMessage>(
message_int32_message->GetReflection()->MutableMessage(
message_int32_message, fd_map_int32_foreign_message_value));
value_int32_message->set_c(Func(key_int32_message, -6));
Expand Down Expand Up @@ -1808,7 +1808,7 @@ TEST_F(MapFieldReflectionTest, RepeatedFieldRefForRegularFields) {
message_int32_message.GetReflection()->GetInt32(
message_int32_message, fd_map_int32_foreign_message_key);
const ForeignMessage& value_int32_message =
DownCastToGenerated<ForeignMessage>(
DownCastMessage<ForeignMessage>(
message_int32_message.GetReflection()->GetMessage(
message_int32_message, fd_map_int32_foreign_message_value));
EXPECT_EQ(value_int32_message.c(), Func(key_int32_message, 6));
Expand Down Expand Up @@ -1849,7 +1849,7 @@ TEST_F(MapFieldReflectionTest, RepeatedFieldRefForRegularFields) {
message_int32_message.GetReflection()->GetInt32(
message_int32_message, fd_map_int32_foreign_message_key);
const ForeignMessage& value_int32_message =
DownCastToGenerated<ForeignMessage>(
DownCastMessage<ForeignMessage>(
message_int32_message.GetReflection()->GetMessage(
message_int32_message, fd_map_int32_foreign_message_value));
EXPECT_EQ(value_int32_message.c(), Func(key_int32_message, 6));
Expand Down Expand Up @@ -1966,8 +1966,8 @@ TEST_F(MapFieldReflectionTest, RepeatedFieldRefForRegularFields) {
const Message& message = *it;
int32_t key = message.GetReflection()->GetInt32(
message, fd_map_int32_foreign_message_key);
const ForeignMessage& sub_message = DownCastToGenerated<ForeignMessage>(
message.GetReflection()->GetMessage(
const ForeignMessage& sub_message =
DownCastMessage<ForeignMessage>(message.GetReflection()->GetMessage(
message, fd_map_int32_foreign_message_value));
result[key].MergeFrom(sub_message);
++index;
Expand Down Expand Up @@ -2120,29 +2120,29 @@ TEST_F(MapFieldReflectionTest, RepeatedFieldRefForRegularFields) {
{
const Message& message0a =
mmf_int32_foreign_message.Get(0, entry_int32_foreign_message.get());
const ForeignMessage& sub_message0a = DownCastToGenerated<ForeignMessage>(
message0a.GetReflection()->GetMessage(
const ForeignMessage& sub_message0a =
DownCastMessage<ForeignMessage>(message0a.GetReflection()->GetMessage(
message0a, fd_map_int32_foreign_message_value));
int32_t int32_value0a = sub_message0a.c();
const Message& message9a =
mmf_int32_foreign_message.Get(9, entry_int32_foreign_message.get());
const ForeignMessage& sub_message9a = DownCastToGenerated<ForeignMessage>(
message9a.GetReflection()->GetMessage(
const ForeignMessage& sub_message9a =
DownCastMessage<ForeignMessage>(message9a.GetReflection()->GetMessage(
message9a, fd_map_int32_foreign_message_value));
int32_t int32_value9a = sub_message9a.c();

mmf_int32_foreign_message.SwapElements(0, 9);

const Message& message0b =
mmf_int32_foreign_message.Get(0, entry_int32_foreign_message.get());
const ForeignMessage& sub_message0b = DownCastToGenerated<ForeignMessage>(
message0b.GetReflection()->GetMessage(
const ForeignMessage& sub_message0b =
DownCastMessage<ForeignMessage>(message0b.GetReflection()->GetMessage(
message0b, fd_map_int32_foreign_message_value));
int32_t int32_value0b = sub_message0b.c();
const Message& message9b =
mmf_int32_foreign_message.Get(9, entry_int32_foreign_message.get());
const ForeignMessage& sub_message9b = DownCastToGenerated<ForeignMessage>(
message9b.GetReflection()->GetMessage(
const ForeignMessage& sub_message9b =
DownCastMessage<ForeignMessage>(message9b.GetReflection()->GetMessage(
message9b, fd_map_int32_foreign_message_value));
int32_t int32_value9b = sub_message9b.c();

Expand Down
Loading

0 comments on commit 18da465

Please sign in to comment.