Skip to content

Commit

Permalink
Fix CSI health server
Browse files Browse the repository at this point in the history
  • Loading branch information
angelini committed Jul 10, 2024
1 parent 89f9554 commit 17e1be7
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 120 deletions.
4 changes: 4 additions & 0 deletions internal/testutil/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ func (tc *TestCtx) Context() context.Context {
return tc.ctx
}

func (tc *TestCtx) Auth() auth.Auth {
return tc.Context().Value(auth.AuthCtxKey).(auth.Auth)
}

func (tc *TestCtx) Connect() pgx.Tx {
tx, _, err := tc.dbConn.Connect(tc.ctx)
require.NoError(tc.t, err, "connecting to db")
Expand Down
40 changes: 40 additions & 0 deletions pkg/cached/cached.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,14 @@ import (
"context"
"crypto/ed25519"
"crypto/tls"
"fmt"
"net"
"net/url"
"os"
"path"
"path/filepath"

"github.com/container-storage-interface/spec/lib/go/csi"
"github.com/gadget-inc/dateilager/internal/auth"
"github.com/gadget-inc/dateilager/internal/logger"
"github.com/gadget-inc/dateilager/internal/pb"
Expand Down Expand Up @@ -63,6 +69,40 @@ func (s *CachedServer) RegisterCached(cached *api.Cached) {
pb.RegisterCachedServer(s.Grpc, cached)
}

func (s *CachedServer) RegisterCSI(cached *api.Cached) {
csi.RegisterIdentityServer(s.Grpc, cached)
csi.RegisterNodeServer(s.Grpc, cached)
}

func (s *CachedServer) Serve(lis net.Listener) error {
return s.Grpc.Serve(lis)
}

func (s *CachedServer) ServeCSI(listenSocketPath string) error {
u, err := url.Parse(listenSocketPath)
if err != nil {
return fmt.Errorf("unable to parse address: %q", err)
}

addr := path.Join(u.Host, filepath.FromSlash(u.Path))
if u.Host == "" {
addr = filepath.FromSlash(u.Path)
}

// CSI plugins talk only over UNIX sockets currently
if u.Scheme != "unix" {
return fmt.Errorf("currently only unix domain sockets are supported, have incorrect protocol: %s", u.Scheme)
} else {
// remove the socket if it's already there. This can happen if we deploy a new version and the socket was created from the old running plugin.
if err := os.Remove(addr); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("failed to remove unix domain socket file %s, error: %s", addr, err)
}
}

listener, err := net.Listen(u.Scheme, addr)
if err != nil {
return fmt.Errorf("failed to listen: %v", err)
}

return s.Grpc.Serve(listener)
}
82 changes: 0 additions & 82 deletions pkg/cached/cachedcsi.go

This file was deleted.

37 changes: 15 additions & 22 deletions pkg/cli/cached.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,23 +97,18 @@ func NewCacheDaemonCommand() *cobra.Command {
return err
}

var s *cached.CachedServer
if csiSocket == "" {
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return fmt.Errorf("cannot open TLS cert and key files (%s, %s): %w", certFile, keyFile, err)
}

pasetoKey, err := parsePublicKey(pasetoFile)
if err != nil {
return fmt.Errorf("cannot parse Paseto public key %s: %w", pasetoFile, err)
}
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return fmt.Errorf("cannot open TLS cert and key files (%s, %s): %w", certFile, keyFile, err)
}

s = cached.NewServer(ctx, &cert, pasetoKey)
} else {
s = cached.NewCSIServer(ctx, cl, stagingPath)
pasetoKey, err := parsePublicKey(pasetoFile)
if err != nil {
return fmt.Errorf("cannot parse Paseto public key %s: %w", pasetoFile, err)
}

s := cached.NewServer(ctx, &cert, pasetoKey)

logger.Info(ctx, "register Cached")
cached := &api.Cached{
Env: env,
Expand All @@ -122,14 +117,6 @@ func NewCacheDaemonCommand() *cobra.Command {
}
s.RegisterCached(cached)

if csiSocket != "" {
logger.Info(ctx, "register CSI")
s.RegisterCSI(cached)
}

osSignals := make(chan os.Signal, 1)
signal.Notify(osSignals, os.Interrupt, syscall.SIGTERM)

err = cached.Prepare(ctx)
if err != nil {
return fmt.Errorf("failed to prepare cache daemon in %s: %w", stagingPath, err)
Expand All @@ -138,6 +125,9 @@ func NewCacheDaemonCommand() *cobra.Command {
group, ctx := errgroup.WithContext(ctx)

if csiSocket != "" {
logger.Info(ctx, "register CSI")
s.RegisterCSI(cached)

group.Go(func() error {
logger.Info(ctx, "start CSI server")
return s.ServeCSI(csiSocket)
Expand All @@ -154,6 +144,9 @@ func NewCacheDaemonCommand() *cobra.Command {
return s.Serve(listen)
})

osSignals := make(chan os.Signal, 1)
signal.Notify(osSignals, os.Interrupt, syscall.SIGTERM)

group.Go(func() error {
<-osSignals
s.Grpc.GracefulStop()
Expand Down
9 changes: 2 additions & 7 deletions test/client_new_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,12 @@ func TestClientNewProjectDuplicateReportsError(t *testing.T) {
/** Create Project Again**/

tcSecond := util.NewTestCtx(t, auth.Admin, projectId)
defer tc.Close()
defer tcSecond.Close()

cSecond, _, closeSecond := createTestClient(tc)
defer closeSecond()

errSecond := cSecond.NewProject(tcSecond.Context(), projectId, nil, nil)
want := "project id already exists"

require.Error(t, errSecond, "NewProject")

if errSecond == nil {
t.Errorf("got %s want %s", errSecond, want)
}
require.Error(t, errSecond, "NewProject already exists error")
}
23 changes: 14 additions & 9 deletions test/shared_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,8 @@ func verifyDir(t *testing.T, dir string, version int64, files map[string]expecte
}
}

func createTestGRPCServer(tc util.TestCtx, reqAuth auth.Auth) (*bufconn.Listener, *grpc.Server, func() *grpc.ClientConn) {
lis := bufconn.Listen(bufSize)
func createTestGRPCServer(tc util.TestCtx) (*bufconn.Listener, *grpc.Server, func() *grpc.ClientConn) {
reqAuth := tc.Auth()
s := grpc.NewServer(
grpc.UnaryInterceptor(
grpc.UnaryServerInterceptor(func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
Expand All @@ -397,6 +397,7 @@ func createTestGRPCServer(tc util.TestCtx, reqAuth auth.Auth) (*bufconn.Listener
),
)

lis := bufconn.Listen(bufSize)
dialer := func(context.Context, string) (net.Conn, error) {
return lis.Dial()
}
Expand All @@ -412,7 +413,15 @@ func createTestGRPCServer(tc util.TestCtx, reqAuth auth.Auth) (*bufconn.Listener

func createTestCachedCSIServer(tc util.TestCtx, tmpDir string) (*api.Cached, string, func()) {
cl, _, closeClient := createTestClient(tc)
s := cached.NewCSIServer(tc.Context(), cl, path.Join(tmpDir, "cached", "staging"))
_, grpcServer, _ := createTestGRPCServer(tc)

s := cached.CachedServer{
Grpc: grpcServer,
}

cached := tc.CachedApi(cl, path.Join(tmpDir, "cached", "staging"))
s.RegisterCached(cached)
s.RegisterCSI(cached)

socket := path.Join(tmpDir, "csi.sock")
endpoint := "unix://" + socket
Expand All @@ -422,15 +431,11 @@ func createTestCachedCSIServer(tc util.TestCtx, tmpDir string) (*api.Cached, str
require.NoError(tc.T(), err, "CSI Server exited")
}()

cached := tc.CachedApi(cl, path.Join(tmpDir, "cached", "staging"))
s.RegisterCached(cached)
s.RegisterCSI(cached)

return cached, endpoint, func() { closeClient(); s.Grpc.Stop() }
}

func createTestClient(tc util.TestCtx) (*client.Client, *api.Fs, func()) {
lis, s, getConn := createTestGRPCServer(tc, tc.Context().Value(auth.AuthCtxKey).(auth.Auth))
lis, s, getConn := createTestGRPCServer(tc)

fs := tc.FsApi()
pb.RegisterFsServer(s, fs)
Expand All @@ -448,7 +453,7 @@ func createTestClient(tc util.TestCtx) (*client.Client, *api.Fs, func()) {
// Make a new client that connects to a test cached server
// Under the hood, this creates a test storage server and connects to that
func createTestCachedClient(tc util.TestCtx) (*client.CachedClient, *api.Cached, func()) {
lis, s, getConn := createTestGRPCServer(tc, tc.Context().Value(auth.AuthCtxKey).(auth.Auth))
lis, s, getConn := createTestGRPCServer(tc)

cl, _, closeClient := createTestClient(tc)
stagingPath := emptyTmpDir(tc.T())
Expand Down

0 comments on commit 17e1be7

Please sign in to comment.