Skip to content

Commit

Permalink
implement GetPeer RPC
Browse files Browse the repository at this point in the history
  • Loading branch information
dennis-tra committed Nov 7, 2024
1 parent 2cd96dc commit dbb5081
Show file tree
Hide file tree
Showing 5 changed files with 397 additions and 53 deletions.
92 changes: 88 additions & 4 deletions cmd/nebula/cmd_serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ package main

import (
"context"
"database/sql"
"errors"
"fmt"
"net/http"
"strings"
"time"

"connectrpc.com/connect"
Expand All @@ -14,6 +17,8 @@ import (
"github.com/urfave/cli/v2"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"

"github.com/dennis-tra/nebula-crawler/db"
)

var serveConfig = &config.Serve{
Expand Down Expand Up @@ -47,8 +52,27 @@ var ServeCommand = &cli.Command{

// ServeAction is the function that is called when running `nebula resolve`.
func ServeAction(c *cli.Context) error {
log.Infoln("Start serving Nebula data...")
defer log.Infoln("Stopped serving Nebula data.")

ctx := c.Context

// initialize a new database client based on the given configuration.
// Options are Postgres, JSON, and noop (dry-run).
dbc, err := db.NewServerClient(ctx, rootConfig.Database)
if err != nil {
return fmt.Errorf("new database client: %w", err)
}
defer func() {
if err := dbc.Close(); err != nil && !errors.Is(err, sql.ErrConnDone) && !strings.Contains(err.Error(), "use of closed network connection") {
log.WithError(err).Warnln("Failed closing database handle")
}
}()

mux := http.NewServeMux()
path, handler := nebulav1connect.NewNebulaServiceHandler(&nebulaServiceServer{})
path, handler := nebulav1connect.NewNebulaServiceHandler(&nebulaServiceServer{
dbc: dbc,
})
mux.Handle(path, handler)

address := fmt.Sprintf("%s:%d", serveConfig.Host, serveConfig.Port)
Expand Down Expand Up @@ -85,15 +109,75 @@ func ServeAction(c *cli.Context) error {
}

// petStoreServiceServer implements the PetStoreService API.
type nebulaServiceServer struct{}
type nebulaServiceServer struct {
dbc db.ServerClient
}

var _ nebulav1connect.NebulaServiceHandler = (*nebulaServiceServer)(nil)

func (n nebulaServiceServer) GetPeer(ctx context.Context, c *connect.Request[v1.GetPeerRequest]) (*connect.Response[v1.GetPeerResponse], error) {
func (n *nebulaServiceServer) GetPeer(ctx context.Context, c *connect.Request[v1.GetPeerRequest]) (*connect.Response[v1.GetPeerResponse], error) {
log.WithField("multihash", c.Msg.MultiHash).Info("GetPeer")

dbPeer, dbProtocols, err := n.dbc.GetPeer(ctx, c.Msg.MultiHash)
if err != nil {
return nil, err
}

v1Maddrs := make([]*v1.MultiAddress, 0, len(dbPeer.R.MultiAddresses))
for _, dbMaddr := range dbPeer.R.MultiAddresses {
var asn *int32
if !dbMaddr.Asn.IsZero() {
val := int32(dbMaddr.Asn.Int)
asn = &val
}

var isCloud *int32
if !dbMaddr.IsCloud.IsZero() {
val := int32(dbMaddr.IsCloud.Int)
asn = &val
}

var country *string
if !dbMaddr.Country.IsZero() {
country = &dbMaddr.Country.String
}

var continent *string
if !dbMaddr.Continent.IsZero() {
continent = &dbMaddr.Country.String
}

var ip *string
if !dbMaddr.Addr.IsZero() {
ip = &dbMaddr.Addr.String
}

v1Maddrs = append(v1Maddrs, &v1.MultiAddress{
MultiAddress: dbMaddr.Maddr,
Asn: asn,
IsCloud: isCloud,
Country: country,
Continent: continent,
Ip: ip,
})
}

protocols := make([]string, 0, len(dbProtocols))
for _, dbProtocol := range dbProtocols {
protocols = append(protocols, dbProtocol.Protocol)
}

var av *string
if dbPeer.R.AgentVersion != nil {
av = &dbPeer.R.AgentVersion.AgentVersion
}

resp := &connect.Response[v1.GetPeerResponse]{
Msg: &v1.GetPeerResponse{
MultiHash: c.Msg.MultiHash,
MultiHash: dbPeer.MultiHash,
AgentVersion: av,
MultiAddresses: v1Maddrs,
Protocols: protocols,
},
}

Expand Down
27 changes: 27 additions & 0 deletions db/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ type Client interface {
PersistNeighbors(ctx context.Context, crawl *models.Crawl, dbPeerID *int, peerID peer.ID, errorBits uint16, dbNeighborsIDs []int, neighbors []peer.ID) error
}

type ServerClient interface {
io.Closer
GetPeer(ctx context.Context, multiHash string) (*models.Peer, models.ProtocolSlice, error)
}

// NewClient will initialize the right database client based on the given
// configuration. This can either be a Postgres, JSON, or noop client. The noop
// client is a dummy implementation of the [Client] interface that does nothing
Expand Down Expand Up @@ -53,3 +58,25 @@ func NewClient(ctx context.Context, cfg *config.Database) (Client, error) {

return dbc, nil
}

func NewServerClient(ctx context.Context, cfg *config.Database) (ServerClient, error) {
var (
dbc ServerClient
err error
)

// dry run has precedence. Then, if a JSON output directory is given, use
// the JSON client. In any other case, use the Postgres database client.
if cfg.DryRun {
return nil, fmt.Errorf("server client not implemented")
} else if cfg.JSONOut != "" {
return nil, fmt.Errorf("server client not implemented")
} else {
dbc, err = InitDBServerClient(ctx, cfg)
}
if err != nil {
return nil, fmt.Errorf("init db client: %w", err)
}

return dbc, nil
}
105 changes: 105 additions & 0 deletions db/client_db_server.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package db

import (
"context"
"database/sql"
"fmt"

"github.com/dennis-tra/nebula-crawler/db/models"
log "github.com/sirupsen/logrus"
"github.com/uptrace/opentelemetry-go-extra/otelsql"
"github.com/volatiletech/sqlboiler/v4/queries/qm"
semconv "go.opentelemetry.io/otel/semconv/v1.21.0"

"github.com/dennis-tra/nebula-crawler/config"
)

type DBServerClient struct {
ctx context.Context

// Reference to the configuration
cfg *config.Database

// Database handler
dbh *sql.DB

// reference to all relevant db telemetry
telemetry *telemetry
}

var _ ServerClient = (*DBServerClient)(nil)

// InitDBServerClient establishes a database connection with the provided
// configuration
func InitDBServerClient(ctx context.Context, cfg *config.Database) (*DBServerClient, error) {
log.WithFields(log.Fields{
"host": cfg.DatabaseHost,
"port": cfg.DatabasePort,
"name": cfg.DatabaseName,
"user": cfg.DatabaseUser,
"ssl": cfg.DatabaseSSLMode,
}).Infoln("Initializing database client")

dbh, err := otelsql.Open("postgres", cfg.DatabaseSourceName(),
otelsql.WithAttributes(semconv.DBSystemPostgreSQL),
otelsql.WithMeterProvider(cfg.MeterProvider),
otelsql.WithTracerProvider(cfg.TracerProvider),
)
if err != nil {
return nil, fmt.Errorf("opening database: %w", err)
}

// Set to match the writer worker
dbh.SetMaxIdleConns(cfg.MaxIdleConns) // default is 2 which leads to many connection open/closings

otelsql.ReportDBStatsMetrics(dbh, otelsql.WithMeterProvider(cfg.MeterProvider))

// Ping database to verify connection.
if err = dbh.Ping(); err != nil {
return nil, fmt.Errorf("pinging database: %w", err)
}

telemetry, err := newTelemetry(cfg.TracerProvider, cfg.MeterProvider)
if err != nil {
return nil, fmt.Errorf("new telemetry: %w", err)
}

client := &DBServerClient{ctx: ctx, cfg: cfg, dbh: dbh, telemetry: telemetry}

return client, nil
}

func (d *DBServerClient) Close() error {
return d.dbh.Close()
}

func (d *DBServerClient) GetPeer(ctx context.Context, multiHash string) (*models.Peer, models.ProtocolSlice, error) {
// write a hand-crafted query to avoid two DB round-trips

dbPeer, err := models.Peers(
models.PeerWhere.MultiHash.EQ(multiHash),
qm.Load(models.PeerRels.AgentVersion),
qm.Load(models.PeerRels.MultiAddresses),
qm.Load(models.PeerRels.ProtocolsSet),
).One(ctx, d.dbh)
if err != nil {
return nil, nil, fmt.Errorf("getting peer: %w", err)
}

if dbPeer.R.ProtocolsSet == nil {
return dbPeer, nil, nil
}

protocolIDs := dbPeer.R.ProtocolsSet.ProtocolIds
ids := make([]int, 0, len(protocolIDs))
for _, id := range protocolIDs {
ids = append(ids, int(id))
}

dbProtocols, err := models.Protocols(models.ProtocolWhere.ID.IN(ids)).All(ctx, d.dbh)
if err != nil {
return dbPeer, nil, fmt.Errorf("getting protocols: %w", err)
}

return dbPeer, dbProtocols, nil
}
Loading

0 comments on commit dbb5081

Please sign in to comment.