From 7d9158df41704eb6432022a3bb6a0face0c5d63c Mon Sep 17 00:00:00 2001 From: Boris Glimcher Date: Fri, 20 Oct 2023 00:24:58 +0300 Subject: [PATCH] refactor: move proto to utils Signed-off-by: Boris Glimcher --- pkg/bridge/bridge_test.go | 6 +++--- pkg/bridge/common.go | 19 ------------------- pkg/bridge/grpc.go | 5 +++-- pkg/port/common.go | 18 ------------------ pkg/port/grpc.go | 5 +++-- pkg/port/port_test.go | 6 +++--- pkg/svi/common.go | 18 ------------------ pkg/svi/grpc.go | 5 +++-- pkg/svi/svi_test.go | 6 +++--- pkg/utils/proto.go | 27 +++++++++++++++++++++++++++ pkg/vrf/common.go | 19 ------------------- pkg/vrf/grpc.go | 5 +++-- pkg/vrf/vrf_test.go | 6 +++--- 13 files changed, 51 insertions(+), 94 deletions(-) create mode 100644 pkg/utils/proto.go diff --git a/pkg/bridge/bridge_test.go b/pkg/bridge/bridge_test.go index 02d82093..11a2698e 100644 --- a/pkg/bridge/bridge_test.go +++ b/pkg/bridge/bridge_test.go @@ -280,7 +280,7 @@ func Test_CreateLogicalBridge(t *testing.T) { _ = opi.store.Set(testLogicalBridgeName, &testLogicalBridgeWithStatus) } if tt.out != nil { - tt.out = protoClone(tt.out) + tt.out = utils.ProtoClone(tt.out) tt.out.Name = testLogicalBridgeName } if tt.on != nil { @@ -541,7 +541,7 @@ func Test_UpdateLogicalBridge(t *testing.T) { _ = opi.store.Set(testLogicalBridgeName, &testLogicalBridgeWithStatus) } if tt.out != nil { - tt.out = protoClone(tt.out) + tt.out = utils.ProtoClone(tt.out) tt.out.Name = testLogicalBridgeName } @@ -730,7 +730,7 @@ func Test_ListLogicalBridges(t *testing.T) { request := &pb.ListLogicalBridgesRequest{PageSize: tt.size, PageToken: tt.token} response, err := client.ListLogicalBridges(ctx, request) - if !equalProtoSlices(response.GetLogicalBridges(), tt.out) { + if !utils.EqualProtoSlices(response.GetLogicalBridges(), tt.out) { t.Error("response: expected", tt.out, "received", response.GetLogicalBridges()) } diff --git a/pkg/bridge/common.go b/pkg/bridge/common.go index fa6116dd..420e4799 100644 --- a/pkg/bridge/common.go +++ b/pkg/bridge/common.go @@ -15,7 +15,6 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/grpc/test/bufconn" - "google.golang.org/protobuf/proto" pb "github.com/opiproject/opi-api/network/evpn-gw/v1alpha1/gen/go" ) @@ -35,10 +34,6 @@ func resourceIDToFullName(_ string, resourceID string) string { return fmt.Sprintf("//network.opiproject.org/bridges/%s", resourceID) } -func protoClone[T proto.Message](protoStruct T) T { - return proto.Clone(protoStruct).(T) -} - func extractPagination(pageSize int32, pageToken string, pagination map[string]int) (size int, offset int, err error) { const ( maxPageSize = 250 @@ -94,17 +89,3 @@ func dialer(opi *Server) func(context.Context, string) (net.Conn, error) { return listener.Dial() } } - -func equalProtoSlices[T proto.Message](x, y []T) bool { - if len(x) != len(y) { - return false - } - - for i := 0; i < len(x); i++ { - if !proto.Equal(x[i], y[i]) { - return false - } - } - - return true -} diff --git a/pkg/bridge/grpc.go b/pkg/bridge/grpc.go index 3a101087..c2a92f2d 100644 --- a/pkg/bridge/grpc.go +++ b/pkg/bridge/grpc.go @@ -13,6 +13,7 @@ import ( "github.com/google/uuid" "github.com/opiproject/opi-evpn-bridge/pkg/models" + "github.com/opiproject/opi-evpn-bridge/pkg/utils" pb "github.com/opiproject/opi-api/network/evpn-gw/v1alpha1/gen/go" @@ -52,7 +53,7 @@ func (s *Server) CreateLogicalBridge(ctx context.Context, in *pb.CreateLogicalBr return nil, err } // translate object - response := protoClone(in.LogicalBridge) + response := utils.ProtoClone(in.LogicalBridge) response.Status = &pb.LogicalBridgeStatus{OperStatus: pb.LBOperStatus_LB_OPER_STATUS_UP} log.Printf("new object %v", models.NewBridge(response)) // save object to the database @@ -130,7 +131,7 @@ func (s *Server) UpdateLogicalBridge(ctx context.Context, in *pb.UpdateLogicalBr return nil, err } } - response := protoClone(in.LogicalBridge) + response := utils.ProtoClone(in.LogicalBridge) response.Status = &pb.LogicalBridgeStatus{OperStatus: pb.LBOperStatus_LB_OPER_STATUS_UP} err = s.store.Set(in.LogicalBridge.Name, response) if err != nil { diff --git a/pkg/port/common.go b/pkg/port/common.go index e85bebac..e200786e 100644 --- a/pkg/port/common.go +++ b/pkg/port/common.go @@ -63,10 +63,6 @@ func resourceIDToFullName(_ string, resourceID string) string { return fmt.Sprintf("//network.opiproject.org/ports/%s", resourceID) } -func protoClone[T proto.Message](protoStruct T) T { - return proto.Clone(protoStruct).(T) -} - func extractPagination(pageSize int32, pageToken string, pagination map[string]int) (size int, offset int, err error) { const ( maxPageSize = 250 @@ -122,17 +118,3 @@ func dialer(opi *Server) func(context.Context, string) (net.Conn, error) { return listener.Dial() } } - -func equalProtoSlices[T proto.Message](x, y []T) bool { - if len(x) != len(y) { - return false - } - - for i := 0; i < len(x); i++ { - if !proto.Equal(x[i], y[i]) { - return false - } - } - - return true -} diff --git a/pkg/port/grpc.go b/pkg/port/grpc.go index ade79050..79711693 100644 --- a/pkg/port/grpc.go +++ b/pkg/port/grpc.go @@ -14,6 +14,7 @@ import ( "github.com/google/uuid" "github.com/opiproject/opi-evpn-bridge/pkg/models" + "github.com/opiproject/opi-evpn-bridge/pkg/utils" pb "github.com/opiproject/opi-api/network/evpn-gw/v1alpha1/gen/go" @@ -53,7 +54,7 @@ func (s *Server) CreateBridgePort(ctx context.Context, in *pb.CreateBridgePortRe return nil, err } // translate object - response := protoClone(in.BridgePort) + response := utils.ProtoClone(in.BridgePort) response.Status = &pb.BridgePortStatus{OperStatus: pb.BPOperStatus_BP_OPER_STATUS_UP} log.Printf("new object %v", models.NewPort(response)) // save object to the database @@ -128,7 +129,7 @@ func (s *Server) UpdateBridgePort(ctx context.Context, in *pb.UpdateBridgePortRe fmt.Printf("Failed to update link: %v", err) return nil, err } - response := protoClone(in.BridgePort) + response := utils.ProtoClone(in.BridgePort) response.Status = &pb.BridgePortStatus{OperStatus: pb.BPOperStatus_BP_OPER_STATUS_UP} err = s.store.Set(in.BridgePort.Name, response) if err != nil { diff --git a/pkg/port/port_test.go b/pkg/port/port_test.go index 90fac034..1e3cc661 100644 --- a/pkg/port/port_test.go +++ b/pkg/port/port_test.go @@ -322,7 +322,7 @@ func Test_CreateBridgePort(t *testing.T) { _ = opi.store.Set(testBridgePortName, &testBridgePortWithStatus) } if tt.out != nil { - tt.out = protoClone(tt.out) + tt.out = utils.ProtoClone(tt.out) tt.out.Name = testBridgePortName } if tt.on != nil { @@ -571,7 +571,7 @@ func Test_UpdateBridgePort(t *testing.T) { _ = opi.store.Set(testBridgePortName, &testBridgePortWithStatus) } if tt.out != nil { - tt.out = protoClone(tt.out) + tt.out = utils.ProtoClone(tt.out) tt.out.Name = testBridgePortName } @@ -760,7 +760,7 @@ func Test_ListBridgePorts(t *testing.T) { request := &pb.ListBridgePortsRequest{PageSize: tt.size, PageToken: tt.token} response, err := client.ListBridgePorts(ctx, request) - if !equalProtoSlices(response.GetBridgePorts(), tt.out) { + if !utils.EqualProtoSlices(response.GetBridgePorts(), tt.out) { t.Error("response: expected", tt.out, "received", response.GetBridgePorts()) } diff --git a/pkg/svi/common.go b/pkg/svi/common.go index 1b7a9c00..166c6f2d 100644 --- a/pkg/svi/common.go +++ b/pkg/svi/common.go @@ -96,10 +96,6 @@ func resourceIDToFullName(_ string, resourceID string) string { return fmt.Sprintf("//network.opiproject.org/svis/%s", resourceID) } -func protoClone[T proto.Message](protoStruct T) T { - return proto.Clone(protoStruct).(T) -} - func extractPagination(pageSize int32, pageToken string, pagination map[string]int) (size int, offset int, err error) { const ( maxPageSize = 250 @@ -155,17 +151,3 @@ func dialer(opi *Server) func(context.Context, string) (net.Conn, error) { return listener.Dial() } } - -func equalProtoSlices[T proto.Message](x, y []T) bool { - if len(x) != len(y) { - return false - } - - for i := 0; i < len(x); i++ { - if !proto.Equal(x[i], y[i]) { - return false - } - } - - return true -} diff --git a/pkg/svi/grpc.go b/pkg/svi/grpc.go index 29812a90..e1314728 100644 --- a/pkg/svi/grpc.go +++ b/pkg/svi/grpc.go @@ -14,6 +14,7 @@ import ( "github.com/google/uuid" "github.com/opiproject/opi-evpn-bridge/pkg/models" + "github.com/opiproject/opi-evpn-bridge/pkg/utils" pb "github.com/opiproject/opi-api/network/evpn-gw/v1alpha1/gen/go" @@ -82,7 +83,7 @@ func (s *Server) CreateSvi(ctx context.Context, in *pb.CreateSviRequest) (*pb.Sv return nil, err } // translate object - response := protoClone(in.Svi) + response := utils.ProtoClone(in.Svi) response.Status = &pb.SviStatus{OperStatus: pb.SVIOperStatus_SVI_OPER_STATUS_UP} log.Printf("new object %v", models.NewSvi(response)) // save object to the database @@ -197,7 +198,7 @@ func (s *Server) UpdateSvi(ctx context.Context, in *pb.UpdateSviRequest) (*pb.Sv fmt.Printf("Failed to update link: %v", err) return nil, err } - response := protoClone(in.Svi) + response := utils.ProtoClone(in.Svi) response.Status = &pb.SviStatus{OperStatus: pb.SVIOperStatus_SVI_OPER_STATUS_UP} err = s.store.Set(in.Svi.Name, response) if err != nil { diff --git a/pkg/svi/svi_test.go b/pkg/svi/svi_test.go index f42f6ad9..c55972b3 100644 --- a/pkg/svi/svi_test.go +++ b/pkg/svi/svi_test.go @@ -422,7 +422,7 @@ func Test_CreateSvi(t *testing.T) { _ = opi.store.Set(testSviName, &testSviWithStatus) } if tt.out != nil { - tt.out = protoClone(tt.out) + tt.out = utils.ProtoClone(tt.out) tt.out.Name = testSviName } if tt.on != nil { @@ -700,7 +700,7 @@ func Test_UpdateSvi(t *testing.T) { _ = opi.store.Set(testSviName, &testSviWithStatus) } if tt.out != nil { - tt.out = protoClone(tt.out) + tt.out = utils.ProtoClone(tt.out) tt.out.Name = testSviName } @@ -889,7 +889,7 @@ func Test_ListSvis(t *testing.T) { request := &pb.ListSvisRequest{PageSize: tt.size, PageToken: tt.token} response, err := client.ListSvis(ctx, request) - if !equalProtoSlices(response.GetSvis(), tt.out) { + if !utils.EqualProtoSlices(response.GetSvis(), tt.out) { t.Error("response: expected", tt.out, "received", response.GetSvis()) } diff --git a/pkg/utils/proto.go b/pkg/utils/proto.go new file mode 100644 index 00000000..7ed732bc --- /dev/null +++ b/pkg/utils/proto.go @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2022-2023 Dell Inc, or its subsidiaries. + +// Package utils has some utility functions and interfaces +package utils + +import "google.golang.org/protobuf/proto" + +// ProtoClone is a helper function to clone and cast protobufs +func ProtoClone[T proto.Message](protoStruct T) T { + return proto.Clone(protoStruct).(T) +} + +// EqualProtoSlices is a helper function to compare protobuf slices +func EqualProtoSlices[T proto.Message](x, y []T) bool { + if len(x) != len(y) { + return false + } + + for i := 0; i < len(x); i++ { + if !proto.Equal(x[i], y[i]) { + return false + } + } + + return true +} diff --git a/pkg/vrf/common.go b/pkg/vrf/common.go index 5e76b5aa..e63108d9 100644 --- a/pkg/vrf/common.go +++ b/pkg/vrf/common.go @@ -16,7 +16,6 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/grpc/test/bufconn" - "google.golang.org/protobuf/proto" pb "github.com/opiproject/opi-api/network/evpn-gw/v1alpha1/gen/go" ) @@ -32,10 +31,6 @@ func resourceIDToFullName(_ string, resourceID string) string { return fmt.Sprintf("//network.opiproject.org/vrfs/%s", resourceID) } -func protoClone[T proto.Message](protoStruct T) T { - return proto.Clone(protoStruct).(T) -} - func generateRandMAC() ([]byte, error) { buf := make([]byte, 6) if _, err := rand.Read(buf); err != nil { @@ -103,17 +98,3 @@ func dialer(opi *Server) func(context.Context, string) (net.Conn, error) { return listener.Dial() } } - -func equalProtoSlices[T proto.Message](x, y []T) bool { - if len(x) != len(y) { - return false - } - - for i := 0; i < len(x); i++ { - if !proto.Equal(x[i], y[i]) { - return false - } - } - - return true -} diff --git a/pkg/vrf/grpc.go b/pkg/vrf/grpc.go index 0a9371b1..d4c4313f 100644 --- a/pkg/vrf/grpc.go +++ b/pkg/vrf/grpc.go @@ -15,6 +15,7 @@ import ( "github.com/google/uuid" "github.com/opiproject/opi-evpn-bridge/pkg/models" + "github.com/opiproject/opi-evpn-bridge/pkg/utils" pb "github.com/opiproject/opi-api/network/evpn-gw/v1alpha1/gen/go" @@ -69,7 +70,7 @@ func (s *Server) CreateVrf(ctx context.Context, in *pb.CreateVrfRequest) (*pb.Vr return nil, err } // translate object - response := protoClone(in.Vrf) + response := utils.ProtoClone(in.Vrf) response.Status = &pb.VrfStatus{LocalAs: 4, RoutingTable: tableID, Rmac: mac} log.Printf("new object %v", models.NewVrf(response)) // save object to the database @@ -148,7 +149,7 @@ func (s *Server) UpdateVrf(ctx context.Context, in *pb.UpdateVrfRequest) (*pb.Vr fmt.Printf("Failed to update link: %v", err) return nil, err } - response := protoClone(in.Vrf) + response := utils.ProtoClone(in.Vrf) response.Status = &pb.VrfStatus{LocalAs: 4} err = s.store.Set(in.Vrf.Name, response) if err != nil { diff --git a/pkg/vrf/vrf_test.go b/pkg/vrf/vrf_test.go index d5e88db5..0f07ccee 100644 --- a/pkg/vrf/vrf_test.go +++ b/pkg/vrf/vrf_test.go @@ -349,7 +349,7 @@ func Test_CreateVrf(t *testing.T) { _ = opi.store.Set(testVrfName, &testVrfWithStatus) } if tt.out != nil { - tt.out = protoClone(tt.out) + tt.out = utils.ProtoClone(tt.out) tt.out.Name = testVrfName } if tt.on != nil { @@ -736,7 +736,7 @@ func Test_UpdateVrf(t *testing.T) { _ = opi.store.Set(testVrfName, &testVrfWithStatus) } if tt.out != nil { - tt.out = protoClone(tt.out) + tt.out = utils.ProtoClone(tt.out) tt.out.Name = testVrfName } @@ -925,7 +925,7 @@ func Test_ListVrfs(t *testing.T) { request := &pb.ListVrfsRequest{PageSize: tt.size, PageToken: tt.token} response, err := client.ListVrfs(ctx, request) - if !equalProtoSlices(response.GetVrfs(), tt.out) { + if !utils.EqualProtoSlices(response.GetVrfs(), tt.out) { t.Error("response: expected", tt.out, "received", response.GetVrfs()) }