Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom Error Handling for Server #588

Merged
merged 7 commits into from
Dec 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
## 3.1.1-dev

* Require Dart 2.17 or greater.
* Fix issue [#51](https://github.com/grpc/grpc-dart/issues/51), add support for custom error handling.

## 3.1.0

Expand Down
7 changes: 3 additions & 4 deletions example/helloworld/bin/server.dart
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,9 @@ class GreeterService extends GreeterServiceBase {
}

Future<void> main(List<String> args) async {
final server = Server(
[GreeterService()],
const <Interceptor>[],
CodecRegistry(codecs: const [GzipCodec(), IdentityCodec()]),
final server = Server.create(
services: [GreeterService()],
codecRegistry: CodecRegistry(codecs: const [GzipCodec(), IdentityCodec()]),
);
await server.serve(port: 50051);
print('Server listening on port ${server.port}...');
Expand Down
2 changes: 1 addition & 1 deletion example/helloworld/bin/unix_server.dart
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class GreeterService extends GreeterServiceBase {
Future<void> main(List<String> args) async {
final udsAddress =
InternetAddress('localhost', type: InternetAddressType.unix);
final server = Server([GreeterService()]);
final server = Server.create(services: [GreeterService()]);
await server.serve(address: udsAddress);
print('Start UNIX Server @localhost...');
}
2 changes: 1 addition & 1 deletion example/metadata/lib/src/server.dart
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class MetadataService extends MetadataServiceBase {

class Server {
Future<void> main(List<String> args) async {
final server = grpc.Server([MetadataService()]);
final server = grpc.Server.create(services: [MetadataService()]);
await server.serve(port: 8080);
print('Server listening on port ${server.port}...');
}
Expand Down
2 changes: 1 addition & 1 deletion example/route_guide/lib/src/server.dart
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class RouteGuideService extends RouteGuideServiceBase {

class Server {
Future<void> main(List<String> args) async {
final server = grpc.Server([RouteGuideService()]);
final server = grpc.Server.create(services: [RouteGuideService()]);
await server.serve(port: 8080);
print('Server listening on port ${server.port}...');
}
Expand Down
2 changes: 1 addition & 1 deletion interop/bin/server.dart
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ Future<void> main(List<String> args) async {

final services = [TestService()];

final server = Server(services);
final server = Server.create(services: services);

ServerTlsCredentials? tlsCredentials;
if (arguments['use_tls'] == 'true') {
Expand Down
54 changes: 30 additions & 24 deletions lib/src/server/handler.dart
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,16 @@ import 'call.dart';
import 'interceptor.dart';
import 'service.dart';

typedef ServiceLookup = Service? Function(String service);
typedef GrpcErrorHandler = void Function(GrpcError error, StackTrace? trace);

/// Handles an incoming gRPC call.
class ServerHandlerImpl extends ServiceCall {
class ServerHandler extends ServiceCall {
final ServerTransportStream _stream;
final Service? Function(String service) _serviceLookup;
final ServiceLookup _serviceLookup;
final List<Interceptor> _interceptors;
final CodecRegistry? _codecRegistry;
final GrpcErrorHandler? _errorHandler;

// ignore: cancel_subscriptions
StreamSubscription<GrpcMessage>? _incomingSubscription;
Expand All @@ -61,9 +65,19 @@ class ServerHandlerImpl extends ServiceCall {
Timer? _timeoutTimer;
final X509Certificate? _clientCertificate;

ServerHandlerImpl(this._serviceLookup, this._stream, this._interceptors,
this._codecRegistry,
[this._clientCertificate]);
ServerHandler({
required ServerTransportStream stream,
required ServiceLookup serviceLookup,
required List<Interceptor> interceptors,
required CodecRegistry? codecRegistry,
X509Certificate? clientCertificate,
GrpcErrorHandler? errorHandler,
}) : _stream = stream,
_serviceLookup = serviceLookup,
_interceptors = interceptors,
_codecRegistry = codecRegistry,
_clientCertificate = clientCertificate,
_errorHandler = errorHandler;

@override
DateTime? get deadline => _deadline;
Expand Down Expand Up @@ -254,12 +268,12 @@ class ServerHandlerImpl extends ServiceCall {
Object? request;
try {
request = _descriptor.deserialize(data.data);
} catch (error) {
} catch (error, trace) {
final grpcError =
GrpcError.internal('Error deserializing request: $error');
_sendError(grpcError);
_sendError(grpcError, trace);
_requests!
..addError(grpcError)
..addError(grpcError, trace)
..close();
return;
}
Expand All @@ -276,15 +290,15 @@ class ServerHandlerImpl extends ServiceCall {
sendHeaders();
}
_stream.sendData(frame(bytes, _callEncodingCodec));
} catch (error) {
} catch (error, trace) {
final grpcError = GrpcError.internal('Error sending response: $error');
if (!_requests!.isClosed) {
// If we can, alert the handler that things are going wrong.
_requests!
..addError(grpcError)
..close();
}
_sendError(grpcError);
_sendError(grpcError, trace);
_cancelResponseSubscription();
}
}
Expand All @@ -293,11 +307,11 @@ class ServerHandlerImpl extends ServiceCall {
sendTrailers();
}

void _onResponseError(error) {
void _onResponseError(error, trace) {
if (error is GrpcError) {
_sendError(error);
_sendError(error, trace);
} else {
_sendError(GrpcError.unknown(error.toString()));
_sendError(GrpcError.unknown(error.toString()), trace);
}
}

Expand Down Expand Up @@ -410,7 +424,9 @@ class ServerHandlerImpl extends ServiceCall {
..onDone(_onDone);
}

void _sendError(GrpcError error) {
void _sendError(GrpcError error, [StackTrace? trace]) {
_errorHandler?.call(error, trace);

sendTrailers(
status: error.code,
message: error.message,
Expand All @@ -424,13 +440,3 @@ class ServerHandlerImpl extends ServiceCall {
_cancelResponseSubscription();
}
}

class ServerHandler extends ServerHandlerImpl {
// ignore: use_super_parameters
ServerHandler(Service Function(String service) serviceLookup, stream,
[List<Interceptor> interceptors = const <Interceptor>[],
CodecRegistry? codecRegistry,
X509Certificate? clientCertificate])
: super(serviceLookup, stream, interceptors, codecRegistry,
clientCertificate);
}
58 changes: 41 additions & 17 deletions lib/src/server/server.dart
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class ConnectionServer {
final Map<String, Service> _services = {};
final List<Interceptor> _interceptors;
final CodecRegistry? _codecRegistry;
final GrpcErrorHandler? _errorHandler;

final _connections = <ServerTransportConnection>[];

Expand All @@ -95,19 +96,23 @@ class ConnectionServer {
List<Service> services, [
List<Interceptor> interceptors = const <Interceptor>[],
CodecRegistry? codecRegistry,
GrpcErrorHandler? errorHandler,
]) : _codecRegistry = codecRegistry,
_interceptors = interceptors {
_interceptors = interceptors,
_errorHandler = errorHandler {
for (final service in services) {
_services[service.$name] = service;
}
}

Service? lookupService(String service) => _services[service];

Future<void> serveConnection(ServerTransportConnection connection,
[X509Certificate? clientCertificate]) async {
Future<void> serveConnection(
ServerTransportConnection connection, [
X509Certificate? clientCertificate,
]) async {
_connections.add(connection);
ServerHandlerImpl? handler;
ServerHandler? handler;
// TODO(jakobr): Set active state handlers, close connection after idle
// timeout.
connection.incomingStreams.listen((stream) {
Expand All @@ -127,12 +132,18 @@ class ConnectionServer {
}

@visibleForTesting
ServerHandlerImpl serveStream_(ServerTransportStream stream,
[X509Certificate? clientCertificate]) {
return ServerHandlerImpl(
lookupService, stream, _interceptors, _codecRegistry,
ServerHandler serveStream_(
ServerTransportStream stream, [
X509Certificate? clientCertificate,
]) {
return ServerHandler(
stream: stream,
serviceLookup: lookupService,
interceptors: _interceptors,
codecRegistry: _codecRegistry,
// ignore: unnecessary_cast
clientCertificate as io_bits.X509Certificate?,
clientCertificate: clientCertificate as io_bits.X509Certificate?,
errorHandler: _errorHandler,
)..handle();
}
}
Expand All @@ -145,12 +156,22 @@ class Server extends ConnectionServer {
SecureServerSocket? _secureServer;

/// Create a server for the given [services].
@Deprecated('use Server.create() instead')
Server(
super.services, [
super.interceptors,
super.codecRegistry,
super.errorHandler,
]);

/// Create a server for the given [services].
Server.create({
required List<Service> services,
List<Interceptor> interceptors = const <Interceptor>[],
CodecRegistry? codecRegistry,
GrpcErrorHandler? errorHandler,
}) : super(services, interceptors, codecRegistry, errorHandler);

/// The port that the server is listening on, or `null` if the server is not
/// active.
int? get port {
Expand Down Expand Up @@ -223,15 +244,18 @@ class Server extends ConnectionServer {

@override
@visibleForTesting
ServerHandlerImpl serveStream_(ServerTransportStream stream,
[X509Certificate? clientCertificate]) {
return ServerHandlerImpl(
lookupService,
stream,
_interceptors,
_codecRegistry,
ServerHandler serveStream_(
ServerTransportStream stream, [
X509Certificate? clientCertificate,
]) {
return ServerHandler(
stream: stream,
serviceLookup: lookupService,
interceptors: _interceptors,
codecRegistry: _codecRegistry,
// ignore: unnecessary_cast
clientCertificate as io_bits.X509Certificate?,
clientCertificate: clientCertificate as io_bits.X509Certificate?,
errorHandler: _errorHandler,
)..handle();
}

Expand Down
6 changes: 5 additions & 1 deletion test/client_certificate_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class EchoService extends EchoServiceBase {
}

const String address = 'localhost';

Future<void> main() async {
test('Client certificate required', () async {
// Server
Expand Down Expand Up @@ -80,7 +81,7 @@ Future<void> main() async {
}

Future<Server> _setUpServer([bool requireClientCertificate = false]) async {
final server = Server([EchoService()]);
final server = Server.create(services: [EchoService()]);
final serverContext = SecurityContextServerCredentials.baseSecurityContext();
serverContext.useCertificateChain('test/data/localhost.crt');
serverContext.usePrivateKey('test/data/localhost.key');
Expand All @@ -102,6 +103,7 @@ class SecurityContextChannelCredentials extends ChannelCredentials {
{super.authority, super.onBadCertificate})
: _securityContext = securityContext,
super.secure();

@override
SecurityContext get securityContext => _securityContext;

Expand All @@ -116,8 +118,10 @@ class SecurityContextServerCredentials extends ServerTlsCredentials {
SecurityContextServerCredentials(SecurityContext securityContext)
: _securityContext = securityContext,
super();

@override
SecurityContext get securityContext => _securityContext;

static SecurityContext baseSecurityContext() {
return createSecurityContext(true);
}
Expand Down
7 changes: 5 additions & 2 deletions test/client_handles_bad_connections_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class TestClient extends grpc.Client {
(List<int> value) => value[0]);

TestClient(super.channel);

grpc.ResponseStream<int> stream(int request, {grpc.CallOptions? options}) {
return $createStreamingCall(_$stream, Stream.value(request),
options: options);
Expand All @@ -42,11 +43,13 @@ class TestService extends grpc.Service {
class FixedConnectionClientChannel extends ClientChannelBase {
final Http2ClientConnection clientConnection;
List<grpc.ConnectionState> states = <grpc.ConnectionState>[];

FixedConnectionClientChannel(this.clientConnection) {
onConnectionStateChanged.listen((state) {
states.add(state);
});
}

@override
ClientConnection createConnection() => clientConnection;
}
Expand All @@ -55,7 +58,7 @@ Future<void> main() async {
testTcpAndUds('client reconnects after the connection gets old',
(address) async {
// client reconnect after a short delay.
final server = grpc.Server([TestService()]);
final server = grpc.Server.create(services: [TestService()]);
await server.serve(address: address, port: 0);

final channel = FixedConnectionClientChannel(Http2ClientConnection(
Expand All @@ -80,7 +83,7 @@ Future<void> main() async {

testTcpAndUds('client reconnects when stream limit is used', (address) async {
// client reconnect after setting stream limit.
final server = grpc.Server([TestService()]);
final server = grpc.Server.create(services: [TestService()]);
await server.serve(
address: address,
port: 0,
Expand Down
3 changes: 2 additions & 1 deletion test/grpc_web_server.dart
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ static_resources:
// with an error. Otherwise if verbose is specified it will be dumped
// to stdout unconditionally.
final output = <String>[];

void _info(String line) {
if (!verbose) {
output.add(line);
Expand All @@ -99,7 +100,7 @@ void _info(String line) {

Future<void> hybridMain(StreamChannel channel) async {
// Spawn a gRPC server.
final server = Server([EchoService()]);
final server = Server.create(services: [EchoService()]);
await server.serve(port: 0);
_info('grpc server listening on ${server.port}');

Expand Down
Loading