diff --git a/storage/grpc_client.go b/storage/grpc_client.go index d47c41c1908f..eb327a3eeb48 100644 --- a/storage/grpc_client.go +++ b/storage/grpc_client.go @@ -16,6 +16,7 @@ package storage import ( "context" + "encoding/binary" "errors" "fmt" "hash/crc32" @@ -36,6 +37,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/encoding" + "google.golang.org/grpc/mem" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" "google.golang.org/protobuf/encoding/protowire" @@ -956,37 +958,48 @@ func (c *grpcStorageClient) RewriteObject(ctx context.Context, req *rewriteObjec return r, nil } -// bytesCodec is a grpc codec which permits receiving messages as either -// protobuf messages, or as raw []bytes. -type bytesCodec struct { - encoding.Codec +// Custom codec to be used for unmarshaling ReadObjectResponse messages. +// This is used to avoid a copy of object data in proto.Unmarshal. +type bytesCodecV2 struct { } -func (bytesCodec) Marshal(v any) ([]byte, error) { +var _ encoding.CodecV2 = bytesCodecV2{} + +// Marshal is used to encode messages to send for bytesCodecV2. Since we are only +// using this to send ReadObjectRequest messages we don't need to recycle buffers +// here. +func (bytesCodecV2) Marshal(v any) (mem.BufferSlice, error) { vv, ok := v.(proto.Message) if !ok { return nil, fmt.Errorf("failed to marshal, message is %T, want proto.Message", v) } - return proto.Marshal(vv) + var data mem.BufferSlice + buf, err := proto.Marshal(vv) + if err != nil { + return nil, err + } + data = append(data, mem.SliceBuffer(buf)) + return data, nil } -func (bytesCodec) Unmarshal(data []byte, v any) error { +// Unmarshal is used for data received for ReadObjectResponse. We want to preserve +// the mem.BufferSlice in most cases rather than copying and calling proto.Unmarshal. +func (bytesCodecV2) Unmarshal(data mem.BufferSlice, v any) error { switch v := v.(type) { - case *[]byte: - // If gRPC could recycle the data []byte after unmarshaling (through - // buffer pools), we would need to make a copy here. + case *mem.BufferSlice: *v = data + // Pick up a reference to the data so that it is not freed while decoding. + data.Ref() return nil case proto.Message: - return proto.Unmarshal(data, v) + buf := data.MaterializeToBuffer(mem.DefaultBufferPool()) + return proto.Unmarshal(buf.ReadOnlyData(), v) default: - return fmt.Errorf("can not unmarshal type %T", v) + return fmt.Errorf("cannot unmarshal type %T, want proto.Message or mem.BufferSlice", v) } } -func (bytesCodec) Name() string { - // If this isn't "", then gRPC sets the content-subtype of the call to this - // value and we get errors. +func (bytesCodecV2) Name() string { return "" } @@ -997,7 +1010,7 @@ func (c *grpcStorageClient) NewRangeReader(ctx context.Context, params *newRange s := callSettings(c.settings, opts...) s.gax = append(s.gax, gax.WithGRPCOptions( - grpc.ForceCodec(bytesCodec{}), + grpc.ForceCodecV2(bytesCodecV2{}), )) if s.userProject != "" { @@ -1015,8 +1028,6 @@ func (c *grpcStorageClient) NewRangeReader(ctx context.Context, params *newRange req.Generation = params.gen } - var databuf []byte - // Define a function that initiates a Read with offset and length, assuming // we have already read seen bytes. reopen := func(seen int64) (*readStreamResponse, context.CancelFunc, error) { @@ -1042,8 +1053,8 @@ func (c *grpcStorageClient) NewRangeReader(ctx context.Context, params *newRange } var stream storagepb.Storage_ReadObjectClient - var msg *storagepb.ReadObjectResponse var err error + var decoder *readResponseDecoder err = run(cc, func(ctx context.Context) error { stream, err = c.raw.ReadObject(ctx, req, s.gax...) @@ -1053,7 +1064,8 @@ func (c *grpcStorageClient) NewRangeReader(ctx context.Context, params *newRange // Receive the message into databuf as a wire-encoded message so we can // use a custom decoder to avoid an extra copy at the protobuf layer. - err := stream.RecvMsg(&databuf) + databufs := mem.BufferSlice{} + err := stream.RecvMsg(&databufs) // These types of errors show up on the Recv call, rather than the // initialization of the stream via ReadObject above. if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { @@ -1063,22 +1075,26 @@ func (c *grpcStorageClient) NewRangeReader(ctx context.Context, params *newRange return err } // Use a custom decoder that uses protobuf unmarshalling for all - // fields except the checksummed data. - // Subsequent receives in Read calls will skip all protobuf - // unmarshalling and directly read the content from the gRPC []byte - // response, since only the first call will contain other fields. - msg, err = readFullObjectResponse(databuf) - + // fields except the object data. Object data is handled separately + // to avoid a copy. + decoder = &readResponseDecoder{ + databufs: databufs, + } + err = decoder.readFullObjectResponse() return err }, s.retry, s.idempotent) if err != nil { // Close the stream context we just created to ensure we don't leak // resources. cancel() + // Free any buffers. + if decoder != nil && decoder.databufs != nil { + decoder.databufs.Free() + } return nil, nil, err } - return &readStreamResponse{stream, msg}, cancel, nil + return &readStreamResponse{stream, decoder}, cancel, nil } res, cancel, err := reopen(0) @@ -1088,7 +1104,7 @@ func (c *grpcStorageClient) NewRangeReader(ctx context.Context, params *newRange // The first message was Recv'd on stream open, use it to populate the // object metadata. - msg := res.response + msg := res.decoder.msg obj := msg.GetMetadata() // This is the size of the entire object, even if only a range was requested. size := obj.GetSize() @@ -1121,12 +1137,10 @@ func (c *grpcStorageClient) NewRangeReader(ctx context.Context, params *newRange reopen: reopen, cancel: cancel, size: size, - // Store the content from the first Recv in the - // client buffer for reading later. - leftovers: msg.GetChecksummedData().GetContent(), + // Preserve the decoder to read out object data when Read/WriteTo is called. + currMsg: res.decoder, settings: s, zeroRange: params.length == 0, - databuf: databuf, wantCRC: wantCRC, checkCRC: checkCRC, }, @@ -1352,8 +1366,8 @@ func setUserProjectMetadata(ctx context.Context, project string) context.Context } type readStreamResponse struct { - stream storagepb.Storage_ReadObjectClient - response *storagepb.ReadObjectResponse + stream storagepb.Storage_ReadObjectClient + decoder *readResponseDecoder } type gRPCReader struct { @@ -1362,7 +1376,7 @@ type gRPCReader struct { stream storagepb.Storage_ReadObjectClient reopen func(seen int64) (*readStreamResponse, context.CancelFunc, error) leftovers []byte - databuf []byte + currMsg *readResponseDecoder // decoder for the current message cancel context.CancelFunc settings *settings checkCRC bool // should we check the CRC? @@ -1405,18 +1419,21 @@ func (r *gRPCReader) Read(p []byte) (int, error) { } var n int - // Read leftovers and return what was available to conform to the Reader + + // If there is data remaining in the current message, return what was + // available to conform to the Reader // interface: https://pkg.go.dev/io#Reader. - if len(r.leftovers) > 0 { - n = copy(p, r.leftovers) + if !r.currMsg.done { + n = r.currMsg.readAndUpdateCRC(p, func(b []byte) { + r.updateCRC(b) + }) r.seen += int64(n) - r.updateCRC(p[:n]) - r.leftovers = r.leftovers[n:] return n, nil } // Attempt to Recv the next message on the stream. - content, err := r.recv() + // This will update r.currMsg with the decoder for the new message. + err := r.recv() if err != nil { return 0, err } @@ -1428,16 +1445,11 @@ func (r *gRPCReader) Read(p []byte) (int, error) { // present in the response here. // TODO: Figure out if we need to support decompressive transcoding // https://cloud.google.com/storage/docs/transcoding. - n = copy(p[n:], content) - leftover := len(content) - n - if leftover > 0 { - // Wasn't able to copy all of the data in the message, store for - // future Read calls. - r.leftovers = content[n:] - } - r.seen += int64(n) - r.updateCRC(p[:n]) + n = r.currMsg.readAndUpdateCRC(p, func(b []byte) { + r.updateCRC(b) + }) + r.seen += int64(n) return n, nil } @@ -1464,14 +1476,14 @@ func (r *gRPCReader) WriteTo(w io.Writer) (int64, error) { // Track bytes written during before call. var alreadySeen = r.seen - // Write any leftovers to the stream. There will be some leftovers from the + // Write any already received message to the stream. There will be some leftovers from the // original NewRangeReader call. - if len(r.leftovers) > 0 { - // Write() will write the entire leftovers slice unless there is an error. - written, err := w.Write(r.leftovers) + if r.currMsg != nil && !r.currMsg.done { + written, err := r.currMsg.writeToAndUpdateCRC(w, func(b []byte) { + r.updateCRC(b) + }) r.seen += int64(written) - r.updateCRC(r.leftovers) - r.leftovers = nil + r.currMsg = nil if err != nil { return r.seen - alreadySeen, err } @@ -1482,7 +1494,7 @@ func (r *gRPCReader) WriteTo(w io.Writer) (int64, error) { // Attempt to receive the next message on the stream. // Will terminate with io.EOF once data has all come through. // recv() handles stream reopening and retry logic so no need for retries here. - msg, err := r.recv() + err := r.recv() if err != nil { if err == io.EOF { // We are done; check the checksum if necessary and return. @@ -1498,9 +1510,10 @@ func (r *gRPCReader) WriteTo(w io.Writer) (int64, error) { // present in the response here. // TODO: Figure out if we need to support decompressive transcoding // https://cloud.google.com/storage/docs/transcoding. - written, err := w.Write(msg) + written, err := r.currMsg.writeToAndUpdateCRC(w, func(b []byte) { + r.updateCRC(b) + }) r.seen += int64(written) - r.updateCRC(msg) if err != nil { return r.seen - alreadySeen, err } @@ -1509,12 +1522,13 @@ func (r *gRPCReader) WriteTo(w io.Writer) (int64, error) { } // Close cancels the read stream's context in order for it to be closed and -// collected. +// collected, and frees any currently in use buffers. func (r *gRPCReader) Close() error { if r.cancel != nil { r.cancel() } r.stream = nil + r.currMsg = nil return nil } @@ -1529,8 +1543,9 @@ func (r *gRPCReader) Close() error { // // The last error received is the one that is returned, which could be from // an attempt to reopen the stream. -func (r *gRPCReader) recv() ([]byte, error) { - err := r.stream.RecvMsg(&r.databuf) +func (r *gRPCReader) recv() error { + databufs := mem.BufferSlice{} + err := r.stream.RecvMsg(&databufs) var shouldRetry = ShouldRetry if r.settings.retry != nil && r.settings.retry.shouldRetry != nil { @@ -1540,16 +1555,16 @@ func (r *gRPCReader) recv() ([]byte, error) { // This will "close" the existing stream and immediately attempt to // reopen the stream, but will backoff if further attempts are necessary. // Reopening the stream Recvs the first message, so if retrying is - // successful, the next logical chunk will be returned. - msg, err := r.reopenStream() - return msg.GetChecksummedData().GetContent(), err + // successful, r.currMsg will be updated to include the new data. + return r.reopenStream() } if err != nil { - return nil, err + return err } - return readObjectResponseContent(r.databuf) + r.currMsg = &readResponseDecoder{databufs: databufs} + return r.currMsg.readFullObjectResponse() } // ReadObjectResponse field and subfield numbers. @@ -1562,21 +1577,297 @@ const ( metadataField = protowire.Number(4) ) -// readObjectResponseContent returns the checksummed_data.content field of a -// ReadObjectResponse message, or an error if the message is invalid. -// This can be used on recvs of objects after the first recv, since only the -// first message will contain non-data fields. -func readObjectResponseContent(b []byte) ([]byte, error) { - checksummedData, err := readProtoBytes(b, checksummedDataField) +// readResponseDecoder is a wrapper on the raw message, used to decode one message +// without copying object data. It also has methods to write out the resulting object +// data to the user application. +type readResponseDecoder struct { + databufs mem.BufferSlice // raw bytes of the message being processed + // Decoding offsets + off uint64 // offset in the messsage relative to the data as a whole + currBuf int // index of the current buffer being processed + currOff uint64 // offset in the current buffer + // Processed data + msg *storagepb.ReadObjectResponse // processed response message with all fields other than object data populated + dataOffsets bufferSliceOffsets // offsets of the object data in the message. + done bool // true if the data has been completely read. +} + +type bufferSliceOffsets struct { + startBuf, endBuf int // indices of start and end buffers of object data in the msg + startOff, endOff uint64 // offsets within these buffers where the data starts and ends. + currBuf int // index of current buffer being read out to the user application. + currOff uint64 // offset of read in current buffer. +} + +// peek ahead 10 bytes from the current offset in the databufs. This will return a +// slice of the current buffer if the bytes are all in one buffer, but will copy +// the bytes into a new buffer if the distance is split across buffers. Use this +// to allow protowire methods to be used to parse tags & fixed values. +// The max length of a varint tag is 10 bytes, see +// https://protobuf.dev/programming-guides/encoding/#varints . Other int types +// are shorter. +func (d *readResponseDecoder) peek() []byte { + b := d.databufs[d.currBuf].ReadOnlyData() + // Check if the tag will fit in the current buffer. If not, copy the next 10 + // bytes into a new buffer to ensure that we can read the tag correctly + // without it being divided between buffers. + tagBuf := b[d.currOff:] + remainingInBuf := len(tagBuf) + // If we have less than 10 bytes remaining and are not in the final buffer, + // copy up to 10 bytes ahead from the next buffer. + if remainingInBuf < binary.MaxVarintLen64 && d.currBuf != len(d.databufs)-1 { + tagBuf = d.copyNextBytes(10) + } + return tagBuf +} + +// Copies up to next n bytes into a new buffer, or fewer if fewer bytes remain in the +// buffers overall. Does not advance offsets. +func (d *readResponseDecoder) copyNextBytes(n int) []byte { + remaining := n + if r := d.databufs.Len() - int(d.off); r < remaining { + remaining = r + } + currBuf := d.currBuf + currOff := d.currOff + var buf []byte + for remaining > 0 { + b := d.databufs[currBuf].ReadOnlyData() + remainingInCurr := len(b[currOff:]) + if remainingInCurr < remaining { + buf = append(buf, b[currOff:]...) + remaining -= remainingInCurr + currBuf++ + currOff = 0 + } else { + buf = append(buf, b[currOff:currOff+uint64(remaining)]...) + remaining = 0 + } + } + return buf +} + +// Advance current buffer & byte offset in the decoding by n bytes. Returns an error if we +// go past the end of the data. +func (d *readResponseDecoder) advanceOffset(n uint64) error { + remaining := n + for remaining > 0 { + remainingInCurr := uint64(d.databufs[d.currBuf].Len()) - d.currOff + if remainingInCurr <= remaining { + remaining -= remainingInCurr + d.currBuf++ + d.currOff = 0 + } else { + d.currOff += remaining + remaining = 0 + } + } + // If we have advanced past the end of the buffers, something went wrong. + if (d.currBuf == len(d.databufs) && d.currOff > 0) || d.currBuf > len(d.databufs) { + return errors.New("decoding: truncated message, cannot advance offset") + } + d.off += n + return nil + +} + +// This copies object data from the message into the buffer and returns the number of +// bytes copied. The data offsets are incremented in the message. The updateCRC +// function is called on the copied bytes. +func (d *readResponseDecoder) readAndUpdateCRC(p []byte, updateCRC func([]byte)) int { + // For a completely empty message, just return 0 + if len(d.databufs) == 0 { + return 0 + } + databuf := d.databufs[d.dataOffsets.currBuf] + startOff := d.dataOffsets.currOff + var b []byte + if d.dataOffsets.currBuf == d.dataOffsets.endBuf { + b = databuf.ReadOnlyData()[startOff:d.dataOffsets.endOff] + } else { + b = databuf.ReadOnlyData()[startOff:] + } + n := copy(p, b) + updateCRC(b[:n]) + d.dataOffsets.currOff += uint64(n) + + // We've read all the data from this message. Free the underlying buffers. + if d.dataOffsets.currBuf == d.dataOffsets.endBuf && d.dataOffsets.currOff == d.dataOffsets.endOff { + d.done = true + d.databufs.Free() + } + // We are at the end of the current buffer + if d.dataOffsets.currBuf != d.dataOffsets.endBuf && d.dataOffsets.currOff == uint64(databuf.Len()) { + d.dataOffsets.currOff = 0 + d.dataOffsets.currBuf++ + } + return n +} + +func (d *readResponseDecoder) writeToAndUpdateCRC(w io.Writer, updateCRC func([]byte)) (int64, error) { + // For a completely empty message, just return 0 + if len(d.databufs) == 0 { + return 0, nil + } + var written int64 + for !d.done { + databuf := d.databufs[d.dataOffsets.currBuf] + startOff := d.dataOffsets.currOff + var b []byte + if d.dataOffsets.currBuf == d.dataOffsets.endBuf { + b = databuf.ReadOnlyData()[startOff:d.dataOffsets.endOff] + } else { + b = databuf.ReadOnlyData()[startOff:] + } + var n int + // Write all remaining data from the current buffer + n, err := w.Write(b) + written += int64(n) + updateCRC(b) + if err != nil { + return written, err + } + d.dataOffsets.currOff = 0 + // We've read all the data from this message. + if d.dataOffsets.currBuf == d.dataOffsets.endBuf { + d.done = true + d.databufs.Free() + } else { + d.dataOffsets.currBuf++ + } + } + return written, nil +} + +// Consume the next available tag in the input data and return the field number and type. +// Advances the relevant offsets in the data. +func (d *readResponseDecoder) consumeTag() (protowire.Number, protowire.Type, error) { + tagBuf := d.peek() + + // Consume the next tag. This will tell us which field is next in the + // buffer, its type, and how much space it takes up. + fieldNum, fieldType, tagLength := protowire.ConsumeTag(tagBuf) + if tagLength < 0 { + return 0, 0, protowire.ParseError(tagLength) + } + // Update the offsets and current buffer depending on the tag length. + if err := d.advanceOffset(uint64(tagLength)); err != nil { + return 0, 0, fmt.Errorf("consuming tag: %w", err) + } + return fieldNum, fieldType, nil +} + +// Consume a varint that represents the length of a bytes field. Return the length of +// the data, and advance the offsets by the length of the varint. +func (d *readResponseDecoder) consumeVarint() (uint64, error) { + tagBuf := d.peek() + + // Consume the next tag. This will tell us which field is next in the + // buffer, its type, and how much space it takes up. + dataLength, tagLength := protowire.ConsumeVarint(tagBuf) + if tagLength < 0 { + return 0, protowire.ParseError(tagLength) + } + + // Update the offsets and current buffer depending on the tag length. + d.advanceOffset(uint64(tagLength)) + return dataLength, nil +} + +func (d *readResponseDecoder) consumeFixed32() (uint32, error) { + valueBuf := d.peek() + + // Consume the next tag. This will tell us which field is next in the + // buffer, its type, and how much space it takes up. + value, tagLength := protowire.ConsumeFixed32(valueBuf) + if tagLength < 0 { + return 0, protowire.ParseError(tagLength) + } + + // Update the offsets and current buffer depending on the tag length. + d.advanceOffset(uint64(tagLength)) + return value, nil +} + +func (d *readResponseDecoder) consumeFixed64() (uint64, error) { + valueBuf := d.peek() + + // Consume the next tag. This will tell us which field is next in the + // buffer, its type, and how much space it takes up. + value, tagLength := protowire.ConsumeFixed64(valueBuf) + if tagLength < 0 { + return 0, protowire.ParseError(tagLength) + } + + // Update the offsets and current buffer depending on the tag length. + d.advanceOffset(uint64(tagLength)) + return value, nil +} + +// Consume any field values up to the end offset provided and don't return anything. +// This is used to skip any values which are not going to be used. +// msgEndOff is indexed in terms of the overall data across all buffers. +func (d *readResponseDecoder) consumeFieldValue(fieldNum protowire.Number, fieldType protowire.Type) error { + // reimplement protowire.ConsumeFieldValue without the extra case for groups (which + // are are complicted and not a thing in proto3). + var err error + switch fieldType { + case protowire.VarintType: + _, err = d.consumeVarint() + case protowire.Fixed32Type: + _, err = d.consumeFixed32() + case protowire.Fixed64Type: + _, err = d.consumeFixed64() + case protowire.BytesType: + _, err = d.consumeBytes() + default: + return fmt.Errorf("unknown field type %v in field %v", fieldType, fieldNum) + } if err != nil { - return b, fmt.Errorf("invalid ReadObjectResponse.ChecksummedData: %v", err) + return fmt.Errorf("consuming field %v of type %v: %w", fieldNum, fieldType, err) } - content, err := readProtoBytes(checksummedData, checksummedDataContentField) + + return nil +} + +// Consume a bytes field from the input. Returns offsets for the data in the buffer slices +// and an error. +func (d *readResponseDecoder) consumeBytes() (bufferSliceOffsets, error) { + // m is the length of the data past the tag. + m, err := d.consumeVarint() if err != nil { - return content, fmt.Errorf("invalid ReadObjectResponse.ChecksummedData.Content: %v", err) + return bufferSliceOffsets{}, fmt.Errorf("consuming bytes field: %w", err) + } + offsets := bufferSliceOffsets{ + startBuf: d.currBuf, + startOff: d.currOff, + currBuf: d.currBuf, + currOff: d.currOff, } - return content, nil + // Advance offsets to lengths of bytes field and capture where we end. + d.advanceOffset(m) + offsets.endBuf = d.currBuf + offsets.endOff = d.currOff + return offsets, nil +} + +// Consume a bytes field from the input and copy into a new buffer if +// necessary (if the data is split across buffers in databuf). This can be +// used to leverage proto.Unmarshal for small bytes fields (i.e. anything +// except object data). +func (d *readResponseDecoder) consumeBytesCopy() ([]byte, error) { + // m is the length of the bytes data. + m, err := d.consumeVarint() + if err != nil { + return nil, fmt.Errorf("consuming varint: %w", err) + } + // Copy the data into a buffer and advance the offset + b := d.copyNextBytes(int(m)) + if err := d.advanceOffset(m); err != nil { + return nil, fmt.Errorf("advancing offset: %w", err) + } + return b, nil } // readFullObjectResponse returns the ReadObjectResponse that is encoded in the @@ -1586,21 +1877,17 @@ func readObjectResponseContent(b []byte) ([]byte, error) { // This function is essentially identical to proto.Unmarshal, except it aliases // the data in the input []byte. If the proto library adds a feature to // Unmarshal that does that, this function can be dropped. -func readFullObjectResponse(b []byte) (*storagepb.ReadObjectResponse, error) { +func (d *readResponseDecoder) readFullObjectResponse() error { msg := &storagepb.ReadObjectResponse{} // Loop over the entire message, extracting fields as we go. This does not // handle field concatenation, in which the contents of a single field // are split across multiple protobuf tags. - off := 0 - for off < len(b) { - // Consume the next tag. This will tell us which field is next in the - // buffer, its type, and how much space it takes up. - fieldNum, fieldType, fieldLength := protowire.ConsumeTag(b[off:]) - if fieldLength < 0 { - return nil, protowire.ParseError(fieldLength) + for d.off < uint64(d.databufs.Len()) { + fieldNum, fieldType, err := d.consumeTag() + if err != nil { + return fmt.Errorf("consuming next tag: %w", err) } - off += fieldLength // Unmarshal the field according to its type. Only fields that are not // nil will be present. @@ -1609,142 +1896,95 @@ func readFullObjectResponse(b []byte) (*storagepb.ReadObjectResponse, error) { // The ChecksummedData field was found. Initialize the struct. msg.ChecksummedData = &storagepb.ChecksummedData{} - // Get the bytes corresponding to the checksummed data. - fieldContent, n := protowire.ConsumeBytes(b[off:]) - if n < 0 { - return nil, fmt.Errorf("invalid ReadObjectResponse.ChecksummedData: %v", protowire.ParseError(n)) + bytesFieldLen, err := d.consumeVarint() + if err != nil { + return fmt.Errorf("consuming bytes: %v", err) } - off += n - - // Get the nested fields. We need to do this manually as it contains - // the object content bytes. - contentOff := 0 - for contentOff < len(fieldContent) { - gotNum, gotTyp, n := protowire.ConsumeTag(fieldContent[contentOff:]) - if n < 0 { - return nil, protowire.ParseError(n) + + var contentEndOff = d.off + bytesFieldLen + for d.off < contentEndOff { + gotNum, gotTyp, err := d.consumeTag() + if err != nil { + return fmt.Errorf("consuming checksummedData tag: %w", err) } - contentOff += n switch { case gotNum == checksummedDataContentField && gotTyp == protowire.BytesType: - // Get the content bytes. - bytes, n := protowire.ConsumeBytes(fieldContent[contentOff:]) - if n < 0 { - return nil, fmt.Errorf("invalid ReadObjectResponse.ChecksummedData.Content: %v", protowire.ParseError(n)) + // Get the offsets of the content bytes. + d.dataOffsets, err = d.consumeBytes() + if err != nil { + return fmt.Errorf("invalid ReadObjectResponse.ChecksummedData.Content: %w", err) } - msg.ChecksummedData.Content = bytes - contentOff += n case gotNum == checksummedDataCRC32CField && gotTyp == protowire.Fixed32Type: - v, n := protowire.ConsumeFixed32(fieldContent[contentOff:]) - if n < 0 { - return nil, fmt.Errorf("invalid ReadObjectResponse.ChecksummedData.Crc32C: %v", protowire.ParseError(n)) + v, err := d.consumeFixed32() + if err != nil { + return fmt.Errorf("invalid ReadObjectResponse.ChecksummedData.Crc32C: %w", err) } msg.ChecksummedData.Crc32C = &v - contentOff += n default: - n = protowire.ConsumeFieldValue(gotNum, gotTyp, fieldContent[contentOff:]) - if n < 0 { - return nil, protowire.ParseError(n) + err := d.consumeFieldValue(gotNum, gotTyp) + if err != nil { + return fmt.Errorf("invalid field in ReadObjectResponse.ChecksummedData: %w", err) } - contentOff += n } } case fieldNum == objectChecksumsField && fieldType == protowire.BytesType: // The field was found. Initialize the struct. msg.ObjectChecksums = &storagepb.ObjectChecksums{} - - // Get the bytes corresponding to the checksums. - bytes, n := protowire.ConsumeBytes(b[off:]) - if n < 0 { - return nil, fmt.Errorf("invalid ReadObjectResponse.ObjectChecksums: %v", protowire.ParseError(n)) + // Consume the bytes and copy them into a single buffer if they are split across buffers. + buf, err := d.consumeBytesCopy() + if err != nil { + return fmt.Errorf("invalid ReadObjectResponse.ObjectChecksums: %v", err) } - off += n - // Unmarshal. - if err := proto.Unmarshal(bytes, msg.ObjectChecksums); err != nil { - return nil, err + if err := proto.Unmarshal(buf, msg.ObjectChecksums); err != nil { + return err } case fieldNum == contentRangeField && fieldType == protowire.BytesType: msg.ContentRange = &storagepb.ContentRange{} - - bytes, n := protowire.ConsumeBytes(b[off:]) - if n < 0 { - return nil, fmt.Errorf("invalid ReadObjectResponse.ContentRange: %v", protowire.ParseError(n)) + buf, err := d.consumeBytesCopy() + if err != nil { + return fmt.Errorf("invalid ReadObjectResponse.ContentRange: %v", err) } - off += n - - if err := proto.Unmarshal(bytes, msg.ContentRange); err != nil { - return nil, err + if err := proto.Unmarshal(buf, msg.ContentRange); err != nil { + return err } case fieldNum == metadataField && fieldType == protowire.BytesType: msg.Metadata = &storagepb.Object{} - bytes, n := protowire.ConsumeBytes(b[off:]) - if n < 0 { - return nil, fmt.Errorf("invalid ReadObjectResponse.Metadata: %v", protowire.ParseError(n)) + buf, err := d.consumeBytesCopy() + if err != nil { + return fmt.Errorf("invalid ReadObjectResponse.Metadata: %v", err) } - off += n - if err := proto.Unmarshal(bytes, msg.Metadata); err != nil { - return nil, err + if err := proto.Unmarshal(buf, msg.Metadata); err != nil { + return err } default: - fieldLength = protowire.ConsumeFieldValue(fieldNum, fieldType, b[off:]) - if fieldLength < 0 { - return nil, fmt.Errorf("default: %v", protowire.ParseError(fieldLength)) - } - off += fieldLength - } - } - - return msg, nil -} - -// readProtoBytes returns the contents of the protobuf field with number num -// and type bytes from a wire-encoded message. If the field cannot be found, -// the returned slice will be nil and no error will be returned. -// -// It does not handle field concatenation, in which the contents of a single field -// are split across multiple protobuf tags. Encoded data containing split fields -// of this form is technically permissable, but uncommon. -func readProtoBytes(b []byte, num protowire.Number) ([]byte, error) { - off := 0 - for off < len(b) { - gotNum, gotTyp, n := protowire.ConsumeTag(b[off:]) - if n < 0 { - return nil, protowire.ParseError(n) - } - off += n - if gotNum == num && gotTyp == protowire.BytesType { - b, n := protowire.ConsumeBytes(b[off:]) - if n < 0 { - return nil, protowire.ParseError(n) + err := d.consumeFieldValue(fieldNum, fieldType) + if err != nil { + return fmt.Errorf("invalid field in ReadObjectResponse: %w", err) } - return b, nil } - n = protowire.ConsumeFieldValue(gotNum, gotTyp, b[off:]) - if n < 0 { - return nil, protowire.ParseError(n) - } - off += n } - return nil, nil + d.msg = msg + return nil } // reopenStream "closes" the existing stream and attempts to reopen a stream and // sets the Reader's stream and cancelStream properties in the process. -func (r *gRPCReader) reopenStream() (*storagepb.ReadObjectResponse, error) { +func (r *gRPCReader) reopenStream() error { // Close existing stream and initialize new stream with updated offset. r.Close() res, cancel, err := r.reopen(r.seen) if err != nil { - return nil, err + return err } r.stream = res.stream + r.currMsg = res.decoder r.cancel = cancel - return res.response, nil + return nil } func newGRPCWriter(c *grpcStorageClient, params *openWriterParams, r io.Reader) *gRPCWriter { diff --git a/storage/grpc_client_test.go b/storage/grpc_client_test.go index 5c1eb0f12838..d75975e40ecc 100644 --- a/storage/grpc_client_test.go +++ b/storage/grpc_client_test.go @@ -15,6 +15,7 @@ package storage import ( + "bytes" "crypto/md5" "hash/crc32" "math/rand" @@ -23,11 +24,12 @@ import ( "cloud.google.com/go/storage/internal/apiv2/storagepb" "github.com/google/go-cmp/cmp" + "google.golang.org/grpc/mem" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/testing/protocmp" ) -func TestBytesCodec(t *testing.T) { +func TestBytesCodecV2(t *testing.T) { // Generate some random content. content := make([]byte, 1<<10+1) // 1 kib + 1 byte rand.New(rand.NewSource(0)).Read(content) @@ -85,8 +87,9 @@ func TestBytesCodec(t *testing.T) { } for _, test := range []struct { - desc string - resp *storagepb.ReadObjectResponse + desc string + resp *storagepb.ReadObjectResponse + wantContent []byte }{ { desc: "filled object response", @@ -106,41 +109,111 @@ func TestBytesCodec(t *testing.T) { }, Metadata: metadata, }, + wantContent: content, }, { - desc: "empty object response", - resp: &storagepb.ReadObjectResponse{}, + desc: "empty object response", + resp: &storagepb.ReadObjectResponse{}, + wantContent: []byte{}, }, { desc: "partially empty", resp: &storagepb.ReadObjectResponse{ ChecksummedData: &storagepb.ChecksummedData{}, ObjectChecksums: &storagepb.ObjectChecksums{Md5Hash: md5}, - Metadata: &storagepb.Object{}, }, + wantContent: []byte{}, }, } { t.Run(test.desc, func(t *testing.T) { - // Encode the response. - encodedResp, err := proto.Marshal(test.resp) - if err != nil { - t.Fatalf("proto.Marshal: %v", err) - } + for _, subtest := range []struct { + desc string + splitBuffer func([]byte) [][]byte // call this to split the message into multiple buffers. + }{ + { + desc: "single buffer", + splitBuffer: func(b []byte) [][]byte { + return [][]byte{b} + }, + }, + { + desc: "split every 100 bytes", + splitBuffer: func(b []byte) [][]byte { + var bufs [][]byte + var i int + for i = 0; i < len(b)-100; i += 100 { + bufs = append(bufs, b[i:i+100]) + } + bufs = append(bufs, b[i:]) + return bufs + }, + }, + { + desc: "split every 8 bytes", + splitBuffer: func(b []byte) [][]byte { + var bufs [][]byte + var i int + for i = 0; i < len(b)-8; i += 8 { + bufs = append(bufs, b[i:i+8]) + } + bufs = append(bufs, b[i:]) + return bufs + }, + }, + { + desc: "split every byte", + splitBuffer: func(b []byte) [][]byte { + var bufs [][]byte + for i := 0; i < len(b); i++ { + bufs = append(bufs, b[i:i+1]) + } + return bufs + }, + }, + } { + t.Run(subtest.desc, func(t *testing.T) { + // Encode the response. + encodedResp, err := proto.Marshal(test.resp) + if err != nil { + t.Fatalf("proto.Marshal: %v", err) + } + // Convert response data into mem.BufferSlice, potentially split across multiple buffers. + var respData mem.BufferSlice + slices := subtest.splitBuffer(encodedResp) + for _, s := range slices { + respData = append(respData, mem.SliceBuffer(s)) + } + // Unmarshal and decode response using custom decoding. + var encodedBytes mem.BufferSlice = mem.BufferSlice{} + if err := bytesCodecV2.Unmarshal(bytesCodecV2{}, respData, &encodedBytes); err != nil { + t.Fatalf("unmarshal: %v", err) + } - // Unmarshal and decode response using custom decoding. - encodedBytes := &[]byte{} - if err := bytesCodec.Unmarshal(bytesCodec{}, encodedResp, encodedBytes); err != nil { - t.Fatalf("unmarshal: %v", err) - } + decoder := &readResponseDecoder{ + databufs: encodedBytes, + } - got, err := readFullObjectResponse(*encodedBytes) - if err != nil { - t.Fatalf("readFullObjectResponse: %v", err) - } + err = decoder.readFullObjectResponse() + if err != nil { + t.Fatalf("readFullObjectResponse: %v", err) + } + + // Compare the result with the original ReadObjectResponse, without the content + if diff := cmp.Diff(decoder.msg, test.resp, protocmp.Transform(), protocmp.IgnoreMessages(&storagepb.ChecksummedData{})); diff != "" { + t.Errorf("cmp.Diff message: got(-),want(+):\n%s", diff) + } + + // Read out the data and compare length and content. + buf := &bytes.Buffer{} + n, err := decoder.writeToAndUpdateCRC(buf, func([]byte) {}) + if n != int64(len(test.wantContent)) { + t.Errorf("mismatched content length: got %d, want %d, offsets %+v", n, len(content), decoder.dataOffsets) + } + if !bytes.Equal(buf.Bytes(), test.wantContent) { + t.Errorf("returned message content did not match") + } - // Compare the result with the original ReadObjectResponse. - if diff := cmp.Diff(got, test.resp, protocmp.Transform()); diff != "" { - t.Errorf("cmp.Diff got(-),want(+):\n%s", diff) + }) } }) }