Skip to content

Commit 3ec0518

Browse files
fix(sdk): more efficient encryption in experiment TDF Writer (#2904)
### Proposed Changes * Better locking to allow parallel encryption. Less memory allocations, causing reuse. ### Checklist - [ ] I have added or updated unit tests - [ ] I have added or updated integration tests (if appropriate) - [ ] I have added or updated documentation ### Testing Instructions pr benchmark: ``` endpoint: http://localhost:8080 # Benchmark Experimental TDF Writer Results: | Metric | Value | |--------------------|--------------| | Payload Size (B) | 1048576999 | | Output Size (B) | 6716081 | | Total Time | 222.619708ms | ``` main benchmark: ``` endpoint: http://localhost:8080 # Benchmark Experimental TDF Writer Results: | Metric | Value | |--------------------|--------------| | Payload Size (B) | 1048576999 | | Output Size (B) | 6721718 | | Total Time | 1.37738475s | ```
1 parent f8cbe15 commit 3ec0518

File tree

11 files changed

+241
-81
lines changed

11 files changed

+241
-81
lines changed
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
//nolint:forbidigo // We use Println here extensively because we are printing markdown.
2+
package cmd
3+
4+
import (
5+
"context"
6+
"crypto/rand"
7+
"fmt"
8+
"sync"
9+
"time"
10+
11+
"connectrpc.com/connect"
12+
"github.com/opentdf/platform/lib/ocrypto"
13+
kasp "github.com/opentdf/platform/protocol/go/kas"
14+
"github.com/opentdf/platform/protocol/go/kas/kasconnect"
15+
"github.com/opentdf/platform/protocol/go/policy"
16+
17+
"github.com/opentdf/platform/sdk/experimental/tdf"
18+
"github.com/opentdf/platform/sdk/httputil"
19+
"github.com/spf13/cobra"
20+
)
21+
22+
var (
23+
payloadSize int
24+
segmentChunk int
25+
testAttr = "https://example.com/attr/attr1/value/value1"
26+
)
27+
28+
func init() {
29+
benchmarkCmd := &cobra.Command{
30+
Use: "benchmark-experimental-writer",
31+
Short: "Benchmark experimental TDF writer speed",
32+
Long: `Benchmark the experimental TDF writer with configurable payload size.`,
33+
RunE: runExperimentalWriterBenchmark,
34+
}
35+
//nolint: mnd // no magic number, this is just default value for payload size
36+
benchmarkCmd.Flags().IntVar(&payloadSize, "payload-size", 1024*1024, "Payload size in bytes") // Default 1MB
37+
//nolint: mnd // same as above
38+
benchmarkCmd.Flags().IntVar(&segmentChunk, "segment-chunks", 16*1024, "segment chunks ize") // Default 16 segments
39+
ExamplesCmd.AddCommand(benchmarkCmd)
40+
}
41+
42+
func runExperimentalWriterBenchmark(_ *cobra.Command, _ []string) error {
43+
payload := make([]byte, payloadSize)
44+
_, err := rand.Read(payload)
45+
if err != nil {
46+
return fmt.Errorf("failed to generate random payload: %w", err)
47+
}
48+
49+
http := httputil.SafeHTTPClient()
50+
fmt.Println("endpoint:", platformEndpoint)
51+
serviceClient := kasconnect.NewAccessServiceClient(http, platformEndpoint)
52+
resp, err := serviceClient.PublicKey(context.Background(), connect.NewRequest(&kasp.PublicKeyRequest{Algorithm: string(ocrypto.RSA2048Key)}))
53+
if err != nil {
54+
return fmt.Errorf("failed to get public key from KAS: %w", err)
55+
}
56+
var attrs []*policy.Value
57+
58+
simpleyKey := &policy.SimpleKasKey{
59+
KasUri: platformEndpoint,
60+
KasId: "id",
61+
PublicKey: &policy.SimpleKasPublicKey{
62+
Kid: resp.Msg.GetKid(),
63+
Pem: resp.Msg.GetPublicKey(),
64+
Algorithm: policy.Algorithm_ALGORITHM_RSA_2048,
65+
},
66+
}
67+
68+
attrs = append(attrs, &policy.Value{Fqn: testAttr, KasKeys: []*policy.SimpleKasKey{simpleyKey}, Attribute: &policy.Attribute{Namespace: &policy.Namespace{Name: "example.com"}, Fqn: testAttr}})
69+
writer, err := tdf.NewWriter(context.Background(), tdf.WithDefaultKASForWriter(simpleyKey), tdf.WithInitialAttributes(attrs), tdf.WithSegmentIntegrityAlgorithm(tdf.HS256))
70+
if err != nil {
71+
return fmt.Errorf("failed to create writer: %w", err)
72+
}
73+
i := 0
74+
wg := sync.WaitGroup{}
75+
segs := len(payload) / segmentChunk
76+
wg.Add(segs)
77+
start := time.Now()
78+
for i < segs {
79+
segment := i
80+
go func() {
81+
start := i * segmentChunk
82+
end := min(start+segmentChunk, len(payload))
83+
_, err = writer.WriteSegment(context.Background(), segment, payload[start:end])
84+
if err != nil {
85+
fmt.Println(err)
86+
panic(err)
87+
}
88+
wg.Done()
89+
}()
90+
i++
91+
}
92+
wg.Wait()
93+
94+
end := time.Now()
95+
result, err := writer.Finalize(context.Background())
96+
if err != nil {
97+
return fmt.Errorf("failed to finalize writer: %w", err)
98+
}
99+
totalTime := end.Sub(start)
100+
101+
fmt.Printf("# Benchmark Experimental TDF Writer Results:\n")
102+
fmt.Printf("| Metric | Value |\n")
103+
fmt.Printf("|--------------------|--------------|\n")
104+
fmt.Printf("| Payload Size (B) | %d |\n", payloadSize)
105+
fmt.Printf("| Output Size (B) | %d |\n", len(result.Data))
106+
fmt.Printf("| Total Time | %s |\n", totalTime)
107+
108+
return nil
109+
}

examples/go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ go 1.24.0
55
toolchain go1.24.9
66

77
require (
8+
connectrpc.com/connect v1.18.1
89
github.com/opentdf/platform/lib/ocrypto v0.7.0
910
github.com/opentdf/platform/protocol/go v0.13.0
1011
github.com/opentdf/platform/sdk v0.10.1
@@ -16,7 +17,6 @@ require (
1617

1718
require (
1819
buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.36.6-20250603165357-b52ab10f4468.1 // indirect
19-
connectrpc.com/connect v1.18.1 // indirect
2020
github.com/Masterminds/semver/v3 v3.3.1 // indirect
2121
github.com/cpuguy83/go-md2man/v2 v2.0.6 // indirect
2222
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect

lib/ocrypto/aes_gcm.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,21 @@ func (aesGcm AesGcm) Encrypt(data []byte) ([]byte, error) {
4747
return cipherText, nil
4848
}
4949

50+
func (aesGcm AesGcm) EncryptInPlace(data []byte) ([]byte, []byte, error) {
51+
nonce, err := RandomBytes(GcmStandardNonceSize)
52+
if err != nil {
53+
return nil, nil, err
54+
}
55+
56+
gcm, err := cipher.NewGCMWithNonceSize(aesGcm.block, GcmStandardNonceSize)
57+
if err != nil {
58+
return nil, nil, fmt.Errorf("cipher.NewGCMWithNonceSize failed: %w", err)
59+
}
60+
61+
cipherText := gcm.Seal(data[:0], nonce, data, nil)
62+
return cipherText, nonce, nil
63+
}
64+
5065
// EncryptWithIV encrypts data with symmetric key.
5166
// NOTE: This method use default auth tag as aes block size(16 bytes)
5267
// and expects iv of 16 bytes.

sdk/experimental/tdf/writer.go

Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"errors"
1111
"fmt"
1212
"hash/crc32"
13+
"io"
1314
"log/slog"
1415
"sort"
1516
"sync"
@@ -34,12 +35,11 @@ const (
3435

3536
// SegmentResult contains the result of writing a segment
3637
type SegmentResult struct {
37-
Data []byte `json:"data"` // Encrypted segment bytes (for streaming)
38-
Index int `json:"index"` // Segment index
39-
Hash string `json:"hash"` // Base64-encoded integrity hash
40-
PlaintextSize int64 `json:"plaintextSize"` // Original data size
41-
EncryptedSize int64 `json:"encryptedSize"` // Encrypted data size
42-
CRC32 uint32 `json:"crc32"` // CRC32 checksum
38+
TDFData io.Reader // Reader for the full TDF segment (nonce + encrypted data + zip structures)
39+
Index int `json:"index"` // Segment index
40+
Hash string `json:"hash"` // Base64-encoded integrity hash
41+
PlaintextSize int64 `json:"plaintextSize"` // Original data size
42+
EncryptedSize int64 `json:"encryptedSize"` // Encrypted data size
4343
}
4444

4545
// FinalizeResult contains the complete TDF creation result
@@ -110,7 +110,7 @@ type Writer struct {
110110

111111
// segments stores segment metadata using sparse map for memory efficiency
112112
// Maps segment index to Segment metadata (hash, size information)
113-
segments map[int]Segment
113+
segments map[int]*Segment
114114
// maxSegmentIndex tracks the highest segment index written
115115
maxSegmentIndex int
116116

@@ -183,7 +183,7 @@ func NewWriter(_ context.Context, opts ...Option[*WriterConfig]) (*Writer, error
183183
WriterConfig: *config,
184184
archiveWriter: archiveWriter,
185185
dek: dek,
186-
segments: make(map[int]Segment), // Initialize sparse storage
186+
segments: make(map[int]*Segment), // Initialize sparse storage
187187
block: block,
188188
initialAttributes: config.initialAttributes,
189189
initialDefaultKAS: config.initialDefaultKAS,
@@ -232,58 +232,76 @@ func NewWriter(_ context.Context, opts ...Option[*WriterConfig]) (*Writer, error
232232
// uploadToS3(segment1, "part-001")
233233
func (w *Writer) WriteSegment(ctx context.Context, index int, data []byte) (*SegmentResult, error) {
234234
w.mutex.Lock()
235-
defer w.mutex.Unlock()
236235

237236
if w.finalized {
237+
w.mutex.Unlock()
238238
return nil, ErrAlreadyFinalized
239239
}
240240

241241
if index < 0 {
242+
w.mutex.Unlock()
242243
return nil, ErrInvalidSegmentIndex
243244
}
244245

245246
// Check for duplicate segments using map lookup
246247
if _, exists := w.segments[index]; exists {
248+
w.mutex.Unlock()
247249
return nil, ErrSegmentAlreadyWritten
248250
}
249251

250252
if index > w.maxSegmentIndex {
251253
w.maxSegmentIndex = index
252254
}
255+
seg := &Segment{
256+
Size: -1, // indicates not filled yet
257+
}
258+
w.segments[index] = seg
253259

254-
// Calculate CRC32 before encryption for integrity tracking
255-
crc32Checksum := crc32.ChecksumIEEE(data)
260+
w.mutex.Unlock()
256261

257262
// Encrypt directly without unnecessary copying - the archive layer will handle copying if needed
258-
segmentCipher, err := w.block.Encrypt(data)
263+
segmentCipher, nonce, err := w.block.EncryptInPlace(data)
259264
if err != nil {
260265
return nil, err
261266
}
262-
263267
segmentSig, err := calculateSignature(segmentCipher, w.dek, w.segmentIntegrityAlgorithm, false) // Don't ever hex encode new tdf's
264268
if err != nil {
265269
return nil, err
266270
}
267271

268272
segmentHash := string(ocrypto.Base64Encode([]byte(segmentSig)))
269-
w.segments[index] = Segment{
270-
Hash: segmentHash,
271-
Size: int64(len(data)), // Use original data length
272-
EncryptedSize: int64(len(segmentCipher)),
273-
}
273+
w.mutex.Lock()
274+
seg.Size = int64(len(data))
275+
seg.EncryptedSize = int64(len(segmentCipher)) + int64(len(nonce))
276+
seg.Hash = segmentHash
277+
w.mutex.Unlock()
274278

275-
zipBytes, err := w.archiveWriter.WriteSegment(ctx, index, segmentCipher)
279+
crc := crc32.NewIEEE()
280+
_, err = crc.Write(nonce)
281+
if err != nil {
282+
return nil, err
283+
}
284+
_, err = crc.Write(segmentCipher)
285+
if err != nil {
286+
return nil, err
287+
}
288+
header, err := w.archiveWriter.WriteSegment(ctx, index, uint64(seg.EncryptedSize), crc.Sum32())
276289
if err != nil {
277290
return nil, err
278291
}
292+
var reader io.Reader
293+
if len(header) == 0 {
294+
reader = io.MultiReader(bytes.NewReader(nonce), bytes.NewReader(segmentCipher))
295+
} else {
296+
reader = io.MultiReader(bytes.NewReader(header), bytes.NewReader(nonce), bytes.NewReader(segmentCipher))
297+
}
279298

280299
return &SegmentResult{
281-
Data: zipBytes,
300+
TDFData: reader,
282301
Index: index,
283-
Hash: segmentHash,
284-
PlaintextSize: int64(len(data)),
285-
EncryptedSize: int64(len(segmentCipher)),
286-
CRC32: crc32Checksum,
302+
Hash: seg.Hash,
303+
PlaintextSize: seg.Size,
304+
EncryptedSize: seg.EncryptedSize,
287305
}, nil
288306
}
289307

@@ -505,7 +523,7 @@ func (w *Writer) getManifest(ctx context.Context, cfg *WriterFinalizeConfig) (*M
505523
// Copy segments to manifest in finalize order (pack densely)
506524
for i, idx := range order {
507525
if segment, exists := w.segments[idx]; exists {
508-
encryptInfo.Segments[i] = segment
526+
encryptInfo.Segments[i] = *segment
509527
}
510528
}
511529

@@ -524,7 +542,8 @@ func (w *Writer) getManifest(ctx context.Context, cfg *WriterFinalizeConfig) (*M
524542
var totalPlaintextSize, totalEncryptedSize int64
525543
for _, i := range order {
526544
segment, exists := w.segments[i]
527-
if !exists {
545+
// if size is negative, segment was not written, finalized has been called too early
546+
if !exists || w.segments[i].Size < 0 {
528547
return nil, 0, 0, fmt.Errorf("segment %d not written; cannot finalize", i)
529548
}
530549
if segment.Hash != "" {

sdk/experimental/tdf/writer_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -669,7 +669,7 @@ func testErrorConditions(t *testing.T) {
669669
require.NoError(t, err)
670670

671671
// Manually corrupt segment hash to test error handling
672-
writer.segments[0] = Segment{Hash: "", Size: 10, EncryptedSize: 26}
672+
writer.segments[0] = &Segment{Hash: "", Size: 10, EncryptedSize: 26}
673673
writer.maxSegmentIndex = 0
674674

675675
attributes := []*policy.Value{

sdk/internal/zipstream/benchmark_test.go

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
package zipstream
44

55
import (
6+
"hash/crc32"
67
"testing"
78
)
89

@@ -55,7 +56,8 @@ func BenchmarkSegmentWriter_CRC32ContiguousProcessing(b *testing.B) {
5556

5657
// Write segments in specified order
5758
for _, segIdx := range writeOrder {
58-
_, err := writer.WriteSegment(ctx, segIdx, segmentData)
59+
crc := crc32.ChecksumIEEE(segmentData)
60+
_, err := writer.WriteSegment(ctx, segIdx, uint64(len(segmentData)), crc)
5961
if err != nil {
6062
b.Fatal(err)
6163
}
@@ -99,7 +101,8 @@ func BenchmarkSegmentWriter_SparseIndices(b *testing.B) {
99101
// Write sparse indices in order
100102
for k := 0; k < n; k++ {
101103
idx := k * stride
102-
if _, err := w.WriteSegment(ctx, idx, data); err != nil {
104+
crc := crc32.ChecksumIEEE(data)
105+
if _, err := w.WriteSegment(ctx, idx, uint64(len(data)), crc); err != nil {
103106
b.Fatal(err)
104107
}
105108
}
@@ -153,7 +156,8 @@ func BenchmarkSegmentWriter_VariableSegmentSizes(b *testing.B) {
153156
segmentData[j] = byte((segIdx * j) % 256)
154157
}
155158

156-
_, err := writer.WriteSegment(ctx, segIdx, segmentData)
159+
crc := crc32.ChecksumIEEE(segmentData)
160+
_, err := writer.WriteSegment(ctx, segIdx, uint64(len(segmentData)), crc)
157161
if err != nil {
158162
b.Fatal(err)
159163
}
@@ -234,7 +238,8 @@ func BenchmarkSegmentWriter_MemoryPressure(b *testing.B) {
234238
segmentData[j] = byte((orderIdx * j) % 256)
235239
}
236240

237-
_, err := writer.WriteSegment(ctx, segIdx, segmentData)
241+
crc := crc32.ChecksumIEEE(segmentData)
242+
_, err := writer.WriteSegment(ctx, segIdx, uint64(len(segmentData)), crc)
238243
if err != nil {
239244
b.Fatal(err)
240245
}
@@ -305,7 +310,8 @@ func BenchmarkSegmentWriter_ZIPGeneration(b *testing.B) {
305310

306311
// Write all segments
307312
for segIdx := 0; segIdx < tc.segmentCount; segIdx++ {
308-
_, err := writer.WriteSegment(ctx, segIdx, segmentData)
313+
crc := crc32.ChecksumIEEE(segmentData)
314+
_, err := writer.WriteSegment(ctx, segIdx, uint64(len(segmentData)), crc)
309315
if err != nil {
310316
b.Fatal(err)
311317
}

0 commit comments

Comments
 (0)