diff --git a/go/lib/infra/modules/segfetcher/fetcher.go b/go/lib/infra/modules/segfetcher/fetcher.go index 22605eb887..19fe4ecb58 100644 --- a/go/lib/infra/modules/segfetcher/fetcher.go +++ b/go/lib/infra/modules/segfetcher/fetcher.go @@ -16,6 +16,7 @@ package segfetcher import ( "context" + "net" "time" "github.com/scionproto/scion/go/lib/addr" @@ -49,6 +50,10 @@ type FetcherConfig struct { Validator Validator // Splitter is used to split requests. Splitter Splitter + // CryptoLookupAtLocalCS indicates whether crypto to verify path material + // should be fetched from the local CS or from the sender of the path + // material. + CryptoLookupAtLocalCS bool } // New creates a new fetcher from the configuration. @@ -62,24 +67,26 @@ func (cfg FetcherConfig) New() *Fetcher { Verifier: &SegVerifier{Verifier: cfg.VerificationFactory.NewVerifier()}, Storage: &DefaultStorage{PathDB: cfg.PathDB, RevCache: cfg.RevCache}, }, - PathDB: cfg.PathDB, - QueryInterval: cfg.QueryInterval, + PathDB: cfg.PathDB, + QueryInterval: cfg.QueryInterval, + CryptoLookupAtLocalCS: cfg.CryptoLookupAtLocalCS, } } // Fetcher fetches, verifies and stores segments for a given path request. type Fetcher struct { - Validator Validator - Splitter Splitter - Resolver Resolver - Requester Requester - ReplyHandler ReplyHandler - PathDB pathdb.PathDB - QueryInterval time.Duration + Validator Validator + Splitter Splitter + Resolver Resolver + Requester Requester + ReplyHandler ReplyHandler + PathDB pathdb.PathDB + QueryInterval time.Duration + CryptoLookupAtLocalCS bool } // FetchSegs fetches the required segments to build a path between src and dst -// of the request. Firs the request is validated and then depending on the +// of the request. First the request is validated and then depending on the // cache the segments are fetched from the remote server. func (f *Fetcher) FetchSegs(ctx context.Context, req Request) (Segments, error) { if f.Validator != nil { @@ -126,7 +133,7 @@ func (f *Fetcher) waitOnProcessed(ctx context.Context, replies <-chan ReplyOrErr if reply.Reply == nil || reply.Reply.Recs == nil { continue } - r := f.ReplyHandler.Handle(ctx, reply.Reply, reply.Peer, nil) + r := f.ReplyHandler.Handle(ctx, reply.Reply, f.verifyServer(reply), nil) select { case <-r.FullReplyProcessed(): if err := r.Err(); err != nil { @@ -143,3 +150,10 @@ func (f *Fetcher) waitOnProcessed(ctx context.Context, replies <-chan ReplyOrErr } return nil } + +func (f *Fetcher) verifyServer(reply ReplyOrErr) net.Addr { + if f.CryptoLookupAtLocalCS { + return nil + } + return reply.Peer +} diff --git a/go/lib/infra/modules/segfetcher/fetcher_test.go b/go/lib/infra/modules/segfetcher/fetcher_test.go index 1f2b4c4fef..0d3f68cf06 100644 --- a/go/lib/infra/modules/segfetcher/fetcher_test.go +++ b/go/lib/infra/modules/segfetcher/fetcher_test.go @@ -114,8 +114,8 @@ func TestFetcher(t *testing.T) { ErrorAssertion: require.NoError, ExpectedSegs: segfetcher.Segments{Up: seg.Segments{tg.seg130_111}}, }, - // XXX(lukedirtwalker): testing the full loop is quite involved, not - // sure if it would be worth it. + // XXX(lukedirtwalker): testing the full loop is quite involved, and is + // therefore currently omitted. } for name, test := range tests { t.Run(name, func(t *testing.T) {