Skip to content

Commit

Permalink
Add compressed writes to cas.go.
Browse files Browse the repository at this point in the history
This follows the current tentative API being worked on in
bazelbuild/remote-apis#168. While there's technically room for it to
change, it has reached a somewhat stable point worth implementing.
  • Loading branch information
rubensf committed Nov 23, 2020
1 parent 25688f6 commit d17dd1b
Show file tree
Hide file tree
Showing 7 changed files with 212 additions and 124 deletions.
11 changes: 5 additions & 6 deletions go/pkg/chunker/chunker.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ var IOBufferSize = 10 * 1024 * 1024
// ErrEOF is returned when Next is called when HasNext is false.
var ErrEOF = errors.New("ErrEOF")

// Compressor for full blobs. It is *only* thread-safe for EncodeAll calls and
// should not be used for streamed compression.
var fullCompressor, _ = zstd.NewWriter(nil)
// Compressor for full blobs
// It is *only* thread-safe for EncodeAll calls and should not be used for streamed compression.
// While we avoid sending 0 len blobs, we do want to create zero len compressed blobs if
// necessary.
var fullCompressor, _ = zstd.NewWriter(nil, zstd.WithZeroFrames(true))

type UploadEntry struct {
digest digest.Digest
Expand Down Expand Up @@ -81,9 +83,6 @@ type Chunker struct {
}

func New(ue *UploadEntry, compressed bool, chunkSize int) (*Chunker, error) {
if compressed {
return nil, errors.New("compression is not supported yet")
}
if chunkSize < 1 {
chunkSize = DefaultChunkSize
}
Expand Down
21 changes: 21 additions & 0 deletions go/pkg/client/cas.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ import (
log "github.com/golang/glog"
)

// DefaultCompressedBytestreamThreshold is the default threshold for transferring blobs compressed on ByteStream.Write RPCs.
const DefaultCompressedBytestreamThreshold = 1024

const logInterval = 25

type requestMetadata struct {
Expand Down Expand Up @@ -858,6 +861,13 @@ func (c *Client) ResourceNameWrite(hash string, sizeBytes int64) string {
return fmt.Sprintf("%s/uploads/%s/blobs/%s/%d", c.InstanceName, uuid.New(), hash, sizeBytes)
}

// ResourceNameCompressedWrite generates a valid write resource name.
// TODO(rubensf): Converge compressor to proto in https://github.com/bazelbuild/remote-apis/pull/168 once
// that gets merged in.
func (c *Client) ResourceNameCompressedWrite(hash string, sizeBytes int64) string {
return fmt.Sprintf("%s/uploads/%s/compressed-blobs/zstd/%s/%d", c.InstanceName, uuid.New(), hash, sizeBytes)
}

// GetDirectoryTree returns the entire directory tree rooted at the given digest (which must target
// a Directory stored in the CAS).
func (c *Client) GetDirectoryTree(ctx context.Context, d *repb.Digest) (result []*repb.Directory, err error) {
Expand Down Expand Up @@ -1377,3 +1387,14 @@ func (c *Client) DownloadFiles(ctx context.Context, execRoot string, outputs map
}
return nil
}

func (c *Client) shouldCompress(sizeBytes int64) bool {
return int64(c.CompressedBytestreamThreshold) >= 0 && int64(c.CompressedBytestreamThreshold) <= sizeBytes
}

func (c *Client) writeRscName(dg digest.Digest) string {
if c.shouldCompress(dg.Size) {
return c.ResourceNameCompressedWrite(dg.Hash, dg.Size)
}
return c.ResourceNameWrite(dg.Hash, dg.Size)
}
152 changes: 76 additions & 76 deletions go/pkg/client/cas_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -286,23 +286,26 @@ func TestWrite(t *testing.T) {
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
gotDg, err := c.WriteBlob(ctx, tc.blob)
if err != nil {
t.Errorf("c.WriteBlob(ctx, blob) gave error %s, wanted nil", err)
}
if fake.Err != nil {
t.Errorf("c.WriteBlob(ctx, blob) caused the server to return error %s (possibly unseen by c)", fake.Err)
}
if !bytes.Equal(tc.blob, fake.Buf) {
t.Errorf("c.WriteBlob(ctx, blob) had diff on blobs, want %v, got %v:", tc.blob, fake.Buf)
}
dg := digest.NewFromBlob(tc.blob)
if dg != gotDg {
t.Errorf("c.WriteBlob(ctx, blob) had diff on digest returned (want %s, got %s)", dg, gotDg)
}
})
for _, cmp := range []client.CompressedBytestreamThreshold{-1, 0} {
for _, tc := range tests {
t.Run(fmt.Sprintf("%s - CompressionThresh:%d", tc.name, cmp), func(t *testing.T) {
cmp.Apply(c)
gotDg, err := c.WriteBlob(ctx, tc.blob)
if err != nil {
t.Errorf("c.WriteBlob(ctx, blob) gave error %s, wanted nil", err)
}
if fake.Err != nil {
t.Errorf("c.WriteBlob(ctx, blob) caused the server to return error %s (possibly unseen by c)", fake.Err)
}
if !bytes.Equal(tc.blob, fake.Buf) {
t.Errorf("c.WriteBlob(ctx, blob) had diff on blobs, want %v, got %v:", tc.blob, fake.Buf)
}
dg := digest.NewFromBlob(tc.blob)
if dg != gotDg {
t.Errorf("c.WriteBlob(ctx, blob) had diff on digest returned (want %s, got %s)", dg, gotDg)
}
})
}
}
}

Expand Down Expand Up @@ -711,74 +714,71 @@ func TestUpload(t *testing.T) {

for _, ub := range []client.UseBatchOps{false, true} {
for _, uo := range []client.UnifiedCASOps{false, true} {
t.Run(fmt.Sprintf("UsingBatch:%t,UnifiedCASOps:%t", ub, uo), func(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ctx := context.Background()
e, cleanup := fakes.NewTestEnv(t)
defer cleanup()
fake := e.Server.CAS
c := e.Client.GrpcClient
client.CASConcurrency(defaultCASConcurrency).Apply(c)
ub.Apply(c)
uo.Apply(c)
if tc.concurrency > 0 {
tc.concurrency.Apply(c)
}

present := make(map[digest.Digest]bool)
for _, blob := range tc.present {
fake.Put(blob)
present[digest.NewFromBlob(blob)] = true
}
var input []*chunker.UploadEntry
for _, blob := range tc.input {
input = append(input, chunker.EntryFromBlob(blob))
}
for _, cmp := range []client.CompressedBytestreamThreshold{-1, 0} {
t.Run(fmt.Sprintf("UsingBatch:%t,UnifiedCASOps:%t,CompressionThresh:%d", ub, uo, cmp), func(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ctx := context.Background()
e, cleanup := fakes.NewTestEnv(t)
defer cleanup()
fake := e.Server.CAS
c := e.Client.GrpcClient
client.CASConcurrency(defaultCASConcurrency).Apply(c)
cmp.Apply(c)
ub.Apply(c)
uo.Apply(c)
if tc.concurrency > 0 {
tc.concurrency.Apply(c)
}

missing, err := c.UploadIfMissing(ctx, input...)
if err != nil {
t.Errorf("c.UploadIfMissing(ctx, input) gave error %v, expected nil", err)
}
present := make(map[digest.Digest]bool)
for _, blob := range tc.present {
fake.Put(blob)
present[digest.NewFromBlob(blob)] = true
}
var input []*chunker.UploadEntry
for _, blob := range tc.input {
input = append(input, chunker.EntryFromBlob(blob))
}

missingSet := make(map[digest.Digest]struct{})
for _, dg := range missing {
missingSet[dg] = struct{}{}
}
for _, ue := range input {
dg := ue.Digest()
ch, err := chunker.New(ue, false, int(c.ChunkMaxSize))
missing, err := c.UploadIfMissing(ctx, input...)
if err != nil {
t.Fatalf("chunker.New(ue): failed to create chunker from UploadEntry: %v", err)
t.Errorf("c.UploadIfMissing(ctx, input) gave error %v, expected nil", err)
}
blob, err := ch.FullData()
if err != nil {
t.Errorf("ch.FullData() returned an error: %v", err)

missingSet := make(map[digest.Digest]struct{})
for _, dg := range missing {
missingSet[dg] = struct{}{}
}
if present[dg] {
if fake.BlobWrites(dg) > 0 {
t.Errorf("blob %v with digest %s was uploaded even though it was already present in the CAS", blob, dg)

for i, ue := range input {
dg := ue.Digest()
blob := tc.input[i]
if present[dg] {
if fake.BlobWrites(dg) > 0 {
t.Errorf("blob %v with digest %s was uploaded even though it was already present in the CAS", blob, dg)
}
if _, ok := missingSet[dg]; ok {
t.Errorf("Stats said that blob %v with digest %s was missing in the CAS", blob, dg)
}
continue
}
if _, ok := missingSet[dg]; ok {
t.Errorf("Stats said that blob %v with digest %s was missing in the CAS", blob, dg)
if gotBlob, ok := fake.Get(dg); !ok {
t.Errorf("blob %v with digest %s was not uploaded, expected it to be present in the CAS", blob, dg)
} else if !bytes.Equal(blob, gotBlob) {
t.Errorf("blob digest %s had diff on uploaded blob: want %v, got %v", dg, blob, gotBlob)
}
if _, ok := missingSet[dg]; !ok {
t.Errorf("Stats said that blob %v with digest %s was present in the CAS", blob, dg)
}
continue
}
if gotBlob, ok := fake.Get(dg); !ok {
t.Errorf("blob %v with digest %s was not uploaded, expected it to be present in the CAS", blob, dg)
} else if !bytes.Equal(blob, gotBlob) {
t.Errorf("blob digest %s had diff on uploaded blob: want %v, got %v", dg, blob, gotBlob)
}
if _, ok := missingSet[dg]; !ok {
t.Errorf("Stats said that blob %v with digest %s was present in the CAS", blob, dg)
if fake.MaxConcurrency() > defaultCASConcurrency {
t.Errorf("CAS concurrency %v was higher than max %v", fake.MaxConcurrency(), defaultCASConcurrency)
}
}
if fake.MaxConcurrency() > defaultCASConcurrency {
t.Errorf("CAS concurrency %v was higher than max %v", fake.MaxConcurrency(), defaultCASConcurrency)
}
})
}
})
})
}
})
}
}
}
}
Expand Down
57 changes: 35 additions & 22 deletions go/pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ type Client struct {
StartupCapabilities StartupCapabilities
// ChunkMaxSize is maximum chunk size to use for CAS uploads/downloads.
ChunkMaxSize ChunkMaxSize
// CompressedBytestreamThreshold is the threshold in bytes for which blobs are read and written
// compressed. Use 0 for all writes being compressed, and a negative number for all writes being
// uncompressed. TODO(rubensf): Make sure this will throw an error if the server doesn't support compression,
// pending https://github.com/bazelbuild/remote-apis/pull/168 being submitted.
CompressedBytestreamThreshold CompressedBytestreamThreshold
// MaxBatchDigests is maximum amount of digests to batch in batched operations.
MaxBatchDigests MaxBatchDigests
// MaxBatchSize is maximum size in bytes of a batch request for batch operations.
Expand Down Expand Up @@ -145,6 +150,13 @@ func (s ChunkMaxSize) Apply(c *Client) {
c.ChunkMaxSize = s
}

type CompressedBytestreamThreshold int64

// Apply sets the client's maximal chunk size s.
func (s CompressedBytestreamThreshold) Apply(c *Client) {
c.CompressedBytestreamThreshold = s
}

// UtilizeLocality is to specify whether client downloads files utilizing disk access locality.
type UtilizeLocality bool

Expand Down Expand Up @@ -460,28 +472,29 @@ func NewClient(ctx context.Context, instanceName string, params DialParams, opts
return nil, err
}
client := &Client{
InstanceName: instanceName,
actionCache: regrpc.NewActionCacheClient(casConn),
byteStream: bsgrpc.NewByteStreamClient(casConn),
cas: regrpc.NewContentAddressableStorageClient(casConn),
execution: regrpc.NewExecutionClient(conn),
operations: opgrpc.NewOperationsClient(conn),
rpcTimeouts: DefaultRPCTimeouts,
Connection: conn,
CASConnection: casConn,
ChunkMaxSize: chunker.DefaultChunkSize,
MaxBatchDigests: DefaultMaxBatchDigests,
MaxBatchSize: DefaultMaxBatchSize,
DirMode: DefaultDirMode,
ExecutableMode: DefaultExecutableMode,
RegularMode: DefaultRegularMode,
useBatchOps: true,
StartupCapabilities: true,
casConcurrency: DefaultCASConcurrency,
casUploaders: semaphore.NewWeighted(DefaultCASConcurrency),
casDownloaders: semaphore.NewWeighted(DefaultCASConcurrency),
casUploads: make(map[digest.Digest]*uploadState),
Retrier: RetryTransient(),
InstanceName: instanceName,
actionCache: regrpc.NewActionCacheClient(casConn),
byteStream: bsgrpc.NewByteStreamClient(casConn),
cas: regrpc.NewContentAddressableStorageClient(casConn),
execution: regrpc.NewExecutionClient(conn),
operations: opgrpc.NewOperationsClient(conn),
rpcTimeouts: DefaultRPCTimeouts,
Connection: conn,
CASConnection: casConn,
CompressedBytestreamThreshold: DefaultCompressedBytestreamThreshold,
ChunkMaxSize: chunker.DefaultChunkSize,
MaxBatchDigests: DefaultMaxBatchDigests,
MaxBatchSize: DefaultMaxBatchSize,
DirMode: DefaultDirMode,
ExecutableMode: DefaultExecutableMode,
RegularMode: DefaultRegularMode,
useBatchOps: true,
StartupCapabilities: true,
casConcurrency: DefaultCASConcurrency,
casUploaders: semaphore.NewWeighted(DefaultCASConcurrency),
casDownloaders: semaphore.NewWeighted(DefaultCASConcurrency),
casUploads: make(map[digest.Digest]*uploadState),
Retrier: RetryTransient(),
}
for _, o := range opts {
o.Apply(client)
Expand Down
1 change: 1 addition & 0 deletions go/pkg/fakes/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ go_library(
"@com_github_golang_glog//:go_default_library",
"@com_github_golang_protobuf//proto:go_default_library",
"@com_github_golang_protobuf//ptypes:go_default_library_gen",
"@com_github_klauspost_compress//zstd:go_default_library",
"@com_github_pborman_uuid//:go_default_library",
"@go_googleapis//google/bytestream:bytestream_go_proto",
"@go_googleapis//google/longrunning:longrunning_go_proto",
Expand Down
Loading

0 comments on commit d17dd1b

Please sign in to comment.