Skip to content

Commit

Permalink
table: simplify the locate partition args for PartitionedTable (#53432
Browse files Browse the repository at this point in the history
)

close #53431
  • Loading branch information
lcwangchao authored May 22, 2024
1 parent 78e4db9 commit a6b4fca
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 67 deletions.
2 changes: 1 addition & 1 deletion pkg/ddl/partition.go
Original file line number Diff line number Diff line change
Expand Up @@ -3366,7 +3366,7 @@ func (w *reorgPartitionWorker) fetchRowColVals(txn kv.Transaction, taskRange reo
}
tmpRow[offset] = d
}
p, err := w.reorgedTbl.GetPartitionByRow(w.sessCtx.GetExprCtx(), tmpRow)
p, err := w.reorgedTbl.GetPartitionByRow(w.sessCtx.GetExprCtx().GetEvalCtx(), tmpRow)
if err != nil {
return false, errors.Trace(err)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/executor/batch_checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func getKeysNeedCheckOneRow(ctx sessionctx.Context, t table.Table, row []types.D
pkIdxInfo *model.IndexInfo, result []toBeCheckedRow) ([]toBeCheckedRow, error) {
var err error
if p, ok := t.(table.PartitionedTable); ok {
t, err = p.GetPartitionByRow(ctx.GetExprCtx(), row)
t, err = p.GetPartitionByRow(ctx.GetExprCtx().GetEvalCtx(), row)
if err != nil {
if terr, ok := errors.Cause(err).(*terror.Error); ok && (terr.Code() == errno.ErrNoPartitionForGivenValue || terr.Code() == errno.ErrRowDoesNotMatchGivenPartitionSet) {
ec := ctx.GetSessionVars().StmtCtx.ErrCtx()
Expand Down
6 changes: 3 additions & 3 deletions pkg/executor/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -3529,7 +3529,7 @@ func (builder *dataReaderBuilder) prunePartitionForInnerExecutor(tbl table.Table
for i, data := range content.Keys {
locateKey[keyColOffsets[i]] = data
}
p, err := partitionTbl.GetPartitionByRow(exprCtx, locateKey)
p, err := partitionTbl.GetPartitionByRow(exprCtx.GetEvalCtx(), locateKey)
if table.ErrNoPartitionForGivenValue.Equal(err) {
continue
}
Expand Down Expand Up @@ -4168,7 +4168,7 @@ func (builder *dataReaderBuilder) buildTableReaderForIndexJoin(ctx context.Conte
for i, data := range content.Keys {
locateKey[keyColOffsets[i]] = data
}
p, err := pt.GetPartitionByRow(exprCtx, locateKey)
p, err := pt.GetPartitionByRow(exprCtx.GetEvalCtx(), locateKey)
if table.ErrNoPartitionForGivenValue.Equal(err) {
continue
}
Expand Down Expand Up @@ -4216,7 +4216,7 @@ func (builder *dataReaderBuilder) buildTableReaderForIndexJoin(ctx context.Conte
for i, data := range content.Keys {
locateKey[keyColOffsets[i]] = data
}
p, err := pt.GetPartitionByRow(exprCtx, locateKey)
p, err := pt.GetPartitionByRow(exprCtx.GetEvalCtx(), locateKey)
if table.ErrNoPartitionForGivenValue.Equal(err) {
continue
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/executor/write.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ func addUnchangedKeysForLockByRow(
count := 0
physicalID := t.Meta().ID
if pt, ok := t.(table.PartitionedTable); ok {
p, err := pt.GetPartitionByRow(sctx.GetExprCtx(), row)
p, err := pt.GetPartitionByRow(sctx.GetExprCtx().GetEvalCtx(), row)
if err != nil {
return 0, err
}
Expand Down Expand Up @@ -330,7 +330,7 @@ func checkRowForExchangePartition(sctx table.MutateContext, row []types.Datum, t
return errors.Errorf("exchange partition process assert table partition failed")
}
err := p.CheckForExchangePartition(
sctx.GetExprCtx(),
sctx.GetExprCtx().GetEvalCtx(),
pt.Meta().Partition,
row,
tbl.ExchangePartitionInfo.ExchangePartitionDefID,
Expand Down
6 changes: 3 additions & 3 deletions pkg/planner/core/point_get_plan.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ func (p *PointGetPlan) PrunePartitions(sctx sessionctx.Context) bool {
}
dVal.Copy(&row[p.HandleColOffset])
}
partIdx, err := pt.GetPartitionIdxByRow(sctx.GetExprCtx(), row)
partIdx, err := pt.GetPartitionIdxByRow(sctx.GetExprCtx().GetEvalCtx(), row)
if err != nil {
partIdx = -1
p.PartitionIdx = &partIdx
Expand Down Expand Up @@ -662,7 +662,7 @@ func (p *BatchPointGetPlan) getPartitionIdxs(sctx sessionctx.Context) []int {
for j := range rows[i] {
rows[i][j].Copy(&r[p.IndexInfo.Columns[j].Offset])
}
pIdx, err := pTbl.GetPartitionIdxByRow(sctx.GetExprCtx(), r)
pIdx, err := pTbl.GetPartitionIdxByRow(sctx.GetExprCtx().GetEvalCtx(), r)
if err != nil {
// Skip on any error, like:
// No matching partition, overflow etc.
Expand Down Expand Up @@ -760,7 +760,7 @@ func (p *BatchPointGetPlan) PrunePartitionsAndValues(sctx sessionctx.Context) ([
d = types.NewIntDatum(handle.IntValue())
}
d.Copy(&r[p.HandleColOffset])
pIdx, err := pTbl.GetPartitionIdxByRow(sctx.GetExprCtx(), r)
pIdx, err := pTbl.GetPartitionIdxByRow(sctx.GetExprCtx().GetEvalCtx(), r)
if err != nil ||
!isInExplicitPartitions(pi, pIdx, p.PartitionNames) ||
(p.SinglePartition &&
Expand Down
6 changes: 3 additions & 3 deletions pkg/table/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,12 +268,12 @@ type PhysicalTable interface {
type PartitionedTable interface {
Table
GetPartition(physicalID int64) PhysicalTable
GetPartitionByRow(expression.BuildContext, []types.Datum) (PhysicalTable, error)
GetPartitionIdxByRow(expression.BuildContext, []types.Datum) (int, error)
GetPartitionByRow(expression.EvalContext, []types.Datum) (PhysicalTable, error)
GetPartitionIdxByRow(expression.EvalContext, []types.Datum) (int, error)
GetAllPartitionIDs() []int64
GetPartitionColumnIDs() []int64
GetPartitionColumnNames() []model.CIStr
CheckForExchangePartition(ctx expression.BuildContext, pi *model.PartitionInfo, r []types.Datum, partID, ntID int64) error
CheckForExchangePartition(ctx expression.EvalContext, pi *model.PartitionInfo, r []types.Datum, partID, ntID int64) error
}

// TableFromMeta builds a table.Table from *model.TableInfo.
Expand Down
82 changes: 28 additions & 54 deletions pkg/table/tables/partition.go
Original file line number Diff line number Diff line change
Expand Up @@ -1269,7 +1269,7 @@ func PartitionRecordKey(pid int64, handle int64) kv.Key {
return tablecodec.EncodeRecordKey(recordPrefix, kv.IntHandle(handle))
}

func (t *partitionedTable) CheckForExchangePartition(ctx expression.BuildContext, pi *model.PartitionInfo, r []types.Datum, partID, ntID int64) error {
func (t *partitionedTable) CheckForExchangePartition(ctx expression.EvalContext, pi *model.PartitionInfo, r []types.Datum, partID, ntID int64) error {
defID, err := t.locatePartition(ctx, r)
if err != nil {
return err
Expand All @@ -1281,7 +1281,7 @@ func (t *partitionedTable) CheckForExchangePartition(ctx expression.BuildContext
}

// locatePartitionCommon returns the partition idx of the input record.
func (t *partitionedTable) locatePartitionCommon(ctx expression.BuildContext, tp model.PartitionType, partitionExpr *PartitionExpr, num uint64, columnsPartitioned bool, r []types.Datum) (int, error) {
func (t *partitionedTable) locatePartitionCommon(ctx expression.EvalContext, tp model.PartitionType, partitionExpr *PartitionExpr, num uint64, columnsPartitioned bool, r []types.Datum) (int, error) {
var err error
var idx int
switch tp {
Expand All @@ -1293,11 +1293,11 @@ func (t *partitionedTable) locatePartitionCommon(ctx expression.BuildContext, tp
}
case model.PartitionTypeHash:
// Note that only LIST and RANGE supports REORGANIZE PARTITION
idx, err = t.locateHashPartition(ctx.GetEvalCtx(), partitionExpr, num, r)
idx, err = t.locateHashPartition(ctx, partitionExpr, num, r)
case model.PartitionTypeKey:
idx, err = partitionExpr.LocateKeyPartition(num, r)
case model.PartitionTypeList:
idx, err = partitionExpr.locateListPartition(ctx.GetEvalCtx(), r)
idx, err = partitionExpr.locateListPartition(ctx, r)
case model.PartitionTypeNone:
idx = 0
}
Expand All @@ -1307,7 +1307,7 @@ func (t *partitionedTable) locatePartitionCommon(ctx expression.BuildContext, tp
return idx, nil
}

func (t *partitionedTable) locatePartitionIdx(ctx expression.BuildContext, r []types.Datum) (int, error) {
func (t *partitionedTable) locatePartitionIdx(ctx expression.EvalContext, r []types.Datum) (int, error) {
pi := t.Meta().GetPartitionInfo()
columnsSet := len(t.meta.Partition.Columns) > 0
idx, err := t.locatePartitionCommon(ctx, pi.Type, t.partitionExpr, pi.Num, columnsSet, r)
Expand All @@ -1317,7 +1317,7 @@ func (t *partitionedTable) locatePartitionIdx(ctx expression.BuildContext, r []t
return idx, nil
}

func (t *partitionedTable) locatePartition(ctx expression.BuildContext, r []types.Datum) (int64, error) {
func (t *partitionedTable) locatePartition(ctx expression.EvalContext, r []types.Datum) (int64, error) {
idx, err := t.locatePartitionIdx(ctx, r)
if err != nil {
return 0, errors.Trace(err)
Expand All @@ -1326,7 +1326,7 @@ func (t *partitionedTable) locatePartition(ctx expression.BuildContext, r []type
return pi.Definitions[idx].ID, nil
}

func (t *partitionedTable) locateReorgPartition(ctx expression.BuildContext, r []types.Datum) (int64, error) {
func (t *partitionedTable) locateReorgPartition(ctx expression.EvalContext, r []types.Datum) (int64, error) {
pi := t.Meta().GetPartitionInfo()
columnsSet := len(pi.DDLColumns) > 0
// Note that for KEY/HASH partitioning, since we do not support LINEAR,
Expand All @@ -1347,14 +1347,14 @@ func (t *partitionedTable) locateReorgPartition(ctx expression.BuildContext, r [
return pi.AddingDefinitions[idx].ID, nil
}

func (t *partitionedTable) locateRangeColumnPartition(ctx expression.BuildContext, partitionExpr *PartitionExpr, r []types.Datum) (int, error) {
func (t *partitionedTable) locateRangeColumnPartition(ctx expression.EvalContext, partitionExpr *PartitionExpr, r []types.Datum) (int, error) {
upperBounds := partitionExpr.UpperBounds
var lastError error
evalBuffer := t.evalBufferPool.Get().(*chunk.MutRow)
defer t.evalBufferPool.Put(evalBuffer)
idx := sort.Search(len(upperBounds), func(i int) bool {
evalBuffer.SetDatums(r...)
ret, isNull, err := upperBounds[i].EvalInt(ctx.GetEvalCtx(), evalBuffer.ToRow())
ret, isNull, err := upperBounds[i].EvalInt(ctx, evalBuffer.ToRow())
if err != nil {
lastError = err
return true // Does not matter, will propagate the last error anyway.
Expand All @@ -1370,21 +1370,7 @@ func (t *partitionedTable) locateRangeColumnPartition(ctx expression.BuildContex
return 0, errors.Trace(lastError)
}
if idx >= len(upperBounds) {
// The data does not belong to any of the partition returns `table has no partition for value %s`.
var valueMsg string
if t.meta.Partition.Expr != "" {
e, err := expression.ParseSimpleExpr(ctx, t.meta.Partition.Expr, expression.WithTableInfo("", t.meta))
if err == nil {
val, _, err := e.EvalInt(ctx.GetEvalCtx(), chunk.MutRowFromDatums(r).ToRow())
if err == nil {
valueMsg = strconv.FormatInt(val, 10)
}
}
} else {
// When the table is partitioned by range columns.
valueMsg = "from column_list"
}
return 0, table.ErrNoPartitionForGivenValue.GenWithStackByArgs(valueMsg)
return 0, table.ErrNoPartitionForGivenValue.GenWithStackByArgs("from column_list")
}
return idx, nil
}
Expand All @@ -1398,7 +1384,7 @@ func (pe *PartitionExpr) locateListPartition(ctx expression.EvalContext, r []typ
return lp.locateListColumnsPartitionByRow(tc, ec, r)
}

func (t *partitionedTable) locateRangePartition(ctx expression.BuildContext, partitionExpr *PartitionExpr, r []types.Datum) (int, error) {
func (t *partitionedTable) locateRangePartition(ctx expression.EvalContext, partitionExpr *PartitionExpr, r []types.Datum) (int, error) {
var (
ret int64
val int64
Expand All @@ -1414,7 +1400,7 @@ func (t *partitionedTable) locateRangePartition(ctx expression.BuildContext, par
evalBuffer := t.evalBufferPool.Get().(*chunk.MutRow)
defer t.evalBufferPool.Put(evalBuffer)
evalBuffer.SetDatums(r...)
val, isNull, err = partitionExpr.Expr.EvalInt(ctx.GetEvalCtx(), evalBuffer.ToRow())
val, isNull, err = partitionExpr.Expr.EvalInt(ctx, evalBuffer.ToRow())
if err != nil {
return 0, err
}
Expand All @@ -1435,22 +1421,10 @@ func (t *partitionedTable) locateRangePartition(ctx expression.BuildContext, par
if pos < 0 || pos >= length {
// The data does not belong to any of the partition returns `table has no partition for value %s`.
var valueMsg string
// TODO: Test with ALTER TABLE t PARTITION BY with a different expression / type
if t.meta.Partition.Expr != "" {
e, err := expression.ParseSimpleExpr(ctx, t.meta.Partition.Expr, expression.WithTableInfo("", t.meta))
if err == nil {
val, _, err := e.EvalInt(ctx.GetEvalCtx(), chunk.MutRowFromDatums(r).ToRow())
if err == nil {
if unsigned {
valueMsg = fmt.Sprintf("%d", uint64(val))
} else {
valueMsg = fmt.Sprintf("%d", val)
}
}
}
if unsigned {
valueMsg = fmt.Sprintf("%d", uint64(ret))
} else {
// When the table is partitioned by range columns.
valueMsg = "from column_list"
valueMsg = fmt.Sprintf("%d", ret)
}
return 0, table.ErrNoPartitionForGivenValue.GenWithStackByArgs(valueMsg)
}
Expand Down Expand Up @@ -1542,7 +1516,7 @@ func GetReorganizedPartitionedTable(t table.Table) (table.PartitionedTable, erro
}

// GetPartitionByRow returns a Table, which is actually a Partition.
func (t *partitionedTable) GetPartitionByRow(ctx expression.BuildContext, r []types.Datum) (table.PhysicalTable, error) {
func (t *partitionedTable) GetPartitionByRow(ctx expression.EvalContext, r []types.Datum) (table.PhysicalTable, error) {
pid, err := t.locatePartition(ctx, r)
if err != nil {
return nil, errors.Trace(err)
Expand All @@ -1551,12 +1525,12 @@ func (t *partitionedTable) GetPartitionByRow(ctx expression.BuildContext, r []ty
}

// GetPartitionIdxByRow returns the index in PartitionDef for the matching partition
func (t *partitionedTable) GetPartitionIdxByRow(ctx expression.BuildContext, r []types.Datum) (int, error) {
func (t *partitionedTable) GetPartitionIdxByRow(ctx expression.EvalContext, r []types.Datum) (int, error) {
return t.locatePartitionIdx(ctx, r)
}

// GetPartitionByRow returns a Table, which is actually a Partition.
func (t *partitionTableWithGivenSets) GetPartitionByRow(ctx expression.BuildContext, r []types.Datum) (table.PhysicalTable, error) {
func (t *partitionTableWithGivenSets) GetPartitionByRow(ctx expression.EvalContext, r []types.Datum) (table.PhysicalTable, error) {
pid, err := t.locatePartition(ctx, r)
if err != nil {
return nil, errors.Trace(err)
Expand Down Expand Up @@ -1606,7 +1580,7 @@ func (t *partitionedTable) AddRecord(ctx table.MutateContext, r []types.Datum, o
}

func partitionedTableAddRecord(ctx table.MutateContext, t *partitionedTable, r []types.Datum, partitionSelection map[int64]struct{}, opts []table.AddRecordOption) (recordID kv.Handle, err error) {
pid, err := t.locatePartition(ctx.GetExprCtx(), r)
pid, err := t.locatePartition(ctx.GetExprCtx().GetEvalCtx(), r)
if err != nil {
return nil, errors.Trace(err)
}
Expand Down Expand Up @@ -1637,7 +1611,7 @@ func partitionedTableAddRecord(ctx table.MutateContext, t *partitionedTable, r [
}
if _, ok := t.reorganizePartitions[pid]; ok {
// Double write to the ongoing reorganized partition
pid, err = t.locateReorgPartition(ctx.GetExprCtx(), r)
pid, err = t.locateReorgPartition(ctx.GetExprCtx().GetEvalCtx(), r)
if err != nil {
return nil, errors.Trace(err)
}
Expand Down Expand Up @@ -1685,7 +1659,7 @@ func (t *partitionTableWithGivenSets) GetAllPartitionIDs() []int64 {
// RemoveRecord implements table.Table RemoveRecord interface.
func (t *partitionedTable) RemoveRecord(ctx table.MutateContext, h kv.Handle, r []types.Datum) error {
ectx := ctx.GetExprCtx()
pid, err := t.locatePartition(ectx, r)
pid, err := t.locatePartition(ectx.GetEvalCtx(), r)
if err != nil {
return errors.Trace(err)
}
Expand All @@ -1697,7 +1671,7 @@ func (t *partitionedTable) RemoveRecord(ctx table.MutateContext, h kv.Handle, r
}

if _, ok := t.reorganizePartitions[pid]; ok {
pid, err = t.locateReorgPartition(ectx, r)
pid, err = t.locateReorgPartition(ectx.GetEvalCtx(), r)
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -1734,11 +1708,11 @@ func (t *partitionTableWithGivenSets) UpdateRecord(ctx context.Context, sctx tab

func partitionedTableUpdateRecord(gctx context.Context, ctx table.MutateContext, t *partitionedTable, h kv.Handle, currData, newData []types.Datum, touched []bool, partitionSelection map[int64]struct{}) error {
ectx := ctx.GetExprCtx()
from, err := t.locatePartition(ectx, currData)
from, err := t.locatePartition(ectx.GetEvalCtx(), currData)
if err != nil {
return errors.Trace(err)
}
to, err := t.locatePartition(ectx, newData)
to, err := t.locatePartition(ectx.GetEvalCtx(), newData)
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -1786,14 +1760,14 @@ func partitionedTableUpdateRecord(gctx context.Context, ctx table.MutateContext,

newTo, newFrom := int64(0), int64(0)
if _, ok := t.reorganizePartitions[to]; ok {
newTo, err = t.locateReorgPartition(ectx, newData)
newTo, err = t.locateReorgPartition(ectx.GetEvalCtx(), newData)
// There might be valid cases when errors should be accepted?
if err != nil {
return errors.Trace(err)
}
}
if _, ok := t.reorganizePartitions[from]; ok {
newFrom, err = t.locateReorgPartition(ectx, currData)
newFrom, err = t.locateReorgPartition(ectx.GetEvalCtx(), currData)
// There might be valid cases when errors should be accepted?
if err != nil {
return errors.Trace(err)
Expand Down Expand Up @@ -1833,11 +1807,11 @@ func partitionedTableUpdateRecord(gctx context.Context, ctx table.MutateContext,
if _, ok := t.reorganizePartitions[to]; ok {
// Even if to == from, in the reorganized partitions they may differ
// like in case of a split
newTo, err := t.locateReorgPartition(ectx, newData)
newTo, err := t.locateReorgPartition(ectx.GetEvalCtx(), newData)
if err != nil {
return errors.Trace(err)
}
newFrom, err := t.locateReorgPartition(ectx, currData)
newFrom, err := t.locateReorgPartition(ectx.GetEvalCtx(), currData)
if err != nil {
return errors.Trace(err)
}
Expand Down

0 comments on commit a6b4fca

Please sign in to comment.