Skip to content

Commit

Permalink
fetch_server: add support for bazel qualifiers (#7412)
Browse files Browse the repository at this point in the history
  • Loading branch information
sluongng authored Sep 13, 2024
1 parent 3b1bc06 commit 49939ef
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 21 deletions.
86 changes: 65 additions & 21 deletions server/remote_asset/fetch_server/fetch_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ import (
)

const (
ChecksumQualifier = "checksum.sri"
maxHTTPTimeout = 60 * time.Minute
ChecksumQualifier = "checksum.sri"
BazelCanonicalIDQualifier = "bazel.canonical_id"
BazelHttpHeaderPrefixQualifier = "http_header:"
maxHTTPTimeout = 60 * time.Minute
)

type FetchServer struct {
Expand Down Expand Up @@ -100,6 +102,24 @@ func timeoutHTTPClient(ctx context.Context, protoTimeout *durationpb.Duration) *
}
}

// parseChecksumQualifier returns a digest function and digest hash
// given a "checksum.sri" qualifier.
func parseChecksumQualifier(qualifier *rapb.Qualifier) (repb.DigestFunction_Value, string, error) {
for _, digestFunc := range digest.SupportedDigestFunctions() {
pr := fmt.Sprintf("%s-", strings.ToLower(repb.DigestFunction_Value_name[int32(digestFunc)]))
if strings.HasPrefix(qualifier.GetValue(), pr) {
b64hash := strings.TrimPrefix(qualifier.GetValue(), pr)
decodedHash, err := base64.StdEncoding.DecodeString(b64hash)
if err != nil {
return repb.DigestFunction_UNKNOWN, "", status.FailedPreconditionErrorf("Error decoding qualifier %q: %s", qualifier.GetName(), err.Error())
}
expectedChecksum := fmt.Sprintf("%x", decodedHash)
return digestFunc, expectedChecksum, nil
}
}
return repb.DigestFunction_UNKNOWN, "", nil
}

func (p *FetchServer) FetchBlob(ctx context.Context, req *rapb.FetchBlobRequest) (*rapb.FetchBlobResponse, error) {
ctx, err := prefix.AttachUserPrefixToContext(ctx, p.env)
if err != nil {
Expand All @@ -110,28 +130,27 @@ func (p *FetchServer) FetchBlob(ctx context.Context, req *rapb.FetchBlobRequest)
if storageFunc == repb.DigestFunction_UNKNOWN {
storageFunc = repb.DigestFunction_SHA256
}
header := make(http.Header)
var checksumFunc repb.DigestFunction_Value
var expectedChecksum string
for _, qualifier := range req.GetQualifiers() {
var prefix string
if qualifier.GetName() == ChecksumQualifier {
for _, digestFunc := range digest.SupportedDigestFunctions() {
pr := fmt.Sprintf("%s-", strings.ToLower(repb.DigestFunction_Value_name[int32(digestFunc)]))
if strings.HasPrefix(qualifier.GetValue(), pr) {
checksumFunc = digestFunc
prefix = pr
break
}
}
}
if prefix != "" {
b64hash := strings.TrimPrefix(qualifier.GetValue(), prefix)
decodedHash, err := base64.StdEncoding.DecodeString(b64hash)
checksumFunc, expectedChecksum, err = parseChecksumQualifier(qualifier)
if err != nil {
return nil, status.FailedPreconditionErrorf("Error decoding qualifier %q: %s", qualifier.GetName(), err.Error())
return nil, err
}
expectedChecksum = fmt.Sprintf("%x", decodedHash)
break
continue
}
if strings.HasPrefix(qualifier.GetName(), BazelHttpHeaderPrefixQualifier) {
header.Add(
strings.TrimPrefix(qualifier.GetName(), BazelHttpHeaderPrefixQualifier),
qualifier.GetValue(),
)
continue
}
if qualifier.GetName() == BazelCanonicalIDQualifier {
// TODO: Implement canonical ID handling.
continue
}
}
if len(expectedChecksum) != 0 {
Expand Down Expand Up @@ -161,7 +180,17 @@ func (p *FetchServer) FetchBlob(ctx context.Context, req *rapb.FetchBlobRequest)
if err != nil {
return nil, status.InvalidArgumentErrorf("unparsable URI: %q", uri)
}
blobDigest, err := mirrorToCache(ctx, p.env.GetByteStreamClient(), req.GetInstanceName(), httpClient, uri, storageFunc, checksumFunc, expectedChecksum)
blobDigest, err := mirrorToCache(
ctx,
p.env.GetByteStreamClient(),
req.GetInstanceName(),
httpClient,
uri,
header,
storageFunc,
checksumFunc,
expectedChecksum,
)
if err != nil {
lastFetchErr = err
log.CtxWarningf(ctx, "Failed to mirror %q to cache: %s", uri, err)
Expand Down Expand Up @@ -266,9 +295,24 @@ func (p *FetchServer) findBlobInCache(ctx context.Context, instanceName string,
// returning the digest. The fetched contents are checked against the given
// expectedChecksum (if non-empty), and if there is a mismatch then an error is
// returned.
func mirrorToCache(ctx context.Context, bsClient bspb.ByteStreamClient, remoteInstanceName string, httpClient *http.Client, uri string, storageFunc repb.DigestFunction_Value, checksumFunc repb.DigestFunction_Value, expectedChecksum string) (*repb.Digest, error) {
func mirrorToCache(
ctx context.Context,
bsClient bspb.ByteStreamClient,
remoteInstanceName string,
httpClient *http.Client,
uri string,
header http.Header,
storageFunc repb.DigestFunction_Value,
checksumFunc repb.DigestFunction_Value,
expectedChecksum string,
) (*repb.Digest, error) {
log.CtxDebugf(ctx, "Fetching %s", uri)
rsp, err := httpClient.Get(uri)
req, err := http.NewRequestWithContext(ctx, "GET", uri, nil)
if err != nil {
return nil, status.UnavailableErrorf("failed to fetch %q: create request failed: %s", uri, err)
}
req.Header = header
rsp, err := httpClient.Do(req)
if err != nil {
return nil, status.UnavailableErrorf("failed to fetch %q: HTTP GET failed: %s", uri, err)
}
Expand Down
58 changes: 58 additions & 0 deletions server/remote_asset/fetch_server/fetch_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,64 @@ func TestSubsequentRequestCacheHit(t *testing.T) {
}
}

func TestFetchBlobWithBazelQualifiers(t *testing.T) {
ctx := context.Background()
te := testenv.GetTestEnv(t)
require.NoError(t, scratchspace.Init())
clientConn := runFetchServer(ctx, t, te)
fetchClient := rapb.NewFetchClient(clientConn)

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, r.Header.Get("hkey"), "hvalue")
fmt.Fprint(w, "some blob")
}))
defer ts.Close()

request := &rapb.FetchBlobRequest{
Uris: []string{ts.URL},
Qualifiers: []*rapb.Qualifier{
{
Name: fetch_server.BazelCanonicalIDQualifier,
Value: "some-bazel-id",
},
{
Name: fetch_server.BazelHttpHeaderPrefixQualifier + "hkey",
Value: "hvalue",
},
},
}
resp, err := fetchClient.FetchBlob(ctx, request)
assert.NoError(t, err)
assert.NotNil(t, resp)
}

func TestFetchBlobWithUnknownQualifiers(t *testing.T) {
ctx := context.Background()
te := testenv.GetTestEnv(t)
require.NoError(t, scratchspace.Init())
clientConn := runFetchServer(ctx, t, te)
fetchClient := rapb.NewFetchClient(clientConn)

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "some blob")
}))
defer ts.Close()

request := &rapb.FetchBlobRequest{
Uris: []string{ts.URL},
Qualifiers: []*rapb.Qualifier{
{
Name: "unknown-qualifier",
Value: "some-value",
},
},
}
resp, err := fetchClient.FetchBlob(ctx, request)
// TODO: Return an error when an unknown qualifier is used
assert.NoError(t, err)
assert.NotNil(t, resp)
}

func TestFetchDirectory(t *testing.T) {
ctx := context.Background()
te := testenv.GetTestEnv(t)
Expand Down

0 comments on commit 49939ef

Please sign in to comment.