diff --git a/pkg/scalers/azure/azure_eventhub_checkpoint.go b/pkg/scalers/azure/azure_eventhub_checkpoint.go index f9aa53f742f..474e919b20c 100644 --- a/pkg/scalers/azure/azure_eventhub_checkpoint.go +++ b/pkg/scalers/azure/azure_eventhub_checkpoint.go @@ -82,6 +82,11 @@ type goSdkCheckpointer struct { containerName string } +type daprCheckpointer struct { + partitionID string + containerName string +} + type defaultCheckpointer struct { partitionID string containerName string @@ -100,6 +105,11 @@ func newCheckpointer(info EventHubInfo, partitionID string) checkpointer { containerName: info.BlobContainer, partitionID: partitionID, } + case (info.CheckpointStrategy == "dapr"): + return &daprCheckpointer{ + containerName: info.BlobContainer, + partitionID: partitionID, + } case (info.CheckpointStrategy == "blobMetadata"): return &blobMetadataCheckpointer{ containerName: info.BlobContainer, @@ -176,19 +186,27 @@ func (checkpointer *goSdkCheckpointer) resolvePath(info EventHubInfo) (*url.URL, // extract checkpoint for goSdkCheckpointer func (checkpointer *goSdkCheckpointer) extractCheckpoint(get *azblob.DownloadResponse) (Checkpoint, error) { - var checkpoint goCheckpoint - err := readToCheckpointFromBody(get, &checkpoint) + return extractCheckpointGoSdk(get) +} + +// resolve path for daprCheckpointer +func (checkpointer *daprCheckpointer) resolvePath(info EventHubInfo) (*url.URL, error) { + _, eventHubName, err := getHubAndNamespace(info) if err != nil { - return Checkpoint{}, err + return nil, err } - return Checkpoint{ - SequenceNumber: checkpoint.Checkpoint.SequenceNumber, - baseCheckpoint: baseCheckpoint{ - Offset: checkpoint.Checkpoint.Offset, - }, - PartitionID: checkpoint.PartitionID, - }, nil + path, err := url.Parse(fmt.Sprintf("/%s/dapr-%s-%s-%s", checkpointer.containerName, eventHubName, info.EventHubConsumerGroup, checkpointer.partitionID)) + if err != nil { + return nil, err + } + + return path, nil +} + +// extract checkpoint for daprCheckpointer +func (checkpointer *daprCheckpointer) extractCheckpoint(get *azblob.DownloadResponse) (Checkpoint, error) { + return extractCheckpointGoSdk(get) } // resolve path for DefaultCheckpointer @@ -223,6 +241,22 @@ func (checkpointer *defaultCheckpointer) extractCheckpoint(get *azblob.DownloadR return checkpoint, err } +func extractCheckpointGoSdk(get *azblob.DownloadResponse) (Checkpoint, error) { + var checkpoint goCheckpoint + err := readToCheckpointFromBody(get, &checkpoint) + if err != nil { + return Checkpoint{}, err + } + + return Checkpoint{ + SequenceNumber: checkpoint.Checkpoint.SequenceNumber, + baseCheckpoint: baseCheckpoint{ + Offset: checkpoint.Checkpoint.Offset, + }, + PartitionID: checkpoint.PartitionID, + }, nil +} + func getCheckpoint(ctx context.Context, httpClient util.HTTPDoer, info EventHubInfo, checkpointer checkpointer) (Checkpoint, error) { blobCreds, storageEndpoint, err := ParseAzureStorageBlobConnection(ctx, httpClient, kedav1alpha1.AuthPodIdentity{Provider: kedav1alpha1.PodIdentityProviderNone}, info.StorageConnection, "", "") diff --git a/pkg/scalers/azure/azure_eventhub_test.go b/pkg/scalers/azure/azure_eventhub_test.go index 8594224113d..052a18e8e1a 100644 --- a/pkg/scalers/azure/azure_eventhub_test.go +++ b/pkg/scalers/azure/azure_eventhub_test.go @@ -227,6 +227,47 @@ func TestCheckpointFromBlobStorageGoSdk(t *testing.T) { assert.Equal(t, check, expectedCheckpoint) } +func TestCheckpointFromBlobStorageDapr(t *testing.T) { + if StorageConnectionString == "" { + return + } + + partitionID := "0" + offset := "1003" + + sequencenumber := int64(1) + + containerName := "daprcontainer" + checkpointFormat := "{\"partitionID\":\"%s\",\"epoch\":0,\"owner\":\"\",\"checkpoint\":{\"offset\":\"%s\",\"sequenceNumber\":%d,\"enqueueTime\":\"\"},\"state\":\"\",\"token\":\"\"}" + checkpoint := fmt.Sprintf(checkpointFormat, partitionID, offset, sequencenumber) + + urlPath := "" + + ctx, err := createNewCheckpointInStorage(urlPath, containerName, partitionID, checkpoint, nil) + assert.Equal(t, err, nil) + + expectedCheckpoint := Checkpoint{ + baseCheckpoint: baseCheckpoint{ + Offset: offset, + }, + PartitionID: partitionID, + SequenceNumber: sequencenumber, + } + + eventHubInfo := EventHubInfo{ + EventHubConnection: "Endpoint=sb://eventhubnamespace.servicebus.windows.net/;EntityPath=hub", + StorageConnection: StorageConnectionString, + EventHubName: "hub", + BlobContainer: containerName, + CheckpointStrategy: "dapr", + } + + check, _ := GetCheckpointFromBlobStorage(ctx, http.DefaultClient, eventHubInfo, partitionID) + _ = check.Offset + _ = expectedCheckpoint.Offset + assert.Equal(t, check, expectedCheckpoint) +} + func TestShouldParseCheckpointForFunction(t *testing.T) { eventHubInfo := EventHubInfo{ EventHubConnection: "Endpoint=sb://eventhubnamespace.servicebus.windows.net/;EntityPath=hub-test",