diff --git a/pkg/cmd/cli/pair.go b/pkg/cmd/cli/pair.go index 6c7803c..65d8144 100644 --- a/pkg/cmd/cli/pair.go +++ b/pkg/cmd/cli/pair.go @@ -69,7 +69,7 @@ func NewPAIRConfig(ctx context.Context, token string, threads int, key string) ( }, nil } -func (c *pairConfig) hashEncryt(ctx context.Context, input string) error { +func (c *pairConfig) hashEncryt(ctx context.Context, input string) (err error) { logger := zerolog.Ctx(ctx) logger.Info().Msg("Step 1: Hash and encrypt the advertiser data.") @@ -85,6 +85,12 @@ func (c *pairConfig) hashEncryt(ctx context.Context, input string) error { return fmt.Errorf("bucket.NewBucketCompleter: %w", err) } defer func() { + // don't complete the bucket if there was an error to prevent writing + // unwanted files. + if err != nil { + return + } + if err := bucketCompleter.Complete(ctx); err != nil { logger.Error().Err(err).Msg("failed to write .Completed file to bucket") return @@ -96,6 +102,12 @@ func (c *pairConfig) hashEncryt(ctx context.Context, input string) error { return fmt.Errorf("bucket.NewBucket: %w", err) } defer func() { + // don't close the bucket if there was an error to prevent writing + // unwanted files. + if err != nil { + return + } + if err := b.Close(); err != nil { logger.Error().Err(err).Msg("failed to close bucket") return @@ -117,10 +129,10 @@ func (c *pairConfig) hashEncryt(ctx context.Context, input string) error { logger.Info().Msg("Step 1: Hash and encrypt the advertiser data completed.") - return nil + return } -func (c *pairConfig) reEncrypt(ctx context.Context, publisherPAIRIDsPath string) error { +func (c *pairConfig) reEncrypt(ctx context.Context, publisherPAIRIDsPath string) (err error) { logger := zerolog.Ctx(ctx) logger.Info().Msg("Step 2: Re-encrypt the publisher's hashed and encrypted PAIR IDs.") @@ -130,6 +142,12 @@ func (c *pairConfig) reEncrypt(ctx context.Context, publisherPAIRIDsPath string) return fmt.Errorf("bucket.NewBucketCompleter: %w", err) } defer func() { + // don't complete the bucket if there was an error to prevent writing + // unwanted files. + if err != nil { + return + } + if err := bucketCompleter.Complete(ctx); err != nil { logger.Error().Err(err).Msg("failed to write .Completed file to bucket") return @@ -141,6 +159,12 @@ func (c *pairConfig) reEncrypt(ctx context.Context, publisherPAIRIDsPath string) return fmt.Errorf("bucket.NewBucket: %w", err) } defer func() { + // don't close the bucket if there was an error to prevent writing + // unwanted files. + if err != nil { + return + } + if err := b.Close(); err != nil { logger.Error().Err(err).Msg("failed to close bucket") return @@ -177,7 +201,7 @@ func (c *pairConfig) reEncrypt(ctx context.Context, publisherPAIRIDsPath string) logger.Info().Msg("Step 2: Re-encrypt the publisher's hashed and encrypted PAIR IDs completed.") - return nil + return } func (c *pairConfig) match(ctx context.Context, outputPath string, publisherPAIRIDsPath string) error { diff --git a/pkg/pair/pair.go b/pkg/pair/pair.go index 2e1ba3b..6180ce1 100644 --- a/pkg/pair/pair.go +++ b/pkg/pair/pair.go @@ -19,11 +19,16 @@ import ( ) const ( - batchSize = 1024 + batchSize = 1024 + minimumIDCount = 1000 maxOperationRunTime = 4 * time.Hour ) +var ( + ErrInputBelowThreshold = errors.New("not enough identifiers for a secure PAIR ID match") +) + type ( pairIDReadWriter struct { reader *pairIDReader @@ -215,6 +220,10 @@ func runPAIROperation(ctx context.Context, p *pairIDReadWriter, numWorkers int, } close(done) + if p.reader.read.Load() < minimumIDCount { + return ErrInputBelowThreshold + } + logger.Debug().Msgf("%s: read %d IDs, written %d PAIR IDs in %s", op, p.reader.read.Load(), p.written.Load(), time.Since(startTime)) return nil case <-ctx.Done():