diff --git a/fixedwindow.go b/fixedwindow.go index 3057e9c..8467056 100644 --- a/fixedwindow.go +++ b/fixedwindow.go @@ -199,6 +199,17 @@ func NewFixedWindowDynamoDB(client *dynamodb.Client, partitionKey string, props } } +type contextKey int + +var fixedWindowDynamoDBPartitionKey contextKey + +// NewFixedWindowDynamoDBContext creates a context for FixedWindowDynamoDB with a partition key. +// +// This context can be used to control the partition key per-request. +func NewFixedWindowDynamoDBContext(ctx context.Context, partitionKey string) context.Context { + return context.WithValue(ctx, fixedWindowDynamoDBPartitionKey, partitionKey) +} + const ( fixedWindowDynamoDBUpdateExpression = "SET #C = if_not_exists(#C, :def) + :inc, #TTL = :ttl" dynamodbWindowCountKey = "Count" @@ -212,9 +223,13 @@ func (f *FixedWindowDynamoDB) Increment(ctx context.Context, window time.Time, t done := make(chan struct{}) go func() { defer close(done) + partitionKey := f.partitionKey + if key, ok := ctx.Value(fixedWindowDynamoDBPartitionKey).(string); ok { + partitionKey = key + } resp, err = f.client.UpdateItem(ctx, &dynamodb.UpdateItemInput{ Key: map[string]types.AttributeValue{ - f.tableProps.PartitionKeyName: &types.AttributeValueMemberS{Value: f.partitionKey}, + f.tableProps.PartitionKeyName: &types.AttributeValueMemberS{Value: partitionKey}, f.tableProps.SortKeyName: &types.AttributeValueMemberS{Value: strconv.FormatInt(window.UnixNano(), 10)}, }, UpdateExpression: aws.String(fixedWindowDynamoDBUpdateExpression), diff --git a/fixedwindow_test.go b/fixedwindow_test.go index 1a2c656..2ca38ab 100644 --- a/fixedwindow_test.go +++ b/fixedwindow_test.go @@ -104,6 +104,23 @@ func (s *LimitersTestSuite) TestFixedWindowOverflow() { } } +func (s *LimitersTestSuite) TestFixedWindowDynamoDBPartitionKey() { + clock := newFakeClockWithTime(time.Date(2019, 8, 30, 0, 0, 0, 0, time.UTC)) + incrementor := l.NewFixedWindowDynamoDB(s.dynamodbClient, "partitionKey1", s.dynamoDBTableProps) + window := l.NewFixedWindow(2, time.Millisecond*100, incrementor, clock) + + w, err := window.Limit(context.TODO()) + s.Require().NoError(err) + s.Equal(time.Duration(0), w) + w, err = window.Limit(context.TODO()) + s.Require().NoError(err) + s.Equal(time.Duration(0), w) + // The third call should fail for the "partitionKey1", but succeed for "partitionKey2". + w, err = window.Limit(l.NewFixedWindowDynamoDBContext(context.Background(), "partitionKey2")) + s.Require().NoError(err) + s.Equal(time.Duration(0), w) +} + func BenchmarkFixedWindows(b *testing.B) { s := new(LimitersTestSuite) s.SetT(&testing.T{})