Skip to content

Commit

Permalink
chore: support grpc codecV2
Browse files Browse the repository at this point in the history
Signed-off-by: Valery Piashchynski <piashchynski.valery@gmail.com>
  • Loading branch information
rustatian committed Aug 30, 2024
1 parent f49b6e3 commit d7fc9e4
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 21 deletions.
29 changes: 23 additions & 6 deletions codec/codec.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package codec

import "google.golang.org/grpc/encoding"
import (
"google.golang.org/grpc/encoding"
"google.golang.org/grpc/mem"
"google.golang.org/protobuf/proto"
)

type RawMessage []byte

Expand All @@ -15,7 +19,7 @@ func (RawMessage) ProtoMessage() {}
func (RawMessage) String() string { return rm }

type Codec struct {
Base encoding.Codec
Base encoding.CodecV2
}

// Marshal returns the wire format of v. rawMessages would be returned without encoding.
Expand All @@ -24,17 +28,30 @@ func (c *Codec) Marshal(v any) ([]byte, error) {
return raw, nil
}

return c.Base.Marshal(v)
data, err := c.Base.Marshal(v)
if err != nil {
return nil, err
}

return data.Materialize(), nil
}

// Unmarshal parses the wire format into v. rawMessages would not be unmarshalled.
func (c *Codec) Unmarshal(data []byte, v any) error {
if raw, ok := v.(*RawMessage); ok {
*raw = data
switch msg := v.(type) {
case *RawMessage:
*msg = data
return nil
case proto.Message:
err := proto.Unmarshal(data, msg)
if err != nil {
return err
}
default:
return c.Base.Unmarshal(mem.BufferSlice{mem.NewBuffer(&data, mem.DefaultBufferPool())}, v)
}

return c.Base.Unmarshal(data, v)
return nil
}

func (c *Codec) Name() string {
Expand Down
19 changes: 14 additions & 5 deletions codec/codec_test.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,29 @@
package codec

import (
"encoding/json"
"testing"

json "github.com/goccy/go-json"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc/mem"
)

type jsonCodec struct{}

func (jsonCodec) Marshal(v any) ([]byte, error) {
return json.Marshal(v)
func (jsonCodec) Marshal(v any) (mem.BufferSlice, error) {
data, err := json.Marshal(v)
if err != nil {
return nil, err
}

buf := mem.NewBuffer(&data, mem.DefaultBufferPool())
bs := mem.BufferSlice{buf}
return bs, nil
}

func (jsonCodec) Unmarshal(data []byte, v any) error {
return json.Unmarshal(data, v)
func (jsonCodec) Unmarshal(data mem.BufferSlice, v any) error {
out := data.Materialize()
return json.Unmarshal(out, v)
}

func (jsonCodec) Name() string {
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ toolchain go1.23.0

require (
github.com/emicklei/proto v1.13.2
github.com/goccy/go-json v0.10.3
github.com/prometheus/client_golang v1.20.2
github.com/roadrunner-server/api/v4 v4.16.0
github.com/roadrunner-server/endure/v2 v2.5.0
Expand All @@ -33,6 +32,7 @@ require (
github.com/go-logr/logr v1.4.2 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-ole/go-ole v1.3.0 // indirect
github.com/goccy/go-json v0.10.3 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
Expand Down
3 changes: 2 additions & 1 deletion plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,10 @@ func (p *Plugin) Init(cfg common.Configurer, log common.Logger, server common.Se
if !cfg.Has(pluginName) {
return errors.E(errors.Disabled)
}

// register the codec
encoding.RegisterCodec(&codec.Codec{
Base: encoding.GetCodec(codec.Name),
Base: encoding.GetCodecV2(codec.Name),
})

err := cfg.UnmarshalKey(pluginName, &p.config)
Expand Down
2 changes: 1 addition & 1 deletion server.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func (p *Plugin) serverOptions() ([]grpc.ServerOption, error) {
var err error

if p.config.EnableTLS() {
// if client CA is not empty we combine it with Cert and Key
// if client CA is not empty, we combine it with Cert and Key
if p.config.TLS.RootCA != "" {
cert, err = tls.LoadX509KeyPair(p.config.TLS.Cert, p.config.TLS.Key)
if err != nil {
Expand Down
9 changes: 2 additions & 7 deletions tests/grpc_plugin_gzip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ func TestGrpcRqRsGzip(t *testing.T) {
cfg := &config.Plugin{
Version: "2023.3.0",
Path: "configs/.rr-grpc-rq.yaml",
Prefix: "rr",
}

err := cont.RegisterAll(
Expand Down Expand Up @@ -112,7 +111,6 @@ func TestGrpcRqRsMultipleGzip(t *testing.T) {
cfg := &config.Plugin{
Version: "2023.3.0",
Path: "configs/.rr-grpc-rq-multiple.yaml",
Prefix: "rr",
}

err := cont.RegisterAll(
Expand Down Expand Up @@ -206,7 +204,6 @@ func TestGrpcRqRsTLSGzip(t *testing.T) {
cfg := &config.Plugin{
Version: "2023.3.0",
Path: "configs/.rr-grpc-rq-tls.yaml",
Prefix: "rr",
}

err := cont.RegisterAll(
Expand Down Expand Up @@ -291,7 +288,6 @@ func TestGrpcRqRsTLSRootCAGzip(t *testing.T) {
cfg := &config.Plugin{
Version: "2023.3.0",
Path: "configs/.rr-grpc-rq-tls-rootca.yaml",
Prefix: "rr",
}

err := cont.RegisterAll(
Expand Down Expand Up @@ -356,7 +352,7 @@ func TestGrpcRqRsTLSRootCAGzip(t *testing.T) {
MinVersion: tls.VersionTLS12,
}

conn, err := grpc.Dial("127.0.0.1:9003", grpc.WithTransportCredentials(credentials.NewTLS(tlscfg)), grpc.WithDefaultCallOptions(grpc.UseCompressor("gzip")))
conn, err := grpc.NewClient("127.0.0.1:9003", grpc.WithTransportCredentials(credentials.NewTLS(tlscfg)), grpc.WithDefaultCallOptions(grpc.UseCompressor("gzip")))
require.NoError(t, err)
require.NotNil(t, conn)

Expand All @@ -376,7 +372,6 @@ func TestGrpcRqRsTLS_WithResetGzip(t *testing.T) {
cfg := &config.Plugin{
Version: "2023.3.0",
Path: "configs/.rr-grpc-rq-tls.yaml",
Prefix: "rr",
}

err := cont.RegisterAll(
Expand Down Expand Up @@ -442,7 +437,7 @@ func TestGrpcRqRsTLS_WithResetGzip(t *testing.T) {
MinVersion: tls.VersionTLS12,
}

conn, err := grpc.Dial("localhost:9002", grpc.WithTransportCredentials(credentials.NewTLS(tlscfg)), grpc.WithDefaultCallOptions(grpc.UseCompressor("gzip")))
conn, err := grpc.NewClient("localhost:9002", grpc.WithTransportCredentials(credentials.NewTLS(tlscfg)), grpc.WithDefaultCallOptions(grpc.UseCompressor("gzip")))
require.NoError(t, err)
require.NotNil(t, conn)

Expand Down

0 comments on commit d7fc9e4

Please sign in to comment.