diff --git a/DEPS.bzl b/DEPS.bzl index 2a7a047cd108..ab45420a1e4e 100644 --- a/DEPS.bzl +++ b/DEPS.bzl @@ -9222,6 +9222,16 @@ def go_deps(): "https://storage.googleapis.com/cockroach-godeps/gomod/github.com/zeebo/assert/com_github_zeebo_assert-v1.3.0.zip", ], ) + go_repository( + name = "com_github_zeebo_errs", + build_file_proto_mode = "disable_global", + importpath = "github.com/zeebo/errs", + sha256 = "d2fa293e275c21bfb413e2968d79036931a55f503d8b62381563ed189b523cd2", + strip_prefix = "github.com/zeebo/errs@v1.2.2", + urls = [ + "https://storage.googleapis.com/cockroach-godeps/gomod/github.com/zeebo/errs/com_github_zeebo_errs-v1.2.2.zip", + ], + ) go_repository( name = "com_github_zeebo_xxh3", build_file_proto_mode = "disable_global", @@ -11355,6 +11365,16 @@ def go_deps(): "https://storage.googleapis.com/cockroach-godeps/gomod/rsc.io/sampler/io_rsc_sampler-v1.3.0.zip", ], ) + go_repository( + name = "io_storj_drpc", + build_file_proto_mode = "disable_global", + importpath = "storj.io/drpc", + sha256 = "e297ccead2763d354959a3c04b0c9c27c9c84c99d129f216ec07da663ee0091a", + strip_prefix = "storj.io/drpc@v0.0.34", + urls = [ + "https://storage.googleapis.com/cockroach-godeps/gomod/storj.io/drpc/io_storj_drpc-v0.0.34.zip", + ], + ) go_repository( name = "io_vitess_vitess", build_file_proto_mode = "disable_global", diff --git a/build/bazelutil/distdir_files.bzl b/build/bazelutil/distdir_files.bzl index 0e8cbb7b9221..5658ad5581d1 100644 --- a/build/bazelutil/distdir_files.bzl +++ b/build/bazelutil/distdir_files.bzl @@ -1052,6 +1052,7 @@ DISTDIR_FILES = { "https://storage.googleapis.com/cockroach-godeps/gomod/github.com/z-division/go-zookeeper/com_github_z_division_go_zookeeper-v0.0.0-20190128072838-6d7457066b9b.zip": "b0a67a3bb3cfbb1be18618b84b02588979795966e040f18c5bb4be036888cabd", "https://storage.googleapis.com/cockroach-godeps/gomod/github.com/zabawaba99/go-gitignore/com_github_zabawaba99_go_gitignore-v0.0.0-20200117185801-39e6bddfb292.zip": "6c837b93e1c73e53123941c8e866de1deae6b645cc49a7d30d493c146178f8e8", "https://storage.googleapis.com/cockroach-godeps/gomod/github.com/zeebo/assert/com_github_zeebo_assert-v1.3.0.zip": "1f01421d74ff37cb8247988155be9e6877d336029bcd887a1d035fd32d7ab6ae", + "https://storage.googleapis.com/cockroach-godeps/gomod/github.com/zeebo/errs/com_github_zeebo_errs-v1.2.2.zip": "d2fa293e275c21bfb413e2968d79036931a55f503d8b62381563ed189b523cd2", "https://storage.googleapis.com/cockroach-godeps/gomod/github.com/zeebo/xxh3/com_github_zeebo_xxh3-v1.0.2.zip": "190e5ef1f672e9321a1580bdd31c6440fde6044ca8168d2b489cf50cdc4f58a6", "https://storage.googleapis.com/cockroach-godeps/gomod/github.com/zenazn/goji/com_github_zenazn_goji-v0.9.0.zip": "0807a255d9d715d18427a6eedd8e4f5a22670b09e5f45fddd229c1ae38da25a9", "https://storage.googleapis.com/cockroach-godeps/gomod/gitlab.com/golang-commonmark/html/com_gitlab_golang_commonmark_html-v0.0.0-20191124015941-a22733972181.zip": "f2ba8985dc9d6be347a17d9200a0be0cee5ab3bce4dc601c0651a77ef2bbffc3", @@ -1190,6 +1191,7 @@ DISTDIR_FILES = { "https://storage.googleapis.com/cockroach-godeps/gomod/sigs.k8s.io/structured-merge-diff/v4/io_k8s_sigs_structured_merge_diff_v4-v4.1.2.zip": "b32af97dadd79179a8f62aaf4ef1e0562e051be77053a60c7a4e724a5cbd00ce", "https://storage.googleapis.com/cockroach-godeps/gomod/sigs.k8s.io/yaml/io_k8s_sigs_yaml-v1.2.0.zip": "55ed08c5df448a033bf7e2c2912d4daa85b856a05c854b0c87ccc85c7f3fbfc7", "https://storage.googleapis.com/cockroach-godeps/gomod/sourcegraph.com/sourcegraph/appdash/com_sourcegraph_sourcegraph_appdash-v0.0.0-20190731080439-ebfcffb1b5c0.zip": "bd2492d9db05362c2fecd0b3d0f6002c89a6d90d678fb93b4158298ab883736f", + "https://storage.googleapis.com/cockroach-godeps/gomod/storj.io/drpc/io_storj_drpc-v0.0.34.zip": "e297ccead2763d354959a3c04b0c9c27c9c84c99d129f216ec07da663ee0091a", "https://storage.googleapis.com/public-bazel-artifacts/bazel/88ef31b429631b787ceb5e4556d773b20ad797c8.zip": "92a89a2bbe6c6db2a8b87da4ce723aff6253656e8417f37e50d362817c39b98b", "https://storage.googleapis.com/public-bazel-artifacts/bazel/bazel-gazelle-v0.39.1.tar.gz": "b760f7fe75173886007f7c2e616a21241208f3d90e8657dc65d36a771e916b6a", "https://storage.googleapis.com/public-bazel-artifacts/bazel/bazel-lib-v1.42.3.tar.gz": "d0529773764ac61184eb3ad3c687fb835df5bee01afedf07f0cf1a45515c96bc", diff --git a/go.mod b/go.mod index bcd5c7337f28..752b5beea53c 100644 --- a/go.mod +++ b/go.mod @@ -41,6 +41,7 @@ require ( google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1 google.golang.org/grpc v1.56.3 google.golang.org/protobuf v1.35.1 + storj.io/drpc v0.0.34 ) // If any of the following dependencies get updated as a side-effect @@ -429,6 +430,7 @@ require ( github.com/twpayne/go-kml v1.5.2 // indirect github.com/urfave/cli/v2 v2.3.0 // indirect github.com/yusufpapurcu/wmi v1.2.2 // indirect + github.com/zeebo/errs v1.2.2 // indirect github.com/zeebo/xxh3 v1.0.2 // indirect gitlab.com/golang-commonmark/html v0.0.0-20191124015941-a22733972181 // indirect gitlab.com/golang-commonmark/linkify v0.0.0-20191026162114-a0c2df6c8f82 // indirect diff --git a/go.sum b/go.sum index 0dc48916a769..bf117e5afb38 100644 --- a/go.sum +++ b/go.sum @@ -2363,6 +2363,8 @@ github.com/zabawaba99/go-gitignore v0.0.0-20200117185801-39e6bddfb292 h1:vpcCVk+ github.com/zabawaba99/go-gitignore v0.0.0-20200117185801-39e6bddfb292/go.mod h1:qcqv8IHwbR0JmjY1LZy4PeytlwxDPn1vUkjX7Wq0VaY= github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ= github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= +github.com/zeebo/errs v1.2.2 h1:5NFypMTuSdoySVTqlNs1dEoU21QVamMQJxW/Fii5O7g= +github.com/zeebo/errs v1.2.2/go.mod h1:sgbWHsvVuTPHcqJJGQ1WhI5KbWlHYz+2+2C/LSEtCw4= github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= @@ -3323,3 +3325,5 @@ sigs.k8s.io/structured-merge-diff/v4 v4.1.2/go.mod h1:j/nl6xW8vLS49O8YvXW1ocPhZa sigs.k8s.io/yaml v1.1.0/go.mod h1:UJmg0vDUVViEyp3mgSv9WPwZCDxu4rQW1olrI1uml+o= sigs.k8s.io/yaml v1.2.0/go.mod h1:yfXDCHCao9+ENCvLSE62v9VSji2MKu5jeNfTrofGhJc= sourcegraph.com/sourcegraph/appdash v0.0.0-20190731080439-ebfcffb1b5c0/go.mod h1:hI742Nqp5OhwiqlzhgfbWU4mW4yO10fP+LoT9WOswdU= +storj.io/drpc v0.0.34 h1:q9zlQKfJ5A7x8NQNFk8x7eKUF78FMhmAbZLnFK+og7I= +storj.io/drpc v0.0.34/go.mod h1:Y9LZaa8esL1PW2IDMqJE7CFSNq7d5bQ3RI7mGPtmKMg= diff --git a/pkg/kv/kvpb/BUILD.bazel b/pkg/kv/kvpb/BUILD.bazel index 9f4d3bb90f64..c1da83856de8 100644 --- a/pkg/kv/kvpb/BUILD.bazel +++ b/pkg/kv/kvpb/BUILD.bazel @@ -8,17 +8,21 @@ load(":gen.bzl", "batch_gen") go_library( name = "kvpb", srcs = [ + ":gen-batch-generated", # keep + ":gen-errordetailtype-stringer", # keep + ":gen-method-stringer", # keep "ambiguous_result_error.go", "api.go", + # DRPC protobuf file (api_drpc.pb.go) is currently generated manually. + # TODO (chandrat): Remove this once DRPC protobuf generation is + # integrated into the build process. + "api_drpc_hacky.go", "batch.go", "data.go", "errors.go", "method.go", "node_decommissioned_error.go", "replica_unavailable_error.go", - ":gen-batch-generated", # keep - ":gen-errordetailtype-stringer", # keep - ":gen-method-stringer", # keep ], embed = [":kvpb_go_proto"], importpath = "github.com/cockroachdb/cockroach/pkg/kv/kvpb", @@ -46,6 +50,8 @@ go_library( "@com_github_gogo_protobuf//types", "@com_github_gogo_status//:status", "@com_github_golang_mock//gomock", # keep + "@io_storj_drpc//:drpc", + "@io_storj_drpc//drpcerr", "@org_golang_google_grpc//codes", "@org_golang_google_grpc//metadata", # keep ], diff --git a/pkg/kv/kvpb/api_drpc_hacky.go b/pkg/kv/kvpb/api_drpc_hacky.go new file mode 100644 index 000000000000..1dc6b640ee39 --- /dev/null +++ b/pkg/kv/kvpb/api_drpc_hacky.go @@ -0,0 +1,207 @@ +// Copyright 2025 The Cockroach Authors. +// +// Use of this software is governed by the CockroachDB Software License +// included in the /LICENSE file. + +// This file was manually generated with the DRPC protogen plugin using a dummy +// `api.proto` that includes a subset of relevant service methods. +// +// For instance, to generate this file, following proto file was used: +// +// -- api.proto -- begin -- +// syntax = "proto3"; +// package cockroach.kv.kvpb; +// option go_package = "github.com/cockroachdb/cockroach/pkg/kv/kvpb"; +// service Batch { +// rpc Batch (BatchRequest) returns (BatchResponse) {} +// rpc BatchStream (stream BatchRequest) returns (stream BatchResponse) {} +// } +// message BatchRequest{} +// message BatchResponse{} +// -- api.proto -- end -- +// +// NB: The use of empty BatchRequest and BatchResponse messages is a deliberate +// decision to avoid dependencies. +// +// +// To generate this file using DRPC protogen plugin from the dummy `api.proto` +// defined above, use the following command: +// +// ``` +// protoc --gogo_out=paths=source_relative:. \ +// --go-drpc_out=paths=source_relative,protolib=github.com/gogo/protobuf:. \ +// api.proto +// ``` +// +// NB: Make sure you have `protoc` installed and `protoc-gen-gogoroach` is +// built from $COCKROACH_SRC/pkg/cmd/protoc-gen-gogoroach. +// +// This code-gen should be automated as part of productionizing drpc. + +package kvpb + +import ( + "context" + + "github.com/cockroachdb/cockroach/pkg/util/protoutil" + "github.com/cockroachdb/errors" + "storj.io/drpc" + "storj.io/drpc/drpcerr" +) + +type drpcEncoding_File_api_proto struct{} + +func (drpcEncoding_File_api_proto) Marshal(msg drpc.Message) ([]byte, error) { + return protoutil.Marshal(msg.(protoutil.Message)) +} + +func (drpcEncoding_File_api_proto) Unmarshal(buf []byte, msg drpc.Message) error { + return protoutil.Unmarshal(buf, msg.(protoutil.Message)) +} + +type DRPCBatchClient interface { + DRPCConn() drpc.Conn + + Batch(ctx context.Context, in *BatchRequest) (*BatchResponse, error) + BatchStream(ctx context.Context) (DRPCBatch_BatchStreamClient, error) +} + +type drpcBatchClient struct { + cc drpc.Conn +} + +func NewDRPCBatchClient(cc drpc.Conn) DRPCBatchClient { + return &drpcBatchClient{cc} +} + +func (c *drpcBatchClient) DRPCConn() drpc.Conn { return c.cc } + +func (c *drpcBatchClient) Batch(ctx context.Context, in *BatchRequest) (*BatchResponse, error) { + out := new(BatchResponse) + err := c.cc.Invoke(ctx, "/cockroach.kv.kvpb.Batch/Batch", drpcEncoding_File_api_proto{}, in, out) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *drpcBatchClient) BatchStream(ctx context.Context) (DRPCBatch_BatchStreamClient, error) { + stream, err := c.cc.NewStream(ctx, "/cockroach.kv.kvpb.Batch/BatchStream", drpcEncoding_File_api_proto{}) + if err != nil { + return nil, err + } + x := &drpcBatch_BatchStreamClient{stream} + return x, nil +} + +type DRPCBatch_BatchStreamClient interface { + drpc.Stream + Send(*BatchRequest) error + Recv() (*BatchResponse, error) +} + +type drpcBatch_BatchStreamClient struct { + drpc.Stream +} + +func (x *drpcBatch_BatchStreamClient) GetStream() drpc.Stream { + return x.Stream +} + +func (x *drpcBatch_BatchStreamClient) Send(m *BatchRequest) error { + return x.MsgSend(m, drpcEncoding_File_api_proto{}) +} + +func (x *drpcBatch_BatchStreamClient) Recv() (*BatchResponse, error) { + m := new(BatchResponse) + if err := x.MsgRecv(m, drpcEncoding_File_api_proto{}); err != nil { + return nil, err + } + return m, nil +} + +func (x *drpcBatch_BatchStreamClient) RecvMsg(m *BatchResponse) error { + return x.MsgRecv(m, drpcEncoding_File_api_proto{}) +} + +type DRPCBatchServer interface { + Batch(context.Context, *BatchRequest) (*BatchResponse, error) + BatchStream(DRPCBatch_BatchStreamStream) error +} + +type DRPCBatchUnimplementedServer struct{} + +func (s *DRPCBatchUnimplementedServer) Batch( + context.Context, *BatchRequest, +) (*BatchResponse, error) { + return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented) +} + +func (s *DRPCBatchUnimplementedServer) BatchStream(DRPCBatch_BatchStreamStream) error { + return drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented) +} + +type DRPCBatchDescription struct{} + +func (DRPCBatchDescription) NumMethods() int { return 2 } + +func (DRPCBatchDescription) Method( + n int, +) (string, drpc.Encoding, drpc.Receiver, interface{}, bool) { + switch n { + case 0: + return "/cockroach.kv.kvpb.Batch/Batch", drpcEncoding_File_api_proto{}, + func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) { + return srv.(DRPCBatchServer). + Batch( + ctx, + in1.(*BatchRequest), + ) + }, DRPCBatchServer.Batch, true + case 1: + return "/cockroach.kv.kvpb.Batch/BatchStream", drpcEncoding_File_api_proto{}, + func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) { + return nil, srv.(DRPCBatchServer). + BatchStream( + &drpcBatch_BatchStreamStream{in1.(drpc.Stream)}, + ) + }, DRPCBatchServer.BatchStream, true + default: + return "", nil, nil, nil, false + } +} + +func DRPCRegisterBatch(mux drpc.Mux, impl DRPCBatchServer) error { + return mux.Register(impl, DRPCBatchDescription{}) +} + +type DRPCBatch_BatchStream interface { + drpc.Stream + SendAndClose(*BatchResponse) error +} + +type DRPCBatch_BatchStreamStream interface { + drpc.Stream + Send(*BatchResponse) error + Recv() (*BatchRequest, error) +} + +type drpcBatch_BatchStreamStream struct { + drpc.Stream +} + +func (x *drpcBatch_BatchStreamStream) Send(m *BatchResponse) error { + return x.MsgSend(m, drpcEncoding_File_api_proto{}) +} + +func (x *drpcBatch_BatchStreamStream) Recv() (*BatchRequest, error) { + m := new(BatchRequest) + if err := x.MsgRecv(m, drpcEncoding_File_api_proto{}); err != nil { + return nil, err + } + return m, nil +} + +func (x *drpcBatch_BatchStreamStream) RecvMsg(m *BatchRequest) error { + return x.MsgRecv(m, drpcEncoding_File_api_proto{}) +} diff --git a/pkg/kv/kvserver/loqrecovery/server.go b/pkg/kv/kvserver/loqrecovery/server.go index 9898724a3943..cf8adfa23a4c 100644 --- a/pkg/kv/kvserver/loqrecovery/server.go +++ b/pkg/kv/kvserver/loqrecovery/server.go @@ -750,7 +750,7 @@ func visitNodeWithRetry( // Note that we use ConnectNoBreaker here to avoid any race with probe // running on current node and target node restarting. Errors from circuit // breaker probes could confuse us and present node as unavailable. - conn, err = rpcCtx.GRPCDialNode(addr.String(), node.NodeID, node.Locality, rpc.DefaultClass).ConnectNoBreaker(ctx) + conn, _, err = rpcCtx.GRPCDialNode(addr.String(), node.NodeID, node.Locality, rpc.DefaultClass).ConnectNoBreaker(ctx) // Nodes would contain dead nodes that we don't need to visit. We can skip // them and let caller handle incomplete info. if err != nil { @@ -803,7 +803,7 @@ func makeVisitNode(g *gossip.Gossip, loc roachpb.Locality, rpcCtx *rpc.Context) // Note that we use ConnectNoBreaker here to avoid any race with probe // running on current node and target node restarting. Errors from circuit // breaker probes could confuse us and present node as unavailable. - conn, err = rpcCtx.GRPCDialNode(addr.String(), node.NodeID, node.Locality, rpc.DefaultClass).ConnectNoBreaker(ctx) + conn, _, err = rpcCtx.GRPCDialNode(addr.String(), node.NodeID, node.Locality, rpc.DefaultClass).ConnectNoBreaker(ctx) if err != nil { if grpcutil.IsClosedConnection(err) { log.Infof(ctx, "can't dial node n%d because connection is permanently closed: %s", diff --git a/pkg/rpc/BUILD.bazel b/pkg/rpc/BUILD.bazel index 567eb2ab197b..3f3ca878487e 100644 --- a/pkg/rpc/BUILD.bazel +++ b/pkg/rpc/BUILD.bazel @@ -15,11 +15,13 @@ go_library( "connection_class.go", "context.go", "context_testutils.go", + "drpc.go", "errors.go", "heartbeat.go", "keepalive.go", "metrics.go", "peer.go", + "peer_drpc.go", "peer_map.go", "restricted_internal_client.go", "settings.go", @@ -45,6 +47,7 @@ go_library( "//pkg/settings/cluster", "//pkg/ts/tspb", "//pkg/util", + "//pkg/util/buildutil", "//pkg/util/circuit", "//pkg/util/envutil", "//pkg/util/growstack", @@ -72,6 +75,15 @@ go_library( "@com_github_montanaflynn_stats//:stats", "@com_github_vividcortex_ewma//:ewma", "@io_opentelemetry_go_otel//attribute", + "@io_storj_drpc//:drpc", + "@io_storj_drpc//drpcconn", + "@io_storj_drpc//drpcmanager", + "@io_storj_drpc//drpcmigrate", + "@io_storj_drpc//drpcmux", + "@io_storj_drpc//drpcpool", + "@io_storj_drpc//drpcserver", + "@io_storj_drpc//drpcstream", + "@io_storj_drpc//drpcwire", "@org_golang_google_grpc//:grpc", "@org_golang_google_grpc//backoff", "@org_golang_google_grpc//codes", diff --git a/pkg/rpc/connection.go b/pkg/rpc/connection.go index 2adfd4560370..232df445aa80 100644 --- a/pkg/rpc/connection.go +++ b/pkg/rpc/connection.go @@ -11,6 +11,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/circuit" "github.com/cockroachdb/errors" "google.golang.org/grpc" + "storj.io/drpc/drpcpool" ) // Connection is a wrapper around grpc.ClientConn. It prevents the underlying @@ -39,7 +40,8 @@ type Connection struct { // RPCs. // // The pool is only initialized once the ClientConn is resolved. - batchStreamPool BatchStreamPool + batchStreamPool BatchStreamPool + drpcBatchStreamPool DRPCBatchStreamPool } // newConnectionToNodeID makes a Connection for the given node, class, and nontrivial Signal @@ -53,7 +55,8 @@ func newConnectionToNodeID( connFuture: connFuture{ ready: make(chan struct{}), }, - batchStreamPool: makeStreamPool(opts.Stopper, newBatchStream), + batchStreamPool: makeStreamPool(opts.Stopper, newBatchStream), + drpcBatchStreamPool: makeStreamPool(opts.Stopper, newDRPCBatchStream), } return c } @@ -65,14 +68,14 @@ func newConnectionToNodeID( // block but fall back to defErr in this case. func (c *Connection) waitOrDefault( ctx context.Context, defErr error, sig circuit.Signal, -) (*grpc.ClientConn, error) { +) (*grpc.ClientConn, drpcpool.Conn, error) { // Check the circuit breaker first. If it is already tripped now, we // want it to take precedence over connFuture below (which is closed in // the common case of a connection going bad after having been healthy // for a while). select { case <-sig.C(): - return nil, sig.Err() + return nil, nil, sig.Err() default: } @@ -83,26 +86,26 @@ func (c *Connection) waitOrDefault( select { case <-c.connFuture.C(): case <-sig.C(): - return nil, sig.Err() + return nil, nil, sig.Err() case <-ctx.Done(): - return nil, errors.Wrapf(ctx.Err(), "while connecting to n%d at %s", c.k.NodeID, c.k.TargetAddr) + return nil, nil, errors.Wrapf(ctx.Err(), "while connecting to n%d at %s", c.k.NodeID, c.k.TargetAddr) } } else { select { case <-c.connFuture.C(): case <-sig.C(): - return nil, sig.Err() + return nil, nil, sig.Err() case <-ctx.Done(): - return nil, errors.Wrapf(ctx.Err(), "while connecting to n%d at %s", c.k.NodeID, c.k.TargetAddr) + return nil, nil, errors.Wrapf(ctx.Err(), "while connecting to n%d at %s", c.k.NodeID, c.k.TargetAddr) default: - return nil, defErr + return nil, nil, defErr } } // Done waiting, c.connFuture has resolved, return the result. Note that this // conn could be unhealthy (or there may not even be a conn, i.e. Err() != // nil), if that's what the caller wanted (ConnectNoBreaker). - return c.connFuture.Conn(), c.connFuture.Err() + return c.connFuture.Conn(), c.connFuture.DRPCConn(), c.connFuture.Err() } // Connect returns the underlying grpc.ClientConn after it has been validated, @@ -112,6 +115,13 @@ func (c *Connection) waitOrDefault( // an error. In rare cases, this behavior is undesired and ConnectNoBreaker may // be used instead. func (c *Connection) Connect(ctx context.Context) (*grpc.ClientConn, error) { + cc, _, err := c.waitOrDefault(ctx, nil /* defErr */, c.breakerSignalFn()) + return cc, err +} + +// ConnectEx is similar to Connect but it addition to gRPC connection, it also +// returns underlying drpc connection after it has been validated. +func (c *Connection) ConnectEx(ctx context.Context) (*grpc.ClientConn, drpcpool.Conn, error) { return c.waitOrDefault(ctx, nil /* defErr */, c.breakerSignalFn()) } @@ -133,7 +143,9 @@ func (s *neverTripSignal) IsTripped() bool { // that it will latch onto (or start) an existing connection attempt even if // previous attempts have not succeeded. This may be preferable to Connect // if the caller is already certain that a peer is available. -func (c *Connection) ConnectNoBreaker(ctx context.Context) (*grpc.ClientConn, error) { +func (c *Connection) ConnectNoBreaker( + ctx context.Context, +) (*grpc.ClientConn, drpcpool.Conn, error) { // For ConnectNoBreaker we don't use the default Signal but pass a dummy one // that never trips. (The probe tears down the Conn on quiesce so we don't rely // on the Signal for that). @@ -157,7 +169,7 @@ func (c *Connection) ConnectNoBreaker(ctx context.Context) (*grpc.ClientConn, er // latest heartbeat. Returns ErrNotHeartbeated if the peer was just contacted for // the first time and the first heartbeat has not occurred yet. func (c *Connection) Health() error { - _, err := c.waitOrDefault(context.Background(), ErrNotHeartbeated, c.breakerSignalFn()) + _, _, err := c.waitOrDefault(context.Background(), ErrNotHeartbeated, c.breakerSignalFn()) return err } @@ -172,9 +184,17 @@ func (c *Connection) BatchStreamPool() *BatchStreamPool { return &c.batchStreamPool } +func (c *Connection) DRPCBatchStreamPool() *DRPCBatchStreamPool { + if !c.connFuture.Resolved() { + panic("DRPCBatchStreamPool called on unresolved connection") + } + return &c.drpcBatchStreamPool +} + type connFuture struct { ready chan struct{} cc *grpc.ClientConn + dc drpcpool.Conn err error } @@ -201,6 +221,14 @@ func (s *connFuture) Conn() *grpc.ClientConn { return s.cc } +// DRPCConn must only be called after C() has been closed. +func (s *connFuture) DRPCConn() drpcpool.Conn { + if s.err != nil { + return nil + } + return s.dc +} + func (s *connFuture) Resolved() bool { select { case <-s.ready: @@ -212,12 +240,12 @@ func (s *connFuture) Resolved() bool { // Resolve is idempotent. Only the first call has any effect. // Not thread safe. -func (s *connFuture) Resolve(cc *grpc.ClientConn, err error) { +func (s *connFuture) Resolve(cc *grpc.ClientConn, dc drpcpool.Conn, err error) { select { case <-s.ready: // Already resolved, noop. default: - s.cc, s.err = cc, err + s.cc, s.dc, s.err = cc, dc, err close(s.ready) } } diff --git a/pkg/rpc/context.go b/pkg/rpc/context.go index d54574e1d3c5..8810686384e4 100644 --- a/pkg/rpc/context.go +++ b/pkg/rpc/context.go @@ -55,7 +55,7 @@ import ( // either expects incoming connections from KV nodes, or from tenant SQL // servers. func NewServer(ctx context.Context, rpcCtx *Context, opts ...ServerOption) (*grpc.Server, error) { - srv, _ /* interceptors */, err := NewServerEx(ctx, rpcCtx, opts...) + srv, _, _, err := NewServerEx(ctx, rpcCtx, opts...) return srv, err } @@ -83,7 +83,7 @@ type ClientInterceptorInfo struct { // internalClientAdapter does). func NewServerEx( ctx context.Context, rpcCtx *Context, opts ...ServerOption, -) (s *grpc.Server, sii ServerInterceptorInfo, err error) { +) (s *grpc.Server, d *DRPCServer, sii ServerInterceptorInfo, err error) { var o serverOpts for _, f := range opts { f(&o) @@ -112,7 +112,7 @@ func NewServerEx( if !rpcCtx.ContextOptions.Insecure { tlsConfig, err := rpcCtx.GetServerTLSConfig() if err != nil { - return nil, sii, err + return nil, nil, sii, err } grpcOpts = append(grpcOpts, grpc.Creds(credentials.NewTLS(tlsConfig))) } @@ -184,8 +184,13 @@ func NewServerEx( grpcOpts = append(grpcOpts, grpc.ChainStreamInterceptor(streamInterceptor...)) s = grpc.NewServer(grpcOpts...) + d, err = newDRPCServer(ctx, rpcCtx) + if err != nil { + return nil, nil, ServerInterceptorInfo{}, err + } RegisterHeartbeatServer(s, rpcCtx.NewHeartbeatService()) - return s, ServerInterceptorInfo{ + + return s, d, ServerInterceptorInfo{ UnaryInterceptors: unaryInterceptor, StreamInterceptors: streamInterceptor, }, nil diff --git a/pkg/rpc/context_test.go b/pkg/rpc/context_test.go index d81a8ff5f6d8..7a01440d8f20 100644 --- a/pkg/rpc/context_test.go +++ b/pkg/rpc/context_test.go @@ -434,7 +434,7 @@ func TestInternalClientAdapterRunsInterceptors(t *testing.T) { serverCtx.AdvertiseAddr = "127.0.0.1:8888" serverCtx.NodeID.Set(context.Background(), 1) - _ /* server */, serverInterceptors, err := NewServerEx(ctx, serverCtx) + _ /* gRPC server */, _ /* drpc server */, serverInterceptors, err := NewServerEx(ctx, serverCtx) require.NoError(t, err) // Pile on one more interceptor to make sure it's called. @@ -535,7 +535,7 @@ func TestInternalClientAdapterWithClientStreamInterceptors(t *testing.T) { serverCtx.AdvertiseAddr = "127.0.0.1:8888" serverCtx.NodeID.Set(context.Background(), 1) - _ /* server */, serverInterceptors, err := NewServerEx(ctx, serverCtx) + _ /* gRPC server */, _ /* drpc server */, serverInterceptors, err := NewServerEx(ctx, serverCtx) require.NoError(t, err) var clientInterceptors ClientInterceptorInfo var s *testClientStream @@ -598,7 +598,7 @@ func TestInternalClientAdapterWithServerStreamInterceptors(t *testing.T) { serverCtx.AdvertiseAddr = "127.0.0.1:8888" serverCtx.NodeID.Set(context.Background(), 1) - _ /* server */, serverInterceptors, err := NewServerEx(ctx, serverCtx) + _ /* gRPC server */, _ /* drpc server */, serverInterceptors, err := NewServerEx(ctx, serverCtx) require.NoError(t, err) const int1Name = "interceptor 1" @@ -736,7 +736,7 @@ func BenchmarkInternalClientAdapter(b *testing.B) { serverCtx.AdvertiseAddr = "127.0.0.1:8888" serverCtx.NodeID.Set(context.Background(), 1) - _, interceptors, err := NewServerEx(ctx, serverCtx) + _ /* gRPC server */, _ /* drpc server */, interceptors, err := NewServerEx(ctx, serverCtx) require.NoError(b, err) internal := &internalServer{} diff --git a/pkg/rpc/drpc.go b/pkg/rpc/drpc.go new file mode 100644 index 000000000000..c67f1d6642b5 --- /dev/null +++ b/pkg/rpc/drpc.go @@ -0,0 +1,170 @@ +// Copyright 2025 The Cockroach Authors. +// +// Use of this software is governed by the CockroachDB Software License +// included in the /LICENSE file. + +package rpc + +import ( + "context" + "crypto/tls" + "math" + "net" + + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/errors" + "storj.io/drpc" + "storj.io/drpc/drpcconn" + "storj.io/drpc/drpcmanager" + "storj.io/drpc/drpcmigrate" + "storj.io/drpc/drpcmux" + "storj.io/drpc/drpcpool" + "storj.io/drpc/drpcserver" + "storj.io/drpc/drpcstream" + "storj.io/drpc/drpcwire" +) + +// ErrDRPCDisabled is returned from hosts that in principle could but do not +// have the DRPC server enabled. +var ErrDRPCDisabled = errors.New("DRPC is not enabled") + +type drpcServerI interface { + Serve(ctx context.Context, lis net.Listener) error +} + +type drpcMuxI interface { + Register(srv interface{}, desc drpc.Description) error +} + +type DRPCServer struct { + Srv drpcServerI + Mux drpcMuxI + TLSCfg *tls.Config +} + +var _ drpcServerI = (*drpcserver.Server)(nil) +var _ drpcServerI = (*drpcOffServer)(nil) + +func newDRPCServer(_ context.Context, rpcCtx *Context) (*DRPCServer, error) { + var dmux drpcMuxI = &drpcOffServer{} + var dsrv drpcServerI = &drpcOffServer{} + var tlsCfg *tls.Config + + if ExperimentalDRPCEnabled.Get(&rpcCtx.Settings.SV) { + mux := drpcmux.New() + dsrv = drpcserver.NewWithOptions(mux, drpcserver.Options{ + Log: func(err error) { + log.Warningf(context.Background(), "drpc server error %v", err) + }, + // The reader's max buffer size defaults to 4mb, and if it is exceeded (such + // as happens with AddSSTable) the RPCs fail. + Manager: drpcmanager.Options{Reader: drpcwire.ReaderOptions{MaximumBufferSize: math.MaxInt}}, + }) + dmux = mux + + var err error + tlsCfg, err = rpcCtx.GetServerTLSConfig() + if err != nil { + return nil, err + } + + // NB: any server middleware (server interceptors in gRPC parlance) would go + // here: + // dmux = whateverMiddleware1(dmux) + // dmux = whateverMiddleware2(dmux) + // ... + // + // Each middleware must implement the Handler interface: + // + // HandleRPC(stream Stream, rpc string) error + // + // where Stream + // See here for an example: + // https://github.com/bryk-io/pkg/blob/4da5fbfef47770be376e4022eab5c6c324984bf7/net/drpc/server.go#L91-L101 + } + + return &DRPCServer{ + Srv: dsrv, + Mux: dmux, + TLSCfg: tlsCfg, + }, nil +} + +func dialDRPC(rpcCtx *Context) func(ctx context.Context, target string) (drpcpool.Conn, error) { + return func(ctx context.Context, target string) (drpcpool.Conn, error) { + // TODO(server): could use connection class instead of empty key here. + pool := drpcpool.New[struct{}, drpcpool.Conn](drpcpool.Options{}) + pooledConn := pool.Get(ctx /* unused */, struct{}{}, func(ctx context.Context, + _ struct{}) (drpcpool.Conn, error) { + + netConn, err := drpcmigrate.DialWithHeader(ctx, "tcp", target, drpcmigrate.DRPCHeader) + if err != nil { + return nil, err + } + + opts := drpcconn.Options{ + Manager: drpcmanager.Options{ + Reader: drpcwire.ReaderOptions{ + MaximumBufferSize: math.MaxInt, + }, + Stream: drpcstream.Options{ + MaximumBufferSize: 0, // unlimited + }, + }, + } + var conn *drpcconn.Conn + if rpcCtx.ContextOptions.Insecure { + conn = drpcconn.NewWithOptions(netConn, opts) + } else { + tlsConfig, err := rpcCtx.GetClientTLSConfig() + if err != nil { + return nil, err + } + tlsConn := tls.Client(netConn, tlsConfig) + // TODO(server): remove this hack which is necessary at least in + // testing to get TestDRPCSelectQuery to pass. + tlsConfig.InsecureSkipVerify = true + conn = drpcconn.NewWithOptions(tlsConn, opts) + } + + return conn, nil + }) + // `pooledConn.Close` doesn't tear down any of the underlying TCP + // connections but simply marks the pooledConn handle as returning + // errors. When we "close" this conn, we want to tear down all of + // the connections in the pool (in effect mirroring the behavior of + // gRPC where a single conn is shared). + return &closeEntirePoolConn{ + Conn: pooledConn, + pool: pool, + }, nil + } +} + +type closeEntirePoolConn struct { + drpcpool.Conn + pool *drpcpool.Pool[struct{}, drpcpool.Conn] +} + +func (c *closeEntirePoolConn) Close() error { + _ = c.Conn.Close() + return c.pool.Close() +} + +// drpcOffServer is used for drpcServerI and drpcMuxI if the DRPC server is +// disabled. It immediately closes accepted connections and returns +// ErrDRPCDisabled. +type drpcOffServer struct{} + +func (srv *drpcOffServer) Serve(_ context.Context, lis net.Listener) error { + conn, err := lis.Accept() + if err != nil { + return err + } + _ = conn.Close() + return ErrDRPCDisabled +} + +func (srv *drpcOffServer) Register(interface{}, drpc.Description) error { + return nil +} diff --git a/pkg/rpc/nodedialer/BUILD.bazel b/pkg/rpc/nodedialer/BUILD.bazel index b405795e4d32..76e785a8bab0 100644 --- a/pkg/rpc/nodedialer/BUILD.bazel +++ b/pkg/rpc/nodedialer/BUILD.bazel @@ -2,7 +2,10 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "nodedialer", - srcs = ["nodedialer.go"], + srcs = [ + "nodedialer.go", + "nodedialer_drpc.go", + ], importpath = "github.com/cockroachdb/cockroach/pkg/rpc/nodedialer", visibility = ["//visibility:public"], deps = [ @@ -20,6 +23,7 @@ go_library( "//pkg/util/stop", "//pkg/util/tracing", "@com_github_cockroachdb_errors//:errors", + "@io_storj_drpc//drpcpool", "@org_golang_google_grpc//:grpc", ], ) diff --git a/pkg/rpc/nodedialer/nodedialer.go b/pkg/rpc/nodedialer/nodedialer.go index 4bebfd4f5dd4..3faff95ea406 100644 --- a/pkg/rpc/nodedialer/nodedialer.go +++ b/pkg/rpc/nodedialer/nodedialer.go @@ -25,6 +25,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/tracing" "github.com/cockroachdb/errors" "google.golang.org/grpc" + "storj.io/drpc/drpcpool" ) // An AddressResolver translates NodeIDs into addresses. @@ -100,7 +101,7 @@ func (n *Dialer) Dial( err = errors.Wrapf(err, "failed to resolve n%d", nodeID) return nil, err } - conn, _, err := n.dial(ctx, nodeID, addr, locality, true, class) + conn, _, _, _, err := n.dial(ctx, nodeID, addr, locality, true, class) return conn, err } @@ -117,7 +118,7 @@ func (n *Dialer) DialNoBreaker( if err != nil { return nil, err } - conn, _, err := n.dial(ctx, nodeID, addr, locality, false, class) + conn, _, _, _, err := n.dial(ctx, nodeID, addr, locality, false, class) return conn, err } @@ -147,14 +148,31 @@ func (n *Dialer) DialInternalClient( return nil, errors.Wrap(err, "resolver error") } log.VEventf(ctx, 2, "sending request to %s", addr) - conn, pool, err := n.dial(ctx, nodeID, addr, locality, true, class) + conn, pool, dconn, drpcBatchStreamPool, err := n.dial(ctx, nodeID, addr, locality, true, class) if err != nil { return nil, err } + client := newBaseInternalClient(conn) - if shouldUseBatchStreamPoolClient(ctx, n.rpcContext.Settings) { + useStreamPoolClient := shouldUseBatchStreamPoolClient(ctx, n.rpcContext.Settings) + if useStreamPoolClient { client = newBatchStreamPoolClient(pool) } + + if rpc.ExperimentalDRPCEnabled.Get(&n.rpcContext.Settings.SV) { + // TODO(server): gRPC version of batch stream pool implements + // rpc.RestrictedInternalClient and is allocation-optimized, + // whereas here we allocate a new throw-away + // unaryDRPCBatchServiceToInternalAdapter. + client = &unaryDRPCBatchServiceToInternalAdapter{ + useStreamPoolClient: useStreamPoolClient, + RestrictedInternalClient: client, // for RangeFeed only + drpcClient: kvpb.NewDRPCBatchClient(dconn), + drpcStreamPool: drpcBatchStreamPool, + } + return client, nil + } + client = maybeWrapInTracingClient(ctx, client) return client, nil } @@ -169,28 +187,29 @@ func (n *Dialer) dial( locality roachpb.Locality, checkBreaker bool, class rpc.ConnectionClass, -) (_ *grpc.ClientConn, _ *rpc.BatchStreamPool, err error) { +) (*grpc.ClientConn, *rpc.BatchStreamPool, drpcpool.Conn, *rpc.DRPCBatchStreamPool, error) { const ctxWrapMsg = "dial" // Don't trip the breaker if we're already canceled. if ctxErr := ctx.Err(); ctxErr != nil { - return nil, nil, errors.Wrap(ctxErr, ctxWrapMsg) + return nil, nil, nil, nil, errors.Wrap(ctxErr, ctxWrapMsg) } rpcConn := n.rpcContext.GRPCDialNode(addr.String(), nodeID, locality, class) - connect := rpcConn.Connect + connect := rpcConn.ConnectEx if !checkBreaker { connect = rpcConn.ConnectNoBreaker } - conn, err := connect(ctx) + conn, dconn, err := connect(ctx) if err != nil { // If we were canceled during the dial, don't trip the breaker. if ctxErr := ctx.Err(); ctxErr != nil { - return nil, nil, errors.Wrap(ctxErr, ctxWrapMsg) + return nil, nil, nil, nil, errors.Wrap(ctxErr, ctxWrapMsg) } err = errors.Wrapf(err, "failed to connect to n%d at %v", nodeID, addr) - return nil, nil, err + return nil, nil, nil, nil, err } pool := rpcConn.BatchStreamPool() - return conn, pool, nil + drpcStreamPool := rpcConn.DRPCBatchStreamPool() + return conn, pool, dconn, drpcStreamPool, nil } // ConnHealth returns nil if we have an open connection of the request diff --git a/pkg/rpc/nodedialer/nodedialer_drpc.go b/pkg/rpc/nodedialer/nodedialer_drpc.go new file mode 100644 index 000000000000..d4ff92fdd1df --- /dev/null +++ b/pkg/rpc/nodedialer/nodedialer_drpc.go @@ -0,0 +1,35 @@ +// Copyright 2025 The Cockroach Authors. +// +// Use of this software is governed by the CockroachDB Software License +// included in the /LICENSE file. + +package nodedialer + +import ( + "context" + + "github.com/cockroachdb/cockroach/pkg/kv/kvpb" + "github.com/cockroachdb/cockroach/pkg/rpc" + "github.com/cockroachdb/errors" + "google.golang.org/grpc" +) + +type unaryDRPCBatchServiceToInternalAdapter struct { + useStreamPoolClient bool + rpc.RestrictedInternalClient + drpcClient kvpb.DRPCBatchClient + drpcStreamPool *rpc.DRPCBatchStreamPool +} + +func (a *unaryDRPCBatchServiceToInternalAdapter) Batch( + ctx context.Context, in *kvpb.BatchRequest, opts ...grpc.CallOption, +) (*kvpb.BatchResponse, error) { + if len(opts) > 0 { + return nil, errors.New("CallOptions unsupported") + } + if a.useStreamPoolClient && a.drpcStreamPool != nil { + return a.drpcStreamPool.Send(ctx, in) + } + + return a.drpcClient.Batch(ctx, in) +} diff --git a/pkg/rpc/peer.go b/pkg/rpc/peer.go index 3cc0bb599168..05ca9e972be4 100644 --- a/pkg/rpc/peer.go +++ b/pkg/rpc/peer.go @@ -27,6 +27,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/status" + "storj.io/drpc/drpcpool" ) type peerStatus int @@ -125,6 +126,7 @@ type peer struct { heartbeatInterval time.Duration heartbeatTimeout time.Duration dial func(ctx context.Context, target string, class ConnectionClass) (*grpc.ClientConn, error) + dialDRPC func(ctx context.Context, target string) (drpcpool.Conn, error) // b maintains connection health. This breaker's async probe is always // active - it is the heartbeat loop and manages `mu.c.` (including // recreating it after the connection fails and has to be redialed). @@ -245,6 +247,7 @@ func (rpcCtx *Context) newPeer(k peerKey, locality roachpb.Locality) *peer { additionalDialOpts = append(additionalDialOpts, rpcCtx.testingDialOpts...) return rpcCtx.grpcDialRaw(ctx, target, class, additionalDialOpts...) }, + dialDRPC: dialDRPC(rpcCtx), heartbeatInterval: rpcCtx.RPCHeartbeatInterval, heartbeatTimeout: rpcCtx.RPCHeartbeatTimeout, } @@ -381,6 +384,13 @@ func (p *peer) runOnce(ctx context.Context, report func(error)) error { defer func() { _ = cc.Close() // nolint:grpcconnclose }() + dc, err := p.dialDRPC(ctx, p.k.TargetAddr) + if err != nil { + return err + } + defer func() { + _ = dc.Close() + }() // Set up notifications on a channel when gRPC tears down, so that we // can trigger another instant heartbeat for expedited circuit breaker @@ -399,7 +409,7 @@ func (p *peer) runOnce(ctx context.Context, report func(error)) error { return err } - p.onInitialHeartbeatSucceeded(ctx, p.opts.Clock.Now(), cc, report) + p.onInitialHeartbeatSucceeded(ctx, p.opts.Clock.Now(), cc, dc, report) return p.runHeartbeatUntilFailure(ctx, connFailedCh) } @@ -563,7 +573,7 @@ func logOnHealthy(ctx context.Context, disconnected, now time.Time) { } func (p *peer) onInitialHeartbeatSucceeded( - ctx context.Context, now time.Time, cc *grpc.ClientConn, report func(err error), + ctx context.Context, now time.Time, cc *grpc.ClientConn, dc drpcpool.Conn, report func(err error), ) { // First heartbeat succeeded. By convention we update the breaker // before updating the peer. The other way is fine too, just the @@ -586,10 +596,11 @@ func (p *peer) onInitialHeartbeatSucceeded( // ahead of signaling the connFuture, so that the stream pool is ready for use // by the time the connFuture is resolved. p.mu.c.batchStreamPool.Bind(ctx, cc) + p.mu.c.drpcBatchStreamPool.Bind(ctx, dc) // Close the channel last which is helpful for unit tests that // first waitOrDefault for a healthy conn to then check metrics. - p.mu.c.connFuture.Resolve(cc, nil /* err */) + p.mu.c.connFuture.Resolve(cc, dc, nil /* err */) logOnHealthy(ctx, p.mu.disconnected, now) } @@ -706,7 +717,7 @@ func (p *peer) onHeartbeatFailed( // someone might be waiting on it in ConnectNoBreaker who is not paying // attention to the circuit breaker. err = &netutil.InitialHeartbeatFailedError{WrappedErr: err} - ls.c.connFuture.Resolve(nil /* cc */, err) + ls.c.connFuture.Resolve(nil /* cc */, nil /* dc */, err) } // Close down the stream pool that was bound to this connection. @@ -746,7 +757,7 @@ func (p *peer) onQuiesce(report func(error)) { // NB: it's important that connFuture is resolved, or a caller sitting on // `c.ConnectNoBreaker` would never be unblocked; after all, the probe won't // start again in the future. - p.snap().c.connFuture.Resolve(nil, errQuiescing) + p.snap().c.connFuture.Resolve(nil, nil, errQuiescing) } func (p PeerSnap) deletable(now time.Time) bool { diff --git a/pkg/rpc/peer_drpc.go b/pkg/rpc/peer_drpc.go new file mode 100644 index 000000000000..e1248e1f329d --- /dev/null +++ b/pkg/rpc/peer_drpc.go @@ -0,0 +1,34 @@ +// Copyright 2025 The Cockroach Authors. +// +// Use of this software is governed by the CockroachDB Software License +// included in the /LICENSE file. + +package rpc + +import ( + "github.com/cockroachdb/cockroach/pkg/settings" + "github.com/cockroachdb/cockroach/pkg/util/buildutil" + "github.com/cockroachdb/cockroach/pkg/util/envutil" + "github.com/cockroachdb/errors" +) + +var envExperimentalDRPCEnabled = envutil.EnvOrDefaultBool("COCKROACH_EXPERIMENTAL_DRPC_ENABLED", false) + +// ExperimentalDRPCEnabled determines whether a drpc server accepting BatchRequest +// is enabled. This server is experimental and completely unsuitable to production +// usage (for example, does not implement authorization checks). +var ExperimentalDRPCEnabled = settings.RegisterBoolSetting( + settings.ApplicationLevel, + "rpc.experimental_drpc.enabled", + "if true, use drpc to execute Batch RPCs (instead of gRPC)", + envExperimentalDRPCEnabled, + settings.WithValidateBool(func(values *settings.Values, b bool) error { + // drpc support is highly experimental and should not be enabled in production. + // Since authorization is not implemented, we only even host the server if the + // env var is set or it's a CRDB test build. Consequently, these are prereqs + // for setting the cluster setting. + if b && !(envExperimentalDRPCEnabled || buildutil.CrdbTestBuild) { + return errors.New("experimental drpc is not allowed in this environment") + } + return nil + })) diff --git a/pkg/rpc/stream_pool.go b/pkg/rpc/stream_pool.go index 20773d714f0a..b12ea472881b 100644 --- a/pkg/rpc/stream_pool.go +++ b/pkg/rpc/stream_pool.go @@ -17,6 +17,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/syncutil" "github.com/cockroachdb/errors" "google.golang.org/grpc" + "storj.io/drpc" ) // streamClient is a type constraint that is satisfied by a bidirectional gRPC @@ -331,3 +332,12 @@ type BatchStreamClient = streamClient[*kvpb.BatchRequest, *kvpb.BatchResponse] func newBatchStream(ctx context.Context, cc *grpc.ClientConn) (BatchStreamClient, error) { return kvpb.NewInternalClient(cc).BatchStream(ctx) } + +type DRPCBatchStreamPool = streamPool[*kvpb.BatchRequest, *kvpb.BatchResponse, drpc.Conn] + +type DRPCBatchStreamClient = streamClient[*kvpb.BatchRequest, *kvpb.BatchResponse] + +// newDRPCBatchStream constructs a BatchStreamClient from a drpc.Conn. +func newDRPCBatchStream(ctx context.Context, dc drpc.Conn) (DRPCBatchStreamClient, error) { + return kvpb.NewDRPCBatchClient(dc).BatchStream(ctx) +} diff --git a/pkg/server/BUILD.bazel b/pkg/server/BUILD.bazel index b31d88567417..2ed0eb3f88de 100644 --- a/pkg/server/BUILD.bazel +++ b/pkg/server/BUILD.bazel @@ -64,6 +64,7 @@ go_library( "span_download.go", "span_stats_server.go", "sql_stats.go", + "start_drpc_listener.go", "start_listen.go", "statement_details.go", "statement_diagnostics_requests.go", @@ -368,6 +369,7 @@ go_library( "@com_github_prometheus_client_model//go", "@com_github_prometheus_common//expfmt", "@in_gopkg_yaml_v2//:yaml_v2", + "@io_storj_drpc//drpcmigrate", "@org_golang_google_grpc//:grpc", "@org_golang_google_grpc//codes", "@org_golang_google_grpc//metadata", @@ -439,6 +441,7 @@ go_test( "critical_nodes_test.go", "distsql_flows_test.go", "drain_test.go", + "drpc_test.go", "get_local_files_test.go", "graphite_test.go", "grpc_gateway_test.go", @@ -460,6 +463,7 @@ go_test( "purge_auth_session_test.go", "server_controller_http_test.go", "server_controller_test.go", + "server_drpc_test.go", "server_http_test.go", "server_import_ts_test.go", "server_internal_executor_factory_test.go", @@ -600,6 +604,8 @@ go_test( "@com_github_stretchr_testify//require", "@in_gopkg_yaml_v2//:yaml_v2", "@io_opentelemetry_go_otel//attribute", + "@io_storj_drpc//drpcconn", + "@io_storj_drpc//drpcmigrate", "@org_golang_google_grpc//:grpc", "@org_golang_google_grpc//codes", "@org_golang_google_grpc//metadata", diff --git a/pkg/server/drpc_test.go b/pkg/server/drpc_test.go new file mode 100644 index 000000000000..7423e0f59fe5 --- /dev/null +++ b/pkg/server/drpc_test.go @@ -0,0 +1,81 @@ +// Copyright 2025 The Cockroach Authors. +// +// Use of this software is governed by the CockroachDB Software License +// included in the /LICENSE file. + +package server_test + +import ( + "context" + "crypto/tls" + "testing" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/kv/kvpb" + "github.com/cockroachdb/cockroach/pkg/rpc" + "github.com/cockroachdb/cockroach/pkg/settings/cluster" + "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/testutils/testcluster" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/stretchr/testify/require" + "storj.io/drpc/drpcconn" + "storj.io/drpc/drpcmigrate" +) + +// TestDRPCBatchServer verifies that CRDB nodes can host a drpc server that +// serves BatchRequest. It doesn't verify that nodes use drpc to communiate with +// each other. +func TestDRPCBatchServer(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + ctx := context.Background() + const numNodes = 1 + + testutils.RunTrueAndFalse(t, "insecure", func(t *testing.T, insecure bool) { + args := base.TestClusterArgs{} + args.ServerArgs.Insecure = insecure + args.ReplicationMode = base.ReplicationManual + args.ServerArgs.Settings = cluster.MakeClusterSettings() + rpc.ExperimentalDRPCEnabled.Override(ctx, &args.ServerArgs.Settings.SV, true) + c := testcluster.StartTestCluster(t, numNodes, args) + defer c.Stopper().Stop(ctx) + + require.Equal(t, insecure, c.Server(0).RPCContext().Insecure) + + rpcAddr := c.Server(0).RPCAddr() + + // Dial the drpc server with the drpc connection header. + rawconn, err := drpcmigrate.DialWithHeader(ctx, "tcp", rpcAddr, drpcmigrate.DRPCHeader) + require.NoError(t, err) + + var conn *drpcconn.Conn + if !insecure { + cm, err := c.Server(0).RPCContext().GetCertificateManager() + require.NoError(t, err) + tlsCfg, err := cm.GetNodeClientTLSConfig() + require.NoError(t, err) + tlsCfg = tlsCfg.Clone() + tlsCfg.ServerName = "*.local" + tlsConn := tls.Client(rawconn, tlsCfg) + conn = drpcconn.New(tlsConn) + } else { + conn = drpcconn.New(rawconn) + } + defer func() { require.NoError(t, conn.Close()) }() + + desc := c.LookupRangeOrFatal(t, c.ScratchRange(t)) + + client := kvpb.NewDRPCBatchClient(conn) + ba := &kvpb.BatchRequest{} + ba.RangeID = desc.RangeID + var ok bool + ba.Replica, ok = desc.GetReplicaDescriptor(1) + require.True(t, ok) + req := &kvpb.LeaseInfoRequest{} + req.Key = desc.StartKey.AsRawKey() + ba.Add(req) + _, err = client.Batch(ctx, ba) + require.NoError(t, err) + }) +} diff --git a/pkg/server/grpc_server.go b/pkg/server/grpc_server.go index 7fbc809af58a..862410d1af18 100644 --- a/pkg/server/grpc_server.go +++ b/pkg/server/grpc_server.go @@ -22,6 +22,7 @@ import ( // RPCs. type grpcServer struct { *grpc.Server + drpc *rpc.DRPCServer serverInterceptorsInfo rpc.ServerInterceptorInfo mode serveMode } @@ -29,7 +30,7 @@ type grpcServer struct { func newGRPCServer(ctx context.Context, rpcCtx *rpc.Context) (*grpcServer, error) { s := &grpcServer{} s.mode.set(modeInitializing) - srv, interceptorInfo, err := rpc.NewServerEx( + srv, dsrv, interceptorInfo, err := rpc.NewServerEx( ctx, rpcCtx, rpc.WithInterceptor(func(path string) error { return s.intercept(path) })) @@ -37,6 +38,7 @@ func newGRPCServer(ctx context.Context, rpcCtx *rpc.Context) (*grpcServer, error return nil, err } s.Server = srv + s.drpc = dsrv s.serverInterceptorsInfo = interceptorInfo return s, nil } diff --git a/pkg/server/node.go b/pkg/server/node.go index b319a9892a3a..3dad84c9670d 100644 --- a/pkg/server/node.go +++ b/pkg/server/node.go @@ -1875,6 +1875,18 @@ func (n *Node) Batch(ctx context.Context, args *kvpb.BatchRequest) (*kvpb.BatchR // BatchStream implements the kvpb.InternalServer interface. func (n *Node) BatchStream(stream kvpb.Internal_BatchStreamServer) error { + return n.batchStreamImpl(stream, func(ba *kvpb.BatchRequest) error { + return stream.RecvMsg(ba) + }) +} + +func (n *Node) batchStreamImpl( + stream interface { + Context() context.Context + Send(response *kvpb.BatchResponse) error + }, + recvMsg func(*kvpb.BatchRequest) error, +) error { ctx := stream.Context() for { argsAlloc := new(struct { @@ -1884,10 +1896,8 @@ func (n *Node) BatchStream(stream kvpb.Internal_BatchStreamServer) error { args := &argsAlloc.args args.Requests = argsAlloc.reqs[:0] - err := stream.RecvMsg(args) + err := recvMsg(args) if err != nil { - // From grpc.ServerStream.Recv: - // > It returns io.EOF when the client has performed a CloseSend. if errors.Is(err, io.EOF) { return nil } @@ -1905,6 +1915,26 @@ func (n *Node) BatchStream(stream kvpb.Internal_BatchStreamServer) error { } } +func (n *Node) AsDRPCBatchServer() kvpb.DRPCBatchServer { + return (*drpcNode)(n) +} + +type drpcNode Node + +func (n *drpcNode) Batch( + ctx context.Context, request *kvpb.BatchRequest, +) (*kvpb.BatchResponse, error) { + return (*Node)(n).Batch(ctx, request) +} + +func (n *drpcNode) BatchStream(stream kvpb.DRPCBatch_BatchStreamStream) error { + return (*Node)(n).batchStreamImpl(stream, func(ba *kvpb.BatchRequest) error { + return stream.(interface { + RecvMsg(request *kvpb.BatchRequest) error + }).RecvMsg(ba) + }) +} + // spanForRequest is the retval of setupSpanForIncomingRPC. It groups together a // few variables needed when finishing an RPC's span. // diff --git a/pkg/server/server.go b/pkg/server/server.go index c462243e4ebb..0666cb3c9fb5 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -957,6 +957,9 @@ func NewServer(cfg Config, stopper *stop.Stopper) (serverctl.ServerStartupInterf cfg.LicenseEnforcer, ) kvpb.RegisterInternalServer(grpcServer.Server, node) + if err := kvpb.DRPCRegisterBatch(grpcServer.drpc.Mux, node.AsDRPCBatchServer()); err != nil { + return nil, err + } kvserver.RegisterPerReplicaServer(grpcServer.Server, node.perReplicaServer) kvserver.RegisterPerStoreServer(grpcServer.Server, node.perReplicaServer) ctpb.RegisterSideTransportServer(grpcServer.Server, ctReceiver) diff --git a/pkg/server/server_drpc_test.go b/pkg/server/server_drpc_test.go new file mode 100644 index 000000000000..2873169a103a --- /dev/null +++ b/pkg/server/server_drpc_test.go @@ -0,0 +1,58 @@ +// Copyright 2025 The Cockroach Authors. +// +// Use of this software is governed by the CockroachDB Software License +// included in the /LICENSE file. + +package server + +import ( + "context" + "math/rand" + "testing" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/rpc" + "github.com/cockroachdb/cockroach/pkg/settings/cluster" + "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/stretchr/testify/require" +) + +func TestDRPCSelectQuery(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + testutils.RunTrueAndFalse(t, "insecure", func(t *testing.T, insecure bool) { + ctx, cancel := context.WithTimeout(context.Background(), testutils.SucceedsSoonDuration()) + defer cancel() + + st := cluster.MakeTestingClusterSettings() + rpc.ExperimentalDRPCEnabled.Override(ctx, &st.SV, true) + + tc := serverutils.StartCluster(t, 3, base.TestClusterArgs{ + ServerArgs: base.TestServerArgs{ + Settings: st, + Insecure: insecure, + }, + }) + defer tc.Stopper().Stop(ctx) + + idx := rand.Intn(tc.NumServers()) + t.Logf("querying from node %d", idx+1) + db := tc.ServerConn(idx) + defer db.Close() + + rows, err := db.QueryContext(ctx, "SELECT count(*) FROM system.tenants") + require.NoError(t, err) + defer rows.Close() + + for rows.Next() { + var count int + require.NoError(t, rows.Scan(&count)) + require.Equal(t, 1, count) + } + require.NoError(t, rows.Err()) + }) +} diff --git a/pkg/server/start_drpc_listener.go b/pkg/server/start_drpc_listener.go new file mode 100644 index 000000000000..96bd999a100c --- /dev/null +++ b/pkg/server/start_drpc_listener.go @@ -0,0 +1,66 @@ +// Copyright 2025 The Cockroach Authors. +// +// Use of this software is governed by the CockroachDB Software License +// included in the /LICENSE file. + +package server + +import ( + "bytes" + "io" + "net" + + "storj.io/drpc/drpcmigrate" +) + +var drpcMatcher = func(reader io.Reader) bool { + buf := make([]byte, len(drpcmigrate.DRPCHeader)) + if _, err := io.ReadFull(reader, buf); err != nil { + return false + } + return bytes.Equal(buf, []byte(drpcmigrate.DRPCHeader)) +} + +type dropDRPCHeaderListener struct { + wrapped net.Listener +} + +func (ln *dropDRPCHeaderListener) Accept() (net.Conn, error) { + conn, err := ln.wrapped.Accept() + if err != nil { + return nil, err + } + buf := make([]byte, len(drpcmigrate.DRPCHeader)) + if _, err := io.ReadFull(conn, buf); err != nil { + return nil, err + } + return conn, nil +} + +func (ln *dropDRPCHeaderListener) Close() error { + return ln.wrapped.Close() +} + +func (ln *dropDRPCHeaderListener) Addr() net.Addr { + return ln.wrapped.Addr() +} + +type noopListener struct{ done chan struct{} } + +func (l *noopListener) Accept() (net.Conn, error) { + <-l.done + return nil, net.ErrClosed +} + +func (l *noopListener) Close() error { + if l.done == nil { + return nil + } + close(l.done) + l.done = nil + return nil +} + +func (l *noopListener) Addr() net.Addr { + return nil +} diff --git a/pkg/server/start_listen.go b/pkg/server/start_listen.go index 456cd7206ba7..64b9eec20bd1 100644 --- a/pkg/server/start_listen.go +++ b/pkg/server/start_listen.go @@ -7,6 +7,7 @@ package server import ( "context" + "crypto/tls" "io" "net" "sync" @@ -131,21 +132,42 @@ func startListenRPCAndSQL( } } - anyL := m.Match(cmux.Any()) + // Host drpc only if it's _possible_ to turn it on (this requires a test build + // or env var). If the setting _is_ on, then it was overridden in testing and + // we want to host the server too. + hostDRPC := rpc.ExperimentalDRPCEnabled.Validate(nil /* not used */, true) == nil || + rpc.ExperimentalDRPCEnabled.Get(&cfg.Settings.SV) + + // If we're not hosting drpc, make a listener that never accepts anything. + // We will start the dRPC server all the same; it barely consumes any + // resources. + var drpcL net.Listener = &noopListener{make(chan struct{})} + if hostDRPC { + // Throw away the header before passing the conn to the drpc server. This + // would not be required explicitly if we used `drpcmigrate.ListenMux` but + // cmux keeps the prefix. + drpcL = &dropDRPCHeaderListener{wrapped: m.Match(drpcMatcher)} + } + + grpcL := m.Match(cmux.Any()) if serverTestKnobs, ok := cfg.TestingKnobs.Server.(*TestingKnobs); ok { if serverTestKnobs.ContextTestingKnobs.InjectedLatencyOracle != nil { - anyL = rpc.NewDelayingListener(anyL, serverTestKnobs.ContextTestingKnobs.InjectedLatencyEnabled) + grpcL = rpc.NewDelayingListener(grpcL, serverTestKnobs.ContextTestingKnobs.InjectedLatencyEnabled) + drpcL = rpc.NewDelayingListener(drpcL, serverTestKnobs.ContextTestingKnobs.InjectedLatencyEnabled) } } rpcLoopbackL := netutil.NewLoopbackListener(ctx, stopper) sqlLoopbackL := netutil.NewLoopbackListener(ctx, stopper) + drpcCtx, drpcCancel := context.WithCancel(workersCtx) // The remainder shutdown worker. waitForQuiesce := func(context.Context) { <-stopper.ShouldQuiesce() + drpcCancel() // TODO(bdarnell): Do we need to also close the other listeners? - netutil.FatalIfUnexpected(anyL.Close()) + netutil.FatalIfUnexpected(grpcL.Close()) + netutil.FatalIfUnexpected(drpcL.Close()) netutil.FatalIfUnexpected(rpcLoopbackL.Close()) netutil.FatalIfUnexpected(sqlLoopbackL.Close()) netutil.FatalIfUnexpected(ln.Close()) @@ -160,12 +182,14 @@ func startListenRPCAndSQL( netutil.FatalIfUnexpected(m.Serve()) }) } + stopper.AddCloser(stop.CloserFn(stopGRPC)) if err := stopper.RunAsyncTask( - workersCtx, "grpc-quiesce", waitForQuiesce, + workersCtx, "grpc-drpc-quiesce", waitForQuiesce, ); err != nil { waitForQuiesce(ctx) stopGRPC() + drpcCancel() return nil, nil, nil, nil, err } stopper.AddCloser(stop.CloserFn(stopGRPC)) @@ -177,7 +201,15 @@ func startListenRPCAndSQL( startRPCServer = func(ctx context.Context) { // Serve the gRPC endpoint. _ = stopper.RunAsyncTask(workersCtx, "serve-grpc", func(context.Context) { - netutil.FatalIfUnexpected(grpc.Serve(anyL)) + netutil.FatalIfUnexpected(grpc.Serve(grpcL)) + }) + _ = stopper.RunAsyncTask(drpcCtx, "serve-drpc", func(ctx context.Context) { + if cfg := grpc.drpc.TLSCfg; cfg != nil { + drpcTLSL := tls.NewListener(drpcL, cfg) + netutil.FatalIfUnexpected(grpc.drpc.Srv.Serve(ctx, drpcTLSL)) + } else { + netutil.FatalIfUnexpected(grpc.drpc.Srv.Serve(ctx, drpcL)) + } }) _ = stopper.RunAsyncTask(workersCtx, "serve-loopback-grpc", func(context.Context) { netutil.FatalIfUnexpected(grpc.Serve(rpcLoopbackL))