diff --git a/integration_tests/resource_tag_test.go b/integration_tests/resource_tag_test.go new file mode 100644 index 0000000000..a37f50aabf --- /dev/null +++ b/integration_tests/resource_tag_test.go @@ -0,0 +1,149 @@ +package tikv_test + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/tikv/client-go/v2/tikv" + "github.com/tikv/client-go/v2/tikvrpc" +) + +var _ tikv.Client = &resourceGroupTagMockClient{} + +type resourceGroupTagMockClient struct { + t *testing.T + inner tikv.Client + expectedTag []byte + requestCount int +} + +func (c *resourceGroupTagMockClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.Request, timeout time.Duration) (*tikvrpc.Response, error) { + if len(req.ResourceGroupTag) == 0 { + return c.inner.SendRequest(ctx, addr, req, timeout) + } + c.requestCount++ + assert.Equal(c.t, c.expectedTag, req.ResourceGroupTag) + return c.inner.SendRequest(ctx, addr, req, timeout) +} + +func (c *resourceGroupTagMockClient) Close() error { + return c.inner.Close() +} + +func TestResourceGroupTag(t *testing.T) { + testTag1 := []byte("TEST-TAG-1") + testTag2 := []byte("TEST-TAG-2") + testTagger := tikvrpc.ResourceGroupTagger(func(req *tikvrpc.Request) { + req.ResourceGroupTag = testTag2 + }) + + /* Get */ + + // SetResourceGroupTag + store := NewTestStore(t) + client := &resourceGroupTagMockClient{t: t, inner: store.GetTiKVClient(), expectedTag: testTag1} + store.SetTiKVClient(client) + txn, err := store.Begin() + assert.NoError(t, err) + txn.SetResourceGroupTag(testTag1) + _, _ = txn.Get(context.Background(), []byte{}) + assert.Equal(t, 1, client.requestCount) + assert.NoError(t, store.Close()) + + // SetResourceGroupTagger + store = NewTestStore(t) + client = &resourceGroupTagMockClient{t: t, inner: store.GetTiKVClient(), expectedTag: testTag2} + store.SetTiKVClient(client) + txn, err = store.Begin() + assert.NoError(t, err) + txn.SetResourceGroupTagger(testTagger) + _, _ = txn.Get(context.Background(), []byte{}) + assert.Equal(t, 1, client.requestCount) + assert.NoError(t, store.Close()) + + // SetResourceGroupTag + SetResourceGroupTagger + store = NewTestStore(t) + client = &resourceGroupTagMockClient{t: t, inner: store.GetTiKVClient(), expectedTag: testTag1} + store.SetTiKVClient(client) + txn, err = store.Begin() + assert.NoError(t, err) + txn.SetResourceGroupTag(testTag1) + txn.SetResourceGroupTagger(testTagger) + _, _ = txn.Get(context.Background(), []byte{}) + assert.Equal(t, 1, client.requestCount) + assert.NoError(t, store.Close()) + + /* BatchGet */ + + // SetResourceGroupTag + store = NewTestStore(t) + client = &resourceGroupTagMockClient{t: t, inner: store.GetTiKVClient(), expectedTag: testTag1} + store.SetTiKVClient(client) + txn, err = store.Begin() + assert.NoError(t, err) + txn.SetResourceGroupTag(testTag1) + _, _ = txn.BatchGet(context.Background(), [][]byte{[]byte("k")}) + assert.Equal(t, 1, client.requestCount) + assert.NoError(t, store.Close()) + + // SetResourceGroupTagger + store = NewTestStore(t) + client = &resourceGroupTagMockClient{t: t, inner: store.GetTiKVClient(), expectedTag: testTag2} + store.SetTiKVClient(client) + txn, err = store.Begin() + assert.NoError(t, err) + txn.SetResourceGroupTagger(testTagger) + _, _ = txn.BatchGet(context.Background(), [][]byte{[]byte("k")}) + assert.Equal(t, 1, client.requestCount) + assert.NoError(t, store.Close()) + + // SetResourceGroupTag + SetResourceGroupTagger + store = NewTestStore(t) + client = &resourceGroupTagMockClient{t: t, inner: store.GetTiKVClient(), expectedTag: testTag1} + store.SetTiKVClient(client) + txn, err = store.Begin() + assert.NoError(t, err) + txn.SetResourceGroupTag(testTag1) + txn.SetResourceGroupTagger(testTagger) + _, _ = txn.BatchGet(context.Background(), [][]byte{[]byte("k")}) + assert.Equal(t, 1, client.requestCount) + assert.NoError(t, store.Close()) + + /* Scan */ + + // SetResourceGroupTag + store = NewTestStore(t) + client = &resourceGroupTagMockClient{t: t, inner: store.GetTiKVClient(), expectedTag: testTag1} + store.SetTiKVClient(client) + txn, err = store.Begin() + assert.NoError(t, err) + txn.SetResourceGroupTag(testTag1) + _, _ = txn.Iter([]byte("abc"), []byte("def")) + assert.Equal(t, 1, client.requestCount) + assert.NoError(t, store.Close()) + + // SetResourceGroupTagger + store = NewTestStore(t) + client = &resourceGroupTagMockClient{t: t, inner: store.GetTiKVClient(), expectedTag: testTag2} + store.SetTiKVClient(client) + txn, err = store.Begin() + assert.NoError(t, err) + txn.SetResourceGroupTagger(testTagger) + _, _ = txn.Iter([]byte("abc"), []byte("def")) + assert.Equal(t, 1, client.requestCount) + assert.NoError(t, store.Close()) + + // SetResourceGroupTag + SetResourceGroupTagger + store = NewTestStore(t) + client = &resourceGroupTagMockClient{t: t, inner: store.GetTiKVClient(), expectedTag: testTag1} + store.SetTiKVClient(client) + txn, err = store.Begin() + assert.NoError(t, err) + txn.SetResourceGroupTag(testTag1) + txn.SetResourceGroupTagger(testTagger) + _, _ = txn.Iter([]byte("abc"), []byte("def")) + assert.Equal(t, 1, client.requestCount) + assert.NoError(t, store.Close()) +} diff --git a/tikvrpc/tikvrpc.go b/tikvrpc/tikvrpc.go index 8b4c0bfcfa..8ef8214e51 100644 --- a/tikvrpc/tikvrpc.go +++ b/tikvrpc/tikvrpc.go @@ -1178,3 +1178,6 @@ func (req *Request) IsTxnWriteRequest() bool { } return false } + +// ResourceGroupTagger is used to fill the ResourceGroupTag in the kvrpcpb.Context. +type ResourceGroupTagger func(req *Request) diff --git a/txnkv/transaction/2pc.go b/txnkv/transaction/2pc.go index 229f9de914..a4b55301fd 100644 --- a/txnkv/transaction/2pc.go +++ b/txnkv/transaction/2pc.go @@ -168,7 +168,8 @@ type twoPhaseCommitter struct { binlog BinlogExecutor - resourceGroupTag []byte + resourceGroupTag []byte + resourceGroupTagger tikvrpc.ResourceGroupTagger // use this when resourceGroupTag is nil // allowed when tikv disk full happened. diskFullOpt kvrpcpb.DiskFullOpt @@ -495,6 +496,7 @@ func (c *twoPhaseCommitter) initKeysAndMutations() error { c.priority = txn.priority.ToPB() c.syncLog = txn.syncLog c.resourceGroupTag = txn.resourceGroupTag + c.resourceGroupTagger = txn.resourceGroupTagger c.setDetail(commitDetail) return nil } diff --git a/txnkv/transaction/cleanup.go b/txnkv/transaction/cleanup.go index 6e85079ff3..bc6bd2609a 100644 --- a/txnkv/transaction/cleanup.go +++ b/txnkv/transaction/cleanup.go @@ -64,6 +64,9 @@ func (actionCleanup) handleSingleBatch(c *twoPhaseCommitter, bo *retry.Backoffer StartVersion: c.startTS, }, kvrpcpb.Context{Priority: c.priority, SyncLog: c.syncLog, ResourceGroupTag: c.resourceGroupTag, MaxExecutionDurationMs: uint64(client.MaxWriteExecutionTime.Milliseconds())}) + if c.resourceGroupTag == nil && c.resourceGroupTagger != nil { + c.resourceGroupTagger(req) + } resp, err := c.store.SendReq(bo, req, batch.region, client.ReadTimeoutShort) if err != nil { return err diff --git a/txnkv/transaction/commit.go b/txnkv/transaction/commit.go index 520cd50e4f..fda0d611c6 100644 --- a/txnkv/transaction/commit.go +++ b/txnkv/transaction/commit.go @@ -73,6 +73,9 @@ func (actionCommit) handleSingleBatch(c *twoPhaseCommitter, bo *retry.Backoffer, }, kvrpcpb.Context{Priority: c.priority, SyncLog: c.syncLog, ResourceGroupTag: c.resourceGroupTag, DiskFullOpt: c.diskFullOpt, MaxExecutionDurationMs: uint64(client.MaxWriteExecutionTime.Milliseconds())}) + if c.resourceGroupTag == nil && c.resourceGroupTagger != nil { + c.resourceGroupTagger(req) + } tBegin := time.Now() attempts := 0 diff --git a/txnkv/transaction/prewrite.go b/txnkv/transaction/prewrite.go index a700f82715..f30c58a217 100644 --- a/txnkv/transaction/prewrite.go +++ b/txnkv/transaction/prewrite.go @@ -141,9 +141,13 @@ func (c *twoPhaseCommitter) buildPrewriteRequest(batch batchMutations, txnSize u req.TryOnePc = true } - return tikvrpc.NewRequest(tikvrpc.CmdPrewrite, req, + r := tikvrpc.NewRequest(tikvrpc.CmdPrewrite, req, kvrpcpb.Context{Priority: c.priority, SyncLog: c.syncLog, ResourceGroupTag: c.resourceGroupTag, DiskFullOpt: c.diskFullOpt, MaxExecutionDurationMs: uint64(client.MaxWriteExecutionTime.Milliseconds())}) + if c.resourceGroupTag == nil && c.resourceGroupTagger != nil { + c.resourceGroupTagger(r) + } + return r } func (action actionPrewrite) handleSingleBatch(c *twoPhaseCommitter, bo *retry.Backoffer, batch batchMutations) (err error) { diff --git a/txnkv/transaction/txn.go b/txnkv/transaction/txn.go index 5f5613eaeb..35e2fc3f6d 100644 --- a/txnkv/transaction/txn.go +++ b/txnkv/transaction/txn.go @@ -58,6 +58,7 @@ import ( "github.com/tikv/client-go/v2/internal/unionstore" tikv "github.com/tikv/client-go/v2/kv" "github.com/tikv/client-go/v2/metrics" + "github.com/tikv/client-go/v2/tikvrpc" "github.com/tikv/client-go/v2/txnkv/txnsnapshot" "github.com/tikv/client-go/v2/txnkv/txnutil" "github.com/tikv/client-go/v2/util" @@ -98,18 +99,19 @@ type KVTxn struct { // commitCallback is called after current transaction gets committed commitCallback func(info string, err error) - binlog BinlogExecutor - schemaLeaseChecker SchemaLeaseChecker - syncLog bool - priority txnutil.Priority - isPessimistic bool - enableAsyncCommit bool - enable1PC bool - causalConsistency bool - scope string - kvFilter KVFilter - resourceGroupTag []byte - diskFullOpt kvrpcpb.DiskFullOpt + binlog BinlogExecutor + schemaLeaseChecker SchemaLeaseChecker + syncLog bool + priority txnutil.Priority + isPessimistic bool + enableAsyncCommit bool + enable1PC bool + causalConsistency bool + scope string + kvFilter KVFilter + resourceGroupTag []byte + resourceGroupTagger tikvrpc.ResourceGroupTagger // use this when resourceGroupTag is nil + diskFullOpt kvrpcpb.DiskFullOpt } // NewTiKVTxn creates a new KVTxn. @@ -232,6 +234,14 @@ func (txn *KVTxn) SetResourceGroupTag(tag []byte) { txn.GetSnapshot().SetResourceGroupTag(tag) } +// SetResourceGroupTagger sets the resource tagger for both write and read. +// Before sending the request, if resourceGroupTag is not nil, use +// resourceGroupTag directly, otherwise use resourceGroupTagger. +func (txn *KVTxn) SetResourceGroupTagger(tagger tikvrpc.ResourceGroupTagger) { + txn.resourceGroupTagger = tagger + txn.GetSnapshot().SetResourceGroupTagger(tagger) +} + // SetSchemaAmender sets an amender to update mutations after schema change. func (txn *KVTxn) SetSchemaAmender(sa SchemaAmender) { txn.schemaAmender = sa diff --git a/txnkv/txnsnapshot/scan.go b/txnkv/txnsnapshot/scan.go index b3f9d3fe1e..354ca81956 100644 --- a/txnkv/txnsnapshot/scan.go +++ b/txnkv/txnsnapshot/scan.go @@ -240,6 +240,9 @@ func (s *Scanner) getData(bo *retry.Backoffer) error { TaskId: s.snapshot.mu.taskID, ResourceGroupTag: s.snapshot.resourceGroupTag, }) + if s.snapshot.resourceGroupTag == nil && s.snapshot.resourceGroupTagger != nil { + s.snapshot.resourceGroupTagger(req) + } s.snapshot.mu.RUnlock() resp, err := sender.SendReq(bo, req, loc.Region, client.ReadTimeoutMedium) if err != nil { diff --git a/txnkv/txnsnapshot/snapshot.go b/txnkv/txnsnapshot/snapshot.go index 422c7a9f16..ef5b684217 100644 --- a/txnkv/txnsnapshot/snapshot.go +++ b/txnkv/txnsnapshot/snapshot.go @@ -134,6 +134,8 @@ type KVSnapshot struct { sampleStep uint32 // resourceGroupTag is use to set the kv request resource group tag. resourceGroupTag []byte + // resourceGroupTagger is use to set the kv request resource group tag if resourceGroupTag is nil. + resourceGroupTagger tikvrpc.ResourceGroupTagger } // NewTiKVSnapshot creates a snapshot of an TiKV store. @@ -353,6 +355,9 @@ func (s *KVSnapshot) batchGetSingleRegion(bo *retry.Backoffer, batch batchKeys, TaskId: s.mu.taskID, ResourceGroupTag: s.resourceGroupTag, }) + if s.resourceGroupTag == nil && s.resourceGroupTagger != nil { + s.resourceGroupTagger(req) + } scope := s.mu.readReplicaScope isStaleness := s.mu.isStaleness matchStoreLabels := s.mu.matchStoreLabels @@ -520,6 +525,9 @@ func (s *KVSnapshot) get(ctx context.Context, bo *retry.Backoffer, k []byte) ([] TaskId: s.mu.taskID, ResourceGroupTag: s.resourceGroupTag, }) + if s.resourceGroupTag == nil && s.resourceGroupTagger != nil { + s.resourceGroupTagger(req) + } isStaleness := s.mu.isStaleness matchStoreLabels := s.mu.matchStoreLabels scope := s.mu.readReplicaScope @@ -714,11 +722,18 @@ func (s *KVSnapshot) SetMatchStoreLabels(labels []*metapb.StoreLabel) { s.mu.matchStoreLabels = labels } -// SetResourceGroupTag sets resource group of the kv request. +// SetResourceGroupTag sets resource group tag of the kv request. func (s *KVSnapshot) SetResourceGroupTag(tag []byte) { s.resourceGroupTag = tag } +// SetResourceGroupTagger sets resource group tagger of the kv request. +// Before sending the request, if resourceGroupTag is not nil, use +// resourceGroupTag directly, otherwise use resourceGroupTagger. +func (s *KVSnapshot) SetResourceGroupTagger(tagger tikvrpc.ResourceGroupTagger) { + s.resourceGroupTagger = tagger +} + // SnapCacheHitCount gets the snapshot cache hit count. Only for test. func (s *KVSnapshot) SnapCacheHitCount() int { return int(atomic.LoadInt64(&s.mu.hitCnt))