Skip to content

Commit

Permalink
fetch_server: add support for bazel qualifiers
Browse files Browse the repository at this point in the history
In addition to the checksum.sri qualifier from remote api spec,
Bazel includes it's own set of qualifiers to each FetchBlob requests.

https://cs.opensource.google/bazel/bazel/+/master:src/main/java/com/google/devtools/build/lib/remote/downloader/GrpcRemoteDownloader.java;l=76-85;drc=618c0abbfe518f4e29de523a2e63ca9179050e94

This change adds initial support for the http_header qualifier as well
as acknoledge the existence of bazel.canonical_id header (without
actually using it). Relevant tests were also added to demonstrate the
current behavior.

In a future patch, we may start rejecting unknown qualifiers as
recommended by the Remote Asset API spec and thus, flip the assertions
in the test. See bazelbuild/remote-apis#301
for more info.

Another change to consider in the future is support for custom
url-specific header credentials. This is being proposed in Bazel via
bazelbuild/bazel#23578.
  • Loading branch information
sluongng committed Sep 11, 2024
1 parent 7bcc7ab commit 66ca793
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 21 deletions.
86 changes: 65 additions & 21 deletions server/remote_asset/fetch_server/fetch_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ import (
)

const (
ChecksumQualifier = "checksum.sri"
maxHTTPTimeout = 60 * time.Minute
ChecksumQualifier = "checksum.sri"
BazelCanonicalIDQualifier = "bazel.canonical_id"
BazelHttpHeaderPrefixQualifier = "http_header:"
maxHTTPTimeout = 60 * time.Minute
)

type FetchServer struct {
Expand Down Expand Up @@ -100,6 +102,24 @@ func timeoutHTTPClient(ctx context.Context, protoTimeout *durationpb.Duration) *
}
}

// parseChecksumQualifier returns a digest function and digest hash
// given a "checksum.sri" qualifier.
func parseChecksumQualifier(qualifier *rapb.Qualifier) (repb.DigestFunction_Value, string, error) {
for _, digestFunc := range digest.SupportedDigestFunctions() {
pr := fmt.Sprintf("%s-", strings.ToLower(repb.DigestFunction_Value_name[int32(digestFunc)]))
if strings.HasPrefix(qualifier.GetValue(), pr) {
b64hash := strings.TrimPrefix(qualifier.GetValue(), pr)
decodedHash, err := base64.StdEncoding.DecodeString(b64hash)
if err != nil {
return repb.DigestFunction_UNKNOWN, "", status.FailedPreconditionErrorf("Error decoding qualifier %q: %s", qualifier.GetName(), err.Error())
}
expectedChecksum := fmt.Sprintf("%x", decodedHash)
return digestFunc, expectedChecksum, nil
}
}
return repb.DigestFunction_UNKNOWN, "", nil
}

func (p *FetchServer) FetchBlob(ctx context.Context, req *rapb.FetchBlobRequest) (*rapb.FetchBlobResponse, error) {
ctx, err := prefix.AttachUserPrefixToContext(ctx, p.env)
if err != nil {
Expand All @@ -110,28 +130,27 @@ func (p *FetchServer) FetchBlob(ctx context.Context, req *rapb.FetchBlobRequest)
if storageFunc == repb.DigestFunction_UNKNOWN {
storageFunc = repb.DigestFunction_SHA256
}
header := make(http.Header)
var checksumFunc repb.DigestFunction_Value
var expectedChecksum string
for _, qualifier := range req.GetQualifiers() {
var prefix string
if qualifier.GetName() == ChecksumQualifier {
for _, digestFunc := range digest.SupportedDigestFunctions() {
pr := fmt.Sprintf("%s-", strings.ToLower(repb.DigestFunction_Value_name[int32(digestFunc)]))
if strings.HasPrefix(qualifier.GetValue(), pr) {
checksumFunc = digestFunc
prefix = pr
break
}
}
}
if prefix != "" {
b64hash := strings.TrimPrefix(qualifier.GetValue(), prefix)
decodedHash, err := base64.StdEncoding.DecodeString(b64hash)
checksumFunc, expectedChecksum, err = parseChecksumQualifier(qualifier)
if err != nil {
return nil, status.FailedPreconditionErrorf("Error decoding qualifier %q: %s", qualifier.GetName(), err.Error())
return nil, err
}
expectedChecksum = fmt.Sprintf("%x", decodedHash)
break
continue
}
if strings.HasPrefix(qualifier.GetName(), BazelHttpHeaderPrefixQualifier) {
header.Add(
strings.TrimPrefix(qualifier.GetName(), BazelHttpHeaderPrefixQualifier),
qualifier.GetValue(),
)
continue
}
if qualifier.GetName() == BazelCanonicalIDQualifier {
// TODO: Implement canonical ID handling.
continue
}
}
if len(expectedChecksum) != 0 {
Expand Down Expand Up @@ -161,7 +180,17 @@ func (p *FetchServer) FetchBlob(ctx context.Context, req *rapb.FetchBlobRequest)
if err != nil {
return nil, status.InvalidArgumentErrorf("unparsable URI: %q", uri)
}
blobDigest, err := mirrorToCache(ctx, p.env.GetByteStreamClient(), req.GetInstanceName(), httpClient, uri, storageFunc, checksumFunc, expectedChecksum)
blobDigest, err := mirrorToCache(
ctx,
p.env.GetByteStreamClient(),
req.GetInstanceName(),
httpClient,
uri,
header,
storageFunc,
checksumFunc,
expectedChecksum,
)
if err != nil {
lastFetchErr = err
log.CtxWarningf(ctx, "Failed to mirror %q to cache: %s", uri, err)
Expand Down Expand Up @@ -266,9 +295,24 @@ func (p *FetchServer) findBlobInCache(ctx context.Context, instanceName string,
// returning the digest. The fetched contents are checked against the given
// expectedChecksum (if non-empty), and if there is a mismatch then an error is
// returned.
func mirrorToCache(ctx context.Context, bsClient bspb.ByteStreamClient, remoteInstanceName string, httpClient *http.Client, uri string, storageFunc repb.DigestFunction_Value, checksumFunc repb.DigestFunction_Value, expectedChecksum string) (*repb.Digest, error) {
func mirrorToCache(
ctx context.Context,
bsClient bspb.ByteStreamClient,
remoteInstanceName string,
httpClient *http.Client,
uri string,
header http.Header,
storageFunc repb.DigestFunction_Value,
checksumFunc repb.DigestFunction_Value,
expectedChecksum string,
) (*repb.Digest, error) {
log.CtxDebugf(ctx, "Fetching %s", uri)
rsp, err := httpClient.Get(uri)
req, err := http.NewRequestWithContext(ctx, "GET", uri, nil)
if err != nil {
return nil, status.UnavailableErrorf("failed to fetch %q: create request failed: %s", uri, err)
}
req.Header = header
rsp, err := httpClient.Do(req)
if err != nil {
return nil, status.UnavailableErrorf("failed to fetch %q: HTTP GET failed: %s", uri, err)
}
Expand Down
58 changes: 58 additions & 0 deletions server/remote_asset/fetch_server/fetch_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,64 @@ func TestSubsequentRequestCacheHit(t *testing.T) {
}
}

func TestFetchBlobWithBazelQualifiers(t *testing.T) {
ctx := context.Background()
te := testenv.GetTestEnv(t)
require.NoError(t, scratchspace.Init())
clientConn := runFetchServer(ctx, t, te)
fetchClient := rapb.NewFetchClient(clientConn)

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, r.Header.Get("hkey"), "hvalue")
fmt.Fprint(w, "some blob")
}))
defer ts.Close()

request := &rapb.FetchBlobRequest{
Uris: []string{ts.URL},
Qualifiers: []*rapb.Qualifier{
{
Name: fetch_server.BazelCanonicalIDQualifier,
Value: "some-bazel-id",
},
{
Name: fetch_server.BazelHttpHeaderPrefixQualifier + "hkey",
Value: "hvalue",
},
},
}
resp, err := fetchClient.FetchBlob(ctx, request)
assert.NoError(t, err)
assert.NotNil(t, resp)
}

func TestFetchBlobWithUnknownQualifiers(t *testing.T) {
ctx := context.Background()
te := testenv.GetTestEnv(t)
require.NoError(t, scratchspace.Init())
clientConn := runFetchServer(ctx, t, te)
fetchClient := rapb.NewFetchClient(clientConn)

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "some blob")
}))
defer ts.Close()

request := &rapb.FetchBlobRequest{
Uris: []string{ts.URL},
Qualifiers: []*rapb.Qualifier{
{
Name: "unknown-qualifier",
Value: "some-value",
},
},
}
resp, err := fetchClient.FetchBlob(ctx, request)
// TODO: Return an error when an unknown qualifier is used
assert.NoError(t, err)
assert.NotNil(t, resp)
}

func TestFetchDirectory(t *testing.T) {
ctx := context.Background()
te := testenv.GetTestEnv(t)
Expand Down

0 comments on commit 66ca793

Please sign in to comment.